Created
March 29, 2024 01:27
-
-
Save yf225/ba9b94194e5ce2cdc930393614fbb7b7 to your computer and use it in GitHub Desktop.
This file has been truncated, but you can view the full file.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| + export USE_LIBUV=1 | |
| + USE_LIBUV=1 | |
| + TRAINER_DIR=/home/willfeng/local/torchtrain | |
| + NGPU=8 | |
| + LOG_RANK=0 | |
| + CONFIG_FILE=./train_configs/llama_1b_full_graph_fsdp.toml | |
| + torchrun --nproc_per_node=8 --rdzv_endpoint=localhost:5972 --local-ranks-filter 0 --role rank --tee 3 train.py --job.config_file ./train_configs/llama_1b_full_graph_fsdp.toml | |
| W2024-03-28 18:03:14,934.934000 140450963445568 torch/distributed/run.py:757] | |
| W2024-03-28 18:03:14,934.934000 140450963445568 torch/distributed/run.py:757] ***************************************** | |
| W2024-03-28 18:03:14,934.934000 140450963445568 torch/distributed/run.py:757] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed. | |
| W2024-03-28 18:03:14,934.934000 140450963445568 torch/distributed/run.py:757] ***************************************** | |
| [rank0]:2024-03-28 18:03:19,030 - root - INFO - Starting job: LLaMA 7B training | |
| [rank0]:2024-03-28 18:03:21,163 - root - WARNING - ENV[TORCH_NCCL_ASYNC_ERROR_HANDLING] = 1 will be overridden to 3 based on job config | |
| [rank0]:2024-03-28 18:03:21,166 - root - INFO - Building 1-D device mesh with ['dp'], [8] | |
| [rank0]:2024-03-28 18:03:21,183 - root - INFO - Building sentencepiece tokenizer locally from ./torchtrain/datasets/tokenizer/tokenizer.model | |
| [rank0]:2024-03-28 18:03:21,199 - root - INFO - SentencePieceTokenizer built: #words 32000, BOS ID 1, EOS ID 2 | |
| [rank0]:2024-03-28 18:03:21,200 - root - INFO - Preparing alpaca dataset from HuggingFace | |
| [rank0]:2024-03-28 18:03:24,206 - root - INFO - Building llama 1B with ModelArgs(dim=2048, n_layers=18, n_heads=16, n_kv_heads=None, vocab_size=32000, multiple_of=256, ffn_dim_multiplier=None, norm_eps=1e-05, max_batch_size=32, max_seq_len=32768, depth_init=True) | |
| [rank0]:2024-03-28 18:03:24,491 - root - INFO - llama 1B size: 1,055,991,808 total parameters | |
| [rank0]:2024-03-28 18:03:24,491 - root - INFO - GPU capacity: NVIDIA PG509-210 (0) with 79.15GiB memory | |
| [rank0]:2024-03-28 18:03:24,557 - root - INFO - Applied FSDP to the model | |
| [rank0]:2024-03-28 18:03:30,309 - root - INFO - Gradient scaling not enabled | |
| [rank0]:2024-03-28 18:03:30,309 - root - INFO - Metrics logging active. Tensorboard logs will be saved at ./outputs/llama_1b_full_graph_fsdp/tb/20240328-1803 | |
| [rank0]:2024-03-28 18:03:30,310 - root - INFO - Compiling model with torch.compile | |
| [rank0]:2024-03-28 18:03:30,311 - root - INFO - Compiling model with torch.compile + full-graph compile FSDP | |
| [rank0]:2024-03-28 18:03:30,372 - root - INFO - Profiling active. Traces will be saved at ./outputs/llama_1b_full_graph_fsdp/profiling/traces | |
| [rank0]:2024-03-28 18:03:40,719 - root - INFO - step: 1 loss: 10.8856 memory: 6.99GiB(8.84%) wps: 198 mfu: 0.46% | |
| [rank0]:2024-03-28 18:03:40,720 - root - INFO - Synchronizing and adjusting timeout for all ProcessGroups to 0:05:00 | |
| [rank0]:[rank0]:W2024-03-28 18:03:42,059.059000 139804330415936 torch/_logging/_internal.py:1042] [0/0] Profiler function <class 'torch.autograd.profiler.record_function'> will be ignored | |
| [rank0]:/data/users/willfeng/pytorch_yf225/torch/_inductor/lowering.py:1788: UserWarning: Torchinductor does not support code generation for complex operators. Performance may be worse than eager. | |
| [rank0]: warnings.warn( | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] TRACED GRAPH | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] ===== after FSDP FX passes: ===== | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] /data/users/willfeng/pytorch_yf225/torch/fx/_lazy_graph_module.py class GraphModule(torch.nn.Module): | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] def forward(self, primals_1: "i64[1, 2048]", primals_2: "f32[8192000]", primals_3: "f32[256]", primals_4: "f32[8192000]", primals_5: "bf16[32000, 2048]", primals_6: "bf16[2048]", primals_7: "bf16[32000, 2048]", primals_8: "c64[65536, 64]", primals_9: "f32[524288]", primals_10: "f32[524288]", primals_11: "f32[524288]", primals_12: "f32[524288]", primals_13: "f32[1441792]", primals_14: "f32[1441792]", primals_15: "f32[1441792]", primals_16: "f32[256]", primals_17: "f32[256]", primals_18: "bf16[2048, 2048]", primals_19: "bf16[2048, 2048]", primals_20: "bf16[2048, 2048]", primals_21: "bf16[2048, 2048]", primals_22: "bf16[5632, 2048]", primals_23: "bf16[2048, 5632]", primals_24: "bf16[5632, 2048]", primals_25: "bf16[2048]", primals_26: "bf16[2048]", primals_27, primals_28: "f32[524288]", primals_29: "f32[524288]", primals_30: "f32[524288]", primals_31: "f32[524288]", primals_32: "f32[1441792]", primals_33: "f32[1441792]", primals_34: "f32[1441792]", primals_35: "f32[256]", primals_36: "f32[256]", primals_37: "bf16[2048, 2048]", primals_38: "bf16[2048, 2048]", primals_39: "bf16[2048, 2048]", primals_40: "bf16[2048, 2048]", primals_41: "bf16[5632, 2048]", primals_42: "bf16[2048, 5632]", primals_43: "bf16[5632, 2048]", primals_44: "bf16[2048]", primals_45: "bf16[2048]", primals_46: "f32[524288]", primals_47: "f32[524288]", primals_48: "f32[524288]", primals_49: "f32[524288]", primals_50: "f32[1441792]", primals_51: "f32[1441792]", primals_52: "f32[1441792]", primals_53: "f32[256]", primals_54: "f32[256]", primals_55: "bf16[2048, 2048]", primals_56: "bf16[2048, 2048]", primals_57: "bf16[2048, 2048]", primals_58: "bf16[2048, 2048]", primals_59: "bf16[5632, 2048]", primals_60: "bf16[2048, 5632]", primals_61: "bf16[5632, 2048]", primals_62: "bf16[2048]", primals_63: "bf16[2048]", primals_64: "f32[524288]", primals_65: "f32[524288]", primals_66: "f32[524288]", primals_67: "f32[524288]", primals_68: "f32[1441792]", primals_69: "f32[1441792]", primals_70: "f32[1441792]", primals_71: "f32[256]", primals_72: "f32[256]", primals_73: "bf16[2048, 2048]", primals_74: "bf16[2048, 2048]", primals_75: "bf16[2048, 2048]", primals_76: "bf16[2048, 2048]", primals_77: "bf16[5632, 2048]", primals_78: "bf16[2048, 5632]", primals_79: "bf16[5632, 2048]", primals_80: "bf16[2048]", primals_81: "bf16[2048]", primals_82: "f32[524288]", primals_83: "f32[524288]", primals_84: "f32[524288]", primals_85: "f32[524288]", primals_86: "f32[1441792]", primals_87: "f32[1441792]", primals_88: "f32[1441792]", primals_89: "f32[256]", primals_90: "f32[256]", primals_91: "bf16[2048, 2048]", primals_92: "bf16[2048, 2048]", primals_93: "bf16[2048, 2048]", primals_94: "bf16[2048, 2048]", primals_95: "bf16[5632, 2048]", primals_96: "bf16[2048, 5632]", primals_97: "bf16[5632, 2048]", primals_98: "bf16[2048]", primals_99: "bf16[2048]", primals_100: "f32[524288]", primals_101: "f32[524288]", primals_102: "f32[524288]", primals_103: "f32[524288]", primals_104: "f32[1441792]", primals_105: "f32[1441792]", primals_106: "f32[1441792]", primals_107: "f32[256]", primals_108: "f32[256]", primals_109: "bf16[2048, 2048]", primals_110: "bf16[2048, 2048]", primals_111: "bf16[2048, 2048]", primals_112: "bf16[2048, 2048]", primals_113: "bf16[5632, 2048]", primals_114: "bf16[2048, 5632]", primals_115: "bf16[5632, 2048]", primals_116: "bf16[2048]", primals_117: "bf16[2048]", primals_118: "f32[524288]", primals_119: "f32[524288]", primals_120: "f32[524288]", primals_121: "f32[524288]", primals_122: "f32[1441792]", primals_123: "f32[1441792]", primals_124: "f32[1441792]", primals_125: "f32[256]", primals_126: "f32[256]", primals_127: "bf16[2048, 2048]", primals_128: "bf16[2048, 2048]", primals_129: "bf16[2048, 2048]", primals_130: "bf16[2048, 2048]", primals_131: "bf16[5632, 2048]", primals_132: "bf16[2048, 5632]", primals_133: "bf16[5632, 2048]", primals_134: "bf16[2048]", primals_135: "bf16[2048]", primals_136: "f32[524288]", primals_137: "f32[524288]", primals_138: "f32[524288]", primals_139: "f32[524288]", primals_140: "f32[1441792]", primals_141: "f32[1441792]", primals_142: "f32[1441792]", primals_143: "f32[256]", primals_144: "f32[256]", primals_145: "bf16[2048, 2048]", primals_146: "bf16[2048, 2048]", primals_147: "bf16[2048, 2048]", primals_148: "bf16[2048, 2048]", primals_149: "bf16[5632, 2048]", primals_150: "bf16[2048, 5632]", primals_151: "bf16[5632, 2048]", primals_152: "bf16[2048]", primals_153: "bf16[2048]", primals_154: "f32[524288]", primals_155: "f32[524288]", primals_156: "f32[524288]", primals_157: "f32[524288]", primals_158: "f32[1441792]", primals_159: "f32[1441792]", primals_160: "f32[1441792]", primals_161: "f32[256]", primals_162: "f32[256]", primals_163: "bf16[2048, 2048]", primals_164: "bf16[2048, 2048]", primals_165: "bf16[2048, 2048]", primals_166: "bf16[2048, 2048]", primals_167: "bf16[5632, 2048]", primals_168: "bf16[2048, 5632]", primals_169: "bf16[5632, 2048]", primals_170: "bf16[2048]", primals_171: "bf16[2048]", primals_172: "f32[524288]", primals_173: "f32[524288]", primals_174: "f32[524288]", primals_175: "f32[524288]", primals_176: "f32[1441792]", primals_177: "f32[1441792]", primals_178: "f32[1441792]", primals_179: "f32[256]", primals_180: "f32[256]", primals_181: "bf16[2048, 2048]", primals_182: "bf16[2048, 2048]", primals_183: "bf16[2048, 2048]", primals_184: "bf16[2048, 2048]", primals_185: "bf16[5632, 2048]", primals_186: "bf16[2048, 5632]", primals_187: "bf16[5632, 2048]", primals_188: "bf16[2048]", primals_189: "bf16[2048]", primals_190: "f32[524288]", primals_191: "f32[524288]", primals_192: "f32[524288]", primals_193: "f32[524288]", primals_194: "f32[1441792]", primals_195: "f32[1441792]", primals_196: "f32[1441792]", primals_197: "f32[256]", primals_198: "f32[256]", primals_199: "bf16[2048, 2048]", primals_200: "bf16[2048, 2048]", primals_201: "bf16[2048, 2048]", primals_202: "bf16[2048, 2048]", primals_203: "bf16[5632, 2048]", primals_204: "bf16[2048, 5632]", primals_205: "bf16[5632, 2048]", primals_206: "bf16[2048]", primals_207: "bf16[2048]", primals_208: "f32[524288]", primals_209: "f32[524288]", primals_210: "f32[524288]", primals_211: "f32[524288]", primals_212: "f32[1441792]", primals_213: "f32[1441792]", primals_214: "f32[1441792]", primals_215: "f32[256]", primals_216: "f32[256]", primals_217: "bf16[2048, 2048]", primals_218: "bf16[2048, 2048]", primals_219: "bf16[2048, 2048]", primals_220: "bf16[2048, 2048]", primals_221: "bf16[5632, 2048]", primals_222: "bf16[2048, 5632]", primals_223: "bf16[5632, 2048]", primals_224: "bf16[2048]", primals_225: "bf16[2048]", primals_226: "f32[524288]", primals_227: "f32[524288]", primals_228: "f32[524288]", primals_229: "f32[524288]", primals_230: "f32[1441792]", primals_231: "f32[1441792]", primals_232: "f32[1441792]", primals_233: "f32[256]", primals_234: "f32[256]", primals_235: "bf16[2048, 2048]", primals_236: "bf16[2048, 2048]", primals_237: "bf16[2048, 2048]", primals_238: "bf16[2048, 2048]", primals_239: "bf16[5632, 2048]", primals_240: "bf16[2048, 5632]", primals_241: "bf16[5632, 2048]", primals_242: "bf16[2048]", primals_243: "bf16[2048]", primals_244: "f32[524288]", primals_245: "f32[524288]", primals_246: "f32[524288]", primals_247: "f32[524288]", primals_248: "f32[1441792]", primals_249: "f32[1441792]", primals_250: "f32[1441792]", primals_251: "f32[256]", primals_252: "f32[256]", primals_253: "bf16[2048, 2048]", primals_254: "bf16[2048, 2048]", primals_255: "bf16[2048, 2048]", primals_256: "bf16[2048, 2048]", primals_257: "bf16[5632, 2048]", primals_258: "bf16[2048, 5632]", primals_259: "bf16[5632, 2048]", primals_260: "bf16[2048]", primals_261: "bf16[2048]", primals_262: "f32[524288]", primals_263: "f32[524288]", primals_264: "f32[524288]", primals_265: "f32[524288]", primals_266: "f32[1441792]", primals_267: "f32[1441792]", primals_268: "f32[1441792]", primals_269: "f32[256]", primals_270: "f32[256]", primals_271: "bf16[2048, 2048]", primals_272: "bf16[2048, 2048]", primals_273: "bf16[2048, 2048]", primals_274: "bf16[2048, 2048]", primals_275: "bf16[5632, 2048]", primals_276: "bf16[2048, 5632]", primals_277: "bf16[5632, 2048]", primals_278: "bf16[2048]", primals_279: "bf16[2048]", primals_280: "f32[524288]", primals_281: "f32[524288]", primals_282: "f32[524288]", primals_283: "f32[524288]", primals_284: "f32[1441792]", primals_285: "f32[1441792]", primals_286: "f32[1441792]", primals_287: "f32[256]", primals_288: "f32[256]", primals_289: "bf16[2048, 2048]", primals_290: "bf16[2048, 2048]", primals_291: "bf16[2048, 2048]", primals_292: "bf16[2048, 2048]", primals_293: "bf16[5632, 2048]", primals_294: "bf16[2048, 5632]", primals_295: "bf16[5632, 2048]", primals_296: "bf16[2048]", primals_297: "bf16[2048]", primals_298: "f32[524288]", primals_299: "f32[524288]", primals_300: "f32[524288]", primals_301: "f32[524288]", primals_302: "f32[1441792]", primals_303: "f32[1441792]", primals_304: "f32[1441792]", primals_305: "f32[256]", primals_306: "f32[256]", primals_307: "bf16[2048, 2048]", primals_308: "bf16[2048, 2048]", primals_309: "bf16[2048, 2048]", primals_310: "bf16[2048, 2048]", primals_311: "bf16[5632, 2048]", primals_312: "bf16[2048, 5632]", primals_313: "bf16[5632, 2048]", primals_314: "bf16[2048]", primals_315: "bf16[2048]", primals_316: "f32[524288]", primals_317: "f32[524288]", primals_318: "f32[524288]", primals_319: "f32[524288]", primals_320: "f32[1441792]", primals_321: "f32[1441792]", primals_322: "f32[1441792]", primals_323: "f32[256]", primals_324: "f32[256]", primals_325: "bf16[2048, 2048]", primals_326: "bf16[2048, 2048]", primals_327: "bf16[2048, 2048]", primals_328: "bf16[2048, 2048]", primals_329: "bf16[5632, 2048]", primals_330: "bf16[2048, 5632]", primals_331: "bf16[5632, 2048]", primals_332: "bf16[2048]", primals_333: "bf16[2048]"): | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_common.py:140 in _to_dtype_if_needed, code: return tensor.to(dtype) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type: "bf16[8192000]" = torch.ops.prims.convert_element_type.default(primals_2, torch.bfloat16); primals_2 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_1: "bf16[256]" = torch.ops.prims.convert_element_type.default(primals_3, torch.bfloat16); primals_3 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_2: "bf16[8192000]" = torch.ops.prims.convert_element_type.default(primals_4, torch.bfloat16); primals_4 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:80 in foreach_all_gather, code: all_gather_input, all_gather_output = torch.ops.fsdp.all_gather_copy_in(all_gather_input_numel, world_size, rank, dtype, device, inp_split_sizes, param_all_gather_inputs) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] all_gather_copy_in = torch.ops.fsdp.all_gather_copy_in.default(16384256, 8, 0, torch.bfloat16, device(type='cuda', index=0), [8192000, 256, 8192000], [convert_element_type, convert_element_type_1, convert_element_type_2]); convert_element_type = convert_element_type_1 = convert_element_type_2 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem: "bf16[16384256]" = all_gather_copy_in[0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_1: "bf16[131074048]" = all_gather_copy_in[1]; all_gather_copy_in = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_functional_collectives.py:229 in all_gather_tensor, code: tensor = torch.ops._c10d_functional.all_gather_into_tensor( | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] all_gather_into_tensor: "bf16[131074048]" = torch.ops._c10d_functional.all_gather_into_tensor.default(getitem, 8, '0'); getitem = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_functional_collectives.py:144 in wait_tensor, code: return torch.ops._c10d_functional.wait_tensor(tensor) # type: ignore[attr-defined] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] wait_tensor: "bf16[131074048]" = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor); all_gather_into_tensor = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:170 in foreach_all_gather_copy_out, code: split_unpadded = torch.ops.fsdp.contiguous_view_as_strided( | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_1: "bf16[8, 16384256]" = torch.ops.aten.reshape.default(wait_tensor, [8, -1]); wait_tensor = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] split_with_sizes_1 = torch.ops.aten.split_with_sizes.default(view_1, [8192000, 256, 8192000], 1); view_1 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_5: "bf16[8, 8192000]" = split_with_sizes_1[0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] contiguous_view_as_strided: "bf16[32000, 2048]" = torch.ops.fsdp.contiguous_view_as_strided.default(getitem_5, [32000, 2048], [2048, 1]); getitem_5 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_9: "bf16[8, 256]" = split_with_sizes_1[1] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] contiguous_view_as_strided_1: "bf16[2048]" = torch.ops.fsdp.contiguous_view_as_strided.default(getitem_9, [2048], [1]); getitem_9 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_13: "bf16[8, 8192000]" = split_with_sizes_1[2]; split_with_sizes_1 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] contiguous_view_as_strided_2: "bf16[32000, 2048]" = torch.ops.fsdp.contiguous_view_as_strided.default(getitem_13, [32000, 2048], [2048, 1]); getitem_13 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/pytorch_yf225/torch/nn/modules/sparse.py:163 in forward, code: return F.embedding( | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] embedding: "bf16[1, 2048, 2048]" = torch.ops.aten.embedding.default(contiguous_view_as_strided, primals_1); contiguous_view_as_strided = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:338 in forward, code: freqs_cis = self.freqs_cis[0:seqlen] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] slice_1: "c64[2048, 64]" = torch.ops.aten.slice.Tensor(primals_8, 0, 0, 2048); primals_8 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:484 in forward, code: h = h.view(-1, self.model_args.dim) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_4: "bf16[2048, 2048]" = torch.ops.aten.reshape.default(embedding, [-1, 2048]) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_common.py:140 in _to_dtype_if_needed, code: return tensor.to(dtype) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_3: "bf16[524288]" = torch.ops.prims.convert_element_type.default(primals_9, torch.bfloat16); primals_9 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_4: "bf16[524288]" = torch.ops.prims.convert_element_type.default(primals_10, torch.bfloat16); primals_10 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_5: "bf16[524288]" = torch.ops.prims.convert_element_type.default(primals_11, torch.bfloat16); primals_11 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_6: "bf16[524288]" = torch.ops.prims.convert_element_type.default(primals_12, torch.bfloat16); primals_12 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_7: "bf16[1441792]" = torch.ops.prims.convert_element_type.default(primals_13, torch.bfloat16); primals_13 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_8: "bf16[1441792]" = torch.ops.prims.convert_element_type.default(primals_14, torch.bfloat16); primals_14 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_9: "bf16[1441792]" = torch.ops.prims.convert_element_type.default(primals_15, torch.bfloat16); primals_15 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_10: "bf16[256]" = torch.ops.prims.convert_element_type.default(primals_16, torch.bfloat16); primals_16 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_11: "bf16[256]" = torch.ops.prims.convert_element_type.default(primals_17, torch.bfloat16); primals_17 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:80 in foreach_all_gather, code: all_gather_input, all_gather_output = torch.ops.fsdp.all_gather_copy_in(all_gather_input_numel, world_size, rank, dtype, device, inp_split_sizes, param_all_gather_inputs) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] all_gather_copy_in_1 = torch.ops.fsdp.all_gather_copy_in.default(6423040, 8, 0, torch.bfloat16, device(type='cuda', index=0), [524288, 524288, 524288, 524288, 1441792, 1441792, 1441792, 256, 256], [convert_element_type_3, convert_element_type_4, convert_element_type_5, convert_element_type_6, convert_element_type_7, convert_element_type_8, convert_element_type_9, convert_element_type_10, convert_element_type_11]); convert_element_type_3 = convert_element_type_4 = convert_element_type_5 = convert_element_type_6 = convert_element_type_7 = convert_element_type_8 = convert_element_type_9 = convert_element_type_10 = convert_element_type_11 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_14: "bf16[6423040]" = all_gather_copy_in_1[0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_15: "bf16[51384320]" = all_gather_copy_in_1[1]; all_gather_copy_in_1 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_functional_collectives.py:229 in all_gather_tensor, code: tensor = torch.ops._c10d_functional.all_gather_into_tensor( | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] all_gather_into_tensor_1: "bf16[51384320]" = torch.ops._c10d_functional.all_gather_into_tensor.default(getitem_14, 8, '0'); getitem_14 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_functional_collectives.py:144 in wait_tensor, code: return torch.ops._c10d_functional.wait_tensor(tensor) # type: ignore[attr-defined] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] wait_tensor_1: "bf16[51384320]" = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_1); all_gather_into_tensor_1 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:170 in foreach_all_gather_copy_out, code: split_unpadded = torch.ops.fsdp.contiguous_view_as_strided( | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_6: "bf16[8, 6423040]" = torch.ops.aten.reshape.default(wait_tensor_1, [8, -1]); wait_tensor_1 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] split_with_sizes_5 = torch.ops.aten.split_with_sizes.default(view_6, [524288, 524288, 524288, 524288, 1441792, 1441792, 1441792, 256, 256], 1); view_6 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_25: "bf16[8, 524288]" = split_with_sizes_5[0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] contiguous_view_as_strided_3: "bf16[2048, 2048]" = torch.ops.fsdp.contiguous_view_as_strided.default(getitem_25, [2048, 2048], [2048, 1]); getitem_25 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_35: "bf16[8, 524288]" = split_with_sizes_5[1] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] contiguous_view_as_strided_4: "bf16[2048, 2048]" = torch.ops.fsdp.contiguous_view_as_strided.default(getitem_35, [2048, 2048], [2048, 1]); getitem_35 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_45: "bf16[8, 524288]" = split_with_sizes_5[2] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] contiguous_view_as_strided_5: "bf16[2048, 2048]" = torch.ops.fsdp.contiguous_view_as_strided.default(getitem_45, [2048, 2048], [2048, 1]); getitem_45 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_55: "bf16[8, 524288]" = split_with_sizes_5[3] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] contiguous_view_as_strided_6: "bf16[2048, 2048]" = torch.ops.fsdp.contiguous_view_as_strided.default(getitem_55, [2048, 2048], [2048, 1]); getitem_55 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_65: "bf16[8, 1441792]" = split_with_sizes_5[4] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] contiguous_view_as_strided_7: "bf16[5632, 2048]" = torch.ops.fsdp.contiguous_view_as_strided.default(getitem_65, [5632, 2048], [2048, 1]); getitem_65 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_75: "bf16[8, 1441792]" = split_with_sizes_5[5] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] contiguous_view_as_strided_8: "bf16[2048, 5632]" = torch.ops.fsdp.contiguous_view_as_strided.default(getitem_75, [2048, 5632], [5632, 1]); getitem_75 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_85: "bf16[8, 1441792]" = split_with_sizes_5[6] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] contiguous_view_as_strided_9: "bf16[5632, 2048]" = torch.ops.fsdp.contiguous_view_as_strided.default(getitem_85, [5632, 2048], [2048, 1]); getitem_85 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_95: "bf16[8, 256]" = split_with_sizes_5[7] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] contiguous_view_as_strided_10: "bf16[2048]" = torch.ops.fsdp.contiguous_view_as_strided.default(getitem_95, [2048], [1]); getitem_95 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_105: "bf16[8, 256]" = split_with_sizes_5[8]; split_with_sizes_5 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] contiguous_view_as_strided_11: "bf16[2048]" = torch.ops.fsdp.contiguous_view_as_strided.default(getitem_105, [2048], [1]); getitem_105 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:74 in forward, code: output = self._norm(x.float()).type_as(x) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_12: "f32[2048, 2048]" = torch.ops.prims.convert_element_type.default(view_4, torch.float32) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:61 in _norm, code: return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] pow_1: "f32[2048, 2048]" = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_12, 2) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mean: "f32[2048, 1]" = torch.ops.aten.mean.dim(pow_1, [-1], True); pow_1 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] add: "f32[2048, 1]" = torch.ops.aten.add.Tensor(mean, 1e-05); mean = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] rsqrt: "f32[2048, 1]" = torch.ops.aten.rsqrt.default(add); add = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mul: "f32[2048, 2048]" = torch.ops.aten.mul.Tensor(convert_element_type_12, rsqrt); convert_element_type_12 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:74 in forward, code: output = self._norm(x.float()).type_as(x) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_13: "bf16[2048, 2048]" = torch.ops.prims.convert_element_type.default(mul, torch.bfloat16); mul = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:75 in forward, code: return output * self.weight | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mul_1: "bf16[2048, 2048]" = torch.ops.aten.mul.Tensor(convert_element_type_13, contiguous_view_as_strided_10); convert_element_type_13 = contiguous_view_as_strided_10 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/pytorch_yf225/torch/nn/modules/linear.py:116 in forward, code: return F.linear(input, self.weight, self.bias) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] permute_1: "bf16[2048, 2048]" = torch.ops.aten.permute.default(contiguous_view_as_strided_3, [1, 0]); contiguous_view_as_strided_3 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mm: "bf16[2048, 2048]" = torch.ops.aten.mm.default(mul_1, permute_1); permute_1 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] permute_3: "bf16[2048, 2048]" = torch.ops.aten.permute.default(contiguous_view_as_strided_4, [1, 0]); contiguous_view_as_strided_4 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mm_1: "bf16[2048, 2048]" = torch.ops.aten.mm.default(mul_1, permute_3); permute_3 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] permute_5: "bf16[2048, 2048]" = torch.ops.aten.permute.default(contiguous_view_as_strided_5, [1, 0]); contiguous_view_as_strided_5 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mm_2: "bf16[2048, 2048]" = torch.ops.aten.mm.default(mul_1, permute_5); permute_5 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:235 in forward, code: xq = xq.view(bsz, seqlen, self.n_heads, self.head_dim) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_15: "bf16[1, 2048, 16, 128]" = torch.ops.aten.reshape.default(mm, [1, 2048, 16, 128]); mm = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:236 in forward, code: xk = xk.view(bsz, seqlen, self.n_kv_heads, self.head_dim) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_16: "bf16[1, 2048, 16, 128]" = torch.ops.aten.reshape.default(mm_1, [1, 2048, 16, 128]); mm_1 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:237 in forward, code: xv = xv.view(bsz, seqlen, self.n_kv_heads, self.head_dim) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_17: "bf16[1, 2048, 16, 128]" = torch.ops.aten.reshape.default(mm_2, [1, 2048, 16, 128]); mm_2 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:146 in apply_rotary_emb, code: xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_20: "f32[1, 2048, 16, 128]" = torch.ops.prims.convert_element_type.default(view_15, torch.float32); view_15 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_18: "f32[1, 2048, 16, 64, 2]" = torch.ops.aten.reshape.default(convert_element_type_20, [1, 2048, 16, -1, 2]); convert_element_type_20 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_as_complex: "c64[1, 2048, 16, 64]" = torch.ops.aten.view_as_complex.default(view_18); view_18 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:147 in apply_rotary_emb, code: xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2)) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_21: "f32[1, 2048, 16, 128]" = torch.ops.prims.convert_element_type.default(view_16, torch.float32); view_16 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_19: "f32[1, 2048, 16, 64, 2]" = torch.ops.aten.reshape.default(convert_element_type_21, [1, 2048, 16, -1, 2]); convert_element_type_21 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_as_complex_1: "c64[1, 2048, 16, 64]" = torch.ops.aten.view_as_complex.default(view_19); view_19 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:122 in reshape_for_broadcast, code: return freqs_cis.view(*shape) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_20: "c64[1, 2048, 1, 64]" = torch.ops.aten.reshape.default(slice_1, [1, 2048, 1, 64]); slice_1 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:149 in apply_rotary_emb, code: xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mul_2: "c64[1, 2048, 16, 64]" = torch.ops.aten.mul.Tensor(view_as_complex, view_20); view_as_complex = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_as_real: "f32[1, 2048, 16, 64, 2]" = torch.ops.aten.view_as_real.default(mul_2); mul_2 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_21: "f32[1, 2048, 16, 128]" = torch.ops.aten.reshape.default(view_as_real, [1, 2048, 16, 128]); view_as_real = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:150 in apply_rotary_emb, code: xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mul_3: "c64[1, 2048, 16, 64]" = torch.ops.aten.mul.Tensor(view_as_complex_1, view_20); view_as_complex_1 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_as_real_1: "f32[1, 2048, 16, 64, 2]" = torch.ops.aten.view_as_real.default(mul_3); mul_3 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_22: "f32[1, 2048, 16, 128]" = torch.ops.aten.reshape.default(view_as_real_1, [1, 2048, 16, 128]); view_as_real_1 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:151 in apply_rotary_emb, code: return xq_out.type_as(xq), xk_out.type_as(xk) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_22: "bf16[1, 2048, 16, 128]" = torch.ops.prims.convert_element_type.default(view_21, torch.bfloat16); view_21 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_23: "bf16[1, 2048, 16, 128]" = torch.ops.prims.convert_element_type.default(view_22, torch.bfloat16); view_22 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:249 in forward, code: xq = xq.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] permute_6: "bf16[1, 16, 2048, 128]" = torch.ops.aten.permute.default(convert_element_type_22, [0, 2, 1, 3]); convert_element_type_22 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:250 in forward, code: xk = keys.transpose(1, 2) # (bs, n_local_heads, cache_len + seqlen, head_dim) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] permute_7: "bf16[1, 16, 2048, 128]" = torch.ops.aten.permute.default(convert_element_type_23, [0, 2, 1, 3]); convert_element_type_23 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:251 in forward, code: xv = values.transpose(1, 2) # (bs, n_local_heads, cache_len + seqlen, head_dim) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] permute_8: "bf16[1, 16, 2048, 128]" = torch.ops.aten.permute.default(view_17, [0, 2, 1, 3]); view_17 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:254 in forward, code: output = F.scaled_dot_product_attention(xq, xk, xv, is_causal=True) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] _scaled_dot_product_flash_attention = torch.ops.aten._scaled_dot_product_flash_attention.default(permute_6, permute_7, permute_8, 0.0, True, scale = 0.08838834764831843) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_106: "bf16[1, 16, 2048, 128]" = _scaled_dot_product_flash_attention[0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_107: "f32[1, 16, 2048]" = _scaled_dot_product_flash_attention[1] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_112: "i64[]" = _scaled_dot_product_flash_attention[6] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_113: "i64[]" = _scaled_dot_product_flash_attention[7]; _scaled_dot_product_flash_attention = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:255 in forward, code: output = output.transpose( | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] permute_9: "bf16[1, 2048, 16, 128]" = torch.ops.aten.permute.default(getitem_106, [0, 2, 1, 3]) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:259 in forward, code: output = output.view(bsz * seqlen, -1) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_23: "bf16[2048, 2048]" = torch.ops.aten.reshape.default(permute_9, [2048, -1]); permute_9 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/pytorch_yf225/torch/nn/modules/linear.py:116 in forward, code: return F.linear(input, self.weight, self.bias) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] permute_11: "bf16[2048, 2048]" = torch.ops.aten.permute.default(contiguous_view_as_strided_6, [1, 0]); contiguous_view_as_strided_6 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mm_3: "bf16[2048, 2048]" = torch.ops.aten.mm.default(view_23, permute_11); permute_11 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:410 in forward, code: h = x + self.attention(self.attention_norm(x), freqs_cis) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] add_1: "bf16[2048, 2048]" = torch.ops.aten.add.Tensor(view_4, mm_3); view_4 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:74 in forward, code: output = self._norm(x.float()).type_as(x) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_26: "f32[2048, 2048]" = torch.ops.prims.convert_element_type.default(add_1, torch.float32) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:61 in _norm, code: return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] pow_2: "f32[2048, 2048]" = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_26, 2) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mean_1: "f32[2048, 1]" = torch.ops.aten.mean.dim(pow_2, [-1], True); pow_2 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] add_2: "f32[2048, 1]" = torch.ops.aten.add.Tensor(mean_1, 1e-05); mean_1 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] rsqrt_1: "f32[2048, 1]" = torch.ops.aten.rsqrt.default(add_2); add_2 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mul_4: "f32[2048, 2048]" = torch.ops.aten.mul.Tensor(convert_element_type_26, rsqrt_1); convert_element_type_26 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:74 in forward, code: output = self._norm(x.float()).type_as(x) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_27: "bf16[2048, 2048]" = torch.ops.prims.convert_element_type.default(mul_4, torch.bfloat16); mul_4 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:75 in forward, code: return output * self.weight | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mul_5: "bf16[2048, 2048]" = torch.ops.aten.mul.Tensor(convert_element_type_27, contiguous_view_as_strided_11); convert_element_type_27 = contiguous_view_as_strided_11 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/pytorch_yf225/torch/nn/modules/linear.py:116 in forward, code: return F.linear(input, self.weight, self.bias) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] permute_13: "bf16[2048, 5632]" = torch.ops.aten.permute.default(contiguous_view_as_strided_7, [1, 0]); contiguous_view_as_strided_7 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mm_4: "bf16[2048, 5632]" = torch.ops.aten.mm.default(mul_5, permute_13); permute_13 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:299 in forward, code: return self.w2(F.silu(self.w1(x)) * self.w3(x)) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_30: "f32[2048, 5632]" = torch.ops.prims.convert_element_type.default(mm_4, torch.float32) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] sigmoid: "f32[2048, 5632]" = torch.ops.aten.sigmoid.default(convert_element_type_30) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mul_6: "f32[2048, 5632]" = torch.ops.aten.mul.Tensor(convert_element_type_30, sigmoid); convert_element_type_30 = sigmoid = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_31: "bf16[2048, 5632]" = torch.ops.prims.convert_element_type.default(mul_6, torch.bfloat16); mul_6 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/pytorch_yf225/torch/nn/modules/linear.py:116 in forward, code: return F.linear(input, self.weight, self.bias) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] permute_15: "bf16[2048, 5632]" = torch.ops.aten.permute.default(contiguous_view_as_strided_9, [1, 0]); contiguous_view_as_strided_9 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mm_5: "bf16[2048, 5632]" = torch.ops.aten.mm.default(mul_5, permute_15); permute_15 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:299 in forward, code: return self.w2(F.silu(self.w1(x)) * self.w3(x)) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mul_7: "bf16[2048, 5632]" = torch.ops.aten.mul.Tensor(convert_element_type_31, mm_5); convert_element_type_31 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/pytorch_yf225/torch/nn/modules/linear.py:116 in forward, code: return F.linear(input, self.weight, self.bias) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] permute_17: "bf16[5632, 2048]" = torch.ops.aten.permute.default(contiguous_view_as_strided_8, [1, 0]); contiguous_view_as_strided_8 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mm_6: "bf16[2048, 2048]" = torch.ops.aten.mm.default(mul_7, permute_17); permute_17 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:411 in forward, code: out = h + self.feed_forward(self.ffn_norm(h)) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] add_3: "bf16[2048, 2048]" = torch.ops.aten.add.Tensor(add_1, mm_6); add_1 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_common.py:140 in _to_dtype_if_needed, code: return tensor.to(dtype) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_36: "bf16[524288]" = torch.ops.prims.convert_element_type.default(primals_28, torch.bfloat16); primals_28 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_37: "bf16[524288]" = torch.ops.prims.convert_element_type.default(primals_29, torch.bfloat16); primals_29 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_38: "bf16[524288]" = torch.ops.prims.convert_element_type.default(primals_30, torch.bfloat16); primals_30 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_39: "bf16[524288]" = torch.ops.prims.convert_element_type.default(primals_31, torch.bfloat16); primals_31 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_40: "bf16[1441792]" = torch.ops.prims.convert_element_type.default(primals_32, torch.bfloat16); primals_32 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_41: "bf16[1441792]" = torch.ops.prims.convert_element_type.default(primals_33, torch.bfloat16); primals_33 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_42: "bf16[1441792]" = torch.ops.prims.convert_element_type.default(primals_34, torch.bfloat16); primals_34 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_43: "bf16[256]" = torch.ops.prims.convert_element_type.default(primals_35, torch.bfloat16); primals_35 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_44: "bf16[256]" = torch.ops.prims.convert_element_type.default(primals_36, torch.bfloat16); primals_36 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:80 in foreach_all_gather, code: all_gather_input, all_gather_output = torch.ops.fsdp.all_gather_copy_in(all_gather_input_numel, world_size, rank, dtype, device, inp_split_sizes, param_all_gather_inputs) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] all_gather_copy_in_2 = torch.ops.fsdp.all_gather_copy_in.default(6423040, 8, 0, torch.bfloat16, device(type='cuda', index=0), [524288, 524288, 524288, 524288, 1441792, 1441792, 1441792, 256, 256], [convert_element_type_36, convert_element_type_37, convert_element_type_38, convert_element_type_39, convert_element_type_40, convert_element_type_41, convert_element_type_42, convert_element_type_43, convert_element_type_44]); convert_element_type_36 = convert_element_type_37 = convert_element_type_38 = convert_element_type_39 = convert_element_type_40 = convert_element_type_41 = convert_element_type_42 = convert_element_type_43 = convert_element_type_44 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_115: "bf16[6423040]" = all_gather_copy_in_2[0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_116: "bf16[51384320]" = all_gather_copy_in_2[1]; all_gather_copy_in_2 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_functional_collectives.py:229 in all_gather_tensor, code: tensor = torch.ops._c10d_functional.all_gather_into_tensor( | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] all_gather_into_tensor_2: "bf16[51384320]" = torch.ops._c10d_functional.all_gather_into_tensor.default(getitem_115, 8, '0'); getitem_115 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_functional_collectives.py:144 in wait_tensor, code: return torch.ops._c10d_functional.wait_tensor(tensor) # type: ignore[attr-defined] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] wait_tensor_2: "bf16[51384320]" = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_2); all_gather_into_tensor_2 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:170 in foreach_all_gather_copy_out, code: split_unpadded = torch.ops.fsdp.contiguous_view_as_strided( | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_25: "bf16[8, 6423040]" = torch.ops.aten.reshape.default(wait_tensor_2, [8, -1]); wait_tensor_2 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] split_with_sizes_15 = torch.ops.aten.split_with_sizes.default(view_25, [524288, 524288, 524288, 524288, 1441792, 1441792, 1441792, 256, 256], 1); view_25 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_126: "bf16[8, 524288]" = split_with_sizes_15[0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] contiguous_view_as_strided_12: "bf16[2048, 2048]" = torch.ops.fsdp.contiguous_view_as_strided.default(getitem_126, [2048, 2048], [2048, 1]); getitem_126 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_136: "bf16[8, 524288]" = split_with_sizes_15[1] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] contiguous_view_as_strided_13: "bf16[2048, 2048]" = torch.ops.fsdp.contiguous_view_as_strided.default(getitem_136, [2048, 2048], [2048, 1]); getitem_136 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_146: "bf16[8, 524288]" = split_with_sizes_15[2] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] contiguous_view_as_strided_14: "bf16[2048, 2048]" = torch.ops.fsdp.contiguous_view_as_strided.default(getitem_146, [2048, 2048], [2048, 1]); getitem_146 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_156: "bf16[8, 524288]" = split_with_sizes_15[3] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] contiguous_view_as_strided_15: "bf16[2048, 2048]" = torch.ops.fsdp.contiguous_view_as_strided.default(getitem_156, [2048, 2048], [2048, 1]); getitem_156 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_166: "bf16[8, 1441792]" = split_with_sizes_15[4] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] contiguous_view_as_strided_16: "bf16[5632, 2048]" = torch.ops.fsdp.contiguous_view_as_strided.default(getitem_166, [5632, 2048], [2048, 1]); getitem_166 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_176: "bf16[8, 1441792]" = split_with_sizes_15[5] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] contiguous_view_as_strided_17: "bf16[2048, 5632]" = torch.ops.fsdp.contiguous_view_as_strided.default(getitem_176, [2048, 5632], [5632, 1]); getitem_176 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_186: "bf16[8, 1441792]" = split_with_sizes_15[6] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] contiguous_view_as_strided_18: "bf16[5632, 2048]" = torch.ops.fsdp.contiguous_view_as_strided.default(getitem_186, [5632, 2048], [2048, 1]); getitem_186 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_196: "bf16[8, 256]" = split_with_sizes_15[7] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] contiguous_view_as_strided_19: "bf16[2048]" = torch.ops.fsdp.contiguous_view_as_strided.default(getitem_196, [2048], [1]); getitem_196 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_206: "bf16[8, 256]" = split_with_sizes_15[8]; split_with_sizes_15 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] contiguous_view_as_strided_20: "bf16[2048]" = torch.ops.fsdp.contiguous_view_as_strided.default(getitem_206, [2048], [1]); getitem_206 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:74 in forward, code: output = self._norm(x.float()).type_as(x) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_45: "f32[2048, 2048]" = torch.ops.prims.convert_element_type.default(add_3, torch.float32) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:61 in _norm, code: return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] pow_3: "f32[2048, 2048]" = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_45, 2) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mean_2: "f32[2048, 1]" = torch.ops.aten.mean.dim(pow_3, [-1], True); pow_3 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] add_4: "f32[2048, 1]" = torch.ops.aten.add.Tensor(mean_2, 1e-05); mean_2 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] rsqrt_2: "f32[2048, 1]" = torch.ops.aten.rsqrt.default(add_4); add_4 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mul_8: "f32[2048, 2048]" = torch.ops.aten.mul.Tensor(convert_element_type_45, rsqrt_2); convert_element_type_45 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:74 in forward, code: output = self._norm(x.float()).type_as(x) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_46: "bf16[2048, 2048]" = torch.ops.prims.convert_element_type.default(mul_8, torch.bfloat16); mul_8 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:75 in forward, code: return output * self.weight | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mul_9: "bf16[2048, 2048]" = torch.ops.aten.mul.Tensor(convert_element_type_46, contiguous_view_as_strided_19); convert_element_type_46 = contiguous_view_as_strided_19 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/pytorch_yf225/torch/nn/modules/linear.py:116 in forward, code: return F.linear(input, self.weight, self.bias) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] permute_19: "bf16[2048, 2048]" = torch.ops.aten.permute.default(contiguous_view_as_strided_12, [1, 0]); contiguous_view_as_strided_12 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mm_7: "bf16[2048, 2048]" = torch.ops.aten.mm.default(mul_9, permute_19); permute_19 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] permute_21: "bf16[2048, 2048]" = torch.ops.aten.permute.default(contiguous_view_as_strided_13, [1, 0]); contiguous_view_as_strided_13 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mm_8: "bf16[2048, 2048]" = torch.ops.aten.mm.default(mul_9, permute_21); permute_21 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] permute_23: "bf16[2048, 2048]" = torch.ops.aten.permute.default(contiguous_view_as_strided_14, [1, 0]); contiguous_view_as_strided_14 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mm_9: "bf16[2048, 2048]" = torch.ops.aten.mm.default(mul_9, permute_23); permute_23 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:235 in forward, code: xq = xq.view(bsz, seqlen, self.n_heads, self.head_dim) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_34: "bf16[1, 2048, 16, 128]" = torch.ops.aten.reshape.default(mm_7, [1, 2048, 16, 128]); mm_7 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:236 in forward, code: xk = xk.view(bsz, seqlen, self.n_kv_heads, self.head_dim) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_35: "bf16[1, 2048, 16, 128]" = torch.ops.aten.reshape.default(mm_8, [1, 2048, 16, 128]); mm_8 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:237 in forward, code: xv = xv.view(bsz, seqlen, self.n_kv_heads, self.head_dim) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_36: "bf16[1, 2048, 16, 128]" = torch.ops.aten.reshape.default(mm_9, [1, 2048, 16, 128]); mm_9 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:146 in apply_rotary_emb, code: xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_53: "f32[1, 2048, 16, 128]" = torch.ops.prims.convert_element_type.default(view_34, torch.float32); view_34 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_37: "f32[1, 2048, 16, 64, 2]" = torch.ops.aten.reshape.default(convert_element_type_53, [1, 2048, 16, -1, 2]); convert_element_type_53 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_as_complex_2: "c64[1, 2048, 16, 64]" = torch.ops.aten.view_as_complex.default(view_37); view_37 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:147 in apply_rotary_emb, code: xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2)) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_54: "f32[1, 2048, 16, 128]" = torch.ops.prims.convert_element_type.default(view_35, torch.float32); view_35 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_38: "f32[1, 2048, 16, 64, 2]" = torch.ops.aten.reshape.default(convert_element_type_54, [1, 2048, 16, -1, 2]); convert_element_type_54 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_as_complex_3: "c64[1, 2048, 16, 64]" = torch.ops.aten.view_as_complex.default(view_38); view_38 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:149 in apply_rotary_emb, code: xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mul_10: "c64[1, 2048, 16, 64]" = torch.ops.aten.mul.Tensor(view_as_complex_2, view_20); view_as_complex_2 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_as_real_2: "f32[1, 2048, 16, 64, 2]" = torch.ops.aten.view_as_real.default(mul_10); mul_10 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_40: "f32[1, 2048, 16, 128]" = torch.ops.aten.reshape.default(view_as_real_2, [1, 2048, 16, 128]); view_as_real_2 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:150 in apply_rotary_emb, code: xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mul_11: "c64[1, 2048, 16, 64]" = torch.ops.aten.mul.Tensor(view_as_complex_3, view_20); view_as_complex_3 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_as_real_3: "f32[1, 2048, 16, 64, 2]" = torch.ops.aten.view_as_real.default(mul_11); mul_11 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_41: "f32[1, 2048, 16, 128]" = torch.ops.aten.reshape.default(view_as_real_3, [1, 2048, 16, 128]); view_as_real_3 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:151 in apply_rotary_emb, code: return xq_out.type_as(xq), xk_out.type_as(xk) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_55: "bf16[1, 2048, 16, 128]" = torch.ops.prims.convert_element_type.default(view_40, torch.bfloat16); view_40 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_56: "bf16[1, 2048, 16, 128]" = torch.ops.prims.convert_element_type.default(view_41, torch.bfloat16); view_41 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:249 in forward, code: xq = xq.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] permute_24: "bf16[1, 16, 2048, 128]" = torch.ops.aten.permute.default(convert_element_type_55, [0, 2, 1, 3]); convert_element_type_55 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:250 in forward, code: xk = keys.transpose(1, 2) # (bs, n_local_heads, cache_len + seqlen, head_dim) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] permute_25: "bf16[1, 16, 2048, 128]" = torch.ops.aten.permute.default(convert_element_type_56, [0, 2, 1, 3]); convert_element_type_56 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:251 in forward, code: xv = values.transpose(1, 2) # (bs, n_local_heads, cache_len + seqlen, head_dim) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] permute_26: "bf16[1, 16, 2048, 128]" = torch.ops.aten.permute.default(view_36, [0, 2, 1, 3]); view_36 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:254 in forward, code: output = F.scaled_dot_product_attention(xq, xk, xv, is_causal=True) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] _scaled_dot_product_flash_attention_1 = torch.ops.aten._scaled_dot_product_flash_attention.default(permute_24, permute_25, permute_26, 0.0, True, scale = 0.08838834764831843) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_207: "bf16[1, 16, 2048, 128]" = _scaled_dot_product_flash_attention_1[0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_208: "f32[1, 16, 2048]" = _scaled_dot_product_flash_attention_1[1] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_213: "i64[]" = _scaled_dot_product_flash_attention_1[6] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_214: "i64[]" = _scaled_dot_product_flash_attention_1[7]; _scaled_dot_product_flash_attention_1 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:255 in forward, code: output = output.transpose( | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] permute_27: "bf16[1, 2048, 16, 128]" = torch.ops.aten.permute.default(getitem_207, [0, 2, 1, 3]) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:259 in forward, code: output = output.view(bsz * seqlen, -1) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_42: "bf16[2048, 2048]" = torch.ops.aten.reshape.default(permute_27, [2048, -1]); permute_27 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/pytorch_yf225/torch/nn/modules/linear.py:116 in forward, code: return F.linear(input, self.weight, self.bias) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] permute_29: "bf16[2048, 2048]" = torch.ops.aten.permute.default(contiguous_view_as_strided_15, [1, 0]); contiguous_view_as_strided_15 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mm_10: "bf16[2048, 2048]" = torch.ops.aten.mm.default(view_42, permute_29); permute_29 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:410 in forward, code: h = x + self.attention(self.attention_norm(x), freqs_cis) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] add_5: "bf16[2048, 2048]" = torch.ops.aten.add.Tensor(add_3, mm_10); add_3 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:74 in forward, code: output = self._norm(x.float()).type_as(x) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_59: "f32[2048, 2048]" = torch.ops.prims.convert_element_type.default(add_5, torch.float32) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:61 in _norm, code: return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] pow_4: "f32[2048, 2048]" = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_59, 2) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mean_3: "f32[2048, 1]" = torch.ops.aten.mean.dim(pow_4, [-1], True); pow_4 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] add_6: "f32[2048, 1]" = torch.ops.aten.add.Tensor(mean_3, 1e-05); mean_3 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] rsqrt_3: "f32[2048, 1]" = torch.ops.aten.rsqrt.default(add_6); add_6 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mul_12: "f32[2048, 2048]" = torch.ops.aten.mul.Tensor(convert_element_type_59, rsqrt_3); convert_element_type_59 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:74 in forward, code: output = self._norm(x.float()).type_as(x) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_60: "bf16[2048, 2048]" = torch.ops.prims.convert_element_type.default(mul_12, torch.bfloat16); mul_12 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:75 in forward, code: return output * self.weight | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mul_13: "bf16[2048, 2048]" = torch.ops.aten.mul.Tensor(convert_element_type_60, contiguous_view_as_strided_20); convert_element_type_60 = contiguous_view_as_strided_20 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/pytorch_yf225/torch/nn/modules/linear.py:116 in forward, code: return F.linear(input, self.weight, self.bias) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] permute_31: "bf16[2048, 5632]" = torch.ops.aten.permute.default(contiguous_view_as_strided_16, [1, 0]); contiguous_view_as_strided_16 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mm_11: "bf16[2048, 5632]" = torch.ops.aten.mm.default(mul_13, permute_31); permute_31 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:299 in forward, code: return self.w2(F.silu(self.w1(x)) * self.w3(x)) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_63: "f32[2048, 5632]" = torch.ops.prims.convert_element_type.default(mm_11, torch.float32) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] sigmoid_1: "f32[2048, 5632]" = torch.ops.aten.sigmoid.default(convert_element_type_63) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mul_14: "f32[2048, 5632]" = torch.ops.aten.mul.Tensor(convert_element_type_63, sigmoid_1); convert_element_type_63 = sigmoid_1 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_64: "bf16[2048, 5632]" = torch.ops.prims.convert_element_type.default(mul_14, torch.bfloat16); mul_14 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/pytorch_yf225/torch/nn/modules/linear.py:116 in forward, code: return F.linear(input, self.weight, self.bias) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] permute_33: "bf16[2048, 5632]" = torch.ops.aten.permute.default(contiguous_view_as_strided_18, [1, 0]); contiguous_view_as_strided_18 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mm_12: "bf16[2048, 5632]" = torch.ops.aten.mm.default(mul_13, permute_33); permute_33 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:299 in forward, code: return self.w2(F.silu(self.w1(x)) * self.w3(x)) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mul_15: "bf16[2048, 5632]" = torch.ops.aten.mul.Tensor(convert_element_type_64, mm_12); convert_element_type_64 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/pytorch_yf225/torch/nn/modules/linear.py:116 in forward, code: return F.linear(input, self.weight, self.bias) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] permute_35: "bf16[5632, 2048]" = torch.ops.aten.permute.default(contiguous_view_as_strided_17, [1, 0]); contiguous_view_as_strided_17 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mm_13: "bf16[2048, 2048]" = torch.ops.aten.mm.default(mul_15, permute_35); permute_35 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:411 in forward, code: out = h + self.feed_forward(self.ffn_norm(h)) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] add_7: "bf16[2048, 2048]" = torch.ops.aten.add.Tensor(add_5, mm_13); add_5 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_common.py:140 in _to_dtype_if_needed, code: return tensor.to(dtype) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_69: "bf16[524288]" = torch.ops.prims.convert_element_type.default(primals_46, torch.bfloat16); primals_46 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_70: "bf16[524288]" = torch.ops.prims.convert_element_type.default(primals_47, torch.bfloat16); primals_47 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_71: "bf16[524288]" = torch.ops.prims.convert_element_type.default(primals_48, torch.bfloat16); primals_48 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_72: "bf16[524288]" = torch.ops.prims.convert_element_type.default(primals_49, torch.bfloat16); primals_49 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_73: "bf16[1441792]" = torch.ops.prims.convert_element_type.default(primals_50, torch.bfloat16); primals_50 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_74: "bf16[1441792]" = torch.ops.prims.convert_element_type.default(primals_51, torch.bfloat16); primals_51 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_75: "bf16[1441792]" = torch.ops.prims.convert_element_type.default(primals_52, torch.bfloat16); primals_52 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_76: "bf16[256]" = torch.ops.prims.convert_element_type.default(primals_53, torch.bfloat16); primals_53 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_77: "bf16[256]" = torch.ops.prims.convert_element_type.default(primals_54, torch.bfloat16); primals_54 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:80 in foreach_all_gather, code: all_gather_input, all_gather_output = torch.ops.fsdp.all_gather_copy_in(all_gather_input_numel, world_size, rank, dtype, device, inp_split_sizes, param_all_gather_inputs) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] all_gather_copy_in_3 = torch.ops.fsdp.all_gather_copy_in.default(6423040, 8, 0, torch.bfloat16, device(type='cuda', index=0), [524288, 524288, 524288, 524288, 1441792, 1441792, 1441792, 256, 256], [convert_element_type_69, convert_element_type_70, convert_element_type_71, convert_element_type_72, convert_element_type_73, convert_element_type_74, convert_element_type_75, convert_element_type_76, convert_element_type_77]); convert_element_type_69 = convert_element_type_70 = convert_element_type_71 = convert_element_type_72 = convert_element_type_73 = convert_element_type_74 = convert_element_type_75 = convert_element_type_76 = convert_element_type_77 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_216: "bf16[6423040]" = all_gather_copy_in_3[0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_217: "bf16[51384320]" = all_gather_copy_in_3[1]; all_gather_copy_in_3 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_functional_collectives.py:229 in all_gather_tensor, code: tensor = torch.ops._c10d_functional.all_gather_into_tensor( | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] all_gather_into_tensor_3: "bf16[51384320]" = torch.ops._c10d_functional.all_gather_into_tensor.default(getitem_216, 8, '0'); getitem_216 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_functional_collectives.py:144 in wait_tensor, code: return torch.ops._c10d_functional.wait_tensor(tensor) # type: ignore[attr-defined] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] wait_tensor_3: "bf16[51384320]" = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_3); all_gather_into_tensor_3 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:170 in foreach_all_gather_copy_out, code: split_unpadded = torch.ops.fsdp.contiguous_view_as_strided( | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_44: "bf16[8, 6423040]" = torch.ops.aten.reshape.default(wait_tensor_3, [8, -1]); wait_tensor_3 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] split_with_sizes_25 = torch.ops.aten.split_with_sizes.default(view_44, [524288, 524288, 524288, 524288, 1441792, 1441792, 1441792, 256, 256], 1); view_44 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_227: "bf16[8, 524288]" = split_with_sizes_25[0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] contiguous_view_as_strided_21: "bf16[2048, 2048]" = torch.ops.fsdp.contiguous_view_as_strided.default(getitem_227, [2048, 2048], [2048, 1]); getitem_227 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_237: "bf16[8, 524288]" = split_with_sizes_25[1] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] contiguous_view_as_strided_22: "bf16[2048, 2048]" = torch.ops.fsdp.contiguous_view_as_strided.default(getitem_237, [2048, 2048], [2048, 1]); getitem_237 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_247: "bf16[8, 524288]" = split_with_sizes_25[2] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] contiguous_view_as_strided_23: "bf16[2048, 2048]" = torch.ops.fsdp.contiguous_view_as_strided.default(getitem_247, [2048, 2048], [2048, 1]); getitem_247 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_257: "bf16[8, 524288]" = split_with_sizes_25[3] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] contiguous_view_as_strided_24: "bf16[2048, 2048]" = torch.ops.fsdp.contiguous_view_as_strided.default(getitem_257, [2048, 2048], [2048, 1]); getitem_257 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_267: "bf16[8, 1441792]" = split_with_sizes_25[4] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] contiguous_view_as_strided_25: "bf16[5632, 2048]" = torch.ops.fsdp.contiguous_view_as_strided.default(getitem_267, [5632, 2048], [2048, 1]); getitem_267 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_277: "bf16[8, 1441792]" = split_with_sizes_25[5] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] contiguous_view_as_strided_26: "bf16[2048, 5632]" = torch.ops.fsdp.contiguous_view_as_strided.default(getitem_277, [2048, 5632], [5632, 1]); getitem_277 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_287: "bf16[8, 1441792]" = split_with_sizes_25[6] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] contiguous_view_as_strided_27: "bf16[5632, 2048]" = torch.ops.fsdp.contiguous_view_as_strided.default(getitem_287, [5632, 2048], [2048, 1]); getitem_287 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_297: "bf16[8, 256]" = split_with_sizes_25[7] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] contiguous_view_as_strided_28: "bf16[2048]" = torch.ops.fsdp.contiguous_view_as_strided.default(getitem_297, [2048], [1]); getitem_297 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_307: "bf16[8, 256]" = split_with_sizes_25[8]; split_with_sizes_25 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] contiguous_view_as_strided_29: "bf16[2048]" = torch.ops.fsdp.contiguous_view_as_strided.default(getitem_307, [2048], [1]); getitem_307 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:74 in forward, code: output = self._norm(x.float()).type_as(x) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_78: "f32[2048, 2048]" = torch.ops.prims.convert_element_type.default(add_7, torch.float32) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:61 in _norm, code: return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] pow_5: "f32[2048, 2048]" = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_78, 2) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mean_4: "f32[2048, 1]" = torch.ops.aten.mean.dim(pow_5, [-1], True); pow_5 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] add_8: "f32[2048, 1]" = torch.ops.aten.add.Tensor(mean_4, 1e-05); mean_4 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] rsqrt_4: "f32[2048, 1]" = torch.ops.aten.rsqrt.default(add_8); add_8 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mul_16: "f32[2048, 2048]" = torch.ops.aten.mul.Tensor(convert_element_type_78, rsqrt_4); convert_element_type_78 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:74 in forward, code: output = self._norm(x.float()).type_as(x) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_79: "bf16[2048, 2048]" = torch.ops.prims.convert_element_type.default(mul_16, torch.bfloat16); mul_16 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:75 in forward, code: return output * self.weight | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mul_17: "bf16[2048, 2048]" = torch.ops.aten.mul.Tensor(convert_element_type_79, contiguous_view_as_strided_28); convert_element_type_79 = contiguous_view_as_strided_28 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/pytorch_yf225/torch/nn/modules/linear.py:116 in forward, code: return F.linear(input, self.weight, self.bias) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] permute_37: "bf16[2048, 2048]" = torch.ops.aten.permute.default(contiguous_view_as_strided_21, [1, 0]); contiguous_view_as_strided_21 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mm_14: "bf16[2048, 2048]" = torch.ops.aten.mm.default(mul_17, permute_37); permute_37 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] permute_39: "bf16[2048, 2048]" = torch.ops.aten.permute.default(contiguous_view_as_strided_22, [1, 0]); contiguous_view_as_strided_22 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mm_15: "bf16[2048, 2048]" = torch.ops.aten.mm.default(mul_17, permute_39); permute_39 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] permute_41: "bf16[2048, 2048]" = torch.ops.aten.permute.default(contiguous_view_as_strided_23, [1, 0]); contiguous_view_as_strided_23 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mm_16: "bf16[2048, 2048]" = torch.ops.aten.mm.default(mul_17, permute_41); permute_41 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:235 in forward, code: xq = xq.view(bsz, seqlen, self.n_heads, self.head_dim) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_53: "bf16[1, 2048, 16, 128]" = torch.ops.aten.reshape.default(mm_14, [1, 2048, 16, 128]); mm_14 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:236 in forward, code: xk = xk.view(bsz, seqlen, self.n_kv_heads, self.head_dim) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_54: "bf16[1, 2048, 16, 128]" = torch.ops.aten.reshape.default(mm_15, [1, 2048, 16, 128]); mm_15 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:237 in forward, code: xv = xv.view(bsz, seqlen, self.n_kv_heads, self.head_dim) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_55: "bf16[1, 2048, 16, 128]" = torch.ops.aten.reshape.default(mm_16, [1, 2048, 16, 128]); mm_16 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:146 in apply_rotary_emb, code: xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_86: "f32[1, 2048, 16, 128]" = torch.ops.prims.convert_element_type.default(view_53, torch.float32); view_53 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_56: "f32[1, 2048, 16, 64, 2]" = torch.ops.aten.reshape.default(convert_element_type_86, [1, 2048, 16, -1, 2]); convert_element_type_86 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_as_complex_4: "c64[1, 2048, 16, 64]" = torch.ops.aten.view_as_complex.default(view_56); view_56 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:147 in apply_rotary_emb, code: xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2)) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_87: "f32[1, 2048, 16, 128]" = torch.ops.prims.convert_element_type.default(view_54, torch.float32); view_54 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_57: "f32[1, 2048, 16, 64, 2]" = torch.ops.aten.reshape.default(convert_element_type_87, [1, 2048, 16, -1, 2]); convert_element_type_87 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_as_complex_5: "c64[1, 2048, 16, 64]" = torch.ops.aten.view_as_complex.default(view_57); view_57 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:149 in apply_rotary_emb, code: xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mul_18: "c64[1, 2048, 16, 64]" = torch.ops.aten.mul.Tensor(view_as_complex_4, view_20); view_as_complex_4 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_as_real_4: "f32[1, 2048, 16, 64, 2]" = torch.ops.aten.view_as_real.default(mul_18); mul_18 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_59: "f32[1, 2048, 16, 128]" = torch.ops.aten.reshape.default(view_as_real_4, [1, 2048, 16, 128]); view_as_real_4 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:150 in apply_rotary_emb, code: xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mul_19: "c64[1, 2048, 16, 64]" = torch.ops.aten.mul.Tensor(view_as_complex_5, view_20); view_as_complex_5 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_as_real_5: "f32[1, 2048, 16, 64, 2]" = torch.ops.aten.view_as_real.default(mul_19); mul_19 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_60: "f32[1, 2048, 16, 128]" = torch.ops.aten.reshape.default(view_as_real_5, [1, 2048, 16, 128]); view_as_real_5 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:151 in apply_rotary_emb, code: return xq_out.type_as(xq), xk_out.type_as(xk) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_88: "bf16[1, 2048, 16, 128]" = torch.ops.prims.convert_element_type.default(view_59, torch.bfloat16); view_59 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_89: "bf16[1, 2048, 16, 128]" = torch.ops.prims.convert_element_type.default(view_60, torch.bfloat16); view_60 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:249 in forward, code: xq = xq.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] permute_42: "bf16[1, 16, 2048, 128]" = torch.ops.aten.permute.default(convert_element_type_88, [0, 2, 1, 3]); convert_element_type_88 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:250 in forward, code: xk = keys.transpose(1, 2) # (bs, n_local_heads, cache_len + seqlen, head_dim) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] permute_43: "bf16[1, 16, 2048, 128]" = torch.ops.aten.permute.default(convert_element_type_89, [0, 2, 1, 3]); convert_element_type_89 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:251 in forward, code: xv = values.transpose(1, 2) # (bs, n_local_heads, cache_len + seqlen, head_dim) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] permute_44: "bf16[1, 16, 2048, 128]" = torch.ops.aten.permute.default(view_55, [0, 2, 1, 3]); view_55 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:254 in forward, code: output = F.scaled_dot_product_attention(xq, xk, xv, is_causal=True) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] _scaled_dot_product_flash_attention_2 = torch.ops.aten._scaled_dot_product_flash_attention.default(permute_42, permute_43, permute_44, 0.0, True, scale = 0.08838834764831843) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_308: "bf16[1, 16, 2048, 128]" = _scaled_dot_product_flash_attention_2[0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_309: "f32[1, 16, 2048]" = _scaled_dot_product_flash_attention_2[1] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_314: "i64[]" = _scaled_dot_product_flash_attention_2[6] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_315: "i64[]" = _scaled_dot_product_flash_attention_2[7]; _scaled_dot_product_flash_attention_2 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:255 in forward, code: output = output.transpose( | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] permute_45: "bf16[1, 2048, 16, 128]" = torch.ops.aten.permute.default(getitem_308, [0, 2, 1, 3]) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:259 in forward, code: output = output.view(bsz * seqlen, -1) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_61: "bf16[2048, 2048]" = torch.ops.aten.reshape.default(permute_45, [2048, -1]); permute_45 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/pytorch_yf225/torch/nn/modules/linear.py:116 in forward, code: return F.linear(input, self.weight, self.bias) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] permute_47: "bf16[2048, 2048]" = torch.ops.aten.permute.default(contiguous_view_as_strided_24, [1, 0]); contiguous_view_as_strided_24 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mm_17: "bf16[2048, 2048]" = torch.ops.aten.mm.default(view_61, permute_47); permute_47 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:410 in forward, code: h = x + self.attention(self.attention_norm(x), freqs_cis) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] add_9: "bf16[2048, 2048]" = torch.ops.aten.add.Tensor(add_7, mm_17); add_7 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:74 in forward, code: output = self._norm(x.float()).type_as(x) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_92: "f32[2048, 2048]" = torch.ops.prims.convert_element_type.default(add_9, torch.float32) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:61 in _norm, code: return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] pow_6: "f32[2048, 2048]" = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_92, 2) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mean_5: "f32[2048, 1]" = torch.ops.aten.mean.dim(pow_6, [-1], True); pow_6 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] add_10: "f32[2048, 1]" = torch.ops.aten.add.Tensor(mean_5, 1e-05); mean_5 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] rsqrt_5: "f32[2048, 1]" = torch.ops.aten.rsqrt.default(add_10); add_10 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mul_20: "f32[2048, 2048]" = torch.ops.aten.mul.Tensor(convert_element_type_92, rsqrt_5); convert_element_type_92 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:74 in forward, code: output = self._norm(x.float()).type_as(x) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_93: "bf16[2048, 2048]" = torch.ops.prims.convert_element_type.default(mul_20, torch.bfloat16); mul_20 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:75 in forward, code: return output * self.weight | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mul_21: "bf16[2048, 2048]" = torch.ops.aten.mul.Tensor(convert_element_type_93, contiguous_view_as_strided_29); convert_element_type_93 = contiguous_view_as_strided_29 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/pytorch_yf225/torch/nn/modules/linear.py:116 in forward, code: return F.linear(input, self.weight, self.bias) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] permute_49: "bf16[2048, 5632]" = torch.ops.aten.permute.default(contiguous_view_as_strided_25, [1, 0]); contiguous_view_as_strided_25 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mm_18: "bf16[2048, 5632]" = torch.ops.aten.mm.default(mul_21, permute_49); permute_49 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:299 in forward, code: return self.w2(F.silu(self.w1(x)) * self.w3(x)) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_96: "f32[2048, 5632]" = torch.ops.prims.convert_element_type.default(mm_18, torch.float32) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] sigmoid_2: "f32[2048, 5632]" = torch.ops.aten.sigmoid.default(convert_element_type_96) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mul_22: "f32[2048, 5632]" = torch.ops.aten.mul.Tensor(convert_element_type_96, sigmoid_2); convert_element_type_96 = sigmoid_2 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_97: "bf16[2048, 5632]" = torch.ops.prims.convert_element_type.default(mul_22, torch.bfloat16); mul_22 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/pytorch_yf225/torch/nn/modules/linear.py:116 in forward, code: return F.linear(input, self.weight, self.bias) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] permute_51: "bf16[2048, 5632]" = torch.ops.aten.permute.default(contiguous_view_as_strided_27, [1, 0]); contiguous_view_as_strided_27 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mm_19: "bf16[2048, 5632]" = torch.ops.aten.mm.default(mul_21, permute_51); permute_51 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:299 in forward, code: return self.w2(F.silu(self.w1(x)) * self.w3(x)) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mul_23: "bf16[2048, 5632]" = torch.ops.aten.mul.Tensor(convert_element_type_97, mm_19); convert_element_type_97 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/pytorch_yf225/torch/nn/modules/linear.py:116 in forward, code: return F.linear(input, self.weight, self.bias) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] permute_53: "bf16[5632, 2048]" = torch.ops.aten.permute.default(contiguous_view_as_strided_26, [1, 0]); contiguous_view_as_strided_26 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mm_20: "bf16[2048, 2048]" = torch.ops.aten.mm.default(mul_23, permute_53); permute_53 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:411 in forward, code: out = h + self.feed_forward(self.ffn_norm(h)) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] add_11: "bf16[2048, 2048]" = torch.ops.aten.add.Tensor(add_9, mm_20); add_9 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_common.py:140 in _to_dtype_if_needed, code: return tensor.to(dtype) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_102: "bf16[524288]" = torch.ops.prims.convert_element_type.default(primals_64, torch.bfloat16); primals_64 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_103: "bf16[524288]" = torch.ops.prims.convert_element_type.default(primals_65, torch.bfloat16); primals_65 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_104: "bf16[524288]" = torch.ops.prims.convert_element_type.default(primals_66, torch.bfloat16); primals_66 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_105: "bf16[524288]" = torch.ops.prims.convert_element_type.default(primals_67, torch.bfloat16); primals_67 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_106: "bf16[1441792]" = torch.ops.prims.convert_element_type.default(primals_68, torch.bfloat16); primals_68 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_107: "bf16[1441792]" = torch.ops.prims.convert_element_type.default(primals_69, torch.bfloat16); primals_69 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_108: "bf16[1441792]" = torch.ops.prims.convert_element_type.default(primals_70, torch.bfloat16); primals_70 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_109: "bf16[256]" = torch.ops.prims.convert_element_type.default(primals_71, torch.bfloat16); primals_71 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_110: "bf16[256]" = torch.ops.prims.convert_element_type.default(primals_72, torch.bfloat16); primals_72 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:80 in foreach_all_gather, code: all_gather_input, all_gather_output = torch.ops.fsdp.all_gather_copy_in(all_gather_input_numel, world_size, rank, dtype, device, inp_split_sizes, param_all_gather_inputs) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] all_gather_copy_in_4 = torch.ops.fsdp.all_gather_copy_in.default(6423040, 8, 0, torch.bfloat16, device(type='cuda', index=0), [524288, 524288, 524288, 524288, 1441792, 1441792, 1441792, 256, 256], [convert_element_type_102, convert_element_type_103, convert_element_type_104, convert_element_type_105, convert_element_type_106, convert_element_type_107, convert_element_type_108, convert_element_type_109, convert_element_type_110]); convert_element_type_102 = convert_element_type_103 = convert_element_type_104 = convert_element_type_105 = convert_element_type_106 = convert_element_type_107 = convert_element_type_108 = convert_element_type_109 = convert_element_type_110 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_317: "bf16[6423040]" = all_gather_copy_in_4[0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_318: "bf16[51384320]" = all_gather_copy_in_4[1]; all_gather_copy_in_4 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_functional_collectives.py:229 in all_gather_tensor, code: tensor = torch.ops._c10d_functional.all_gather_into_tensor( | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] all_gather_into_tensor_4: "bf16[51384320]" = torch.ops._c10d_functional.all_gather_into_tensor.default(getitem_317, 8, '0'); getitem_317 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_functional_collectives.py:144 in wait_tensor, code: return torch.ops._c10d_functional.wait_tensor(tensor) # type: ignore[attr-defined] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] wait_tensor_4: "bf16[51384320]" = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_4); all_gather_into_tensor_4 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:170 in foreach_all_gather_copy_out, code: split_unpadded = torch.ops.fsdp.contiguous_view_as_strided( | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_63: "bf16[8, 6423040]" = torch.ops.aten.reshape.default(wait_tensor_4, [8, -1]); wait_tensor_4 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] split_with_sizes_35 = torch.ops.aten.split_with_sizes.default(view_63, [524288, 524288, 524288, 524288, 1441792, 1441792, 1441792, 256, 256], 1); view_63 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_328: "bf16[8, 524288]" = split_with_sizes_35[0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] contiguous_view_as_strided_30: "bf16[2048, 2048]" = torch.ops.fsdp.contiguous_view_as_strided.default(getitem_328, [2048, 2048], [2048, 1]); getitem_328 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_338: "bf16[8, 524288]" = split_with_sizes_35[1] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] contiguous_view_as_strided_31: "bf16[2048, 2048]" = torch.ops.fsdp.contiguous_view_as_strided.default(getitem_338, [2048, 2048], [2048, 1]); getitem_338 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_348: "bf16[8, 524288]" = split_with_sizes_35[2] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] contiguous_view_as_strided_32: "bf16[2048, 2048]" = torch.ops.fsdp.contiguous_view_as_strided.default(getitem_348, [2048, 2048], [2048, 1]); getitem_348 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_358: "bf16[8, 524288]" = split_with_sizes_35[3] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] contiguous_view_as_strided_33: "bf16[2048, 2048]" = torch.ops.fsdp.contiguous_view_as_strided.default(getitem_358, [2048, 2048], [2048, 1]); getitem_358 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_368: "bf16[8, 1441792]" = split_with_sizes_35[4] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] contiguous_view_as_strided_34: "bf16[5632, 2048]" = torch.ops.fsdp.contiguous_view_as_strided.default(getitem_368, [5632, 2048], [2048, 1]); getitem_368 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_378: "bf16[8, 1441792]" = split_with_sizes_35[5] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] contiguous_view_as_strided_35: "bf16[2048, 5632]" = torch.ops.fsdp.contiguous_view_as_strided.default(getitem_378, [2048, 5632], [5632, 1]); getitem_378 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_388: "bf16[8, 1441792]" = split_with_sizes_35[6] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] contiguous_view_as_strided_36: "bf16[5632, 2048]" = torch.ops.fsdp.contiguous_view_as_strided.default(getitem_388, [5632, 2048], [2048, 1]); getitem_388 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_398: "bf16[8, 256]" = split_with_sizes_35[7] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] contiguous_view_as_strided_37: "bf16[2048]" = torch.ops.fsdp.contiguous_view_as_strided.default(getitem_398, [2048], [1]); getitem_398 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_408: "bf16[8, 256]" = split_with_sizes_35[8]; split_with_sizes_35 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] contiguous_view_as_strided_38: "bf16[2048]" = torch.ops.fsdp.contiguous_view_as_strided.default(getitem_408, [2048], [1]); getitem_408 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:74 in forward, code: output = self._norm(x.float()).type_as(x) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_111: "f32[2048, 2048]" = torch.ops.prims.convert_element_type.default(add_11, torch.float32) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:61 in _norm, code: return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] pow_7: "f32[2048, 2048]" = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_111, 2) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mean_6: "f32[2048, 1]" = torch.ops.aten.mean.dim(pow_7, [-1], True); pow_7 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] add_12: "f32[2048, 1]" = torch.ops.aten.add.Tensor(mean_6, 1e-05); mean_6 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] rsqrt_6: "f32[2048, 1]" = torch.ops.aten.rsqrt.default(add_12); add_12 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mul_24: "f32[2048, 2048]" = torch.ops.aten.mul.Tensor(convert_element_type_111, rsqrt_6); convert_element_type_111 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:74 in forward, code: output = self._norm(x.float()).type_as(x) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_112: "bf16[2048, 2048]" = torch.ops.prims.convert_element_type.default(mul_24, torch.bfloat16); mul_24 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:75 in forward, code: return output * self.weight | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mul_25: "bf16[2048, 2048]" = torch.ops.aten.mul.Tensor(convert_element_type_112, contiguous_view_as_strided_37); convert_element_type_112 = contiguous_view_as_strided_37 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/pytorch_yf225/torch/nn/modules/linear.py:116 in forward, code: return F.linear(input, self.weight, self.bias) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] permute_55: "bf16[2048, 2048]" = torch.ops.aten.permute.default(contiguous_view_as_strided_30, [1, 0]); contiguous_view_as_strided_30 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mm_21: "bf16[2048, 2048]" = torch.ops.aten.mm.default(mul_25, permute_55); permute_55 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] permute_57: "bf16[2048, 2048]" = torch.ops.aten.permute.default(contiguous_view_as_strided_31, [1, 0]); contiguous_view_as_strided_31 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mm_22: "bf16[2048, 2048]" = torch.ops.aten.mm.default(mul_25, permute_57); permute_57 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] permute_59: "bf16[2048, 2048]" = torch.ops.aten.permute.default(contiguous_view_as_strided_32, [1, 0]); contiguous_view_as_strided_32 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mm_23: "bf16[2048, 2048]" = torch.ops.aten.mm.default(mul_25, permute_59); permute_59 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:235 in forward, code: xq = xq.view(bsz, seqlen, self.n_heads, self.head_dim) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_72: "bf16[1, 2048, 16, 128]" = torch.ops.aten.reshape.default(mm_21, [1, 2048, 16, 128]); mm_21 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:236 in forward, code: xk = xk.view(bsz, seqlen, self.n_kv_heads, self.head_dim) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_73: "bf16[1, 2048, 16, 128]" = torch.ops.aten.reshape.default(mm_22, [1, 2048, 16, 128]); mm_22 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:237 in forward, code: xv = xv.view(bsz, seqlen, self.n_kv_heads, self.head_dim) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_74: "bf16[1, 2048, 16, 128]" = torch.ops.aten.reshape.default(mm_23, [1, 2048, 16, 128]); mm_23 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:146 in apply_rotary_emb, code: xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_119: "f32[1, 2048, 16, 128]" = torch.ops.prims.convert_element_type.default(view_72, torch.float32); view_72 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_75: "f32[1, 2048, 16, 64, 2]" = torch.ops.aten.reshape.default(convert_element_type_119, [1, 2048, 16, -1, 2]); convert_element_type_119 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_as_complex_6: "c64[1, 2048, 16, 64]" = torch.ops.aten.view_as_complex.default(view_75); view_75 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:147 in apply_rotary_emb, code: xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2)) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_120: "f32[1, 2048, 16, 128]" = torch.ops.prims.convert_element_type.default(view_73, torch.float32); view_73 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_76: "f32[1, 2048, 16, 64, 2]" = torch.ops.aten.reshape.default(convert_element_type_120, [1, 2048, 16, -1, 2]); convert_element_type_120 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_as_complex_7: "c64[1, 2048, 16, 64]" = torch.ops.aten.view_as_complex.default(view_76); view_76 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:149 in apply_rotary_emb, code: xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mul_26: "c64[1, 2048, 16, 64]" = torch.ops.aten.mul.Tensor(view_as_complex_6, view_20); view_as_complex_6 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_as_real_6: "f32[1, 2048, 16, 64, 2]" = torch.ops.aten.view_as_real.default(mul_26); mul_26 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_78: "f32[1, 2048, 16, 128]" = torch.ops.aten.reshape.default(view_as_real_6, [1, 2048, 16, 128]); view_as_real_6 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:150 in apply_rotary_emb, code: xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mul_27: "c64[1, 2048, 16, 64]" = torch.ops.aten.mul.Tensor(view_as_complex_7, view_20); view_as_complex_7 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_as_real_7: "f32[1, 2048, 16, 64, 2]" = torch.ops.aten.view_as_real.default(mul_27); mul_27 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_79: "f32[1, 2048, 16, 128]" = torch.ops.aten.reshape.default(view_as_real_7, [1, 2048, 16, 128]); view_as_real_7 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:151 in apply_rotary_emb, code: return xq_out.type_as(xq), xk_out.type_as(xk) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_121: "bf16[1, 2048, 16, 128]" = torch.ops.prims.convert_element_type.default(view_78, torch.bfloat16); view_78 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_122: "bf16[1, 2048, 16, 128]" = torch.ops.prims.convert_element_type.default(view_79, torch.bfloat16); view_79 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:249 in forward, code: xq = xq.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] permute_60: "bf16[1, 16, 2048, 128]" = torch.ops.aten.permute.default(convert_element_type_121, [0, 2, 1, 3]); convert_element_type_121 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:250 in forward, code: xk = keys.transpose(1, 2) # (bs, n_local_heads, cache_len + seqlen, head_dim) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] permute_61: "bf16[1, 16, 2048, 128]" = torch.ops.aten.permute.default(convert_element_type_122, [0, 2, 1, 3]); convert_element_type_122 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:251 in forward, code: xv = values.transpose(1, 2) # (bs, n_local_heads, cache_len + seqlen, head_dim) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] permute_62: "bf16[1, 16, 2048, 128]" = torch.ops.aten.permute.default(view_74, [0, 2, 1, 3]); view_74 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:254 in forward, code: output = F.scaled_dot_product_attention(xq, xk, xv, is_causal=True) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] _scaled_dot_product_flash_attention_3 = torch.ops.aten._scaled_dot_product_flash_attention.default(permute_60, permute_61, permute_62, 0.0, True, scale = 0.08838834764831843) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_409: "bf16[1, 16, 2048, 128]" = _scaled_dot_product_flash_attention_3[0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_410: "f32[1, 16, 2048]" = _scaled_dot_product_flash_attention_3[1] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_415: "i64[]" = _scaled_dot_product_flash_attention_3[6] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_416: "i64[]" = _scaled_dot_product_flash_attention_3[7]; _scaled_dot_product_flash_attention_3 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:255 in forward, code: output = output.transpose( | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] permute_63: "bf16[1, 2048, 16, 128]" = torch.ops.aten.permute.default(getitem_409, [0, 2, 1, 3]) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:259 in forward, code: output = output.view(bsz * seqlen, -1) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_80: "bf16[2048, 2048]" = torch.ops.aten.reshape.default(permute_63, [2048, -1]); permute_63 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/pytorch_yf225/torch/nn/modules/linear.py:116 in forward, code: return F.linear(input, self.weight, self.bias) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] permute_65: "bf16[2048, 2048]" = torch.ops.aten.permute.default(contiguous_view_as_strided_33, [1, 0]); contiguous_view_as_strided_33 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mm_24: "bf16[2048, 2048]" = torch.ops.aten.mm.default(view_80, permute_65); permute_65 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:410 in forward, code: h = x + self.attention(self.attention_norm(x), freqs_cis) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] add_13: "bf16[2048, 2048]" = torch.ops.aten.add.Tensor(add_11, mm_24); add_11 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:74 in forward, code: output = self._norm(x.float()).type_as(x) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_125: "f32[2048, 2048]" = torch.ops.prims.convert_element_type.default(add_13, torch.float32) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:61 in _norm, code: return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] pow_8: "f32[2048, 2048]" = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_125, 2) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mean_7: "f32[2048, 1]" = torch.ops.aten.mean.dim(pow_8, [-1], True); pow_8 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] add_14: "f32[2048, 1]" = torch.ops.aten.add.Tensor(mean_7, 1e-05); mean_7 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] rsqrt_7: "f32[2048, 1]" = torch.ops.aten.rsqrt.default(add_14); add_14 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mul_28: "f32[2048, 2048]" = torch.ops.aten.mul.Tensor(convert_element_type_125, rsqrt_7); convert_element_type_125 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:74 in forward, code: output = self._norm(x.float()).type_as(x) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_126: "bf16[2048, 2048]" = torch.ops.prims.convert_element_type.default(mul_28, torch.bfloat16); mul_28 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:75 in forward, code: return output * self.weight | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mul_29: "bf16[2048, 2048]" = torch.ops.aten.mul.Tensor(convert_element_type_126, contiguous_view_as_strided_38); convert_element_type_126 = contiguous_view_as_strided_38 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/pytorch_yf225/torch/nn/modules/linear.py:116 in forward, code: return F.linear(input, self.weight, self.bias) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] permute_67: "bf16[2048, 5632]" = torch.ops.aten.permute.default(contiguous_view_as_strided_34, [1, 0]); contiguous_view_as_strided_34 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mm_25: "bf16[2048, 5632]" = torch.ops.aten.mm.default(mul_29, permute_67); permute_67 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:299 in forward, code: return self.w2(F.silu(self.w1(x)) * self.w3(x)) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_129: "f32[2048, 5632]" = torch.ops.prims.convert_element_type.default(mm_25, torch.float32) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] sigmoid_3: "f32[2048, 5632]" = torch.ops.aten.sigmoid.default(convert_element_type_129) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mul_30: "f32[2048, 5632]" = torch.ops.aten.mul.Tensor(convert_element_type_129, sigmoid_3); convert_element_type_129 = sigmoid_3 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_130: "bf16[2048, 5632]" = torch.ops.prims.convert_element_type.default(mul_30, torch.bfloat16); mul_30 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/pytorch_yf225/torch/nn/modules/linear.py:116 in forward, code: return F.linear(input, self.weight, self.bias) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] permute_69: "bf16[2048, 5632]" = torch.ops.aten.permute.default(contiguous_view_as_strided_36, [1, 0]); contiguous_view_as_strided_36 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mm_26: "bf16[2048, 5632]" = torch.ops.aten.mm.default(mul_29, permute_69); permute_69 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:299 in forward, code: return self.w2(F.silu(self.w1(x)) * self.w3(x)) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mul_31: "bf16[2048, 5632]" = torch.ops.aten.mul.Tensor(convert_element_type_130, mm_26); convert_element_type_130 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/pytorch_yf225/torch/nn/modules/linear.py:116 in forward, code: return F.linear(input, self.weight, self.bias) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] permute_71: "bf16[5632, 2048]" = torch.ops.aten.permute.default(contiguous_view_as_strided_35, [1, 0]); contiguous_view_as_strided_35 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mm_27: "bf16[2048, 2048]" = torch.ops.aten.mm.default(mul_31, permute_71); permute_71 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:411 in forward, code: out = h + self.feed_forward(self.ffn_norm(h)) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] add_15: "bf16[2048, 2048]" = torch.ops.aten.add.Tensor(add_13, mm_27); add_13 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_common.py:140 in _to_dtype_if_needed, code: return tensor.to(dtype) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_135: "bf16[524288]" = torch.ops.prims.convert_element_type.default(primals_82, torch.bfloat16); primals_82 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_136: "bf16[524288]" = torch.ops.prims.convert_element_type.default(primals_83, torch.bfloat16); primals_83 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_137: "bf16[524288]" = torch.ops.prims.convert_element_type.default(primals_84, torch.bfloat16); primals_84 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_138: "bf16[524288]" = torch.ops.prims.convert_element_type.default(primals_85, torch.bfloat16); primals_85 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_139: "bf16[1441792]" = torch.ops.prims.convert_element_type.default(primals_86, torch.bfloat16); primals_86 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_140: "bf16[1441792]" = torch.ops.prims.convert_element_type.default(primals_87, torch.bfloat16); primals_87 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_141: "bf16[1441792]" = torch.ops.prims.convert_element_type.default(primals_88, torch.bfloat16); primals_88 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_142: "bf16[256]" = torch.ops.prims.convert_element_type.default(primals_89, torch.bfloat16); primals_89 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_143: "bf16[256]" = torch.ops.prims.convert_element_type.default(primals_90, torch.bfloat16); primals_90 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:80 in foreach_all_gather, code: all_gather_input, all_gather_output = torch.ops.fsdp.all_gather_copy_in(all_gather_input_numel, world_size, rank, dtype, device, inp_split_sizes, param_all_gather_inputs) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] all_gather_copy_in_5 = torch.ops.fsdp.all_gather_copy_in.default(6423040, 8, 0, torch.bfloat16, device(type='cuda', index=0), [524288, 524288, 524288, 524288, 1441792, 1441792, 1441792, 256, 256], [convert_element_type_135, convert_element_type_136, convert_element_type_137, convert_element_type_138, convert_element_type_139, convert_element_type_140, convert_element_type_141, convert_element_type_142, convert_element_type_143]); convert_element_type_135 = convert_element_type_136 = convert_element_type_137 = convert_element_type_138 = convert_element_type_139 = convert_element_type_140 = convert_element_type_141 = convert_element_type_142 = convert_element_type_143 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_418: "bf16[6423040]" = all_gather_copy_in_5[0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_419: "bf16[51384320]" = all_gather_copy_in_5[1]; all_gather_copy_in_5 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_functional_collectives.py:229 in all_gather_tensor, code: tensor = torch.ops._c10d_functional.all_gather_into_tensor( | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] all_gather_into_tensor_5: "bf16[51384320]" = torch.ops._c10d_functional.all_gather_into_tensor.default(getitem_418, 8, '0'); getitem_418 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_functional_collectives.py:144 in wait_tensor, code: return torch.ops._c10d_functional.wait_tensor(tensor) # type: ignore[attr-defined] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] wait_tensor_5: "bf16[51384320]" = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_5); all_gather_into_tensor_5 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:170 in foreach_all_gather_copy_out, code: split_unpadded = torch.ops.fsdp.contiguous_view_as_strided( | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_82: "bf16[8, 6423040]" = torch.ops.aten.reshape.default(wait_tensor_5, [8, -1]); wait_tensor_5 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] split_with_sizes_45 = torch.ops.aten.split_with_sizes.default(view_82, [524288, 524288, 524288, 524288, 1441792, 1441792, 1441792, 256, 256], 1); view_82 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_429: "bf16[8, 524288]" = split_with_sizes_45[0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] contiguous_view_as_strided_39: "bf16[2048, 2048]" = torch.ops.fsdp.contiguous_view_as_strided.default(getitem_429, [2048, 2048], [2048, 1]); getitem_429 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_439: "bf16[8, 524288]" = split_with_sizes_45[1] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] contiguous_view_as_strided_40: "bf16[2048, 2048]" = torch.ops.fsdp.contiguous_view_as_strided.default(getitem_439, [2048, 2048], [2048, 1]); getitem_439 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_449: "bf16[8, 524288]" = split_with_sizes_45[2] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] contiguous_view_as_strided_41: "bf16[2048, 2048]" = torch.ops.fsdp.contiguous_view_as_strided.default(getitem_449, [2048, 2048], [2048, 1]); getitem_449 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_459: "bf16[8, 524288]" = split_with_sizes_45[3] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] contiguous_view_as_strided_42: "bf16[2048, 2048]" = torch.ops.fsdp.contiguous_view_as_strided.default(getitem_459, [2048, 2048], [2048, 1]); getitem_459 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_469: "bf16[8, 1441792]" = split_with_sizes_45[4] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] contiguous_view_as_strided_43: "bf16[5632, 2048]" = torch.ops.fsdp.contiguous_view_as_strided.default(getitem_469, [5632, 2048], [2048, 1]); getitem_469 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_479: "bf16[8, 1441792]" = split_with_sizes_45[5] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] contiguous_view_as_strided_44: "bf16[2048, 5632]" = torch.ops.fsdp.contiguous_view_as_strided.default(getitem_479, [2048, 5632], [5632, 1]); getitem_479 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_489: "bf16[8, 1441792]" = split_with_sizes_45[6] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] contiguous_view_as_strided_45: "bf16[5632, 2048]" = torch.ops.fsdp.contiguous_view_as_strided.default(getitem_489, [5632, 2048], [2048, 1]); getitem_489 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_499: "bf16[8, 256]" = split_with_sizes_45[7] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] contiguous_view_as_strided_46: "bf16[2048]" = torch.ops.fsdp.contiguous_view_as_strided.default(getitem_499, [2048], [1]); getitem_499 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_509: "bf16[8, 256]" = split_with_sizes_45[8]; split_with_sizes_45 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] contiguous_view_as_strided_47: "bf16[2048]" = torch.ops.fsdp.contiguous_view_as_strided.default(getitem_509, [2048], [1]); getitem_509 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:74 in forward, code: output = self._norm(x.float()).type_as(x) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_144: "f32[2048, 2048]" = torch.ops.prims.convert_element_type.default(add_15, torch.float32) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:61 in _norm, code: return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] pow_9: "f32[2048, 2048]" = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_144, 2) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mean_8: "f32[2048, 1]" = torch.ops.aten.mean.dim(pow_9, [-1], True); pow_9 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] add_16: "f32[2048, 1]" = torch.ops.aten.add.Tensor(mean_8, 1e-05); mean_8 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] rsqrt_8: "f32[2048, 1]" = torch.ops.aten.rsqrt.default(add_16); add_16 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mul_32: "f32[2048, 2048]" = torch.ops.aten.mul.Tensor(convert_element_type_144, rsqrt_8); convert_element_type_144 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:74 in forward, code: output = self._norm(x.float()).type_as(x) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_145: "bf16[2048, 2048]" = torch.ops.prims.convert_element_type.default(mul_32, torch.bfloat16); mul_32 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:75 in forward, code: return output * self.weight | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mul_33: "bf16[2048, 2048]" = torch.ops.aten.mul.Tensor(convert_element_type_145, contiguous_view_as_strided_46); convert_element_type_145 = contiguous_view_as_strided_46 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/pytorch_yf225/torch/nn/modules/linear.py:116 in forward, code: return F.linear(input, self.weight, self.bias) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] permute_73: "bf16[2048, 2048]" = torch.ops.aten.permute.default(contiguous_view_as_strided_39, [1, 0]); contiguous_view_as_strided_39 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mm_28: "bf16[2048, 2048]" = torch.ops.aten.mm.default(mul_33, permute_73); permute_73 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] permute_75: "bf16[2048, 2048]" = torch.ops.aten.permute.default(contiguous_view_as_strided_40, [1, 0]); contiguous_view_as_strided_40 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mm_29: "bf16[2048, 2048]" = torch.ops.aten.mm.default(mul_33, permute_75); permute_75 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] permute_77: "bf16[2048, 2048]" = torch.ops.aten.permute.default(contiguous_view_as_strided_41, [1, 0]); contiguous_view_as_strided_41 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mm_30: "bf16[2048, 2048]" = torch.ops.aten.mm.default(mul_33, permute_77); permute_77 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:235 in forward, code: xq = xq.view(bsz, seqlen, self.n_heads, self.head_dim) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_91: "bf16[1, 2048, 16, 128]" = torch.ops.aten.reshape.default(mm_28, [1, 2048, 16, 128]); mm_28 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:236 in forward, code: xk = xk.view(bsz, seqlen, self.n_kv_heads, self.head_dim) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_92: "bf16[1, 2048, 16, 128]" = torch.ops.aten.reshape.default(mm_29, [1, 2048, 16, 128]); mm_29 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:237 in forward, code: xv = xv.view(bsz, seqlen, self.n_kv_heads, self.head_dim) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_93: "bf16[1, 2048, 16, 128]" = torch.ops.aten.reshape.default(mm_30, [1, 2048, 16, 128]); mm_30 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:146 in apply_rotary_emb, code: xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_152: "f32[1, 2048, 16, 128]" = torch.ops.prims.convert_element_type.default(view_91, torch.float32); view_91 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_94: "f32[1, 2048, 16, 64, 2]" = torch.ops.aten.reshape.default(convert_element_type_152, [1, 2048, 16, -1, 2]); convert_element_type_152 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_as_complex_8: "c64[1, 2048, 16, 64]" = torch.ops.aten.view_as_complex.default(view_94); view_94 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:147 in apply_rotary_emb, code: xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2)) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_153: "f32[1, 2048, 16, 128]" = torch.ops.prims.convert_element_type.default(view_92, torch.float32); view_92 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_95: "f32[1, 2048, 16, 64, 2]" = torch.ops.aten.reshape.default(convert_element_type_153, [1, 2048, 16, -1, 2]); convert_element_type_153 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_as_complex_9: "c64[1, 2048, 16, 64]" = torch.ops.aten.view_as_complex.default(view_95); view_95 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:149 in apply_rotary_emb, code: xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mul_34: "c64[1, 2048, 16, 64]" = torch.ops.aten.mul.Tensor(view_as_complex_8, view_20); view_as_complex_8 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_as_real_8: "f32[1, 2048, 16, 64, 2]" = torch.ops.aten.view_as_real.default(mul_34); mul_34 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_97: "f32[1, 2048, 16, 128]" = torch.ops.aten.reshape.default(view_as_real_8, [1, 2048, 16, 128]); view_as_real_8 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:150 in apply_rotary_emb, code: xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mul_35: "c64[1, 2048, 16, 64]" = torch.ops.aten.mul.Tensor(view_as_complex_9, view_20); view_as_complex_9 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_as_real_9: "f32[1, 2048, 16, 64, 2]" = torch.ops.aten.view_as_real.default(mul_35); mul_35 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_98: "f32[1, 2048, 16, 128]" = torch.ops.aten.reshape.default(view_as_real_9, [1, 2048, 16, 128]); view_as_real_9 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:151 in apply_rotary_emb, code: return xq_out.type_as(xq), xk_out.type_as(xk) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_154: "bf16[1, 2048, 16, 128]" = torch.ops.prims.convert_element_type.default(view_97, torch.bfloat16); view_97 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_155: "bf16[1, 2048, 16, 128]" = torch.ops.prims.convert_element_type.default(view_98, torch.bfloat16); view_98 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:249 in forward, code: xq = xq.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] permute_78: "bf16[1, 16, 2048, 128]" = torch.ops.aten.permute.default(convert_element_type_154, [0, 2, 1, 3]); convert_element_type_154 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:250 in forward, code: xk = keys.transpose(1, 2) # (bs, n_local_heads, cache_len + seqlen, head_dim) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] permute_79: "bf16[1, 16, 2048, 128]" = torch.ops.aten.permute.default(convert_element_type_155, [0, 2, 1, 3]); convert_element_type_155 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:251 in forward, code: xv = values.transpose(1, 2) # (bs, n_local_heads, cache_len + seqlen, head_dim) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] permute_80: "bf16[1, 16, 2048, 128]" = torch.ops.aten.permute.default(view_93, [0, 2, 1, 3]); view_93 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:254 in forward, code: output = F.scaled_dot_product_attention(xq, xk, xv, is_causal=True) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] _scaled_dot_product_flash_attention_4 = torch.ops.aten._scaled_dot_product_flash_attention.default(permute_78, permute_79, permute_80, 0.0, True, scale = 0.08838834764831843) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_510: "bf16[1, 16, 2048, 128]" = _scaled_dot_product_flash_attention_4[0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_511: "f32[1, 16, 2048]" = _scaled_dot_product_flash_attention_4[1] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_516: "i64[]" = _scaled_dot_product_flash_attention_4[6] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_517: "i64[]" = _scaled_dot_product_flash_attention_4[7]; _scaled_dot_product_flash_attention_4 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:255 in forward, code: output = output.transpose( | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] permute_81: "bf16[1, 2048, 16, 128]" = torch.ops.aten.permute.default(getitem_510, [0, 2, 1, 3]) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:259 in forward, code: output = output.view(bsz * seqlen, -1) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_99: "bf16[2048, 2048]" = torch.ops.aten.reshape.default(permute_81, [2048, -1]); permute_81 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/pytorch_yf225/torch/nn/modules/linear.py:116 in forward, code: return F.linear(input, self.weight, self.bias) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] permute_83: "bf16[2048, 2048]" = torch.ops.aten.permute.default(contiguous_view_as_strided_42, [1, 0]); contiguous_view_as_strided_42 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mm_31: "bf16[2048, 2048]" = torch.ops.aten.mm.default(view_99, permute_83); permute_83 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:410 in forward, code: h = x + self.attention(self.attention_norm(x), freqs_cis) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] add_17: "bf16[2048, 2048]" = torch.ops.aten.add.Tensor(add_15, mm_31); add_15 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:74 in forward, code: output = self._norm(x.float()).type_as(x) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_158: "f32[2048, 2048]" = torch.ops.prims.convert_element_type.default(add_17, torch.float32) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:61 in _norm, code: return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] pow_10: "f32[2048, 2048]" = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_158, 2) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mean_9: "f32[2048, 1]" = torch.ops.aten.mean.dim(pow_10, [-1], True); pow_10 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] add_18: "f32[2048, 1]" = torch.ops.aten.add.Tensor(mean_9, 1e-05); mean_9 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] rsqrt_9: "f32[2048, 1]" = torch.ops.aten.rsqrt.default(add_18); add_18 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mul_36: "f32[2048, 2048]" = torch.ops.aten.mul.Tensor(convert_element_type_158, rsqrt_9); convert_element_type_158 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:74 in forward, code: output = self._norm(x.float()).type_as(x) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_159: "bf16[2048, 2048]" = torch.ops.prims.convert_element_type.default(mul_36, torch.bfloat16); mul_36 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:75 in forward, code: return output * self.weight | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mul_37: "bf16[2048, 2048]" = torch.ops.aten.mul.Tensor(convert_element_type_159, contiguous_view_as_strided_47); convert_element_type_159 = contiguous_view_as_strided_47 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/pytorch_yf225/torch/nn/modules/linear.py:116 in forward, code: return F.linear(input, self.weight, self.bias) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] permute_85: "bf16[2048, 5632]" = torch.ops.aten.permute.default(contiguous_view_as_strided_43, [1, 0]); contiguous_view_as_strided_43 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mm_32: "bf16[2048, 5632]" = torch.ops.aten.mm.default(mul_37, permute_85); permute_85 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:299 in forward, code: return self.w2(F.silu(self.w1(x)) * self.w3(x)) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_162: "f32[2048, 5632]" = torch.ops.prims.convert_element_type.default(mm_32, torch.float32) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] sigmoid_4: "f32[2048, 5632]" = torch.ops.aten.sigmoid.default(convert_element_type_162) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mul_38: "f32[2048, 5632]" = torch.ops.aten.mul.Tensor(convert_element_type_162, sigmoid_4); convert_element_type_162 = sigmoid_4 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_163: "bf16[2048, 5632]" = torch.ops.prims.convert_element_type.default(mul_38, torch.bfloat16); mul_38 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/pytorch_yf225/torch/nn/modules/linear.py:116 in forward, code: return F.linear(input, self.weight, self.bias) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] permute_87: "bf16[2048, 5632]" = torch.ops.aten.permute.default(contiguous_view_as_strided_45, [1, 0]); contiguous_view_as_strided_45 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mm_33: "bf16[2048, 5632]" = torch.ops.aten.mm.default(mul_37, permute_87); permute_87 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:299 in forward, code: return self.w2(F.silu(self.w1(x)) * self.w3(x)) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mul_39: "bf16[2048, 5632]" = torch.ops.aten.mul.Tensor(convert_element_type_163, mm_33); convert_element_type_163 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/pytorch_yf225/torch/nn/modules/linear.py:116 in forward, code: return F.linear(input, self.weight, self.bias) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] permute_89: "bf16[5632, 2048]" = torch.ops.aten.permute.default(contiguous_view_as_strided_44, [1, 0]); contiguous_view_as_strided_44 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mm_34: "bf16[2048, 2048]" = torch.ops.aten.mm.default(mul_39, permute_89); permute_89 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:411 in forward, code: out = h + self.feed_forward(self.ffn_norm(h)) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] add_19: "bf16[2048, 2048]" = torch.ops.aten.add.Tensor(add_17, mm_34); add_17 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_common.py:140 in _to_dtype_if_needed, code: return tensor.to(dtype) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_168: "bf16[524288]" = torch.ops.prims.convert_element_type.default(primals_100, torch.bfloat16); primals_100 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_169: "bf16[524288]" = torch.ops.prims.convert_element_type.default(primals_101, torch.bfloat16); primals_101 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_170: "bf16[524288]" = torch.ops.prims.convert_element_type.default(primals_102, torch.bfloat16); primals_102 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_171: "bf16[524288]" = torch.ops.prims.convert_element_type.default(primals_103, torch.bfloat16); primals_103 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_172: "bf16[1441792]" = torch.ops.prims.convert_element_type.default(primals_104, torch.bfloat16); primals_104 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_173: "bf16[1441792]" = torch.ops.prims.convert_element_type.default(primals_105, torch.bfloat16); primals_105 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_174: "bf16[1441792]" = torch.ops.prims.convert_element_type.default(primals_106, torch.bfloat16); primals_106 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_175: "bf16[256]" = torch.ops.prims.convert_element_type.default(primals_107, torch.bfloat16); primals_107 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_176: "bf16[256]" = torch.ops.prims.convert_element_type.default(primals_108, torch.bfloat16); primals_108 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:80 in foreach_all_gather, code: all_gather_input, all_gather_output = torch.ops.fsdp.all_gather_copy_in(all_gather_input_numel, world_size, rank, dtype, device, inp_split_sizes, param_all_gather_inputs) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] all_gather_copy_in_6 = torch.ops.fsdp.all_gather_copy_in.default(6423040, 8, 0, torch.bfloat16, device(type='cuda', index=0), [524288, 524288, 524288, 524288, 1441792, 1441792, 1441792, 256, 256], [convert_element_type_168, convert_element_type_169, convert_element_type_170, convert_element_type_171, convert_element_type_172, convert_element_type_173, convert_element_type_174, convert_element_type_175, convert_element_type_176]); convert_element_type_168 = convert_element_type_169 = convert_element_type_170 = convert_element_type_171 = convert_element_type_172 = convert_element_type_173 = convert_element_type_174 = convert_element_type_175 = convert_element_type_176 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_519: "bf16[6423040]" = all_gather_copy_in_6[0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_520: "bf16[51384320]" = all_gather_copy_in_6[1]; all_gather_copy_in_6 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_functional_collectives.py:229 in all_gather_tensor, code: tensor = torch.ops._c10d_functional.all_gather_into_tensor( | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] all_gather_into_tensor_6: "bf16[51384320]" = torch.ops._c10d_functional.all_gather_into_tensor.default(getitem_519, 8, '0'); getitem_519 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_functional_collectives.py:144 in wait_tensor, code: return torch.ops._c10d_functional.wait_tensor(tensor) # type: ignore[attr-defined] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] wait_tensor_6: "bf16[51384320]" = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_6); all_gather_into_tensor_6 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:170 in foreach_all_gather_copy_out, code: split_unpadded = torch.ops.fsdp.contiguous_view_as_strided( | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_101: "bf16[8, 6423040]" = torch.ops.aten.reshape.default(wait_tensor_6, [8, -1]); wait_tensor_6 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] split_with_sizes_55 = torch.ops.aten.split_with_sizes.default(view_101, [524288, 524288, 524288, 524288, 1441792, 1441792, 1441792, 256, 256], 1); view_101 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_530: "bf16[8, 524288]" = split_with_sizes_55[0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] contiguous_view_as_strided_48: "bf16[2048, 2048]" = torch.ops.fsdp.contiguous_view_as_strided.default(getitem_530, [2048, 2048], [2048, 1]); getitem_530 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_540: "bf16[8, 524288]" = split_with_sizes_55[1] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] contiguous_view_as_strided_49: "bf16[2048, 2048]" = torch.ops.fsdp.contiguous_view_as_strided.default(getitem_540, [2048, 2048], [2048, 1]); getitem_540 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_550: "bf16[8, 524288]" = split_with_sizes_55[2] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] contiguous_view_as_strided_50: "bf16[2048, 2048]" = torch.ops.fsdp.contiguous_view_as_strided.default(getitem_550, [2048, 2048], [2048, 1]); getitem_550 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_560: "bf16[8, 524288]" = split_with_sizes_55[3] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] contiguous_view_as_strided_51: "bf16[2048, 2048]" = torch.ops.fsdp.contiguous_view_as_strided.default(getitem_560, [2048, 2048], [2048, 1]); getitem_560 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_570: "bf16[8, 1441792]" = split_with_sizes_55[4] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] contiguous_view_as_strided_52: "bf16[5632, 2048]" = torch.ops.fsdp.contiguous_view_as_strided.default(getitem_570, [5632, 2048], [2048, 1]); getitem_570 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_580: "bf16[8, 1441792]" = split_with_sizes_55[5] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] contiguous_view_as_strided_53: "bf16[2048, 5632]" = torch.ops.fsdp.contiguous_view_as_strided.default(getitem_580, [2048, 5632], [5632, 1]); getitem_580 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_590: "bf16[8, 1441792]" = split_with_sizes_55[6] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] contiguous_view_as_strided_54: "bf16[5632, 2048]" = torch.ops.fsdp.contiguous_view_as_strided.default(getitem_590, [5632, 2048], [2048, 1]); getitem_590 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_600: "bf16[8, 256]" = split_with_sizes_55[7] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] contiguous_view_as_strided_55: "bf16[2048]" = torch.ops.fsdp.contiguous_view_as_strided.default(getitem_600, [2048], [1]); getitem_600 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_610: "bf16[8, 256]" = split_with_sizes_55[8]; split_with_sizes_55 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] contiguous_view_as_strided_56: "bf16[2048]" = torch.ops.fsdp.contiguous_view_as_strided.default(getitem_610, [2048], [1]); getitem_610 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:74 in forward, code: output = self._norm(x.float()).type_as(x) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_177: "f32[2048, 2048]" = torch.ops.prims.convert_element_type.default(add_19, torch.float32) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:61 in _norm, code: return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] pow_11: "f32[2048, 2048]" = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_177, 2) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mean_10: "f32[2048, 1]" = torch.ops.aten.mean.dim(pow_11, [-1], True); pow_11 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] add_20: "f32[2048, 1]" = torch.ops.aten.add.Tensor(mean_10, 1e-05); mean_10 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] rsqrt_10: "f32[2048, 1]" = torch.ops.aten.rsqrt.default(add_20); add_20 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mul_40: "f32[2048, 2048]" = torch.ops.aten.mul.Tensor(convert_element_type_177, rsqrt_10); convert_element_type_177 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:74 in forward, code: output = self._norm(x.float()).type_as(x) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_178: "bf16[2048, 2048]" = torch.ops.prims.convert_element_type.default(mul_40, torch.bfloat16); mul_40 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:75 in forward, code: return output * self.weight | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mul_41: "bf16[2048, 2048]" = torch.ops.aten.mul.Tensor(convert_element_type_178, contiguous_view_as_strided_55); convert_element_type_178 = contiguous_view_as_strided_55 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/pytorch_yf225/torch/nn/modules/linear.py:116 in forward, code: return F.linear(input, self.weight, self.bias) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] permute_91: "bf16[2048, 2048]" = torch.ops.aten.permute.default(contiguous_view_as_strided_48, [1, 0]); contiguous_view_as_strided_48 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mm_35: "bf16[2048, 2048]" = torch.ops.aten.mm.default(mul_41, permute_91); permute_91 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] permute_93: "bf16[2048, 2048]" = torch.ops.aten.permute.default(contiguous_view_as_strided_49, [1, 0]); contiguous_view_as_strided_49 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mm_36: "bf16[2048, 2048]" = torch.ops.aten.mm.default(mul_41, permute_93); permute_93 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] permute_95: "bf16[2048, 2048]" = torch.ops.aten.permute.default(contiguous_view_as_strided_50, [1, 0]); contiguous_view_as_strided_50 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mm_37: "bf16[2048, 2048]" = torch.ops.aten.mm.default(mul_41, permute_95); permute_95 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:235 in forward, code: xq = xq.view(bsz, seqlen, self.n_heads, self.head_dim) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_110: "bf16[1, 2048, 16, 128]" = torch.ops.aten.reshape.default(mm_35, [1, 2048, 16, 128]); mm_35 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:236 in forward, code: xk = xk.view(bsz, seqlen, self.n_kv_heads, self.head_dim) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_111: "bf16[1, 2048, 16, 128]" = torch.ops.aten.reshape.default(mm_36, [1, 2048, 16, 128]); mm_36 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:237 in forward, code: xv = xv.view(bsz, seqlen, self.n_kv_heads, self.head_dim) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_112: "bf16[1, 2048, 16, 128]" = torch.ops.aten.reshape.default(mm_37, [1, 2048, 16, 128]); mm_37 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:146 in apply_rotary_emb, code: xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_185: "f32[1, 2048, 16, 128]" = torch.ops.prims.convert_element_type.default(view_110, torch.float32); view_110 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_113: "f32[1, 2048, 16, 64, 2]" = torch.ops.aten.reshape.default(convert_element_type_185, [1, 2048, 16, -1, 2]); convert_element_type_185 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_as_complex_10: "c64[1, 2048, 16, 64]" = torch.ops.aten.view_as_complex.default(view_113); view_113 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:147 in apply_rotary_emb, code: xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2)) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_186: "f32[1, 2048, 16, 128]" = torch.ops.prims.convert_element_type.default(view_111, torch.float32); view_111 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_114: "f32[1, 2048, 16, 64, 2]" = torch.ops.aten.reshape.default(convert_element_type_186, [1, 2048, 16, -1, 2]); convert_element_type_186 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_as_complex_11: "c64[1, 2048, 16, 64]" = torch.ops.aten.view_as_complex.default(view_114); view_114 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:149 in apply_rotary_emb, code: xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mul_42: "c64[1, 2048, 16, 64]" = torch.ops.aten.mul.Tensor(view_as_complex_10, view_20); view_as_complex_10 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_as_real_10: "f32[1, 2048, 16, 64, 2]" = torch.ops.aten.view_as_real.default(mul_42); mul_42 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_116: "f32[1, 2048, 16, 128]" = torch.ops.aten.reshape.default(view_as_real_10, [1, 2048, 16, 128]); view_as_real_10 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:150 in apply_rotary_emb, code: xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mul_43: "c64[1, 2048, 16, 64]" = torch.ops.aten.mul.Tensor(view_as_complex_11, view_20); view_as_complex_11 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_as_real_11: "f32[1, 2048, 16, 64, 2]" = torch.ops.aten.view_as_real.default(mul_43); mul_43 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_117: "f32[1, 2048, 16, 128]" = torch.ops.aten.reshape.default(view_as_real_11, [1, 2048, 16, 128]); view_as_real_11 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:151 in apply_rotary_emb, code: return xq_out.type_as(xq), xk_out.type_as(xk) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_187: "bf16[1, 2048, 16, 128]" = torch.ops.prims.convert_element_type.default(view_116, torch.bfloat16); view_116 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_188: "bf16[1, 2048, 16, 128]" = torch.ops.prims.convert_element_type.default(view_117, torch.bfloat16); view_117 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:249 in forward, code: xq = xq.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] permute_96: "bf16[1, 16, 2048, 128]" = torch.ops.aten.permute.default(convert_element_type_187, [0, 2, 1, 3]); convert_element_type_187 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:250 in forward, code: xk = keys.transpose(1, 2) # (bs, n_local_heads, cache_len + seqlen, head_dim) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] permute_97: "bf16[1, 16, 2048, 128]" = torch.ops.aten.permute.default(convert_element_type_188, [0, 2, 1, 3]); convert_element_type_188 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:251 in forward, code: xv = values.transpose(1, 2) # (bs, n_local_heads, cache_len + seqlen, head_dim) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] permute_98: "bf16[1, 16, 2048, 128]" = torch.ops.aten.permute.default(view_112, [0, 2, 1, 3]); view_112 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:254 in forward, code: output = F.scaled_dot_product_attention(xq, xk, xv, is_causal=True) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] _scaled_dot_product_flash_attention_5 = torch.ops.aten._scaled_dot_product_flash_attention.default(permute_96, permute_97, permute_98, 0.0, True, scale = 0.08838834764831843) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_611: "bf16[1, 16, 2048, 128]" = _scaled_dot_product_flash_attention_5[0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_612: "f32[1, 16, 2048]" = _scaled_dot_product_flash_attention_5[1] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_617: "i64[]" = _scaled_dot_product_flash_attention_5[6] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_618: "i64[]" = _scaled_dot_product_flash_attention_5[7]; _scaled_dot_product_flash_attention_5 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:255 in forward, code: output = output.transpose( | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] permute_99: "bf16[1, 2048, 16, 128]" = torch.ops.aten.permute.default(getitem_611, [0, 2, 1, 3]) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:259 in forward, code: output = output.view(bsz * seqlen, -1) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_118: "bf16[2048, 2048]" = torch.ops.aten.reshape.default(permute_99, [2048, -1]); permute_99 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/pytorch_yf225/torch/nn/modules/linear.py:116 in forward, code: return F.linear(input, self.weight, self.bias) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] permute_101: "bf16[2048, 2048]" = torch.ops.aten.permute.default(contiguous_view_as_strided_51, [1, 0]); contiguous_view_as_strided_51 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mm_38: "bf16[2048, 2048]" = torch.ops.aten.mm.default(view_118, permute_101); permute_101 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:410 in forward, code: h = x + self.attention(self.attention_norm(x), freqs_cis) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] add_21: "bf16[2048, 2048]" = torch.ops.aten.add.Tensor(add_19, mm_38); add_19 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:74 in forward, code: output = self._norm(x.float()).type_as(x) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_191: "f32[2048, 2048]" = torch.ops.prims.convert_element_type.default(add_21, torch.float32) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:61 in _norm, code: return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] pow_12: "f32[2048, 2048]" = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_191, 2) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mean_11: "f32[2048, 1]" = torch.ops.aten.mean.dim(pow_12, [-1], True); pow_12 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] add_22: "f32[2048, 1]" = torch.ops.aten.add.Tensor(mean_11, 1e-05); mean_11 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] rsqrt_11: "f32[2048, 1]" = torch.ops.aten.rsqrt.default(add_22); add_22 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mul_44: "f32[2048, 2048]" = torch.ops.aten.mul.Tensor(convert_element_type_191, rsqrt_11); convert_element_type_191 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:74 in forward, code: output = self._norm(x.float()).type_as(x) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_192: "bf16[2048, 2048]" = torch.ops.prims.convert_element_type.default(mul_44, torch.bfloat16); mul_44 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:75 in forward, code: return output * self.weight | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mul_45: "bf16[2048, 2048]" = torch.ops.aten.mul.Tensor(convert_element_type_192, contiguous_view_as_strided_56); convert_element_type_192 = contiguous_view_as_strided_56 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/pytorch_yf225/torch/nn/modules/linear.py:116 in forward, code: return F.linear(input, self.weight, self.bias) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] permute_103: "bf16[2048, 5632]" = torch.ops.aten.permute.default(contiguous_view_as_strided_52, [1, 0]); contiguous_view_as_strided_52 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mm_39: "bf16[2048, 5632]" = torch.ops.aten.mm.default(mul_45, permute_103); permute_103 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:299 in forward, code: return self.w2(F.silu(self.w1(x)) * self.w3(x)) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_195: "f32[2048, 5632]" = torch.ops.prims.convert_element_type.default(mm_39, torch.float32) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] sigmoid_5: "f32[2048, 5632]" = torch.ops.aten.sigmoid.default(convert_element_type_195) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mul_46: "f32[2048, 5632]" = torch.ops.aten.mul.Tensor(convert_element_type_195, sigmoid_5); convert_element_type_195 = sigmoid_5 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_196: "bf16[2048, 5632]" = torch.ops.prims.convert_element_type.default(mul_46, torch.bfloat16); mul_46 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/pytorch_yf225/torch/nn/modules/linear.py:116 in forward, code: return F.linear(input, self.weight, self.bias) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] permute_105: "bf16[2048, 5632]" = torch.ops.aten.permute.default(contiguous_view_as_strided_54, [1, 0]); contiguous_view_as_strided_54 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mm_40: "bf16[2048, 5632]" = torch.ops.aten.mm.default(mul_45, permute_105); permute_105 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:299 in forward, code: return self.w2(F.silu(self.w1(x)) * self.w3(x)) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mul_47: "bf16[2048, 5632]" = torch.ops.aten.mul.Tensor(convert_element_type_196, mm_40); convert_element_type_196 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/pytorch_yf225/torch/nn/modules/linear.py:116 in forward, code: return F.linear(input, self.weight, self.bias) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] permute_107: "bf16[5632, 2048]" = torch.ops.aten.permute.default(contiguous_view_as_strided_53, [1, 0]); contiguous_view_as_strided_53 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mm_41: "bf16[2048, 2048]" = torch.ops.aten.mm.default(mul_47, permute_107); permute_107 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:411 in forward, code: out = h + self.feed_forward(self.ffn_norm(h)) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] add_23: "bf16[2048, 2048]" = torch.ops.aten.add.Tensor(add_21, mm_41); add_21 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_common.py:140 in _to_dtype_if_needed, code: return tensor.to(dtype) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_201: "bf16[524288]" = torch.ops.prims.convert_element_type.default(primals_118, torch.bfloat16); primals_118 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_202: "bf16[524288]" = torch.ops.prims.convert_element_type.default(primals_119, torch.bfloat16); primals_119 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_203: "bf16[524288]" = torch.ops.prims.convert_element_type.default(primals_120, torch.bfloat16); primals_120 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_204: "bf16[524288]" = torch.ops.prims.convert_element_type.default(primals_121, torch.bfloat16); primals_121 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_205: "bf16[1441792]" = torch.ops.prims.convert_element_type.default(primals_122, torch.bfloat16); primals_122 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_206: "bf16[1441792]" = torch.ops.prims.convert_element_type.default(primals_123, torch.bfloat16); primals_123 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_207: "bf16[1441792]" = torch.ops.prims.convert_element_type.default(primals_124, torch.bfloat16); primals_124 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_208: "bf16[256]" = torch.ops.prims.convert_element_type.default(primals_125, torch.bfloat16); primals_125 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_209: "bf16[256]" = torch.ops.prims.convert_element_type.default(primals_126, torch.bfloat16); primals_126 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:80 in foreach_all_gather, code: all_gather_input, all_gather_output = torch.ops.fsdp.all_gather_copy_in(all_gather_input_numel, world_size, rank, dtype, device, inp_split_sizes, param_all_gather_inputs) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] all_gather_copy_in_7 = torch.ops.fsdp.all_gather_copy_in.default(6423040, 8, 0, torch.bfloat16, device(type='cuda', index=0), [524288, 524288, 524288, 524288, 1441792, 1441792, 1441792, 256, 256], [convert_element_type_201, convert_element_type_202, convert_element_type_203, convert_element_type_204, convert_element_type_205, convert_element_type_206, convert_element_type_207, convert_element_type_208, convert_element_type_209]); convert_element_type_201 = convert_element_type_202 = convert_element_type_203 = convert_element_type_204 = convert_element_type_205 = convert_element_type_206 = convert_element_type_207 = convert_element_type_208 = convert_element_type_209 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_620: "bf16[6423040]" = all_gather_copy_in_7[0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_621: "bf16[51384320]" = all_gather_copy_in_7[1]; all_gather_copy_in_7 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_functional_collectives.py:229 in all_gather_tensor, code: tensor = torch.ops._c10d_functional.all_gather_into_tensor( | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] all_gather_into_tensor_7: "bf16[51384320]" = torch.ops._c10d_functional.all_gather_into_tensor.default(getitem_620, 8, '0'); getitem_620 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_functional_collectives.py:144 in wait_tensor, code: return torch.ops._c10d_functional.wait_tensor(tensor) # type: ignore[attr-defined] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] wait_tensor_7: "bf16[51384320]" = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_7); all_gather_into_tensor_7 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:170 in foreach_all_gather_copy_out, code: split_unpadded = torch.ops.fsdp.contiguous_view_as_strided( | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_120: "bf16[8, 6423040]" = torch.ops.aten.reshape.default(wait_tensor_7, [8, -1]); wait_tensor_7 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] split_with_sizes_65 = torch.ops.aten.split_with_sizes.default(view_120, [524288, 524288, 524288, 524288, 1441792, 1441792, 1441792, 256, 256], 1); view_120 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_631: "bf16[8, 524288]" = split_with_sizes_65[0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] contiguous_view_as_strided_57: "bf16[2048, 2048]" = torch.ops.fsdp.contiguous_view_as_strided.default(getitem_631, [2048, 2048], [2048, 1]); getitem_631 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_641: "bf16[8, 524288]" = split_with_sizes_65[1] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] contiguous_view_as_strided_58: "bf16[2048, 2048]" = torch.ops.fsdp.contiguous_view_as_strided.default(getitem_641, [2048, 2048], [2048, 1]); getitem_641 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_651: "bf16[8, 524288]" = split_with_sizes_65[2] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] contiguous_view_as_strided_59: "bf16[2048, 2048]" = torch.ops.fsdp.contiguous_view_as_strided.default(getitem_651, [2048, 2048], [2048, 1]); getitem_651 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_661: "bf16[8, 524288]" = split_with_sizes_65[3] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] contiguous_view_as_strided_60: "bf16[2048, 2048]" = torch.ops.fsdp.contiguous_view_as_strided.default(getitem_661, [2048, 2048], [2048, 1]); getitem_661 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_671: "bf16[8, 1441792]" = split_with_sizes_65[4] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] contiguous_view_as_strided_61: "bf16[5632, 2048]" = torch.ops.fsdp.contiguous_view_as_strided.default(getitem_671, [5632, 2048], [2048, 1]); getitem_671 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_681: "bf16[8, 1441792]" = split_with_sizes_65[5] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] contiguous_view_as_strided_62: "bf16[2048, 5632]" = torch.ops.fsdp.contiguous_view_as_strided.default(getitem_681, [2048, 5632], [5632, 1]); getitem_681 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_691: "bf16[8, 1441792]" = split_with_sizes_65[6] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] contiguous_view_as_strided_63: "bf16[5632, 2048]" = torch.ops.fsdp.contiguous_view_as_strided.default(getitem_691, [5632, 2048], [2048, 1]); getitem_691 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_701: "bf16[8, 256]" = split_with_sizes_65[7] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] contiguous_view_as_strided_64: "bf16[2048]" = torch.ops.fsdp.contiguous_view_as_strided.default(getitem_701, [2048], [1]); getitem_701 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_711: "bf16[8, 256]" = split_with_sizes_65[8]; split_with_sizes_65 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] contiguous_view_as_strided_65: "bf16[2048]" = torch.ops.fsdp.contiguous_view_as_strided.default(getitem_711, [2048], [1]); getitem_711 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:74 in forward, code: output = self._norm(x.float()).type_as(x) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_210: "f32[2048, 2048]" = torch.ops.prims.convert_element_type.default(add_23, torch.float32) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:61 in _norm, code: return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] pow_13: "f32[2048, 2048]" = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_210, 2) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mean_12: "f32[2048, 1]" = torch.ops.aten.mean.dim(pow_13, [-1], True); pow_13 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] add_24: "f32[2048, 1]" = torch.ops.aten.add.Tensor(mean_12, 1e-05); mean_12 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] rsqrt_12: "f32[2048, 1]" = torch.ops.aten.rsqrt.default(add_24); add_24 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mul_48: "f32[2048, 2048]" = torch.ops.aten.mul.Tensor(convert_element_type_210, rsqrt_12); convert_element_type_210 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:74 in forward, code: output = self._norm(x.float()).type_as(x) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_211: "bf16[2048, 2048]" = torch.ops.prims.convert_element_type.default(mul_48, torch.bfloat16); mul_48 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:75 in forward, code: return output * self.weight | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mul_49: "bf16[2048, 2048]" = torch.ops.aten.mul.Tensor(convert_element_type_211, contiguous_view_as_strided_64); convert_element_type_211 = contiguous_view_as_strided_64 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/pytorch_yf225/torch/nn/modules/linear.py:116 in forward, code: return F.linear(input, self.weight, self.bias) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] permute_109: "bf16[2048, 2048]" = torch.ops.aten.permute.default(contiguous_view_as_strided_57, [1, 0]); contiguous_view_as_strided_57 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mm_42: "bf16[2048, 2048]" = torch.ops.aten.mm.default(mul_49, permute_109); permute_109 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] permute_111: "bf16[2048, 2048]" = torch.ops.aten.permute.default(contiguous_view_as_strided_58, [1, 0]); contiguous_view_as_strided_58 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mm_43: "bf16[2048, 2048]" = torch.ops.aten.mm.default(mul_49, permute_111); permute_111 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] permute_113: "bf16[2048, 2048]" = torch.ops.aten.permute.default(contiguous_view_as_strided_59, [1, 0]); contiguous_view_as_strided_59 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mm_44: "bf16[2048, 2048]" = torch.ops.aten.mm.default(mul_49, permute_113); permute_113 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:235 in forward, code: xq = xq.view(bsz, seqlen, self.n_heads, self.head_dim) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_129: "bf16[1, 2048, 16, 128]" = torch.ops.aten.reshape.default(mm_42, [1, 2048, 16, 128]); mm_42 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:236 in forward, code: xk = xk.view(bsz, seqlen, self.n_kv_heads, self.head_dim) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_130: "bf16[1, 2048, 16, 128]" = torch.ops.aten.reshape.default(mm_43, [1, 2048, 16, 128]); mm_43 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:237 in forward, code: xv = xv.view(bsz, seqlen, self.n_kv_heads, self.head_dim) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_131: "bf16[1, 2048, 16, 128]" = torch.ops.aten.reshape.default(mm_44, [1, 2048, 16, 128]); mm_44 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:146 in apply_rotary_emb, code: xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_218: "f32[1, 2048, 16, 128]" = torch.ops.prims.convert_element_type.default(view_129, torch.float32); view_129 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_132: "f32[1, 2048, 16, 64, 2]" = torch.ops.aten.reshape.default(convert_element_type_218, [1, 2048, 16, -1, 2]); convert_element_type_218 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_as_complex_12: "c64[1, 2048, 16, 64]" = torch.ops.aten.view_as_complex.default(view_132); view_132 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:147 in apply_rotary_emb, code: xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2)) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_219: "f32[1, 2048, 16, 128]" = torch.ops.prims.convert_element_type.default(view_130, torch.float32); view_130 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_133: "f32[1, 2048, 16, 64, 2]" = torch.ops.aten.reshape.default(convert_element_type_219, [1, 2048, 16, -1, 2]); convert_element_type_219 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_as_complex_13: "c64[1, 2048, 16, 64]" = torch.ops.aten.view_as_complex.default(view_133); view_133 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:149 in apply_rotary_emb, code: xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mul_50: "c64[1, 2048, 16, 64]" = torch.ops.aten.mul.Tensor(view_as_complex_12, view_20); view_as_complex_12 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_as_real_12: "f32[1, 2048, 16, 64, 2]" = torch.ops.aten.view_as_real.default(mul_50); mul_50 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_135: "f32[1, 2048, 16, 128]" = torch.ops.aten.reshape.default(view_as_real_12, [1, 2048, 16, 128]); view_as_real_12 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:150 in apply_rotary_emb, code: xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mul_51: "c64[1, 2048, 16, 64]" = torch.ops.aten.mul.Tensor(view_as_complex_13, view_20); view_as_complex_13 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_as_real_13: "f32[1, 2048, 16, 64, 2]" = torch.ops.aten.view_as_real.default(mul_51); mul_51 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_136: "f32[1, 2048, 16, 128]" = torch.ops.aten.reshape.default(view_as_real_13, [1, 2048, 16, 128]); view_as_real_13 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:151 in apply_rotary_emb, code: return xq_out.type_as(xq), xk_out.type_as(xk) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_220: "bf16[1, 2048, 16, 128]" = torch.ops.prims.convert_element_type.default(view_135, torch.bfloat16); view_135 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_221: "bf16[1, 2048, 16, 128]" = torch.ops.prims.convert_element_type.default(view_136, torch.bfloat16); view_136 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:249 in forward, code: xq = xq.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] permute_114: "bf16[1, 16, 2048, 128]" = torch.ops.aten.permute.default(convert_element_type_220, [0, 2, 1, 3]); convert_element_type_220 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:250 in forward, code: xk = keys.transpose(1, 2) # (bs, n_local_heads, cache_len + seqlen, head_dim) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] permute_115: "bf16[1, 16, 2048, 128]" = torch.ops.aten.permute.default(convert_element_type_221, [0, 2, 1, 3]); convert_element_type_221 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:251 in forward, code: xv = values.transpose(1, 2) # (bs, n_local_heads, cache_len + seqlen, head_dim) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] permute_116: "bf16[1, 16, 2048, 128]" = torch.ops.aten.permute.default(view_131, [0, 2, 1, 3]); view_131 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:254 in forward, code: output = F.scaled_dot_product_attention(xq, xk, xv, is_causal=True) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] _scaled_dot_product_flash_attention_6 = torch.ops.aten._scaled_dot_product_flash_attention.default(permute_114, permute_115, permute_116, 0.0, True, scale = 0.08838834764831843) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_712: "bf16[1, 16, 2048, 128]" = _scaled_dot_product_flash_attention_6[0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_713: "f32[1, 16, 2048]" = _scaled_dot_product_flash_attention_6[1] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_718: "i64[]" = _scaled_dot_product_flash_attention_6[6] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_719: "i64[]" = _scaled_dot_product_flash_attention_6[7]; _scaled_dot_product_flash_attention_6 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:255 in forward, code: output = output.transpose( | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] permute_117: "bf16[1, 2048, 16, 128]" = torch.ops.aten.permute.default(getitem_712, [0, 2, 1, 3]) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:259 in forward, code: output = output.view(bsz * seqlen, -1) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_137: "bf16[2048, 2048]" = torch.ops.aten.reshape.default(permute_117, [2048, -1]); permute_117 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/pytorch_yf225/torch/nn/modules/linear.py:116 in forward, code: return F.linear(input, self.weight, self.bias) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] permute_119: "bf16[2048, 2048]" = torch.ops.aten.permute.default(contiguous_view_as_strided_60, [1, 0]); contiguous_view_as_strided_60 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mm_45: "bf16[2048, 2048]" = torch.ops.aten.mm.default(view_137, permute_119); permute_119 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:410 in forward, code: h = x + self.attention(self.attention_norm(x), freqs_cis) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] add_25: "bf16[2048, 2048]" = torch.ops.aten.add.Tensor(add_23, mm_45); add_23 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:74 in forward, code: output = self._norm(x.float()).type_as(x) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_224: "f32[2048, 2048]" = torch.ops.prims.convert_element_type.default(add_25, torch.float32) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:61 in _norm, code: return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] pow_14: "f32[2048, 2048]" = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_224, 2) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mean_13: "f32[2048, 1]" = torch.ops.aten.mean.dim(pow_14, [-1], True); pow_14 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] add_26: "f32[2048, 1]" = torch.ops.aten.add.Tensor(mean_13, 1e-05); mean_13 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] rsqrt_13: "f32[2048, 1]" = torch.ops.aten.rsqrt.default(add_26); add_26 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mul_52: "f32[2048, 2048]" = torch.ops.aten.mul.Tensor(convert_element_type_224, rsqrt_13); convert_element_type_224 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:74 in forward, code: output = self._norm(x.float()).type_as(x) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_225: "bf16[2048, 2048]" = torch.ops.prims.convert_element_type.default(mul_52, torch.bfloat16); mul_52 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:75 in forward, code: return output * self.weight | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mul_53: "bf16[2048, 2048]" = torch.ops.aten.mul.Tensor(convert_element_type_225, contiguous_view_as_strided_65); convert_element_type_225 = contiguous_view_as_strided_65 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/pytorch_yf225/torch/nn/modules/linear.py:116 in forward, code: return F.linear(input, self.weight, self.bias) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] permute_121: "bf16[2048, 5632]" = torch.ops.aten.permute.default(contiguous_view_as_strided_61, [1, 0]); contiguous_view_as_strided_61 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mm_46: "bf16[2048, 5632]" = torch.ops.aten.mm.default(mul_53, permute_121); permute_121 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:299 in forward, code: return self.w2(F.silu(self.w1(x)) * self.w3(x)) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_228: "f32[2048, 5632]" = torch.ops.prims.convert_element_type.default(mm_46, torch.float32) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] sigmoid_6: "f32[2048, 5632]" = torch.ops.aten.sigmoid.default(convert_element_type_228) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mul_54: "f32[2048, 5632]" = torch.ops.aten.mul.Tensor(convert_element_type_228, sigmoid_6); convert_element_type_228 = sigmoid_6 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_229: "bf16[2048, 5632]" = torch.ops.prims.convert_element_type.default(mul_54, torch.bfloat16); mul_54 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/pytorch_yf225/torch/nn/modules/linear.py:116 in forward, code: return F.linear(input, self.weight, self.bias) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] permute_123: "bf16[2048, 5632]" = torch.ops.aten.permute.default(contiguous_view_as_strided_63, [1, 0]); contiguous_view_as_strided_63 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mm_47: "bf16[2048, 5632]" = torch.ops.aten.mm.default(mul_53, permute_123); permute_123 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:299 in forward, code: return self.w2(F.silu(self.w1(x)) * self.w3(x)) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mul_55: "bf16[2048, 5632]" = torch.ops.aten.mul.Tensor(convert_element_type_229, mm_47); convert_element_type_229 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/pytorch_yf225/torch/nn/modules/linear.py:116 in forward, code: return F.linear(input, self.weight, self.bias) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] permute_125: "bf16[5632, 2048]" = torch.ops.aten.permute.default(contiguous_view_as_strided_62, [1, 0]); contiguous_view_as_strided_62 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mm_48: "bf16[2048, 2048]" = torch.ops.aten.mm.default(mul_55, permute_125); permute_125 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:411 in forward, code: out = h + self.feed_forward(self.ffn_norm(h)) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] add_27: "bf16[2048, 2048]" = torch.ops.aten.add.Tensor(add_25, mm_48); add_25 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_common.py:140 in _to_dtype_if_needed, code: return tensor.to(dtype) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_234: "bf16[524288]" = torch.ops.prims.convert_element_type.default(primals_136, torch.bfloat16); primals_136 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_235: "bf16[524288]" = torch.ops.prims.convert_element_type.default(primals_137, torch.bfloat16); primals_137 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_236: "bf16[524288]" = torch.ops.prims.convert_element_type.default(primals_138, torch.bfloat16); primals_138 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_237: "bf16[524288]" = torch.ops.prims.convert_element_type.default(primals_139, torch.bfloat16); primals_139 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_238: "bf16[1441792]" = torch.ops.prims.convert_element_type.default(primals_140, torch.bfloat16); primals_140 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_239: "bf16[1441792]" = torch.ops.prims.convert_element_type.default(primals_141, torch.bfloat16); primals_141 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_240: "bf16[1441792]" = torch.ops.prims.convert_element_type.default(primals_142, torch.bfloat16); primals_142 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_241: "bf16[256]" = torch.ops.prims.convert_element_type.default(primals_143, torch.bfloat16); primals_143 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_242: "bf16[256]" = torch.ops.prims.convert_element_type.default(primals_144, torch.bfloat16); primals_144 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:80 in foreach_all_gather, code: all_gather_input, all_gather_output = torch.ops.fsdp.all_gather_copy_in(all_gather_input_numel, world_size, rank, dtype, device, inp_split_sizes, param_all_gather_inputs) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] all_gather_copy_in_8 = torch.ops.fsdp.all_gather_copy_in.default(6423040, 8, 0, torch.bfloat16, device(type='cuda', index=0), [524288, 524288, 524288, 524288, 1441792, 1441792, 1441792, 256, 256], [convert_element_type_234, convert_element_type_235, convert_element_type_236, convert_element_type_237, convert_element_type_238, convert_element_type_239, convert_element_type_240, convert_element_type_241, convert_element_type_242]); convert_element_type_234 = convert_element_type_235 = convert_element_type_236 = convert_element_type_237 = convert_element_type_238 = convert_element_type_239 = convert_element_type_240 = convert_element_type_241 = convert_element_type_242 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_721: "bf16[6423040]" = all_gather_copy_in_8[0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_722: "bf16[51384320]" = all_gather_copy_in_8[1]; all_gather_copy_in_8 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_functional_collectives.py:229 in all_gather_tensor, code: tensor = torch.ops._c10d_functional.all_gather_into_tensor( | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] all_gather_into_tensor_8: "bf16[51384320]" = torch.ops._c10d_functional.all_gather_into_tensor.default(getitem_721, 8, '0'); getitem_721 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_functional_collectives.py:144 in wait_tensor, code: return torch.ops._c10d_functional.wait_tensor(tensor) # type: ignore[attr-defined] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] wait_tensor_8: "bf16[51384320]" = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_8); all_gather_into_tensor_8 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:170 in foreach_all_gather_copy_out, code: split_unpadded = torch.ops.fsdp.contiguous_view_as_strided( | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_139: "bf16[8, 6423040]" = torch.ops.aten.reshape.default(wait_tensor_8, [8, -1]); wait_tensor_8 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] split_with_sizes_75 = torch.ops.aten.split_with_sizes.default(view_139, [524288, 524288, 524288, 524288, 1441792, 1441792, 1441792, 256, 256], 1); view_139 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_732: "bf16[8, 524288]" = split_with_sizes_75[0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] contiguous_view_as_strided_66: "bf16[2048, 2048]" = torch.ops.fsdp.contiguous_view_as_strided.default(getitem_732, [2048, 2048], [2048, 1]); getitem_732 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_742: "bf16[8, 524288]" = split_with_sizes_75[1] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] contiguous_view_as_strided_67: "bf16[2048, 2048]" = torch.ops.fsdp.contiguous_view_as_strided.default(getitem_742, [2048, 2048], [2048, 1]); getitem_742 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_752: "bf16[8, 524288]" = split_with_sizes_75[2] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] contiguous_view_as_strided_68: "bf16[2048, 2048]" = torch.ops.fsdp.contiguous_view_as_strided.default(getitem_752, [2048, 2048], [2048, 1]); getitem_752 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_762: "bf16[8, 524288]" = split_with_sizes_75[3] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] contiguous_view_as_strided_69: "bf16[2048, 2048]" = torch.ops.fsdp.contiguous_view_as_strided.default(getitem_762, [2048, 2048], [2048, 1]); getitem_762 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_772: "bf16[8, 1441792]" = split_with_sizes_75[4] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] contiguous_view_as_strided_70: "bf16[5632, 2048]" = torch.ops.fsdp.contiguous_view_as_strided.default(getitem_772, [5632, 2048], [2048, 1]); getitem_772 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_782: "bf16[8, 1441792]" = split_with_sizes_75[5] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] contiguous_view_as_strided_71: "bf16[2048, 5632]" = torch.ops.fsdp.contiguous_view_as_strided.default(getitem_782, [2048, 5632], [5632, 1]); getitem_782 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_792: "bf16[8, 1441792]" = split_with_sizes_75[6] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] contiguous_view_as_strided_72: "bf16[5632, 2048]" = torch.ops.fsdp.contiguous_view_as_strided.default(getitem_792, [5632, 2048], [2048, 1]); getitem_792 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_802: "bf16[8, 256]" = split_with_sizes_75[7] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] contiguous_view_as_strided_73: "bf16[2048]" = torch.ops.fsdp.contiguous_view_as_strided.default(getitem_802, [2048], [1]); getitem_802 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_812: "bf16[8, 256]" = split_with_sizes_75[8]; split_with_sizes_75 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] contiguous_view_as_strided_74: "bf16[2048]" = torch.ops.fsdp.contiguous_view_as_strided.default(getitem_812, [2048], [1]); getitem_812 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:74 in forward, code: output = self._norm(x.float()).type_as(x) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_243: "f32[2048, 2048]" = torch.ops.prims.convert_element_type.default(add_27, torch.float32) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:61 in _norm, code: return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] pow_15: "f32[2048, 2048]" = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_243, 2) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mean_14: "f32[2048, 1]" = torch.ops.aten.mean.dim(pow_15, [-1], True); pow_15 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] add_28: "f32[2048, 1]" = torch.ops.aten.add.Tensor(mean_14, 1e-05); mean_14 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] rsqrt_14: "f32[2048, 1]" = torch.ops.aten.rsqrt.default(add_28); add_28 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mul_56: "f32[2048, 2048]" = torch.ops.aten.mul.Tensor(convert_element_type_243, rsqrt_14); convert_element_type_243 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:74 in forward, code: output = self._norm(x.float()).type_as(x) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_244: "bf16[2048, 2048]" = torch.ops.prims.convert_element_type.default(mul_56, torch.bfloat16); mul_56 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:75 in forward, code: return output * self.weight | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mul_57: "bf16[2048, 2048]" = torch.ops.aten.mul.Tensor(convert_element_type_244, contiguous_view_as_strided_73); convert_element_type_244 = contiguous_view_as_strided_73 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/pytorch_yf225/torch/nn/modules/linear.py:116 in forward, code: return F.linear(input, self.weight, self.bias) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] permute_127: "bf16[2048, 2048]" = torch.ops.aten.permute.default(contiguous_view_as_strided_66, [1, 0]); contiguous_view_as_strided_66 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mm_49: "bf16[2048, 2048]" = torch.ops.aten.mm.default(mul_57, permute_127); permute_127 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] permute_129: "bf16[2048, 2048]" = torch.ops.aten.permute.default(contiguous_view_as_strided_67, [1, 0]); contiguous_view_as_strided_67 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mm_50: "bf16[2048, 2048]" = torch.ops.aten.mm.default(mul_57, permute_129); permute_129 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] permute_131: "bf16[2048, 2048]" = torch.ops.aten.permute.default(contiguous_view_as_strided_68, [1, 0]); contiguous_view_as_strided_68 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mm_51: "bf16[2048, 2048]" = torch.ops.aten.mm.default(mul_57, permute_131); permute_131 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:235 in forward, code: xq = xq.view(bsz, seqlen, self.n_heads, self.head_dim) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_148: "bf16[1, 2048, 16, 128]" = torch.ops.aten.reshape.default(mm_49, [1, 2048, 16, 128]); mm_49 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:236 in forward, code: xk = xk.view(bsz, seqlen, self.n_kv_heads, self.head_dim) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_149: "bf16[1, 2048, 16, 128]" = torch.ops.aten.reshape.default(mm_50, [1, 2048, 16, 128]); mm_50 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:237 in forward, code: xv = xv.view(bsz, seqlen, self.n_kv_heads, self.head_dim) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_150: "bf16[1, 2048, 16, 128]" = torch.ops.aten.reshape.default(mm_51, [1, 2048, 16, 128]); mm_51 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:146 in apply_rotary_emb, code: xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_251: "f32[1, 2048, 16, 128]" = torch.ops.prims.convert_element_type.default(view_148, torch.float32); view_148 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_151: "f32[1, 2048, 16, 64, 2]" = torch.ops.aten.reshape.default(convert_element_type_251, [1, 2048, 16, -1, 2]); convert_element_type_251 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_as_complex_14: "c64[1, 2048, 16, 64]" = torch.ops.aten.view_as_complex.default(view_151); view_151 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:147 in apply_rotary_emb, code: xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2)) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_252: "f32[1, 2048, 16, 128]" = torch.ops.prims.convert_element_type.default(view_149, torch.float32); view_149 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_152: "f32[1, 2048, 16, 64, 2]" = torch.ops.aten.reshape.default(convert_element_type_252, [1, 2048, 16, -1, 2]); convert_element_type_252 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_as_complex_15: "c64[1, 2048, 16, 64]" = torch.ops.aten.view_as_complex.default(view_152); view_152 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:149 in apply_rotary_emb, code: xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mul_58: "c64[1, 2048, 16, 64]" = torch.ops.aten.mul.Tensor(view_as_complex_14, view_20); view_as_complex_14 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_as_real_14: "f32[1, 2048, 16, 64, 2]" = torch.ops.aten.view_as_real.default(mul_58); mul_58 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_154: "f32[1, 2048, 16, 128]" = torch.ops.aten.reshape.default(view_as_real_14, [1, 2048, 16, 128]); view_as_real_14 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:150 in apply_rotary_emb, code: xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mul_59: "c64[1, 2048, 16, 64]" = torch.ops.aten.mul.Tensor(view_as_complex_15, view_20); view_as_complex_15 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_as_real_15: "f32[1, 2048, 16, 64, 2]" = torch.ops.aten.view_as_real.default(mul_59); mul_59 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_155: "f32[1, 2048, 16, 128]" = torch.ops.aten.reshape.default(view_as_real_15, [1, 2048, 16, 128]); view_as_real_15 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:151 in apply_rotary_emb, code: return xq_out.type_as(xq), xk_out.type_as(xk) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_253: "bf16[1, 2048, 16, 128]" = torch.ops.prims.convert_element_type.default(view_154, torch.bfloat16); view_154 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_254: "bf16[1, 2048, 16, 128]" = torch.ops.prims.convert_element_type.default(view_155, torch.bfloat16); view_155 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:249 in forward, code: xq = xq.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] permute_132: "bf16[1, 16, 2048, 128]" = torch.ops.aten.permute.default(convert_element_type_253, [0, 2, 1, 3]); convert_element_type_253 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:250 in forward, code: xk = keys.transpose(1, 2) # (bs, n_local_heads, cache_len + seqlen, head_dim) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] permute_133: "bf16[1, 16, 2048, 128]" = torch.ops.aten.permute.default(convert_element_type_254, [0, 2, 1, 3]); convert_element_type_254 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:251 in forward, code: xv = values.transpose(1, 2) # (bs, n_local_heads, cache_len + seqlen, head_dim) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] permute_134: "bf16[1, 16, 2048, 128]" = torch.ops.aten.permute.default(view_150, [0, 2, 1, 3]); view_150 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:254 in forward, code: output = F.scaled_dot_product_attention(xq, xk, xv, is_causal=True) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] _scaled_dot_product_flash_attention_7 = torch.ops.aten._scaled_dot_product_flash_attention.default(permute_132, permute_133, permute_134, 0.0, True, scale = 0.08838834764831843) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_813: "bf16[1, 16, 2048, 128]" = _scaled_dot_product_flash_attention_7[0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_814: "f32[1, 16, 2048]" = _scaled_dot_product_flash_attention_7[1] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_819: "i64[]" = _scaled_dot_product_flash_attention_7[6] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_820: "i64[]" = _scaled_dot_product_flash_attention_7[7]; _scaled_dot_product_flash_attention_7 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:255 in forward, code: output = output.transpose( | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] permute_135: "bf16[1, 2048, 16, 128]" = torch.ops.aten.permute.default(getitem_813, [0, 2, 1, 3]) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:259 in forward, code: output = output.view(bsz * seqlen, -1) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_156: "bf16[2048, 2048]" = torch.ops.aten.reshape.default(permute_135, [2048, -1]); permute_135 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/pytorch_yf225/torch/nn/modules/linear.py:116 in forward, code: return F.linear(input, self.weight, self.bias) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] permute_137: "bf16[2048, 2048]" = torch.ops.aten.permute.default(contiguous_view_as_strided_69, [1, 0]); contiguous_view_as_strided_69 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mm_52: "bf16[2048, 2048]" = torch.ops.aten.mm.default(view_156, permute_137); permute_137 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:410 in forward, code: h = x + self.attention(self.attention_norm(x), freqs_cis) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] add_29: "bf16[2048, 2048]" = torch.ops.aten.add.Tensor(add_27, mm_52); add_27 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:74 in forward, code: output = self._norm(x.float()).type_as(x) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_257: "f32[2048, 2048]" = torch.ops.prims.convert_element_type.default(add_29, torch.float32) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:61 in _norm, code: return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] pow_16: "f32[2048, 2048]" = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_257, 2) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mean_15: "f32[2048, 1]" = torch.ops.aten.mean.dim(pow_16, [-1], True); pow_16 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] add_30: "f32[2048, 1]" = torch.ops.aten.add.Tensor(mean_15, 1e-05); mean_15 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] rsqrt_15: "f32[2048, 1]" = torch.ops.aten.rsqrt.default(add_30); add_30 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mul_60: "f32[2048, 2048]" = torch.ops.aten.mul.Tensor(convert_element_type_257, rsqrt_15); convert_element_type_257 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:74 in forward, code: output = self._norm(x.float()).type_as(x) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_258: "bf16[2048, 2048]" = torch.ops.prims.convert_element_type.default(mul_60, torch.bfloat16); mul_60 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:75 in forward, code: return output * self.weight | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mul_61: "bf16[2048, 2048]" = torch.ops.aten.mul.Tensor(convert_element_type_258, contiguous_view_as_strided_74); convert_element_type_258 = contiguous_view_as_strided_74 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/pytorch_yf225/torch/nn/modules/linear.py:116 in forward, code: return F.linear(input, self.weight, self.bias) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] permute_139: "bf16[2048, 5632]" = torch.ops.aten.permute.default(contiguous_view_as_strided_70, [1, 0]); contiguous_view_as_strided_70 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mm_53: "bf16[2048, 5632]" = torch.ops.aten.mm.default(mul_61, permute_139); permute_139 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:299 in forward, code: return self.w2(F.silu(self.w1(x)) * self.w3(x)) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_261: "f32[2048, 5632]" = torch.ops.prims.convert_element_type.default(mm_53, torch.float32) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] sigmoid_7: "f32[2048, 5632]" = torch.ops.aten.sigmoid.default(convert_element_type_261) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mul_62: "f32[2048, 5632]" = torch.ops.aten.mul.Tensor(convert_element_type_261, sigmoid_7); convert_element_type_261 = sigmoid_7 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_262: "bf16[2048, 5632]" = torch.ops.prims.convert_element_type.default(mul_62, torch.bfloat16); mul_62 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/pytorch_yf225/torch/nn/modules/linear.py:116 in forward, code: return F.linear(input, self.weight, self.bias) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] permute_141: "bf16[2048, 5632]" = torch.ops.aten.permute.default(contiguous_view_as_strided_72, [1, 0]); contiguous_view_as_strided_72 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mm_54: "bf16[2048, 5632]" = torch.ops.aten.mm.default(mul_61, permute_141); permute_141 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:299 in forward, code: return self.w2(F.silu(self.w1(x)) * self.w3(x)) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mul_63: "bf16[2048, 5632]" = torch.ops.aten.mul.Tensor(convert_element_type_262, mm_54); convert_element_type_262 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/pytorch_yf225/torch/nn/modules/linear.py:116 in forward, code: return F.linear(input, self.weight, self.bias) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] permute_143: "bf16[5632, 2048]" = torch.ops.aten.permute.default(contiguous_view_as_strided_71, [1, 0]); contiguous_view_as_strided_71 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mm_55: "bf16[2048, 2048]" = torch.ops.aten.mm.default(mul_63, permute_143); permute_143 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:411 in forward, code: out = h + self.feed_forward(self.ffn_norm(h)) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] add_31: "bf16[2048, 2048]" = torch.ops.aten.add.Tensor(add_29, mm_55); add_29 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_common.py:140 in _to_dtype_if_needed, code: return tensor.to(dtype) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_267: "bf16[524288]" = torch.ops.prims.convert_element_type.default(primals_154, torch.bfloat16); primals_154 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_268: "bf16[524288]" = torch.ops.prims.convert_element_type.default(primals_155, torch.bfloat16); primals_155 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_269: "bf16[524288]" = torch.ops.prims.convert_element_type.default(primals_156, torch.bfloat16); primals_156 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_270: "bf16[524288]" = torch.ops.prims.convert_element_type.default(primals_157, torch.bfloat16); primals_157 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_271: "bf16[1441792]" = torch.ops.prims.convert_element_type.default(primals_158, torch.bfloat16); primals_158 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_272: "bf16[1441792]" = torch.ops.prims.convert_element_type.default(primals_159, torch.bfloat16); primals_159 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_273: "bf16[1441792]" = torch.ops.prims.convert_element_type.default(primals_160, torch.bfloat16); primals_160 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_274: "bf16[256]" = torch.ops.prims.convert_element_type.default(primals_161, torch.bfloat16); primals_161 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_275: "bf16[256]" = torch.ops.prims.convert_element_type.default(primals_162, torch.bfloat16); primals_162 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:80 in foreach_all_gather, code: all_gather_input, all_gather_output = torch.ops.fsdp.all_gather_copy_in(all_gather_input_numel, world_size, rank, dtype, device, inp_split_sizes, param_all_gather_inputs) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] all_gather_copy_in_9 = torch.ops.fsdp.all_gather_copy_in.default(6423040, 8, 0, torch.bfloat16, device(type='cuda', index=0), [524288, 524288, 524288, 524288, 1441792, 1441792, 1441792, 256, 256], [convert_element_type_267, convert_element_type_268, convert_element_type_269, convert_element_type_270, convert_element_type_271, convert_element_type_272, convert_element_type_273, convert_element_type_274, convert_element_type_275]); convert_element_type_267 = convert_element_type_268 = convert_element_type_269 = convert_element_type_270 = convert_element_type_271 = convert_element_type_272 = convert_element_type_273 = convert_element_type_274 = convert_element_type_275 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_822: "bf16[6423040]" = all_gather_copy_in_9[0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_823: "bf16[51384320]" = all_gather_copy_in_9[1]; all_gather_copy_in_9 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_functional_collectives.py:229 in all_gather_tensor, code: tensor = torch.ops._c10d_functional.all_gather_into_tensor( | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] all_gather_into_tensor_9: "bf16[51384320]" = torch.ops._c10d_functional.all_gather_into_tensor.default(getitem_822, 8, '0'); getitem_822 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_functional_collectives.py:144 in wait_tensor, code: return torch.ops._c10d_functional.wait_tensor(tensor) # type: ignore[attr-defined] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] wait_tensor_9: "bf16[51384320]" = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_9); all_gather_into_tensor_9 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:170 in foreach_all_gather_copy_out, code: split_unpadded = torch.ops.fsdp.contiguous_view_as_strided( | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_158: "bf16[8, 6423040]" = torch.ops.aten.reshape.default(wait_tensor_9, [8, -1]); wait_tensor_9 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] split_with_sizes_85 = torch.ops.aten.split_with_sizes.default(view_158, [524288, 524288, 524288, 524288, 1441792, 1441792, 1441792, 256, 256], 1); view_158 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_833: "bf16[8, 524288]" = split_with_sizes_85[0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] contiguous_view_as_strided_75: "bf16[2048, 2048]" = torch.ops.fsdp.contiguous_view_as_strided.default(getitem_833, [2048, 2048], [2048, 1]); getitem_833 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_843: "bf16[8, 524288]" = split_with_sizes_85[1] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] contiguous_view_as_strided_76: "bf16[2048, 2048]" = torch.ops.fsdp.contiguous_view_as_strided.default(getitem_843, [2048, 2048], [2048, 1]); getitem_843 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_853: "bf16[8, 524288]" = split_with_sizes_85[2] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] contiguous_view_as_strided_77: "bf16[2048, 2048]" = torch.ops.fsdp.contiguous_view_as_strided.default(getitem_853, [2048, 2048], [2048, 1]); getitem_853 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_863: "bf16[8, 524288]" = split_with_sizes_85[3] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] contiguous_view_as_strided_78: "bf16[2048, 2048]" = torch.ops.fsdp.contiguous_view_as_strided.default(getitem_863, [2048, 2048], [2048, 1]); getitem_863 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_873: "bf16[8, 1441792]" = split_with_sizes_85[4] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] contiguous_view_as_strided_79: "bf16[5632, 2048]" = torch.ops.fsdp.contiguous_view_as_strided.default(getitem_873, [5632, 2048], [2048, 1]); getitem_873 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_883: "bf16[8, 1441792]" = split_with_sizes_85[5] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] contiguous_view_as_strided_80: "bf16[2048, 5632]" = torch.ops.fsdp.contiguous_view_as_strided.default(getitem_883, [2048, 5632], [5632, 1]); getitem_883 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_893: "bf16[8, 1441792]" = split_with_sizes_85[6] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] contiguous_view_as_strided_81: "bf16[5632, 2048]" = torch.ops.fsdp.contiguous_view_as_strided.default(getitem_893, [5632, 2048], [2048, 1]); getitem_893 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_903: "bf16[8, 256]" = split_with_sizes_85[7] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] contiguous_view_as_strided_82: "bf16[2048]" = torch.ops.fsdp.contiguous_view_as_strided.default(getitem_903, [2048], [1]); getitem_903 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_913: "bf16[8, 256]" = split_with_sizes_85[8]; split_with_sizes_85 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] contiguous_view_as_strided_83: "bf16[2048]" = torch.ops.fsdp.contiguous_view_as_strided.default(getitem_913, [2048], [1]); getitem_913 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:74 in forward, code: output = self._norm(x.float()).type_as(x) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_276: "f32[2048, 2048]" = torch.ops.prims.convert_element_type.default(add_31, torch.float32) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:61 in _norm, code: return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] pow_17: "f32[2048, 2048]" = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_276, 2) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mean_16: "f32[2048, 1]" = torch.ops.aten.mean.dim(pow_17, [-1], True); pow_17 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] add_32: "f32[2048, 1]" = torch.ops.aten.add.Tensor(mean_16, 1e-05); mean_16 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] rsqrt_16: "f32[2048, 1]" = torch.ops.aten.rsqrt.default(add_32); add_32 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mul_64: "f32[2048, 2048]" = torch.ops.aten.mul.Tensor(convert_element_type_276, rsqrt_16); convert_element_type_276 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:74 in forward, code: output = self._norm(x.float()).type_as(x) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_277: "bf16[2048, 2048]" = torch.ops.prims.convert_element_type.default(mul_64, torch.bfloat16); mul_64 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:75 in forward, code: return output * self.weight | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mul_65: "bf16[2048, 2048]" = torch.ops.aten.mul.Tensor(convert_element_type_277, contiguous_view_as_strided_82); convert_element_type_277 = contiguous_view_as_strided_82 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/pytorch_yf225/torch/nn/modules/linear.py:116 in forward, code: return F.linear(input, self.weight, self.bias) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] permute_145: "bf16[2048, 2048]" = torch.ops.aten.permute.default(contiguous_view_as_strided_75, [1, 0]); contiguous_view_as_strided_75 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mm_56: "bf16[2048, 2048]" = torch.ops.aten.mm.default(mul_65, permute_145); permute_145 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] permute_147: "bf16[2048, 2048]" = torch.ops.aten.permute.default(contiguous_view_as_strided_76, [1, 0]); contiguous_view_as_strided_76 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mm_57: "bf16[2048, 2048]" = torch.ops.aten.mm.default(mul_65, permute_147); permute_147 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] permute_149: "bf16[2048, 2048]" = torch.ops.aten.permute.default(contiguous_view_as_strided_77, [1, 0]); contiguous_view_as_strided_77 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mm_58: "bf16[2048, 2048]" = torch.ops.aten.mm.default(mul_65, permute_149); permute_149 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:235 in forward, code: xq = xq.view(bsz, seqlen, self.n_heads, self.head_dim) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_167: "bf16[1, 2048, 16, 128]" = torch.ops.aten.reshape.default(mm_56, [1, 2048, 16, 128]); mm_56 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:236 in forward, code: xk = xk.view(bsz, seqlen, self.n_kv_heads, self.head_dim) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_168: "bf16[1, 2048, 16, 128]" = torch.ops.aten.reshape.default(mm_57, [1, 2048, 16, 128]); mm_57 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:237 in forward, code: xv = xv.view(bsz, seqlen, self.n_kv_heads, self.head_dim) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_169: "bf16[1, 2048, 16, 128]" = torch.ops.aten.reshape.default(mm_58, [1, 2048, 16, 128]); mm_58 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:146 in apply_rotary_emb, code: xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_284: "f32[1, 2048, 16, 128]" = torch.ops.prims.convert_element_type.default(view_167, torch.float32); view_167 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_170: "f32[1, 2048, 16, 64, 2]" = torch.ops.aten.reshape.default(convert_element_type_284, [1, 2048, 16, -1, 2]); convert_element_type_284 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_as_complex_16: "c64[1, 2048, 16, 64]" = torch.ops.aten.view_as_complex.default(view_170); view_170 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:147 in apply_rotary_emb, code: xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2)) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_285: "f32[1, 2048, 16, 128]" = torch.ops.prims.convert_element_type.default(view_168, torch.float32); view_168 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_171: "f32[1, 2048, 16, 64, 2]" = torch.ops.aten.reshape.default(convert_element_type_285, [1, 2048, 16, -1, 2]); convert_element_type_285 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_as_complex_17: "c64[1, 2048, 16, 64]" = torch.ops.aten.view_as_complex.default(view_171); view_171 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:149 in apply_rotary_emb, code: xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mul_66: "c64[1, 2048, 16, 64]" = torch.ops.aten.mul.Tensor(view_as_complex_16, view_20); view_as_complex_16 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_as_real_16: "f32[1, 2048, 16, 64, 2]" = torch.ops.aten.view_as_real.default(mul_66); mul_66 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_173: "f32[1, 2048, 16, 128]" = torch.ops.aten.reshape.default(view_as_real_16, [1, 2048, 16, 128]); view_as_real_16 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:150 in apply_rotary_emb, code: xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mul_67: "c64[1, 2048, 16, 64]" = torch.ops.aten.mul.Tensor(view_as_complex_17, view_20); view_as_complex_17 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_as_real_17: "f32[1, 2048, 16, 64, 2]" = torch.ops.aten.view_as_real.default(mul_67); mul_67 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_174: "f32[1, 2048, 16, 128]" = torch.ops.aten.reshape.default(view_as_real_17, [1, 2048, 16, 128]); view_as_real_17 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:151 in apply_rotary_emb, code: return xq_out.type_as(xq), xk_out.type_as(xk) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_286: "bf16[1, 2048, 16, 128]" = torch.ops.prims.convert_element_type.default(view_173, torch.bfloat16); view_173 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_287: "bf16[1, 2048, 16, 128]" = torch.ops.prims.convert_element_type.default(view_174, torch.bfloat16); view_174 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:249 in forward, code: xq = xq.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] permute_150: "bf16[1, 16, 2048, 128]" = torch.ops.aten.permute.default(convert_element_type_286, [0, 2, 1, 3]); convert_element_type_286 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:250 in forward, code: xk = keys.transpose(1, 2) # (bs, n_local_heads, cache_len + seqlen, head_dim) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] permute_151: "bf16[1, 16, 2048, 128]" = torch.ops.aten.permute.default(convert_element_type_287, [0, 2, 1, 3]); convert_element_type_287 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:251 in forward, code: xv = values.transpose(1, 2) # (bs, n_local_heads, cache_len + seqlen, head_dim) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] permute_152: "bf16[1, 16, 2048, 128]" = torch.ops.aten.permute.default(view_169, [0, 2, 1, 3]); view_169 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:254 in forward, code: output = F.scaled_dot_product_attention(xq, xk, xv, is_causal=True) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] _scaled_dot_product_flash_attention_8 = torch.ops.aten._scaled_dot_product_flash_attention.default(permute_150, permute_151, permute_152, 0.0, True, scale = 0.08838834764831843) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_914: "bf16[1, 16, 2048, 128]" = _scaled_dot_product_flash_attention_8[0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_915: "f32[1, 16, 2048]" = _scaled_dot_product_flash_attention_8[1] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_920: "i64[]" = _scaled_dot_product_flash_attention_8[6] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_921: "i64[]" = _scaled_dot_product_flash_attention_8[7]; _scaled_dot_product_flash_attention_8 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:255 in forward, code: output = output.transpose( | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] permute_153: "bf16[1, 2048, 16, 128]" = torch.ops.aten.permute.default(getitem_914, [0, 2, 1, 3]) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:259 in forward, code: output = output.view(bsz * seqlen, -1) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_175: "bf16[2048, 2048]" = torch.ops.aten.reshape.default(permute_153, [2048, -1]); permute_153 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/pytorch_yf225/torch/nn/modules/linear.py:116 in forward, code: return F.linear(input, self.weight, self.bias) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] permute_155: "bf16[2048, 2048]" = torch.ops.aten.permute.default(contiguous_view_as_strided_78, [1, 0]); contiguous_view_as_strided_78 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mm_59: "bf16[2048, 2048]" = torch.ops.aten.mm.default(view_175, permute_155); permute_155 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:410 in forward, code: h = x + self.attention(self.attention_norm(x), freqs_cis) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] add_33: "bf16[2048, 2048]" = torch.ops.aten.add.Tensor(add_31, mm_59); add_31 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:74 in forward, code: output = self._norm(x.float()).type_as(x) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_290: "f32[2048, 2048]" = torch.ops.prims.convert_element_type.default(add_33, torch.float32) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:61 in _norm, code: return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] pow_18: "f32[2048, 2048]" = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_290, 2) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mean_17: "f32[2048, 1]" = torch.ops.aten.mean.dim(pow_18, [-1], True); pow_18 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] add_34: "f32[2048, 1]" = torch.ops.aten.add.Tensor(mean_17, 1e-05); mean_17 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] rsqrt_17: "f32[2048, 1]" = torch.ops.aten.rsqrt.default(add_34); add_34 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mul_68: "f32[2048, 2048]" = torch.ops.aten.mul.Tensor(convert_element_type_290, rsqrt_17); convert_element_type_290 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:74 in forward, code: output = self._norm(x.float()).type_as(x) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_291: "bf16[2048, 2048]" = torch.ops.prims.convert_element_type.default(mul_68, torch.bfloat16); mul_68 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:75 in forward, code: return output * self.weight | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mul_69: "bf16[2048, 2048]" = torch.ops.aten.mul.Tensor(convert_element_type_291, contiguous_view_as_strided_83); convert_element_type_291 = contiguous_view_as_strided_83 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/pytorch_yf225/torch/nn/modules/linear.py:116 in forward, code: return F.linear(input, self.weight, self.bias) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] permute_157: "bf16[2048, 5632]" = torch.ops.aten.permute.default(contiguous_view_as_strided_79, [1, 0]); contiguous_view_as_strided_79 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mm_60: "bf16[2048, 5632]" = torch.ops.aten.mm.default(mul_69, permute_157); permute_157 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:299 in forward, code: return self.w2(F.silu(self.w1(x)) * self.w3(x)) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_294: "f32[2048, 5632]" = torch.ops.prims.convert_element_type.default(mm_60, torch.float32) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] sigmoid_8: "f32[2048, 5632]" = torch.ops.aten.sigmoid.default(convert_element_type_294) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mul_70: "f32[2048, 5632]" = torch.ops.aten.mul.Tensor(convert_element_type_294, sigmoid_8); convert_element_type_294 = sigmoid_8 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_295: "bf16[2048, 5632]" = torch.ops.prims.convert_element_type.default(mul_70, torch.bfloat16); mul_70 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/pytorch_yf225/torch/nn/modules/linear.py:116 in forward, code: return F.linear(input, self.weight, self.bias) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] permute_159: "bf16[2048, 5632]" = torch.ops.aten.permute.default(contiguous_view_as_strided_81, [1, 0]); contiguous_view_as_strided_81 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mm_61: "bf16[2048, 5632]" = torch.ops.aten.mm.default(mul_69, permute_159); permute_159 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:299 in forward, code: return self.w2(F.silu(self.w1(x)) * self.w3(x)) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mul_71: "bf16[2048, 5632]" = torch.ops.aten.mul.Tensor(convert_element_type_295, mm_61); convert_element_type_295 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/pytorch_yf225/torch/nn/modules/linear.py:116 in forward, code: return F.linear(input, self.weight, self.bias) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] permute_161: "bf16[5632, 2048]" = torch.ops.aten.permute.default(contiguous_view_as_strided_80, [1, 0]); contiguous_view_as_strided_80 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mm_62: "bf16[2048, 2048]" = torch.ops.aten.mm.default(mul_71, permute_161); permute_161 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:411 in forward, code: out = h + self.feed_forward(self.ffn_norm(h)) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] add_35: "bf16[2048, 2048]" = torch.ops.aten.add.Tensor(add_33, mm_62); add_33 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_common.py:140 in _to_dtype_if_needed, code: return tensor.to(dtype) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_300: "bf16[524288]" = torch.ops.prims.convert_element_type.default(primals_172, torch.bfloat16); primals_172 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_301: "bf16[524288]" = torch.ops.prims.convert_element_type.default(primals_173, torch.bfloat16); primals_173 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_302: "bf16[524288]" = torch.ops.prims.convert_element_type.default(primals_174, torch.bfloat16); primals_174 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_303: "bf16[524288]" = torch.ops.prims.convert_element_type.default(primals_175, torch.bfloat16); primals_175 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_304: "bf16[1441792]" = torch.ops.prims.convert_element_type.default(primals_176, torch.bfloat16); primals_176 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_305: "bf16[1441792]" = torch.ops.prims.convert_element_type.default(primals_177, torch.bfloat16); primals_177 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_306: "bf16[1441792]" = torch.ops.prims.convert_element_type.default(primals_178, torch.bfloat16); primals_178 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_307: "bf16[256]" = torch.ops.prims.convert_element_type.default(primals_179, torch.bfloat16); primals_179 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_308: "bf16[256]" = torch.ops.prims.convert_element_type.default(primals_180, torch.bfloat16); primals_180 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:80 in foreach_all_gather, code: all_gather_input, all_gather_output = torch.ops.fsdp.all_gather_copy_in(all_gather_input_numel, world_size, rank, dtype, device, inp_split_sizes, param_all_gather_inputs) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] all_gather_copy_in_10 = torch.ops.fsdp.all_gather_copy_in.default(6423040, 8, 0, torch.bfloat16, device(type='cuda', index=0), [524288, 524288, 524288, 524288, 1441792, 1441792, 1441792, 256, 256], [convert_element_type_300, convert_element_type_301, convert_element_type_302, convert_element_type_303, convert_element_type_304, convert_element_type_305, convert_element_type_306, convert_element_type_307, convert_element_type_308]); convert_element_type_300 = convert_element_type_301 = convert_element_type_302 = convert_element_type_303 = convert_element_type_304 = convert_element_type_305 = convert_element_type_306 = convert_element_type_307 = convert_element_type_308 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_923: "bf16[6423040]" = all_gather_copy_in_10[0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_924: "bf16[51384320]" = all_gather_copy_in_10[1]; all_gather_copy_in_10 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_functional_collectives.py:229 in all_gather_tensor, code: tensor = torch.ops._c10d_functional.all_gather_into_tensor( | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] all_gather_into_tensor_10: "bf16[51384320]" = torch.ops._c10d_functional.all_gather_into_tensor.default(getitem_923, 8, '0'); getitem_923 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_functional_collectives.py:144 in wait_tensor, code: return torch.ops._c10d_functional.wait_tensor(tensor) # type: ignore[attr-defined] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] wait_tensor_10: "bf16[51384320]" = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_10); all_gather_into_tensor_10 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:170 in foreach_all_gather_copy_out, code: split_unpadded = torch.ops.fsdp.contiguous_view_as_strided( | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_177: "bf16[8, 6423040]" = torch.ops.aten.reshape.default(wait_tensor_10, [8, -1]); wait_tensor_10 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] split_with_sizes_95 = torch.ops.aten.split_with_sizes.default(view_177, [524288, 524288, 524288, 524288, 1441792, 1441792, 1441792, 256, 256], 1); view_177 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_934: "bf16[8, 524288]" = split_with_sizes_95[0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] contiguous_view_as_strided_84: "bf16[2048, 2048]" = torch.ops.fsdp.contiguous_view_as_strided.default(getitem_934, [2048, 2048], [2048, 1]); getitem_934 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_944: "bf16[8, 524288]" = split_with_sizes_95[1] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] contiguous_view_as_strided_85: "bf16[2048, 2048]" = torch.ops.fsdp.contiguous_view_as_strided.default(getitem_944, [2048, 2048], [2048, 1]); getitem_944 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_954: "bf16[8, 524288]" = split_with_sizes_95[2] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] contiguous_view_as_strided_86: "bf16[2048, 2048]" = torch.ops.fsdp.contiguous_view_as_strided.default(getitem_954, [2048, 2048], [2048, 1]); getitem_954 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_964: "bf16[8, 524288]" = split_with_sizes_95[3] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] contiguous_view_as_strided_87: "bf16[2048, 2048]" = torch.ops.fsdp.contiguous_view_as_strided.default(getitem_964, [2048, 2048], [2048, 1]); getitem_964 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_974: "bf16[8, 1441792]" = split_with_sizes_95[4] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] contiguous_view_as_strided_88: "bf16[5632, 2048]" = torch.ops.fsdp.contiguous_view_as_strided.default(getitem_974, [5632, 2048], [2048, 1]); getitem_974 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_984: "bf16[8, 1441792]" = split_with_sizes_95[5] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] contiguous_view_as_strided_89: "bf16[2048, 5632]" = torch.ops.fsdp.contiguous_view_as_strided.default(getitem_984, [2048, 5632], [5632, 1]); getitem_984 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_994: "bf16[8, 1441792]" = split_with_sizes_95[6] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] contiguous_view_as_strided_90: "bf16[5632, 2048]" = torch.ops.fsdp.contiguous_view_as_strided.default(getitem_994, [5632, 2048], [2048, 1]); getitem_994 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_1004: "bf16[8, 256]" = split_with_sizes_95[7] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] contiguous_view_as_strided_91: "bf16[2048]" = torch.ops.fsdp.contiguous_view_as_strided.default(getitem_1004, [2048], [1]); getitem_1004 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_1014: "bf16[8, 256]" = split_with_sizes_95[8]; split_with_sizes_95 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] contiguous_view_as_strided_92: "bf16[2048]" = torch.ops.fsdp.contiguous_view_as_strided.default(getitem_1014, [2048], [1]); getitem_1014 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:74 in forward, code: output = self._norm(x.float()).type_as(x) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_309: "f32[2048, 2048]" = torch.ops.prims.convert_element_type.default(add_35, torch.float32) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:61 in _norm, code: return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] pow_19: "f32[2048, 2048]" = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_309, 2) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mean_18: "f32[2048, 1]" = torch.ops.aten.mean.dim(pow_19, [-1], True); pow_19 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] add_36: "f32[2048, 1]" = torch.ops.aten.add.Tensor(mean_18, 1e-05); mean_18 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] rsqrt_18: "f32[2048, 1]" = torch.ops.aten.rsqrt.default(add_36); add_36 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mul_72: "f32[2048, 2048]" = torch.ops.aten.mul.Tensor(convert_element_type_309, rsqrt_18); convert_element_type_309 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:74 in forward, code: output = self._norm(x.float()).type_as(x) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_310: "bf16[2048, 2048]" = torch.ops.prims.convert_element_type.default(mul_72, torch.bfloat16); mul_72 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:75 in forward, code: return output * self.weight | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mul_73: "bf16[2048, 2048]" = torch.ops.aten.mul.Tensor(convert_element_type_310, contiguous_view_as_strided_91); convert_element_type_310 = contiguous_view_as_strided_91 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/pytorch_yf225/torch/nn/modules/linear.py:116 in forward, code: return F.linear(input, self.weight, self.bias) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] permute_163: "bf16[2048, 2048]" = torch.ops.aten.permute.default(contiguous_view_as_strided_84, [1, 0]); contiguous_view_as_strided_84 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mm_63: "bf16[2048, 2048]" = torch.ops.aten.mm.default(mul_73, permute_163); permute_163 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] permute_165: "bf16[2048, 2048]" = torch.ops.aten.permute.default(contiguous_view_as_strided_85, [1, 0]); contiguous_view_as_strided_85 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mm_64: "bf16[2048, 2048]" = torch.ops.aten.mm.default(mul_73, permute_165); permute_165 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] permute_167: "bf16[2048, 2048]" = torch.ops.aten.permute.default(contiguous_view_as_strided_86, [1, 0]); contiguous_view_as_strided_86 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mm_65: "bf16[2048, 2048]" = torch.ops.aten.mm.default(mul_73, permute_167); permute_167 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:235 in forward, code: xq = xq.view(bsz, seqlen, self.n_heads, self.head_dim) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_186: "bf16[1, 2048, 16, 128]" = torch.ops.aten.reshape.default(mm_63, [1, 2048, 16, 128]); mm_63 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:236 in forward, code: xk = xk.view(bsz, seqlen, self.n_kv_heads, self.head_dim) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_187: "bf16[1, 2048, 16, 128]" = torch.ops.aten.reshape.default(mm_64, [1, 2048, 16, 128]); mm_64 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:237 in forward, code: xv = xv.view(bsz, seqlen, self.n_kv_heads, self.head_dim) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_188: "bf16[1, 2048, 16, 128]" = torch.ops.aten.reshape.default(mm_65, [1, 2048, 16, 128]); mm_65 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:146 in apply_rotary_emb, code: xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_317: "f32[1, 2048, 16, 128]" = torch.ops.prims.convert_element_type.default(view_186, torch.float32); view_186 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_189: "f32[1, 2048, 16, 64, 2]" = torch.ops.aten.reshape.default(convert_element_type_317, [1, 2048, 16, -1, 2]); convert_element_type_317 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_as_complex_18: "c64[1, 2048, 16, 64]" = torch.ops.aten.view_as_complex.default(view_189); view_189 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:147 in apply_rotary_emb, code: xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2)) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_318: "f32[1, 2048, 16, 128]" = torch.ops.prims.convert_element_type.default(view_187, torch.float32); view_187 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_190: "f32[1, 2048, 16, 64, 2]" = torch.ops.aten.reshape.default(convert_element_type_318, [1, 2048, 16, -1, 2]); convert_element_type_318 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_as_complex_19: "c64[1, 2048, 16, 64]" = torch.ops.aten.view_as_complex.default(view_190); view_190 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:149 in apply_rotary_emb, code: xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mul_74: "c64[1, 2048, 16, 64]" = torch.ops.aten.mul.Tensor(view_as_complex_18, view_20); view_as_complex_18 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_as_real_18: "f32[1, 2048, 16, 64, 2]" = torch.ops.aten.view_as_real.default(mul_74); mul_74 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_192: "f32[1, 2048, 16, 128]" = torch.ops.aten.reshape.default(view_as_real_18, [1, 2048, 16, 128]); view_as_real_18 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:150 in apply_rotary_emb, code: xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mul_75: "c64[1, 2048, 16, 64]" = torch.ops.aten.mul.Tensor(view_as_complex_19, view_20); view_as_complex_19 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_as_real_19: "f32[1, 2048, 16, 64, 2]" = torch.ops.aten.view_as_real.default(mul_75); mul_75 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_193: "f32[1, 2048, 16, 128]" = torch.ops.aten.reshape.default(view_as_real_19, [1, 2048, 16, 128]); view_as_real_19 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:151 in apply_rotary_emb, code: return xq_out.type_as(xq), xk_out.type_as(xk) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_319: "bf16[1, 2048, 16, 128]" = torch.ops.prims.convert_element_type.default(view_192, torch.bfloat16); view_192 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_320: "bf16[1, 2048, 16, 128]" = torch.ops.prims.convert_element_type.default(view_193, torch.bfloat16); view_193 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:249 in forward, code: xq = xq.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] permute_168: "bf16[1, 16, 2048, 128]" = torch.ops.aten.permute.default(convert_element_type_319, [0, 2, 1, 3]); convert_element_type_319 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:250 in forward, code: xk = keys.transpose(1, 2) # (bs, n_local_heads, cache_len + seqlen, head_dim) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] permute_169: "bf16[1, 16, 2048, 128]" = torch.ops.aten.permute.default(convert_element_type_320, [0, 2, 1, 3]); convert_element_type_320 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:251 in forward, code: xv = values.transpose(1, 2) # (bs, n_local_heads, cache_len + seqlen, head_dim) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] permute_170: "bf16[1, 16, 2048, 128]" = torch.ops.aten.permute.default(view_188, [0, 2, 1, 3]); view_188 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:254 in forward, code: output = F.scaled_dot_product_attention(xq, xk, xv, is_causal=True) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] _scaled_dot_product_flash_attention_9 = torch.ops.aten._scaled_dot_product_flash_attention.default(permute_168, permute_169, permute_170, 0.0, True, scale = 0.08838834764831843) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_1015: "bf16[1, 16, 2048, 128]" = _scaled_dot_product_flash_attention_9[0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_1016: "f32[1, 16, 2048]" = _scaled_dot_product_flash_attention_9[1] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_1021: "i64[]" = _scaled_dot_product_flash_attention_9[6] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_1022: "i64[]" = _scaled_dot_product_flash_attention_9[7]; _scaled_dot_product_flash_attention_9 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:255 in forward, code: output = output.transpose( | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] permute_171: "bf16[1, 2048, 16, 128]" = torch.ops.aten.permute.default(getitem_1015, [0, 2, 1, 3]) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:259 in forward, code: output = output.view(bsz * seqlen, -1) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_194: "bf16[2048, 2048]" = torch.ops.aten.reshape.default(permute_171, [2048, -1]); permute_171 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/pytorch_yf225/torch/nn/modules/linear.py:116 in forward, code: return F.linear(input, self.weight, self.bias) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] permute_173: "bf16[2048, 2048]" = torch.ops.aten.permute.default(contiguous_view_as_strided_87, [1, 0]); contiguous_view_as_strided_87 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mm_66: "bf16[2048, 2048]" = torch.ops.aten.mm.default(view_194, permute_173); permute_173 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:410 in forward, code: h = x + self.attention(self.attention_norm(x), freqs_cis) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] add_37: "bf16[2048, 2048]" = torch.ops.aten.add.Tensor(add_35, mm_66); add_35 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:74 in forward, code: output = self._norm(x.float()).type_as(x) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_323: "f32[2048, 2048]" = torch.ops.prims.convert_element_type.default(add_37, torch.float32) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:61 in _norm, code: return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] pow_20: "f32[2048, 2048]" = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_323, 2) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mean_19: "f32[2048, 1]" = torch.ops.aten.mean.dim(pow_20, [-1], True); pow_20 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] add_38: "f32[2048, 1]" = torch.ops.aten.add.Tensor(mean_19, 1e-05); mean_19 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] rsqrt_19: "f32[2048, 1]" = torch.ops.aten.rsqrt.default(add_38); add_38 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mul_76: "f32[2048, 2048]" = torch.ops.aten.mul.Tensor(convert_element_type_323, rsqrt_19); convert_element_type_323 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:74 in forward, code: output = self._norm(x.float()).type_as(x) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_324: "bf16[2048, 2048]" = torch.ops.prims.convert_element_type.default(mul_76, torch.bfloat16); mul_76 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:75 in forward, code: return output * self.weight | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mul_77: "bf16[2048, 2048]" = torch.ops.aten.mul.Tensor(convert_element_type_324, contiguous_view_as_strided_92); convert_element_type_324 = contiguous_view_as_strided_92 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/pytorch_yf225/torch/nn/modules/linear.py:116 in forward, code: return F.linear(input, self.weight, self.bias) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] permute_175: "bf16[2048, 5632]" = torch.ops.aten.permute.default(contiguous_view_as_strided_88, [1, 0]); contiguous_view_as_strided_88 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mm_67: "bf16[2048, 5632]" = torch.ops.aten.mm.default(mul_77, permute_175); permute_175 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:299 in forward, code: return self.w2(F.silu(self.w1(x)) * self.w3(x)) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_327: "f32[2048, 5632]" = torch.ops.prims.convert_element_type.default(mm_67, torch.float32) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] sigmoid_9: "f32[2048, 5632]" = torch.ops.aten.sigmoid.default(convert_element_type_327) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mul_78: "f32[2048, 5632]" = torch.ops.aten.mul.Tensor(convert_element_type_327, sigmoid_9); convert_element_type_327 = sigmoid_9 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_328: "bf16[2048, 5632]" = torch.ops.prims.convert_element_type.default(mul_78, torch.bfloat16); mul_78 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/pytorch_yf225/torch/nn/modules/linear.py:116 in forward, code: return F.linear(input, self.weight, self.bias) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] permute_177: "bf16[2048, 5632]" = torch.ops.aten.permute.default(contiguous_view_as_strided_90, [1, 0]); contiguous_view_as_strided_90 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mm_68: "bf16[2048, 5632]" = torch.ops.aten.mm.default(mul_77, permute_177); permute_177 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:299 in forward, code: return self.w2(F.silu(self.w1(x)) * self.w3(x)) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mul_79: "bf16[2048, 5632]" = torch.ops.aten.mul.Tensor(convert_element_type_328, mm_68); convert_element_type_328 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/pytorch_yf225/torch/nn/modules/linear.py:116 in forward, code: return F.linear(input, self.weight, self.bias) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] permute_179: "bf16[5632, 2048]" = torch.ops.aten.permute.default(contiguous_view_as_strided_89, [1, 0]); contiguous_view_as_strided_89 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mm_69: "bf16[2048, 2048]" = torch.ops.aten.mm.default(mul_79, permute_179); permute_179 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:411 in forward, code: out = h + self.feed_forward(self.ffn_norm(h)) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] add_39: "bf16[2048, 2048]" = torch.ops.aten.add.Tensor(add_37, mm_69); add_37 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_common.py:140 in _to_dtype_if_needed, code: return tensor.to(dtype) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_333: "bf16[524288]" = torch.ops.prims.convert_element_type.default(primals_190, torch.bfloat16); primals_190 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_334: "bf16[524288]" = torch.ops.prims.convert_element_type.default(primals_191, torch.bfloat16); primals_191 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_335: "bf16[524288]" = torch.ops.prims.convert_element_type.default(primals_192, torch.bfloat16); primals_192 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_336: "bf16[524288]" = torch.ops.prims.convert_element_type.default(primals_193, torch.bfloat16); primals_193 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_337: "bf16[1441792]" = torch.ops.prims.convert_element_type.default(primals_194, torch.bfloat16); primals_194 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_338: "bf16[1441792]" = torch.ops.prims.convert_element_type.default(primals_195, torch.bfloat16); primals_195 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_339: "bf16[1441792]" = torch.ops.prims.convert_element_type.default(primals_196, torch.bfloat16); primals_196 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_340: "bf16[256]" = torch.ops.prims.convert_element_type.default(primals_197, torch.bfloat16); primals_197 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_341: "bf16[256]" = torch.ops.prims.convert_element_type.default(primals_198, torch.bfloat16); primals_198 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:80 in foreach_all_gather, code: all_gather_input, all_gather_output = torch.ops.fsdp.all_gather_copy_in(all_gather_input_numel, world_size, rank, dtype, device, inp_split_sizes, param_all_gather_inputs) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] all_gather_copy_in_11 = torch.ops.fsdp.all_gather_copy_in.default(6423040, 8, 0, torch.bfloat16, device(type='cuda', index=0), [524288, 524288, 524288, 524288, 1441792, 1441792, 1441792, 256, 256], [convert_element_type_333, convert_element_type_334, convert_element_type_335, convert_element_type_336, convert_element_type_337, convert_element_type_338, convert_element_type_339, convert_element_type_340, convert_element_type_341]); convert_element_type_333 = convert_element_type_334 = convert_element_type_335 = convert_element_type_336 = convert_element_type_337 = convert_element_type_338 = convert_element_type_339 = convert_element_type_340 = convert_element_type_341 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_1024: "bf16[6423040]" = all_gather_copy_in_11[0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_1025: "bf16[51384320]" = all_gather_copy_in_11[1]; all_gather_copy_in_11 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_functional_collectives.py:229 in all_gather_tensor, code: tensor = torch.ops._c10d_functional.all_gather_into_tensor( | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] all_gather_into_tensor_11: "bf16[51384320]" = torch.ops._c10d_functional.all_gather_into_tensor.default(getitem_1024, 8, '0'); getitem_1024 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_functional_collectives.py:144 in wait_tensor, code: return torch.ops._c10d_functional.wait_tensor(tensor) # type: ignore[attr-defined] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] wait_tensor_11: "bf16[51384320]" = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_11); all_gather_into_tensor_11 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:170 in foreach_all_gather_copy_out, code: split_unpadded = torch.ops.fsdp.contiguous_view_as_strided( | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_196: "bf16[8, 6423040]" = torch.ops.aten.reshape.default(wait_tensor_11, [8, -1]); wait_tensor_11 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] split_with_sizes_105 = torch.ops.aten.split_with_sizes.default(view_196, [524288, 524288, 524288, 524288, 1441792, 1441792, 1441792, 256, 256], 1); view_196 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_1035: "bf16[8, 524288]" = split_with_sizes_105[0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] contiguous_view_as_strided_93: "bf16[2048, 2048]" = torch.ops.fsdp.contiguous_view_as_strided.default(getitem_1035, [2048, 2048], [2048, 1]); getitem_1035 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_1045: "bf16[8, 524288]" = split_with_sizes_105[1] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] contiguous_view_as_strided_94: "bf16[2048, 2048]" = torch.ops.fsdp.contiguous_view_as_strided.default(getitem_1045, [2048, 2048], [2048, 1]); getitem_1045 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_1055: "bf16[8, 524288]" = split_with_sizes_105[2] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] contiguous_view_as_strided_95: "bf16[2048, 2048]" = torch.ops.fsdp.contiguous_view_as_strided.default(getitem_1055, [2048, 2048], [2048, 1]); getitem_1055 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_1065: "bf16[8, 524288]" = split_with_sizes_105[3] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] contiguous_view_as_strided_96: "bf16[2048, 2048]" = torch.ops.fsdp.contiguous_view_as_strided.default(getitem_1065, [2048, 2048], [2048, 1]); getitem_1065 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_1075: "bf16[8, 1441792]" = split_with_sizes_105[4] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] contiguous_view_as_strided_97: "bf16[5632, 2048]" = torch.ops.fsdp.contiguous_view_as_strided.default(getitem_1075, [5632, 2048], [2048, 1]); getitem_1075 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_1085: "bf16[8, 1441792]" = split_with_sizes_105[5] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] contiguous_view_as_strided_98: "bf16[2048, 5632]" = torch.ops.fsdp.contiguous_view_as_strided.default(getitem_1085, [2048, 5632], [5632, 1]); getitem_1085 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_1095: "bf16[8, 1441792]" = split_with_sizes_105[6] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] contiguous_view_as_strided_99: "bf16[5632, 2048]" = torch.ops.fsdp.contiguous_view_as_strided.default(getitem_1095, [5632, 2048], [2048, 1]); getitem_1095 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_1105: "bf16[8, 256]" = split_with_sizes_105[7] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] contiguous_view_as_strided_100: "bf16[2048]" = torch.ops.fsdp.contiguous_view_as_strided.default(getitem_1105, [2048], [1]); getitem_1105 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_1115: "bf16[8, 256]" = split_with_sizes_105[8]; split_with_sizes_105 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] contiguous_view_as_strided_101: "bf16[2048]" = torch.ops.fsdp.contiguous_view_as_strided.default(getitem_1115, [2048], [1]); getitem_1115 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:74 in forward, code: output = self._norm(x.float()).type_as(x) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_342: "f32[2048, 2048]" = torch.ops.prims.convert_element_type.default(add_39, torch.float32) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:61 in _norm, code: return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] pow_21: "f32[2048, 2048]" = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_342, 2) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mean_20: "f32[2048, 1]" = torch.ops.aten.mean.dim(pow_21, [-1], True); pow_21 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] add_40: "f32[2048, 1]" = torch.ops.aten.add.Tensor(mean_20, 1e-05); mean_20 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] rsqrt_20: "f32[2048, 1]" = torch.ops.aten.rsqrt.default(add_40); add_40 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mul_80: "f32[2048, 2048]" = torch.ops.aten.mul.Tensor(convert_element_type_342, rsqrt_20); convert_element_type_342 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:74 in forward, code: output = self._norm(x.float()).type_as(x) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_343: "bf16[2048, 2048]" = torch.ops.prims.convert_element_type.default(mul_80, torch.bfloat16); mul_80 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:75 in forward, code: return output * self.weight | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mul_81: "bf16[2048, 2048]" = torch.ops.aten.mul.Tensor(convert_element_type_343, contiguous_view_as_strided_100); convert_element_type_343 = contiguous_view_as_strided_100 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/pytorch_yf225/torch/nn/modules/linear.py:116 in forward, code: return F.linear(input, self.weight, self.bias) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] permute_181: "bf16[2048, 2048]" = torch.ops.aten.permute.default(contiguous_view_as_strided_93, [1, 0]); contiguous_view_as_strided_93 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mm_70: "bf16[2048, 2048]" = torch.ops.aten.mm.default(mul_81, permute_181); permute_181 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] permute_183: "bf16[2048, 2048]" = torch.ops.aten.permute.default(contiguous_view_as_strided_94, [1, 0]); contiguous_view_as_strided_94 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mm_71: "bf16[2048, 2048]" = torch.ops.aten.mm.default(mul_81, permute_183); permute_183 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] permute_185: "bf16[2048, 2048]" = torch.ops.aten.permute.default(contiguous_view_as_strided_95, [1, 0]); contiguous_view_as_strided_95 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mm_72: "bf16[2048, 2048]" = torch.ops.aten.mm.default(mul_81, permute_185); permute_185 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:235 in forward, code: xq = xq.view(bsz, seqlen, self.n_heads, self.head_dim) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_205: "bf16[1, 2048, 16, 128]" = torch.ops.aten.reshape.default(mm_70, [1, 2048, 16, 128]); mm_70 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:236 in forward, code: xk = xk.view(bsz, seqlen, self.n_kv_heads, self.head_dim) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_206: "bf16[1, 2048, 16, 128]" = torch.ops.aten.reshape.default(mm_71, [1, 2048, 16, 128]); mm_71 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:237 in forward, code: xv = xv.view(bsz, seqlen, self.n_kv_heads, self.head_dim) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_207: "bf16[1, 2048, 16, 128]" = torch.ops.aten.reshape.default(mm_72, [1, 2048, 16, 128]); mm_72 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:146 in apply_rotary_emb, code: xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_350: "f32[1, 2048, 16, 128]" = torch.ops.prims.convert_element_type.default(view_205, torch.float32); view_205 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_208: "f32[1, 2048, 16, 64, 2]" = torch.ops.aten.reshape.default(convert_element_type_350, [1, 2048, 16, -1, 2]); convert_element_type_350 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_as_complex_20: "c64[1, 2048, 16, 64]" = torch.ops.aten.view_as_complex.default(view_208); view_208 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:147 in apply_rotary_emb, code: xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2)) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_351: "f32[1, 2048, 16, 128]" = torch.ops.prims.convert_element_type.default(view_206, torch.float32); view_206 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_209: "f32[1, 2048, 16, 64, 2]" = torch.ops.aten.reshape.default(convert_element_type_351, [1, 2048, 16, -1, 2]); convert_element_type_351 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_as_complex_21: "c64[1, 2048, 16, 64]" = torch.ops.aten.view_as_complex.default(view_209); view_209 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:149 in apply_rotary_emb, code: xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mul_82: "c64[1, 2048, 16, 64]" = torch.ops.aten.mul.Tensor(view_as_complex_20, view_20); view_as_complex_20 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_as_real_20: "f32[1, 2048, 16, 64, 2]" = torch.ops.aten.view_as_real.default(mul_82); mul_82 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_211: "f32[1, 2048, 16, 128]" = torch.ops.aten.reshape.default(view_as_real_20, [1, 2048, 16, 128]); view_as_real_20 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:150 in apply_rotary_emb, code: xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mul_83: "c64[1, 2048, 16, 64]" = torch.ops.aten.mul.Tensor(view_as_complex_21, view_20); view_as_complex_21 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_as_real_21: "f32[1, 2048, 16, 64, 2]" = torch.ops.aten.view_as_real.default(mul_83); mul_83 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_212: "f32[1, 2048, 16, 128]" = torch.ops.aten.reshape.default(view_as_real_21, [1, 2048, 16, 128]); view_as_real_21 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:151 in apply_rotary_emb, code: return xq_out.type_as(xq), xk_out.type_as(xk) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_352: "bf16[1, 2048, 16, 128]" = torch.ops.prims.convert_element_type.default(view_211, torch.bfloat16); view_211 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_353: "bf16[1, 2048, 16, 128]" = torch.ops.prims.convert_element_type.default(view_212, torch.bfloat16); view_212 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:249 in forward, code: xq = xq.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] permute_186: "bf16[1, 16, 2048, 128]" = torch.ops.aten.permute.default(convert_element_type_352, [0, 2, 1, 3]); convert_element_type_352 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:250 in forward, code: xk = keys.transpose(1, 2) # (bs, n_local_heads, cache_len + seqlen, head_dim) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] permute_187: "bf16[1, 16, 2048, 128]" = torch.ops.aten.permute.default(convert_element_type_353, [0, 2, 1, 3]); convert_element_type_353 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:251 in forward, code: xv = values.transpose(1, 2) # (bs, n_local_heads, cache_len + seqlen, head_dim) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] permute_188: "bf16[1, 16, 2048, 128]" = torch.ops.aten.permute.default(view_207, [0, 2, 1, 3]); view_207 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:254 in forward, code: output = F.scaled_dot_product_attention(xq, xk, xv, is_causal=True) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] _scaled_dot_product_flash_attention_10 = torch.ops.aten._scaled_dot_product_flash_attention.default(permute_186, permute_187, permute_188, 0.0, True, scale = 0.08838834764831843) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_1116: "bf16[1, 16, 2048, 128]" = _scaled_dot_product_flash_attention_10[0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_1117: "f32[1, 16, 2048]" = _scaled_dot_product_flash_attention_10[1] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_1122: "i64[]" = _scaled_dot_product_flash_attention_10[6] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_1123: "i64[]" = _scaled_dot_product_flash_attention_10[7]; _scaled_dot_product_flash_attention_10 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:255 in forward, code: output = output.transpose( | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] permute_189: "bf16[1, 2048, 16, 128]" = torch.ops.aten.permute.default(getitem_1116, [0, 2, 1, 3]) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:259 in forward, code: output = output.view(bsz * seqlen, -1) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_213: "bf16[2048, 2048]" = torch.ops.aten.reshape.default(permute_189, [2048, -1]); permute_189 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/pytorch_yf225/torch/nn/modules/linear.py:116 in forward, code: return F.linear(input, self.weight, self.bias) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] permute_191: "bf16[2048, 2048]" = torch.ops.aten.permute.default(contiguous_view_as_strided_96, [1, 0]); contiguous_view_as_strided_96 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mm_73: "bf16[2048, 2048]" = torch.ops.aten.mm.default(view_213, permute_191); permute_191 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:410 in forward, code: h = x + self.attention(self.attention_norm(x), freqs_cis) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] add_41: "bf16[2048, 2048]" = torch.ops.aten.add.Tensor(add_39, mm_73); add_39 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:74 in forward, code: output = self._norm(x.float()).type_as(x) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_356: "f32[2048, 2048]" = torch.ops.prims.convert_element_type.default(add_41, torch.float32) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:61 in _norm, code: return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] pow_22: "f32[2048, 2048]" = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_356, 2) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mean_21: "f32[2048, 1]" = torch.ops.aten.mean.dim(pow_22, [-1], True); pow_22 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] add_42: "f32[2048, 1]" = torch.ops.aten.add.Tensor(mean_21, 1e-05); mean_21 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] rsqrt_21: "f32[2048, 1]" = torch.ops.aten.rsqrt.default(add_42); add_42 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mul_84: "f32[2048, 2048]" = torch.ops.aten.mul.Tensor(convert_element_type_356, rsqrt_21); convert_element_type_356 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:74 in forward, code: output = self._norm(x.float()).type_as(x) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_357: "bf16[2048, 2048]" = torch.ops.prims.convert_element_type.default(mul_84, torch.bfloat16); mul_84 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:75 in forward, code: return output * self.weight | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mul_85: "bf16[2048, 2048]" = torch.ops.aten.mul.Tensor(convert_element_type_357, contiguous_view_as_strided_101); convert_element_type_357 = contiguous_view_as_strided_101 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/pytorch_yf225/torch/nn/modules/linear.py:116 in forward, code: return F.linear(input, self.weight, self.bias) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] permute_193: "bf16[2048, 5632]" = torch.ops.aten.permute.default(contiguous_view_as_strided_97, [1, 0]); contiguous_view_as_strided_97 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mm_74: "bf16[2048, 5632]" = torch.ops.aten.mm.default(mul_85, permute_193); permute_193 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:299 in forward, code: return self.w2(F.silu(self.w1(x)) * self.w3(x)) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_360: "f32[2048, 5632]" = torch.ops.prims.convert_element_type.default(mm_74, torch.float32) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] sigmoid_10: "f32[2048, 5632]" = torch.ops.aten.sigmoid.default(convert_element_type_360) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mul_86: "f32[2048, 5632]" = torch.ops.aten.mul.Tensor(convert_element_type_360, sigmoid_10); convert_element_type_360 = sigmoid_10 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_361: "bf16[2048, 5632]" = torch.ops.prims.convert_element_type.default(mul_86, torch.bfloat16); mul_86 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/pytorch_yf225/torch/nn/modules/linear.py:116 in forward, code: return F.linear(input, self.weight, self.bias) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] permute_195: "bf16[2048, 5632]" = torch.ops.aten.permute.default(contiguous_view_as_strided_99, [1, 0]); contiguous_view_as_strided_99 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mm_75: "bf16[2048, 5632]" = torch.ops.aten.mm.default(mul_85, permute_195); permute_195 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:299 in forward, code: return self.w2(F.silu(self.w1(x)) * self.w3(x)) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mul_87: "bf16[2048, 5632]" = torch.ops.aten.mul.Tensor(convert_element_type_361, mm_75); convert_element_type_361 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/pytorch_yf225/torch/nn/modules/linear.py:116 in forward, code: return F.linear(input, self.weight, self.bias) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] permute_197: "bf16[5632, 2048]" = torch.ops.aten.permute.default(contiguous_view_as_strided_98, [1, 0]); contiguous_view_as_strided_98 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mm_76: "bf16[2048, 2048]" = torch.ops.aten.mm.default(mul_87, permute_197); permute_197 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:411 in forward, code: out = h + self.feed_forward(self.ffn_norm(h)) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] add_43: "bf16[2048, 2048]" = torch.ops.aten.add.Tensor(add_41, mm_76); add_41 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_common.py:140 in _to_dtype_if_needed, code: return tensor.to(dtype) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_366: "bf16[524288]" = torch.ops.prims.convert_element_type.default(primals_208, torch.bfloat16); primals_208 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_367: "bf16[524288]" = torch.ops.prims.convert_element_type.default(primals_209, torch.bfloat16); primals_209 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_368: "bf16[524288]" = torch.ops.prims.convert_element_type.default(primals_210, torch.bfloat16); primals_210 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_369: "bf16[524288]" = torch.ops.prims.convert_element_type.default(primals_211, torch.bfloat16); primals_211 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_370: "bf16[1441792]" = torch.ops.prims.convert_element_type.default(primals_212, torch.bfloat16); primals_212 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_371: "bf16[1441792]" = torch.ops.prims.convert_element_type.default(primals_213, torch.bfloat16); primals_213 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_372: "bf16[1441792]" = torch.ops.prims.convert_element_type.default(primals_214, torch.bfloat16); primals_214 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_373: "bf16[256]" = torch.ops.prims.convert_element_type.default(primals_215, torch.bfloat16); primals_215 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_374: "bf16[256]" = torch.ops.prims.convert_element_type.default(primals_216, torch.bfloat16); primals_216 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:80 in foreach_all_gather, code: all_gather_input, all_gather_output = torch.ops.fsdp.all_gather_copy_in(all_gather_input_numel, world_size, rank, dtype, device, inp_split_sizes, param_all_gather_inputs) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] all_gather_copy_in_12 = torch.ops.fsdp.all_gather_copy_in.default(6423040, 8, 0, torch.bfloat16, device(type='cuda', index=0), [524288, 524288, 524288, 524288, 1441792, 1441792, 1441792, 256, 256], [convert_element_type_366, convert_element_type_367, convert_element_type_368, convert_element_type_369, convert_element_type_370, convert_element_type_371, convert_element_type_372, convert_element_type_373, convert_element_type_374]); convert_element_type_366 = convert_element_type_367 = convert_element_type_368 = convert_element_type_369 = convert_element_type_370 = convert_element_type_371 = convert_element_type_372 = convert_element_type_373 = convert_element_type_374 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_1125: "bf16[6423040]" = all_gather_copy_in_12[0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_1126: "bf16[51384320]" = all_gather_copy_in_12[1]; all_gather_copy_in_12 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_functional_collectives.py:229 in all_gather_tensor, code: tensor = torch.ops._c10d_functional.all_gather_into_tensor( | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] all_gather_into_tensor_12: "bf16[51384320]" = torch.ops._c10d_functional.all_gather_into_tensor.default(getitem_1125, 8, '0'); getitem_1125 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_functional_collectives.py:144 in wait_tensor, code: return torch.ops._c10d_functional.wait_tensor(tensor) # type: ignore[attr-defined] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] wait_tensor_12: "bf16[51384320]" = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_12); all_gather_into_tensor_12 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:170 in foreach_all_gather_copy_out, code: split_unpadded = torch.ops.fsdp.contiguous_view_as_strided( | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_215: "bf16[8, 6423040]" = torch.ops.aten.reshape.default(wait_tensor_12, [8, -1]); wait_tensor_12 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] split_with_sizes_115 = torch.ops.aten.split_with_sizes.default(view_215, [524288, 524288, 524288, 524288, 1441792, 1441792, 1441792, 256, 256], 1); view_215 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_1136: "bf16[8, 524288]" = split_with_sizes_115[0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] contiguous_view_as_strided_102: "bf16[2048, 2048]" = torch.ops.fsdp.contiguous_view_as_strided.default(getitem_1136, [2048, 2048], [2048, 1]); getitem_1136 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_1146: "bf16[8, 524288]" = split_with_sizes_115[1] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] contiguous_view_as_strided_103: "bf16[2048, 2048]" = torch.ops.fsdp.contiguous_view_as_strided.default(getitem_1146, [2048, 2048], [2048, 1]); getitem_1146 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_1156: "bf16[8, 524288]" = split_with_sizes_115[2] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] contiguous_view_as_strided_104: "bf16[2048, 2048]" = torch.ops.fsdp.contiguous_view_as_strided.default(getitem_1156, [2048, 2048], [2048, 1]); getitem_1156 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_1166: "bf16[8, 524288]" = split_with_sizes_115[3] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] contiguous_view_as_strided_105: "bf16[2048, 2048]" = torch.ops.fsdp.contiguous_view_as_strided.default(getitem_1166, [2048, 2048], [2048, 1]); getitem_1166 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_1176: "bf16[8, 1441792]" = split_with_sizes_115[4] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] contiguous_view_as_strided_106: "bf16[5632, 2048]" = torch.ops.fsdp.contiguous_view_as_strided.default(getitem_1176, [5632, 2048], [2048, 1]); getitem_1176 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_1186: "bf16[8, 1441792]" = split_with_sizes_115[5] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] contiguous_view_as_strided_107: "bf16[2048, 5632]" = torch.ops.fsdp.contiguous_view_as_strided.default(getitem_1186, [2048, 5632], [5632, 1]); getitem_1186 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_1196: "bf16[8, 1441792]" = split_with_sizes_115[6] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] contiguous_view_as_strided_108: "bf16[5632, 2048]" = torch.ops.fsdp.contiguous_view_as_strided.default(getitem_1196, [5632, 2048], [2048, 1]); getitem_1196 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_1206: "bf16[8, 256]" = split_with_sizes_115[7] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] contiguous_view_as_strided_109: "bf16[2048]" = torch.ops.fsdp.contiguous_view_as_strided.default(getitem_1206, [2048], [1]); getitem_1206 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_1216: "bf16[8, 256]" = split_with_sizes_115[8]; split_with_sizes_115 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] contiguous_view_as_strided_110: "bf16[2048]" = torch.ops.fsdp.contiguous_view_as_strided.default(getitem_1216, [2048], [1]); getitem_1216 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:74 in forward, code: output = self._norm(x.float()).type_as(x) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_375: "f32[2048, 2048]" = torch.ops.prims.convert_element_type.default(add_43, torch.float32) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:61 in _norm, code: return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] pow_23: "f32[2048, 2048]" = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_375, 2) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mean_22: "f32[2048, 1]" = torch.ops.aten.mean.dim(pow_23, [-1], True); pow_23 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] add_44: "f32[2048, 1]" = torch.ops.aten.add.Tensor(mean_22, 1e-05); mean_22 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] rsqrt_22: "f32[2048, 1]" = torch.ops.aten.rsqrt.default(add_44); add_44 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mul_88: "f32[2048, 2048]" = torch.ops.aten.mul.Tensor(convert_element_type_375, rsqrt_22); convert_element_type_375 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:74 in forward, code: output = self._norm(x.float()).type_as(x) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_376: "bf16[2048, 2048]" = torch.ops.prims.convert_element_type.default(mul_88, torch.bfloat16); mul_88 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:75 in forward, code: return output * self.weight | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mul_89: "bf16[2048, 2048]" = torch.ops.aten.mul.Tensor(convert_element_type_376, contiguous_view_as_strided_109); convert_element_type_376 = contiguous_view_as_strided_109 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/pytorch_yf225/torch/nn/modules/linear.py:116 in forward, code: return F.linear(input, self.weight, self.bias) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] permute_199: "bf16[2048, 2048]" = torch.ops.aten.permute.default(contiguous_view_as_strided_102, [1, 0]); contiguous_view_as_strided_102 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mm_77: "bf16[2048, 2048]" = torch.ops.aten.mm.default(mul_89, permute_199); permute_199 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] permute_201: "bf16[2048, 2048]" = torch.ops.aten.permute.default(contiguous_view_as_strided_103, [1, 0]); contiguous_view_as_strided_103 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mm_78: "bf16[2048, 2048]" = torch.ops.aten.mm.default(mul_89, permute_201); permute_201 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] permute_203: "bf16[2048, 2048]" = torch.ops.aten.permute.default(contiguous_view_as_strided_104, [1, 0]); contiguous_view_as_strided_104 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mm_79: "bf16[2048, 2048]" = torch.ops.aten.mm.default(mul_89, permute_203); permute_203 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:235 in forward, code: xq = xq.view(bsz, seqlen, self.n_heads, self.head_dim) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_224: "bf16[1, 2048, 16, 128]" = torch.ops.aten.reshape.default(mm_77, [1, 2048, 16, 128]); mm_77 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:236 in forward, code: xk = xk.view(bsz, seqlen, self.n_kv_heads, self.head_dim) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_225: "bf16[1, 2048, 16, 128]" = torch.ops.aten.reshape.default(mm_78, [1, 2048, 16, 128]); mm_78 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:237 in forward, code: xv = xv.view(bsz, seqlen, self.n_kv_heads, self.head_dim) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_226: "bf16[1, 2048, 16, 128]" = torch.ops.aten.reshape.default(mm_79, [1, 2048, 16, 128]); mm_79 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:146 in apply_rotary_emb, code: xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_383: "f32[1, 2048, 16, 128]" = torch.ops.prims.convert_element_type.default(view_224, torch.float32); view_224 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_227: "f32[1, 2048, 16, 64, 2]" = torch.ops.aten.reshape.default(convert_element_type_383, [1, 2048, 16, -1, 2]); convert_element_type_383 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_as_complex_22: "c64[1, 2048, 16, 64]" = torch.ops.aten.view_as_complex.default(view_227); view_227 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:147 in apply_rotary_emb, code: xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2)) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_384: "f32[1, 2048, 16, 128]" = torch.ops.prims.convert_element_type.default(view_225, torch.float32); view_225 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_228: "f32[1, 2048, 16, 64, 2]" = torch.ops.aten.reshape.default(convert_element_type_384, [1, 2048, 16, -1, 2]); convert_element_type_384 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_as_complex_23: "c64[1, 2048, 16, 64]" = torch.ops.aten.view_as_complex.default(view_228); view_228 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:149 in apply_rotary_emb, code: xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mul_90: "c64[1, 2048, 16, 64]" = torch.ops.aten.mul.Tensor(view_as_complex_22, view_20); view_as_complex_22 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_as_real_22: "f32[1, 2048, 16, 64, 2]" = torch.ops.aten.view_as_real.default(mul_90); mul_90 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_230: "f32[1, 2048, 16, 128]" = torch.ops.aten.reshape.default(view_as_real_22, [1, 2048, 16, 128]); view_as_real_22 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:150 in apply_rotary_emb, code: xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mul_91: "c64[1, 2048, 16, 64]" = torch.ops.aten.mul.Tensor(view_as_complex_23, view_20); view_as_complex_23 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_as_real_23: "f32[1, 2048, 16, 64, 2]" = torch.ops.aten.view_as_real.default(mul_91); mul_91 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_231: "f32[1, 2048, 16, 128]" = torch.ops.aten.reshape.default(view_as_real_23, [1, 2048, 16, 128]); view_as_real_23 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:151 in apply_rotary_emb, code: return xq_out.type_as(xq), xk_out.type_as(xk) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_385: "bf16[1, 2048, 16, 128]" = torch.ops.prims.convert_element_type.default(view_230, torch.bfloat16); view_230 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_386: "bf16[1, 2048, 16, 128]" = torch.ops.prims.convert_element_type.default(view_231, torch.bfloat16); view_231 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:249 in forward, code: xq = xq.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] permute_204: "bf16[1, 16, 2048, 128]" = torch.ops.aten.permute.default(convert_element_type_385, [0, 2, 1, 3]); convert_element_type_385 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:250 in forward, code: xk = keys.transpose(1, 2) # (bs, n_local_heads, cache_len + seqlen, head_dim) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] permute_205: "bf16[1, 16, 2048, 128]" = torch.ops.aten.permute.default(convert_element_type_386, [0, 2, 1, 3]); convert_element_type_386 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:251 in forward, code: xv = values.transpose(1, 2) # (bs, n_local_heads, cache_len + seqlen, head_dim) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] permute_206: "bf16[1, 16, 2048, 128]" = torch.ops.aten.permute.default(view_226, [0, 2, 1, 3]); view_226 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:254 in forward, code: output = F.scaled_dot_product_attention(xq, xk, xv, is_causal=True) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] _scaled_dot_product_flash_attention_11 = torch.ops.aten._scaled_dot_product_flash_attention.default(permute_204, permute_205, permute_206, 0.0, True, scale = 0.08838834764831843) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_1217: "bf16[1, 16, 2048, 128]" = _scaled_dot_product_flash_attention_11[0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_1218: "f32[1, 16, 2048]" = _scaled_dot_product_flash_attention_11[1] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_1223: "i64[]" = _scaled_dot_product_flash_attention_11[6] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_1224: "i64[]" = _scaled_dot_product_flash_attention_11[7]; _scaled_dot_product_flash_attention_11 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:255 in forward, code: output = output.transpose( | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] permute_207: "bf16[1, 2048, 16, 128]" = torch.ops.aten.permute.default(getitem_1217, [0, 2, 1, 3]) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:259 in forward, code: output = output.view(bsz * seqlen, -1) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_232: "bf16[2048, 2048]" = torch.ops.aten.reshape.default(permute_207, [2048, -1]); permute_207 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/pytorch_yf225/torch/nn/modules/linear.py:116 in forward, code: return F.linear(input, self.weight, self.bias) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] permute_209: "bf16[2048, 2048]" = torch.ops.aten.permute.default(contiguous_view_as_strided_105, [1, 0]); contiguous_view_as_strided_105 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mm_80: "bf16[2048, 2048]" = torch.ops.aten.mm.default(view_232, permute_209); permute_209 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:410 in forward, code: h = x + self.attention(self.attention_norm(x), freqs_cis) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] add_45: "bf16[2048, 2048]" = torch.ops.aten.add.Tensor(add_43, mm_80); add_43 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:74 in forward, code: output = self._norm(x.float()).type_as(x) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_389: "f32[2048, 2048]" = torch.ops.prims.convert_element_type.default(add_45, torch.float32) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:61 in _norm, code: return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] pow_24: "f32[2048, 2048]" = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_389, 2) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mean_23: "f32[2048, 1]" = torch.ops.aten.mean.dim(pow_24, [-1], True); pow_24 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] add_46: "f32[2048, 1]" = torch.ops.aten.add.Tensor(mean_23, 1e-05); mean_23 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] rsqrt_23: "f32[2048, 1]" = torch.ops.aten.rsqrt.default(add_46); add_46 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mul_92: "f32[2048, 2048]" = torch.ops.aten.mul.Tensor(convert_element_type_389, rsqrt_23); convert_element_type_389 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:74 in forward, code: output = self._norm(x.float()).type_as(x) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_390: "bf16[2048, 2048]" = torch.ops.prims.convert_element_type.default(mul_92, torch.bfloat16); mul_92 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:75 in forward, code: return output * self.weight | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mul_93: "bf16[2048, 2048]" = torch.ops.aten.mul.Tensor(convert_element_type_390, contiguous_view_as_strided_110); convert_element_type_390 = contiguous_view_as_strided_110 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/pytorch_yf225/torch/nn/modules/linear.py:116 in forward, code: return F.linear(input, self.weight, self.bias) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] permute_211: "bf16[2048, 5632]" = torch.ops.aten.permute.default(contiguous_view_as_strided_106, [1, 0]); contiguous_view_as_strided_106 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mm_81: "bf16[2048, 5632]" = torch.ops.aten.mm.default(mul_93, permute_211); permute_211 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:299 in forward, code: return self.w2(F.silu(self.w1(x)) * self.w3(x)) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_393: "f32[2048, 5632]" = torch.ops.prims.convert_element_type.default(mm_81, torch.float32) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] sigmoid_11: "f32[2048, 5632]" = torch.ops.aten.sigmoid.default(convert_element_type_393) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mul_94: "f32[2048, 5632]" = torch.ops.aten.mul.Tensor(convert_element_type_393, sigmoid_11); convert_element_type_393 = sigmoid_11 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_394: "bf16[2048, 5632]" = torch.ops.prims.convert_element_type.default(mul_94, torch.bfloat16); mul_94 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/pytorch_yf225/torch/nn/modules/linear.py:116 in forward, code: return F.linear(input, self.weight, self.bias) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] permute_213: "bf16[2048, 5632]" = torch.ops.aten.permute.default(contiguous_view_as_strided_108, [1, 0]); contiguous_view_as_strided_108 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mm_82: "bf16[2048, 5632]" = torch.ops.aten.mm.default(mul_93, permute_213); permute_213 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:299 in forward, code: return self.w2(F.silu(self.w1(x)) * self.w3(x)) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mul_95: "bf16[2048, 5632]" = torch.ops.aten.mul.Tensor(convert_element_type_394, mm_82); convert_element_type_394 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/pytorch_yf225/torch/nn/modules/linear.py:116 in forward, code: return F.linear(input, self.weight, self.bias) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] permute_215: "bf16[5632, 2048]" = torch.ops.aten.permute.default(contiguous_view_as_strided_107, [1, 0]); contiguous_view_as_strided_107 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mm_83: "bf16[2048, 2048]" = torch.ops.aten.mm.default(mul_95, permute_215); permute_215 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:411 in forward, code: out = h + self.feed_forward(self.ffn_norm(h)) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] add_47: "bf16[2048, 2048]" = torch.ops.aten.add.Tensor(add_45, mm_83); add_45 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_common.py:140 in _to_dtype_if_needed, code: return tensor.to(dtype) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_399: "bf16[524288]" = torch.ops.prims.convert_element_type.default(primals_226, torch.bfloat16); primals_226 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_400: "bf16[524288]" = torch.ops.prims.convert_element_type.default(primals_227, torch.bfloat16); primals_227 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_401: "bf16[524288]" = torch.ops.prims.convert_element_type.default(primals_228, torch.bfloat16); primals_228 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_402: "bf16[524288]" = torch.ops.prims.convert_element_type.default(primals_229, torch.bfloat16); primals_229 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_403: "bf16[1441792]" = torch.ops.prims.convert_element_type.default(primals_230, torch.bfloat16); primals_230 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_404: "bf16[1441792]" = torch.ops.prims.convert_element_type.default(primals_231, torch.bfloat16); primals_231 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_405: "bf16[1441792]" = torch.ops.prims.convert_element_type.default(primals_232, torch.bfloat16); primals_232 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_406: "bf16[256]" = torch.ops.prims.convert_element_type.default(primals_233, torch.bfloat16); primals_233 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_407: "bf16[256]" = torch.ops.prims.convert_element_type.default(primals_234, torch.bfloat16); primals_234 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:80 in foreach_all_gather, code: all_gather_input, all_gather_output = torch.ops.fsdp.all_gather_copy_in(all_gather_input_numel, world_size, rank, dtype, device, inp_split_sizes, param_all_gather_inputs) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] all_gather_copy_in_13 = torch.ops.fsdp.all_gather_copy_in.default(6423040, 8, 0, torch.bfloat16, device(type='cuda', index=0), [524288, 524288, 524288, 524288, 1441792, 1441792, 1441792, 256, 256], [convert_element_type_399, convert_element_type_400, convert_element_type_401, convert_element_type_402, convert_element_type_403, convert_element_type_404, convert_element_type_405, convert_element_type_406, convert_element_type_407]); convert_element_type_399 = convert_element_type_400 = convert_element_type_401 = convert_element_type_402 = convert_element_type_403 = convert_element_type_404 = convert_element_type_405 = convert_element_type_406 = convert_element_type_407 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_1226: "bf16[6423040]" = all_gather_copy_in_13[0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_1227: "bf16[51384320]" = all_gather_copy_in_13[1]; all_gather_copy_in_13 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_functional_collectives.py:229 in all_gather_tensor, code: tensor = torch.ops._c10d_functional.all_gather_into_tensor( | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] all_gather_into_tensor_13: "bf16[51384320]" = torch.ops._c10d_functional.all_gather_into_tensor.default(getitem_1226, 8, '0'); getitem_1226 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_functional_collectives.py:144 in wait_tensor, code: return torch.ops._c10d_functional.wait_tensor(tensor) # type: ignore[attr-defined] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] wait_tensor_13: "bf16[51384320]" = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_13); all_gather_into_tensor_13 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:170 in foreach_all_gather_copy_out, code: split_unpadded = torch.ops.fsdp.contiguous_view_as_strided( | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_234: "bf16[8, 6423040]" = torch.ops.aten.reshape.default(wait_tensor_13, [8, -1]); wait_tensor_13 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] split_with_sizes_125 = torch.ops.aten.split_with_sizes.default(view_234, [524288, 524288, 524288, 524288, 1441792, 1441792, 1441792, 256, 256], 1); view_234 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_1237: "bf16[8, 524288]" = split_with_sizes_125[0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] contiguous_view_as_strided_111: "bf16[2048, 2048]" = torch.ops.fsdp.contiguous_view_as_strided.default(getitem_1237, [2048, 2048], [2048, 1]); getitem_1237 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_1247: "bf16[8, 524288]" = split_with_sizes_125[1] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] contiguous_view_as_strided_112: "bf16[2048, 2048]" = torch.ops.fsdp.contiguous_view_as_strided.default(getitem_1247, [2048, 2048], [2048, 1]); getitem_1247 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_1257: "bf16[8, 524288]" = split_with_sizes_125[2] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] contiguous_view_as_strided_113: "bf16[2048, 2048]" = torch.ops.fsdp.contiguous_view_as_strided.default(getitem_1257, [2048, 2048], [2048, 1]); getitem_1257 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_1267: "bf16[8, 524288]" = split_with_sizes_125[3] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] contiguous_view_as_strided_114: "bf16[2048, 2048]" = torch.ops.fsdp.contiguous_view_as_strided.default(getitem_1267, [2048, 2048], [2048, 1]); getitem_1267 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_1277: "bf16[8, 1441792]" = split_with_sizes_125[4] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] contiguous_view_as_strided_115: "bf16[5632, 2048]" = torch.ops.fsdp.contiguous_view_as_strided.default(getitem_1277, [5632, 2048], [2048, 1]); getitem_1277 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_1287: "bf16[8, 1441792]" = split_with_sizes_125[5] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] contiguous_view_as_strided_116: "bf16[2048, 5632]" = torch.ops.fsdp.contiguous_view_as_strided.default(getitem_1287, [2048, 5632], [5632, 1]); getitem_1287 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_1297: "bf16[8, 1441792]" = split_with_sizes_125[6] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] contiguous_view_as_strided_117: "bf16[5632, 2048]" = torch.ops.fsdp.contiguous_view_as_strided.default(getitem_1297, [5632, 2048], [2048, 1]); getitem_1297 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_1307: "bf16[8, 256]" = split_with_sizes_125[7] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] contiguous_view_as_strided_118: "bf16[2048]" = torch.ops.fsdp.contiguous_view_as_strided.default(getitem_1307, [2048], [1]); getitem_1307 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_1317: "bf16[8, 256]" = split_with_sizes_125[8]; split_with_sizes_125 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] contiguous_view_as_strided_119: "bf16[2048]" = torch.ops.fsdp.contiguous_view_as_strided.default(getitem_1317, [2048], [1]); getitem_1317 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:74 in forward, code: output = self._norm(x.float()).type_as(x) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_408: "f32[2048, 2048]" = torch.ops.prims.convert_element_type.default(add_47, torch.float32) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:61 in _norm, code: return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] pow_25: "f32[2048, 2048]" = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_408, 2) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mean_24: "f32[2048, 1]" = torch.ops.aten.mean.dim(pow_25, [-1], True); pow_25 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] add_48: "f32[2048, 1]" = torch.ops.aten.add.Tensor(mean_24, 1e-05); mean_24 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] rsqrt_24: "f32[2048, 1]" = torch.ops.aten.rsqrt.default(add_48); add_48 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mul_96: "f32[2048, 2048]" = torch.ops.aten.mul.Tensor(convert_element_type_408, rsqrt_24); convert_element_type_408 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:74 in forward, code: output = self._norm(x.float()).type_as(x) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_409: "bf16[2048, 2048]" = torch.ops.prims.convert_element_type.default(mul_96, torch.bfloat16); mul_96 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:75 in forward, code: return output * self.weight | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mul_97: "bf16[2048, 2048]" = torch.ops.aten.mul.Tensor(convert_element_type_409, contiguous_view_as_strided_118); convert_element_type_409 = contiguous_view_as_strided_118 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/pytorch_yf225/torch/nn/modules/linear.py:116 in forward, code: return F.linear(input, self.weight, self.bias) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] permute_217: "bf16[2048, 2048]" = torch.ops.aten.permute.default(contiguous_view_as_strided_111, [1, 0]); contiguous_view_as_strided_111 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mm_84: "bf16[2048, 2048]" = torch.ops.aten.mm.default(mul_97, permute_217); permute_217 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] permute_219: "bf16[2048, 2048]" = torch.ops.aten.permute.default(contiguous_view_as_strided_112, [1, 0]); contiguous_view_as_strided_112 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mm_85: "bf16[2048, 2048]" = torch.ops.aten.mm.default(mul_97, permute_219); permute_219 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] permute_221: "bf16[2048, 2048]" = torch.ops.aten.permute.default(contiguous_view_as_strided_113, [1, 0]); contiguous_view_as_strided_113 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mm_86: "bf16[2048, 2048]" = torch.ops.aten.mm.default(mul_97, permute_221); permute_221 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:235 in forward, code: xq = xq.view(bsz, seqlen, self.n_heads, self.head_dim) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_243: "bf16[1, 2048, 16, 128]" = torch.ops.aten.reshape.default(mm_84, [1, 2048, 16, 128]); mm_84 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:236 in forward, code: xk = xk.view(bsz, seqlen, self.n_kv_heads, self.head_dim) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_244: "bf16[1, 2048, 16, 128]" = torch.ops.aten.reshape.default(mm_85, [1, 2048, 16, 128]); mm_85 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:237 in forward, code: xv = xv.view(bsz, seqlen, self.n_kv_heads, self.head_dim) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_245: "bf16[1, 2048, 16, 128]" = torch.ops.aten.reshape.default(mm_86, [1, 2048, 16, 128]); mm_86 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:146 in apply_rotary_emb, code: xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_416: "f32[1, 2048, 16, 128]" = torch.ops.prims.convert_element_type.default(view_243, torch.float32); view_243 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_246: "f32[1, 2048, 16, 64, 2]" = torch.ops.aten.reshape.default(convert_element_type_416, [1, 2048, 16, -1, 2]); convert_element_type_416 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_as_complex_24: "c64[1, 2048, 16, 64]" = torch.ops.aten.view_as_complex.default(view_246); view_246 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:147 in apply_rotary_emb, code: xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2)) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_417: "f32[1, 2048, 16, 128]" = torch.ops.prims.convert_element_type.default(view_244, torch.float32); view_244 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_247: "f32[1, 2048, 16, 64, 2]" = torch.ops.aten.reshape.default(convert_element_type_417, [1, 2048, 16, -1, 2]); convert_element_type_417 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_as_complex_25: "c64[1, 2048, 16, 64]" = torch.ops.aten.view_as_complex.default(view_247); view_247 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:149 in apply_rotary_emb, code: xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mul_98: "c64[1, 2048, 16, 64]" = torch.ops.aten.mul.Tensor(view_as_complex_24, view_20); view_as_complex_24 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_as_real_24: "f32[1, 2048, 16, 64, 2]" = torch.ops.aten.view_as_real.default(mul_98); mul_98 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_249: "f32[1, 2048, 16, 128]" = torch.ops.aten.reshape.default(view_as_real_24, [1, 2048, 16, 128]); view_as_real_24 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:150 in apply_rotary_emb, code: xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mul_99: "c64[1, 2048, 16, 64]" = torch.ops.aten.mul.Tensor(view_as_complex_25, view_20); view_as_complex_25 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_as_real_25: "f32[1, 2048, 16, 64, 2]" = torch.ops.aten.view_as_real.default(mul_99); mul_99 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_250: "f32[1, 2048, 16, 128]" = torch.ops.aten.reshape.default(view_as_real_25, [1, 2048, 16, 128]); view_as_real_25 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:151 in apply_rotary_emb, code: return xq_out.type_as(xq), xk_out.type_as(xk) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_418: "bf16[1, 2048, 16, 128]" = torch.ops.prims.convert_element_type.default(view_249, torch.bfloat16); view_249 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_419: "bf16[1, 2048, 16, 128]" = torch.ops.prims.convert_element_type.default(view_250, torch.bfloat16); view_250 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:249 in forward, code: xq = xq.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] permute_222: "bf16[1, 16, 2048, 128]" = torch.ops.aten.permute.default(convert_element_type_418, [0, 2, 1, 3]); convert_element_type_418 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:250 in forward, code: xk = keys.transpose(1, 2) # (bs, n_local_heads, cache_len + seqlen, head_dim) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] permute_223: "bf16[1, 16, 2048, 128]" = torch.ops.aten.permute.default(convert_element_type_419, [0, 2, 1, 3]); convert_element_type_419 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:251 in forward, code: xv = values.transpose(1, 2) # (bs, n_local_heads, cache_len + seqlen, head_dim) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] permute_224: "bf16[1, 16, 2048, 128]" = torch.ops.aten.permute.default(view_245, [0, 2, 1, 3]); view_245 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:254 in forward, code: output = F.scaled_dot_product_attention(xq, xk, xv, is_causal=True) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] _scaled_dot_product_flash_attention_12 = torch.ops.aten._scaled_dot_product_flash_attention.default(permute_222, permute_223, permute_224, 0.0, True, scale = 0.08838834764831843) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_1318: "bf16[1, 16, 2048, 128]" = _scaled_dot_product_flash_attention_12[0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_1319: "f32[1, 16, 2048]" = _scaled_dot_product_flash_attention_12[1] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_1324: "i64[]" = _scaled_dot_product_flash_attention_12[6] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_1325: "i64[]" = _scaled_dot_product_flash_attention_12[7]; _scaled_dot_product_flash_attention_12 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:255 in forward, code: output = output.transpose( | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] permute_225: "bf16[1, 2048, 16, 128]" = torch.ops.aten.permute.default(getitem_1318, [0, 2, 1, 3]) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:259 in forward, code: output = output.view(bsz * seqlen, -1) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_251: "bf16[2048, 2048]" = torch.ops.aten.reshape.default(permute_225, [2048, -1]); permute_225 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/pytorch_yf225/torch/nn/modules/linear.py:116 in forward, code: return F.linear(input, self.weight, self.bias) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] permute_227: "bf16[2048, 2048]" = torch.ops.aten.permute.default(contiguous_view_as_strided_114, [1, 0]); contiguous_view_as_strided_114 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mm_87: "bf16[2048, 2048]" = torch.ops.aten.mm.default(view_251, permute_227); permute_227 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:410 in forward, code: h = x + self.attention(self.attention_norm(x), freqs_cis) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] add_49: "bf16[2048, 2048]" = torch.ops.aten.add.Tensor(add_47, mm_87); add_47 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:74 in forward, code: output = self._norm(x.float()).type_as(x) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_422: "f32[2048, 2048]" = torch.ops.prims.convert_element_type.default(add_49, torch.float32) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:61 in _norm, code: return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] pow_26: "f32[2048, 2048]" = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_422, 2) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mean_25: "f32[2048, 1]" = torch.ops.aten.mean.dim(pow_26, [-1], True); pow_26 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] add_50: "f32[2048, 1]" = torch.ops.aten.add.Tensor(mean_25, 1e-05); mean_25 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] rsqrt_25: "f32[2048, 1]" = torch.ops.aten.rsqrt.default(add_50); add_50 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mul_100: "f32[2048, 2048]" = torch.ops.aten.mul.Tensor(convert_element_type_422, rsqrt_25); convert_element_type_422 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:74 in forward, code: output = self._norm(x.float()).type_as(x) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_423: "bf16[2048, 2048]" = torch.ops.prims.convert_element_type.default(mul_100, torch.bfloat16); mul_100 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:75 in forward, code: return output * self.weight | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mul_101: "bf16[2048, 2048]" = torch.ops.aten.mul.Tensor(convert_element_type_423, contiguous_view_as_strided_119); convert_element_type_423 = contiguous_view_as_strided_119 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/pytorch_yf225/torch/nn/modules/linear.py:116 in forward, code: return F.linear(input, self.weight, self.bias) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] permute_229: "bf16[2048, 5632]" = torch.ops.aten.permute.default(contiguous_view_as_strided_115, [1, 0]); contiguous_view_as_strided_115 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mm_88: "bf16[2048, 5632]" = torch.ops.aten.mm.default(mul_101, permute_229); permute_229 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:299 in forward, code: return self.w2(F.silu(self.w1(x)) * self.w3(x)) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_426: "f32[2048, 5632]" = torch.ops.prims.convert_element_type.default(mm_88, torch.float32) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] sigmoid_12: "f32[2048, 5632]" = torch.ops.aten.sigmoid.default(convert_element_type_426) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mul_102: "f32[2048, 5632]" = torch.ops.aten.mul.Tensor(convert_element_type_426, sigmoid_12); convert_element_type_426 = sigmoid_12 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_427: "bf16[2048, 5632]" = torch.ops.prims.convert_element_type.default(mul_102, torch.bfloat16); mul_102 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/pytorch_yf225/torch/nn/modules/linear.py:116 in forward, code: return F.linear(input, self.weight, self.bias) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] permute_231: "bf16[2048, 5632]" = torch.ops.aten.permute.default(contiguous_view_as_strided_117, [1, 0]); contiguous_view_as_strided_117 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mm_89: "bf16[2048, 5632]" = torch.ops.aten.mm.default(mul_101, permute_231); permute_231 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:299 in forward, code: return self.w2(F.silu(self.w1(x)) * self.w3(x)) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mul_103: "bf16[2048, 5632]" = torch.ops.aten.mul.Tensor(convert_element_type_427, mm_89); convert_element_type_427 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/pytorch_yf225/torch/nn/modules/linear.py:116 in forward, code: return F.linear(input, self.weight, self.bias) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] permute_233: "bf16[5632, 2048]" = torch.ops.aten.permute.default(contiguous_view_as_strided_116, [1, 0]); contiguous_view_as_strided_116 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mm_90: "bf16[2048, 2048]" = torch.ops.aten.mm.default(mul_103, permute_233); permute_233 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:411 in forward, code: out = h + self.feed_forward(self.ffn_norm(h)) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] add_51: "bf16[2048, 2048]" = torch.ops.aten.add.Tensor(add_49, mm_90); add_49 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_common.py:140 in _to_dtype_if_needed, code: return tensor.to(dtype) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_432: "bf16[524288]" = torch.ops.prims.convert_element_type.default(primals_244, torch.bfloat16); primals_244 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_433: "bf16[524288]" = torch.ops.prims.convert_element_type.default(primals_245, torch.bfloat16); primals_245 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_434: "bf16[524288]" = torch.ops.prims.convert_element_type.default(primals_246, torch.bfloat16); primals_246 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_435: "bf16[524288]" = torch.ops.prims.convert_element_type.default(primals_247, torch.bfloat16); primals_247 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_436: "bf16[1441792]" = torch.ops.prims.convert_element_type.default(primals_248, torch.bfloat16); primals_248 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_437: "bf16[1441792]" = torch.ops.prims.convert_element_type.default(primals_249, torch.bfloat16); primals_249 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_438: "bf16[1441792]" = torch.ops.prims.convert_element_type.default(primals_250, torch.bfloat16); primals_250 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_439: "bf16[256]" = torch.ops.prims.convert_element_type.default(primals_251, torch.bfloat16); primals_251 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_440: "bf16[256]" = torch.ops.prims.convert_element_type.default(primals_252, torch.bfloat16); primals_252 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:80 in foreach_all_gather, code: all_gather_input, all_gather_output = torch.ops.fsdp.all_gather_copy_in(all_gather_input_numel, world_size, rank, dtype, device, inp_split_sizes, param_all_gather_inputs) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] all_gather_copy_in_14 = torch.ops.fsdp.all_gather_copy_in.default(6423040, 8, 0, torch.bfloat16, device(type='cuda', index=0), [524288, 524288, 524288, 524288, 1441792, 1441792, 1441792, 256, 256], [convert_element_type_432, convert_element_type_433, convert_element_type_434, convert_element_type_435, convert_element_type_436, convert_element_type_437, convert_element_type_438, convert_element_type_439, convert_element_type_440]); convert_element_type_432 = convert_element_type_433 = convert_element_type_434 = convert_element_type_435 = convert_element_type_436 = convert_element_type_437 = convert_element_type_438 = convert_element_type_439 = convert_element_type_440 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_1327: "bf16[6423040]" = all_gather_copy_in_14[0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_1328: "bf16[51384320]" = all_gather_copy_in_14[1]; all_gather_copy_in_14 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_functional_collectives.py:229 in all_gather_tensor, code: tensor = torch.ops._c10d_functional.all_gather_into_tensor( | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] all_gather_into_tensor_14: "bf16[51384320]" = torch.ops._c10d_functional.all_gather_into_tensor.default(getitem_1327, 8, '0'); getitem_1327 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_functional_collectives.py:144 in wait_tensor, code: return torch.ops._c10d_functional.wait_tensor(tensor) # type: ignore[attr-defined] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] wait_tensor_14: "bf16[51384320]" = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_14); all_gather_into_tensor_14 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:170 in foreach_all_gather_copy_out, code: split_unpadded = torch.ops.fsdp.contiguous_view_as_strided( | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_253: "bf16[8, 6423040]" = torch.ops.aten.reshape.default(wait_tensor_14, [8, -1]); wait_tensor_14 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] split_with_sizes_135 = torch.ops.aten.split_with_sizes.default(view_253, [524288, 524288, 524288, 524288, 1441792, 1441792, 1441792, 256, 256], 1); view_253 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_1338: "bf16[8, 524288]" = split_with_sizes_135[0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] contiguous_view_as_strided_120: "bf16[2048, 2048]" = torch.ops.fsdp.contiguous_view_as_strided.default(getitem_1338, [2048, 2048], [2048, 1]); getitem_1338 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_1348: "bf16[8, 524288]" = split_with_sizes_135[1] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] contiguous_view_as_strided_121: "bf16[2048, 2048]" = torch.ops.fsdp.contiguous_view_as_strided.default(getitem_1348, [2048, 2048], [2048, 1]); getitem_1348 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_1358: "bf16[8, 524288]" = split_with_sizes_135[2] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] contiguous_view_as_strided_122: "bf16[2048, 2048]" = torch.ops.fsdp.contiguous_view_as_strided.default(getitem_1358, [2048, 2048], [2048, 1]); getitem_1358 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_1368: "bf16[8, 524288]" = split_with_sizes_135[3] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] contiguous_view_as_strided_123: "bf16[2048, 2048]" = torch.ops.fsdp.contiguous_view_as_strided.default(getitem_1368, [2048, 2048], [2048, 1]); getitem_1368 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_1378: "bf16[8, 1441792]" = split_with_sizes_135[4] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] contiguous_view_as_strided_124: "bf16[5632, 2048]" = torch.ops.fsdp.contiguous_view_as_strided.default(getitem_1378, [5632, 2048], [2048, 1]); getitem_1378 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_1388: "bf16[8, 1441792]" = split_with_sizes_135[5] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] contiguous_view_as_strided_125: "bf16[2048, 5632]" = torch.ops.fsdp.contiguous_view_as_strided.default(getitem_1388, [2048, 5632], [5632, 1]); getitem_1388 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_1398: "bf16[8, 1441792]" = split_with_sizes_135[6] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] contiguous_view_as_strided_126: "bf16[5632, 2048]" = torch.ops.fsdp.contiguous_view_as_strided.default(getitem_1398, [5632, 2048], [2048, 1]); getitem_1398 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_1408: "bf16[8, 256]" = split_with_sizes_135[7] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] contiguous_view_as_strided_127: "bf16[2048]" = torch.ops.fsdp.contiguous_view_as_strided.default(getitem_1408, [2048], [1]); getitem_1408 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_1418: "bf16[8, 256]" = split_with_sizes_135[8]; split_with_sizes_135 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] contiguous_view_as_strided_128: "bf16[2048]" = torch.ops.fsdp.contiguous_view_as_strided.default(getitem_1418, [2048], [1]); getitem_1418 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:74 in forward, code: output = self._norm(x.float()).type_as(x) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_441: "f32[2048, 2048]" = torch.ops.prims.convert_element_type.default(add_51, torch.float32) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:61 in _norm, code: return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] pow_27: "f32[2048, 2048]" = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_441, 2) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mean_26: "f32[2048, 1]" = torch.ops.aten.mean.dim(pow_27, [-1], True); pow_27 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] add_52: "f32[2048, 1]" = torch.ops.aten.add.Tensor(mean_26, 1e-05); mean_26 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] rsqrt_26: "f32[2048, 1]" = torch.ops.aten.rsqrt.default(add_52); add_52 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mul_104: "f32[2048, 2048]" = torch.ops.aten.mul.Tensor(convert_element_type_441, rsqrt_26); convert_element_type_441 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:74 in forward, code: output = self._norm(x.float()).type_as(x) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_442: "bf16[2048, 2048]" = torch.ops.prims.convert_element_type.default(mul_104, torch.bfloat16); mul_104 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:75 in forward, code: return output * self.weight | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mul_105: "bf16[2048, 2048]" = torch.ops.aten.mul.Tensor(convert_element_type_442, contiguous_view_as_strided_127); convert_element_type_442 = contiguous_view_as_strided_127 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/pytorch_yf225/torch/nn/modules/linear.py:116 in forward, code: return F.linear(input, self.weight, self.bias) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] permute_235: "bf16[2048, 2048]" = torch.ops.aten.permute.default(contiguous_view_as_strided_120, [1, 0]); contiguous_view_as_strided_120 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mm_91: "bf16[2048, 2048]" = torch.ops.aten.mm.default(mul_105, permute_235); permute_235 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] permute_237: "bf16[2048, 2048]" = torch.ops.aten.permute.default(contiguous_view_as_strided_121, [1, 0]); contiguous_view_as_strided_121 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mm_92: "bf16[2048, 2048]" = torch.ops.aten.mm.default(mul_105, permute_237); permute_237 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] permute_239: "bf16[2048, 2048]" = torch.ops.aten.permute.default(contiguous_view_as_strided_122, [1, 0]); contiguous_view_as_strided_122 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mm_93: "bf16[2048, 2048]" = torch.ops.aten.mm.default(mul_105, permute_239); permute_239 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:235 in forward, code: xq = xq.view(bsz, seqlen, self.n_heads, self.head_dim) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_262: "bf16[1, 2048, 16, 128]" = torch.ops.aten.reshape.default(mm_91, [1, 2048, 16, 128]); mm_91 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:236 in forward, code: xk = xk.view(bsz, seqlen, self.n_kv_heads, self.head_dim) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_263: "bf16[1, 2048, 16, 128]" = torch.ops.aten.reshape.default(mm_92, [1, 2048, 16, 128]); mm_92 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:237 in forward, code: xv = xv.view(bsz, seqlen, self.n_kv_heads, self.head_dim) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_264: "bf16[1, 2048, 16, 128]" = torch.ops.aten.reshape.default(mm_93, [1, 2048, 16, 128]); mm_93 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:146 in apply_rotary_emb, code: xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_449: "f32[1, 2048, 16, 128]" = torch.ops.prims.convert_element_type.default(view_262, torch.float32); view_262 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_265: "f32[1, 2048, 16, 64, 2]" = torch.ops.aten.reshape.default(convert_element_type_449, [1, 2048, 16, -1, 2]); convert_element_type_449 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_as_complex_26: "c64[1, 2048, 16, 64]" = torch.ops.aten.view_as_complex.default(view_265); view_265 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:147 in apply_rotary_emb, code: xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2)) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_450: "f32[1, 2048, 16, 128]" = torch.ops.prims.convert_element_type.default(view_263, torch.float32); view_263 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_266: "f32[1, 2048, 16, 64, 2]" = torch.ops.aten.reshape.default(convert_element_type_450, [1, 2048, 16, -1, 2]); convert_element_type_450 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_as_complex_27: "c64[1, 2048, 16, 64]" = torch.ops.aten.view_as_complex.default(view_266); view_266 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:149 in apply_rotary_emb, code: xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mul_106: "c64[1, 2048, 16, 64]" = torch.ops.aten.mul.Tensor(view_as_complex_26, view_20); view_as_complex_26 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_as_real_26: "f32[1, 2048, 16, 64, 2]" = torch.ops.aten.view_as_real.default(mul_106); mul_106 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_268: "f32[1, 2048, 16, 128]" = torch.ops.aten.reshape.default(view_as_real_26, [1, 2048, 16, 128]); view_as_real_26 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:150 in apply_rotary_emb, code: xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mul_107: "c64[1, 2048, 16, 64]" = torch.ops.aten.mul.Tensor(view_as_complex_27, view_20); view_as_complex_27 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_as_real_27: "f32[1, 2048, 16, 64, 2]" = torch.ops.aten.view_as_real.default(mul_107); mul_107 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_269: "f32[1, 2048, 16, 128]" = torch.ops.aten.reshape.default(view_as_real_27, [1, 2048, 16, 128]); view_as_real_27 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:151 in apply_rotary_emb, code: return xq_out.type_as(xq), xk_out.type_as(xk) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_451: "bf16[1, 2048, 16, 128]" = torch.ops.prims.convert_element_type.default(view_268, torch.bfloat16); view_268 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_452: "bf16[1, 2048, 16, 128]" = torch.ops.prims.convert_element_type.default(view_269, torch.bfloat16); view_269 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:249 in forward, code: xq = xq.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] permute_240: "bf16[1, 16, 2048, 128]" = torch.ops.aten.permute.default(convert_element_type_451, [0, 2, 1, 3]); convert_element_type_451 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:250 in forward, code: xk = keys.transpose(1, 2) # (bs, n_local_heads, cache_len + seqlen, head_dim) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] permute_241: "bf16[1, 16, 2048, 128]" = torch.ops.aten.permute.default(convert_element_type_452, [0, 2, 1, 3]); convert_element_type_452 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:251 in forward, code: xv = values.transpose(1, 2) # (bs, n_local_heads, cache_len + seqlen, head_dim) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] permute_242: "bf16[1, 16, 2048, 128]" = torch.ops.aten.permute.default(view_264, [0, 2, 1, 3]); view_264 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:254 in forward, code: output = F.scaled_dot_product_attention(xq, xk, xv, is_causal=True) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] _scaled_dot_product_flash_attention_13 = torch.ops.aten._scaled_dot_product_flash_attention.default(permute_240, permute_241, permute_242, 0.0, True, scale = 0.08838834764831843) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_1419: "bf16[1, 16, 2048, 128]" = _scaled_dot_product_flash_attention_13[0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_1420: "f32[1, 16, 2048]" = _scaled_dot_product_flash_attention_13[1] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_1425: "i64[]" = _scaled_dot_product_flash_attention_13[6] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_1426: "i64[]" = _scaled_dot_product_flash_attention_13[7]; _scaled_dot_product_flash_attention_13 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:255 in forward, code: output = output.transpose( | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] permute_243: "bf16[1, 2048, 16, 128]" = torch.ops.aten.permute.default(getitem_1419, [0, 2, 1, 3]) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:259 in forward, code: output = output.view(bsz * seqlen, -1) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_270: "bf16[2048, 2048]" = torch.ops.aten.reshape.default(permute_243, [2048, -1]); permute_243 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/pytorch_yf225/torch/nn/modules/linear.py:116 in forward, code: return F.linear(input, self.weight, self.bias) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] permute_245: "bf16[2048, 2048]" = torch.ops.aten.permute.default(contiguous_view_as_strided_123, [1, 0]); contiguous_view_as_strided_123 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mm_94: "bf16[2048, 2048]" = torch.ops.aten.mm.default(view_270, permute_245); permute_245 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:410 in forward, code: h = x + self.attention(self.attention_norm(x), freqs_cis) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] add_53: "bf16[2048, 2048]" = torch.ops.aten.add.Tensor(add_51, mm_94); add_51 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:74 in forward, code: output = self._norm(x.float()).type_as(x) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_455: "f32[2048, 2048]" = torch.ops.prims.convert_element_type.default(add_53, torch.float32) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:61 in _norm, code: return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] pow_28: "f32[2048, 2048]" = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_455, 2) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mean_27: "f32[2048, 1]" = torch.ops.aten.mean.dim(pow_28, [-1], True); pow_28 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] add_54: "f32[2048, 1]" = torch.ops.aten.add.Tensor(mean_27, 1e-05); mean_27 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] rsqrt_27: "f32[2048, 1]" = torch.ops.aten.rsqrt.default(add_54); add_54 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mul_108: "f32[2048, 2048]" = torch.ops.aten.mul.Tensor(convert_element_type_455, rsqrt_27); convert_element_type_455 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:74 in forward, code: output = self._norm(x.float()).type_as(x) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_456: "bf16[2048, 2048]" = torch.ops.prims.convert_element_type.default(mul_108, torch.bfloat16); mul_108 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:75 in forward, code: return output * self.weight | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mul_109: "bf16[2048, 2048]" = torch.ops.aten.mul.Tensor(convert_element_type_456, contiguous_view_as_strided_128); convert_element_type_456 = contiguous_view_as_strided_128 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/pytorch_yf225/torch/nn/modules/linear.py:116 in forward, code: return F.linear(input, self.weight, self.bias) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] permute_247: "bf16[2048, 5632]" = torch.ops.aten.permute.default(contiguous_view_as_strided_124, [1, 0]); contiguous_view_as_strided_124 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mm_95: "bf16[2048, 5632]" = torch.ops.aten.mm.default(mul_109, permute_247); permute_247 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:299 in forward, code: return self.w2(F.silu(self.w1(x)) * self.w3(x)) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_459: "f32[2048, 5632]" = torch.ops.prims.convert_element_type.default(mm_95, torch.float32) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] sigmoid_13: "f32[2048, 5632]" = torch.ops.aten.sigmoid.default(convert_element_type_459) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mul_110: "f32[2048, 5632]" = torch.ops.aten.mul.Tensor(convert_element_type_459, sigmoid_13); convert_element_type_459 = sigmoid_13 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_460: "bf16[2048, 5632]" = torch.ops.prims.convert_element_type.default(mul_110, torch.bfloat16); mul_110 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/pytorch_yf225/torch/nn/modules/linear.py:116 in forward, code: return F.linear(input, self.weight, self.bias) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] permute_249: "bf16[2048, 5632]" = torch.ops.aten.permute.default(contiguous_view_as_strided_126, [1, 0]); contiguous_view_as_strided_126 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mm_96: "bf16[2048, 5632]" = torch.ops.aten.mm.default(mul_109, permute_249); permute_249 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:299 in forward, code: return self.w2(F.silu(self.w1(x)) * self.w3(x)) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mul_111: "bf16[2048, 5632]" = torch.ops.aten.mul.Tensor(convert_element_type_460, mm_96); convert_element_type_460 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/pytorch_yf225/torch/nn/modules/linear.py:116 in forward, code: return F.linear(input, self.weight, self.bias) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] permute_251: "bf16[5632, 2048]" = torch.ops.aten.permute.default(contiguous_view_as_strided_125, [1, 0]); contiguous_view_as_strided_125 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mm_97: "bf16[2048, 2048]" = torch.ops.aten.mm.default(mul_111, permute_251); permute_251 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:411 in forward, code: out = h + self.feed_forward(self.ffn_norm(h)) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] add_55: "bf16[2048, 2048]" = torch.ops.aten.add.Tensor(add_53, mm_97); add_53 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_common.py:140 in _to_dtype_if_needed, code: return tensor.to(dtype) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_465: "bf16[524288]" = torch.ops.prims.convert_element_type.default(primals_262, torch.bfloat16); primals_262 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_466: "bf16[524288]" = torch.ops.prims.convert_element_type.default(primals_263, torch.bfloat16); primals_263 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_467: "bf16[524288]" = torch.ops.prims.convert_element_type.default(primals_264, torch.bfloat16); primals_264 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_468: "bf16[524288]" = torch.ops.prims.convert_element_type.default(primals_265, torch.bfloat16); primals_265 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_469: "bf16[1441792]" = torch.ops.prims.convert_element_type.default(primals_266, torch.bfloat16); primals_266 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_470: "bf16[1441792]" = torch.ops.prims.convert_element_type.default(primals_267, torch.bfloat16); primals_267 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_471: "bf16[1441792]" = torch.ops.prims.convert_element_type.default(primals_268, torch.bfloat16); primals_268 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_472: "bf16[256]" = torch.ops.prims.convert_element_type.default(primals_269, torch.bfloat16); primals_269 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_473: "bf16[256]" = torch.ops.prims.convert_element_type.default(primals_270, torch.bfloat16); primals_270 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:80 in foreach_all_gather, code: all_gather_input, all_gather_output = torch.ops.fsdp.all_gather_copy_in(all_gather_input_numel, world_size, rank, dtype, device, inp_split_sizes, param_all_gather_inputs) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] all_gather_copy_in_15 = torch.ops.fsdp.all_gather_copy_in.default(6423040, 8, 0, torch.bfloat16, device(type='cuda', index=0), [524288, 524288, 524288, 524288, 1441792, 1441792, 1441792, 256, 256], [convert_element_type_465, convert_element_type_466, convert_element_type_467, convert_element_type_468, convert_element_type_469, convert_element_type_470, convert_element_type_471, convert_element_type_472, convert_element_type_473]); convert_element_type_465 = convert_element_type_466 = convert_element_type_467 = convert_element_type_468 = convert_element_type_469 = convert_element_type_470 = convert_element_type_471 = convert_element_type_472 = convert_element_type_473 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_1428: "bf16[6423040]" = all_gather_copy_in_15[0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_1429: "bf16[51384320]" = all_gather_copy_in_15[1]; all_gather_copy_in_15 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_functional_collectives.py:229 in all_gather_tensor, code: tensor = torch.ops._c10d_functional.all_gather_into_tensor( | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] all_gather_into_tensor_15: "bf16[51384320]" = torch.ops._c10d_functional.all_gather_into_tensor.default(getitem_1428, 8, '0'); getitem_1428 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_functional_collectives.py:144 in wait_tensor, code: return torch.ops._c10d_functional.wait_tensor(tensor) # type: ignore[attr-defined] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] wait_tensor_15: "bf16[51384320]" = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_15); all_gather_into_tensor_15 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:170 in foreach_all_gather_copy_out, code: split_unpadded = torch.ops.fsdp.contiguous_view_as_strided( | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_272: "bf16[8, 6423040]" = torch.ops.aten.reshape.default(wait_tensor_15, [8, -1]); wait_tensor_15 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] split_with_sizes_145 = torch.ops.aten.split_with_sizes.default(view_272, [524288, 524288, 524288, 524288, 1441792, 1441792, 1441792, 256, 256], 1); view_272 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_1439: "bf16[8, 524288]" = split_with_sizes_145[0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] contiguous_view_as_strided_129: "bf16[2048, 2048]" = torch.ops.fsdp.contiguous_view_as_strided.default(getitem_1439, [2048, 2048], [2048, 1]); getitem_1439 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_1449: "bf16[8, 524288]" = split_with_sizes_145[1] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] contiguous_view_as_strided_130: "bf16[2048, 2048]" = torch.ops.fsdp.contiguous_view_as_strided.default(getitem_1449, [2048, 2048], [2048, 1]); getitem_1449 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_1459: "bf16[8, 524288]" = split_with_sizes_145[2] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] contiguous_view_as_strided_131: "bf16[2048, 2048]" = torch.ops.fsdp.contiguous_view_as_strided.default(getitem_1459, [2048, 2048], [2048, 1]); getitem_1459 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_1469: "bf16[8, 524288]" = split_with_sizes_145[3] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] contiguous_view_as_strided_132: "bf16[2048, 2048]" = torch.ops.fsdp.contiguous_view_as_strided.default(getitem_1469, [2048, 2048], [2048, 1]); getitem_1469 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_1479: "bf16[8, 1441792]" = split_with_sizes_145[4] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] contiguous_view_as_strided_133: "bf16[5632, 2048]" = torch.ops.fsdp.contiguous_view_as_strided.default(getitem_1479, [5632, 2048], [2048, 1]); getitem_1479 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_1489: "bf16[8, 1441792]" = split_with_sizes_145[5] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] contiguous_view_as_strided_134: "bf16[2048, 5632]" = torch.ops.fsdp.contiguous_view_as_strided.default(getitem_1489, [2048, 5632], [5632, 1]); getitem_1489 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_1499: "bf16[8, 1441792]" = split_with_sizes_145[6] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] contiguous_view_as_strided_135: "bf16[5632, 2048]" = torch.ops.fsdp.contiguous_view_as_strided.default(getitem_1499, [5632, 2048], [2048, 1]); getitem_1499 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_1509: "bf16[8, 256]" = split_with_sizes_145[7] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] contiguous_view_as_strided_136: "bf16[2048]" = torch.ops.fsdp.contiguous_view_as_strided.default(getitem_1509, [2048], [1]); getitem_1509 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_1519: "bf16[8, 256]" = split_with_sizes_145[8]; split_with_sizes_145 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] contiguous_view_as_strided_137: "bf16[2048]" = torch.ops.fsdp.contiguous_view_as_strided.default(getitem_1519, [2048], [1]); getitem_1519 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:74 in forward, code: output = self._norm(x.float()).type_as(x) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_474: "f32[2048, 2048]" = torch.ops.prims.convert_element_type.default(add_55, torch.float32) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:61 in _norm, code: return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] pow_29: "f32[2048, 2048]" = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_474, 2) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mean_28: "f32[2048, 1]" = torch.ops.aten.mean.dim(pow_29, [-1], True); pow_29 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] add_56: "f32[2048, 1]" = torch.ops.aten.add.Tensor(mean_28, 1e-05); mean_28 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] rsqrt_28: "f32[2048, 1]" = torch.ops.aten.rsqrt.default(add_56); add_56 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mul_112: "f32[2048, 2048]" = torch.ops.aten.mul.Tensor(convert_element_type_474, rsqrt_28); convert_element_type_474 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:74 in forward, code: output = self._norm(x.float()).type_as(x) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_475: "bf16[2048, 2048]" = torch.ops.prims.convert_element_type.default(mul_112, torch.bfloat16); mul_112 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:75 in forward, code: return output * self.weight | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mul_113: "bf16[2048, 2048]" = torch.ops.aten.mul.Tensor(convert_element_type_475, contiguous_view_as_strided_136); convert_element_type_475 = contiguous_view_as_strided_136 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/pytorch_yf225/torch/nn/modules/linear.py:116 in forward, code: return F.linear(input, self.weight, self.bias) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] permute_253: "bf16[2048, 2048]" = torch.ops.aten.permute.default(contiguous_view_as_strided_129, [1, 0]); contiguous_view_as_strided_129 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mm_98: "bf16[2048, 2048]" = torch.ops.aten.mm.default(mul_113, permute_253); permute_253 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] permute_255: "bf16[2048, 2048]" = torch.ops.aten.permute.default(contiguous_view_as_strided_130, [1, 0]); contiguous_view_as_strided_130 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mm_99: "bf16[2048, 2048]" = torch.ops.aten.mm.default(mul_113, permute_255); permute_255 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] permute_257: "bf16[2048, 2048]" = torch.ops.aten.permute.default(contiguous_view_as_strided_131, [1, 0]); contiguous_view_as_strided_131 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mm_100: "bf16[2048, 2048]" = torch.ops.aten.mm.default(mul_113, permute_257); permute_257 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:235 in forward, code: xq = xq.view(bsz, seqlen, self.n_heads, self.head_dim) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_281: "bf16[1, 2048, 16, 128]" = torch.ops.aten.reshape.default(mm_98, [1, 2048, 16, 128]); mm_98 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:236 in forward, code: xk = xk.view(bsz, seqlen, self.n_kv_heads, self.head_dim) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_282: "bf16[1, 2048, 16, 128]" = torch.ops.aten.reshape.default(mm_99, [1, 2048, 16, 128]); mm_99 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:237 in forward, code: xv = xv.view(bsz, seqlen, self.n_kv_heads, self.head_dim) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_283: "bf16[1, 2048, 16, 128]" = torch.ops.aten.reshape.default(mm_100, [1, 2048, 16, 128]); mm_100 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:146 in apply_rotary_emb, code: xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_482: "f32[1, 2048, 16, 128]" = torch.ops.prims.convert_element_type.default(view_281, torch.float32); view_281 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_284: "f32[1, 2048, 16, 64, 2]" = torch.ops.aten.reshape.default(convert_element_type_482, [1, 2048, 16, -1, 2]); convert_element_type_482 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_as_complex_28: "c64[1, 2048, 16, 64]" = torch.ops.aten.view_as_complex.default(view_284); view_284 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:147 in apply_rotary_emb, code: xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2)) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_483: "f32[1, 2048, 16, 128]" = torch.ops.prims.convert_element_type.default(view_282, torch.float32); view_282 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_285: "f32[1, 2048, 16, 64, 2]" = torch.ops.aten.reshape.default(convert_element_type_483, [1, 2048, 16, -1, 2]); convert_element_type_483 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_as_complex_29: "c64[1, 2048, 16, 64]" = torch.ops.aten.view_as_complex.default(view_285); view_285 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:149 in apply_rotary_emb, code: xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mul_114: "c64[1, 2048, 16, 64]" = torch.ops.aten.mul.Tensor(view_as_complex_28, view_20); view_as_complex_28 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_as_real_28: "f32[1, 2048, 16, 64, 2]" = torch.ops.aten.view_as_real.default(mul_114); mul_114 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_287: "f32[1, 2048, 16, 128]" = torch.ops.aten.reshape.default(view_as_real_28, [1, 2048, 16, 128]); view_as_real_28 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:150 in apply_rotary_emb, code: xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mul_115: "c64[1, 2048, 16, 64]" = torch.ops.aten.mul.Tensor(view_as_complex_29, view_20); view_as_complex_29 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_as_real_29: "f32[1, 2048, 16, 64, 2]" = torch.ops.aten.view_as_real.default(mul_115); mul_115 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_288: "f32[1, 2048, 16, 128]" = torch.ops.aten.reshape.default(view_as_real_29, [1, 2048, 16, 128]); view_as_real_29 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:151 in apply_rotary_emb, code: return xq_out.type_as(xq), xk_out.type_as(xk) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_484: "bf16[1, 2048, 16, 128]" = torch.ops.prims.convert_element_type.default(view_287, torch.bfloat16); view_287 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_485: "bf16[1, 2048, 16, 128]" = torch.ops.prims.convert_element_type.default(view_288, torch.bfloat16); view_288 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:249 in forward, code: xq = xq.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] permute_258: "bf16[1, 16, 2048, 128]" = torch.ops.aten.permute.default(convert_element_type_484, [0, 2, 1, 3]); convert_element_type_484 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:250 in forward, code: xk = keys.transpose(1, 2) # (bs, n_local_heads, cache_len + seqlen, head_dim) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] permute_259: "bf16[1, 16, 2048, 128]" = torch.ops.aten.permute.default(convert_element_type_485, [0, 2, 1, 3]); convert_element_type_485 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:251 in forward, code: xv = values.transpose(1, 2) # (bs, n_local_heads, cache_len + seqlen, head_dim) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] permute_260: "bf16[1, 16, 2048, 128]" = torch.ops.aten.permute.default(view_283, [0, 2, 1, 3]); view_283 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:254 in forward, code: output = F.scaled_dot_product_attention(xq, xk, xv, is_causal=True) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] _scaled_dot_product_flash_attention_14 = torch.ops.aten._scaled_dot_product_flash_attention.default(permute_258, permute_259, permute_260, 0.0, True, scale = 0.08838834764831843) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_1520: "bf16[1, 16, 2048, 128]" = _scaled_dot_product_flash_attention_14[0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_1521: "f32[1, 16, 2048]" = _scaled_dot_product_flash_attention_14[1] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_1526: "i64[]" = _scaled_dot_product_flash_attention_14[6] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_1527: "i64[]" = _scaled_dot_product_flash_attention_14[7]; _scaled_dot_product_flash_attention_14 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:255 in forward, code: output = output.transpose( | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] permute_261: "bf16[1, 2048, 16, 128]" = torch.ops.aten.permute.default(getitem_1520, [0, 2, 1, 3]) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:259 in forward, code: output = output.view(bsz * seqlen, -1) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_289: "bf16[2048, 2048]" = torch.ops.aten.reshape.default(permute_261, [2048, -1]); permute_261 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/pytorch_yf225/torch/nn/modules/linear.py:116 in forward, code: return F.linear(input, self.weight, self.bias) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] permute_263: "bf16[2048, 2048]" = torch.ops.aten.permute.default(contiguous_view_as_strided_132, [1, 0]); contiguous_view_as_strided_132 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mm_101: "bf16[2048, 2048]" = torch.ops.aten.mm.default(view_289, permute_263); permute_263 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:410 in forward, code: h = x + self.attention(self.attention_norm(x), freqs_cis) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] add_57: "bf16[2048, 2048]" = torch.ops.aten.add.Tensor(add_55, mm_101); add_55 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:74 in forward, code: output = self._norm(x.float()).type_as(x) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_488: "f32[2048, 2048]" = torch.ops.prims.convert_element_type.default(add_57, torch.float32) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:61 in _norm, code: return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] pow_30: "f32[2048, 2048]" = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_488, 2) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mean_29: "f32[2048, 1]" = torch.ops.aten.mean.dim(pow_30, [-1], True); pow_30 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] add_58: "f32[2048, 1]" = torch.ops.aten.add.Tensor(mean_29, 1e-05); mean_29 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] rsqrt_29: "f32[2048, 1]" = torch.ops.aten.rsqrt.default(add_58); add_58 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mul_116: "f32[2048, 2048]" = torch.ops.aten.mul.Tensor(convert_element_type_488, rsqrt_29); convert_element_type_488 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:74 in forward, code: output = self._norm(x.float()).type_as(x) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_489: "bf16[2048, 2048]" = torch.ops.prims.convert_element_type.default(mul_116, torch.bfloat16); mul_116 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:75 in forward, code: return output * self.weight | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mul_117: "bf16[2048, 2048]" = torch.ops.aten.mul.Tensor(convert_element_type_489, contiguous_view_as_strided_137); convert_element_type_489 = contiguous_view_as_strided_137 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/pytorch_yf225/torch/nn/modules/linear.py:116 in forward, code: return F.linear(input, self.weight, self.bias) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] permute_265: "bf16[2048, 5632]" = torch.ops.aten.permute.default(contiguous_view_as_strided_133, [1, 0]); contiguous_view_as_strided_133 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mm_102: "bf16[2048, 5632]" = torch.ops.aten.mm.default(mul_117, permute_265); permute_265 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:299 in forward, code: return self.w2(F.silu(self.w1(x)) * self.w3(x)) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_492: "f32[2048, 5632]" = torch.ops.prims.convert_element_type.default(mm_102, torch.float32) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] sigmoid_14: "f32[2048, 5632]" = torch.ops.aten.sigmoid.default(convert_element_type_492) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mul_118: "f32[2048, 5632]" = torch.ops.aten.mul.Tensor(convert_element_type_492, sigmoid_14); convert_element_type_492 = sigmoid_14 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_493: "bf16[2048, 5632]" = torch.ops.prims.convert_element_type.default(mul_118, torch.bfloat16); mul_118 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/pytorch_yf225/torch/nn/modules/linear.py:116 in forward, code: return F.linear(input, self.weight, self.bias) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] permute_267: "bf16[2048, 5632]" = torch.ops.aten.permute.default(contiguous_view_as_strided_135, [1, 0]); contiguous_view_as_strided_135 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mm_103: "bf16[2048, 5632]" = torch.ops.aten.mm.default(mul_117, permute_267); permute_267 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:299 in forward, code: return self.w2(F.silu(self.w1(x)) * self.w3(x)) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mul_119: "bf16[2048, 5632]" = torch.ops.aten.mul.Tensor(convert_element_type_493, mm_103); convert_element_type_493 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/pytorch_yf225/torch/nn/modules/linear.py:116 in forward, code: return F.linear(input, self.weight, self.bias) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] permute_269: "bf16[5632, 2048]" = torch.ops.aten.permute.default(contiguous_view_as_strided_134, [1, 0]); contiguous_view_as_strided_134 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mm_104: "bf16[2048, 2048]" = torch.ops.aten.mm.default(mul_119, permute_269); permute_269 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:411 in forward, code: out = h + self.feed_forward(self.ffn_norm(h)) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] add_59: "bf16[2048, 2048]" = torch.ops.aten.add.Tensor(add_57, mm_104); add_57 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_common.py:140 in _to_dtype_if_needed, code: return tensor.to(dtype) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_498: "bf16[524288]" = torch.ops.prims.convert_element_type.default(primals_280, torch.bfloat16); primals_280 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_499: "bf16[524288]" = torch.ops.prims.convert_element_type.default(primals_281, torch.bfloat16); primals_281 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_500: "bf16[524288]" = torch.ops.prims.convert_element_type.default(primals_282, torch.bfloat16); primals_282 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_501: "bf16[524288]" = torch.ops.prims.convert_element_type.default(primals_283, torch.bfloat16); primals_283 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_502: "bf16[1441792]" = torch.ops.prims.convert_element_type.default(primals_284, torch.bfloat16); primals_284 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_503: "bf16[1441792]" = torch.ops.prims.convert_element_type.default(primals_285, torch.bfloat16); primals_285 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_504: "bf16[1441792]" = torch.ops.prims.convert_element_type.default(primals_286, torch.bfloat16); primals_286 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_505: "bf16[256]" = torch.ops.prims.convert_element_type.default(primals_287, torch.bfloat16); primals_287 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_506: "bf16[256]" = torch.ops.prims.convert_element_type.default(primals_288, torch.bfloat16); primals_288 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:80 in foreach_all_gather, code: all_gather_input, all_gather_output = torch.ops.fsdp.all_gather_copy_in(all_gather_input_numel, world_size, rank, dtype, device, inp_split_sizes, param_all_gather_inputs) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] all_gather_copy_in_16 = torch.ops.fsdp.all_gather_copy_in.default(6423040, 8, 0, torch.bfloat16, device(type='cuda', index=0), [524288, 524288, 524288, 524288, 1441792, 1441792, 1441792, 256, 256], [convert_element_type_498, convert_element_type_499, convert_element_type_500, convert_element_type_501, convert_element_type_502, convert_element_type_503, convert_element_type_504, convert_element_type_505, convert_element_type_506]); convert_element_type_498 = convert_element_type_499 = convert_element_type_500 = convert_element_type_501 = convert_element_type_502 = convert_element_type_503 = convert_element_type_504 = convert_element_type_505 = convert_element_type_506 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_1529: "bf16[6423040]" = all_gather_copy_in_16[0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_1530: "bf16[51384320]" = all_gather_copy_in_16[1]; all_gather_copy_in_16 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_functional_collectives.py:229 in all_gather_tensor, code: tensor = torch.ops._c10d_functional.all_gather_into_tensor( | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] all_gather_into_tensor_16: "bf16[51384320]" = torch.ops._c10d_functional.all_gather_into_tensor.default(getitem_1529, 8, '0'); getitem_1529 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_functional_collectives.py:144 in wait_tensor, code: return torch.ops._c10d_functional.wait_tensor(tensor) # type: ignore[attr-defined] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] wait_tensor_16: "bf16[51384320]" = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_16); all_gather_into_tensor_16 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:170 in foreach_all_gather_copy_out, code: split_unpadded = torch.ops.fsdp.contiguous_view_as_strided( | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_291: "bf16[8, 6423040]" = torch.ops.aten.reshape.default(wait_tensor_16, [8, -1]); wait_tensor_16 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] split_with_sizes_155 = torch.ops.aten.split_with_sizes.default(view_291, [524288, 524288, 524288, 524288, 1441792, 1441792, 1441792, 256, 256], 1); view_291 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_1540: "bf16[8, 524288]" = split_with_sizes_155[0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] contiguous_view_as_strided_138: "bf16[2048, 2048]" = torch.ops.fsdp.contiguous_view_as_strided.default(getitem_1540, [2048, 2048], [2048, 1]); getitem_1540 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_1550: "bf16[8, 524288]" = split_with_sizes_155[1] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] contiguous_view_as_strided_139: "bf16[2048, 2048]" = torch.ops.fsdp.contiguous_view_as_strided.default(getitem_1550, [2048, 2048], [2048, 1]); getitem_1550 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_1560: "bf16[8, 524288]" = split_with_sizes_155[2] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] contiguous_view_as_strided_140: "bf16[2048, 2048]" = torch.ops.fsdp.contiguous_view_as_strided.default(getitem_1560, [2048, 2048], [2048, 1]); getitem_1560 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_1570: "bf16[8, 524288]" = split_with_sizes_155[3] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] contiguous_view_as_strided_141: "bf16[2048, 2048]" = torch.ops.fsdp.contiguous_view_as_strided.default(getitem_1570, [2048, 2048], [2048, 1]); getitem_1570 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_1580: "bf16[8, 1441792]" = split_with_sizes_155[4] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] contiguous_view_as_strided_142: "bf16[5632, 2048]" = torch.ops.fsdp.contiguous_view_as_strided.default(getitem_1580, [5632, 2048], [2048, 1]); getitem_1580 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_1590: "bf16[8, 1441792]" = split_with_sizes_155[5] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] contiguous_view_as_strided_143: "bf16[2048, 5632]" = torch.ops.fsdp.contiguous_view_as_strided.default(getitem_1590, [2048, 5632], [5632, 1]); getitem_1590 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_1600: "bf16[8, 1441792]" = split_with_sizes_155[6] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] contiguous_view_as_strided_144: "bf16[5632, 2048]" = torch.ops.fsdp.contiguous_view_as_strided.default(getitem_1600, [5632, 2048], [2048, 1]); getitem_1600 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_1610: "bf16[8, 256]" = split_with_sizes_155[7] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] contiguous_view_as_strided_145: "bf16[2048]" = torch.ops.fsdp.contiguous_view_as_strided.default(getitem_1610, [2048], [1]); getitem_1610 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_1620: "bf16[8, 256]" = split_with_sizes_155[8]; split_with_sizes_155 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] contiguous_view_as_strided_146: "bf16[2048]" = torch.ops.fsdp.contiguous_view_as_strided.default(getitem_1620, [2048], [1]); getitem_1620 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:74 in forward, code: output = self._norm(x.float()).type_as(x) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_507: "f32[2048, 2048]" = torch.ops.prims.convert_element_type.default(add_59, torch.float32) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:61 in _norm, code: return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] pow_31: "f32[2048, 2048]" = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_507, 2) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mean_30: "f32[2048, 1]" = torch.ops.aten.mean.dim(pow_31, [-1], True); pow_31 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] add_60: "f32[2048, 1]" = torch.ops.aten.add.Tensor(mean_30, 1e-05); mean_30 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] rsqrt_30: "f32[2048, 1]" = torch.ops.aten.rsqrt.default(add_60); add_60 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mul_120: "f32[2048, 2048]" = torch.ops.aten.mul.Tensor(convert_element_type_507, rsqrt_30); convert_element_type_507 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:74 in forward, code: output = self._norm(x.float()).type_as(x) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_508: "bf16[2048, 2048]" = torch.ops.prims.convert_element_type.default(mul_120, torch.bfloat16); mul_120 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:75 in forward, code: return output * self.weight | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mul_121: "bf16[2048, 2048]" = torch.ops.aten.mul.Tensor(convert_element_type_508, contiguous_view_as_strided_145); convert_element_type_508 = contiguous_view_as_strided_145 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/pytorch_yf225/torch/nn/modules/linear.py:116 in forward, code: return F.linear(input, self.weight, self.bias) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] permute_271: "bf16[2048, 2048]" = torch.ops.aten.permute.default(contiguous_view_as_strided_138, [1, 0]); contiguous_view_as_strided_138 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mm_105: "bf16[2048, 2048]" = torch.ops.aten.mm.default(mul_121, permute_271); permute_271 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] permute_273: "bf16[2048, 2048]" = torch.ops.aten.permute.default(contiguous_view_as_strided_139, [1, 0]); contiguous_view_as_strided_139 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mm_106: "bf16[2048, 2048]" = torch.ops.aten.mm.default(mul_121, permute_273); permute_273 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] permute_275: "bf16[2048, 2048]" = torch.ops.aten.permute.default(contiguous_view_as_strided_140, [1, 0]); contiguous_view_as_strided_140 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mm_107: "bf16[2048, 2048]" = torch.ops.aten.mm.default(mul_121, permute_275); permute_275 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:235 in forward, code: xq = xq.view(bsz, seqlen, self.n_heads, self.head_dim) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_300: "bf16[1, 2048, 16, 128]" = torch.ops.aten.reshape.default(mm_105, [1, 2048, 16, 128]); mm_105 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:236 in forward, code: xk = xk.view(bsz, seqlen, self.n_kv_heads, self.head_dim) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_301: "bf16[1, 2048, 16, 128]" = torch.ops.aten.reshape.default(mm_106, [1, 2048, 16, 128]); mm_106 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:237 in forward, code: xv = xv.view(bsz, seqlen, self.n_kv_heads, self.head_dim) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_302: "bf16[1, 2048, 16, 128]" = torch.ops.aten.reshape.default(mm_107, [1, 2048, 16, 128]); mm_107 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:146 in apply_rotary_emb, code: xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_515: "f32[1, 2048, 16, 128]" = torch.ops.prims.convert_element_type.default(view_300, torch.float32); view_300 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_303: "f32[1, 2048, 16, 64, 2]" = torch.ops.aten.reshape.default(convert_element_type_515, [1, 2048, 16, -1, 2]); convert_element_type_515 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_as_complex_30: "c64[1, 2048, 16, 64]" = torch.ops.aten.view_as_complex.default(view_303); view_303 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:147 in apply_rotary_emb, code: xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2)) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_516: "f32[1, 2048, 16, 128]" = torch.ops.prims.convert_element_type.default(view_301, torch.float32); view_301 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_304: "f32[1, 2048, 16, 64, 2]" = torch.ops.aten.reshape.default(convert_element_type_516, [1, 2048, 16, -1, 2]); convert_element_type_516 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_as_complex_31: "c64[1, 2048, 16, 64]" = torch.ops.aten.view_as_complex.default(view_304); view_304 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:149 in apply_rotary_emb, code: xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mul_122: "c64[1, 2048, 16, 64]" = torch.ops.aten.mul.Tensor(view_as_complex_30, view_20); view_as_complex_30 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_as_real_30: "f32[1, 2048, 16, 64, 2]" = torch.ops.aten.view_as_real.default(mul_122); mul_122 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_306: "f32[1, 2048, 16, 128]" = torch.ops.aten.reshape.default(view_as_real_30, [1, 2048, 16, 128]); view_as_real_30 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:150 in apply_rotary_emb, code: xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mul_123: "c64[1, 2048, 16, 64]" = torch.ops.aten.mul.Tensor(view_as_complex_31, view_20); view_as_complex_31 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_as_real_31: "f32[1, 2048, 16, 64, 2]" = torch.ops.aten.view_as_real.default(mul_123); mul_123 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_307: "f32[1, 2048, 16, 128]" = torch.ops.aten.reshape.default(view_as_real_31, [1, 2048, 16, 128]); view_as_real_31 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:151 in apply_rotary_emb, code: return xq_out.type_as(xq), xk_out.type_as(xk) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_517: "bf16[1, 2048, 16, 128]" = torch.ops.prims.convert_element_type.default(view_306, torch.bfloat16); view_306 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_518: "bf16[1, 2048, 16, 128]" = torch.ops.prims.convert_element_type.default(view_307, torch.bfloat16); view_307 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:249 in forward, code: xq = xq.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] permute_276: "bf16[1, 16, 2048, 128]" = torch.ops.aten.permute.default(convert_element_type_517, [0, 2, 1, 3]); convert_element_type_517 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:250 in forward, code: xk = keys.transpose(1, 2) # (bs, n_local_heads, cache_len + seqlen, head_dim) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] permute_277: "bf16[1, 16, 2048, 128]" = torch.ops.aten.permute.default(convert_element_type_518, [0, 2, 1, 3]); convert_element_type_518 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:251 in forward, code: xv = values.transpose(1, 2) # (bs, n_local_heads, cache_len + seqlen, head_dim) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] permute_278: "bf16[1, 16, 2048, 128]" = torch.ops.aten.permute.default(view_302, [0, 2, 1, 3]); view_302 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:254 in forward, code: output = F.scaled_dot_product_attention(xq, xk, xv, is_causal=True) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] _scaled_dot_product_flash_attention_15 = torch.ops.aten._scaled_dot_product_flash_attention.default(permute_276, permute_277, permute_278, 0.0, True, scale = 0.08838834764831843) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_1621: "bf16[1, 16, 2048, 128]" = _scaled_dot_product_flash_attention_15[0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_1622: "f32[1, 16, 2048]" = _scaled_dot_product_flash_attention_15[1] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_1627: "i64[]" = _scaled_dot_product_flash_attention_15[6] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_1628: "i64[]" = _scaled_dot_product_flash_attention_15[7]; _scaled_dot_product_flash_attention_15 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:255 in forward, code: output = output.transpose( | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] permute_279: "bf16[1, 2048, 16, 128]" = torch.ops.aten.permute.default(getitem_1621, [0, 2, 1, 3]) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:259 in forward, code: output = output.view(bsz * seqlen, -1) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_308: "bf16[2048, 2048]" = torch.ops.aten.reshape.default(permute_279, [2048, -1]); permute_279 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/pytorch_yf225/torch/nn/modules/linear.py:116 in forward, code: return F.linear(input, self.weight, self.bias) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] permute_281: "bf16[2048, 2048]" = torch.ops.aten.permute.default(contiguous_view_as_strided_141, [1, 0]); contiguous_view_as_strided_141 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mm_108: "bf16[2048, 2048]" = torch.ops.aten.mm.default(view_308, permute_281); permute_281 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:410 in forward, code: h = x + self.attention(self.attention_norm(x), freqs_cis) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] add_61: "bf16[2048, 2048]" = torch.ops.aten.add.Tensor(add_59, mm_108); add_59 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:74 in forward, code: output = self._norm(x.float()).type_as(x) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_521: "f32[2048, 2048]" = torch.ops.prims.convert_element_type.default(add_61, torch.float32) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:61 in _norm, code: return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] pow_32: "f32[2048, 2048]" = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_521, 2) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mean_31: "f32[2048, 1]" = torch.ops.aten.mean.dim(pow_32, [-1], True); pow_32 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] add_62: "f32[2048, 1]" = torch.ops.aten.add.Tensor(mean_31, 1e-05); mean_31 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] rsqrt_31: "f32[2048, 1]" = torch.ops.aten.rsqrt.default(add_62); add_62 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mul_124: "f32[2048, 2048]" = torch.ops.aten.mul.Tensor(convert_element_type_521, rsqrt_31); convert_element_type_521 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:74 in forward, code: output = self._norm(x.float()).type_as(x) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_522: "bf16[2048, 2048]" = torch.ops.prims.convert_element_type.default(mul_124, torch.bfloat16); mul_124 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:75 in forward, code: return output * self.weight | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mul_125: "bf16[2048, 2048]" = torch.ops.aten.mul.Tensor(convert_element_type_522, contiguous_view_as_strided_146); convert_element_type_522 = contiguous_view_as_strided_146 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/pytorch_yf225/torch/nn/modules/linear.py:116 in forward, code: return F.linear(input, self.weight, self.bias) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] permute_283: "bf16[2048, 5632]" = torch.ops.aten.permute.default(contiguous_view_as_strided_142, [1, 0]); contiguous_view_as_strided_142 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mm_109: "bf16[2048, 5632]" = torch.ops.aten.mm.default(mul_125, permute_283); permute_283 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:299 in forward, code: return self.w2(F.silu(self.w1(x)) * self.w3(x)) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_525: "f32[2048, 5632]" = torch.ops.prims.convert_element_type.default(mm_109, torch.float32) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] sigmoid_15: "f32[2048, 5632]" = torch.ops.aten.sigmoid.default(convert_element_type_525) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mul_126: "f32[2048, 5632]" = torch.ops.aten.mul.Tensor(convert_element_type_525, sigmoid_15); convert_element_type_525 = sigmoid_15 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_526: "bf16[2048, 5632]" = torch.ops.prims.convert_element_type.default(mul_126, torch.bfloat16); mul_126 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/pytorch_yf225/torch/nn/modules/linear.py:116 in forward, code: return F.linear(input, self.weight, self.bias) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] permute_285: "bf16[2048, 5632]" = torch.ops.aten.permute.default(contiguous_view_as_strided_144, [1, 0]); contiguous_view_as_strided_144 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mm_110: "bf16[2048, 5632]" = torch.ops.aten.mm.default(mul_125, permute_285); permute_285 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:299 in forward, code: return self.w2(F.silu(self.w1(x)) * self.w3(x)) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mul_127: "bf16[2048, 5632]" = torch.ops.aten.mul.Tensor(convert_element_type_526, mm_110); convert_element_type_526 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/pytorch_yf225/torch/nn/modules/linear.py:116 in forward, code: return F.linear(input, self.weight, self.bias) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] permute_287: "bf16[5632, 2048]" = torch.ops.aten.permute.default(contiguous_view_as_strided_143, [1, 0]); contiguous_view_as_strided_143 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mm_111: "bf16[2048, 2048]" = torch.ops.aten.mm.default(mul_127, permute_287); permute_287 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:411 in forward, code: out = h + self.feed_forward(self.ffn_norm(h)) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] add_63: "bf16[2048, 2048]" = torch.ops.aten.add.Tensor(add_61, mm_111); add_61 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_common.py:140 in _to_dtype_if_needed, code: return tensor.to(dtype) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_531: "bf16[524288]" = torch.ops.prims.convert_element_type.default(primals_298, torch.bfloat16); primals_298 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_532: "bf16[524288]" = torch.ops.prims.convert_element_type.default(primals_299, torch.bfloat16); primals_299 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_533: "bf16[524288]" = torch.ops.prims.convert_element_type.default(primals_300, torch.bfloat16); primals_300 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_534: "bf16[524288]" = torch.ops.prims.convert_element_type.default(primals_301, torch.bfloat16); primals_301 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_535: "bf16[1441792]" = torch.ops.prims.convert_element_type.default(primals_302, torch.bfloat16); primals_302 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_536: "bf16[1441792]" = torch.ops.prims.convert_element_type.default(primals_303, torch.bfloat16); primals_303 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_537: "bf16[1441792]" = torch.ops.prims.convert_element_type.default(primals_304, torch.bfloat16); primals_304 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_538: "bf16[256]" = torch.ops.prims.convert_element_type.default(primals_305, torch.bfloat16); primals_305 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_539: "bf16[256]" = torch.ops.prims.convert_element_type.default(primals_306, torch.bfloat16); primals_306 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:80 in foreach_all_gather, code: all_gather_input, all_gather_output = torch.ops.fsdp.all_gather_copy_in(all_gather_input_numel, world_size, rank, dtype, device, inp_split_sizes, param_all_gather_inputs) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] all_gather_copy_in_17 = torch.ops.fsdp.all_gather_copy_in.default(6423040, 8, 0, torch.bfloat16, device(type='cuda', index=0), [524288, 524288, 524288, 524288, 1441792, 1441792, 1441792, 256, 256], [convert_element_type_531, convert_element_type_532, convert_element_type_533, convert_element_type_534, convert_element_type_535, convert_element_type_536, convert_element_type_537, convert_element_type_538, convert_element_type_539]); convert_element_type_531 = convert_element_type_532 = convert_element_type_533 = convert_element_type_534 = convert_element_type_535 = convert_element_type_536 = convert_element_type_537 = convert_element_type_538 = convert_element_type_539 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_1630: "bf16[6423040]" = all_gather_copy_in_17[0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_1631: "bf16[51384320]" = all_gather_copy_in_17[1]; all_gather_copy_in_17 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_functional_collectives.py:229 in all_gather_tensor, code: tensor = torch.ops._c10d_functional.all_gather_into_tensor( | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] all_gather_into_tensor_17: "bf16[51384320]" = torch.ops._c10d_functional.all_gather_into_tensor.default(getitem_1630, 8, '0'); getitem_1630 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_functional_collectives.py:144 in wait_tensor, code: return torch.ops._c10d_functional.wait_tensor(tensor) # type: ignore[attr-defined] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] wait_tensor_17: "bf16[51384320]" = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_17); all_gather_into_tensor_17 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:170 in foreach_all_gather_copy_out, code: split_unpadded = torch.ops.fsdp.contiguous_view_as_strided( | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_310: "bf16[8, 6423040]" = torch.ops.aten.reshape.default(wait_tensor_17, [8, -1]); wait_tensor_17 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] split_with_sizes_165 = torch.ops.aten.split_with_sizes.default(view_310, [524288, 524288, 524288, 524288, 1441792, 1441792, 1441792, 256, 256], 1); view_310 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_1641: "bf16[8, 524288]" = split_with_sizes_165[0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] contiguous_view_as_strided_147: "bf16[2048, 2048]" = torch.ops.fsdp.contiguous_view_as_strided.default(getitem_1641, [2048, 2048], [2048, 1]); getitem_1641 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_1651: "bf16[8, 524288]" = split_with_sizes_165[1] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] contiguous_view_as_strided_148: "bf16[2048, 2048]" = torch.ops.fsdp.contiguous_view_as_strided.default(getitem_1651, [2048, 2048], [2048, 1]); getitem_1651 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_1661: "bf16[8, 524288]" = split_with_sizes_165[2] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] contiguous_view_as_strided_149: "bf16[2048, 2048]" = torch.ops.fsdp.contiguous_view_as_strided.default(getitem_1661, [2048, 2048], [2048, 1]); getitem_1661 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_1671: "bf16[8, 524288]" = split_with_sizes_165[3] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] contiguous_view_as_strided_150: "bf16[2048, 2048]" = torch.ops.fsdp.contiguous_view_as_strided.default(getitem_1671, [2048, 2048], [2048, 1]); getitem_1671 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_1681: "bf16[8, 1441792]" = split_with_sizes_165[4] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] contiguous_view_as_strided_151: "bf16[5632, 2048]" = torch.ops.fsdp.contiguous_view_as_strided.default(getitem_1681, [5632, 2048], [2048, 1]); getitem_1681 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_1691: "bf16[8, 1441792]" = split_with_sizes_165[5] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] contiguous_view_as_strided_152: "bf16[2048, 5632]" = torch.ops.fsdp.contiguous_view_as_strided.default(getitem_1691, [2048, 5632], [5632, 1]); getitem_1691 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_1701: "bf16[8, 1441792]" = split_with_sizes_165[6] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] contiguous_view_as_strided_153: "bf16[5632, 2048]" = torch.ops.fsdp.contiguous_view_as_strided.default(getitem_1701, [5632, 2048], [2048, 1]); getitem_1701 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_1711: "bf16[8, 256]" = split_with_sizes_165[7] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] contiguous_view_as_strided_154: "bf16[2048]" = torch.ops.fsdp.contiguous_view_as_strided.default(getitem_1711, [2048], [1]); getitem_1711 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_1721: "bf16[8, 256]" = split_with_sizes_165[8]; split_with_sizes_165 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] contiguous_view_as_strided_155: "bf16[2048]" = torch.ops.fsdp.contiguous_view_as_strided.default(getitem_1721, [2048], [1]); getitem_1721 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:74 in forward, code: output = self._norm(x.float()).type_as(x) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_540: "f32[2048, 2048]" = torch.ops.prims.convert_element_type.default(add_63, torch.float32) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:61 in _norm, code: return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] pow_33: "f32[2048, 2048]" = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_540, 2) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mean_32: "f32[2048, 1]" = torch.ops.aten.mean.dim(pow_33, [-1], True); pow_33 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] add_64: "f32[2048, 1]" = torch.ops.aten.add.Tensor(mean_32, 1e-05); mean_32 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] rsqrt_32: "f32[2048, 1]" = torch.ops.aten.rsqrt.default(add_64); add_64 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mul_128: "f32[2048, 2048]" = torch.ops.aten.mul.Tensor(convert_element_type_540, rsqrt_32); convert_element_type_540 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:74 in forward, code: output = self._norm(x.float()).type_as(x) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_541: "bf16[2048, 2048]" = torch.ops.prims.convert_element_type.default(mul_128, torch.bfloat16); mul_128 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:75 in forward, code: return output * self.weight | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mul_129: "bf16[2048, 2048]" = torch.ops.aten.mul.Tensor(convert_element_type_541, contiguous_view_as_strided_154); convert_element_type_541 = contiguous_view_as_strided_154 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/pytorch_yf225/torch/nn/modules/linear.py:116 in forward, code: return F.linear(input, self.weight, self.bias) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] permute_289: "bf16[2048, 2048]" = torch.ops.aten.permute.default(contiguous_view_as_strided_147, [1, 0]); contiguous_view_as_strided_147 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mm_112: "bf16[2048, 2048]" = torch.ops.aten.mm.default(mul_129, permute_289); permute_289 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] permute_291: "bf16[2048, 2048]" = torch.ops.aten.permute.default(contiguous_view_as_strided_148, [1, 0]); contiguous_view_as_strided_148 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mm_113: "bf16[2048, 2048]" = torch.ops.aten.mm.default(mul_129, permute_291); permute_291 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] permute_293: "bf16[2048, 2048]" = torch.ops.aten.permute.default(contiguous_view_as_strided_149, [1, 0]); contiguous_view_as_strided_149 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mm_114: "bf16[2048, 2048]" = torch.ops.aten.mm.default(mul_129, permute_293); permute_293 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:235 in forward, code: xq = xq.view(bsz, seqlen, self.n_heads, self.head_dim) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_319: "bf16[1, 2048, 16, 128]" = torch.ops.aten.reshape.default(mm_112, [1, 2048, 16, 128]); mm_112 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:236 in forward, code: xk = xk.view(bsz, seqlen, self.n_kv_heads, self.head_dim) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_320: "bf16[1, 2048, 16, 128]" = torch.ops.aten.reshape.default(mm_113, [1, 2048, 16, 128]); mm_113 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:237 in forward, code: xv = xv.view(bsz, seqlen, self.n_kv_heads, self.head_dim) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_321: "bf16[1, 2048, 16, 128]" = torch.ops.aten.reshape.default(mm_114, [1, 2048, 16, 128]); mm_114 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:146 in apply_rotary_emb, code: xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_548: "f32[1, 2048, 16, 128]" = torch.ops.prims.convert_element_type.default(view_319, torch.float32); view_319 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_322: "f32[1, 2048, 16, 64, 2]" = torch.ops.aten.reshape.default(convert_element_type_548, [1, 2048, 16, -1, 2]); convert_element_type_548 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_as_complex_32: "c64[1, 2048, 16, 64]" = torch.ops.aten.view_as_complex.default(view_322); view_322 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:147 in apply_rotary_emb, code: xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2)) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_549: "f32[1, 2048, 16, 128]" = torch.ops.prims.convert_element_type.default(view_320, torch.float32); view_320 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_323: "f32[1, 2048, 16, 64, 2]" = torch.ops.aten.reshape.default(convert_element_type_549, [1, 2048, 16, -1, 2]); convert_element_type_549 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_as_complex_33: "c64[1, 2048, 16, 64]" = torch.ops.aten.view_as_complex.default(view_323); view_323 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:149 in apply_rotary_emb, code: xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mul_130: "c64[1, 2048, 16, 64]" = torch.ops.aten.mul.Tensor(view_as_complex_32, view_20); view_as_complex_32 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_as_real_32: "f32[1, 2048, 16, 64, 2]" = torch.ops.aten.view_as_real.default(mul_130); mul_130 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_325: "f32[1, 2048, 16, 128]" = torch.ops.aten.reshape.default(view_as_real_32, [1, 2048, 16, 128]); view_as_real_32 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:150 in apply_rotary_emb, code: xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mul_131: "c64[1, 2048, 16, 64]" = torch.ops.aten.mul.Tensor(view_as_complex_33, view_20); view_as_complex_33 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_as_real_33: "f32[1, 2048, 16, 64, 2]" = torch.ops.aten.view_as_real.default(mul_131); mul_131 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_326: "f32[1, 2048, 16, 128]" = torch.ops.aten.reshape.default(view_as_real_33, [1, 2048, 16, 128]); view_as_real_33 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:151 in apply_rotary_emb, code: return xq_out.type_as(xq), xk_out.type_as(xk) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_550: "bf16[1, 2048, 16, 128]" = torch.ops.prims.convert_element_type.default(view_325, torch.bfloat16); view_325 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_551: "bf16[1, 2048, 16, 128]" = torch.ops.prims.convert_element_type.default(view_326, torch.bfloat16); view_326 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:249 in forward, code: xq = xq.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] permute_294: "bf16[1, 16, 2048, 128]" = torch.ops.aten.permute.default(convert_element_type_550, [0, 2, 1, 3]); convert_element_type_550 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:250 in forward, code: xk = keys.transpose(1, 2) # (bs, n_local_heads, cache_len + seqlen, head_dim) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] permute_295: "bf16[1, 16, 2048, 128]" = torch.ops.aten.permute.default(convert_element_type_551, [0, 2, 1, 3]); convert_element_type_551 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:251 in forward, code: xv = values.transpose(1, 2) # (bs, n_local_heads, cache_len + seqlen, head_dim) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] permute_296: "bf16[1, 16, 2048, 128]" = torch.ops.aten.permute.default(view_321, [0, 2, 1, 3]); view_321 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:254 in forward, code: output = F.scaled_dot_product_attention(xq, xk, xv, is_causal=True) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] _scaled_dot_product_flash_attention_16 = torch.ops.aten._scaled_dot_product_flash_attention.default(permute_294, permute_295, permute_296, 0.0, True, scale = 0.08838834764831843) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_1722: "bf16[1, 16, 2048, 128]" = _scaled_dot_product_flash_attention_16[0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_1723: "f32[1, 16, 2048]" = _scaled_dot_product_flash_attention_16[1] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_1728: "i64[]" = _scaled_dot_product_flash_attention_16[6] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_1729: "i64[]" = _scaled_dot_product_flash_attention_16[7]; _scaled_dot_product_flash_attention_16 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:255 in forward, code: output = output.transpose( | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] permute_297: "bf16[1, 2048, 16, 128]" = torch.ops.aten.permute.default(getitem_1722, [0, 2, 1, 3]) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:259 in forward, code: output = output.view(bsz * seqlen, -1) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_327: "bf16[2048, 2048]" = torch.ops.aten.reshape.default(permute_297, [2048, -1]); permute_297 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/pytorch_yf225/torch/nn/modules/linear.py:116 in forward, code: return F.linear(input, self.weight, self.bias) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] permute_299: "bf16[2048, 2048]" = torch.ops.aten.permute.default(contiguous_view_as_strided_150, [1, 0]); contiguous_view_as_strided_150 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mm_115: "bf16[2048, 2048]" = torch.ops.aten.mm.default(view_327, permute_299); permute_299 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:410 in forward, code: h = x + self.attention(self.attention_norm(x), freqs_cis) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] add_65: "bf16[2048, 2048]" = torch.ops.aten.add.Tensor(add_63, mm_115); add_63 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:74 in forward, code: output = self._norm(x.float()).type_as(x) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_554: "f32[2048, 2048]" = torch.ops.prims.convert_element_type.default(add_65, torch.float32) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:61 in _norm, code: return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] pow_34: "f32[2048, 2048]" = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_554, 2) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mean_33: "f32[2048, 1]" = torch.ops.aten.mean.dim(pow_34, [-1], True); pow_34 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] add_66: "f32[2048, 1]" = torch.ops.aten.add.Tensor(mean_33, 1e-05); mean_33 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] rsqrt_33: "f32[2048, 1]" = torch.ops.aten.rsqrt.default(add_66); add_66 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mul_132: "f32[2048, 2048]" = torch.ops.aten.mul.Tensor(convert_element_type_554, rsqrt_33); convert_element_type_554 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:74 in forward, code: output = self._norm(x.float()).type_as(x) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_555: "bf16[2048, 2048]" = torch.ops.prims.convert_element_type.default(mul_132, torch.bfloat16); mul_132 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:75 in forward, code: return output * self.weight | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mul_133: "bf16[2048, 2048]" = torch.ops.aten.mul.Tensor(convert_element_type_555, contiguous_view_as_strided_155); convert_element_type_555 = contiguous_view_as_strided_155 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/pytorch_yf225/torch/nn/modules/linear.py:116 in forward, code: return F.linear(input, self.weight, self.bias) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] permute_301: "bf16[2048, 5632]" = torch.ops.aten.permute.default(contiguous_view_as_strided_151, [1, 0]); contiguous_view_as_strided_151 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mm_116: "bf16[2048, 5632]" = torch.ops.aten.mm.default(mul_133, permute_301); permute_301 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:299 in forward, code: return self.w2(F.silu(self.w1(x)) * self.w3(x)) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_558: "f32[2048, 5632]" = torch.ops.prims.convert_element_type.default(mm_116, torch.float32) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] sigmoid_16: "f32[2048, 5632]" = torch.ops.aten.sigmoid.default(convert_element_type_558) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mul_134: "f32[2048, 5632]" = torch.ops.aten.mul.Tensor(convert_element_type_558, sigmoid_16); convert_element_type_558 = sigmoid_16 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_559: "bf16[2048, 5632]" = torch.ops.prims.convert_element_type.default(mul_134, torch.bfloat16); mul_134 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/pytorch_yf225/torch/nn/modules/linear.py:116 in forward, code: return F.linear(input, self.weight, self.bias) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] permute_303: "bf16[2048, 5632]" = torch.ops.aten.permute.default(contiguous_view_as_strided_153, [1, 0]); contiguous_view_as_strided_153 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mm_117: "bf16[2048, 5632]" = torch.ops.aten.mm.default(mul_133, permute_303); permute_303 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:299 in forward, code: return self.w2(F.silu(self.w1(x)) * self.w3(x)) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mul_135: "bf16[2048, 5632]" = torch.ops.aten.mul.Tensor(convert_element_type_559, mm_117); convert_element_type_559 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/pytorch_yf225/torch/nn/modules/linear.py:116 in forward, code: return F.linear(input, self.weight, self.bias) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] permute_305: "bf16[5632, 2048]" = torch.ops.aten.permute.default(contiguous_view_as_strided_152, [1, 0]); contiguous_view_as_strided_152 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mm_118: "bf16[2048, 2048]" = torch.ops.aten.mm.default(mul_135, permute_305); permute_305 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:411 in forward, code: out = h + self.feed_forward(self.ffn_norm(h)) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] add_67: "bf16[2048, 2048]" = torch.ops.aten.add.Tensor(add_65, mm_118); add_65 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_common.py:140 in _to_dtype_if_needed, code: return tensor.to(dtype) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_564: "bf16[524288]" = torch.ops.prims.convert_element_type.default(primals_316, torch.bfloat16); primals_316 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_565: "bf16[524288]" = torch.ops.prims.convert_element_type.default(primals_317, torch.bfloat16); primals_317 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_566: "bf16[524288]" = torch.ops.prims.convert_element_type.default(primals_318, torch.bfloat16); primals_318 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_567: "bf16[524288]" = torch.ops.prims.convert_element_type.default(primals_319, torch.bfloat16); primals_319 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_568: "bf16[1441792]" = torch.ops.prims.convert_element_type.default(primals_320, torch.bfloat16); primals_320 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_569: "bf16[1441792]" = torch.ops.prims.convert_element_type.default(primals_321, torch.bfloat16); primals_321 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_570: "bf16[1441792]" = torch.ops.prims.convert_element_type.default(primals_322, torch.bfloat16); primals_322 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_571: "bf16[256]" = torch.ops.prims.convert_element_type.default(primals_323, torch.bfloat16); primals_323 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_572: "bf16[256]" = torch.ops.prims.convert_element_type.default(primals_324, torch.bfloat16); primals_324 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:80 in foreach_all_gather, code: all_gather_input, all_gather_output = torch.ops.fsdp.all_gather_copy_in(all_gather_input_numel, world_size, rank, dtype, device, inp_split_sizes, param_all_gather_inputs) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] all_gather_copy_in_18 = torch.ops.fsdp.all_gather_copy_in.default(6423040, 8, 0, torch.bfloat16, device(type='cuda', index=0), [524288, 524288, 524288, 524288, 1441792, 1441792, 1441792, 256, 256], [convert_element_type_564, convert_element_type_565, convert_element_type_566, convert_element_type_567, convert_element_type_568, convert_element_type_569, convert_element_type_570, convert_element_type_571, convert_element_type_572]); convert_element_type_564 = convert_element_type_565 = convert_element_type_566 = convert_element_type_567 = convert_element_type_568 = convert_element_type_569 = convert_element_type_570 = convert_element_type_571 = convert_element_type_572 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_1731: "bf16[6423040]" = all_gather_copy_in_18[0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_1732: "bf16[51384320]" = all_gather_copy_in_18[1]; all_gather_copy_in_18 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_functional_collectives.py:229 in all_gather_tensor, code: tensor = torch.ops._c10d_functional.all_gather_into_tensor( | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] all_gather_into_tensor_18: "bf16[51384320]" = torch.ops._c10d_functional.all_gather_into_tensor.default(getitem_1731, 8, '0'); getitem_1731 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_functional_collectives.py:144 in wait_tensor, code: return torch.ops._c10d_functional.wait_tensor(tensor) # type: ignore[attr-defined] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] wait_tensor_18: "bf16[51384320]" = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_18); all_gather_into_tensor_18 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:170 in foreach_all_gather_copy_out, code: split_unpadded = torch.ops.fsdp.contiguous_view_as_strided( | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_329: "bf16[8, 6423040]" = torch.ops.aten.reshape.default(wait_tensor_18, [8, -1]); wait_tensor_18 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] split_with_sizes_175 = torch.ops.aten.split_with_sizes.default(view_329, [524288, 524288, 524288, 524288, 1441792, 1441792, 1441792, 256, 256], 1); view_329 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_1742: "bf16[8, 524288]" = split_with_sizes_175[0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] contiguous_view_as_strided_156: "bf16[2048, 2048]" = torch.ops.fsdp.contiguous_view_as_strided.default(getitem_1742, [2048, 2048], [2048, 1]); getitem_1742 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_1752: "bf16[8, 524288]" = split_with_sizes_175[1] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] contiguous_view_as_strided_157: "bf16[2048, 2048]" = torch.ops.fsdp.contiguous_view_as_strided.default(getitem_1752, [2048, 2048], [2048, 1]); getitem_1752 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_1762: "bf16[8, 524288]" = split_with_sizes_175[2] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] contiguous_view_as_strided_158: "bf16[2048, 2048]" = torch.ops.fsdp.contiguous_view_as_strided.default(getitem_1762, [2048, 2048], [2048, 1]); getitem_1762 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_1772: "bf16[8, 524288]" = split_with_sizes_175[3] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] contiguous_view_as_strided_159: "bf16[2048, 2048]" = torch.ops.fsdp.contiguous_view_as_strided.default(getitem_1772, [2048, 2048], [2048, 1]); getitem_1772 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_1782: "bf16[8, 1441792]" = split_with_sizes_175[4] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] contiguous_view_as_strided_160: "bf16[5632, 2048]" = torch.ops.fsdp.contiguous_view_as_strided.default(getitem_1782, [5632, 2048], [2048, 1]); getitem_1782 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_1792: "bf16[8, 1441792]" = split_with_sizes_175[5] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] contiguous_view_as_strided_161: "bf16[2048, 5632]" = torch.ops.fsdp.contiguous_view_as_strided.default(getitem_1792, [2048, 5632], [5632, 1]); getitem_1792 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_1802: "bf16[8, 1441792]" = split_with_sizes_175[6] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] contiguous_view_as_strided_162: "bf16[5632, 2048]" = torch.ops.fsdp.contiguous_view_as_strided.default(getitem_1802, [5632, 2048], [2048, 1]); getitem_1802 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_1812: "bf16[8, 256]" = split_with_sizes_175[7] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] contiguous_view_as_strided_163: "bf16[2048]" = torch.ops.fsdp.contiguous_view_as_strided.default(getitem_1812, [2048], [1]); getitem_1812 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_1822: "bf16[8, 256]" = split_with_sizes_175[8]; split_with_sizes_175 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] contiguous_view_as_strided_164: "bf16[2048]" = torch.ops.fsdp.contiguous_view_as_strided.default(getitem_1822, [2048], [1]); getitem_1822 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:74 in forward, code: output = self._norm(x.float()).type_as(x) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_573: "f32[2048, 2048]" = torch.ops.prims.convert_element_type.default(add_67, torch.float32) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:61 in _norm, code: return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] pow_35: "f32[2048, 2048]" = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_573, 2) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mean_34: "f32[2048, 1]" = torch.ops.aten.mean.dim(pow_35, [-1], True); pow_35 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] add_68: "f32[2048, 1]" = torch.ops.aten.add.Tensor(mean_34, 1e-05); mean_34 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] rsqrt_34: "f32[2048, 1]" = torch.ops.aten.rsqrt.default(add_68); add_68 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mul_136: "f32[2048, 2048]" = torch.ops.aten.mul.Tensor(convert_element_type_573, rsqrt_34); convert_element_type_573 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:74 in forward, code: output = self._norm(x.float()).type_as(x) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_574: "bf16[2048, 2048]" = torch.ops.prims.convert_element_type.default(mul_136, torch.bfloat16); mul_136 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:75 in forward, code: return output * self.weight | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mul_137: "bf16[2048, 2048]" = torch.ops.aten.mul.Tensor(convert_element_type_574, contiguous_view_as_strided_163); convert_element_type_574 = contiguous_view_as_strided_163 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/pytorch_yf225/torch/nn/modules/linear.py:116 in forward, code: return F.linear(input, self.weight, self.bias) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] permute_307: "bf16[2048, 2048]" = torch.ops.aten.permute.default(contiguous_view_as_strided_156, [1, 0]); contiguous_view_as_strided_156 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mm_119: "bf16[2048, 2048]" = torch.ops.aten.mm.default(mul_137, permute_307); permute_307 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] permute_309: "bf16[2048, 2048]" = torch.ops.aten.permute.default(contiguous_view_as_strided_157, [1, 0]); contiguous_view_as_strided_157 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mm_120: "bf16[2048, 2048]" = torch.ops.aten.mm.default(mul_137, permute_309); permute_309 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] permute_311: "bf16[2048, 2048]" = torch.ops.aten.permute.default(contiguous_view_as_strided_158, [1, 0]); contiguous_view_as_strided_158 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mm_121: "bf16[2048, 2048]" = torch.ops.aten.mm.default(mul_137, permute_311); permute_311 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:235 in forward, code: xq = xq.view(bsz, seqlen, self.n_heads, self.head_dim) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_338: "bf16[1, 2048, 16, 128]" = torch.ops.aten.reshape.default(mm_119, [1, 2048, 16, 128]); mm_119 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:236 in forward, code: xk = xk.view(bsz, seqlen, self.n_kv_heads, self.head_dim) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_339: "bf16[1, 2048, 16, 128]" = torch.ops.aten.reshape.default(mm_120, [1, 2048, 16, 128]); mm_120 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:237 in forward, code: xv = xv.view(bsz, seqlen, self.n_kv_heads, self.head_dim) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_340: "bf16[1, 2048, 16, 128]" = torch.ops.aten.reshape.default(mm_121, [1, 2048, 16, 128]); mm_121 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:146 in apply_rotary_emb, code: xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_581: "f32[1, 2048, 16, 128]" = torch.ops.prims.convert_element_type.default(view_338, torch.float32); view_338 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_341: "f32[1, 2048, 16, 64, 2]" = torch.ops.aten.reshape.default(convert_element_type_581, [1, 2048, 16, -1, 2]); convert_element_type_581 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_as_complex_34: "c64[1, 2048, 16, 64]" = torch.ops.aten.view_as_complex.default(view_341); view_341 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:147 in apply_rotary_emb, code: xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2)) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_582: "f32[1, 2048, 16, 128]" = torch.ops.prims.convert_element_type.default(view_339, torch.float32); view_339 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_342: "f32[1, 2048, 16, 64, 2]" = torch.ops.aten.reshape.default(convert_element_type_582, [1, 2048, 16, -1, 2]); convert_element_type_582 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_as_complex_35: "c64[1, 2048, 16, 64]" = torch.ops.aten.view_as_complex.default(view_342); view_342 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:149 in apply_rotary_emb, code: xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mul_138: "c64[1, 2048, 16, 64]" = torch.ops.aten.mul.Tensor(view_as_complex_34, view_20); view_as_complex_34 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_as_real_34: "f32[1, 2048, 16, 64, 2]" = torch.ops.aten.view_as_real.default(mul_138); mul_138 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_344: "f32[1, 2048, 16, 128]" = torch.ops.aten.reshape.default(view_as_real_34, [1, 2048, 16, 128]); view_as_real_34 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:150 in apply_rotary_emb, code: xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mul_139: "c64[1, 2048, 16, 64]" = torch.ops.aten.mul.Tensor(view_as_complex_35, view_20); view_as_complex_35 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_as_real_35: "f32[1, 2048, 16, 64, 2]" = torch.ops.aten.view_as_real.default(mul_139); mul_139 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_345: "f32[1, 2048, 16, 128]" = torch.ops.aten.reshape.default(view_as_real_35, [1, 2048, 16, 128]); view_as_real_35 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:151 in apply_rotary_emb, code: return xq_out.type_as(xq), xk_out.type_as(xk) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_583: "bf16[1, 2048, 16, 128]" = torch.ops.prims.convert_element_type.default(view_344, torch.bfloat16); view_344 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_584: "bf16[1, 2048, 16, 128]" = torch.ops.prims.convert_element_type.default(view_345, torch.bfloat16); view_345 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:249 in forward, code: xq = xq.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] permute_312: "bf16[1, 16, 2048, 128]" = torch.ops.aten.permute.default(convert_element_type_583, [0, 2, 1, 3]); convert_element_type_583 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:250 in forward, code: xk = keys.transpose(1, 2) # (bs, n_local_heads, cache_len + seqlen, head_dim) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] permute_313: "bf16[1, 16, 2048, 128]" = torch.ops.aten.permute.default(convert_element_type_584, [0, 2, 1, 3]); convert_element_type_584 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:251 in forward, code: xv = values.transpose(1, 2) # (bs, n_local_heads, cache_len + seqlen, head_dim) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] permute_314: "bf16[1, 16, 2048, 128]" = torch.ops.aten.permute.default(view_340, [0, 2, 1, 3]); view_340 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:254 in forward, code: output = F.scaled_dot_product_attention(xq, xk, xv, is_causal=True) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] _scaled_dot_product_flash_attention_17 = torch.ops.aten._scaled_dot_product_flash_attention.default(permute_312, permute_313, permute_314, 0.0, True, scale = 0.08838834764831843) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_1823: "bf16[1, 16, 2048, 128]" = _scaled_dot_product_flash_attention_17[0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_1824: "f32[1, 16, 2048]" = _scaled_dot_product_flash_attention_17[1] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_1829: "i64[]" = _scaled_dot_product_flash_attention_17[6] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_1830: "i64[]" = _scaled_dot_product_flash_attention_17[7]; _scaled_dot_product_flash_attention_17 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:255 in forward, code: output = output.transpose( | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] permute_315: "bf16[1, 2048, 16, 128]" = torch.ops.aten.permute.default(getitem_1823, [0, 2, 1, 3]) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:259 in forward, code: output = output.view(bsz * seqlen, -1) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_346: "bf16[2048, 2048]" = torch.ops.aten.reshape.default(permute_315, [2048, -1]); permute_315 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/pytorch_yf225/torch/nn/modules/linear.py:116 in forward, code: return F.linear(input, self.weight, self.bias) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] permute_317: "bf16[2048, 2048]" = torch.ops.aten.permute.default(contiguous_view_as_strided_159, [1, 0]); contiguous_view_as_strided_159 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mm_122: "bf16[2048, 2048]" = torch.ops.aten.mm.default(view_346, permute_317); permute_317 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:410 in forward, code: h = x + self.attention(self.attention_norm(x), freqs_cis) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] add_69: "bf16[2048, 2048]" = torch.ops.aten.add.Tensor(add_67, mm_122); add_67 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:74 in forward, code: output = self._norm(x.float()).type_as(x) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_587: "f32[2048, 2048]" = torch.ops.prims.convert_element_type.default(add_69, torch.float32) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:61 in _norm, code: return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] pow_36: "f32[2048, 2048]" = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_587, 2) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mean_35: "f32[2048, 1]" = torch.ops.aten.mean.dim(pow_36, [-1], True); pow_36 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] add_70: "f32[2048, 1]" = torch.ops.aten.add.Tensor(mean_35, 1e-05); mean_35 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] rsqrt_35: "f32[2048, 1]" = torch.ops.aten.rsqrt.default(add_70); add_70 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mul_140: "f32[2048, 2048]" = torch.ops.aten.mul.Tensor(convert_element_type_587, rsqrt_35); convert_element_type_587 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:74 in forward, code: output = self._norm(x.float()).type_as(x) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_588: "bf16[2048, 2048]" = torch.ops.prims.convert_element_type.default(mul_140, torch.bfloat16); mul_140 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:75 in forward, code: return output * self.weight | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mul_141: "bf16[2048, 2048]" = torch.ops.aten.mul.Tensor(convert_element_type_588, contiguous_view_as_strided_164); convert_element_type_588 = contiguous_view_as_strided_164 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/pytorch_yf225/torch/nn/modules/linear.py:116 in forward, code: return F.linear(input, self.weight, self.bias) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] permute_319: "bf16[2048, 5632]" = torch.ops.aten.permute.default(contiguous_view_as_strided_160, [1, 0]); contiguous_view_as_strided_160 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mm_123: "bf16[2048, 5632]" = torch.ops.aten.mm.default(mul_141, permute_319); permute_319 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:299 in forward, code: return self.w2(F.silu(self.w1(x)) * self.w3(x)) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_591: "f32[2048, 5632]" = torch.ops.prims.convert_element_type.default(mm_123, torch.float32) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] sigmoid_17: "f32[2048, 5632]" = torch.ops.aten.sigmoid.default(convert_element_type_591) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mul_142: "f32[2048, 5632]" = torch.ops.aten.mul.Tensor(convert_element_type_591, sigmoid_17); convert_element_type_591 = sigmoid_17 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_592: "bf16[2048, 5632]" = torch.ops.prims.convert_element_type.default(mul_142, torch.bfloat16); mul_142 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/pytorch_yf225/torch/nn/modules/linear.py:116 in forward, code: return F.linear(input, self.weight, self.bias) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] permute_321: "bf16[2048, 5632]" = torch.ops.aten.permute.default(contiguous_view_as_strided_162, [1, 0]); contiguous_view_as_strided_162 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mm_124: "bf16[2048, 5632]" = torch.ops.aten.mm.default(mul_141, permute_321); permute_321 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:299 in forward, code: return self.w2(F.silu(self.w1(x)) * self.w3(x)) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mul_143: "bf16[2048, 5632]" = torch.ops.aten.mul.Tensor(convert_element_type_592, mm_124); convert_element_type_592 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/pytorch_yf225/torch/nn/modules/linear.py:116 in forward, code: return F.linear(input, self.weight, self.bias) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] permute_323: "bf16[5632, 2048]" = torch.ops.aten.permute.default(contiguous_view_as_strided_161, [1, 0]); contiguous_view_as_strided_161 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mm_125: "bf16[2048, 2048]" = torch.ops.aten.mm.default(mul_143, permute_323); permute_323 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:411 in forward, code: out = h + self.feed_forward(self.ffn_norm(h)) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] add_71: "bf16[2048, 2048]" = torch.ops.aten.add.Tensor(add_69, mm_125); add_69 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:74 in forward, code: output = self._norm(x.float()).type_as(x) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_597: "f32[2048, 2048]" = torch.ops.prims.convert_element_type.default(add_71, torch.float32); add_71 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:61 in _norm, code: return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] pow_37: "f32[2048, 2048]" = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_597, 2) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mean_36: "f32[2048, 1]" = torch.ops.aten.mean.dim(pow_37, [-1], True); pow_37 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] add_72: "f32[2048, 1]" = torch.ops.aten.add.Tensor(mean_36, 1e-05); mean_36 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] rsqrt_36: "f32[2048, 1]" = torch.ops.aten.rsqrt.default(add_72); add_72 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mul_144: "f32[2048, 2048]" = torch.ops.aten.mul.Tensor(convert_element_type_597, rsqrt_36); convert_element_type_597 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:74 in forward, code: output = self._norm(x.float()).type_as(x) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_598: "bf16[2048, 2048]" = torch.ops.prims.convert_element_type.default(mul_144, torch.bfloat16); mul_144 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:75 in forward, code: return output * self.weight | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mul_145: "bf16[2048, 2048]" = torch.ops.aten.mul.Tensor(convert_element_type_598, contiguous_view_as_strided_1); convert_element_type_598 = contiguous_view_as_strided_1 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:493 in forward, code: h = h.view(bsz, bs_seqlen // bsz, self.model_args.dim) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_347: "bf16[1, 2048, 2048]" = torch.ops.aten.reshape.default(mul_145, [1, 2048, 2048]); mul_145 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/pytorch_yf225/torch/nn/modules/linear.py:116 in forward, code: return F.linear(input, self.weight, self.bias) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_348: "bf16[2048, 2048]" = torch.ops.aten.reshape.default(view_347, [2048, 2048]); view_347 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] permute_325: "bf16[2048, 32000]" = torch.ops.aten.permute.default(contiguous_view_as_strided_2, [1, 0]); contiguous_view_as_strided_2 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mm_126: "bf16[2048, 32000]" = torch.ops.aten.mm.default(view_348, permute_325); permute_325 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_349: "bf16[1, 2048, 32000]" = torch.ops.aten.reshape.default(mm_126, [1, 2048, 32000]); mm_126 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:494 in forward, code: output = self.output(h).float() | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_601: "f32[1, 2048, 32000]" = torch.ops.prims.convert_element_type.default(view_349, torch.float32); view_349 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:150 in apply_rotary_emb, code: xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3) | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] _conj: "c64[1, 2048, 1, 64]" = torch.ops.aten._conj.default(view_20); view_20 = None | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] return [convert_element_type_601, primals_1, primals_6, primals_7, primals_18, primals_19, primals_20, primals_21, primals_22, primals_23, primals_24, primals_25, primals_26, primals_37, primals_38, primals_39, primals_40, primals_41, primals_42, primals_43, primals_44, primals_45, primals_55, primals_56, primals_57, primals_58, primals_59, primals_60, primals_61, primals_62, primals_63, primals_73, primals_74, primals_75, primals_76, primals_77, primals_78, primals_79, primals_80, primals_81, primals_91, primals_92, primals_93, primals_94, primals_95, primals_96, primals_97, primals_98, primals_99, primals_109, primals_110, primals_111, primals_112, primals_113, primals_114, primals_115, primals_116, primals_117, primals_127, primals_128, primals_129, primals_130, primals_131, primals_132, primals_133, primals_134, primals_135, primals_145, primals_146, primals_147, primals_148, primals_149, primals_150, primals_151, primals_152, primals_153, primals_163, primals_164, primals_165, primals_166, primals_167, primals_168, primals_169, primals_170, primals_171, primals_181, primals_182, primals_183, primals_184, primals_185, primals_186, primals_187, primals_188, primals_189, primals_199, primals_200, primals_201, primals_202, primals_203, primals_204, primals_205, primals_206, primals_207, primals_217, primals_218, primals_219, primals_220, primals_221, primals_222, primals_223, primals_224, primals_225, primals_235, primals_236, primals_237, primals_238, primals_239, primals_240, primals_241, primals_242, primals_243, primals_253, primals_254, primals_255, primals_256, primals_257, primals_258, primals_259, primals_260, primals_261, primals_271, primals_272, primals_273, primals_274, primals_275, primals_276, primals_277, primals_278, primals_279, primals_289, primals_290, primals_291, primals_292, primals_293, primals_294, primals_295, primals_296, primals_297, primals_307, primals_308, primals_309, primals_310, primals_311, primals_312, primals_313, primals_314, primals_315, primals_325, primals_326, primals_327, primals_328, primals_329, primals_330, primals_331, primals_332, primals_333, embedding, rsqrt, mul_1, permute_6, permute_7, permute_8, getitem_107, getitem_112, getitem_113, view_23, mm_3, rsqrt_1, mul_5, mm_4, mm_5, mul_7, mm_6, rsqrt_2, mul_9, permute_24, permute_25, permute_26, getitem_208, getitem_213, getitem_214, view_42, mm_10, rsqrt_3, mul_13, mm_11, mm_12, mul_15, mm_13, rsqrt_4, mul_17, permute_42, permute_43, permute_44, getitem_309, getitem_314, getitem_315, view_61, mm_17, rsqrt_5, mul_21, mm_18, mm_19, mul_23, mm_20, rsqrt_6, mul_25, permute_60, permute_61, permute_62, getitem_410, getitem_415, getitem_416, view_80, mm_24, rsqrt_7, mul_29, mm_25, mm_26, mul_31, mm_27, rsqrt_8, mul_33, permute_78, permute_79, permute_80, getitem_511, getitem_516, getitem_517, view_99, mm_31, rsqrt_9, mul_37, mm_32, mm_33, mul_39, mm_34, rsqrt_10, mul_41, permute_96, permute_97, permute_98, getitem_612, getitem_617, getitem_618, view_118, mm_38, rsqrt_11, mul_45, mm_39, mm_40, mul_47, mm_41, rsqrt_12, mul_49, permute_114, permute_115, permute_116, getitem_713, getitem_718, getitem_719, view_137, mm_45, rsqrt_13, mul_53, mm_46, mm_47, mul_55, mm_48, rsqrt_14, mul_57, permute_132, permute_133, permute_134, getitem_814, getitem_819, getitem_820, view_156, mm_52, rsqrt_15, mul_61, mm_53, mm_54, mul_63, mm_55, rsqrt_16, mul_65, permute_150, permute_151, permute_152, getitem_915, getitem_920, getitem_921, view_175, mm_59, rsqrt_17, mul_69, mm_60, mm_61, mul_71, mm_62, rsqrt_18, mul_73, permute_168, permute_169, permute_170, getitem_1016, getitem_1021, getitem_1022, view_194, mm_66, rsqrt_19, mul_77, mm_67, mm_68, mul_79, mm_69, rsqrt_20, mul_81, permute_186, permute_187, permute_188, getitem_1117, getitem_1122, getitem_1123, view_213, mm_73, rsqrt_21, mul_85, mm_74, mm_75, mul_87, mm_76, rsqrt_22, mul_89, permute_204, permute_205, permute_206, getitem_1218, getitem_1223, getitem_1224, view_232, mm_80, rsqrt_23, mul_93, mm_81, mm_82, mul_95, mm_83, rsqrt_24, mul_97, permute_222, permute_223, permute_224, getitem_1319, getitem_1324, getitem_1325, view_251, mm_87, rsqrt_25, mul_101, mm_88, mm_89, mul_103, mm_90, rsqrt_26, mul_105, permute_240, permute_241, permute_242, getitem_1420, getitem_1425, getitem_1426, view_270, mm_94, rsqrt_27, mul_109, mm_95, mm_96, mul_111, mm_97, rsqrt_28, mul_113, permute_258, permute_259, permute_260, getitem_1521, getitem_1526, getitem_1527, view_289, mm_101, rsqrt_29, mul_117, mm_102, mm_103, mul_119, mm_104, rsqrt_30, mul_121, permute_276, permute_277, permute_278, getitem_1622, getitem_1627, getitem_1628, view_308, mm_108, rsqrt_31, mul_125, mm_109, mm_110, mul_127, mm_111, rsqrt_32, mul_129, permute_294, permute_295, permute_296, getitem_1723, getitem_1728, getitem_1729, view_327, mm_115, rsqrt_33, mul_133, mm_116, mm_117, mul_135, mm_118, rsqrt_34, mul_137, permute_312, permute_313, permute_314, getitem_1824, getitem_1829, getitem_1830, view_346, mm_122, rsqrt_35, mul_141, mm_123, mm_124, mul_143, mm_125, rsqrt_36, view_348, getitem_1823, _conj, getitem_1722, getitem_1621, getitem_1520, getitem_1419, getitem_1318, getitem_1217, getitem_1116, getitem_1015, getitem_914, getitem_813, getitem_712, getitem_611, getitem_510, getitem_409, getitem_308, getitem_207, getitem_106] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] TRACED GRAPH | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] ===== after FSDP FX passes: ===== | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] /data/users/willfeng/pytorch_yf225/torch/fx/_lazy_graph_module.py class <lambda>(torch.nn.Module): | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] def forward(self, arg0_1: "f32[]", arg1_1: "f32[2048, 32000]", arg2_1: "i64[2048]", arg3_1: "f32[]", arg4_1: "f32[2048, 32000]", arg5_1: "i64[1, 2048]", arg6_1: "bf16[2048]", arg7_1: "bf16[32000, 2048]", arg8_1: "bf16[2048, 2048]", arg9_1: "bf16[2048, 2048]", arg10_1: "bf16[2048, 2048]", arg11_1: "bf16[2048, 2048]", arg12_1: "bf16[5632, 2048]", arg13_1: "bf16[2048, 5632]", arg14_1: "bf16[5632, 2048]", arg15_1: "bf16[2048]", arg16_1: "bf16[2048]", arg17_1: "bf16[2048, 2048]", arg18_1: "bf16[2048, 2048]", arg19_1: "bf16[2048, 2048]", arg20_1: "bf16[2048, 2048]", arg21_1: "bf16[5632, 2048]", arg22_1: "bf16[2048, 5632]", arg23_1: "bf16[5632, 2048]", arg24_1: "bf16[2048]", arg25_1: "bf16[2048]", arg26_1: "bf16[2048, 2048]", arg27_1: "bf16[2048, 2048]", arg28_1: "bf16[2048, 2048]", arg29_1: "bf16[2048, 2048]", arg30_1: "bf16[5632, 2048]", arg31_1: "bf16[2048, 5632]", arg32_1: "bf16[5632, 2048]", arg33_1: "bf16[2048]", arg34_1: "bf16[2048]", arg35_1: "bf16[2048, 2048]", arg36_1: "bf16[2048, 2048]", arg37_1: "bf16[2048, 2048]", arg38_1: "bf16[2048, 2048]", arg39_1: "bf16[5632, 2048]", arg40_1: "bf16[2048, 5632]", arg41_1: "bf16[5632, 2048]", arg42_1: "bf16[2048]", arg43_1: "bf16[2048]", arg44_1: "bf16[2048, 2048]", arg45_1: "bf16[2048, 2048]", arg46_1: "bf16[2048, 2048]", arg47_1: "bf16[2048, 2048]", arg48_1: "bf16[5632, 2048]", arg49_1: "bf16[2048, 5632]", arg50_1: "bf16[5632, 2048]", arg51_1: "bf16[2048]", arg52_1: "bf16[2048]", arg53_1: "bf16[2048, 2048]", arg54_1: "bf16[2048, 2048]", arg55_1: "bf16[2048, 2048]", arg56_1: "bf16[2048, 2048]", arg57_1: "bf16[5632, 2048]", arg58_1: "bf16[2048, 5632]", arg59_1: "bf16[5632, 2048]", arg60_1: "bf16[2048]", arg61_1: "bf16[2048]", arg62_1: "bf16[2048, 2048]", arg63_1: "bf16[2048, 2048]", arg64_1: "bf16[2048, 2048]", arg65_1: "bf16[2048, 2048]", arg66_1: "bf16[5632, 2048]", arg67_1: "bf16[2048, 5632]", arg68_1: "bf16[5632, 2048]", arg69_1: "bf16[2048]", arg70_1: "bf16[2048]", arg71_1: "bf16[2048, 2048]", arg72_1: "bf16[2048, 2048]", arg73_1: "bf16[2048, 2048]", arg74_1: "bf16[2048, 2048]", arg75_1: "bf16[5632, 2048]", arg76_1: "bf16[2048, 5632]", arg77_1: "bf16[5632, 2048]", arg78_1: "bf16[2048]", arg79_1: "bf16[2048]", arg80_1: "bf16[2048, 2048]", arg81_1: "bf16[2048, 2048]", arg82_1: "bf16[2048, 2048]", arg83_1: "bf16[2048, 2048]", arg84_1: "bf16[5632, 2048]", arg85_1: "bf16[2048, 5632]", arg86_1: "bf16[5632, 2048]", arg87_1: "bf16[2048]", arg88_1: "bf16[2048]", arg89_1: "bf16[2048, 2048]", arg90_1: "bf16[2048, 2048]", arg91_1: "bf16[2048, 2048]", arg92_1: "bf16[2048, 2048]", arg93_1: "bf16[5632, 2048]", arg94_1: "bf16[2048, 5632]", arg95_1: "bf16[5632, 2048]", arg96_1: "bf16[2048]", arg97_1: "bf16[2048]", arg98_1: "bf16[2048, 2048]", arg99_1: "bf16[2048, 2048]", arg100_1: "bf16[2048, 2048]", arg101_1: "bf16[2048, 2048]", arg102_1: "bf16[5632, 2048]", arg103_1: "bf16[2048, 5632]", arg104_1: "bf16[5632, 2048]", arg105_1: "bf16[2048]", arg106_1: "bf16[2048]", arg107_1: "bf16[2048, 2048]", arg108_1: "bf16[2048, 2048]", arg109_1: "bf16[2048, 2048]", arg110_1: "bf16[2048, 2048]", arg111_1: "bf16[5632, 2048]", arg112_1: "bf16[2048, 5632]", arg113_1: "bf16[5632, 2048]", arg114_1: "bf16[2048]", arg115_1: "bf16[2048]", arg116_1: "bf16[2048, 2048]", arg117_1: "bf16[2048, 2048]", arg118_1: "bf16[2048, 2048]", arg119_1: "bf16[2048, 2048]", arg120_1: "bf16[5632, 2048]", arg121_1: "bf16[2048, 5632]", arg122_1: "bf16[5632, 2048]", arg123_1: "bf16[2048]", arg124_1: "bf16[2048]", arg125_1: "bf16[2048, 2048]", arg126_1: "bf16[2048, 2048]", arg127_1: "bf16[2048, 2048]", arg128_1: "bf16[2048, 2048]", arg129_1: "bf16[5632, 2048]", arg130_1: "bf16[2048, 5632]", arg131_1: "bf16[5632, 2048]", arg132_1: "bf16[2048]", arg133_1: "bf16[2048]", arg134_1: "bf16[2048, 2048]", arg135_1: "bf16[2048, 2048]", arg136_1: "bf16[2048, 2048]", arg137_1: "bf16[2048, 2048]", arg138_1: "bf16[5632, 2048]", arg139_1: "bf16[2048, 5632]", arg140_1: "bf16[5632, 2048]", arg141_1: "bf16[2048]", arg142_1: "bf16[2048]", arg143_1: "bf16[2048, 2048]", arg144_1: "bf16[2048, 2048]", arg145_1: "bf16[2048, 2048]", arg146_1: "bf16[2048, 2048]", arg147_1: "bf16[5632, 2048]", arg148_1: "bf16[2048, 5632]", arg149_1: "bf16[5632, 2048]", arg150_1: "bf16[2048]", arg151_1: "bf16[2048]", arg152_1: "bf16[2048, 2048]", arg153_1: "bf16[2048, 2048]", arg154_1: "bf16[2048, 2048]", arg155_1: "bf16[2048, 2048]", arg156_1: "bf16[5632, 2048]", arg157_1: "bf16[2048, 5632]", arg158_1: "bf16[5632, 2048]", arg159_1: "bf16[2048]", arg160_1: "bf16[2048]", arg161_1: "bf16[2048, 2048]", arg162_1: "bf16[2048, 2048]", arg163_1: "bf16[2048, 2048]", arg164_1: "bf16[2048, 2048]", arg165_1: "bf16[5632, 2048]", arg166_1: "bf16[2048, 5632]", arg167_1: "bf16[5632, 2048]", arg168_1: "bf16[2048]", arg169_1: "bf16[2048]", arg170_1: "bf16[1, 2048, 2048]", arg171_1: "f32[2048, 1]", arg172_1: "bf16[2048, 2048]", arg173_1: "bf16[1, 16, 2048, 128]", arg174_1: "bf16[1, 16, 2048, 128]", arg175_1: "bf16[1, 16, 2048, 128]", arg176_1: "f32[1, 16, 2048]", arg177_1: "i64[]", arg178_1: "i64[]", arg179_1: "bf16[2048, 2048]", arg180_1: "bf16[2048, 2048]", arg181_1: "f32[2048, 1]", arg182_1: "bf16[2048, 2048]", arg183_1: "bf16[2048, 5632]", arg184_1: "bf16[2048, 5632]", arg185_1: "bf16[2048, 5632]", arg186_1: "bf16[2048, 2048]", arg187_1: "f32[2048, 1]", arg188_1: "bf16[2048, 2048]", arg189_1: "bf16[1, 16, 2048, 128]", arg190_1: "bf16[1, 16, 2048, 128]", arg191_1: "bf16[1, 16, 2048, 128]", arg192_1: "f32[1, 16, 2048]", arg193_1: "i64[]", arg194_1: "i64[]", arg195_1: "bf16[2048, 2048]", arg196_1: "bf16[2048, 2048]", arg197_1: "f32[2048, 1]", arg198_1: "bf16[2048, 2048]", arg199_1: "bf16[2048, 5632]", arg200_1: "bf16[2048, 5632]", arg201_1: "bf16[2048, 5632]", arg202_1: "bf16[2048, 2048]", arg203_1: "f32[2048, 1]", arg204_1: "bf16[2048, 2048]", arg205_1: "bf16[1, 16, 2048, 128]", arg206_1: "bf16[1, 16, 2048, 128]", arg207_1: "bf16[1, 16, 2048, 128]", arg208_1: "f32[1, 16, 2048]", arg209_1: "i64[]", arg210_1: "i64[]", arg211_1: "bf16[2048, 2048]", arg212_1: "bf16[2048, 2048]", arg213_1: "f32[2048, 1]", arg214_1: "bf16[2048, 2048]", arg215_1: "bf16[2048, 5632]", arg216_1: "bf16[2048, 5632]", arg217_1: "bf16[2048, 5632]", arg218_1: "bf16[2048, 2048]", arg219_1: "f32[2048, 1]", arg220_1: "bf16[2048, 2048]", arg221_1: "bf16[1, 16, 2048, 128]", arg222_1: "bf16[1, 16, 2048, 128]", arg223_1: "bf16[1, 16, 2048, 128]", arg224_1: "f32[1, 16, 2048]", arg225_1: "i64[]", arg226_1: "i64[]", arg227_1: "bf16[2048, 2048]", arg228_1: "bf16[2048, 2048]", arg229_1: "f32[2048, 1]", arg230_1: "bf16[2048, 2048]", arg231_1: "bf16[2048, 5632]", arg232_1: "bf16[2048, 5632]", arg233_1: "bf16[2048, 5632]", arg234_1: "bf16[2048, 2048]", arg235_1: "f32[2048, 1]", arg236_1: "bf16[2048, 2048]", arg237_1: "bf16[1, 16, 2048, 128]", arg238_1: "bf16[1, 16, 2048, 128]", arg239_1: "bf16[1, 16, 2048, 128]", arg240_1: "f32[1, 16, 2048]", arg241_1: "i64[]", arg242_1: "i64[]", arg243_1: "bf16[2048, 2048]", arg244_1: "bf16[2048, 2048]", arg245_1: "f32[2048, 1]", arg246_1: "bf16[2048, 2048]", arg247_1: "bf16[2048, 5632]", arg248_1: "bf16[2048, 5632]", arg249_1: "bf16[2048, 5632]", arg250_1: "bf16[2048, 2048]", arg251_1: "f32[2048, 1]", arg252_1: "bf16[2048, 2048]", arg253_1: "bf16[1, 16, 2048, 128]", arg254_1: "bf16[1, 16, 2048, 128]", arg255_1: "bf16[1, 16, 2048, 128]", arg256_1: "f32[1, 16, 2048]", arg257_1: "i64[]", arg258_1: "i64[]", arg259_1: "bf16[2048, 2048]", arg260_1: "bf16[2048, 2048]", arg261_1: "f32[2048, 1]", arg262_1: "bf16[2048, 2048]", arg263_1: "bf16[2048, 5632]", arg264_1: "bf16[2048, 5632]", arg265_1: "bf16[2048, 5632]", arg266_1: "bf16[2048, 2048]", arg267_1: "f32[2048, 1]", arg268_1: "bf16[2048, 2048]", arg269_1: "bf16[1, 16, 2048, 128]", arg270_1: "bf16[1, 16, 2048, 128]", arg271_1: "bf16[1, 16, 2048, 128]", arg272_1: "f32[1, 16, 2048]", arg273_1: "i64[]", arg274_1: "i64[]", arg275_1: "bf16[2048, 2048]", arg276_1: "bf16[2048, 2048]", arg277_1: "f32[2048, 1]", arg278_1: "bf16[2048, 2048]", arg279_1: "bf16[2048, 5632]", arg280_1: "bf16[2048, 5632]", arg281_1: "bf16[2048, 5632]", arg282_1: "bf16[2048, 2048]", arg283_1: "f32[2048, 1]", arg284_1: "bf16[2048, 2048]", arg285_1: "bf16[1, 16, 2048, 128]", arg286_1: "bf16[1, 16, 2048, 128]", arg287_1: "bf16[1, 16, 2048, 128]", arg288_1: "f32[1, 16, 2048]", arg289_1: "i64[]", arg290_1: "i64[]", arg291_1: "bf16[2048, 2048]", arg292_1: "bf16[2048, 2048]", arg293_1: "f32[2048, 1]", arg294_1: "bf16[2048, 2048]", arg295_1: "bf16[2048, 5632]", arg296_1: "bf16[2048, 5632]", arg297_1: "bf16[2048, 5632]", arg298_1: "bf16[2048, 2048]", arg299_1: "f32[2048, 1]", arg300_1: "bf16[2048, 2048]", arg301_1: "bf16[1, 16, 2048, 128]", arg302_1: "bf16[1, 16, 2048, 128]", arg303_1: "bf16[1, 16, 2048, 128]", arg304_1: "f32[1, 16, 2048]", arg305_1: "i64[]", arg306_1: "i64[]", arg307_1: "bf16[2048, 2048]", arg308_1: "bf16[2048, 2048]", arg309_1: "f32[2048, 1]", arg310_1: "bf16[2048, 2048]", arg311_1: "bf16[2048, 5632]", arg312_1: "bf16[2048, 5632]", arg313_1: "bf16[2048, 5632]", arg314_1: "bf16[2048, 2048]", arg315_1: "f32[2048, 1]", arg316_1: "bf16[2048, 2048]", arg317_1: "bf16[1, 16, 2048, 128]", arg318_1: "bf16[1, 16, 2048, 128]", arg319_1: "bf16[1, 16, 2048, 128]", arg320_1: "f32[1, 16, 2048]", arg321_1: "i64[]", arg322_1: "i64[]", arg323_1: "bf16[2048, 2048]", arg324_1: "bf16[2048, 2048]", arg325_1: "f32[2048, 1]", arg326_1: "bf16[2048, 2048]", arg327_1: "bf16[2048, 5632]", arg328_1: "bf16[2048, 5632]", arg329_1: "bf16[2048, 5632]", arg330_1: "bf16[2048, 2048]", arg331_1: "f32[2048, 1]", arg332_1: "bf16[2048, 2048]", arg333_1: "bf16[1, 16, 2048, 128]", arg334_1: "bf16[1, 16, 2048, 128]", arg335_1: "bf16[1, 16, 2048, 128]", arg336_1: "f32[1, 16, 2048]", arg337_1: "i64[]", arg338_1: "i64[]", arg339_1: "bf16[2048, 2048]", arg340_1: "bf16[2048, 2048]", arg341_1: "f32[2048, 1]", arg342_1: "bf16[2048, 2048]", arg343_1: "bf16[2048, 5632]", arg344_1: "bf16[2048, 5632]", arg345_1: "bf16[2048, 5632]", arg346_1: "bf16[2048, 2048]", arg347_1: "f32[2048, 1]", arg348_1: "bf16[2048, 2048]", arg349_1: "bf16[1, 16, 2048, 128]", arg350_1: "bf16[1, 16, 2048, 128]", arg351_1: "bf16[1, 16, 2048, 128]", arg352_1: "f32[1, 16, 2048]", arg353_1: "i64[]", arg354_1: "i64[]", arg355_1: "bf16[2048, 2048]", arg356_1: "bf16[2048, 2048]", arg357_1: "f32[2048, 1]", arg358_1: "bf16[2048, 2048]", arg359_1: "bf16[2048, 5632]", arg360_1: "bf16[2048, 5632]", arg361_1: "bf16[2048, 5632]", arg362_1: "bf16[2048, 2048]", arg363_1: "f32[2048, 1]", arg364_1: "bf16[2048, 2048]", arg365_1: "bf16[1, 16, 2048, 128]", arg366_1: "bf16[1, 16, 2048, 128]", arg367_1: "bf16[1, 16, 2048, 128]", arg368_1: "f32[1, 16, 2048]", arg369_1: "i64[]", arg370_1: "i64[]", arg371_1: "bf16[2048, 2048]", arg372_1: "bf16[2048, 2048]", arg373_1: "f32[2048, 1]", arg374_1: "bf16[2048, 2048]", arg375_1: "bf16[2048, 5632]", arg376_1: "bf16[2048, 5632]", arg377_1: "bf16[2048, 5632]", arg378_1: "bf16[2048, 2048]", arg379_1: "f32[2048, 1]", arg380_1: "bf16[2048, 2048]", arg381_1: "bf16[1, 16, 2048, 128]", arg382_1: "bf16[1, 16, 2048, 128]", arg383_1: "bf16[1, 16, 2048, 128]", arg384_1: "f32[1, 16, 2048]", arg385_1: "i64[]", arg386_1: "i64[]", arg387_1: "bf16[2048, 2048]", arg388_1: "bf16[2048, 2048]", arg389_1: "f32[2048, 1]", arg390_1: "bf16[2048, 2048]", arg391_1: "bf16[2048, 5632]", arg392_1: "bf16[2048, 5632]", arg393_1: "bf16[2048, 5632]", arg394_1: "bf16[2048, 2048]", arg395_1: "f32[2048, 1]", arg396_1: "bf16[2048, 2048]", arg397_1: "bf16[1, 16, 2048, 128]", arg398_1: "bf16[1, 16, 2048, 128]", arg399_1: "bf16[1, 16, 2048, 128]", arg400_1: "f32[1, 16, 2048]", arg401_1: "i64[]", arg402_1: "i64[]", arg403_1: "bf16[2048, 2048]", arg404_1: "bf16[2048, 2048]", arg405_1: "f32[2048, 1]", arg406_1: "bf16[2048, 2048]", arg407_1: "bf16[2048, 5632]", arg408_1: "bf16[2048, 5632]", arg409_1: "bf16[2048, 5632]", arg410_1: "bf16[2048, 2048]", arg411_1: "f32[2048, 1]", arg412_1: "bf16[2048, 2048]", arg413_1: "bf16[1, 16, 2048, 128]", arg414_1: "bf16[1, 16, 2048, 128]", arg415_1: "bf16[1, 16, 2048, 128]", arg416_1: "f32[1, 16, 2048]", arg417_1: "i64[]", arg418_1: "i64[]", arg419_1: "bf16[2048, 2048]", arg420_1: "bf16[2048, 2048]", arg421_1: "f32[2048, 1]", arg422_1: "bf16[2048, 2048]", arg423_1: "bf16[2048, 5632]", arg424_1: "bf16[2048, 5632]", arg425_1: "bf16[2048, 5632]", arg426_1: "bf16[2048, 2048]", arg427_1: "f32[2048, 1]", arg428_1: "bf16[2048, 2048]", arg429_1: "bf16[1, 16, 2048, 128]", arg430_1: "bf16[1, 16, 2048, 128]", arg431_1: "bf16[1, 16, 2048, 128]", arg432_1: "f32[1, 16, 2048]", arg433_1: "i64[]", arg434_1: "i64[]", arg435_1: "bf16[2048, 2048]", arg436_1: "bf16[2048, 2048]", arg437_1: "f32[2048, 1]", arg438_1: "bf16[2048, 2048]", arg439_1: "bf16[2048, 5632]", arg440_1: "bf16[2048, 5632]", arg441_1: "bf16[2048, 5632]", arg442_1: "bf16[2048, 2048]", arg443_1: "f32[2048, 1]", arg444_1: "bf16[2048, 2048]", arg445_1: "bf16[1, 16, 2048, 128]", arg446_1: "bf16[1, 16, 2048, 128]", arg447_1: "bf16[1, 16, 2048, 128]", arg448_1: "f32[1, 16, 2048]", arg449_1: "i64[]", arg450_1: "i64[]", arg451_1: "bf16[2048, 2048]", arg452_1: "bf16[2048, 2048]", arg453_1: "f32[2048, 1]", arg454_1: "bf16[2048, 2048]", arg455_1: "bf16[2048, 5632]", arg456_1: "bf16[2048, 5632]", arg457_1: "bf16[2048, 5632]", arg458_1: "bf16[2048, 2048]", arg459_1: "f32[2048, 1]", arg460_1: "bf16[2048, 2048]", arg461_1: "bf16[1, 16, 2048, 128]", arg462_1: "c64[1, 2048, 1, 64]", arg463_1: "bf16[1, 16, 2048, 128]", arg464_1: "bf16[1, 16, 2048, 128]", arg465_1: "bf16[1, 16, 2048, 128]", arg466_1: "bf16[1, 16, 2048, 128]", arg467_1: "bf16[1, 16, 2048, 128]", arg468_1: "bf16[1, 16, 2048, 128]", arg469_1: "bf16[1, 16, 2048, 128]", arg470_1: "bf16[1, 16, 2048, 128]", arg471_1: "bf16[1, 16, 2048, 128]", arg472_1: "bf16[1, 16, 2048, 128]", arg473_1: "bf16[1, 16, 2048, 128]", arg474_1: "bf16[1, 16, 2048, 128]", arg475_1: "bf16[1, 16, 2048, 128]", arg476_1: "bf16[1, 16, 2048, 128]", arg477_1: "bf16[1, 16, 2048, 128]", arg478_1: "bf16[1, 16, 2048, 128]", arg479_1: "bf16[1, 16, 2048, 128]", arg480_1: "bf16[32000, 2048]", arg481_1: "f32[8192000]", arg482_1: "f32[256]", arg483_1: "f32[8192000]", arg484_1: "f32[524288]", arg485_1: "f32[524288]", arg486_1: "f32[524288]", arg487_1: "f32[524288]", arg488_1: "f32[1441792]", arg489_1: "f32[1441792]", arg490_1: "f32[1441792]", arg491_1: "f32[256]", arg492_1: "f32[256]", arg493_1: "f32[524288]", arg494_1: "f32[524288]", arg495_1: "f32[524288]", arg496_1: "f32[524288]", arg497_1: "f32[1441792]", arg498_1: "f32[1441792]", arg499_1: "f32[1441792]", arg500_1: "f32[256]", arg501_1: "f32[256]", arg502_1: "f32[524288]", arg503_1: "f32[524288]", arg504_1: "f32[524288]", arg505_1: "f32[524288]", arg506_1: "f32[1441792]", arg507_1: "f32[1441792]", arg508_1: "f32[1441792]", arg509_1: "f32[256]", arg510_1: "f32[256]", arg511_1: "f32[524288]", arg512_1: "f32[524288]", arg513_1: "f32[524288]", arg514_1: "f32[524288]", arg515_1: "f32[1441792]", arg516_1: "f32[1441792]", arg517_1: "f32[1441792]", arg518_1: "f32[256]", arg519_1: "f32[256]", arg520_1: "f32[524288]", arg521_1: "f32[524288]", arg522_1: "f32[524288]", arg523_1: "f32[524288]", arg524_1: "f32[1441792]", arg525_1: "f32[1441792]", arg526_1: "f32[1441792]", arg527_1: "f32[256]", arg528_1: "f32[256]", arg529_1: "f32[524288]", arg530_1: "f32[524288]", arg531_1: "f32[524288]", arg532_1: "f32[524288]", arg533_1: "f32[1441792]", arg534_1: "f32[1441792]", arg535_1: "f32[1441792]", arg536_1: "f32[256]", arg537_1: "f32[256]", arg538_1: "f32[524288]", arg539_1: "f32[524288]", arg540_1: "f32[524288]", arg541_1: "f32[524288]", arg542_1: "f32[1441792]", arg543_1: "f32[1441792]", arg544_1: "f32[1441792]", arg545_1: "f32[256]", arg546_1: "f32[256]", arg547_1: "f32[524288]", arg548_1: "f32[524288]", arg549_1: "f32[524288]", arg550_1: "f32[524288]", arg551_1: "f32[1441792]", arg552_1: "f32[1441792]", arg553_1: "f32[1441792]", arg554_1: "f32[256]", arg555_1: "f32[256]", arg556_1: "f32[524288]", arg557_1: "f32[524288]", arg558_1: "f32[524288]", arg559_1: "f32[524288]", arg560_1: "f32[1441792]", arg561_1: "f32[1441792]", arg562_1: "f32[1441792]", arg563_1: "f32[256]", arg564_1: "f32[256]", arg565_1: "f32[524288]", arg566_1: "f32[524288]", arg567_1: "f32[524288]", arg568_1: "f32[524288]", arg569_1: "f32[1441792]", arg570_1: "f32[1441792]", arg571_1: "f32[1441792]", arg572_1: "f32[256]", arg573_1: "f32[256]", arg574_1: "f32[524288]", arg575_1: "f32[524288]", arg576_1: "f32[524288]", arg577_1: "f32[524288]", arg578_1: "f32[1441792]", arg579_1: "f32[1441792]", arg580_1: "f32[1441792]", arg581_1: "f32[256]", arg582_1: "f32[256]", arg583_1: "f32[524288]", arg584_1: "f32[524288]", arg585_1: "f32[524288]", arg586_1: "f32[524288]", arg587_1: "f32[1441792]", arg588_1: "f32[1441792]", arg589_1: "f32[1441792]", arg590_1: "f32[256]", arg591_1: "f32[256]", arg592_1: "f32[524288]", arg593_1: "f32[524288]", arg594_1: "f32[524288]", arg595_1: "f32[524288]", arg596_1: "f32[1441792]", arg597_1: "f32[1441792]", arg598_1: "f32[1441792]", arg599_1: "f32[256]", arg600_1: "f32[256]", arg601_1: "f32[524288]", arg602_1: "f32[524288]", arg603_1: "f32[524288]", arg604_1: "f32[524288]", arg605_1: "f32[1441792]", arg606_1: "f32[1441792]", arg607_1: "f32[1441792]", arg608_1: "f32[256]", arg609_1: "f32[256]", arg610_1: "f32[524288]", arg611_1: "f32[524288]", arg612_1: "f32[524288]", arg613_1: "f32[524288]", arg614_1: "f32[1441792]", arg615_1: "f32[1441792]", arg616_1: "f32[1441792]", arg617_1: "f32[256]", arg618_1: "f32[256]", arg619_1: "f32[524288]", arg620_1: "f32[524288]", arg621_1: "f32[524288]", arg622_1: "f32[524288]", arg623_1: "f32[1441792]", arg624_1: "f32[1441792]", arg625_1: "f32[1441792]", arg626_1: "f32[256]", arg627_1: "f32[256]", arg628_1: "f32[524288]", arg629_1: "f32[524288]", arg630_1: "f32[524288]", arg631_1: "f32[524288]", arg632_1: "f32[1441792]", arg633_1: "f32[1441792]", arg634_1: "f32[1441792]", arg635_1: "f32[256]", arg636_1: "f32[256]", arg637_1: "f32[524288]", arg638_1: "f32[524288]", arg639_1: "f32[524288]", arg640_1: "f32[524288]", arg641_1: "f32[1441792]", arg642_1: "f32[1441792]", arg643_1: "f32[1441792]", arg644_1: "f32[256]", arg645_1: "f32[256]"): | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_common.py:140 in _to_dtype_if_needed, code: return tensor.to(dtype) | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] convert_element_type_110: "bf16[8192000]" = torch.ops.prims.convert_element_type.default(arg481_1, torch.bfloat16); arg481_1 = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] convert_element_type_111: "bf16[256]" = torch.ops.prims.convert_element_type.default(arg482_1, torch.bfloat16); arg482_1 = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] convert_element_type_112: "bf16[8192000]" = torch.ops.prims.convert_element_type.default(arg483_1, torch.bfloat16); arg483_1 = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:80 in foreach_all_gather, code: all_gather_input, all_gather_output = torch.ops.fsdp.all_gather_copy_in(all_gather_input_numel, world_size, rank, dtype, device, inp_split_sizes, param_all_gather_inputs) | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] all_gather_copy_in = torch.ops.fsdp.all_gather_copy_in.default(16384256, 8, 0, torch.bfloat16, device(type='cuda', index=0), [8192000, 256, 8192000], [convert_element_type_110, convert_element_type_111, convert_element_type_112]); convert_element_type_110 = convert_element_type_111 = convert_element_type_112 = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] getitem: "bf16[16384256]" = all_gather_copy_in[0] | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] getitem_1: "bf16[131074048]" = all_gather_copy_in[1]; all_gather_copy_in = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_functional_collectives.py:229 in all_gather_tensor, code: tensor = torch.ops._c10d_functional.all_gather_into_tensor( | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] all_gather_into_tensor: "bf16[131074048]" = torch.ops._c10d_functional.all_gather_into_tensor.default(getitem, 8, '0'); getitem = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_functional_collectives.py:144 in wait_tensor, code: return torch.ops._c10d_functional.wait_tensor(tensor) # type: ignore[attr-defined] | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] wait_tensor: "bf16[131074048]" = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor); all_gather_into_tensor = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:170 in foreach_all_gather_copy_out, code: split_unpadded = torch.ops.fsdp.contiguous_view_as_strided( | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] view_4: "bf16[8, 16384256]" = torch.ops.aten.reshape.default(wait_tensor, [8, -1]) | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] split_with_sizes_2 = torch.ops.aten.split_with_sizes.default(view_4, [8192000, 256, 8192000], 1); view_4 = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] getitem_9: "bf16[8, 256]" = split_with_sizes_2[1]; split_with_sizes_2 = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] view_5: "bf16[8, 16384256]" = torch.ops.aten.reshape.default(wait_tensor, [8, -1]); wait_tensor = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] split_with_sizes_3 = torch.ops.aten.split_with_sizes.default(view_5, [8192000, 256, 8192000], 1); view_5 = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] getitem_13: "bf16[8, 8192000]" = split_with_sizes_3[2]; split_with_sizes_3 = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_common.py:140 in _to_dtype_if_needed, code: return tensor.to(dtype) | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] convert_element_type_113: "bf16[524288]" = torch.ops.prims.convert_element_type.default(arg484_1, torch.bfloat16); arg484_1 = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] convert_element_type_114: "bf16[524288]" = torch.ops.prims.convert_element_type.default(arg485_1, torch.bfloat16); arg485_1 = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] convert_element_type_115: "bf16[524288]" = torch.ops.prims.convert_element_type.default(arg486_1, torch.bfloat16); arg486_1 = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] convert_element_type_116: "bf16[524288]" = torch.ops.prims.convert_element_type.default(arg487_1, torch.bfloat16); arg487_1 = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] convert_element_type_117: "bf16[1441792]" = torch.ops.prims.convert_element_type.default(arg488_1, torch.bfloat16); arg488_1 = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] convert_element_type_118: "bf16[1441792]" = torch.ops.prims.convert_element_type.default(arg489_1, torch.bfloat16); arg489_1 = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] convert_element_type_119: "bf16[1441792]" = torch.ops.prims.convert_element_type.default(arg490_1, torch.bfloat16); arg490_1 = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] convert_element_type_120: "bf16[256]" = torch.ops.prims.convert_element_type.default(arg491_1, torch.bfloat16); arg491_1 = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] convert_element_type_121: "bf16[256]" = torch.ops.prims.convert_element_type.default(arg492_1, torch.bfloat16); arg492_1 = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:80 in foreach_all_gather, code: all_gather_input, all_gather_output = torch.ops.fsdp.all_gather_copy_in(all_gather_input_numel, world_size, rank, dtype, device, inp_split_sizes, param_all_gather_inputs) | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] all_gather_copy_in_1 = torch.ops.fsdp.all_gather_copy_in.default(6423040, 8, 0, torch.bfloat16, device(type='cuda', index=0), [524288, 524288, 524288, 524288, 1441792, 1441792, 1441792, 256, 256], [convert_element_type_113, convert_element_type_114, convert_element_type_115, convert_element_type_116, convert_element_type_117, convert_element_type_118, convert_element_type_119, convert_element_type_120, convert_element_type_121]); convert_element_type_113 = convert_element_type_114 = convert_element_type_115 = convert_element_type_116 = convert_element_type_117 = convert_element_type_118 = convert_element_type_119 = convert_element_type_120 = convert_element_type_121 = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] getitem_14: "bf16[6423040]" = all_gather_copy_in_1[0] | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] getitem_15: "bf16[51384320]" = all_gather_copy_in_1[1]; all_gather_copy_in_1 = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_functional_collectives.py:229 in all_gather_tensor, code: tensor = torch.ops._c10d_functional.all_gather_into_tensor( | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] all_gather_into_tensor_1: "bf16[51384320]" = torch.ops._c10d_functional.all_gather_into_tensor.default(getitem_14, 8, '0'); getitem_14 = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_functional_collectives.py:144 in wait_tensor, code: return torch.ops._c10d_functional.wait_tensor(tensor) # type: ignore[attr-defined] | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] wait_tensor_1: "bf16[51384320]" = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_1); all_gather_into_tensor_1 = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:170 in foreach_all_gather_copy_out, code: split_unpadded = torch.ops.fsdp.contiguous_view_as_strided( | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] view_11: "bf16[8, 6423040]" = torch.ops.aten.reshape.default(wait_tensor_1, [8, -1]) | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] split_with_sizes_5 = torch.ops.aten.split_with_sizes.default(view_11, [524288, 524288, 524288, 524288, 1441792, 1441792, 1441792, 256, 256], 1); view_11 = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] getitem_25: "bf16[8, 524288]" = split_with_sizes_5[0]; split_with_sizes_5 = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] view_12: "bf16[8, 6423040]" = torch.ops.aten.reshape.default(wait_tensor_1, [8, -1]) | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] split_with_sizes_6 = torch.ops.aten.split_with_sizes.default(view_12, [524288, 524288, 524288, 524288, 1441792, 1441792, 1441792, 256, 256], 1); view_12 = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] getitem_35: "bf16[8, 524288]" = split_with_sizes_6[1]; split_with_sizes_6 = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] view_13: "bf16[8, 6423040]" = torch.ops.aten.reshape.default(wait_tensor_1, [8, -1]) | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] split_with_sizes_7 = torch.ops.aten.split_with_sizes.default(view_13, [524288, 524288, 524288, 524288, 1441792, 1441792, 1441792, 256, 256], 1); view_13 = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] getitem_45: "bf16[8, 524288]" = split_with_sizes_7[2]; split_with_sizes_7 = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] view_14: "bf16[8, 6423040]" = torch.ops.aten.reshape.default(wait_tensor_1, [8, -1]) | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] split_with_sizes_8 = torch.ops.aten.split_with_sizes.default(view_14, [524288, 524288, 524288, 524288, 1441792, 1441792, 1441792, 256, 256], 1); view_14 = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] getitem_55: "bf16[8, 524288]" = split_with_sizes_8[3]; split_with_sizes_8 = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] view_15: "bf16[8, 6423040]" = torch.ops.aten.reshape.default(wait_tensor_1, [8, -1]) | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] split_with_sizes_9 = torch.ops.aten.split_with_sizes.default(view_15, [524288, 524288, 524288, 524288, 1441792, 1441792, 1441792, 256, 256], 1); view_15 = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] getitem_65: "bf16[8, 1441792]" = split_with_sizes_9[4]; split_with_sizes_9 = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] view_16: "bf16[8, 6423040]" = torch.ops.aten.reshape.default(wait_tensor_1, [8, -1]) | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] split_with_sizes_10 = torch.ops.aten.split_with_sizes.default(view_16, [524288, 524288, 524288, 524288, 1441792, 1441792, 1441792, 256, 256], 1); view_16 = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] getitem_75: "bf16[8, 1441792]" = split_with_sizes_10[5]; split_with_sizes_10 = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] view_17: "bf16[8, 6423040]" = torch.ops.aten.reshape.default(wait_tensor_1, [8, -1]) | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] split_with_sizes_11 = torch.ops.aten.split_with_sizes.default(view_17, [524288, 524288, 524288, 524288, 1441792, 1441792, 1441792, 256, 256], 1); view_17 = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] getitem_85: "bf16[8, 1441792]" = split_with_sizes_11[6]; split_with_sizes_11 = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] view_18: "bf16[8, 6423040]" = torch.ops.aten.reshape.default(wait_tensor_1, [8, -1]) | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] split_with_sizes_12 = torch.ops.aten.split_with_sizes.default(view_18, [524288, 524288, 524288, 524288, 1441792, 1441792, 1441792, 256, 256], 1); view_18 = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] getitem_95: "bf16[8, 256]" = split_with_sizes_12[7]; split_with_sizes_12 = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] view_19: "bf16[8, 6423040]" = torch.ops.aten.reshape.default(wait_tensor_1, [8, -1]); wait_tensor_1 = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] split_with_sizes_13 = torch.ops.aten.split_with_sizes.default(view_19, [524288, 524288, 524288, 524288, 1441792, 1441792, 1441792, 256, 256], 1); view_19 = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] getitem_105: "bf16[8, 256]" = split_with_sizes_13[8]; split_with_sizes_13 = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_common.py:140 in _to_dtype_if_needed, code: return tensor.to(dtype) | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] convert_element_type_129: "bf16[524288]" = torch.ops.prims.convert_element_type.default(arg493_1, torch.bfloat16); arg493_1 = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] convert_element_type_130: "bf16[524288]" = torch.ops.prims.convert_element_type.default(arg494_1, torch.bfloat16); arg494_1 = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] convert_element_type_131: "bf16[524288]" = torch.ops.prims.convert_element_type.default(arg495_1, torch.bfloat16); arg495_1 = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] convert_element_type_132: "bf16[524288]" = torch.ops.prims.convert_element_type.default(arg496_1, torch.bfloat16); arg496_1 = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] convert_element_type_133: "bf16[1441792]" = torch.ops.prims.convert_element_type.default(arg497_1, torch.bfloat16); arg497_1 = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] convert_element_type_134: "bf16[1441792]" = torch.ops.prims.convert_element_type.default(arg498_1, torch.bfloat16); arg498_1 = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] convert_element_type_135: "bf16[1441792]" = torch.ops.prims.convert_element_type.default(arg499_1, torch.bfloat16); arg499_1 = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] convert_element_type_136: "bf16[256]" = torch.ops.prims.convert_element_type.default(arg500_1, torch.bfloat16); arg500_1 = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] convert_element_type_137: "bf16[256]" = torch.ops.prims.convert_element_type.default(arg501_1, torch.bfloat16); arg501_1 = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:80 in foreach_all_gather, code: all_gather_input, all_gather_output = torch.ops.fsdp.all_gather_copy_in(all_gather_input_numel, world_size, rank, dtype, device, inp_split_sizes, param_all_gather_inputs) | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] all_gather_copy_in_2 = torch.ops.fsdp.all_gather_copy_in.default(6423040, 8, 0, torch.bfloat16, device(type='cuda', index=0), [524288, 524288, 524288, 524288, 1441792, 1441792, 1441792, 256, 256], [convert_element_type_129, convert_element_type_130, convert_element_type_131, convert_element_type_132, convert_element_type_133, convert_element_type_134, convert_element_type_135, convert_element_type_136, convert_element_type_137]); convert_element_type_129 = convert_element_type_130 = convert_element_type_131 = convert_element_type_132 = convert_element_type_133 = convert_element_type_134 = convert_element_type_135 = convert_element_type_136 = convert_element_type_137 = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] getitem_106: "bf16[6423040]" = all_gather_copy_in_2[0] | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] getitem_107: "bf16[51384320]" = all_gather_copy_in_2[1]; all_gather_copy_in_2 = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] # File: <eval_with_key>.7:487 in forward, code: _log_softmax_backward_data = torch.ops.aten._log_softmax_backward_data.default(nll_loss_backward, getitem_4, 1, torch.float32); nll_loss_backward = getitem_4 = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] exp: "f32[2048, 32000]" = torch.ops.aten.exp.default(arg4_1); arg4_1 = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] # File: <eval_with_key>.7:486 in forward, code: nll_loss_backward = torch.ops.aten.nll_loss_backward.default(getitem, getitem_1, getitem_2, None, 1, -100, getitem_3); getitem = getitem_1 = getitem_2 = getitem_3 = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] full: "f32[2048, 32000]" = torch.ops.aten.full.default([2048, 32000], 0, dtype = torch.float32, layout = torch.strided, device = device(type='cuda', index=0), pin_memory = False) | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] unsqueeze: "i64[2048, 1]" = torch.ops.aten.unsqueeze.default(arg2_1, 1); arg2_1 = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] ne: "b8[2048, 1]" = torch.ops.aten.ne.Scalar(unsqueeze, -100) | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] scalar_tensor: "i64[]" = torch.ops.aten.scalar_tensor.default(0, dtype = torch.int64, layout = torch.strided, device = device(type='cuda', index=0)) | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] where: "i64[2048, 1]" = torch.ops.aten.where.self(ne, unsqueeze, scalar_tensor); ne = scalar_tensor = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] scatter: "f32[2048, 32000]" = torch.ops.aten.scatter.value(full, 1, where, -1.0); full = where = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] ne_1: "b8[2048, 1]" = torch.ops.aten.ne.Scalar(unsqueeze, -100); unsqueeze = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] div: "f32[]" = torch.ops.aten.div.Tensor(arg0_1, arg3_1); arg0_1 = arg3_1 = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] scalar_tensor_1: "f32[]" = torch.ops.aten.scalar_tensor.default(0, dtype = torch.float32, layout = torch.strided, device = device(type='cuda', index=0)) | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] where_1: "f32[2048, 1]" = torch.ops.aten.where.self(ne_1, div, scalar_tensor_1); ne_1 = div = scalar_tensor_1 = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] mul: "f32[2048, 32000]" = torch.ops.aten.mul.Tensor(scatter, where_1); scatter = where_1 = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] # File: <eval_with_key>.7:487 in forward, code: _log_softmax_backward_data = torch.ops.aten._log_softmax_backward_data.default(nll_loss_backward, getitem_4, 1, torch.float32); nll_loss_backward = getitem_4 = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] sum_1: "f32[2048, 1]" = torch.ops.aten.sum.dim_IntList(mul, [1], True) | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] mul_1: "f32[2048, 32000]" = torch.ops.aten.mul.Tensor(exp, sum_1); exp = sum_1 = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] sub: "f32[2048, 32000]" = torch.ops.aten.sub.Tensor(mul, mul_1); mul = mul_1 = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] # File: <eval_with_key>.7:488 in forward, code: view = torch.ops.aten.view.default(_log_softmax_backward_data, [1, 2048, 32000]); _log_softmax_backward_data = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] view: "f32[1, 2048, 32000]" = torch.ops.aten.reshape.default(sub, [1, 2048, 32000]); sub = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] # File: <eval_with_key>.7:950 in forward, code: convert_element_type_110 = torch.ops.prims.convert_element_type.default(trace_wrapped, torch.bfloat16); trace_wrapped = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] convert_element_type_122: "bf16[1, 2048, 32000]" = torch.ops.prims.convert_element_type.default(view, torch.bfloat16); view = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] # File: <eval_with_key>.7:951 in forward, code: view_2 = torch.ops.aten.view.default(convert_element_type_110, [2048, 32000]); convert_element_type_110 = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] view_6: "bf16[2048, 32000]" = torch.ops.aten.reshape.default(convert_element_type_122, [2048, 32000]); convert_element_type_122 = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:170 in foreach_all_gather_copy_out, code: split_unpadded = torch.ops.fsdp.contiguous_view_as_strided( | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] contiguous_view_as_strided_2: "bf16[32000, 2048]" = torch.ops.fsdp.contiguous_view_as_strided.default(getitem_13, [32000, 2048], [2048, 1]); getitem_13 = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] # File: <eval_with_key>.7:957 in forward, code: mm_1 = torch.ops.aten.mm.default(view_2, permute_3); view_2 = permute_3 = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] mm_1: "bf16[2048, 2048]" = torch.ops.aten.mm.default(view_6, contiguous_view_as_strided_2); contiguous_view_as_strided_2 = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] # File: <eval_with_key>.7:958 in forward, code: view_3 = torch.ops.aten.view.default(mm_1, [1, 2048, 2048]); mm_1 = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] view_7: "bf16[1, 2048, 2048]" = torch.ops.aten.reshape.default(mm_1, [1, 2048, 2048]); mm_1 = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] # File: <eval_with_key>.7:961 in forward, code: view_4 = torch.ops.aten.view.default(view_3, [2048, 2048]); view_3 = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] view_8: "bf16[2048, 2048]" = torch.ops.aten.reshape.default(view_7, [2048, 2048]); view_7 = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:170 in foreach_all_gather_copy_out, code: split_unpadded = torch.ops.fsdp.contiguous_view_as_strided( | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] contiguous_view_as_strided_1: "bf16[2048]" = torch.ops.fsdp.contiguous_view_as_strided.default(getitem_9, [2048], [1]); getitem_9 = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] # File: <eval_with_key>.7:963 in forward, code: mul_56 = torch.ops.aten.mul.Tensor(view_4, getitem_6); view_4 = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] mul_58: "bf16[2048, 2048]" = torch.ops.aten.mul.Tensor(view_8, contiguous_view_as_strided_1); contiguous_view_as_strided_1 = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] # File: <eval_with_key>.7:967 in forward, code: convert_element_type_111 = torch.ops.prims.convert_element_type.default(mul_56, torch.float32); mul_56 = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] convert_element_type_127: "f32[2048, 2048]" = torch.ops.prims.convert_element_type.default(mul_58, torch.float32); mul_58 = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] # File: <eval_with_key>.7:969 in forward, code: mul_58 = torch.ops.aten.mul.Tensor(convert_element_type_111, getitem_459); convert_element_type_111 = getitem_459 = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] mul_60: "f32[2048, 2048]" = torch.ops.aten.mul.Tensor(convert_element_type_127, arg459_1) | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] # File: <eval_with_key>.7:655 in forward, code: view_1 = torch.ops.aten.view.default(getitem_170, [-1, 2048]); getitem_170 = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] view_1: "bf16[2048, 2048]" = torch.ops.aten.reshape.default(arg170_1, [-1, 2048]); arg170_1 = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] # File: <eval_with_key>.7:661 in forward, code: add = torch.ops.aten.add.Tensor(view_1, getitem_180); view_1 = getitem_180 = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] add: "bf16[2048, 2048]" = torch.ops.aten.add.Tensor(view_1, arg180_1); arg180_1 = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] # File: <eval_with_key>.7:671 in forward, code: add_1 = torch.ops.aten.add.Tensor(add, getitem_186); add = getitem_186 = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] add_1: "bf16[2048, 2048]" = torch.ops.aten.add.Tensor(add, arg186_1); arg186_1 = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] # File: <eval_with_key>.7:677 in forward, code: add_2 = torch.ops.aten.add.Tensor(add_1, getitem_196); add_1 = getitem_196 = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] add_2: "bf16[2048, 2048]" = torch.ops.aten.add.Tensor(add_1, arg196_1); arg196_1 = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] # File: <eval_with_key>.7:687 in forward, code: add_3 = torch.ops.aten.add.Tensor(add_2, getitem_202); add_2 = getitem_202 = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] add_3: "bf16[2048, 2048]" = torch.ops.aten.add.Tensor(add_2, arg202_1); arg202_1 = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] # File: <eval_with_key>.7:693 in forward, code: add_4 = torch.ops.aten.add.Tensor(add_3, getitem_212); add_3 = getitem_212 = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] add_4: "bf16[2048, 2048]" = torch.ops.aten.add.Tensor(add_3, arg212_1); arg212_1 = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] # File: <eval_with_key>.7:703 in forward, code: add_5 = torch.ops.aten.add.Tensor(add_4, getitem_218); add_4 = getitem_218 = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] add_5: "bf16[2048, 2048]" = torch.ops.aten.add.Tensor(add_4, arg218_1); arg218_1 = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] # File: <eval_with_key>.7:709 in forward, code: add_6 = torch.ops.aten.add.Tensor(add_5, getitem_228); add_5 = getitem_228 = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] add_6: "bf16[2048, 2048]" = torch.ops.aten.add.Tensor(add_5, arg228_1); arg228_1 = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] # File: <eval_with_key>.7:719 in forward, code: add_7 = torch.ops.aten.add.Tensor(add_6, getitem_234); add_6 = getitem_234 = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] add_7: "bf16[2048, 2048]" = torch.ops.aten.add.Tensor(add_6, arg234_1); arg234_1 = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] # File: <eval_with_key>.7:725 in forward, code: add_8 = torch.ops.aten.add.Tensor(add_7, getitem_244); add_7 = getitem_244 = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] add_8: "bf16[2048, 2048]" = torch.ops.aten.add.Tensor(add_7, arg244_1); arg244_1 = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] # File: <eval_with_key>.7:735 in forward, code: add_9 = torch.ops.aten.add.Tensor(add_8, getitem_250); add_8 = getitem_250 = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] add_9: "bf16[2048, 2048]" = torch.ops.aten.add.Tensor(add_8, arg250_1); arg250_1 = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] # File: <eval_with_key>.7:741 in forward, code: add_10 = torch.ops.aten.add.Tensor(add_9, getitem_260); add_9 = getitem_260 = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] add_10: "bf16[2048, 2048]" = torch.ops.aten.add.Tensor(add_9, arg260_1); arg260_1 = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] # File: <eval_with_key>.7:751 in forward, code: add_11 = torch.ops.aten.add.Tensor(add_10, getitem_266); add_10 = getitem_266 = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] add_11: "bf16[2048, 2048]" = torch.ops.aten.add.Tensor(add_10, arg266_1); arg266_1 = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] # File: <eval_with_key>.7:757 in forward, code: add_12 = torch.ops.aten.add.Tensor(add_11, getitem_276); add_11 = getitem_276 = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] add_12: "bf16[2048, 2048]" = torch.ops.aten.add.Tensor(add_11, arg276_1); arg276_1 = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] # File: <eval_with_key>.7:767 in forward, code: add_13 = torch.ops.aten.add.Tensor(add_12, getitem_282); add_12 = getitem_282 = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] add_13: "bf16[2048, 2048]" = torch.ops.aten.add.Tensor(add_12, arg282_1); arg282_1 = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] # File: <eval_with_key>.7:773 in forward, code: add_14 = torch.ops.aten.add.Tensor(add_13, getitem_292); add_13 = getitem_292 = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] add_14: "bf16[2048, 2048]" = torch.ops.aten.add.Tensor(add_13, arg292_1); arg292_1 = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] # File: <eval_with_key>.7:783 in forward, code: add_15 = torch.ops.aten.add.Tensor(add_14, getitem_298); add_14 = getitem_298 = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] add_15: "bf16[2048, 2048]" = torch.ops.aten.add.Tensor(add_14, arg298_1); arg298_1 = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] # File: <eval_with_key>.7:789 in forward, code: add_16 = torch.ops.aten.add.Tensor(add_15, getitem_308); add_15 = getitem_308 = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] add_16: "bf16[2048, 2048]" = torch.ops.aten.add.Tensor(add_15, arg308_1); arg308_1 = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] # File: <eval_with_key>.7:799 in forward, code: add_17 = torch.ops.aten.add.Tensor(add_16, getitem_314); add_16 = getitem_314 = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] add_17: "bf16[2048, 2048]" = torch.ops.aten.add.Tensor(add_16, arg314_1); arg314_1 = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] # File: <eval_with_key>.7:805 in forward, code: add_18 = torch.ops.aten.add.Tensor(add_17, getitem_324); add_17 = getitem_324 = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] add_18: "bf16[2048, 2048]" = torch.ops.aten.add.Tensor(add_17, arg324_1); arg324_1 = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] # File: <eval_with_key>.7:815 in forward, code: add_19 = torch.ops.aten.add.Tensor(add_18, getitem_330); add_18 = getitem_330 = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] add_19: "bf16[2048, 2048]" = torch.ops.aten.add.Tensor(add_18, arg330_1); arg330_1 = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] # File: <eval_with_key>.7:821 in forward, code: add_20 = torch.ops.aten.add.Tensor(add_19, getitem_340); add_19 = getitem_340 = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] add_20: "bf16[2048, 2048]" = torch.ops.aten.add.Tensor(add_19, arg340_1); arg340_1 = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] # File: <eval_with_key>.7:831 in forward, code: add_21 = torch.ops.aten.add.Tensor(add_20, getitem_346); add_20 = getitem_346 = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] add_21: "bf16[2048, 2048]" = torch.ops.aten.add.Tensor(add_20, arg346_1); arg346_1 = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] # File: <eval_with_key>.7:837 in forward, code: add_22 = torch.ops.aten.add.Tensor(add_21, getitem_356); add_21 = getitem_356 = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] add_22: "bf16[2048, 2048]" = torch.ops.aten.add.Tensor(add_21, arg356_1); arg356_1 = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] # File: <eval_with_key>.7:847 in forward, code: add_23 = torch.ops.aten.add.Tensor(add_22, getitem_362); add_22 = getitem_362 = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] add_23: "bf16[2048, 2048]" = torch.ops.aten.add.Tensor(add_22, arg362_1); arg362_1 = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] # File: <eval_with_key>.7:853 in forward, code: add_24 = torch.ops.aten.add.Tensor(add_23, getitem_372); add_23 = getitem_372 = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] add_24: "bf16[2048, 2048]" = torch.ops.aten.add.Tensor(add_23, arg372_1); arg372_1 = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] # File: <eval_with_key>.7:863 in forward, code: add_25 = torch.ops.aten.add.Tensor(add_24, getitem_378); add_24 = getitem_378 = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] add_25: "bf16[2048, 2048]" = torch.ops.aten.add.Tensor(add_24, arg378_1); arg378_1 = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] # File: <eval_with_key>.7:869 in forward, code: add_26 = torch.ops.aten.add.Tensor(add_25, getitem_388); add_25 = getitem_388 = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] add_26: "bf16[2048, 2048]" = torch.ops.aten.add.Tensor(add_25, arg388_1); arg388_1 = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] # File: <eval_with_key>.7:879 in forward, code: add_27 = torch.ops.aten.add.Tensor(add_26, getitem_394); add_26 = getitem_394 = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] add_27: "bf16[2048, 2048]" = torch.ops.aten.add.Tensor(add_26, arg394_1); arg394_1 = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] # File: <eval_with_key>.7:885 in forward, code: add_28 = torch.ops.aten.add.Tensor(add_27, getitem_404); add_27 = getitem_404 = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] add_28: "bf16[2048, 2048]" = torch.ops.aten.add.Tensor(add_27, arg404_1); arg404_1 = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] # File: <eval_with_key>.7:895 in forward, code: add_29 = torch.ops.aten.add.Tensor(add_28, getitem_410); add_28 = getitem_410 = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] add_29: "bf16[2048, 2048]" = torch.ops.aten.add.Tensor(add_28, arg410_1); arg410_1 = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] # File: <eval_with_key>.7:901 in forward, code: add_30 = torch.ops.aten.add.Tensor(add_29, getitem_420); add_29 = getitem_420 = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] add_30: "bf16[2048, 2048]" = torch.ops.aten.add.Tensor(add_29, arg420_1); arg420_1 = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] # File: <eval_with_key>.7:911 in forward, code: add_31 = torch.ops.aten.add.Tensor(add_30, getitem_426); add_30 = getitem_426 = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] add_31: "bf16[2048, 2048]" = torch.ops.aten.add.Tensor(add_30, arg426_1); arg426_1 = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] # File: <eval_with_key>.7:917 in forward, code: add_32 = torch.ops.aten.add.Tensor(add_31, getitem_436); add_31 = getitem_436 = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] add_32: "bf16[2048, 2048]" = torch.ops.aten.add.Tensor(add_31, arg436_1); arg436_1 = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] # File: <eval_with_key>.7:927 in forward, code: add_33 = torch.ops.aten.add.Tensor(add_32, getitem_442); add_32 = getitem_442 = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] add_33: "bf16[2048, 2048]" = torch.ops.aten.add.Tensor(add_32, arg442_1); arg442_1 = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] # File: <eval_with_key>.7:933 in forward, code: add_34 = torch.ops.aten.add.Tensor(add_33, getitem_452); add_33 = getitem_452 = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] add_34: "bf16[2048, 2048]" = torch.ops.aten.add.Tensor(add_33, arg452_1); arg452_1 = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] # File: <eval_with_key>.7:943 in forward, code: add_35 = torch.ops.aten.add.Tensor(add_34, getitem_458); add_34 = getitem_458 = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] add_35: "bf16[2048, 2048]" = torch.ops.aten.add.Tensor(add_34, arg458_1); arg458_1 = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] # File: <eval_with_key>.7:944 in forward, code: convert_element_type_108 = torch.ops.prims.convert_element_type.default(add_35, torch.float32); add_35 = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] convert_element_type_108: "f32[2048, 2048]" = torch.ops.prims.convert_element_type.default(add_35, torch.float32); add_35 = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] # File: <eval_with_key>.7:968 in forward, code: mul_57 = torch.ops.aten.mul.Tensor(convert_element_type_111, convert_element_type_108) | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] mul_59: "f32[2048, 2048]" = torch.ops.aten.mul.Tensor(convert_element_type_127, convert_element_type_108); convert_element_type_127 = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] # File: <eval_with_key>.7:970 in forward, code: sum_2 = torch.ops.aten.sum.dim_IntList(mul_57, [1], True); mul_57 = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] sum_3: "f32[2048, 1]" = torch.ops.aten.sum.dim_IntList(mul_59, [1], True); mul_59 = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] # File: <eval_with_key>.7:974 in forward, code: mul_59 = torch.ops.aten.mul.Scalar(sum_2, -0.5); sum_2 = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] mul_61: "f32[2048, 1]" = torch.ops.aten.mul.Scalar(sum_3, -0.5); sum_3 = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] # File: <eval_with_key>.7:973 in forward, code: pow_1 = torch.ops.aten.pow.Tensor_Scalar(alias_75, 3); alias_75 = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] pow_1: "f32[2048, 1]" = torch.ops.aten.pow.Tensor_Scalar(arg459_1, 3) | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] # File: <eval_with_key>.7:975 in forward, code: mul_60 = torch.ops.aten.mul.Tensor(mul_59, pow_1); mul_59 = pow_1 = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] mul_62: "f32[2048, 1]" = torch.ops.aten.mul.Tensor(mul_61, pow_1); mul_61 = pow_1 = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] # File: <eval_with_key>.7:976 in forward, code: expand = torch.ops.aten.expand.default(mul_60, [2048, 2048]); mul_60 = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] expand: "f32[2048, 2048]" = torch.ops.aten.expand.default(mul_62, [2048, 2048]); mul_62 = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] # File: <eval_with_key>.7:977 in forward, code: div = torch.ops.aten.div.Scalar(expand, 2048); expand = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] div_1: "f32[2048, 2048]" = torch.ops.aten.div.Scalar(expand, 2048); expand = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] # File: <eval_with_key>.7:978 in forward, code: pow_2 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_108, 1.0); convert_element_type_108 = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] pow_2: "f32[2048, 2048]" = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_108, 1.0) | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] # File: <eval_with_key>.7:979 in forward, code: mul_61 = torch.ops.aten.mul.Scalar(pow_2, 2.0); pow_2 = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] mul_63: "f32[2048, 2048]" = torch.ops.aten.mul.Scalar(pow_2, 2.0); pow_2 = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] # File: <eval_with_key>.7:980 in forward, code: mul_62 = torch.ops.aten.mul.Tensor(div, mul_61); div = mul_61 = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] mul_64: "f32[2048, 2048]" = torch.ops.aten.mul.Tensor(div_1, mul_63); div_1 = mul_63 = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] # File: <eval_with_key>.7:981 in forward, code: add_36 = torch.ops.aten.add.Tensor(mul_58, mul_62); mul_58 = mul_62 = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] add_36: "f32[2048, 2048]" = torch.ops.aten.add.Tensor(mul_60, mul_64); mul_60 = mul_64 = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] # File: <eval_with_key>.7:982 in forward, code: convert_element_type_112 = torch.ops.prims.convert_element_type.default(add_36, torch.bfloat16); add_36 = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] convert_element_type_128: "bf16[2048, 2048]" = torch.ops.prims.convert_element_type.default(add_36, torch.bfloat16); add_36 = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:170 in foreach_all_gather_copy_out, code: split_unpadded = torch.ops.fsdp.contiguous_view_as_strided( | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] contiguous_view_as_strided_8: "bf16[2048, 5632]" = torch.ops.fsdp.contiguous_view_as_strided.default(getitem_75, [2048, 5632], [5632, 1]); getitem_75 = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] # File: <eval_with_key>.7:989 in forward, code: mm_3 = torch.ops.aten.mm.default(trace_wrapped_1, permute_8); permute_8 = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] mm_3: "bf16[2048, 5632]" = torch.ops.aten.mm.default(convert_element_type_128, contiguous_view_as_strided_8); contiguous_view_as_strided_8 = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] # File: <eval_with_key>.7:939 in forward, code: convert_element_type_106 = torch.ops.prims.convert_element_type.default(getitem_455, torch.float32) | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] convert_element_type_106: "f32[2048, 5632]" = torch.ops.prims.convert_element_type.default(arg455_1, torch.float32) | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] # File: <eval_with_key>.7:940 in forward, code: sigmoid_17 = torch.ops.aten.sigmoid.default(convert_element_type_106) | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] sigmoid_17: "f32[2048, 5632]" = torch.ops.aten.sigmoid.default(convert_element_type_106) | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] # File: <eval_with_key>.7:941 in forward, code: mul_53 = torch.ops.aten.mul.Tensor(convert_element_type_106, sigmoid_17); convert_element_type_106 = sigmoid_17 = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] mul_55: "f32[2048, 5632]" = torch.ops.aten.mul.Tensor(convert_element_type_106, sigmoid_17); convert_element_type_106 = sigmoid_17 = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] # File: <eval_with_key>.7:942 in forward, code: convert_element_type_107 = torch.ops.prims.convert_element_type.default(mul_53, torch.bfloat16); mul_53 = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] convert_element_type_107: "bf16[2048, 5632]" = torch.ops.prims.convert_element_type.default(mul_55, torch.bfloat16); mul_55 = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] # File: <eval_with_key>.7:992 in forward, code: mul_63 = torch.ops.aten.mul.Tensor(mm_3, convert_element_type_107); convert_element_type_107 = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] mul_65: "bf16[2048, 5632]" = torch.ops.aten.mul.Tensor(mm_3, convert_element_type_107); convert_element_type_107 = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:170 in foreach_all_gather_copy_out, code: split_unpadded = torch.ops.fsdp.contiguous_view_as_strided( | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] contiguous_view_as_strided_9: "bf16[5632, 2048]" = torch.ops.fsdp.contiguous_view_as_strided.default(getitem_85, [5632, 2048], [2048, 1]); getitem_85 = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] # File: <eval_with_key>.7:993 in forward, code: mul_64 = torch.ops.aten.mul.Tensor(mm_3, getitem_456); mm_3 = getitem_456 = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] mul_66: "bf16[2048, 5632]" = torch.ops.aten.mul.Tensor(mm_3, arg456_1); mm_3 = arg456_1 = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] # File: <eval_with_key>.7:1003 in forward, code: full = torch.ops.aten.full.default([2048, 5632], 1, dtype = torch.bfloat16, layout = torch.strided, device = device(type='cuda', index=0), pin_memory = False) | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] full_166: "bf16[2048, 5632]" = torch.ops.aten.full.default([2048, 5632], 1, dtype = torch.bfloat16, layout = torch.strided, device = device(type='cuda', index=0), pin_memory = False) | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] # File: <eval_with_key>.7:1002 in forward, code: sigmoid_18 = torch.ops.aten.sigmoid.default(getitem_455) | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] sigmoid_18: "bf16[2048, 5632]" = torch.ops.aten.sigmoid.default(arg455_1) | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] # File: <eval_with_key>.7:1004 in forward, code: sub = torch.ops.aten.sub.Tensor(full, sigmoid_18) | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] sub_1: "bf16[2048, 5632]" = torch.ops.aten.sub.Tensor(full_166, sigmoid_18) | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] # File: <eval_with_key>.7:1005 in forward, code: mul_65 = torch.ops.aten.mul.Tensor(getitem_455, sub); getitem_455 = sub = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] mul_67: "bf16[2048, 5632]" = torch.ops.aten.mul.Tensor(arg455_1, sub_1); arg455_1 = sub_1 = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] # File: <eval_with_key>.7:1006 in forward, code: add_37 = torch.ops.aten.add.Scalar(mul_65, 1); mul_65 = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] add_37: "bf16[2048, 5632]" = torch.ops.aten.add.Scalar(mul_67, 1); mul_67 = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] # File: <eval_with_key>.7:1007 in forward, code: mul_66 = torch.ops.aten.mul.Tensor(sigmoid_18, add_37); sigmoid_18 = add_37 = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] mul_68: "bf16[2048, 5632]" = torch.ops.aten.mul.Tensor(sigmoid_18, add_37); sigmoid_18 = add_37 = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] # File: <eval_with_key>.7:1008 in forward, code: mul_67 = torch.ops.aten.mul.Tensor(mul_64, mul_66); mul_64 = mul_66 = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] mul_69: "bf16[2048, 5632]" = torch.ops.aten.mul.Tensor(mul_66, mul_68); mul_66 = mul_68 = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:170 in foreach_all_gather_copy_out, code: split_unpadded = torch.ops.fsdp.contiguous_view_as_strided( | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] contiguous_view_as_strided_7: "bf16[5632, 2048]" = torch.ops.fsdp.contiguous_view_as_strided.default(getitem_65, [5632, 2048], [2048, 1]); getitem_65 = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] # File: <eval_with_key>.7:1015 in forward, code: add_38 = torch.ops.aten.add.Tensor(mm_5, mm_7); mm_5 = mm_7 = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] mm_plus_mm_35: "bf16[2048, 2048]" = torch__inductor_fx_passes_post_grad_mm_plus_mm(mat1 = mul_65, mat2 = contiguous_view_as_strided_9, mat3 = mul_69, mat4 = contiguous_view_as_strided_7); contiguous_view_as_strided_9 = contiguous_view_as_strided_7 = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:170 in foreach_all_gather_copy_out, code: split_unpadded = torch.ops.fsdp.contiguous_view_as_strided( | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] contiguous_view_as_strided_11: "bf16[2048]" = torch.ops.fsdp.contiguous_view_as_strided.default(getitem_105, [2048], [1]); getitem_105 = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] # File: <eval_with_key>.7:1019 in forward, code: mul_69 = torch.ops.aten.mul.Tensor(add_38, getitem_169); add_38 = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] mul_71: "bf16[2048, 2048]" = torch.ops.aten.mul.Tensor(mm_plus_mm_35, contiguous_view_as_strided_11); contiguous_view_as_strided_11 = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] # File: <eval_with_key>.7:1023 in forward, code: convert_element_type_113 = torch.ops.prims.convert_element_type.default(mul_69, torch.float32); mul_69 = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] convert_element_type_150: "f32[2048, 2048]" = torch.ops.prims.convert_element_type.default(mul_71, torch.float32); mul_71 = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] # File: <eval_with_key>.7:1025 in forward, code: mul_71 = torch.ops.aten.mul.Tensor(convert_element_type_113, getitem_453); convert_element_type_113 = getitem_453 = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] mul_73: "f32[2048, 2048]" = torch.ops.aten.mul.Tensor(convert_element_type_150, arg453_1) | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] # File: <eval_with_key>.7:934 in forward, code: convert_element_type_104 = torch.ops.prims.convert_element_type.default(add_34, torch.float32) | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] convert_element_type_104: "f32[2048, 2048]" = torch.ops.prims.convert_element_type.default(add_34, torch.float32); add_34 = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] # File: <eval_with_key>.7:1024 in forward, code: mul_70 = torch.ops.aten.mul.Tensor(convert_element_type_113, convert_element_type_104) | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] mul_72: "f32[2048, 2048]" = torch.ops.aten.mul.Tensor(convert_element_type_150, convert_element_type_104); convert_element_type_150 = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] # File: <eval_with_key>.7:1026 in forward, code: sum_4 = torch.ops.aten.sum.dim_IntList(mul_70, [1], True); mul_70 = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] sum_5: "f32[2048, 1]" = torch.ops.aten.sum.dim_IntList(mul_72, [1], True); mul_72 = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] # File: <eval_with_key>.7:1030 in forward, code: mul_72 = torch.ops.aten.mul.Scalar(sum_4, -0.5); sum_4 = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] mul_74: "f32[2048, 1]" = torch.ops.aten.mul.Scalar(sum_5, -0.5); sum_5 = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] # File: <eval_with_key>.7:1029 in forward, code: pow_3 = torch.ops.aten.pow.Tensor_Scalar(alias_77, 3); alias_77 = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] pow_3: "f32[2048, 1]" = torch.ops.aten.pow.Tensor_Scalar(arg453_1, 3) | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] # File: <eval_with_key>.7:1031 in forward, code: mul_73 = torch.ops.aten.mul.Tensor(mul_72, pow_3); mul_72 = pow_3 = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] mul_75: "f32[2048, 1]" = torch.ops.aten.mul.Tensor(mul_74, pow_3); mul_74 = pow_3 = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] # File: <eval_with_key>.7:1032 in forward, code: expand_1 = torch.ops.aten.expand.default(mul_73, [2048, 2048]); mul_73 = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] expand_1: "f32[2048, 2048]" = torch.ops.aten.expand.default(mul_75, [2048, 2048]); mul_75 = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] # File: <eval_with_key>.7:1033 in forward, code: div_1 = torch.ops.aten.div.Scalar(expand_1, 2048); expand_1 = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] div_2: "f32[2048, 2048]" = torch.ops.aten.div.Scalar(expand_1, 2048); expand_1 = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] # File: <eval_with_key>.7:1034 in forward, code: pow_4 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_104, 1.0); convert_element_type_104 = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] pow_4: "f32[2048, 2048]" = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_104, 1.0) | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] # File: <eval_with_key>.7:1035 in forward, code: mul_74 = torch.ops.aten.mul.Scalar(pow_4, 2.0); pow_4 = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] mul_76: "f32[2048, 2048]" = torch.ops.aten.mul.Scalar(pow_4, 2.0); pow_4 = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] # File: <eval_with_key>.7:1036 in forward, code: mul_75 = torch.ops.aten.mul.Tensor(div_1, mul_74); div_1 = mul_74 = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] mul_77: "f32[2048, 2048]" = torch.ops.aten.mul.Tensor(div_2, mul_76); div_2 = mul_76 = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] # File: <eval_with_key>.7:1037 in forward, code: add_39 = torch.ops.aten.add.Tensor(mul_71, mul_75); mul_71 = mul_75 = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] add_39: "f32[2048, 2048]" = torch.ops.aten.add.Tensor(mul_73, mul_77); mul_73 = mul_77 = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] # File: <eval_with_key>.7:1038 in forward, code: convert_element_type_114 = torch.ops.prims.convert_element_type.default(add_39, torch.bfloat16); add_39 = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] convert_element_type_151: "bf16[2048, 2048]" = torch.ops.prims.convert_element_type.default(add_39, torch.bfloat16); add_39 = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] # File: <eval_with_key>.7:1039 in forward, code: add_40 = torch.ops.aten.add.Tensor(trace_wrapped_1, convert_element_type_114); trace_wrapped_1 = convert_element_type_114 = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] add_40: "bf16[2048, 2048]" = torch.ops.aten.add.Tensor(convert_element_type_128, convert_element_type_151); convert_element_type_151 = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:170 in foreach_all_gather_copy_out, code: split_unpadded = torch.ops.fsdp.contiguous_view_as_strided( | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] contiguous_view_as_strided_6: "bf16[2048, 2048]" = torch.ops.fsdp.contiguous_view_as_strided.default(getitem_55, [2048, 2048], [2048, 1]); getitem_55 = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] # File: <eval_with_key>.7:1045 in forward, code: mm_9 = torch.ops.aten.mm.default(add_40, permute_23); permute_23 = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] mm_9: "bf16[2048, 2048]" = torch.ops.aten.mm.default(add_40, contiguous_view_as_strided_6); contiguous_view_as_strided_6 = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] # File: <eval_with_key>.7:1048 in forward, code: view_7 = torch.ops.aten.view.default(mm_9, [1, 2048, 16, 128]); mm_9 = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] view_21: "bf16[1, 2048, 16, 128]" = torch.ops.aten.reshape.default(mm_9, [1, 2048, 16, 128]); mm_9 = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] # File: <eval_with_key>.7:1049 in forward, code: permute_25 = torch.ops.aten.permute.default(view_7, [0, 2, 1, 3]); view_7 = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] permute_35: "bf16[1, 16, 2048, 128]" = torch.ops.aten.permute.default(view_21, [0, 2, 1, 3]); view_21 = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] # File: <eval_with_key>.7:1050 in forward, code: _scaled_dot_product_flash_attention_backward = torch.ops.aten._scaled_dot_product_flash_attention_backward.default(permute_25, getitem_445, getitem_446, getitem_447, getitem_461, getitem_448, None, None, 2048, 2048, 0.0, True, getitem_449, getitem_450, scale = 0.08838834764831843); permute_25 = getitem_445 = getitem_446 = getitem_447 = getitem_461 = getitem_448 = getitem_449 = getitem_450 = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] _scaled_dot_product_flash_attention_backward = torch.ops.aten._scaled_dot_product_flash_attention_backward.default(permute_35, arg445_1, arg446_1, arg447_1, arg461_1, arg448_1, None, None, 2048, 2048, 0.0, True, arg449_1, arg450_1, scale = 0.08838834764831843); permute_35 = arg445_1 = arg446_1 = arg447_1 = arg461_1 = arg448_1 = arg449_1 = arg450_1 = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] getitem_108: "bf16[1, 16, 2048, 128]" = _scaled_dot_product_flash_attention_backward[0] | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] getitem_109: "bf16[1, 16, 2048, 128]" = _scaled_dot_product_flash_attention_backward[1] | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] getitem_110: "bf16[1, 16, 2048, 128]" = _scaled_dot_product_flash_attention_backward[2]; _scaled_dot_product_flash_attention_backward = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_functional_collectives.py:229 in all_gather_tensor, code: tensor = torch.ops._c10d_functional.all_gather_into_tensor( | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] all_gather_into_tensor_2: "bf16[51384320]" = torch.ops._c10d_functional.all_gather_into_tensor.default(getitem_106, 8, '0'); getitem_106 = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_functional_collectives.py:144 in wait_tensor, code: return torch.ops._c10d_functional.wait_tensor(tensor) # type: ignore[attr-defined] | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] wait_tensor_2: "bf16[51384320]" = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_2); all_gather_into_tensor_2 = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:170 in foreach_all_gather_copy_out, code: split_unpadded = torch.ops.fsdp.contiguous_view_as_strided( | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] view_31: "bf16[8, 6423040]" = torch.ops.aten.reshape.default(wait_tensor_2, [8, -1]) | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] split_with_sizes_15 = torch.ops.aten.split_with_sizes.default(view_31, [524288, 524288, 524288, 524288, 1441792, 1441792, 1441792, 256, 256], 1); view_31 = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] getitem_120: "bf16[8, 524288]" = split_with_sizes_15[0]; split_with_sizes_15 = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] view_32: "bf16[8, 6423040]" = torch.ops.aten.reshape.default(wait_tensor_2, [8, -1]) | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] split_with_sizes_16 = torch.ops.aten.split_with_sizes.default(view_32, [524288, 524288, 524288, 524288, 1441792, 1441792, 1441792, 256, 256], 1); view_32 = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] getitem_130: "bf16[8, 524288]" = split_with_sizes_16[1]; split_with_sizes_16 = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] view_33: "bf16[8, 6423040]" = torch.ops.aten.reshape.default(wait_tensor_2, [8, -1]) | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] split_with_sizes_17 = torch.ops.aten.split_with_sizes.default(view_33, [524288, 524288, 524288, 524288, 1441792, 1441792, 1441792, 256, 256], 1); view_33 = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] getitem_140: "bf16[8, 524288]" = split_with_sizes_17[2]; split_with_sizes_17 = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] view_34: "bf16[8, 6423040]" = torch.ops.aten.reshape.default(wait_tensor_2, [8, -1]) | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] split_with_sizes_18 = torch.ops.aten.split_with_sizes.default(view_34, [524288, 524288, 524288, 524288, 1441792, 1441792, 1441792, 256, 256], 1); view_34 = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] getitem_150: "bf16[8, 524288]" = split_with_sizes_18[3]; split_with_sizes_18 = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] view_35: "bf16[8, 6423040]" = torch.ops.aten.reshape.default(wait_tensor_2, [8, -1]) | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] split_with_sizes_19 = torch.ops.aten.split_with_sizes.default(view_35, [524288, 524288, 524288, 524288, 1441792, 1441792, 1441792, 256, 256], 1); view_35 = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] getitem_160: "bf16[8, 1441792]" = split_with_sizes_19[4]; split_with_sizes_19 = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] view_36: "bf16[8, 6423040]" = torch.ops.aten.reshape.default(wait_tensor_2, [8, -1]) | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] split_with_sizes_20 = torch.ops.aten.split_with_sizes.default(view_36, [524288, 524288, 524288, 524288, 1441792, 1441792, 1441792, 256, 256], 1); view_36 = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] getitem_170: "bf16[8, 1441792]" = split_with_sizes_20[5]; split_with_sizes_20 = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] view_37: "bf16[8, 6423040]" = torch.ops.aten.reshape.default(wait_tensor_2, [8, -1]) | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] split_with_sizes_21 = torch.ops.aten.split_with_sizes.default(view_37, [524288, 524288, 524288, 524288, 1441792, 1441792, 1441792, 256, 256], 1); view_37 = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] getitem_180: "bf16[8, 1441792]" = split_with_sizes_21[6]; split_with_sizes_21 = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] view_38: "bf16[8, 6423040]" = torch.ops.aten.reshape.default(wait_tensor_2, [8, -1]) | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] split_with_sizes_22 = torch.ops.aten.split_with_sizes.default(view_38, [524288, 524288, 524288, 524288, 1441792, 1441792, 1441792, 256, 256], 1); view_38 = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] getitem_190: "bf16[8, 256]" = split_with_sizes_22[7]; split_with_sizes_22 = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] view_39: "bf16[8, 6423040]" = torch.ops.aten.reshape.default(wait_tensor_2, [8, -1]); wait_tensor_2 = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] split_with_sizes_23 = torch.ops.aten.split_with_sizes.default(view_39, [524288, 524288, 524288, 524288, 1441792, 1441792, 1441792, 256, 256], 1); view_39 = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] getitem_200: "bf16[8, 256]" = split_with_sizes_23[8]; split_with_sizes_23 = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_common.py:140 in _to_dtype_if_needed, code: return tensor.to(dtype) | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] convert_element_type_174: "bf16[524288]" = torch.ops.prims.convert_element_type.default(arg502_1, torch.bfloat16); arg502_1 = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] convert_element_type_175: "bf16[524288]" = torch.ops.prims.convert_element_type.default(arg503_1, torch.bfloat16); arg503_1 = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] convert_element_type_176: "bf16[524288]" = torch.ops.prims.convert_element_type.default(arg504_1, torch.bfloat16); arg504_1 = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] convert_element_type_177: "bf16[524288]" = torch.ops.prims.convert_element_type.default(arg505_1, torch.bfloat16); arg505_1 = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] convert_element_type_178: "bf16[1441792]" = torch.ops.prims.convert_element_type.default(arg506_1, torch.bfloat16); arg506_1 = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] convert_element_type_179: "bf16[1441792]" = torch.ops.prims.convert_element_type.default(arg507_1, torch.bfloat16); arg507_1 = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] convert_element_type_180: "bf16[1441792]" = torch.ops.prims.convert_element_type.default(arg508_1, torch.bfloat16); arg508_1 = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] convert_element_type_181: "bf16[256]" = torch.ops.prims.convert_element_type.default(arg509_1, torch.bfloat16); arg509_1 = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] convert_element_type_182: "bf16[256]" = torch.ops.prims.convert_element_type.default(arg510_1, torch.bfloat16); arg510_1 = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:80 in foreach_all_gather, code: all_gather_input, all_gather_output = torch.ops.fsdp.all_gather_copy_in(all_gather_input_numel, world_size, rank, dtype, device, inp_split_sizes, param_all_gather_inputs) | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] all_gather_copy_in_3 = torch.ops.fsdp.all_gather_copy_in.default(6423040, 8, 0, torch.bfloat16, device(type='cuda', index=0), [524288, 524288, 524288, 524288, 1441792, 1441792, 1441792, 256, 256], [convert_element_type_174, convert_element_type_175, convert_element_type_176, convert_element_type_177, convert_element_type_178, convert_element_type_179, convert_element_type_180, convert_element_type_181, convert_element_type_182]); convert_element_type_174 = convert_element_type_175 = convert_element_type_176 = convert_element_type_177 = convert_element_type_178 = convert_element_type_179 = convert_element_type_180 = convert_element_type_181 = convert_element_type_182 = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] getitem_201: "bf16[6423040]" = all_gather_copy_in_3[0] | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] getitem_202: "bf16[51384320]" = all_gather_copy_in_3[1]; all_gather_copy_in_3 = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] # File: <eval_with_key>.7:1054 in forward, code: permute_26 = torch.ops.aten.permute.default(getitem_484, [0, 2, 1, 3]); getitem_484 = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] permute_36: "bf16[1, 2048, 16, 128]" = torch.ops.aten.permute.default(getitem_110, [0, 2, 1, 3]); getitem_110 = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] # File: <eval_with_key>.7:1072 in forward, code: view_12 = torch.ops.aten.view.default(permute_26, [2048, 2048]); permute_26 = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] view_26: "bf16[2048, 2048]" = torch.ops.aten.reshape.default(permute_36, [2048, 2048]); permute_36 = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:170 in foreach_all_gather_copy_out, code: split_unpadded = torch.ops.fsdp.contiguous_view_as_strided( | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] contiguous_view_as_strided_5: "bf16[2048, 2048]" = torch.ops.fsdp.contiguous_view_as_strided.default(getitem_45, [2048, 2048], [2048, 1]); getitem_45 = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] # File: <eval_with_key>.7:1055 in forward, code: permute_27 = torch.ops.aten.permute.default(getitem_483, [0, 2, 1, 3]); getitem_483 = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] permute_37: "bf16[1, 2048, 16, 128]" = torch.ops.aten.permute.default(getitem_109, [0, 2, 1, 3]); getitem_109 = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] # File: <eval_with_key>.7:1057 in forward, code: convert_element_type_115 = torch.ops.prims.convert_element_type.default(permute_27, torch.float32); permute_27 = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] convert_element_type_156: "f32[1, 2048, 16, 128]" = torch.ops.prims.convert_element_type.default(permute_37, torch.float32); permute_37 = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] # File: <eval_with_key>.7:1059 in forward, code: view_8 = torch.ops.aten.view.default(convert_element_type_115, [1, 2048, 16, 64, 2]); convert_element_type_115 = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] view_22: "f32[1, 2048, 16, 64, 2]" = torch.ops.aten.reshape.default(convert_element_type_156, [1, 2048, 16, 64, 2]); convert_element_type_156 = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] # File: <eval_with_key>.7:1060 in forward, code: view_as_complex = torch.ops.aten.view_as_complex.default(view_8); view_8 = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] view_as_complex: "c64[1, 2048, 16, 64]" = torch.ops.aten.view_as_complex.default(view_22); view_22 = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] # File: <eval_with_key>.7:1062 in forward, code: mul_76 = torch.ops.aten.mul.Tensor(view_as_complex, clone); view_as_complex = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] mul_78: "c64[1, 2048, 16, 64]" = torch.ops.aten.mul.Tensor(view_as_complex, arg462_1); view_as_complex = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] # File: <eval_with_key>.7:1066 in forward, code: view_as_real = torch.ops.aten.view_as_real.default(mul_76); mul_76 = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] view_as_real: "f32[1, 2048, 16, 64, 2]" = torch.ops.aten.view_as_real.default(mul_78); mul_78 = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] # File: <eval_with_key>.7:1067 in forward, code: view_10 = torch.ops.aten.view.default(view_as_real, [1, 2048, 16, 128]); view_as_real = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] view_24: "f32[1, 2048, 16, 128]" = torch.ops.aten.reshape.default(view_as_real, [1, 2048, 16, 128]); view_as_real = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] # File: <eval_with_key>.7:1068 in forward, code: convert_element_type_117 = torch.ops.prims.convert_element_type.default(view_10, torch.bfloat16); view_10 = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] convert_element_type_158: "bf16[1, 2048, 16, 128]" = torch.ops.prims.convert_element_type.default(view_24, torch.bfloat16); view_24 = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] # File: <eval_with_key>.7:1073 in forward, code: view_13 = torch.ops.aten.view.default(convert_element_type_117, [2048, 2048]); convert_element_type_117 = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] view_27: "bf16[2048, 2048]" = torch.ops.aten.reshape.default(convert_element_type_158, [2048, 2048]); convert_element_type_158 = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:170 in foreach_all_gather_copy_out, code: split_unpadded = torch.ops.fsdp.contiguous_view_as_strided( | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] contiguous_view_as_strided_4: "bf16[2048, 2048]" = torch.ops.fsdp.contiguous_view_as_strided.default(getitem_35, [2048, 2048], [2048, 1]); getitem_35 = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] # File: <eval_with_key>.7:1089 in forward, code: add_41 = torch.ops.aten.add.Tensor(mm_11, mm_13); mm_11 = mm_13 = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] mm_plus_mm_34: "bf16[2048, 2048]" = torch__inductor_fx_passes_post_grad_mm_plus_mm_1(mat1 = view_26, mat2 = contiguous_view_as_strided_5, mat3 = view_27, mat4 = contiguous_view_as_strided_4); contiguous_view_as_strided_5 = contiguous_view_as_strided_4 = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] # File: <eval_with_key>.7:1056 in forward, code: permute_28 = torch.ops.aten.permute.default(getitem_482, [0, 2, 1, 3]); getitem_482 = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] permute_38: "bf16[1, 2048, 16, 128]" = torch.ops.aten.permute.default(getitem_108, [0, 2, 1, 3]); getitem_108 = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] # File: <eval_with_key>.7:1058 in forward, code: convert_element_type_116 = torch.ops.prims.convert_element_type.default(permute_28, torch.float32); permute_28 = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] convert_element_type_157: "f32[1, 2048, 16, 128]" = torch.ops.prims.convert_element_type.default(permute_38, torch.float32); permute_38 = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] # File: <eval_with_key>.7:1063 in forward, code: view_9 = torch.ops.aten.view.default(convert_element_type_116, [1, 2048, 16, 64, 2]); convert_element_type_116 = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] view_23: "f32[1, 2048, 16, 64, 2]" = torch.ops.aten.reshape.default(convert_element_type_157, [1, 2048, 16, 64, 2]); convert_element_type_157 = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] # File: <eval_with_key>.7:1064 in forward, code: view_as_complex_1 = torch.ops.aten.view_as_complex.default(view_9); view_9 = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] view_as_complex_1: "c64[1, 2048, 16, 64]" = torch.ops.aten.view_as_complex.default(view_23); view_23 = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] # File: <eval_with_key>.7:1065 in forward, code: mul_77 = torch.ops.aten.mul.Tensor(view_as_complex_1, clone); view_as_complex_1 = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] mul_79: "c64[1, 2048, 16, 64]" = torch.ops.aten.mul.Tensor(view_as_complex_1, arg462_1); view_as_complex_1 = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] # File: <eval_with_key>.7:1069 in forward, code: view_as_real_1 = torch.ops.aten.view_as_real.default(mul_77); mul_77 = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] view_as_real_1: "f32[1, 2048, 16, 64, 2]" = torch.ops.aten.view_as_real.default(mul_79); mul_79 = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] # File: <eval_with_key>.7:1070 in forward, code: view_11 = torch.ops.aten.view.default(view_as_real_1, [1, 2048, 16, 128]); view_as_real_1 = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] view_25: "f32[1, 2048, 16, 128]" = torch.ops.aten.reshape.default(view_as_real_1, [1, 2048, 16, 128]); view_as_real_1 = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] # File: <eval_with_key>.7:1071 in forward, code: convert_element_type_118 = torch.ops.prims.convert_element_type.default(view_11, torch.bfloat16); view_11 = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] convert_element_type_159: "bf16[1, 2048, 16, 128]" = torch.ops.prims.convert_element_type.default(view_25, torch.bfloat16); view_25 = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] # File: <eval_with_key>.7:1074 in forward, code: view_14 = torch.ops.aten.view.default(convert_element_type_118, [2048, 2048]); convert_element_type_118 = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] view_28: "bf16[2048, 2048]" = torch.ops.aten.reshape.default(convert_element_type_159, [2048, 2048]); convert_element_type_159 = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:170 in foreach_all_gather_copy_out, code: split_unpadded = torch.ops.fsdp.contiguous_view_as_strided( | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] contiguous_view_as_strided_3: "bf16[2048, 2048]" = torch.ops.fsdp.contiguous_view_as_strided.default(getitem_25, [2048, 2048], [2048, 1]); getitem_25 = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] # File: <eval_with_key>.7:1097 in forward, code: mm_15 = torch.ops.aten.mm.default(view_14, permute_42); view_14 = permute_42 = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] mm_15: "bf16[2048, 2048]" = torch.ops.aten.mm.default(view_28, contiguous_view_as_strided_3); contiguous_view_as_strided_3 = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] # File: <eval_with_key>.7:1098 in forward, code: add_42 = torch.ops.aten.add.Tensor(add_41, mm_15); add_41 = mm_15 = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] add_42: "bf16[2048, 2048]" = torch.ops.aten.add.Tensor(mm_plus_mm_34, mm_15); mm_plus_mm_34 = mm_15 = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:170 in foreach_all_gather_copy_out, code: split_unpadded = torch.ops.fsdp.contiguous_view_as_strided( | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] contiguous_view_as_strided_10: "bf16[2048]" = torch.ops.fsdp.contiguous_view_as_strided.default(getitem_95, [2048], [1]); getitem_95 = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] # File: <eval_with_key>.7:1102 in forward, code: mul_79 = torch.ops.aten.mul.Tensor(add_42, getitem_168); add_42 = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] mul_81: "bf16[2048, 2048]" = torch.ops.aten.mul.Tensor(add_42, contiguous_view_as_strided_10); contiguous_view_as_strided_10 = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] # File: <eval_with_key>.7:1106 in forward, code: convert_element_type_119 = torch.ops.prims.convert_element_type.default(mul_79, torch.float32); mul_79 = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] convert_element_type_172: "f32[2048, 2048]" = torch.ops.prims.convert_element_type.default(mul_81, torch.float32); mul_81 = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] # File: <eval_with_key>.7:1108 in forward, code: mul_81 = torch.ops.aten.mul.Tensor(convert_element_type_119, getitem_443); convert_element_type_119 = getitem_443 = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] mul_83: "f32[2048, 2048]" = torch.ops.aten.mul.Tensor(convert_element_type_172, arg443_1) | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] # File: <eval_with_key>.7:928 in forward, code: convert_element_type_102 = torch.ops.prims.convert_element_type.default(add_33, torch.float32) | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] convert_element_type_102: "f32[2048, 2048]" = torch.ops.prims.convert_element_type.default(add_33, torch.float32); add_33 = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] # File: <eval_with_key>.7:1107 in forward, code: mul_80 = torch.ops.aten.mul.Tensor(convert_element_type_119, convert_element_type_102) | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] mul_82: "f32[2048, 2048]" = torch.ops.aten.mul.Tensor(convert_element_type_172, convert_element_type_102); convert_element_type_172 = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] # File: <eval_with_key>.7:1109 in forward, code: sum_6 = torch.ops.aten.sum.dim_IntList(mul_80, [1], True); mul_80 = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] sum_7: "f32[2048, 1]" = torch.ops.aten.sum.dim_IntList(mul_82, [1], True); mul_82 = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] # File: <eval_with_key>.7:1113 in forward, code: mul_82 = torch.ops.aten.mul.Scalar(sum_6, -0.5); sum_6 = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] mul_84: "f32[2048, 1]" = torch.ops.aten.mul.Scalar(sum_7, -0.5); sum_7 = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] # File: <eval_with_key>.7:1112 in forward, code: pow_5 = torch.ops.aten.pow.Tensor_Scalar(alias_79, 3); alias_79 = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] pow_5: "f32[2048, 1]" = torch.ops.aten.pow.Tensor_Scalar(arg443_1, 3) | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] # File: <eval_with_key>.7:1114 in forward, code: mul_83 = torch.ops.aten.mul.Tensor(mul_82, pow_5); mul_82 = pow_5 = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] mul_85: "f32[2048, 1]" = torch.ops.aten.mul.Tensor(mul_84, pow_5); mul_84 = pow_5 = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] # File: <eval_with_key>.7:1115 in forward, code: expand_2 = torch.ops.aten.expand.default(mul_83, [2048, 2048]); mul_83 = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] expand_2: "f32[2048, 2048]" = torch.ops.aten.expand.default(mul_85, [2048, 2048]); mul_85 = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] # File: <eval_with_key>.7:1116 in forward, code: div_2 = torch.ops.aten.div.Scalar(expand_2, 2048); expand_2 = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] div_3: "f32[2048, 2048]" = torch.ops.aten.div.Scalar(expand_2, 2048); expand_2 = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] # File: <eval_with_key>.7:1117 in forward, code: pow_6 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_102, 1.0); convert_element_type_102 = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] pow_6: "f32[2048, 2048]" = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_102, 1.0) | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] # File: <eval_with_key>.7:1118 in forward, code: mul_84 = torch.ops.aten.mul.Scalar(pow_6, 2.0); pow_6 = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] mul_86: "f32[2048, 2048]" = torch.ops.aten.mul.Scalar(pow_6, 2.0); pow_6 = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] # File: <eval_with_key>.7:1119 in forward, code: mul_85 = torch.ops.aten.mul.Tensor(div_2, mul_84); div_2 = mul_84 = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] mul_87: "f32[2048, 2048]" = torch.ops.aten.mul.Tensor(div_3, mul_86); div_3 = mul_86 = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] # File: <eval_with_key>.7:1120 in forward, code: add_43 = torch.ops.aten.add.Tensor(mul_81, mul_85); mul_81 = mul_85 = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] add_43: "f32[2048, 2048]" = torch.ops.aten.add.Tensor(mul_83, mul_87); mul_83 = mul_87 = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] # File: <eval_with_key>.7:1121 in forward, code: convert_element_type_120 = torch.ops.prims.convert_element_type.default(add_43, torch.bfloat16); add_43 = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] convert_element_type_173: "bf16[2048, 2048]" = torch.ops.prims.convert_element_type.default(add_43, torch.bfloat16); add_43 = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] # File: <eval_with_key>.7:1122 in forward, code: add_44 = torch.ops.aten.add.Tensor(add_40, convert_element_type_120); add_40 = convert_element_type_120 = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] add_44: "bf16[2048, 2048]" = torch.ops.aten.add.Tensor(add_40, convert_element_type_173); convert_element_type_173 = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:170 in foreach_all_gather_copy_out, code: split_unpadded = torch.ops.fsdp.contiguous_view_as_strided( | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] contiguous_view_as_strided_17: "bf16[2048, 5632]" = torch.ops.fsdp.contiguous_view_as_strided.default(getitem_170, [2048, 5632], [5632, 1]); getitem_170 = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] # File: <eval_with_key>.7:1129 in forward, code: mm_17 = torch.ops.aten.mm.default(trace_wrapped_2, permute_47); permute_47 = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] mm_17: "bf16[2048, 5632]" = torch.ops.aten.mm.default(add_44, contiguous_view_as_strided_17); contiguous_view_as_strided_17 = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] # File: <eval_with_key>.7:923 in forward, code: convert_element_type_100 = torch.ops.prims.convert_element_type.default(getitem_439, torch.float32) | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] convert_element_type_100: "f32[2048, 5632]" = torch.ops.prims.convert_element_type.default(arg439_1, torch.float32) | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] # File: <eval_with_key>.7:924 in forward, code: sigmoid_16 = torch.ops.aten.sigmoid.default(convert_element_type_100) | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] sigmoid_16: "f32[2048, 5632]" = torch.ops.aten.sigmoid.default(convert_element_type_100) | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] # File: <eval_with_key>.7:925 in forward, code: mul_50 = torch.ops.aten.mul.Tensor(convert_element_type_100, sigmoid_16); convert_element_type_100 = sigmoid_16 = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] mul_52: "f32[2048, 5632]" = torch.ops.aten.mul.Tensor(convert_element_type_100, sigmoid_16); convert_element_type_100 = sigmoid_16 = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] # File: <eval_with_key>.7:926 in forward, code: convert_element_type_101 = torch.ops.prims.convert_element_type.default(mul_50, torch.bfloat16); mul_50 = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] convert_element_type_101: "bf16[2048, 5632]" = torch.ops.prims.convert_element_type.default(mul_52, torch.bfloat16); mul_52 = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] # File: <eval_with_key>.7:1132 in forward, code: mul_86 = torch.ops.aten.mul.Tensor(mm_17, convert_element_type_101); convert_element_type_101 = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] mul_88: "bf16[2048, 5632]" = torch.ops.aten.mul.Tensor(mm_17, convert_element_type_101); convert_element_type_101 = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:170 in foreach_all_gather_copy_out, code: split_unpadded = torch.ops.fsdp.contiguous_view_as_strided( | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] contiguous_view_as_strided_18: "bf16[5632, 2048]" = torch.ops.fsdp.contiguous_view_as_strided.default(getitem_180, [5632, 2048], [2048, 1]); getitem_180 = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] # File: <eval_with_key>.7:1133 in forward, code: mul_87 = torch.ops.aten.mul.Tensor(mm_17, getitem_440); mm_17 = getitem_440 = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] mul_89: "bf16[2048, 5632]" = torch.ops.aten.mul.Tensor(mm_17, arg440_1); mm_17 = arg440_1 = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] # File: <eval_with_key>.7:1142 in forward, code: sigmoid_19 = torch.ops.aten.sigmoid.default(getitem_439) | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] sigmoid_19: "bf16[2048, 5632]" = torch.ops.aten.sigmoid.default(arg439_1) | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] # File: <eval_with_key>.7:1143 in forward, code: sub_1 = torch.ops.aten.sub.Tensor(full, sigmoid_19) | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] sub_2: "bf16[2048, 5632]" = torch.ops.aten.sub.Tensor(full_166, sigmoid_19) | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] # File: <eval_with_key>.7:1144 in forward, code: mul_88 = torch.ops.aten.mul.Tensor(getitem_439, sub_1); getitem_439 = sub_1 = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] mul_90: "bf16[2048, 5632]" = torch.ops.aten.mul.Tensor(arg439_1, sub_2); arg439_1 = sub_2 = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] # File: <eval_with_key>.7:1145 in forward, code: add_45 = torch.ops.aten.add.Scalar(mul_88, 1); mul_88 = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] add_45: "bf16[2048, 5632]" = torch.ops.aten.add.Scalar(mul_90, 1); mul_90 = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] # File: <eval_with_key>.7:1146 in forward, code: mul_89 = torch.ops.aten.mul.Tensor(sigmoid_19, add_45); sigmoid_19 = add_45 = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] mul_91: "bf16[2048, 5632]" = torch.ops.aten.mul.Tensor(sigmoid_19, add_45); sigmoid_19 = add_45 = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] # File: <eval_with_key>.7:1147 in forward, code: mul_90 = torch.ops.aten.mul.Tensor(mul_87, mul_89); mul_87 = mul_89 = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] mul_92: "bf16[2048, 5632]" = torch.ops.aten.mul.Tensor(mul_89, mul_91); mul_89 = mul_91 = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:170 in foreach_all_gather_copy_out, code: split_unpadded = torch.ops.fsdp.contiguous_view_as_strided( | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] contiguous_view_as_strided_16: "bf16[5632, 2048]" = torch.ops.fsdp.contiguous_view_as_strided.default(getitem_160, [5632, 2048], [2048, 1]); getitem_160 = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] # File: <eval_with_key>.7:1154 in forward, code: add_46 = torch.ops.aten.add.Tensor(mm_19, mm_21); mm_19 = mm_21 = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] mm_plus_mm_33: "bf16[2048, 2048]" = torch__inductor_fx_passes_post_grad_mm_plus_mm_2(mat1 = mul_88, mat2 = contiguous_view_as_strided_18, mat3 = mul_92, mat4 = contiguous_view_as_strided_16); contiguous_view_as_strided_18 = contiguous_view_as_strided_16 = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:170 in foreach_all_gather_copy_out, code: split_unpadded = torch.ops.fsdp.contiguous_view_as_strided( | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] contiguous_view_as_strided_20: "bf16[2048]" = torch.ops.fsdp.contiguous_view_as_strided.default(getitem_200, [2048], [1]); getitem_200 = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] # File: <eval_with_key>.7:1158 in forward, code: mul_92 = torch.ops.aten.mul.Tensor(add_46, getitem_160); add_46 = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] mul_94: "bf16[2048, 2048]" = torch.ops.aten.mul.Tensor(mm_plus_mm_33, contiguous_view_as_strided_20); contiguous_view_as_strided_20 = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] # File: <eval_with_key>.7:1162 in forward, code: convert_element_type_121 = torch.ops.prims.convert_element_type.default(mul_92, torch.float32); mul_92 = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] convert_element_type_195: "f32[2048, 2048]" = torch.ops.prims.convert_element_type.default(mul_94, torch.float32); mul_94 = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] # File: <eval_with_key>.7:1164 in forward, code: mul_94 = torch.ops.aten.mul.Tensor(convert_element_type_121, getitem_437); convert_element_type_121 = getitem_437 = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] mul_96: "f32[2048, 2048]" = torch.ops.aten.mul.Tensor(convert_element_type_195, arg437_1) | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] # File: <eval_with_key>.7:918 in forward, code: convert_element_type_98 = torch.ops.prims.convert_element_type.default(add_32, torch.float32) | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] convert_element_type_98: "f32[2048, 2048]" = torch.ops.prims.convert_element_type.default(add_32, torch.float32); add_32 = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] # File: <eval_with_key>.7:1163 in forward, code: mul_93 = torch.ops.aten.mul.Tensor(convert_element_type_121, convert_element_type_98) | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] mul_95: "f32[2048, 2048]" = torch.ops.aten.mul.Tensor(convert_element_type_195, convert_element_type_98); convert_element_type_195 = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] # File: <eval_with_key>.7:1165 in forward, code: sum_8 = torch.ops.aten.sum.dim_IntList(mul_93, [1], True); mul_93 = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] sum_9: "f32[2048, 1]" = torch.ops.aten.sum.dim_IntList(mul_95, [1], True); mul_95 = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] # File: <eval_with_key>.7:1169 in forward, code: mul_95 = torch.ops.aten.mul.Scalar(sum_8, -0.5); sum_8 = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] mul_97: "f32[2048, 1]" = torch.ops.aten.mul.Scalar(sum_9, -0.5); sum_9 = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] # File: <eval_with_key>.7:1168 in forward, code: pow_7 = torch.ops.aten.pow.Tensor_Scalar(alias_81, 3); alias_81 = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] pow_7: "f32[2048, 1]" = torch.ops.aten.pow.Tensor_Scalar(arg437_1, 3) | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] # File: <eval_with_key>.7:1170 in forward, code: mul_96 = torch.ops.aten.mul.Tensor(mul_95, pow_7); mul_95 = pow_7 = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] mul_98: "f32[2048, 1]" = torch.ops.aten.mul.Tensor(mul_97, pow_7); mul_97 = pow_7 = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] # File: <eval_with_key>.7:1171 in forward, code: expand_3 = torch.ops.aten.expand.default(mul_96, [2048, 2048]); mul_96 = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] expand_3: "f32[2048, 2048]" = torch.ops.aten.expand.default(mul_98, [2048, 2048]); mul_98 = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] # File: <eval_with_key>.7:1172 in forward, code: div_3 = torch.ops.aten.div.Scalar(expand_3, 2048); expand_3 = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] div_4: "f32[2048, 2048]" = torch.ops.aten.div.Scalar(expand_3, 2048); expand_3 = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] # File: <eval_with_key>.7:1173 in forward, code: pow_8 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_98, 1.0); convert_element_type_98 = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] pow_8: "f32[2048, 2048]" = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_98, 1.0) | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] # File: <eval_with_key>.7:1174 in forward, code: mul_97 = torch.ops.aten.mul.Scalar(pow_8, 2.0); pow_8 = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] mul_99: "f32[2048, 2048]" = torch.ops.aten.mul.Scalar(pow_8, 2.0); pow_8 = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] # File: <eval_with_key>.7:1175 in forward, code: mul_98 = torch.ops.aten.mul.Tensor(div_3, mul_97); div_3 = mul_97 = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] mul_100: "f32[2048, 2048]" = torch.ops.aten.mul.Tensor(div_4, mul_99); div_4 = mul_99 = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] # File: <eval_with_key>.7:1176 in forward, code: add_47 = torch.ops.aten.add.Tensor(mul_94, mul_98); mul_94 = mul_98 = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] add_47: "f32[2048, 2048]" = torch.ops.aten.add.Tensor(mul_96, mul_100); mul_96 = mul_100 = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] # File: <eval_with_key>.7:1177 in forward, code: convert_element_type_122 = torch.ops.prims.convert_element_type.default(add_47, torch.bfloat16); add_47 = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] convert_element_type_196: "bf16[2048, 2048]" = torch.ops.prims.convert_element_type.default(add_47, torch.bfloat16); add_47 = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] # File: <eval_with_key>.7:1178 in forward, code: add_48 = torch.ops.aten.add.Tensor(trace_wrapped_2, convert_element_type_122); trace_wrapped_2 = convert_element_type_122 = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] add_48: "bf16[2048, 2048]" = torch.ops.aten.add.Tensor(add_44, convert_element_type_196); convert_element_type_196 = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:170 in foreach_all_gather_copy_out, code: split_unpadded = torch.ops.fsdp.contiguous_view_as_strided( | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] contiguous_view_as_strided_15: "bf16[2048, 2048]" = torch.ops.fsdp.contiguous_view_as_strided.default(getitem_150, [2048, 2048], [2048, 1]); getitem_150 = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] # File: <eval_with_key>.7:1184 in forward, code: mm_23 = torch.ops.aten.mm.default(add_48, permute_62); permute_62 = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] mm_23: "bf16[2048, 2048]" = torch.ops.aten.mm.default(add_48, contiguous_view_as_strided_15); contiguous_view_as_strided_15 = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] # File: <eval_with_key>.7:1187 in forward, code: view_17 = torch.ops.aten.view.default(mm_23, [1, 2048, 16, 128]); mm_23 = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] view_41: "bf16[1, 2048, 16, 128]" = torch.ops.aten.reshape.default(mm_23, [1, 2048, 16, 128]); mm_23 = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] # File: <eval_with_key>.7:1188 in forward, code: permute_64 = torch.ops.aten.permute.default(view_17, [0, 2, 1, 3]); view_17 = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] permute_88: "bf16[1, 16, 2048, 128]" = torch.ops.aten.permute.default(view_41, [0, 2, 1, 3]); view_41 = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] # File: <eval_with_key>.7:1189 in forward, code: _scaled_dot_product_flash_attention_backward_1 = torch.ops.aten._scaled_dot_product_flash_attention_backward.default(permute_64, getitem_429, getitem_430, getitem_431, getitem_463, getitem_432, None, None, 2048, 2048, 0.0, True, getitem_433, getitem_434, scale = 0.08838834764831843); permute_64 = getitem_429 = getitem_430 = getitem_431 = getitem_463 = getitem_432 = getitem_433 = getitem_434 = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] _scaled_dot_product_flash_attention_backward_1 = torch.ops.aten._scaled_dot_product_flash_attention_backward.default(permute_88, arg429_1, arg430_1, arg431_1, arg463_1, arg432_1, None, None, 2048, 2048, 0.0, True, arg433_1, arg434_1, scale = 0.08838834764831843); permute_88 = arg429_1 = arg430_1 = arg431_1 = arg463_1 = arg432_1 = arg433_1 = arg434_1 = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] getitem_203: "bf16[1, 16, 2048, 128]" = _scaled_dot_product_flash_attention_backward_1[0] | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] getitem_204: "bf16[1, 16, 2048, 128]" = _scaled_dot_product_flash_attention_backward_1[1] | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] getitem_205: "bf16[1, 16, 2048, 128]" = _scaled_dot_product_flash_attention_backward_1[2]; _scaled_dot_product_flash_attention_backward_1 = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_functional_collectives.py:229 in all_gather_tensor, code: tensor = torch.ops._c10d_functional.all_gather_into_tensor( | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] all_gather_into_tensor_3: "bf16[51384320]" = torch.ops._c10d_functional.all_gather_into_tensor.default(getitem_201, 8, '0'); getitem_201 = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_functional_collectives.py:144 in wait_tensor, code: return torch.ops._c10d_functional.wait_tensor(tensor) # type: ignore[attr-defined] | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] wait_tensor_3: "bf16[51384320]" = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_3); all_gather_into_tensor_3 = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:170 in foreach_all_gather_copy_out, code: split_unpadded = torch.ops.fsdp.contiguous_view_as_strided( | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] view_51: "bf16[8, 6423040]" = torch.ops.aten.reshape.default(wait_tensor_3, [8, -1]) | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] split_with_sizes_25 = torch.ops.aten.split_with_sizes.default(view_51, [524288, 524288, 524288, 524288, 1441792, 1441792, 1441792, 256, 256], 1); view_51 = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] getitem_215: "bf16[8, 524288]" = split_with_sizes_25[0]; split_with_sizes_25 = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] view_52: "bf16[8, 6423040]" = torch.ops.aten.reshape.default(wait_tensor_3, [8, -1]) | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] split_with_sizes_26 = torch.ops.aten.split_with_sizes.default(view_52, [524288, 524288, 524288, 524288, 1441792, 1441792, 1441792, 256, 256], 1); view_52 = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] getitem_225: "bf16[8, 524288]" = split_with_sizes_26[1]; split_with_sizes_26 = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] view_53: "bf16[8, 6423040]" = torch.ops.aten.reshape.default(wait_tensor_3, [8, -1]) | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] split_with_sizes_27 = torch.ops.aten.split_with_sizes.default(view_53, [524288, 524288, 524288, 524288, 1441792, 1441792, 1441792, 256, 256], 1); view_53 = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] getitem_235: "bf16[8, 524288]" = split_with_sizes_27[2]; split_with_sizes_27 = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] view_54: "bf16[8, 6423040]" = torch.ops.aten.reshape.default(wait_tensor_3, [8, -1]) | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] split_with_sizes_28 = torch.ops.aten.split_with_sizes.default(view_54, [524288, 524288, 524288, 524288, 1441792, 1441792, 1441792, 256, 256], 1); view_54 = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] getitem_245: "bf16[8, 524288]" = split_with_sizes_28[3]; split_with_sizes_28 = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] view_55: "bf16[8, 6423040]" = torch.ops.aten.reshape.default(wait_tensor_3, [8, -1]) | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] split_with_sizes_29 = torch.ops.aten.split_with_sizes.default(view_55, [524288, 524288, 524288, 524288, 1441792, 1441792, 1441792, 256, 256], 1); view_55 = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] getitem_255: "bf16[8, 1441792]" = split_with_sizes_29[4]; split_with_sizes_29 = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] view_56: "bf16[8, 6423040]" = torch.ops.aten.reshape.default(wait_tensor_3, [8, -1]) | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] split_with_sizes_30 = torch.ops.aten.split_with_sizes.default(view_56, [524288, 524288, 524288, 524288, 1441792, 1441792, 1441792, 256, 256], 1); view_56 = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] getitem_265: "bf16[8, 1441792]" = split_with_sizes_30[5]; split_with_sizes_30 = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] view_57: "bf16[8, 6423040]" = torch.ops.aten.reshape.default(wait_tensor_3, [8, -1]) | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] split_with_sizes_31 = torch.ops.aten.split_with_sizes.default(view_57, [524288, 524288, 524288, 524288, 1441792, 1441792, 1441792, 256, 256], 1); view_57 = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] getitem_275: "bf16[8, 1441792]" = split_with_sizes_31[6]; split_with_sizes_31 = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] view_58: "bf16[8, 6423040]" = torch.ops.aten.reshape.default(wait_tensor_3, [8, -1]) | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] split_with_sizes_32 = torch.ops.aten.split_with_sizes.default(view_58, [524288, 524288, 524288, 524288, 1441792, 1441792, 1441792, 256, 256], 1); view_58 = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] getitem_285: "bf16[8, 256]" = split_with_sizes_32[7]; split_with_sizes_32 = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] view_59: "bf16[8, 6423040]" = torch.ops.aten.reshape.default(wait_tensor_3, [8, -1]); wait_tensor_3 = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] split_with_sizes_33 = torch.ops.aten.split_with_sizes.default(view_59, [524288, 524288, 524288, 524288, 1441792, 1441792, 1441792, 256, 256], 1); view_59 = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] getitem_295: "bf16[8, 256]" = split_with_sizes_33[8]; split_with_sizes_33 = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_common.py:140 in _to_dtype_if_needed, code: return tensor.to(dtype) | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] convert_element_type_219: "bf16[524288]" = torch.ops.prims.convert_element_type.default(arg511_1, torch.bfloat16); arg511_1 = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] convert_element_type_220: "bf16[524288]" = torch.ops.prims.convert_element_type.default(arg512_1, torch.bfloat16); arg512_1 = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] convert_element_type_221: "bf16[524288]" = torch.ops.prims.convert_element_type.default(arg513_1, torch.bfloat16); arg513_1 = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] convert_element_type_222: "bf16[524288]" = torch.ops.prims.convert_element_type.default(arg514_1, torch.bfloat16); arg514_1 = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] convert_element_type_223: "bf16[1441792]" = torch.ops.prims.convert_element_type.default(arg515_1, torch.bfloat16); arg515_1 = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] convert_element_type_224: "bf16[1441792]" = torch.ops.prims.convert_element_type.default(arg516_1, torch.bfloat16); arg516_1 = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] convert_element_type_225: "bf16[1441792]" = torch.ops.prims.convert_element_type.default(arg517_1, torch.bfloat16); arg517_1 = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] convert_element_type_226: "bf16[256]" = torch.ops.prims.convert_element_type.default(arg518_1, torch.bfloat16); arg518_1 = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] convert_element_type_227: "bf16[256]" = torch.ops.prims.convert_element_type.default(arg519_1, torch.bfloat16); arg519_1 = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:80 in foreach_all_gather, code: all_gather_input, all_gather_output = torch.ops.fsdp.all_gather_copy_in(all_gather_input_numel, world_size, rank, dtype, device, inp_split_sizes, param_all_gather_inputs) | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] all_gather_copy_in_4 = torch.ops.fsdp.all_gather_copy_in.default(6423040, 8, 0, torch.bfloat16, device(type='cuda', index=0), [524288, 524288, 524288, 524288, 1441792, 1441792, 1441792, 256, 256], [convert_element_type_219, convert_element_type_220, convert_element_type_221, convert_element_type_222, convert_element_type_223, convert_element_type_224, convert_element_type_225, convert_element_type_226, convert_element_type_227]); convert_element_type_219 = convert_element_type_220 = convert_element_type_221 = convert_element_type_222 = convert_element_type_223 = convert_element_type_224 = convert_element_type_225 = convert_element_type_226 = convert_element_type_227 = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] getitem_296: "bf16[6423040]" = all_gather_copy_in_4[0] | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] getitem_297: "bf16[51384320]" = all_gather_copy_in_4[1]; all_gather_copy_in_4 = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] # File: <eval_with_key>.7:1193 in forward, code: permute_65 = torch.ops.aten.permute.default(getitem_487, [0, 2, 1, 3]); getitem_487 = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] permute_89: "bf16[1, 2048, 16, 128]" = torch.ops.aten.permute.default(getitem_205, [0, 2, 1, 3]); getitem_205 = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] # File: <eval_with_key>.7:1210 in forward, code: view_22 = torch.ops.aten.view.default(permute_65, [2048, 2048]); permute_65 = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] view_46: "bf16[2048, 2048]" = torch.ops.aten.reshape.default(permute_89, [2048, 2048]); permute_89 = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:170 in foreach_all_gather_copy_out, code: split_unpadded = torch.ops.fsdp.contiguous_view_as_strided( | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] contiguous_view_as_strided_14: "bf16[2048, 2048]" = torch.ops.fsdp.contiguous_view_as_strided.default(getitem_140, [2048, 2048], [2048, 1]); getitem_140 = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] # File: <eval_with_key>.7:1194 in forward, code: permute_66 = torch.ops.aten.permute.default(getitem_486, [0, 2, 1, 3]); getitem_486 = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] permute_90: "bf16[1, 2048, 16, 128]" = torch.ops.aten.permute.default(getitem_204, [0, 2, 1, 3]); getitem_204 = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] # File: <eval_with_key>.7:1196 in forward, code: convert_element_type_123 = torch.ops.prims.convert_element_type.default(permute_66, torch.float32); permute_66 = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] convert_element_type_201: "f32[1, 2048, 16, 128]" = torch.ops.prims.convert_element_type.default(permute_90, torch.float32); permute_90 = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] # File: <eval_with_key>.7:1198 in forward, code: view_18 = torch.ops.aten.view.default(convert_element_type_123, [1, 2048, 16, 64, 2]); convert_element_type_123 = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] view_42: "f32[1, 2048, 16, 64, 2]" = torch.ops.aten.reshape.default(convert_element_type_201, [1, 2048, 16, 64, 2]); convert_element_type_201 = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] # File: <eval_with_key>.7:1199 in forward, code: view_as_complex_2 = torch.ops.aten.view_as_complex.default(view_18); view_18 = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] view_as_complex_2: "c64[1, 2048, 16, 64]" = torch.ops.aten.view_as_complex.default(view_42); view_42 = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] # File: <eval_with_key>.7:1200 in forward, code: mul_99 = torch.ops.aten.mul.Tensor(view_as_complex_2, clone); view_as_complex_2 = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] mul_101: "c64[1, 2048, 16, 64]" = torch.ops.aten.mul.Tensor(view_as_complex_2, arg462_1); view_as_complex_2 = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] # File: <eval_with_key>.7:1204 in forward, code: view_as_real_2 = torch.ops.aten.view_as_real.default(mul_99); mul_99 = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] view_as_real_2: "f32[1, 2048, 16, 64, 2]" = torch.ops.aten.view_as_real.default(mul_101); mul_101 = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] # File: <eval_with_key>.7:1205 in forward, code: view_20 = torch.ops.aten.view.default(view_as_real_2, [1, 2048, 16, 128]); view_as_real_2 = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] view_44: "f32[1, 2048, 16, 128]" = torch.ops.aten.reshape.default(view_as_real_2, [1, 2048, 16, 128]); view_as_real_2 = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] # File: <eval_with_key>.7:1206 in forward, code: convert_element_type_125 = torch.ops.prims.convert_element_type.default(view_20, torch.bfloat16); view_20 = None | |
| [rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inducto |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment