Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Select an option

  • Save yf225/ba9b94194e5ce2cdc930393614fbb7b7 to your computer and use it in GitHub Desktop.

Select an option

Save yf225/ba9b94194e5ce2cdc930393614fbb7b7 to your computer and use it in GitHub Desktop.
This file has been truncated, but you can view the full file.
+ export USE_LIBUV=1
+ USE_LIBUV=1
+ TRAINER_DIR=/home/willfeng/local/torchtrain
+ NGPU=8
+ LOG_RANK=0
+ CONFIG_FILE=./train_configs/llama_1b_full_graph_fsdp.toml
+ torchrun --nproc_per_node=8 --rdzv_endpoint=localhost:5972 --local-ranks-filter 0 --role rank --tee 3 train.py --job.config_file ./train_configs/llama_1b_full_graph_fsdp.toml
W2024-03-28 18:03:14,934.934000 140450963445568 torch/distributed/run.py:757]
W2024-03-28 18:03:14,934.934000 140450963445568 torch/distributed/run.py:757] *****************************************
W2024-03-28 18:03:14,934.934000 140450963445568 torch/distributed/run.py:757] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed.
W2024-03-28 18:03:14,934.934000 140450963445568 torch/distributed/run.py:757] *****************************************
[rank0]:2024-03-28 18:03:19,030 - root - INFO - Starting job: LLaMA 7B training
[rank0]:2024-03-28 18:03:21,163 - root - WARNING - ENV[TORCH_NCCL_ASYNC_ERROR_HANDLING] = 1 will be overridden to 3 based on job config
[rank0]:2024-03-28 18:03:21,166 - root - INFO - Building 1-D device mesh with ['dp'], [8]
[rank0]:2024-03-28 18:03:21,183 - root - INFO - Building sentencepiece tokenizer locally from ./torchtrain/datasets/tokenizer/tokenizer.model
[rank0]:2024-03-28 18:03:21,199 - root - INFO - SentencePieceTokenizer built: #words 32000, BOS ID 1, EOS ID 2
[rank0]:2024-03-28 18:03:21,200 - root - INFO - Preparing alpaca dataset from HuggingFace
[rank0]:2024-03-28 18:03:24,206 - root - INFO - Building llama 1B with ModelArgs(dim=2048, n_layers=18, n_heads=16, n_kv_heads=None, vocab_size=32000, multiple_of=256, ffn_dim_multiplier=None, norm_eps=1e-05, max_batch_size=32, max_seq_len=32768, depth_init=True)
[rank0]:2024-03-28 18:03:24,491 - root - INFO - llama 1B size: 1,055,991,808 total parameters
[rank0]:2024-03-28 18:03:24,491 - root - INFO - GPU capacity: NVIDIA PG509-210 (0) with 79.15GiB memory
[rank0]:2024-03-28 18:03:24,557 - root - INFO - Applied FSDP to the model
[rank0]:2024-03-28 18:03:30,309 - root - INFO - Gradient scaling not enabled
[rank0]:2024-03-28 18:03:30,309 - root - INFO - Metrics logging active. Tensorboard logs will be saved at ./outputs/llama_1b_full_graph_fsdp/tb/20240328-1803
[rank0]:2024-03-28 18:03:30,310 - root - INFO - Compiling model with torch.compile
[rank0]:2024-03-28 18:03:30,311 - root - INFO - Compiling model with torch.compile + full-graph compile FSDP
[rank0]:2024-03-28 18:03:30,372 - root - INFO - Profiling active. Traces will be saved at ./outputs/llama_1b_full_graph_fsdp/profiling/traces
[rank0]:2024-03-28 18:03:40,719 - root - INFO - step: 1 loss: 10.8856 memory: 6.99GiB(8.84%) wps: 198 mfu: 0.46%
[rank0]:2024-03-28 18:03:40,720 - root - INFO - Synchronizing and adjusting timeout for all ProcessGroups to 0:05:00
[rank0]:[rank0]:W2024-03-28 18:03:42,059.059000 139804330415936 torch/_logging/_internal.py:1042] [0/0] Profiler function <class 'torch.autograd.profiler.record_function'> will be ignored
[rank0]:/data/users/willfeng/pytorch_yf225/torch/_inductor/lowering.py:1788: UserWarning: Torchinductor does not support code generation for complex operators. Performance may be worse than eager.
[rank0]: warnings.warn(
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] TRACED GRAPH
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] ===== after FSDP FX passes: =====
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] /data/users/willfeng/pytorch_yf225/torch/fx/_lazy_graph_module.py class GraphModule(torch.nn.Module):
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] def forward(self, primals_1: "i64[1, 2048]", primals_2: "f32[8192000]", primals_3: "f32[256]", primals_4: "f32[8192000]", primals_5: "bf16[32000, 2048]", primals_6: "bf16[2048]", primals_7: "bf16[32000, 2048]", primals_8: "c64[65536, 64]", primals_9: "f32[524288]", primals_10: "f32[524288]", primals_11: "f32[524288]", primals_12: "f32[524288]", primals_13: "f32[1441792]", primals_14: "f32[1441792]", primals_15: "f32[1441792]", primals_16: "f32[256]", primals_17: "f32[256]", primals_18: "bf16[2048, 2048]", primals_19: "bf16[2048, 2048]", primals_20: "bf16[2048, 2048]", primals_21: "bf16[2048, 2048]", primals_22: "bf16[5632, 2048]", primals_23: "bf16[2048, 5632]", primals_24: "bf16[5632, 2048]", primals_25: "bf16[2048]", primals_26: "bf16[2048]", primals_27, primals_28: "f32[524288]", primals_29: "f32[524288]", primals_30: "f32[524288]", primals_31: "f32[524288]", primals_32: "f32[1441792]", primals_33: "f32[1441792]", primals_34: "f32[1441792]", primals_35: "f32[256]", primals_36: "f32[256]", primals_37: "bf16[2048, 2048]", primals_38: "bf16[2048, 2048]", primals_39: "bf16[2048, 2048]", primals_40: "bf16[2048, 2048]", primals_41: "bf16[5632, 2048]", primals_42: "bf16[2048, 5632]", primals_43: "bf16[5632, 2048]", primals_44: "bf16[2048]", primals_45: "bf16[2048]", primals_46: "f32[524288]", primals_47: "f32[524288]", primals_48: "f32[524288]", primals_49: "f32[524288]", primals_50: "f32[1441792]", primals_51: "f32[1441792]", primals_52: "f32[1441792]", primals_53: "f32[256]", primals_54: "f32[256]", primals_55: "bf16[2048, 2048]", primals_56: "bf16[2048, 2048]", primals_57: "bf16[2048, 2048]", primals_58: "bf16[2048, 2048]", primals_59: "bf16[5632, 2048]", primals_60: "bf16[2048, 5632]", primals_61: "bf16[5632, 2048]", primals_62: "bf16[2048]", primals_63: "bf16[2048]", primals_64: "f32[524288]", primals_65: "f32[524288]", primals_66: "f32[524288]", primals_67: "f32[524288]", primals_68: "f32[1441792]", primals_69: "f32[1441792]", primals_70: "f32[1441792]", primals_71: "f32[256]", primals_72: "f32[256]", primals_73: "bf16[2048, 2048]", primals_74: "bf16[2048, 2048]", primals_75: "bf16[2048, 2048]", primals_76: "bf16[2048, 2048]", primals_77: "bf16[5632, 2048]", primals_78: "bf16[2048, 5632]", primals_79: "bf16[5632, 2048]", primals_80: "bf16[2048]", primals_81: "bf16[2048]", primals_82: "f32[524288]", primals_83: "f32[524288]", primals_84: "f32[524288]", primals_85: "f32[524288]", primals_86: "f32[1441792]", primals_87: "f32[1441792]", primals_88: "f32[1441792]", primals_89: "f32[256]", primals_90: "f32[256]", primals_91: "bf16[2048, 2048]", primals_92: "bf16[2048, 2048]", primals_93: "bf16[2048, 2048]", primals_94: "bf16[2048, 2048]", primals_95: "bf16[5632, 2048]", primals_96: "bf16[2048, 5632]", primals_97: "bf16[5632, 2048]", primals_98: "bf16[2048]", primals_99: "bf16[2048]", primals_100: "f32[524288]", primals_101: "f32[524288]", primals_102: "f32[524288]", primals_103: "f32[524288]", primals_104: "f32[1441792]", primals_105: "f32[1441792]", primals_106: "f32[1441792]", primals_107: "f32[256]", primals_108: "f32[256]", primals_109: "bf16[2048, 2048]", primals_110: "bf16[2048, 2048]", primals_111: "bf16[2048, 2048]", primals_112: "bf16[2048, 2048]", primals_113: "bf16[5632, 2048]", primals_114: "bf16[2048, 5632]", primals_115: "bf16[5632, 2048]", primals_116: "bf16[2048]", primals_117: "bf16[2048]", primals_118: "f32[524288]", primals_119: "f32[524288]", primals_120: "f32[524288]", primals_121: "f32[524288]", primals_122: "f32[1441792]", primals_123: "f32[1441792]", primals_124: "f32[1441792]", primals_125: "f32[256]", primals_126: "f32[256]", primals_127: "bf16[2048, 2048]", primals_128: "bf16[2048, 2048]", primals_129: "bf16[2048, 2048]", primals_130: "bf16[2048, 2048]", primals_131: "bf16[5632, 2048]", primals_132: "bf16[2048, 5632]", primals_133: "bf16[5632, 2048]", primals_134: "bf16[2048]", primals_135: "bf16[2048]", primals_136: "f32[524288]", primals_137: "f32[524288]", primals_138: "f32[524288]", primals_139: "f32[524288]", primals_140: "f32[1441792]", primals_141: "f32[1441792]", primals_142: "f32[1441792]", primals_143: "f32[256]", primals_144: "f32[256]", primals_145: "bf16[2048, 2048]", primals_146: "bf16[2048, 2048]", primals_147: "bf16[2048, 2048]", primals_148: "bf16[2048, 2048]", primals_149: "bf16[5632, 2048]", primals_150: "bf16[2048, 5632]", primals_151: "bf16[5632, 2048]", primals_152: "bf16[2048]", primals_153: "bf16[2048]", primals_154: "f32[524288]", primals_155: "f32[524288]", primals_156: "f32[524288]", primals_157: "f32[524288]", primals_158: "f32[1441792]", primals_159: "f32[1441792]", primals_160: "f32[1441792]", primals_161: "f32[256]", primals_162: "f32[256]", primals_163: "bf16[2048, 2048]", primals_164: "bf16[2048, 2048]", primals_165: "bf16[2048, 2048]", primals_166: "bf16[2048, 2048]", primals_167: "bf16[5632, 2048]", primals_168: "bf16[2048, 5632]", primals_169: "bf16[5632, 2048]", primals_170: "bf16[2048]", primals_171: "bf16[2048]", primals_172: "f32[524288]", primals_173: "f32[524288]", primals_174: "f32[524288]", primals_175: "f32[524288]", primals_176: "f32[1441792]", primals_177: "f32[1441792]", primals_178: "f32[1441792]", primals_179: "f32[256]", primals_180: "f32[256]", primals_181: "bf16[2048, 2048]", primals_182: "bf16[2048, 2048]", primals_183: "bf16[2048, 2048]", primals_184: "bf16[2048, 2048]", primals_185: "bf16[5632, 2048]", primals_186: "bf16[2048, 5632]", primals_187: "bf16[5632, 2048]", primals_188: "bf16[2048]", primals_189: "bf16[2048]", primals_190: "f32[524288]", primals_191: "f32[524288]", primals_192: "f32[524288]", primals_193: "f32[524288]", primals_194: "f32[1441792]", primals_195: "f32[1441792]", primals_196: "f32[1441792]", primals_197: "f32[256]", primals_198: "f32[256]", primals_199: "bf16[2048, 2048]", primals_200: "bf16[2048, 2048]", primals_201: "bf16[2048, 2048]", primals_202: "bf16[2048, 2048]", primals_203: "bf16[5632, 2048]", primals_204: "bf16[2048, 5632]", primals_205: "bf16[5632, 2048]", primals_206: "bf16[2048]", primals_207: "bf16[2048]", primals_208: "f32[524288]", primals_209: "f32[524288]", primals_210: "f32[524288]", primals_211: "f32[524288]", primals_212: "f32[1441792]", primals_213: "f32[1441792]", primals_214: "f32[1441792]", primals_215: "f32[256]", primals_216: "f32[256]", primals_217: "bf16[2048, 2048]", primals_218: "bf16[2048, 2048]", primals_219: "bf16[2048, 2048]", primals_220: "bf16[2048, 2048]", primals_221: "bf16[5632, 2048]", primals_222: "bf16[2048, 5632]", primals_223: "bf16[5632, 2048]", primals_224: "bf16[2048]", primals_225: "bf16[2048]", primals_226: "f32[524288]", primals_227: "f32[524288]", primals_228: "f32[524288]", primals_229: "f32[524288]", primals_230: "f32[1441792]", primals_231: "f32[1441792]", primals_232: "f32[1441792]", primals_233: "f32[256]", primals_234: "f32[256]", primals_235: "bf16[2048, 2048]", primals_236: "bf16[2048, 2048]", primals_237: "bf16[2048, 2048]", primals_238: "bf16[2048, 2048]", primals_239: "bf16[5632, 2048]", primals_240: "bf16[2048, 5632]", primals_241: "bf16[5632, 2048]", primals_242: "bf16[2048]", primals_243: "bf16[2048]", primals_244: "f32[524288]", primals_245: "f32[524288]", primals_246: "f32[524288]", primals_247: "f32[524288]", primals_248: "f32[1441792]", primals_249: "f32[1441792]", primals_250: "f32[1441792]", primals_251: "f32[256]", primals_252: "f32[256]", primals_253: "bf16[2048, 2048]", primals_254: "bf16[2048, 2048]", primals_255: "bf16[2048, 2048]", primals_256: "bf16[2048, 2048]", primals_257: "bf16[5632, 2048]", primals_258: "bf16[2048, 5632]", primals_259: "bf16[5632, 2048]", primals_260: "bf16[2048]", primals_261: "bf16[2048]", primals_262: "f32[524288]", primals_263: "f32[524288]", primals_264: "f32[524288]", primals_265: "f32[524288]", primals_266: "f32[1441792]", primals_267: "f32[1441792]", primals_268: "f32[1441792]", primals_269: "f32[256]", primals_270: "f32[256]", primals_271: "bf16[2048, 2048]", primals_272: "bf16[2048, 2048]", primals_273: "bf16[2048, 2048]", primals_274: "bf16[2048, 2048]", primals_275: "bf16[5632, 2048]", primals_276: "bf16[2048, 5632]", primals_277: "bf16[5632, 2048]", primals_278: "bf16[2048]", primals_279: "bf16[2048]", primals_280: "f32[524288]", primals_281: "f32[524288]", primals_282: "f32[524288]", primals_283: "f32[524288]", primals_284: "f32[1441792]", primals_285: "f32[1441792]", primals_286: "f32[1441792]", primals_287: "f32[256]", primals_288: "f32[256]", primals_289: "bf16[2048, 2048]", primals_290: "bf16[2048, 2048]", primals_291: "bf16[2048, 2048]", primals_292: "bf16[2048, 2048]", primals_293: "bf16[5632, 2048]", primals_294: "bf16[2048, 5632]", primals_295: "bf16[5632, 2048]", primals_296: "bf16[2048]", primals_297: "bf16[2048]", primals_298: "f32[524288]", primals_299: "f32[524288]", primals_300: "f32[524288]", primals_301: "f32[524288]", primals_302: "f32[1441792]", primals_303: "f32[1441792]", primals_304: "f32[1441792]", primals_305: "f32[256]", primals_306: "f32[256]", primals_307: "bf16[2048, 2048]", primals_308: "bf16[2048, 2048]", primals_309: "bf16[2048, 2048]", primals_310: "bf16[2048, 2048]", primals_311: "bf16[5632, 2048]", primals_312: "bf16[2048, 5632]", primals_313: "bf16[5632, 2048]", primals_314: "bf16[2048]", primals_315: "bf16[2048]", primals_316: "f32[524288]", primals_317: "f32[524288]", primals_318: "f32[524288]", primals_319: "f32[524288]", primals_320: "f32[1441792]", primals_321: "f32[1441792]", primals_322: "f32[1441792]", primals_323: "f32[256]", primals_324: "f32[256]", primals_325: "bf16[2048, 2048]", primals_326: "bf16[2048, 2048]", primals_327: "bf16[2048, 2048]", primals_328: "bf16[2048, 2048]", primals_329: "bf16[5632, 2048]", primals_330: "bf16[2048, 5632]", primals_331: "bf16[5632, 2048]", primals_332: "bf16[2048]", primals_333: "bf16[2048]"):
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_common.py:140 in _to_dtype_if_needed, code: return tensor.to(dtype)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type: "bf16[8192000]" = torch.ops.prims.convert_element_type.default(primals_2, torch.bfloat16); primals_2 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_1: "bf16[256]" = torch.ops.prims.convert_element_type.default(primals_3, torch.bfloat16); primals_3 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_2: "bf16[8192000]" = torch.ops.prims.convert_element_type.default(primals_4, torch.bfloat16); primals_4 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:80 in foreach_all_gather, code: all_gather_input, all_gather_output = torch.ops.fsdp.all_gather_copy_in(all_gather_input_numel, world_size, rank, dtype, device, inp_split_sizes, param_all_gather_inputs)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] all_gather_copy_in = torch.ops.fsdp.all_gather_copy_in.default(16384256, 8, 0, torch.bfloat16, device(type='cuda', index=0), [8192000, 256, 8192000], [convert_element_type, convert_element_type_1, convert_element_type_2]); convert_element_type = convert_element_type_1 = convert_element_type_2 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem: "bf16[16384256]" = all_gather_copy_in[0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_1: "bf16[131074048]" = all_gather_copy_in[1]; all_gather_copy_in = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_functional_collectives.py:229 in all_gather_tensor, code: tensor = torch.ops._c10d_functional.all_gather_into_tensor(
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] all_gather_into_tensor: "bf16[131074048]" = torch.ops._c10d_functional.all_gather_into_tensor.default(getitem, 8, '0'); getitem = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_functional_collectives.py:144 in wait_tensor, code: return torch.ops._c10d_functional.wait_tensor(tensor) # type: ignore[attr-defined]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] wait_tensor: "bf16[131074048]" = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor); all_gather_into_tensor = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:170 in foreach_all_gather_copy_out, code: split_unpadded = torch.ops.fsdp.contiguous_view_as_strided(
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_1: "bf16[8, 16384256]" = torch.ops.aten.reshape.default(wait_tensor, [8, -1]); wait_tensor = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] split_with_sizes_1 = torch.ops.aten.split_with_sizes.default(view_1, [8192000, 256, 8192000], 1); view_1 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_5: "bf16[8, 8192000]" = split_with_sizes_1[0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] contiguous_view_as_strided: "bf16[32000, 2048]" = torch.ops.fsdp.contiguous_view_as_strided.default(getitem_5, [32000, 2048], [2048, 1]); getitem_5 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_9: "bf16[8, 256]" = split_with_sizes_1[1]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] contiguous_view_as_strided_1: "bf16[2048]" = torch.ops.fsdp.contiguous_view_as_strided.default(getitem_9, [2048], [1]); getitem_9 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_13: "bf16[8, 8192000]" = split_with_sizes_1[2]; split_with_sizes_1 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] contiguous_view_as_strided_2: "bf16[32000, 2048]" = torch.ops.fsdp.contiguous_view_as_strided.default(getitem_13, [32000, 2048], [2048, 1]); getitem_13 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/pytorch_yf225/torch/nn/modules/sparse.py:163 in forward, code: return F.embedding(
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] embedding: "bf16[1, 2048, 2048]" = torch.ops.aten.embedding.default(contiguous_view_as_strided, primals_1); contiguous_view_as_strided = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:338 in forward, code: freqs_cis = self.freqs_cis[0:seqlen]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] slice_1: "c64[2048, 64]" = torch.ops.aten.slice.Tensor(primals_8, 0, 0, 2048); primals_8 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:484 in forward, code: h = h.view(-1, self.model_args.dim)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_4: "bf16[2048, 2048]" = torch.ops.aten.reshape.default(embedding, [-1, 2048])
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_common.py:140 in _to_dtype_if_needed, code: return tensor.to(dtype)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_3: "bf16[524288]" = torch.ops.prims.convert_element_type.default(primals_9, torch.bfloat16); primals_9 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_4: "bf16[524288]" = torch.ops.prims.convert_element_type.default(primals_10, torch.bfloat16); primals_10 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_5: "bf16[524288]" = torch.ops.prims.convert_element_type.default(primals_11, torch.bfloat16); primals_11 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_6: "bf16[524288]" = torch.ops.prims.convert_element_type.default(primals_12, torch.bfloat16); primals_12 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_7: "bf16[1441792]" = torch.ops.prims.convert_element_type.default(primals_13, torch.bfloat16); primals_13 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_8: "bf16[1441792]" = torch.ops.prims.convert_element_type.default(primals_14, torch.bfloat16); primals_14 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_9: "bf16[1441792]" = torch.ops.prims.convert_element_type.default(primals_15, torch.bfloat16); primals_15 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_10: "bf16[256]" = torch.ops.prims.convert_element_type.default(primals_16, torch.bfloat16); primals_16 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_11: "bf16[256]" = torch.ops.prims.convert_element_type.default(primals_17, torch.bfloat16); primals_17 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:80 in foreach_all_gather, code: all_gather_input, all_gather_output = torch.ops.fsdp.all_gather_copy_in(all_gather_input_numel, world_size, rank, dtype, device, inp_split_sizes, param_all_gather_inputs)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] all_gather_copy_in_1 = torch.ops.fsdp.all_gather_copy_in.default(6423040, 8, 0, torch.bfloat16, device(type='cuda', index=0), [524288, 524288, 524288, 524288, 1441792, 1441792, 1441792, 256, 256], [convert_element_type_3, convert_element_type_4, convert_element_type_5, convert_element_type_6, convert_element_type_7, convert_element_type_8, convert_element_type_9, convert_element_type_10, convert_element_type_11]); convert_element_type_3 = convert_element_type_4 = convert_element_type_5 = convert_element_type_6 = convert_element_type_7 = convert_element_type_8 = convert_element_type_9 = convert_element_type_10 = convert_element_type_11 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_14: "bf16[6423040]" = all_gather_copy_in_1[0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_15: "bf16[51384320]" = all_gather_copy_in_1[1]; all_gather_copy_in_1 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_functional_collectives.py:229 in all_gather_tensor, code: tensor = torch.ops._c10d_functional.all_gather_into_tensor(
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] all_gather_into_tensor_1: "bf16[51384320]" = torch.ops._c10d_functional.all_gather_into_tensor.default(getitem_14, 8, '0'); getitem_14 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_functional_collectives.py:144 in wait_tensor, code: return torch.ops._c10d_functional.wait_tensor(tensor) # type: ignore[attr-defined]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] wait_tensor_1: "bf16[51384320]" = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_1); all_gather_into_tensor_1 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:170 in foreach_all_gather_copy_out, code: split_unpadded = torch.ops.fsdp.contiguous_view_as_strided(
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_6: "bf16[8, 6423040]" = torch.ops.aten.reshape.default(wait_tensor_1, [8, -1]); wait_tensor_1 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] split_with_sizes_5 = torch.ops.aten.split_with_sizes.default(view_6, [524288, 524288, 524288, 524288, 1441792, 1441792, 1441792, 256, 256], 1); view_6 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_25: "bf16[8, 524288]" = split_with_sizes_5[0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] contiguous_view_as_strided_3: "bf16[2048, 2048]" = torch.ops.fsdp.contiguous_view_as_strided.default(getitem_25, [2048, 2048], [2048, 1]); getitem_25 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_35: "bf16[8, 524288]" = split_with_sizes_5[1]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] contiguous_view_as_strided_4: "bf16[2048, 2048]" = torch.ops.fsdp.contiguous_view_as_strided.default(getitem_35, [2048, 2048], [2048, 1]); getitem_35 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_45: "bf16[8, 524288]" = split_with_sizes_5[2]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] contiguous_view_as_strided_5: "bf16[2048, 2048]" = torch.ops.fsdp.contiguous_view_as_strided.default(getitem_45, [2048, 2048], [2048, 1]); getitem_45 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_55: "bf16[8, 524288]" = split_with_sizes_5[3]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] contiguous_view_as_strided_6: "bf16[2048, 2048]" = torch.ops.fsdp.contiguous_view_as_strided.default(getitem_55, [2048, 2048], [2048, 1]); getitem_55 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_65: "bf16[8, 1441792]" = split_with_sizes_5[4]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] contiguous_view_as_strided_7: "bf16[5632, 2048]" = torch.ops.fsdp.contiguous_view_as_strided.default(getitem_65, [5632, 2048], [2048, 1]); getitem_65 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_75: "bf16[8, 1441792]" = split_with_sizes_5[5]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] contiguous_view_as_strided_8: "bf16[2048, 5632]" = torch.ops.fsdp.contiguous_view_as_strided.default(getitem_75, [2048, 5632], [5632, 1]); getitem_75 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_85: "bf16[8, 1441792]" = split_with_sizes_5[6]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] contiguous_view_as_strided_9: "bf16[5632, 2048]" = torch.ops.fsdp.contiguous_view_as_strided.default(getitem_85, [5632, 2048], [2048, 1]); getitem_85 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_95: "bf16[8, 256]" = split_with_sizes_5[7]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] contiguous_view_as_strided_10: "bf16[2048]" = torch.ops.fsdp.contiguous_view_as_strided.default(getitem_95, [2048], [1]); getitem_95 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_105: "bf16[8, 256]" = split_with_sizes_5[8]; split_with_sizes_5 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] contiguous_view_as_strided_11: "bf16[2048]" = torch.ops.fsdp.contiguous_view_as_strided.default(getitem_105, [2048], [1]); getitem_105 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:74 in forward, code: output = self._norm(x.float()).type_as(x)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_12: "f32[2048, 2048]" = torch.ops.prims.convert_element_type.default(view_4, torch.float32)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:61 in _norm, code: return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] pow_1: "f32[2048, 2048]" = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_12, 2)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mean: "f32[2048, 1]" = torch.ops.aten.mean.dim(pow_1, [-1], True); pow_1 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] add: "f32[2048, 1]" = torch.ops.aten.add.Tensor(mean, 1e-05); mean = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] rsqrt: "f32[2048, 1]" = torch.ops.aten.rsqrt.default(add); add = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mul: "f32[2048, 2048]" = torch.ops.aten.mul.Tensor(convert_element_type_12, rsqrt); convert_element_type_12 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:74 in forward, code: output = self._norm(x.float()).type_as(x)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_13: "bf16[2048, 2048]" = torch.ops.prims.convert_element_type.default(mul, torch.bfloat16); mul = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:75 in forward, code: return output * self.weight
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mul_1: "bf16[2048, 2048]" = torch.ops.aten.mul.Tensor(convert_element_type_13, contiguous_view_as_strided_10); convert_element_type_13 = contiguous_view_as_strided_10 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/pytorch_yf225/torch/nn/modules/linear.py:116 in forward, code: return F.linear(input, self.weight, self.bias)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] permute_1: "bf16[2048, 2048]" = torch.ops.aten.permute.default(contiguous_view_as_strided_3, [1, 0]); contiguous_view_as_strided_3 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mm: "bf16[2048, 2048]" = torch.ops.aten.mm.default(mul_1, permute_1); permute_1 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] permute_3: "bf16[2048, 2048]" = torch.ops.aten.permute.default(contiguous_view_as_strided_4, [1, 0]); contiguous_view_as_strided_4 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mm_1: "bf16[2048, 2048]" = torch.ops.aten.mm.default(mul_1, permute_3); permute_3 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] permute_5: "bf16[2048, 2048]" = torch.ops.aten.permute.default(contiguous_view_as_strided_5, [1, 0]); contiguous_view_as_strided_5 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mm_2: "bf16[2048, 2048]" = torch.ops.aten.mm.default(mul_1, permute_5); permute_5 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:235 in forward, code: xq = xq.view(bsz, seqlen, self.n_heads, self.head_dim)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_15: "bf16[1, 2048, 16, 128]" = torch.ops.aten.reshape.default(mm, [1, 2048, 16, 128]); mm = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:236 in forward, code: xk = xk.view(bsz, seqlen, self.n_kv_heads, self.head_dim)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_16: "bf16[1, 2048, 16, 128]" = torch.ops.aten.reshape.default(mm_1, [1, 2048, 16, 128]); mm_1 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:237 in forward, code: xv = xv.view(bsz, seqlen, self.n_kv_heads, self.head_dim)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_17: "bf16[1, 2048, 16, 128]" = torch.ops.aten.reshape.default(mm_2, [1, 2048, 16, 128]); mm_2 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:146 in apply_rotary_emb, code: xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_20: "f32[1, 2048, 16, 128]" = torch.ops.prims.convert_element_type.default(view_15, torch.float32); view_15 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_18: "f32[1, 2048, 16, 64, 2]" = torch.ops.aten.reshape.default(convert_element_type_20, [1, 2048, 16, -1, 2]); convert_element_type_20 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_as_complex: "c64[1, 2048, 16, 64]" = torch.ops.aten.view_as_complex.default(view_18); view_18 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:147 in apply_rotary_emb, code: xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_21: "f32[1, 2048, 16, 128]" = torch.ops.prims.convert_element_type.default(view_16, torch.float32); view_16 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_19: "f32[1, 2048, 16, 64, 2]" = torch.ops.aten.reshape.default(convert_element_type_21, [1, 2048, 16, -1, 2]); convert_element_type_21 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_as_complex_1: "c64[1, 2048, 16, 64]" = torch.ops.aten.view_as_complex.default(view_19); view_19 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:122 in reshape_for_broadcast, code: return freqs_cis.view(*shape)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_20: "c64[1, 2048, 1, 64]" = torch.ops.aten.reshape.default(slice_1, [1, 2048, 1, 64]); slice_1 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:149 in apply_rotary_emb, code: xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mul_2: "c64[1, 2048, 16, 64]" = torch.ops.aten.mul.Tensor(view_as_complex, view_20); view_as_complex = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_as_real: "f32[1, 2048, 16, 64, 2]" = torch.ops.aten.view_as_real.default(mul_2); mul_2 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_21: "f32[1, 2048, 16, 128]" = torch.ops.aten.reshape.default(view_as_real, [1, 2048, 16, 128]); view_as_real = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:150 in apply_rotary_emb, code: xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mul_3: "c64[1, 2048, 16, 64]" = torch.ops.aten.mul.Tensor(view_as_complex_1, view_20); view_as_complex_1 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_as_real_1: "f32[1, 2048, 16, 64, 2]" = torch.ops.aten.view_as_real.default(mul_3); mul_3 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_22: "f32[1, 2048, 16, 128]" = torch.ops.aten.reshape.default(view_as_real_1, [1, 2048, 16, 128]); view_as_real_1 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:151 in apply_rotary_emb, code: return xq_out.type_as(xq), xk_out.type_as(xk)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_22: "bf16[1, 2048, 16, 128]" = torch.ops.prims.convert_element_type.default(view_21, torch.bfloat16); view_21 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_23: "bf16[1, 2048, 16, 128]" = torch.ops.prims.convert_element_type.default(view_22, torch.bfloat16); view_22 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:249 in forward, code: xq = xq.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] permute_6: "bf16[1, 16, 2048, 128]" = torch.ops.aten.permute.default(convert_element_type_22, [0, 2, 1, 3]); convert_element_type_22 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:250 in forward, code: xk = keys.transpose(1, 2) # (bs, n_local_heads, cache_len + seqlen, head_dim)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] permute_7: "bf16[1, 16, 2048, 128]" = torch.ops.aten.permute.default(convert_element_type_23, [0, 2, 1, 3]); convert_element_type_23 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:251 in forward, code: xv = values.transpose(1, 2) # (bs, n_local_heads, cache_len + seqlen, head_dim)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] permute_8: "bf16[1, 16, 2048, 128]" = torch.ops.aten.permute.default(view_17, [0, 2, 1, 3]); view_17 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:254 in forward, code: output = F.scaled_dot_product_attention(xq, xk, xv, is_causal=True)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] _scaled_dot_product_flash_attention = torch.ops.aten._scaled_dot_product_flash_attention.default(permute_6, permute_7, permute_8, 0.0, True, scale = 0.08838834764831843)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_106: "bf16[1, 16, 2048, 128]" = _scaled_dot_product_flash_attention[0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_107: "f32[1, 16, 2048]" = _scaled_dot_product_flash_attention[1]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_112: "i64[]" = _scaled_dot_product_flash_attention[6]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_113: "i64[]" = _scaled_dot_product_flash_attention[7]; _scaled_dot_product_flash_attention = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:255 in forward, code: output = output.transpose(
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] permute_9: "bf16[1, 2048, 16, 128]" = torch.ops.aten.permute.default(getitem_106, [0, 2, 1, 3])
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:259 in forward, code: output = output.view(bsz * seqlen, -1)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_23: "bf16[2048, 2048]" = torch.ops.aten.reshape.default(permute_9, [2048, -1]); permute_9 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/pytorch_yf225/torch/nn/modules/linear.py:116 in forward, code: return F.linear(input, self.weight, self.bias)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] permute_11: "bf16[2048, 2048]" = torch.ops.aten.permute.default(contiguous_view_as_strided_6, [1, 0]); contiguous_view_as_strided_6 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mm_3: "bf16[2048, 2048]" = torch.ops.aten.mm.default(view_23, permute_11); permute_11 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:410 in forward, code: h = x + self.attention(self.attention_norm(x), freqs_cis)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] add_1: "bf16[2048, 2048]" = torch.ops.aten.add.Tensor(view_4, mm_3); view_4 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:74 in forward, code: output = self._norm(x.float()).type_as(x)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_26: "f32[2048, 2048]" = torch.ops.prims.convert_element_type.default(add_1, torch.float32)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:61 in _norm, code: return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] pow_2: "f32[2048, 2048]" = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_26, 2)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mean_1: "f32[2048, 1]" = torch.ops.aten.mean.dim(pow_2, [-1], True); pow_2 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] add_2: "f32[2048, 1]" = torch.ops.aten.add.Tensor(mean_1, 1e-05); mean_1 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] rsqrt_1: "f32[2048, 1]" = torch.ops.aten.rsqrt.default(add_2); add_2 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mul_4: "f32[2048, 2048]" = torch.ops.aten.mul.Tensor(convert_element_type_26, rsqrt_1); convert_element_type_26 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:74 in forward, code: output = self._norm(x.float()).type_as(x)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_27: "bf16[2048, 2048]" = torch.ops.prims.convert_element_type.default(mul_4, torch.bfloat16); mul_4 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:75 in forward, code: return output * self.weight
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mul_5: "bf16[2048, 2048]" = torch.ops.aten.mul.Tensor(convert_element_type_27, contiguous_view_as_strided_11); convert_element_type_27 = contiguous_view_as_strided_11 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/pytorch_yf225/torch/nn/modules/linear.py:116 in forward, code: return F.linear(input, self.weight, self.bias)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] permute_13: "bf16[2048, 5632]" = torch.ops.aten.permute.default(contiguous_view_as_strided_7, [1, 0]); contiguous_view_as_strided_7 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mm_4: "bf16[2048, 5632]" = torch.ops.aten.mm.default(mul_5, permute_13); permute_13 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:299 in forward, code: return self.w2(F.silu(self.w1(x)) * self.w3(x))
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_30: "f32[2048, 5632]" = torch.ops.prims.convert_element_type.default(mm_4, torch.float32)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] sigmoid: "f32[2048, 5632]" = torch.ops.aten.sigmoid.default(convert_element_type_30)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mul_6: "f32[2048, 5632]" = torch.ops.aten.mul.Tensor(convert_element_type_30, sigmoid); convert_element_type_30 = sigmoid = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_31: "bf16[2048, 5632]" = torch.ops.prims.convert_element_type.default(mul_6, torch.bfloat16); mul_6 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/pytorch_yf225/torch/nn/modules/linear.py:116 in forward, code: return F.linear(input, self.weight, self.bias)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] permute_15: "bf16[2048, 5632]" = torch.ops.aten.permute.default(contiguous_view_as_strided_9, [1, 0]); contiguous_view_as_strided_9 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mm_5: "bf16[2048, 5632]" = torch.ops.aten.mm.default(mul_5, permute_15); permute_15 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:299 in forward, code: return self.w2(F.silu(self.w1(x)) * self.w3(x))
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mul_7: "bf16[2048, 5632]" = torch.ops.aten.mul.Tensor(convert_element_type_31, mm_5); convert_element_type_31 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/pytorch_yf225/torch/nn/modules/linear.py:116 in forward, code: return F.linear(input, self.weight, self.bias)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] permute_17: "bf16[5632, 2048]" = torch.ops.aten.permute.default(contiguous_view_as_strided_8, [1, 0]); contiguous_view_as_strided_8 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mm_6: "bf16[2048, 2048]" = torch.ops.aten.mm.default(mul_7, permute_17); permute_17 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:411 in forward, code: out = h + self.feed_forward(self.ffn_norm(h))
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] add_3: "bf16[2048, 2048]" = torch.ops.aten.add.Tensor(add_1, mm_6); add_1 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_common.py:140 in _to_dtype_if_needed, code: return tensor.to(dtype)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_36: "bf16[524288]" = torch.ops.prims.convert_element_type.default(primals_28, torch.bfloat16); primals_28 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_37: "bf16[524288]" = torch.ops.prims.convert_element_type.default(primals_29, torch.bfloat16); primals_29 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_38: "bf16[524288]" = torch.ops.prims.convert_element_type.default(primals_30, torch.bfloat16); primals_30 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_39: "bf16[524288]" = torch.ops.prims.convert_element_type.default(primals_31, torch.bfloat16); primals_31 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_40: "bf16[1441792]" = torch.ops.prims.convert_element_type.default(primals_32, torch.bfloat16); primals_32 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_41: "bf16[1441792]" = torch.ops.prims.convert_element_type.default(primals_33, torch.bfloat16); primals_33 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_42: "bf16[1441792]" = torch.ops.prims.convert_element_type.default(primals_34, torch.bfloat16); primals_34 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_43: "bf16[256]" = torch.ops.prims.convert_element_type.default(primals_35, torch.bfloat16); primals_35 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_44: "bf16[256]" = torch.ops.prims.convert_element_type.default(primals_36, torch.bfloat16); primals_36 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:80 in foreach_all_gather, code: all_gather_input, all_gather_output = torch.ops.fsdp.all_gather_copy_in(all_gather_input_numel, world_size, rank, dtype, device, inp_split_sizes, param_all_gather_inputs)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] all_gather_copy_in_2 = torch.ops.fsdp.all_gather_copy_in.default(6423040, 8, 0, torch.bfloat16, device(type='cuda', index=0), [524288, 524288, 524288, 524288, 1441792, 1441792, 1441792, 256, 256], [convert_element_type_36, convert_element_type_37, convert_element_type_38, convert_element_type_39, convert_element_type_40, convert_element_type_41, convert_element_type_42, convert_element_type_43, convert_element_type_44]); convert_element_type_36 = convert_element_type_37 = convert_element_type_38 = convert_element_type_39 = convert_element_type_40 = convert_element_type_41 = convert_element_type_42 = convert_element_type_43 = convert_element_type_44 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_115: "bf16[6423040]" = all_gather_copy_in_2[0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_116: "bf16[51384320]" = all_gather_copy_in_2[1]; all_gather_copy_in_2 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_functional_collectives.py:229 in all_gather_tensor, code: tensor = torch.ops._c10d_functional.all_gather_into_tensor(
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] all_gather_into_tensor_2: "bf16[51384320]" = torch.ops._c10d_functional.all_gather_into_tensor.default(getitem_115, 8, '0'); getitem_115 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_functional_collectives.py:144 in wait_tensor, code: return torch.ops._c10d_functional.wait_tensor(tensor) # type: ignore[attr-defined]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] wait_tensor_2: "bf16[51384320]" = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_2); all_gather_into_tensor_2 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:170 in foreach_all_gather_copy_out, code: split_unpadded = torch.ops.fsdp.contiguous_view_as_strided(
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_25: "bf16[8, 6423040]" = torch.ops.aten.reshape.default(wait_tensor_2, [8, -1]); wait_tensor_2 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] split_with_sizes_15 = torch.ops.aten.split_with_sizes.default(view_25, [524288, 524288, 524288, 524288, 1441792, 1441792, 1441792, 256, 256], 1); view_25 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_126: "bf16[8, 524288]" = split_with_sizes_15[0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] contiguous_view_as_strided_12: "bf16[2048, 2048]" = torch.ops.fsdp.contiguous_view_as_strided.default(getitem_126, [2048, 2048], [2048, 1]); getitem_126 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_136: "bf16[8, 524288]" = split_with_sizes_15[1]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] contiguous_view_as_strided_13: "bf16[2048, 2048]" = torch.ops.fsdp.contiguous_view_as_strided.default(getitem_136, [2048, 2048], [2048, 1]); getitem_136 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_146: "bf16[8, 524288]" = split_with_sizes_15[2]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] contiguous_view_as_strided_14: "bf16[2048, 2048]" = torch.ops.fsdp.contiguous_view_as_strided.default(getitem_146, [2048, 2048], [2048, 1]); getitem_146 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_156: "bf16[8, 524288]" = split_with_sizes_15[3]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] contiguous_view_as_strided_15: "bf16[2048, 2048]" = torch.ops.fsdp.contiguous_view_as_strided.default(getitem_156, [2048, 2048], [2048, 1]); getitem_156 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_166: "bf16[8, 1441792]" = split_with_sizes_15[4]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] contiguous_view_as_strided_16: "bf16[5632, 2048]" = torch.ops.fsdp.contiguous_view_as_strided.default(getitem_166, [5632, 2048], [2048, 1]); getitem_166 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_176: "bf16[8, 1441792]" = split_with_sizes_15[5]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] contiguous_view_as_strided_17: "bf16[2048, 5632]" = torch.ops.fsdp.contiguous_view_as_strided.default(getitem_176, [2048, 5632], [5632, 1]); getitem_176 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_186: "bf16[8, 1441792]" = split_with_sizes_15[6]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] contiguous_view_as_strided_18: "bf16[5632, 2048]" = torch.ops.fsdp.contiguous_view_as_strided.default(getitem_186, [5632, 2048], [2048, 1]); getitem_186 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_196: "bf16[8, 256]" = split_with_sizes_15[7]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] contiguous_view_as_strided_19: "bf16[2048]" = torch.ops.fsdp.contiguous_view_as_strided.default(getitem_196, [2048], [1]); getitem_196 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_206: "bf16[8, 256]" = split_with_sizes_15[8]; split_with_sizes_15 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] contiguous_view_as_strided_20: "bf16[2048]" = torch.ops.fsdp.contiguous_view_as_strided.default(getitem_206, [2048], [1]); getitem_206 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:74 in forward, code: output = self._norm(x.float()).type_as(x)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_45: "f32[2048, 2048]" = torch.ops.prims.convert_element_type.default(add_3, torch.float32)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:61 in _norm, code: return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] pow_3: "f32[2048, 2048]" = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_45, 2)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mean_2: "f32[2048, 1]" = torch.ops.aten.mean.dim(pow_3, [-1], True); pow_3 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] add_4: "f32[2048, 1]" = torch.ops.aten.add.Tensor(mean_2, 1e-05); mean_2 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] rsqrt_2: "f32[2048, 1]" = torch.ops.aten.rsqrt.default(add_4); add_4 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mul_8: "f32[2048, 2048]" = torch.ops.aten.mul.Tensor(convert_element_type_45, rsqrt_2); convert_element_type_45 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:74 in forward, code: output = self._norm(x.float()).type_as(x)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_46: "bf16[2048, 2048]" = torch.ops.prims.convert_element_type.default(mul_8, torch.bfloat16); mul_8 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:75 in forward, code: return output * self.weight
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mul_9: "bf16[2048, 2048]" = torch.ops.aten.mul.Tensor(convert_element_type_46, contiguous_view_as_strided_19); convert_element_type_46 = contiguous_view_as_strided_19 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/pytorch_yf225/torch/nn/modules/linear.py:116 in forward, code: return F.linear(input, self.weight, self.bias)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] permute_19: "bf16[2048, 2048]" = torch.ops.aten.permute.default(contiguous_view_as_strided_12, [1, 0]); contiguous_view_as_strided_12 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mm_7: "bf16[2048, 2048]" = torch.ops.aten.mm.default(mul_9, permute_19); permute_19 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] permute_21: "bf16[2048, 2048]" = torch.ops.aten.permute.default(contiguous_view_as_strided_13, [1, 0]); contiguous_view_as_strided_13 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mm_8: "bf16[2048, 2048]" = torch.ops.aten.mm.default(mul_9, permute_21); permute_21 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] permute_23: "bf16[2048, 2048]" = torch.ops.aten.permute.default(contiguous_view_as_strided_14, [1, 0]); contiguous_view_as_strided_14 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mm_9: "bf16[2048, 2048]" = torch.ops.aten.mm.default(mul_9, permute_23); permute_23 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:235 in forward, code: xq = xq.view(bsz, seqlen, self.n_heads, self.head_dim)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_34: "bf16[1, 2048, 16, 128]" = torch.ops.aten.reshape.default(mm_7, [1, 2048, 16, 128]); mm_7 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:236 in forward, code: xk = xk.view(bsz, seqlen, self.n_kv_heads, self.head_dim)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_35: "bf16[1, 2048, 16, 128]" = torch.ops.aten.reshape.default(mm_8, [1, 2048, 16, 128]); mm_8 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:237 in forward, code: xv = xv.view(bsz, seqlen, self.n_kv_heads, self.head_dim)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_36: "bf16[1, 2048, 16, 128]" = torch.ops.aten.reshape.default(mm_9, [1, 2048, 16, 128]); mm_9 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:146 in apply_rotary_emb, code: xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_53: "f32[1, 2048, 16, 128]" = torch.ops.prims.convert_element_type.default(view_34, torch.float32); view_34 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_37: "f32[1, 2048, 16, 64, 2]" = torch.ops.aten.reshape.default(convert_element_type_53, [1, 2048, 16, -1, 2]); convert_element_type_53 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_as_complex_2: "c64[1, 2048, 16, 64]" = torch.ops.aten.view_as_complex.default(view_37); view_37 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:147 in apply_rotary_emb, code: xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_54: "f32[1, 2048, 16, 128]" = torch.ops.prims.convert_element_type.default(view_35, torch.float32); view_35 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_38: "f32[1, 2048, 16, 64, 2]" = torch.ops.aten.reshape.default(convert_element_type_54, [1, 2048, 16, -1, 2]); convert_element_type_54 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_as_complex_3: "c64[1, 2048, 16, 64]" = torch.ops.aten.view_as_complex.default(view_38); view_38 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:149 in apply_rotary_emb, code: xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mul_10: "c64[1, 2048, 16, 64]" = torch.ops.aten.mul.Tensor(view_as_complex_2, view_20); view_as_complex_2 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_as_real_2: "f32[1, 2048, 16, 64, 2]" = torch.ops.aten.view_as_real.default(mul_10); mul_10 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_40: "f32[1, 2048, 16, 128]" = torch.ops.aten.reshape.default(view_as_real_2, [1, 2048, 16, 128]); view_as_real_2 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:150 in apply_rotary_emb, code: xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mul_11: "c64[1, 2048, 16, 64]" = torch.ops.aten.mul.Tensor(view_as_complex_3, view_20); view_as_complex_3 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_as_real_3: "f32[1, 2048, 16, 64, 2]" = torch.ops.aten.view_as_real.default(mul_11); mul_11 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_41: "f32[1, 2048, 16, 128]" = torch.ops.aten.reshape.default(view_as_real_3, [1, 2048, 16, 128]); view_as_real_3 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:151 in apply_rotary_emb, code: return xq_out.type_as(xq), xk_out.type_as(xk)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_55: "bf16[1, 2048, 16, 128]" = torch.ops.prims.convert_element_type.default(view_40, torch.bfloat16); view_40 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_56: "bf16[1, 2048, 16, 128]" = torch.ops.prims.convert_element_type.default(view_41, torch.bfloat16); view_41 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:249 in forward, code: xq = xq.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] permute_24: "bf16[1, 16, 2048, 128]" = torch.ops.aten.permute.default(convert_element_type_55, [0, 2, 1, 3]); convert_element_type_55 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:250 in forward, code: xk = keys.transpose(1, 2) # (bs, n_local_heads, cache_len + seqlen, head_dim)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] permute_25: "bf16[1, 16, 2048, 128]" = torch.ops.aten.permute.default(convert_element_type_56, [0, 2, 1, 3]); convert_element_type_56 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:251 in forward, code: xv = values.transpose(1, 2) # (bs, n_local_heads, cache_len + seqlen, head_dim)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] permute_26: "bf16[1, 16, 2048, 128]" = torch.ops.aten.permute.default(view_36, [0, 2, 1, 3]); view_36 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:254 in forward, code: output = F.scaled_dot_product_attention(xq, xk, xv, is_causal=True)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] _scaled_dot_product_flash_attention_1 = torch.ops.aten._scaled_dot_product_flash_attention.default(permute_24, permute_25, permute_26, 0.0, True, scale = 0.08838834764831843)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_207: "bf16[1, 16, 2048, 128]" = _scaled_dot_product_flash_attention_1[0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_208: "f32[1, 16, 2048]" = _scaled_dot_product_flash_attention_1[1]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_213: "i64[]" = _scaled_dot_product_flash_attention_1[6]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_214: "i64[]" = _scaled_dot_product_flash_attention_1[7]; _scaled_dot_product_flash_attention_1 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:255 in forward, code: output = output.transpose(
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] permute_27: "bf16[1, 2048, 16, 128]" = torch.ops.aten.permute.default(getitem_207, [0, 2, 1, 3])
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:259 in forward, code: output = output.view(bsz * seqlen, -1)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_42: "bf16[2048, 2048]" = torch.ops.aten.reshape.default(permute_27, [2048, -1]); permute_27 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/pytorch_yf225/torch/nn/modules/linear.py:116 in forward, code: return F.linear(input, self.weight, self.bias)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] permute_29: "bf16[2048, 2048]" = torch.ops.aten.permute.default(contiguous_view_as_strided_15, [1, 0]); contiguous_view_as_strided_15 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mm_10: "bf16[2048, 2048]" = torch.ops.aten.mm.default(view_42, permute_29); permute_29 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:410 in forward, code: h = x + self.attention(self.attention_norm(x), freqs_cis)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] add_5: "bf16[2048, 2048]" = torch.ops.aten.add.Tensor(add_3, mm_10); add_3 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:74 in forward, code: output = self._norm(x.float()).type_as(x)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_59: "f32[2048, 2048]" = torch.ops.prims.convert_element_type.default(add_5, torch.float32)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:61 in _norm, code: return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] pow_4: "f32[2048, 2048]" = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_59, 2)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mean_3: "f32[2048, 1]" = torch.ops.aten.mean.dim(pow_4, [-1], True); pow_4 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] add_6: "f32[2048, 1]" = torch.ops.aten.add.Tensor(mean_3, 1e-05); mean_3 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] rsqrt_3: "f32[2048, 1]" = torch.ops.aten.rsqrt.default(add_6); add_6 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mul_12: "f32[2048, 2048]" = torch.ops.aten.mul.Tensor(convert_element_type_59, rsqrt_3); convert_element_type_59 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:74 in forward, code: output = self._norm(x.float()).type_as(x)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_60: "bf16[2048, 2048]" = torch.ops.prims.convert_element_type.default(mul_12, torch.bfloat16); mul_12 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:75 in forward, code: return output * self.weight
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mul_13: "bf16[2048, 2048]" = torch.ops.aten.mul.Tensor(convert_element_type_60, contiguous_view_as_strided_20); convert_element_type_60 = contiguous_view_as_strided_20 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/pytorch_yf225/torch/nn/modules/linear.py:116 in forward, code: return F.linear(input, self.weight, self.bias)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] permute_31: "bf16[2048, 5632]" = torch.ops.aten.permute.default(contiguous_view_as_strided_16, [1, 0]); contiguous_view_as_strided_16 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mm_11: "bf16[2048, 5632]" = torch.ops.aten.mm.default(mul_13, permute_31); permute_31 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:299 in forward, code: return self.w2(F.silu(self.w1(x)) * self.w3(x))
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_63: "f32[2048, 5632]" = torch.ops.prims.convert_element_type.default(mm_11, torch.float32)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] sigmoid_1: "f32[2048, 5632]" = torch.ops.aten.sigmoid.default(convert_element_type_63)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mul_14: "f32[2048, 5632]" = torch.ops.aten.mul.Tensor(convert_element_type_63, sigmoid_1); convert_element_type_63 = sigmoid_1 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_64: "bf16[2048, 5632]" = torch.ops.prims.convert_element_type.default(mul_14, torch.bfloat16); mul_14 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/pytorch_yf225/torch/nn/modules/linear.py:116 in forward, code: return F.linear(input, self.weight, self.bias)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] permute_33: "bf16[2048, 5632]" = torch.ops.aten.permute.default(contiguous_view_as_strided_18, [1, 0]); contiguous_view_as_strided_18 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mm_12: "bf16[2048, 5632]" = torch.ops.aten.mm.default(mul_13, permute_33); permute_33 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:299 in forward, code: return self.w2(F.silu(self.w1(x)) * self.w3(x))
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mul_15: "bf16[2048, 5632]" = torch.ops.aten.mul.Tensor(convert_element_type_64, mm_12); convert_element_type_64 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/pytorch_yf225/torch/nn/modules/linear.py:116 in forward, code: return F.linear(input, self.weight, self.bias)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] permute_35: "bf16[5632, 2048]" = torch.ops.aten.permute.default(contiguous_view_as_strided_17, [1, 0]); contiguous_view_as_strided_17 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mm_13: "bf16[2048, 2048]" = torch.ops.aten.mm.default(mul_15, permute_35); permute_35 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:411 in forward, code: out = h + self.feed_forward(self.ffn_norm(h))
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] add_7: "bf16[2048, 2048]" = torch.ops.aten.add.Tensor(add_5, mm_13); add_5 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_common.py:140 in _to_dtype_if_needed, code: return tensor.to(dtype)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_69: "bf16[524288]" = torch.ops.prims.convert_element_type.default(primals_46, torch.bfloat16); primals_46 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_70: "bf16[524288]" = torch.ops.prims.convert_element_type.default(primals_47, torch.bfloat16); primals_47 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_71: "bf16[524288]" = torch.ops.prims.convert_element_type.default(primals_48, torch.bfloat16); primals_48 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_72: "bf16[524288]" = torch.ops.prims.convert_element_type.default(primals_49, torch.bfloat16); primals_49 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_73: "bf16[1441792]" = torch.ops.prims.convert_element_type.default(primals_50, torch.bfloat16); primals_50 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_74: "bf16[1441792]" = torch.ops.prims.convert_element_type.default(primals_51, torch.bfloat16); primals_51 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_75: "bf16[1441792]" = torch.ops.prims.convert_element_type.default(primals_52, torch.bfloat16); primals_52 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_76: "bf16[256]" = torch.ops.prims.convert_element_type.default(primals_53, torch.bfloat16); primals_53 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_77: "bf16[256]" = torch.ops.prims.convert_element_type.default(primals_54, torch.bfloat16); primals_54 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:80 in foreach_all_gather, code: all_gather_input, all_gather_output = torch.ops.fsdp.all_gather_copy_in(all_gather_input_numel, world_size, rank, dtype, device, inp_split_sizes, param_all_gather_inputs)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] all_gather_copy_in_3 = torch.ops.fsdp.all_gather_copy_in.default(6423040, 8, 0, torch.bfloat16, device(type='cuda', index=0), [524288, 524288, 524288, 524288, 1441792, 1441792, 1441792, 256, 256], [convert_element_type_69, convert_element_type_70, convert_element_type_71, convert_element_type_72, convert_element_type_73, convert_element_type_74, convert_element_type_75, convert_element_type_76, convert_element_type_77]); convert_element_type_69 = convert_element_type_70 = convert_element_type_71 = convert_element_type_72 = convert_element_type_73 = convert_element_type_74 = convert_element_type_75 = convert_element_type_76 = convert_element_type_77 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_216: "bf16[6423040]" = all_gather_copy_in_3[0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_217: "bf16[51384320]" = all_gather_copy_in_3[1]; all_gather_copy_in_3 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_functional_collectives.py:229 in all_gather_tensor, code: tensor = torch.ops._c10d_functional.all_gather_into_tensor(
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] all_gather_into_tensor_3: "bf16[51384320]" = torch.ops._c10d_functional.all_gather_into_tensor.default(getitem_216, 8, '0'); getitem_216 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_functional_collectives.py:144 in wait_tensor, code: return torch.ops._c10d_functional.wait_tensor(tensor) # type: ignore[attr-defined]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] wait_tensor_3: "bf16[51384320]" = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_3); all_gather_into_tensor_3 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:170 in foreach_all_gather_copy_out, code: split_unpadded = torch.ops.fsdp.contiguous_view_as_strided(
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_44: "bf16[8, 6423040]" = torch.ops.aten.reshape.default(wait_tensor_3, [8, -1]); wait_tensor_3 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] split_with_sizes_25 = torch.ops.aten.split_with_sizes.default(view_44, [524288, 524288, 524288, 524288, 1441792, 1441792, 1441792, 256, 256], 1); view_44 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_227: "bf16[8, 524288]" = split_with_sizes_25[0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] contiguous_view_as_strided_21: "bf16[2048, 2048]" = torch.ops.fsdp.contiguous_view_as_strided.default(getitem_227, [2048, 2048], [2048, 1]); getitem_227 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_237: "bf16[8, 524288]" = split_with_sizes_25[1]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] contiguous_view_as_strided_22: "bf16[2048, 2048]" = torch.ops.fsdp.contiguous_view_as_strided.default(getitem_237, [2048, 2048], [2048, 1]); getitem_237 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_247: "bf16[8, 524288]" = split_with_sizes_25[2]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] contiguous_view_as_strided_23: "bf16[2048, 2048]" = torch.ops.fsdp.contiguous_view_as_strided.default(getitem_247, [2048, 2048], [2048, 1]); getitem_247 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_257: "bf16[8, 524288]" = split_with_sizes_25[3]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] contiguous_view_as_strided_24: "bf16[2048, 2048]" = torch.ops.fsdp.contiguous_view_as_strided.default(getitem_257, [2048, 2048], [2048, 1]); getitem_257 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_267: "bf16[8, 1441792]" = split_with_sizes_25[4]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] contiguous_view_as_strided_25: "bf16[5632, 2048]" = torch.ops.fsdp.contiguous_view_as_strided.default(getitem_267, [5632, 2048], [2048, 1]); getitem_267 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_277: "bf16[8, 1441792]" = split_with_sizes_25[5]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] contiguous_view_as_strided_26: "bf16[2048, 5632]" = torch.ops.fsdp.contiguous_view_as_strided.default(getitem_277, [2048, 5632], [5632, 1]); getitem_277 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_287: "bf16[8, 1441792]" = split_with_sizes_25[6]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] contiguous_view_as_strided_27: "bf16[5632, 2048]" = torch.ops.fsdp.contiguous_view_as_strided.default(getitem_287, [5632, 2048], [2048, 1]); getitem_287 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_297: "bf16[8, 256]" = split_with_sizes_25[7]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] contiguous_view_as_strided_28: "bf16[2048]" = torch.ops.fsdp.contiguous_view_as_strided.default(getitem_297, [2048], [1]); getitem_297 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_307: "bf16[8, 256]" = split_with_sizes_25[8]; split_with_sizes_25 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] contiguous_view_as_strided_29: "bf16[2048]" = torch.ops.fsdp.contiguous_view_as_strided.default(getitem_307, [2048], [1]); getitem_307 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:74 in forward, code: output = self._norm(x.float()).type_as(x)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_78: "f32[2048, 2048]" = torch.ops.prims.convert_element_type.default(add_7, torch.float32)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:61 in _norm, code: return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] pow_5: "f32[2048, 2048]" = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_78, 2)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mean_4: "f32[2048, 1]" = torch.ops.aten.mean.dim(pow_5, [-1], True); pow_5 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] add_8: "f32[2048, 1]" = torch.ops.aten.add.Tensor(mean_4, 1e-05); mean_4 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] rsqrt_4: "f32[2048, 1]" = torch.ops.aten.rsqrt.default(add_8); add_8 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mul_16: "f32[2048, 2048]" = torch.ops.aten.mul.Tensor(convert_element_type_78, rsqrt_4); convert_element_type_78 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:74 in forward, code: output = self._norm(x.float()).type_as(x)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_79: "bf16[2048, 2048]" = torch.ops.prims.convert_element_type.default(mul_16, torch.bfloat16); mul_16 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:75 in forward, code: return output * self.weight
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mul_17: "bf16[2048, 2048]" = torch.ops.aten.mul.Tensor(convert_element_type_79, contiguous_view_as_strided_28); convert_element_type_79 = contiguous_view_as_strided_28 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/pytorch_yf225/torch/nn/modules/linear.py:116 in forward, code: return F.linear(input, self.weight, self.bias)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] permute_37: "bf16[2048, 2048]" = torch.ops.aten.permute.default(contiguous_view_as_strided_21, [1, 0]); contiguous_view_as_strided_21 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mm_14: "bf16[2048, 2048]" = torch.ops.aten.mm.default(mul_17, permute_37); permute_37 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] permute_39: "bf16[2048, 2048]" = torch.ops.aten.permute.default(contiguous_view_as_strided_22, [1, 0]); contiguous_view_as_strided_22 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mm_15: "bf16[2048, 2048]" = torch.ops.aten.mm.default(mul_17, permute_39); permute_39 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] permute_41: "bf16[2048, 2048]" = torch.ops.aten.permute.default(contiguous_view_as_strided_23, [1, 0]); contiguous_view_as_strided_23 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mm_16: "bf16[2048, 2048]" = torch.ops.aten.mm.default(mul_17, permute_41); permute_41 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:235 in forward, code: xq = xq.view(bsz, seqlen, self.n_heads, self.head_dim)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_53: "bf16[1, 2048, 16, 128]" = torch.ops.aten.reshape.default(mm_14, [1, 2048, 16, 128]); mm_14 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:236 in forward, code: xk = xk.view(bsz, seqlen, self.n_kv_heads, self.head_dim)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_54: "bf16[1, 2048, 16, 128]" = torch.ops.aten.reshape.default(mm_15, [1, 2048, 16, 128]); mm_15 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:237 in forward, code: xv = xv.view(bsz, seqlen, self.n_kv_heads, self.head_dim)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_55: "bf16[1, 2048, 16, 128]" = torch.ops.aten.reshape.default(mm_16, [1, 2048, 16, 128]); mm_16 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:146 in apply_rotary_emb, code: xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_86: "f32[1, 2048, 16, 128]" = torch.ops.prims.convert_element_type.default(view_53, torch.float32); view_53 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_56: "f32[1, 2048, 16, 64, 2]" = torch.ops.aten.reshape.default(convert_element_type_86, [1, 2048, 16, -1, 2]); convert_element_type_86 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_as_complex_4: "c64[1, 2048, 16, 64]" = torch.ops.aten.view_as_complex.default(view_56); view_56 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:147 in apply_rotary_emb, code: xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_87: "f32[1, 2048, 16, 128]" = torch.ops.prims.convert_element_type.default(view_54, torch.float32); view_54 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_57: "f32[1, 2048, 16, 64, 2]" = torch.ops.aten.reshape.default(convert_element_type_87, [1, 2048, 16, -1, 2]); convert_element_type_87 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_as_complex_5: "c64[1, 2048, 16, 64]" = torch.ops.aten.view_as_complex.default(view_57); view_57 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:149 in apply_rotary_emb, code: xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mul_18: "c64[1, 2048, 16, 64]" = torch.ops.aten.mul.Tensor(view_as_complex_4, view_20); view_as_complex_4 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_as_real_4: "f32[1, 2048, 16, 64, 2]" = torch.ops.aten.view_as_real.default(mul_18); mul_18 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_59: "f32[1, 2048, 16, 128]" = torch.ops.aten.reshape.default(view_as_real_4, [1, 2048, 16, 128]); view_as_real_4 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:150 in apply_rotary_emb, code: xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mul_19: "c64[1, 2048, 16, 64]" = torch.ops.aten.mul.Tensor(view_as_complex_5, view_20); view_as_complex_5 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_as_real_5: "f32[1, 2048, 16, 64, 2]" = torch.ops.aten.view_as_real.default(mul_19); mul_19 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_60: "f32[1, 2048, 16, 128]" = torch.ops.aten.reshape.default(view_as_real_5, [1, 2048, 16, 128]); view_as_real_5 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:151 in apply_rotary_emb, code: return xq_out.type_as(xq), xk_out.type_as(xk)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_88: "bf16[1, 2048, 16, 128]" = torch.ops.prims.convert_element_type.default(view_59, torch.bfloat16); view_59 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_89: "bf16[1, 2048, 16, 128]" = torch.ops.prims.convert_element_type.default(view_60, torch.bfloat16); view_60 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:249 in forward, code: xq = xq.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] permute_42: "bf16[1, 16, 2048, 128]" = torch.ops.aten.permute.default(convert_element_type_88, [0, 2, 1, 3]); convert_element_type_88 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:250 in forward, code: xk = keys.transpose(1, 2) # (bs, n_local_heads, cache_len + seqlen, head_dim)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] permute_43: "bf16[1, 16, 2048, 128]" = torch.ops.aten.permute.default(convert_element_type_89, [0, 2, 1, 3]); convert_element_type_89 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:251 in forward, code: xv = values.transpose(1, 2) # (bs, n_local_heads, cache_len + seqlen, head_dim)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] permute_44: "bf16[1, 16, 2048, 128]" = torch.ops.aten.permute.default(view_55, [0, 2, 1, 3]); view_55 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:254 in forward, code: output = F.scaled_dot_product_attention(xq, xk, xv, is_causal=True)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] _scaled_dot_product_flash_attention_2 = torch.ops.aten._scaled_dot_product_flash_attention.default(permute_42, permute_43, permute_44, 0.0, True, scale = 0.08838834764831843)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_308: "bf16[1, 16, 2048, 128]" = _scaled_dot_product_flash_attention_2[0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_309: "f32[1, 16, 2048]" = _scaled_dot_product_flash_attention_2[1]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_314: "i64[]" = _scaled_dot_product_flash_attention_2[6]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_315: "i64[]" = _scaled_dot_product_flash_attention_2[7]; _scaled_dot_product_flash_attention_2 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:255 in forward, code: output = output.transpose(
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] permute_45: "bf16[1, 2048, 16, 128]" = torch.ops.aten.permute.default(getitem_308, [0, 2, 1, 3])
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:259 in forward, code: output = output.view(bsz * seqlen, -1)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_61: "bf16[2048, 2048]" = torch.ops.aten.reshape.default(permute_45, [2048, -1]); permute_45 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/pytorch_yf225/torch/nn/modules/linear.py:116 in forward, code: return F.linear(input, self.weight, self.bias)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] permute_47: "bf16[2048, 2048]" = torch.ops.aten.permute.default(contiguous_view_as_strided_24, [1, 0]); contiguous_view_as_strided_24 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mm_17: "bf16[2048, 2048]" = torch.ops.aten.mm.default(view_61, permute_47); permute_47 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:410 in forward, code: h = x + self.attention(self.attention_norm(x), freqs_cis)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] add_9: "bf16[2048, 2048]" = torch.ops.aten.add.Tensor(add_7, mm_17); add_7 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:74 in forward, code: output = self._norm(x.float()).type_as(x)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_92: "f32[2048, 2048]" = torch.ops.prims.convert_element_type.default(add_9, torch.float32)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:61 in _norm, code: return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] pow_6: "f32[2048, 2048]" = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_92, 2)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mean_5: "f32[2048, 1]" = torch.ops.aten.mean.dim(pow_6, [-1], True); pow_6 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] add_10: "f32[2048, 1]" = torch.ops.aten.add.Tensor(mean_5, 1e-05); mean_5 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] rsqrt_5: "f32[2048, 1]" = torch.ops.aten.rsqrt.default(add_10); add_10 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mul_20: "f32[2048, 2048]" = torch.ops.aten.mul.Tensor(convert_element_type_92, rsqrt_5); convert_element_type_92 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:74 in forward, code: output = self._norm(x.float()).type_as(x)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_93: "bf16[2048, 2048]" = torch.ops.prims.convert_element_type.default(mul_20, torch.bfloat16); mul_20 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:75 in forward, code: return output * self.weight
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mul_21: "bf16[2048, 2048]" = torch.ops.aten.mul.Tensor(convert_element_type_93, contiguous_view_as_strided_29); convert_element_type_93 = contiguous_view_as_strided_29 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/pytorch_yf225/torch/nn/modules/linear.py:116 in forward, code: return F.linear(input, self.weight, self.bias)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] permute_49: "bf16[2048, 5632]" = torch.ops.aten.permute.default(contiguous_view_as_strided_25, [1, 0]); contiguous_view_as_strided_25 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mm_18: "bf16[2048, 5632]" = torch.ops.aten.mm.default(mul_21, permute_49); permute_49 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:299 in forward, code: return self.w2(F.silu(self.w1(x)) * self.w3(x))
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_96: "f32[2048, 5632]" = torch.ops.prims.convert_element_type.default(mm_18, torch.float32)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] sigmoid_2: "f32[2048, 5632]" = torch.ops.aten.sigmoid.default(convert_element_type_96)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mul_22: "f32[2048, 5632]" = torch.ops.aten.mul.Tensor(convert_element_type_96, sigmoid_2); convert_element_type_96 = sigmoid_2 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_97: "bf16[2048, 5632]" = torch.ops.prims.convert_element_type.default(mul_22, torch.bfloat16); mul_22 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/pytorch_yf225/torch/nn/modules/linear.py:116 in forward, code: return F.linear(input, self.weight, self.bias)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] permute_51: "bf16[2048, 5632]" = torch.ops.aten.permute.default(contiguous_view_as_strided_27, [1, 0]); contiguous_view_as_strided_27 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mm_19: "bf16[2048, 5632]" = torch.ops.aten.mm.default(mul_21, permute_51); permute_51 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:299 in forward, code: return self.w2(F.silu(self.w1(x)) * self.w3(x))
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mul_23: "bf16[2048, 5632]" = torch.ops.aten.mul.Tensor(convert_element_type_97, mm_19); convert_element_type_97 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/pytorch_yf225/torch/nn/modules/linear.py:116 in forward, code: return F.linear(input, self.weight, self.bias)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] permute_53: "bf16[5632, 2048]" = torch.ops.aten.permute.default(contiguous_view_as_strided_26, [1, 0]); contiguous_view_as_strided_26 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mm_20: "bf16[2048, 2048]" = torch.ops.aten.mm.default(mul_23, permute_53); permute_53 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:411 in forward, code: out = h + self.feed_forward(self.ffn_norm(h))
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] add_11: "bf16[2048, 2048]" = torch.ops.aten.add.Tensor(add_9, mm_20); add_9 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_common.py:140 in _to_dtype_if_needed, code: return tensor.to(dtype)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_102: "bf16[524288]" = torch.ops.prims.convert_element_type.default(primals_64, torch.bfloat16); primals_64 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_103: "bf16[524288]" = torch.ops.prims.convert_element_type.default(primals_65, torch.bfloat16); primals_65 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_104: "bf16[524288]" = torch.ops.prims.convert_element_type.default(primals_66, torch.bfloat16); primals_66 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_105: "bf16[524288]" = torch.ops.prims.convert_element_type.default(primals_67, torch.bfloat16); primals_67 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_106: "bf16[1441792]" = torch.ops.prims.convert_element_type.default(primals_68, torch.bfloat16); primals_68 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_107: "bf16[1441792]" = torch.ops.prims.convert_element_type.default(primals_69, torch.bfloat16); primals_69 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_108: "bf16[1441792]" = torch.ops.prims.convert_element_type.default(primals_70, torch.bfloat16); primals_70 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_109: "bf16[256]" = torch.ops.prims.convert_element_type.default(primals_71, torch.bfloat16); primals_71 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_110: "bf16[256]" = torch.ops.prims.convert_element_type.default(primals_72, torch.bfloat16); primals_72 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:80 in foreach_all_gather, code: all_gather_input, all_gather_output = torch.ops.fsdp.all_gather_copy_in(all_gather_input_numel, world_size, rank, dtype, device, inp_split_sizes, param_all_gather_inputs)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] all_gather_copy_in_4 = torch.ops.fsdp.all_gather_copy_in.default(6423040, 8, 0, torch.bfloat16, device(type='cuda', index=0), [524288, 524288, 524288, 524288, 1441792, 1441792, 1441792, 256, 256], [convert_element_type_102, convert_element_type_103, convert_element_type_104, convert_element_type_105, convert_element_type_106, convert_element_type_107, convert_element_type_108, convert_element_type_109, convert_element_type_110]); convert_element_type_102 = convert_element_type_103 = convert_element_type_104 = convert_element_type_105 = convert_element_type_106 = convert_element_type_107 = convert_element_type_108 = convert_element_type_109 = convert_element_type_110 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_317: "bf16[6423040]" = all_gather_copy_in_4[0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_318: "bf16[51384320]" = all_gather_copy_in_4[1]; all_gather_copy_in_4 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_functional_collectives.py:229 in all_gather_tensor, code: tensor = torch.ops._c10d_functional.all_gather_into_tensor(
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] all_gather_into_tensor_4: "bf16[51384320]" = torch.ops._c10d_functional.all_gather_into_tensor.default(getitem_317, 8, '0'); getitem_317 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_functional_collectives.py:144 in wait_tensor, code: return torch.ops._c10d_functional.wait_tensor(tensor) # type: ignore[attr-defined]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] wait_tensor_4: "bf16[51384320]" = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_4); all_gather_into_tensor_4 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:170 in foreach_all_gather_copy_out, code: split_unpadded = torch.ops.fsdp.contiguous_view_as_strided(
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_63: "bf16[8, 6423040]" = torch.ops.aten.reshape.default(wait_tensor_4, [8, -1]); wait_tensor_4 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] split_with_sizes_35 = torch.ops.aten.split_with_sizes.default(view_63, [524288, 524288, 524288, 524288, 1441792, 1441792, 1441792, 256, 256], 1); view_63 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_328: "bf16[8, 524288]" = split_with_sizes_35[0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] contiguous_view_as_strided_30: "bf16[2048, 2048]" = torch.ops.fsdp.contiguous_view_as_strided.default(getitem_328, [2048, 2048], [2048, 1]); getitem_328 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_338: "bf16[8, 524288]" = split_with_sizes_35[1]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] contiguous_view_as_strided_31: "bf16[2048, 2048]" = torch.ops.fsdp.contiguous_view_as_strided.default(getitem_338, [2048, 2048], [2048, 1]); getitem_338 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_348: "bf16[8, 524288]" = split_with_sizes_35[2]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] contiguous_view_as_strided_32: "bf16[2048, 2048]" = torch.ops.fsdp.contiguous_view_as_strided.default(getitem_348, [2048, 2048], [2048, 1]); getitem_348 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_358: "bf16[8, 524288]" = split_with_sizes_35[3]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] contiguous_view_as_strided_33: "bf16[2048, 2048]" = torch.ops.fsdp.contiguous_view_as_strided.default(getitem_358, [2048, 2048], [2048, 1]); getitem_358 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_368: "bf16[8, 1441792]" = split_with_sizes_35[4]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] contiguous_view_as_strided_34: "bf16[5632, 2048]" = torch.ops.fsdp.contiguous_view_as_strided.default(getitem_368, [5632, 2048], [2048, 1]); getitem_368 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_378: "bf16[8, 1441792]" = split_with_sizes_35[5]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] contiguous_view_as_strided_35: "bf16[2048, 5632]" = torch.ops.fsdp.contiguous_view_as_strided.default(getitem_378, [2048, 5632], [5632, 1]); getitem_378 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_388: "bf16[8, 1441792]" = split_with_sizes_35[6]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] contiguous_view_as_strided_36: "bf16[5632, 2048]" = torch.ops.fsdp.contiguous_view_as_strided.default(getitem_388, [5632, 2048], [2048, 1]); getitem_388 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_398: "bf16[8, 256]" = split_with_sizes_35[7]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] contiguous_view_as_strided_37: "bf16[2048]" = torch.ops.fsdp.contiguous_view_as_strided.default(getitem_398, [2048], [1]); getitem_398 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_408: "bf16[8, 256]" = split_with_sizes_35[8]; split_with_sizes_35 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] contiguous_view_as_strided_38: "bf16[2048]" = torch.ops.fsdp.contiguous_view_as_strided.default(getitem_408, [2048], [1]); getitem_408 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:74 in forward, code: output = self._norm(x.float()).type_as(x)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_111: "f32[2048, 2048]" = torch.ops.prims.convert_element_type.default(add_11, torch.float32)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:61 in _norm, code: return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] pow_7: "f32[2048, 2048]" = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_111, 2)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mean_6: "f32[2048, 1]" = torch.ops.aten.mean.dim(pow_7, [-1], True); pow_7 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] add_12: "f32[2048, 1]" = torch.ops.aten.add.Tensor(mean_6, 1e-05); mean_6 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] rsqrt_6: "f32[2048, 1]" = torch.ops.aten.rsqrt.default(add_12); add_12 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mul_24: "f32[2048, 2048]" = torch.ops.aten.mul.Tensor(convert_element_type_111, rsqrt_6); convert_element_type_111 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:74 in forward, code: output = self._norm(x.float()).type_as(x)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_112: "bf16[2048, 2048]" = torch.ops.prims.convert_element_type.default(mul_24, torch.bfloat16); mul_24 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:75 in forward, code: return output * self.weight
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mul_25: "bf16[2048, 2048]" = torch.ops.aten.mul.Tensor(convert_element_type_112, contiguous_view_as_strided_37); convert_element_type_112 = contiguous_view_as_strided_37 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/pytorch_yf225/torch/nn/modules/linear.py:116 in forward, code: return F.linear(input, self.weight, self.bias)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] permute_55: "bf16[2048, 2048]" = torch.ops.aten.permute.default(contiguous_view_as_strided_30, [1, 0]); contiguous_view_as_strided_30 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mm_21: "bf16[2048, 2048]" = torch.ops.aten.mm.default(mul_25, permute_55); permute_55 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] permute_57: "bf16[2048, 2048]" = torch.ops.aten.permute.default(contiguous_view_as_strided_31, [1, 0]); contiguous_view_as_strided_31 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mm_22: "bf16[2048, 2048]" = torch.ops.aten.mm.default(mul_25, permute_57); permute_57 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] permute_59: "bf16[2048, 2048]" = torch.ops.aten.permute.default(contiguous_view_as_strided_32, [1, 0]); contiguous_view_as_strided_32 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mm_23: "bf16[2048, 2048]" = torch.ops.aten.mm.default(mul_25, permute_59); permute_59 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:235 in forward, code: xq = xq.view(bsz, seqlen, self.n_heads, self.head_dim)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_72: "bf16[1, 2048, 16, 128]" = torch.ops.aten.reshape.default(mm_21, [1, 2048, 16, 128]); mm_21 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:236 in forward, code: xk = xk.view(bsz, seqlen, self.n_kv_heads, self.head_dim)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_73: "bf16[1, 2048, 16, 128]" = torch.ops.aten.reshape.default(mm_22, [1, 2048, 16, 128]); mm_22 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:237 in forward, code: xv = xv.view(bsz, seqlen, self.n_kv_heads, self.head_dim)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_74: "bf16[1, 2048, 16, 128]" = torch.ops.aten.reshape.default(mm_23, [1, 2048, 16, 128]); mm_23 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:146 in apply_rotary_emb, code: xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_119: "f32[1, 2048, 16, 128]" = torch.ops.prims.convert_element_type.default(view_72, torch.float32); view_72 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_75: "f32[1, 2048, 16, 64, 2]" = torch.ops.aten.reshape.default(convert_element_type_119, [1, 2048, 16, -1, 2]); convert_element_type_119 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_as_complex_6: "c64[1, 2048, 16, 64]" = torch.ops.aten.view_as_complex.default(view_75); view_75 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:147 in apply_rotary_emb, code: xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_120: "f32[1, 2048, 16, 128]" = torch.ops.prims.convert_element_type.default(view_73, torch.float32); view_73 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_76: "f32[1, 2048, 16, 64, 2]" = torch.ops.aten.reshape.default(convert_element_type_120, [1, 2048, 16, -1, 2]); convert_element_type_120 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_as_complex_7: "c64[1, 2048, 16, 64]" = torch.ops.aten.view_as_complex.default(view_76); view_76 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:149 in apply_rotary_emb, code: xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mul_26: "c64[1, 2048, 16, 64]" = torch.ops.aten.mul.Tensor(view_as_complex_6, view_20); view_as_complex_6 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_as_real_6: "f32[1, 2048, 16, 64, 2]" = torch.ops.aten.view_as_real.default(mul_26); mul_26 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_78: "f32[1, 2048, 16, 128]" = torch.ops.aten.reshape.default(view_as_real_6, [1, 2048, 16, 128]); view_as_real_6 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:150 in apply_rotary_emb, code: xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mul_27: "c64[1, 2048, 16, 64]" = torch.ops.aten.mul.Tensor(view_as_complex_7, view_20); view_as_complex_7 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_as_real_7: "f32[1, 2048, 16, 64, 2]" = torch.ops.aten.view_as_real.default(mul_27); mul_27 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_79: "f32[1, 2048, 16, 128]" = torch.ops.aten.reshape.default(view_as_real_7, [1, 2048, 16, 128]); view_as_real_7 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:151 in apply_rotary_emb, code: return xq_out.type_as(xq), xk_out.type_as(xk)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_121: "bf16[1, 2048, 16, 128]" = torch.ops.prims.convert_element_type.default(view_78, torch.bfloat16); view_78 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_122: "bf16[1, 2048, 16, 128]" = torch.ops.prims.convert_element_type.default(view_79, torch.bfloat16); view_79 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:249 in forward, code: xq = xq.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] permute_60: "bf16[1, 16, 2048, 128]" = torch.ops.aten.permute.default(convert_element_type_121, [0, 2, 1, 3]); convert_element_type_121 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:250 in forward, code: xk = keys.transpose(1, 2) # (bs, n_local_heads, cache_len + seqlen, head_dim)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] permute_61: "bf16[1, 16, 2048, 128]" = torch.ops.aten.permute.default(convert_element_type_122, [0, 2, 1, 3]); convert_element_type_122 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:251 in forward, code: xv = values.transpose(1, 2) # (bs, n_local_heads, cache_len + seqlen, head_dim)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] permute_62: "bf16[1, 16, 2048, 128]" = torch.ops.aten.permute.default(view_74, [0, 2, 1, 3]); view_74 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:254 in forward, code: output = F.scaled_dot_product_attention(xq, xk, xv, is_causal=True)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] _scaled_dot_product_flash_attention_3 = torch.ops.aten._scaled_dot_product_flash_attention.default(permute_60, permute_61, permute_62, 0.0, True, scale = 0.08838834764831843)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_409: "bf16[1, 16, 2048, 128]" = _scaled_dot_product_flash_attention_3[0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_410: "f32[1, 16, 2048]" = _scaled_dot_product_flash_attention_3[1]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_415: "i64[]" = _scaled_dot_product_flash_attention_3[6]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_416: "i64[]" = _scaled_dot_product_flash_attention_3[7]; _scaled_dot_product_flash_attention_3 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:255 in forward, code: output = output.transpose(
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] permute_63: "bf16[1, 2048, 16, 128]" = torch.ops.aten.permute.default(getitem_409, [0, 2, 1, 3])
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:259 in forward, code: output = output.view(bsz * seqlen, -1)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_80: "bf16[2048, 2048]" = torch.ops.aten.reshape.default(permute_63, [2048, -1]); permute_63 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/pytorch_yf225/torch/nn/modules/linear.py:116 in forward, code: return F.linear(input, self.weight, self.bias)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] permute_65: "bf16[2048, 2048]" = torch.ops.aten.permute.default(contiguous_view_as_strided_33, [1, 0]); contiguous_view_as_strided_33 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mm_24: "bf16[2048, 2048]" = torch.ops.aten.mm.default(view_80, permute_65); permute_65 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:410 in forward, code: h = x + self.attention(self.attention_norm(x), freqs_cis)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] add_13: "bf16[2048, 2048]" = torch.ops.aten.add.Tensor(add_11, mm_24); add_11 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:74 in forward, code: output = self._norm(x.float()).type_as(x)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_125: "f32[2048, 2048]" = torch.ops.prims.convert_element_type.default(add_13, torch.float32)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:61 in _norm, code: return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] pow_8: "f32[2048, 2048]" = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_125, 2)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mean_7: "f32[2048, 1]" = torch.ops.aten.mean.dim(pow_8, [-1], True); pow_8 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] add_14: "f32[2048, 1]" = torch.ops.aten.add.Tensor(mean_7, 1e-05); mean_7 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] rsqrt_7: "f32[2048, 1]" = torch.ops.aten.rsqrt.default(add_14); add_14 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mul_28: "f32[2048, 2048]" = torch.ops.aten.mul.Tensor(convert_element_type_125, rsqrt_7); convert_element_type_125 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:74 in forward, code: output = self._norm(x.float()).type_as(x)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_126: "bf16[2048, 2048]" = torch.ops.prims.convert_element_type.default(mul_28, torch.bfloat16); mul_28 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:75 in forward, code: return output * self.weight
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mul_29: "bf16[2048, 2048]" = torch.ops.aten.mul.Tensor(convert_element_type_126, contiguous_view_as_strided_38); convert_element_type_126 = contiguous_view_as_strided_38 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/pytorch_yf225/torch/nn/modules/linear.py:116 in forward, code: return F.linear(input, self.weight, self.bias)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] permute_67: "bf16[2048, 5632]" = torch.ops.aten.permute.default(contiguous_view_as_strided_34, [1, 0]); contiguous_view_as_strided_34 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mm_25: "bf16[2048, 5632]" = torch.ops.aten.mm.default(mul_29, permute_67); permute_67 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:299 in forward, code: return self.w2(F.silu(self.w1(x)) * self.w3(x))
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_129: "f32[2048, 5632]" = torch.ops.prims.convert_element_type.default(mm_25, torch.float32)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] sigmoid_3: "f32[2048, 5632]" = torch.ops.aten.sigmoid.default(convert_element_type_129)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mul_30: "f32[2048, 5632]" = torch.ops.aten.mul.Tensor(convert_element_type_129, sigmoid_3); convert_element_type_129 = sigmoid_3 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_130: "bf16[2048, 5632]" = torch.ops.prims.convert_element_type.default(mul_30, torch.bfloat16); mul_30 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/pytorch_yf225/torch/nn/modules/linear.py:116 in forward, code: return F.linear(input, self.weight, self.bias)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] permute_69: "bf16[2048, 5632]" = torch.ops.aten.permute.default(contiguous_view_as_strided_36, [1, 0]); contiguous_view_as_strided_36 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mm_26: "bf16[2048, 5632]" = torch.ops.aten.mm.default(mul_29, permute_69); permute_69 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:299 in forward, code: return self.w2(F.silu(self.w1(x)) * self.w3(x))
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mul_31: "bf16[2048, 5632]" = torch.ops.aten.mul.Tensor(convert_element_type_130, mm_26); convert_element_type_130 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/pytorch_yf225/torch/nn/modules/linear.py:116 in forward, code: return F.linear(input, self.weight, self.bias)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] permute_71: "bf16[5632, 2048]" = torch.ops.aten.permute.default(contiguous_view_as_strided_35, [1, 0]); contiguous_view_as_strided_35 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mm_27: "bf16[2048, 2048]" = torch.ops.aten.mm.default(mul_31, permute_71); permute_71 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:411 in forward, code: out = h + self.feed_forward(self.ffn_norm(h))
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] add_15: "bf16[2048, 2048]" = torch.ops.aten.add.Tensor(add_13, mm_27); add_13 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_common.py:140 in _to_dtype_if_needed, code: return tensor.to(dtype)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_135: "bf16[524288]" = torch.ops.prims.convert_element_type.default(primals_82, torch.bfloat16); primals_82 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_136: "bf16[524288]" = torch.ops.prims.convert_element_type.default(primals_83, torch.bfloat16); primals_83 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_137: "bf16[524288]" = torch.ops.prims.convert_element_type.default(primals_84, torch.bfloat16); primals_84 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_138: "bf16[524288]" = torch.ops.prims.convert_element_type.default(primals_85, torch.bfloat16); primals_85 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_139: "bf16[1441792]" = torch.ops.prims.convert_element_type.default(primals_86, torch.bfloat16); primals_86 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_140: "bf16[1441792]" = torch.ops.prims.convert_element_type.default(primals_87, torch.bfloat16); primals_87 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_141: "bf16[1441792]" = torch.ops.prims.convert_element_type.default(primals_88, torch.bfloat16); primals_88 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_142: "bf16[256]" = torch.ops.prims.convert_element_type.default(primals_89, torch.bfloat16); primals_89 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_143: "bf16[256]" = torch.ops.prims.convert_element_type.default(primals_90, torch.bfloat16); primals_90 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:80 in foreach_all_gather, code: all_gather_input, all_gather_output = torch.ops.fsdp.all_gather_copy_in(all_gather_input_numel, world_size, rank, dtype, device, inp_split_sizes, param_all_gather_inputs)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] all_gather_copy_in_5 = torch.ops.fsdp.all_gather_copy_in.default(6423040, 8, 0, torch.bfloat16, device(type='cuda', index=0), [524288, 524288, 524288, 524288, 1441792, 1441792, 1441792, 256, 256], [convert_element_type_135, convert_element_type_136, convert_element_type_137, convert_element_type_138, convert_element_type_139, convert_element_type_140, convert_element_type_141, convert_element_type_142, convert_element_type_143]); convert_element_type_135 = convert_element_type_136 = convert_element_type_137 = convert_element_type_138 = convert_element_type_139 = convert_element_type_140 = convert_element_type_141 = convert_element_type_142 = convert_element_type_143 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_418: "bf16[6423040]" = all_gather_copy_in_5[0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_419: "bf16[51384320]" = all_gather_copy_in_5[1]; all_gather_copy_in_5 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_functional_collectives.py:229 in all_gather_tensor, code: tensor = torch.ops._c10d_functional.all_gather_into_tensor(
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] all_gather_into_tensor_5: "bf16[51384320]" = torch.ops._c10d_functional.all_gather_into_tensor.default(getitem_418, 8, '0'); getitem_418 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_functional_collectives.py:144 in wait_tensor, code: return torch.ops._c10d_functional.wait_tensor(tensor) # type: ignore[attr-defined]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] wait_tensor_5: "bf16[51384320]" = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_5); all_gather_into_tensor_5 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:170 in foreach_all_gather_copy_out, code: split_unpadded = torch.ops.fsdp.contiguous_view_as_strided(
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_82: "bf16[8, 6423040]" = torch.ops.aten.reshape.default(wait_tensor_5, [8, -1]); wait_tensor_5 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] split_with_sizes_45 = torch.ops.aten.split_with_sizes.default(view_82, [524288, 524288, 524288, 524288, 1441792, 1441792, 1441792, 256, 256], 1); view_82 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_429: "bf16[8, 524288]" = split_with_sizes_45[0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] contiguous_view_as_strided_39: "bf16[2048, 2048]" = torch.ops.fsdp.contiguous_view_as_strided.default(getitem_429, [2048, 2048], [2048, 1]); getitem_429 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_439: "bf16[8, 524288]" = split_with_sizes_45[1]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] contiguous_view_as_strided_40: "bf16[2048, 2048]" = torch.ops.fsdp.contiguous_view_as_strided.default(getitem_439, [2048, 2048], [2048, 1]); getitem_439 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_449: "bf16[8, 524288]" = split_with_sizes_45[2]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] contiguous_view_as_strided_41: "bf16[2048, 2048]" = torch.ops.fsdp.contiguous_view_as_strided.default(getitem_449, [2048, 2048], [2048, 1]); getitem_449 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_459: "bf16[8, 524288]" = split_with_sizes_45[3]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] contiguous_view_as_strided_42: "bf16[2048, 2048]" = torch.ops.fsdp.contiguous_view_as_strided.default(getitem_459, [2048, 2048], [2048, 1]); getitem_459 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_469: "bf16[8, 1441792]" = split_with_sizes_45[4]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] contiguous_view_as_strided_43: "bf16[5632, 2048]" = torch.ops.fsdp.contiguous_view_as_strided.default(getitem_469, [5632, 2048], [2048, 1]); getitem_469 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_479: "bf16[8, 1441792]" = split_with_sizes_45[5]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] contiguous_view_as_strided_44: "bf16[2048, 5632]" = torch.ops.fsdp.contiguous_view_as_strided.default(getitem_479, [2048, 5632], [5632, 1]); getitem_479 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_489: "bf16[8, 1441792]" = split_with_sizes_45[6]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] contiguous_view_as_strided_45: "bf16[5632, 2048]" = torch.ops.fsdp.contiguous_view_as_strided.default(getitem_489, [5632, 2048], [2048, 1]); getitem_489 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_499: "bf16[8, 256]" = split_with_sizes_45[7]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] contiguous_view_as_strided_46: "bf16[2048]" = torch.ops.fsdp.contiguous_view_as_strided.default(getitem_499, [2048], [1]); getitem_499 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_509: "bf16[8, 256]" = split_with_sizes_45[8]; split_with_sizes_45 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] contiguous_view_as_strided_47: "bf16[2048]" = torch.ops.fsdp.contiguous_view_as_strided.default(getitem_509, [2048], [1]); getitem_509 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:74 in forward, code: output = self._norm(x.float()).type_as(x)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_144: "f32[2048, 2048]" = torch.ops.prims.convert_element_type.default(add_15, torch.float32)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:61 in _norm, code: return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] pow_9: "f32[2048, 2048]" = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_144, 2)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mean_8: "f32[2048, 1]" = torch.ops.aten.mean.dim(pow_9, [-1], True); pow_9 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] add_16: "f32[2048, 1]" = torch.ops.aten.add.Tensor(mean_8, 1e-05); mean_8 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] rsqrt_8: "f32[2048, 1]" = torch.ops.aten.rsqrt.default(add_16); add_16 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mul_32: "f32[2048, 2048]" = torch.ops.aten.mul.Tensor(convert_element_type_144, rsqrt_8); convert_element_type_144 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:74 in forward, code: output = self._norm(x.float()).type_as(x)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_145: "bf16[2048, 2048]" = torch.ops.prims.convert_element_type.default(mul_32, torch.bfloat16); mul_32 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:75 in forward, code: return output * self.weight
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mul_33: "bf16[2048, 2048]" = torch.ops.aten.mul.Tensor(convert_element_type_145, contiguous_view_as_strided_46); convert_element_type_145 = contiguous_view_as_strided_46 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/pytorch_yf225/torch/nn/modules/linear.py:116 in forward, code: return F.linear(input, self.weight, self.bias)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] permute_73: "bf16[2048, 2048]" = torch.ops.aten.permute.default(contiguous_view_as_strided_39, [1, 0]); contiguous_view_as_strided_39 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mm_28: "bf16[2048, 2048]" = torch.ops.aten.mm.default(mul_33, permute_73); permute_73 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] permute_75: "bf16[2048, 2048]" = torch.ops.aten.permute.default(contiguous_view_as_strided_40, [1, 0]); contiguous_view_as_strided_40 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mm_29: "bf16[2048, 2048]" = torch.ops.aten.mm.default(mul_33, permute_75); permute_75 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] permute_77: "bf16[2048, 2048]" = torch.ops.aten.permute.default(contiguous_view_as_strided_41, [1, 0]); contiguous_view_as_strided_41 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mm_30: "bf16[2048, 2048]" = torch.ops.aten.mm.default(mul_33, permute_77); permute_77 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:235 in forward, code: xq = xq.view(bsz, seqlen, self.n_heads, self.head_dim)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_91: "bf16[1, 2048, 16, 128]" = torch.ops.aten.reshape.default(mm_28, [1, 2048, 16, 128]); mm_28 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:236 in forward, code: xk = xk.view(bsz, seqlen, self.n_kv_heads, self.head_dim)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_92: "bf16[1, 2048, 16, 128]" = torch.ops.aten.reshape.default(mm_29, [1, 2048, 16, 128]); mm_29 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:237 in forward, code: xv = xv.view(bsz, seqlen, self.n_kv_heads, self.head_dim)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_93: "bf16[1, 2048, 16, 128]" = torch.ops.aten.reshape.default(mm_30, [1, 2048, 16, 128]); mm_30 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:146 in apply_rotary_emb, code: xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_152: "f32[1, 2048, 16, 128]" = torch.ops.prims.convert_element_type.default(view_91, torch.float32); view_91 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_94: "f32[1, 2048, 16, 64, 2]" = torch.ops.aten.reshape.default(convert_element_type_152, [1, 2048, 16, -1, 2]); convert_element_type_152 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_as_complex_8: "c64[1, 2048, 16, 64]" = torch.ops.aten.view_as_complex.default(view_94); view_94 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:147 in apply_rotary_emb, code: xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_153: "f32[1, 2048, 16, 128]" = torch.ops.prims.convert_element_type.default(view_92, torch.float32); view_92 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_95: "f32[1, 2048, 16, 64, 2]" = torch.ops.aten.reshape.default(convert_element_type_153, [1, 2048, 16, -1, 2]); convert_element_type_153 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_as_complex_9: "c64[1, 2048, 16, 64]" = torch.ops.aten.view_as_complex.default(view_95); view_95 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:149 in apply_rotary_emb, code: xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mul_34: "c64[1, 2048, 16, 64]" = torch.ops.aten.mul.Tensor(view_as_complex_8, view_20); view_as_complex_8 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_as_real_8: "f32[1, 2048, 16, 64, 2]" = torch.ops.aten.view_as_real.default(mul_34); mul_34 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_97: "f32[1, 2048, 16, 128]" = torch.ops.aten.reshape.default(view_as_real_8, [1, 2048, 16, 128]); view_as_real_8 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:150 in apply_rotary_emb, code: xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mul_35: "c64[1, 2048, 16, 64]" = torch.ops.aten.mul.Tensor(view_as_complex_9, view_20); view_as_complex_9 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_as_real_9: "f32[1, 2048, 16, 64, 2]" = torch.ops.aten.view_as_real.default(mul_35); mul_35 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_98: "f32[1, 2048, 16, 128]" = torch.ops.aten.reshape.default(view_as_real_9, [1, 2048, 16, 128]); view_as_real_9 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:151 in apply_rotary_emb, code: return xq_out.type_as(xq), xk_out.type_as(xk)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_154: "bf16[1, 2048, 16, 128]" = torch.ops.prims.convert_element_type.default(view_97, torch.bfloat16); view_97 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_155: "bf16[1, 2048, 16, 128]" = torch.ops.prims.convert_element_type.default(view_98, torch.bfloat16); view_98 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:249 in forward, code: xq = xq.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] permute_78: "bf16[1, 16, 2048, 128]" = torch.ops.aten.permute.default(convert_element_type_154, [0, 2, 1, 3]); convert_element_type_154 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:250 in forward, code: xk = keys.transpose(1, 2) # (bs, n_local_heads, cache_len + seqlen, head_dim)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] permute_79: "bf16[1, 16, 2048, 128]" = torch.ops.aten.permute.default(convert_element_type_155, [0, 2, 1, 3]); convert_element_type_155 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:251 in forward, code: xv = values.transpose(1, 2) # (bs, n_local_heads, cache_len + seqlen, head_dim)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] permute_80: "bf16[1, 16, 2048, 128]" = torch.ops.aten.permute.default(view_93, [0, 2, 1, 3]); view_93 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:254 in forward, code: output = F.scaled_dot_product_attention(xq, xk, xv, is_causal=True)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] _scaled_dot_product_flash_attention_4 = torch.ops.aten._scaled_dot_product_flash_attention.default(permute_78, permute_79, permute_80, 0.0, True, scale = 0.08838834764831843)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_510: "bf16[1, 16, 2048, 128]" = _scaled_dot_product_flash_attention_4[0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_511: "f32[1, 16, 2048]" = _scaled_dot_product_flash_attention_4[1]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_516: "i64[]" = _scaled_dot_product_flash_attention_4[6]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_517: "i64[]" = _scaled_dot_product_flash_attention_4[7]; _scaled_dot_product_flash_attention_4 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:255 in forward, code: output = output.transpose(
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] permute_81: "bf16[1, 2048, 16, 128]" = torch.ops.aten.permute.default(getitem_510, [0, 2, 1, 3])
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:259 in forward, code: output = output.view(bsz * seqlen, -1)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_99: "bf16[2048, 2048]" = torch.ops.aten.reshape.default(permute_81, [2048, -1]); permute_81 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/pytorch_yf225/torch/nn/modules/linear.py:116 in forward, code: return F.linear(input, self.weight, self.bias)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] permute_83: "bf16[2048, 2048]" = torch.ops.aten.permute.default(contiguous_view_as_strided_42, [1, 0]); contiguous_view_as_strided_42 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mm_31: "bf16[2048, 2048]" = torch.ops.aten.mm.default(view_99, permute_83); permute_83 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:410 in forward, code: h = x + self.attention(self.attention_norm(x), freqs_cis)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] add_17: "bf16[2048, 2048]" = torch.ops.aten.add.Tensor(add_15, mm_31); add_15 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:74 in forward, code: output = self._norm(x.float()).type_as(x)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_158: "f32[2048, 2048]" = torch.ops.prims.convert_element_type.default(add_17, torch.float32)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:61 in _norm, code: return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] pow_10: "f32[2048, 2048]" = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_158, 2)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mean_9: "f32[2048, 1]" = torch.ops.aten.mean.dim(pow_10, [-1], True); pow_10 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] add_18: "f32[2048, 1]" = torch.ops.aten.add.Tensor(mean_9, 1e-05); mean_9 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] rsqrt_9: "f32[2048, 1]" = torch.ops.aten.rsqrt.default(add_18); add_18 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mul_36: "f32[2048, 2048]" = torch.ops.aten.mul.Tensor(convert_element_type_158, rsqrt_9); convert_element_type_158 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:74 in forward, code: output = self._norm(x.float()).type_as(x)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_159: "bf16[2048, 2048]" = torch.ops.prims.convert_element_type.default(mul_36, torch.bfloat16); mul_36 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:75 in forward, code: return output * self.weight
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mul_37: "bf16[2048, 2048]" = torch.ops.aten.mul.Tensor(convert_element_type_159, contiguous_view_as_strided_47); convert_element_type_159 = contiguous_view_as_strided_47 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/pytorch_yf225/torch/nn/modules/linear.py:116 in forward, code: return F.linear(input, self.weight, self.bias)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] permute_85: "bf16[2048, 5632]" = torch.ops.aten.permute.default(contiguous_view_as_strided_43, [1, 0]); contiguous_view_as_strided_43 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mm_32: "bf16[2048, 5632]" = torch.ops.aten.mm.default(mul_37, permute_85); permute_85 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:299 in forward, code: return self.w2(F.silu(self.w1(x)) * self.w3(x))
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_162: "f32[2048, 5632]" = torch.ops.prims.convert_element_type.default(mm_32, torch.float32)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] sigmoid_4: "f32[2048, 5632]" = torch.ops.aten.sigmoid.default(convert_element_type_162)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mul_38: "f32[2048, 5632]" = torch.ops.aten.mul.Tensor(convert_element_type_162, sigmoid_4); convert_element_type_162 = sigmoid_4 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_163: "bf16[2048, 5632]" = torch.ops.prims.convert_element_type.default(mul_38, torch.bfloat16); mul_38 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/pytorch_yf225/torch/nn/modules/linear.py:116 in forward, code: return F.linear(input, self.weight, self.bias)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] permute_87: "bf16[2048, 5632]" = torch.ops.aten.permute.default(contiguous_view_as_strided_45, [1, 0]); contiguous_view_as_strided_45 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mm_33: "bf16[2048, 5632]" = torch.ops.aten.mm.default(mul_37, permute_87); permute_87 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:299 in forward, code: return self.w2(F.silu(self.w1(x)) * self.w3(x))
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mul_39: "bf16[2048, 5632]" = torch.ops.aten.mul.Tensor(convert_element_type_163, mm_33); convert_element_type_163 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/pytorch_yf225/torch/nn/modules/linear.py:116 in forward, code: return F.linear(input, self.weight, self.bias)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] permute_89: "bf16[5632, 2048]" = torch.ops.aten.permute.default(contiguous_view_as_strided_44, [1, 0]); contiguous_view_as_strided_44 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mm_34: "bf16[2048, 2048]" = torch.ops.aten.mm.default(mul_39, permute_89); permute_89 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:411 in forward, code: out = h + self.feed_forward(self.ffn_norm(h))
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] add_19: "bf16[2048, 2048]" = torch.ops.aten.add.Tensor(add_17, mm_34); add_17 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_common.py:140 in _to_dtype_if_needed, code: return tensor.to(dtype)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_168: "bf16[524288]" = torch.ops.prims.convert_element_type.default(primals_100, torch.bfloat16); primals_100 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_169: "bf16[524288]" = torch.ops.prims.convert_element_type.default(primals_101, torch.bfloat16); primals_101 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_170: "bf16[524288]" = torch.ops.prims.convert_element_type.default(primals_102, torch.bfloat16); primals_102 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_171: "bf16[524288]" = torch.ops.prims.convert_element_type.default(primals_103, torch.bfloat16); primals_103 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_172: "bf16[1441792]" = torch.ops.prims.convert_element_type.default(primals_104, torch.bfloat16); primals_104 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_173: "bf16[1441792]" = torch.ops.prims.convert_element_type.default(primals_105, torch.bfloat16); primals_105 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_174: "bf16[1441792]" = torch.ops.prims.convert_element_type.default(primals_106, torch.bfloat16); primals_106 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_175: "bf16[256]" = torch.ops.prims.convert_element_type.default(primals_107, torch.bfloat16); primals_107 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_176: "bf16[256]" = torch.ops.prims.convert_element_type.default(primals_108, torch.bfloat16); primals_108 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:80 in foreach_all_gather, code: all_gather_input, all_gather_output = torch.ops.fsdp.all_gather_copy_in(all_gather_input_numel, world_size, rank, dtype, device, inp_split_sizes, param_all_gather_inputs)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] all_gather_copy_in_6 = torch.ops.fsdp.all_gather_copy_in.default(6423040, 8, 0, torch.bfloat16, device(type='cuda', index=0), [524288, 524288, 524288, 524288, 1441792, 1441792, 1441792, 256, 256], [convert_element_type_168, convert_element_type_169, convert_element_type_170, convert_element_type_171, convert_element_type_172, convert_element_type_173, convert_element_type_174, convert_element_type_175, convert_element_type_176]); convert_element_type_168 = convert_element_type_169 = convert_element_type_170 = convert_element_type_171 = convert_element_type_172 = convert_element_type_173 = convert_element_type_174 = convert_element_type_175 = convert_element_type_176 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_519: "bf16[6423040]" = all_gather_copy_in_6[0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_520: "bf16[51384320]" = all_gather_copy_in_6[1]; all_gather_copy_in_6 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_functional_collectives.py:229 in all_gather_tensor, code: tensor = torch.ops._c10d_functional.all_gather_into_tensor(
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] all_gather_into_tensor_6: "bf16[51384320]" = torch.ops._c10d_functional.all_gather_into_tensor.default(getitem_519, 8, '0'); getitem_519 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_functional_collectives.py:144 in wait_tensor, code: return torch.ops._c10d_functional.wait_tensor(tensor) # type: ignore[attr-defined]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] wait_tensor_6: "bf16[51384320]" = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_6); all_gather_into_tensor_6 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:170 in foreach_all_gather_copy_out, code: split_unpadded = torch.ops.fsdp.contiguous_view_as_strided(
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_101: "bf16[8, 6423040]" = torch.ops.aten.reshape.default(wait_tensor_6, [8, -1]); wait_tensor_6 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] split_with_sizes_55 = torch.ops.aten.split_with_sizes.default(view_101, [524288, 524288, 524288, 524288, 1441792, 1441792, 1441792, 256, 256], 1); view_101 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_530: "bf16[8, 524288]" = split_with_sizes_55[0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] contiguous_view_as_strided_48: "bf16[2048, 2048]" = torch.ops.fsdp.contiguous_view_as_strided.default(getitem_530, [2048, 2048], [2048, 1]); getitem_530 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_540: "bf16[8, 524288]" = split_with_sizes_55[1]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] contiguous_view_as_strided_49: "bf16[2048, 2048]" = torch.ops.fsdp.contiguous_view_as_strided.default(getitem_540, [2048, 2048], [2048, 1]); getitem_540 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_550: "bf16[8, 524288]" = split_with_sizes_55[2]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] contiguous_view_as_strided_50: "bf16[2048, 2048]" = torch.ops.fsdp.contiguous_view_as_strided.default(getitem_550, [2048, 2048], [2048, 1]); getitem_550 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_560: "bf16[8, 524288]" = split_with_sizes_55[3]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] contiguous_view_as_strided_51: "bf16[2048, 2048]" = torch.ops.fsdp.contiguous_view_as_strided.default(getitem_560, [2048, 2048], [2048, 1]); getitem_560 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_570: "bf16[8, 1441792]" = split_with_sizes_55[4]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] contiguous_view_as_strided_52: "bf16[5632, 2048]" = torch.ops.fsdp.contiguous_view_as_strided.default(getitem_570, [5632, 2048], [2048, 1]); getitem_570 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_580: "bf16[8, 1441792]" = split_with_sizes_55[5]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] contiguous_view_as_strided_53: "bf16[2048, 5632]" = torch.ops.fsdp.contiguous_view_as_strided.default(getitem_580, [2048, 5632], [5632, 1]); getitem_580 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_590: "bf16[8, 1441792]" = split_with_sizes_55[6]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] contiguous_view_as_strided_54: "bf16[5632, 2048]" = torch.ops.fsdp.contiguous_view_as_strided.default(getitem_590, [5632, 2048], [2048, 1]); getitem_590 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_600: "bf16[8, 256]" = split_with_sizes_55[7]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] contiguous_view_as_strided_55: "bf16[2048]" = torch.ops.fsdp.contiguous_view_as_strided.default(getitem_600, [2048], [1]); getitem_600 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_610: "bf16[8, 256]" = split_with_sizes_55[8]; split_with_sizes_55 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] contiguous_view_as_strided_56: "bf16[2048]" = torch.ops.fsdp.contiguous_view_as_strided.default(getitem_610, [2048], [1]); getitem_610 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:74 in forward, code: output = self._norm(x.float()).type_as(x)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_177: "f32[2048, 2048]" = torch.ops.prims.convert_element_type.default(add_19, torch.float32)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:61 in _norm, code: return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] pow_11: "f32[2048, 2048]" = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_177, 2)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mean_10: "f32[2048, 1]" = torch.ops.aten.mean.dim(pow_11, [-1], True); pow_11 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] add_20: "f32[2048, 1]" = torch.ops.aten.add.Tensor(mean_10, 1e-05); mean_10 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] rsqrt_10: "f32[2048, 1]" = torch.ops.aten.rsqrt.default(add_20); add_20 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mul_40: "f32[2048, 2048]" = torch.ops.aten.mul.Tensor(convert_element_type_177, rsqrt_10); convert_element_type_177 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:74 in forward, code: output = self._norm(x.float()).type_as(x)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_178: "bf16[2048, 2048]" = torch.ops.prims.convert_element_type.default(mul_40, torch.bfloat16); mul_40 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:75 in forward, code: return output * self.weight
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mul_41: "bf16[2048, 2048]" = torch.ops.aten.mul.Tensor(convert_element_type_178, contiguous_view_as_strided_55); convert_element_type_178 = contiguous_view_as_strided_55 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/pytorch_yf225/torch/nn/modules/linear.py:116 in forward, code: return F.linear(input, self.weight, self.bias)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] permute_91: "bf16[2048, 2048]" = torch.ops.aten.permute.default(contiguous_view_as_strided_48, [1, 0]); contiguous_view_as_strided_48 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mm_35: "bf16[2048, 2048]" = torch.ops.aten.mm.default(mul_41, permute_91); permute_91 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] permute_93: "bf16[2048, 2048]" = torch.ops.aten.permute.default(contiguous_view_as_strided_49, [1, 0]); contiguous_view_as_strided_49 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mm_36: "bf16[2048, 2048]" = torch.ops.aten.mm.default(mul_41, permute_93); permute_93 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] permute_95: "bf16[2048, 2048]" = torch.ops.aten.permute.default(contiguous_view_as_strided_50, [1, 0]); contiguous_view_as_strided_50 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mm_37: "bf16[2048, 2048]" = torch.ops.aten.mm.default(mul_41, permute_95); permute_95 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:235 in forward, code: xq = xq.view(bsz, seqlen, self.n_heads, self.head_dim)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_110: "bf16[1, 2048, 16, 128]" = torch.ops.aten.reshape.default(mm_35, [1, 2048, 16, 128]); mm_35 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:236 in forward, code: xk = xk.view(bsz, seqlen, self.n_kv_heads, self.head_dim)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_111: "bf16[1, 2048, 16, 128]" = torch.ops.aten.reshape.default(mm_36, [1, 2048, 16, 128]); mm_36 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:237 in forward, code: xv = xv.view(bsz, seqlen, self.n_kv_heads, self.head_dim)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_112: "bf16[1, 2048, 16, 128]" = torch.ops.aten.reshape.default(mm_37, [1, 2048, 16, 128]); mm_37 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:146 in apply_rotary_emb, code: xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_185: "f32[1, 2048, 16, 128]" = torch.ops.prims.convert_element_type.default(view_110, torch.float32); view_110 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_113: "f32[1, 2048, 16, 64, 2]" = torch.ops.aten.reshape.default(convert_element_type_185, [1, 2048, 16, -1, 2]); convert_element_type_185 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_as_complex_10: "c64[1, 2048, 16, 64]" = torch.ops.aten.view_as_complex.default(view_113); view_113 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:147 in apply_rotary_emb, code: xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_186: "f32[1, 2048, 16, 128]" = torch.ops.prims.convert_element_type.default(view_111, torch.float32); view_111 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_114: "f32[1, 2048, 16, 64, 2]" = torch.ops.aten.reshape.default(convert_element_type_186, [1, 2048, 16, -1, 2]); convert_element_type_186 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_as_complex_11: "c64[1, 2048, 16, 64]" = torch.ops.aten.view_as_complex.default(view_114); view_114 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:149 in apply_rotary_emb, code: xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mul_42: "c64[1, 2048, 16, 64]" = torch.ops.aten.mul.Tensor(view_as_complex_10, view_20); view_as_complex_10 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_as_real_10: "f32[1, 2048, 16, 64, 2]" = torch.ops.aten.view_as_real.default(mul_42); mul_42 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_116: "f32[1, 2048, 16, 128]" = torch.ops.aten.reshape.default(view_as_real_10, [1, 2048, 16, 128]); view_as_real_10 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:150 in apply_rotary_emb, code: xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mul_43: "c64[1, 2048, 16, 64]" = torch.ops.aten.mul.Tensor(view_as_complex_11, view_20); view_as_complex_11 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_as_real_11: "f32[1, 2048, 16, 64, 2]" = torch.ops.aten.view_as_real.default(mul_43); mul_43 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_117: "f32[1, 2048, 16, 128]" = torch.ops.aten.reshape.default(view_as_real_11, [1, 2048, 16, 128]); view_as_real_11 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:151 in apply_rotary_emb, code: return xq_out.type_as(xq), xk_out.type_as(xk)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_187: "bf16[1, 2048, 16, 128]" = torch.ops.prims.convert_element_type.default(view_116, torch.bfloat16); view_116 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_188: "bf16[1, 2048, 16, 128]" = torch.ops.prims.convert_element_type.default(view_117, torch.bfloat16); view_117 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:249 in forward, code: xq = xq.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] permute_96: "bf16[1, 16, 2048, 128]" = torch.ops.aten.permute.default(convert_element_type_187, [0, 2, 1, 3]); convert_element_type_187 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:250 in forward, code: xk = keys.transpose(1, 2) # (bs, n_local_heads, cache_len + seqlen, head_dim)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] permute_97: "bf16[1, 16, 2048, 128]" = torch.ops.aten.permute.default(convert_element_type_188, [0, 2, 1, 3]); convert_element_type_188 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:251 in forward, code: xv = values.transpose(1, 2) # (bs, n_local_heads, cache_len + seqlen, head_dim)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] permute_98: "bf16[1, 16, 2048, 128]" = torch.ops.aten.permute.default(view_112, [0, 2, 1, 3]); view_112 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:254 in forward, code: output = F.scaled_dot_product_attention(xq, xk, xv, is_causal=True)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] _scaled_dot_product_flash_attention_5 = torch.ops.aten._scaled_dot_product_flash_attention.default(permute_96, permute_97, permute_98, 0.0, True, scale = 0.08838834764831843)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_611: "bf16[1, 16, 2048, 128]" = _scaled_dot_product_flash_attention_5[0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_612: "f32[1, 16, 2048]" = _scaled_dot_product_flash_attention_5[1]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_617: "i64[]" = _scaled_dot_product_flash_attention_5[6]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_618: "i64[]" = _scaled_dot_product_flash_attention_5[7]; _scaled_dot_product_flash_attention_5 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:255 in forward, code: output = output.transpose(
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] permute_99: "bf16[1, 2048, 16, 128]" = torch.ops.aten.permute.default(getitem_611, [0, 2, 1, 3])
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:259 in forward, code: output = output.view(bsz * seqlen, -1)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_118: "bf16[2048, 2048]" = torch.ops.aten.reshape.default(permute_99, [2048, -1]); permute_99 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/pytorch_yf225/torch/nn/modules/linear.py:116 in forward, code: return F.linear(input, self.weight, self.bias)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] permute_101: "bf16[2048, 2048]" = torch.ops.aten.permute.default(contiguous_view_as_strided_51, [1, 0]); contiguous_view_as_strided_51 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mm_38: "bf16[2048, 2048]" = torch.ops.aten.mm.default(view_118, permute_101); permute_101 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:410 in forward, code: h = x + self.attention(self.attention_norm(x), freqs_cis)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] add_21: "bf16[2048, 2048]" = torch.ops.aten.add.Tensor(add_19, mm_38); add_19 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:74 in forward, code: output = self._norm(x.float()).type_as(x)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_191: "f32[2048, 2048]" = torch.ops.prims.convert_element_type.default(add_21, torch.float32)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:61 in _norm, code: return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] pow_12: "f32[2048, 2048]" = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_191, 2)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mean_11: "f32[2048, 1]" = torch.ops.aten.mean.dim(pow_12, [-1], True); pow_12 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] add_22: "f32[2048, 1]" = torch.ops.aten.add.Tensor(mean_11, 1e-05); mean_11 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] rsqrt_11: "f32[2048, 1]" = torch.ops.aten.rsqrt.default(add_22); add_22 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mul_44: "f32[2048, 2048]" = torch.ops.aten.mul.Tensor(convert_element_type_191, rsqrt_11); convert_element_type_191 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:74 in forward, code: output = self._norm(x.float()).type_as(x)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_192: "bf16[2048, 2048]" = torch.ops.prims.convert_element_type.default(mul_44, torch.bfloat16); mul_44 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:75 in forward, code: return output * self.weight
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mul_45: "bf16[2048, 2048]" = torch.ops.aten.mul.Tensor(convert_element_type_192, contiguous_view_as_strided_56); convert_element_type_192 = contiguous_view_as_strided_56 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/pytorch_yf225/torch/nn/modules/linear.py:116 in forward, code: return F.linear(input, self.weight, self.bias)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] permute_103: "bf16[2048, 5632]" = torch.ops.aten.permute.default(contiguous_view_as_strided_52, [1, 0]); contiguous_view_as_strided_52 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mm_39: "bf16[2048, 5632]" = torch.ops.aten.mm.default(mul_45, permute_103); permute_103 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:299 in forward, code: return self.w2(F.silu(self.w1(x)) * self.w3(x))
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_195: "f32[2048, 5632]" = torch.ops.prims.convert_element_type.default(mm_39, torch.float32)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] sigmoid_5: "f32[2048, 5632]" = torch.ops.aten.sigmoid.default(convert_element_type_195)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mul_46: "f32[2048, 5632]" = torch.ops.aten.mul.Tensor(convert_element_type_195, sigmoid_5); convert_element_type_195 = sigmoid_5 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_196: "bf16[2048, 5632]" = torch.ops.prims.convert_element_type.default(mul_46, torch.bfloat16); mul_46 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/pytorch_yf225/torch/nn/modules/linear.py:116 in forward, code: return F.linear(input, self.weight, self.bias)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] permute_105: "bf16[2048, 5632]" = torch.ops.aten.permute.default(contiguous_view_as_strided_54, [1, 0]); contiguous_view_as_strided_54 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mm_40: "bf16[2048, 5632]" = torch.ops.aten.mm.default(mul_45, permute_105); permute_105 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:299 in forward, code: return self.w2(F.silu(self.w1(x)) * self.w3(x))
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mul_47: "bf16[2048, 5632]" = torch.ops.aten.mul.Tensor(convert_element_type_196, mm_40); convert_element_type_196 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/pytorch_yf225/torch/nn/modules/linear.py:116 in forward, code: return F.linear(input, self.weight, self.bias)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] permute_107: "bf16[5632, 2048]" = torch.ops.aten.permute.default(contiguous_view_as_strided_53, [1, 0]); contiguous_view_as_strided_53 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mm_41: "bf16[2048, 2048]" = torch.ops.aten.mm.default(mul_47, permute_107); permute_107 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:411 in forward, code: out = h + self.feed_forward(self.ffn_norm(h))
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] add_23: "bf16[2048, 2048]" = torch.ops.aten.add.Tensor(add_21, mm_41); add_21 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_common.py:140 in _to_dtype_if_needed, code: return tensor.to(dtype)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_201: "bf16[524288]" = torch.ops.prims.convert_element_type.default(primals_118, torch.bfloat16); primals_118 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_202: "bf16[524288]" = torch.ops.prims.convert_element_type.default(primals_119, torch.bfloat16); primals_119 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_203: "bf16[524288]" = torch.ops.prims.convert_element_type.default(primals_120, torch.bfloat16); primals_120 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_204: "bf16[524288]" = torch.ops.prims.convert_element_type.default(primals_121, torch.bfloat16); primals_121 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_205: "bf16[1441792]" = torch.ops.prims.convert_element_type.default(primals_122, torch.bfloat16); primals_122 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_206: "bf16[1441792]" = torch.ops.prims.convert_element_type.default(primals_123, torch.bfloat16); primals_123 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_207: "bf16[1441792]" = torch.ops.prims.convert_element_type.default(primals_124, torch.bfloat16); primals_124 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_208: "bf16[256]" = torch.ops.prims.convert_element_type.default(primals_125, torch.bfloat16); primals_125 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_209: "bf16[256]" = torch.ops.prims.convert_element_type.default(primals_126, torch.bfloat16); primals_126 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:80 in foreach_all_gather, code: all_gather_input, all_gather_output = torch.ops.fsdp.all_gather_copy_in(all_gather_input_numel, world_size, rank, dtype, device, inp_split_sizes, param_all_gather_inputs)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] all_gather_copy_in_7 = torch.ops.fsdp.all_gather_copy_in.default(6423040, 8, 0, torch.bfloat16, device(type='cuda', index=0), [524288, 524288, 524288, 524288, 1441792, 1441792, 1441792, 256, 256], [convert_element_type_201, convert_element_type_202, convert_element_type_203, convert_element_type_204, convert_element_type_205, convert_element_type_206, convert_element_type_207, convert_element_type_208, convert_element_type_209]); convert_element_type_201 = convert_element_type_202 = convert_element_type_203 = convert_element_type_204 = convert_element_type_205 = convert_element_type_206 = convert_element_type_207 = convert_element_type_208 = convert_element_type_209 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_620: "bf16[6423040]" = all_gather_copy_in_7[0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_621: "bf16[51384320]" = all_gather_copy_in_7[1]; all_gather_copy_in_7 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_functional_collectives.py:229 in all_gather_tensor, code: tensor = torch.ops._c10d_functional.all_gather_into_tensor(
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] all_gather_into_tensor_7: "bf16[51384320]" = torch.ops._c10d_functional.all_gather_into_tensor.default(getitem_620, 8, '0'); getitem_620 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_functional_collectives.py:144 in wait_tensor, code: return torch.ops._c10d_functional.wait_tensor(tensor) # type: ignore[attr-defined]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] wait_tensor_7: "bf16[51384320]" = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_7); all_gather_into_tensor_7 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:170 in foreach_all_gather_copy_out, code: split_unpadded = torch.ops.fsdp.contiguous_view_as_strided(
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_120: "bf16[8, 6423040]" = torch.ops.aten.reshape.default(wait_tensor_7, [8, -1]); wait_tensor_7 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] split_with_sizes_65 = torch.ops.aten.split_with_sizes.default(view_120, [524288, 524288, 524288, 524288, 1441792, 1441792, 1441792, 256, 256], 1); view_120 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_631: "bf16[8, 524288]" = split_with_sizes_65[0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] contiguous_view_as_strided_57: "bf16[2048, 2048]" = torch.ops.fsdp.contiguous_view_as_strided.default(getitem_631, [2048, 2048], [2048, 1]); getitem_631 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_641: "bf16[8, 524288]" = split_with_sizes_65[1]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] contiguous_view_as_strided_58: "bf16[2048, 2048]" = torch.ops.fsdp.contiguous_view_as_strided.default(getitem_641, [2048, 2048], [2048, 1]); getitem_641 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_651: "bf16[8, 524288]" = split_with_sizes_65[2]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] contiguous_view_as_strided_59: "bf16[2048, 2048]" = torch.ops.fsdp.contiguous_view_as_strided.default(getitem_651, [2048, 2048], [2048, 1]); getitem_651 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_661: "bf16[8, 524288]" = split_with_sizes_65[3]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] contiguous_view_as_strided_60: "bf16[2048, 2048]" = torch.ops.fsdp.contiguous_view_as_strided.default(getitem_661, [2048, 2048], [2048, 1]); getitem_661 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_671: "bf16[8, 1441792]" = split_with_sizes_65[4]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] contiguous_view_as_strided_61: "bf16[5632, 2048]" = torch.ops.fsdp.contiguous_view_as_strided.default(getitem_671, [5632, 2048], [2048, 1]); getitem_671 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_681: "bf16[8, 1441792]" = split_with_sizes_65[5]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] contiguous_view_as_strided_62: "bf16[2048, 5632]" = torch.ops.fsdp.contiguous_view_as_strided.default(getitem_681, [2048, 5632], [5632, 1]); getitem_681 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_691: "bf16[8, 1441792]" = split_with_sizes_65[6]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] contiguous_view_as_strided_63: "bf16[5632, 2048]" = torch.ops.fsdp.contiguous_view_as_strided.default(getitem_691, [5632, 2048], [2048, 1]); getitem_691 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_701: "bf16[8, 256]" = split_with_sizes_65[7]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] contiguous_view_as_strided_64: "bf16[2048]" = torch.ops.fsdp.contiguous_view_as_strided.default(getitem_701, [2048], [1]); getitem_701 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_711: "bf16[8, 256]" = split_with_sizes_65[8]; split_with_sizes_65 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] contiguous_view_as_strided_65: "bf16[2048]" = torch.ops.fsdp.contiguous_view_as_strided.default(getitem_711, [2048], [1]); getitem_711 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:74 in forward, code: output = self._norm(x.float()).type_as(x)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_210: "f32[2048, 2048]" = torch.ops.prims.convert_element_type.default(add_23, torch.float32)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:61 in _norm, code: return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] pow_13: "f32[2048, 2048]" = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_210, 2)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mean_12: "f32[2048, 1]" = torch.ops.aten.mean.dim(pow_13, [-1], True); pow_13 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] add_24: "f32[2048, 1]" = torch.ops.aten.add.Tensor(mean_12, 1e-05); mean_12 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] rsqrt_12: "f32[2048, 1]" = torch.ops.aten.rsqrt.default(add_24); add_24 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mul_48: "f32[2048, 2048]" = torch.ops.aten.mul.Tensor(convert_element_type_210, rsqrt_12); convert_element_type_210 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:74 in forward, code: output = self._norm(x.float()).type_as(x)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_211: "bf16[2048, 2048]" = torch.ops.prims.convert_element_type.default(mul_48, torch.bfloat16); mul_48 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:75 in forward, code: return output * self.weight
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mul_49: "bf16[2048, 2048]" = torch.ops.aten.mul.Tensor(convert_element_type_211, contiguous_view_as_strided_64); convert_element_type_211 = contiguous_view_as_strided_64 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/pytorch_yf225/torch/nn/modules/linear.py:116 in forward, code: return F.linear(input, self.weight, self.bias)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] permute_109: "bf16[2048, 2048]" = torch.ops.aten.permute.default(contiguous_view_as_strided_57, [1, 0]); contiguous_view_as_strided_57 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mm_42: "bf16[2048, 2048]" = torch.ops.aten.mm.default(mul_49, permute_109); permute_109 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] permute_111: "bf16[2048, 2048]" = torch.ops.aten.permute.default(contiguous_view_as_strided_58, [1, 0]); contiguous_view_as_strided_58 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mm_43: "bf16[2048, 2048]" = torch.ops.aten.mm.default(mul_49, permute_111); permute_111 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] permute_113: "bf16[2048, 2048]" = torch.ops.aten.permute.default(contiguous_view_as_strided_59, [1, 0]); contiguous_view_as_strided_59 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mm_44: "bf16[2048, 2048]" = torch.ops.aten.mm.default(mul_49, permute_113); permute_113 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:235 in forward, code: xq = xq.view(bsz, seqlen, self.n_heads, self.head_dim)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_129: "bf16[1, 2048, 16, 128]" = torch.ops.aten.reshape.default(mm_42, [1, 2048, 16, 128]); mm_42 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:236 in forward, code: xk = xk.view(bsz, seqlen, self.n_kv_heads, self.head_dim)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_130: "bf16[1, 2048, 16, 128]" = torch.ops.aten.reshape.default(mm_43, [1, 2048, 16, 128]); mm_43 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:237 in forward, code: xv = xv.view(bsz, seqlen, self.n_kv_heads, self.head_dim)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_131: "bf16[1, 2048, 16, 128]" = torch.ops.aten.reshape.default(mm_44, [1, 2048, 16, 128]); mm_44 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:146 in apply_rotary_emb, code: xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_218: "f32[1, 2048, 16, 128]" = torch.ops.prims.convert_element_type.default(view_129, torch.float32); view_129 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_132: "f32[1, 2048, 16, 64, 2]" = torch.ops.aten.reshape.default(convert_element_type_218, [1, 2048, 16, -1, 2]); convert_element_type_218 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_as_complex_12: "c64[1, 2048, 16, 64]" = torch.ops.aten.view_as_complex.default(view_132); view_132 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:147 in apply_rotary_emb, code: xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_219: "f32[1, 2048, 16, 128]" = torch.ops.prims.convert_element_type.default(view_130, torch.float32); view_130 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_133: "f32[1, 2048, 16, 64, 2]" = torch.ops.aten.reshape.default(convert_element_type_219, [1, 2048, 16, -1, 2]); convert_element_type_219 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_as_complex_13: "c64[1, 2048, 16, 64]" = torch.ops.aten.view_as_complex.default(view_133); view_133 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:149 in apply_rotary_emb, code: xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mul_50: "c64[1, 2048, 16, 64]" = torch.ops.aten.mul.Tensor(view_as_complex_12, view_20); view_as_complex_12 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_as_real_12: "f32[1, 2048, 16, 64, 2]" = torch.ops.aten.view_as_real.default(mul_50); mul_50 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_135: "f32[1, 2048, 16, 128]" = torch.ops.aten.reshape.default(view_as_real_12, [1, 2048, 16, 128]); view_as_real_12 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:150 in apply_rotary_emb, code: xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mul_51: "c64[1, 2048, 16, 64]" = torch.ops.aten.mul.Tensor(view_as_complex_13, view_20); view_as_complex_13 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_as_real_13: "f32[1, 2048, 16, 64, 2]" = torch.ops.aten.view_as_real.default(mul_51); mul_51 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_136: "f32[1, 2048, 16, 128]" = torch.ops.aten.reshape.default(view_as_real_13, [1, 2048, 16, 128]); view_as_real_13 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:151 in apply_rotary_emb, code: return xq_out.type_as(xq), xk_out.type_as(xk)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_220: "bf16[1, 2048, 16, 128]" = torch.ops.prims.convert_element_type.default(view_135, torch.bfloat16); view_135 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_221: "bf16[1, 2048, 16, 128]" = torch.ops.prims.convert_element_type.default(view_136, torch.bfloat16); view_136 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:249 in forward, code: xq = xq.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] permute_114: "bf16[1, 16, 2048, 128]" = torch.ops.aten.permute.default(convert_element_type_220, [0, 2, 1, 3]); convert_element_type_220 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:250 in forward, code: xk = keys.transpose(1, 2) # (bs, n_local_heads, cache_len + seqlen, head_dim)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] permute_115: "bf16[1, 16, 2048, 128]" = torch.ops.aten.permute.default(convert_element_type_221, [0, 2, 1, 3]); convert_element_type_221 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:251 in forward, code: xv = values.transpose(1, 2) # (bs, n_local_heads, cache_len + seqlen, head_dim)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] permute_116: "bf16[1, 16, 2048, 128]" = torch.ops.aten.permute.default(view_131, [0, 2, 1, 3]); view_131 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:254 in forward, code: output = F.scaled_dot_product_attention(xq, xk, xv, is_causal=True)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] _scaled_dot_product_flash_attention_6 = torch.ops.aten._scaled_dot_product_flash_attention.default(permute_114, permute_115, permute_116, 0.0, True, scale = 0.08838834764831843)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_712: "bf16[1, 16, 2048, 128]" = _scaled_dot_product_flash_attention_6[0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_713: "f32[1, 16, 2048]" = _scaled_dot_product_flash_attention_6[1]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_718: "i64[]" = _scaled_dot_product_flash_attention_6[6]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_719: "i64[]" = _scaled_dot_product_flash_attention_6[7]; _scaled_dot_product_flash_attention_6 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:255 in forward, code: output = output.transpose(
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] permute_117: "bf16[1, 2048, 16, 128]" = torch.ops.aten.permute.default(getitem_712, [0, 2, 1, 3])
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:259 in forward, code: output = output.view(bsz * seqlen, -1)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_137: "bf16[2048, 2048]" = torch.ops.aten.reshape.default(permute_117, [2048, -1]); permute_117 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/pytorch_yf225/torch/nn/modules/linear.py:116 in forward, code: return F.linear(input, self.weight, self.bias)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] permute_119: "bf16[2048, 2048]" = torch.ops.aten.permute.default(contiguous_view_as_strided_60, [1, 0]); contiguous_view_as_strided_60 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mm_45: "bf16[2048, 2048]" = torch.ops.aten.mm.default(view_137, permute_119); permute_119 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:410 in forward, code: h = x + self.attention(self.attention_norm(x), freqs_cis)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] add_25: "bf16[2048, 2048]" = torch.ops.aten.add.Tensor(add_23, mm_45); add_23 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:74 in forward, code: output = self._norm(x.float()).type_as(x)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_224: "f32[2048, 2048]" = torch.ops.prims.convert_element_type.default(add_25, torch.float32)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:61 in _norm, code: return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] pow_14: "f32[2048, 2048]" = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_224, 2)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mean_13: "f32[2048, 1]" = torch.ops.aten.mean.dim(pow_14, [-1], True); pow_14 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] add_26: "f32[2048, 1]" = torch.ops.aten.add.Tensor(mean_13, 1e-05); mean_13 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] rsqrt_13: "f32[2048, 1]" = torch.ops.aten.rsqrt.default(add_26); add_26 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mul_52: "f32[2048, 2048]" = torch.ops.aten.mul.Tensor(convert_element_type_224, rsqrt_13); convert_element_type_224 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:74 in forward, code: output = self._norm(x.float()).type_as(x)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_225: "bf16[2048, 2048]" = torch.ops.prims.convert_element_type.default(mul_52, torch.bfloat16); mul_52 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:75 in forward, code: return output * self.weight
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mul_53: "bf16[2048, 2048]" = torch.ops.aten.mul.Tensor(convert_element_type_225, contiguous_view_as_strided_65); convert_element_type_225 = contiguous_view_as_strided_65 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/pytorch_yf225/torch/nn/modules/linear.py:116 in forward, code: return F.linear(input, self.weight, self.bias)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] permute_121: "bf16[2048, 5632]" = torch.ops.aten.permute.default(contiguous_view_as_strided_61, [1, 0]); contiguous_view_as_strided_61 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mm_46: "bf16[2048, 5632]" = torch.ops.aten.mm.default(mul_53, permute_121); permute_121 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:299 in forward, code: return self.w2(F.silu(self.w1(x)) * self.w3(x))
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_228: "f32[2048, 5632]" = torch.ops.prims.convert_element_type.default(mm_46, torch.float32)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] sigmoid_6: "f32[2048, 5632]" = torch.ops.aten.sigmoid.default(convert_element_type_228)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mul_54: "f32[2048, 5632]" = torch.ops.aten.mul.Tensor(convert_element_type_228, sigmoid_6); convert_element_type_228 = sigmoid_6 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_229: "bf16[2048, 5632]" = torch.ops.prims.convert_element_type.default(mul_54, torch.bfloat16); mul_54 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/pytorch_yf225/torch/nn/modules/linear.py:116 in forward, code: return F.linear(input, self.weight, self.bias)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] permute_123: "bf16[2048, 5632]" = torch.ops.aten.permute.default(contiguous_view_as_strided_63, [1, 0]); contiguous_view_as_strided_63 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mm_47: "bf16[2048, 5632]" = torch.ops.aten.mm.default(mul_53, permute_123); permute_123 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:299 in forward, code: return self.w2(F.silu(self.w1(x)) * self.w3(x))
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mul_55: "bf16[2048, 5632]" = torch.ops.aten.mul.Tensor(convert_element_type_229, mm_47); convert_element_type_229 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/pytorch_yf225/torch/nn/modules/linear.py:116 in forward, code: return F.linear(input, self.weight, self.bias)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] permute_125: "bf16[5632, 2048]" = torch.ops.aten.permute.default(contiguous_view_as_strided_62, [1, 0]); contiguous_view_as_strided_62 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mm_48: "bf16[2048, 2048]" = torch.ops.aten.mm.default(mul_55, permute_125); permute_125 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:411 in forward, code: out = h + self.feed_forward(self.ffn_norm(h))
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] add_27: "bf16[2048, 2048]" = torch.ops.aten.add.Tensor(add_25, mm_48); add_25 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_common.py:140 in _to_dtype_if_needed, code: return tensor.to(dtype)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_234: "bf16[524288]" = torch.ops.prims.convert_element_type.default(primals_136, torch.bfloat16); primals_136 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_235: "bf16[524288]" = torch.ops.prims.convert_element_type.default(primals_137, torch.bfloat16); primals_137 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_236: "bf16[524288]" = torch.ops.prims.convert_element_type.default(primals_138, torch.bfloat16); primals_138 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_237: "bf16[524288]" = torch.ops.prims.convert_element_type.default(primals_139, torch.bfloat16); primals_139 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_238: "bf16[1441792]" = torch.ops.prims.convert_element_type.default(primals_140, torch.bfloat16); primals_140 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_239: "bf16[1441792]" = torch.ops.prims.convert_element_type.default(primals_141, torch.bfloat16); primals_141 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_240: "bf16[1441792]" = torch.ops.prims.convert_element_type.default(primals_142, torch.bfloat16); primals_142 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_241: "bf16[256]" = torch.ops.prims.convert_element_type.default(primals_143, torch.bfloat16); primals_143 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_242: "bf16[256]" = torch.ops.prims.convert_element_type.default(primals_144, torch.bfloat16); primals_144 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:80 in foreach_all_gather, code: all_gather_input, all_gather_output = torch.ops.fsdp.all_gather_copy_in(all_gather_input_numel, world_size, rank, dtype, device, inp_split_sizes, param_all_gather_inputs)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] all_gather_copy_in_8 = torch.ops.fsdp.all_gather_copy_in.default(6423040, 8, 0, torch.bfloat16, device(type='cuda', index=0), [524288, 524288, 524288, 524288, 1441792, 1441792, 1441792, 256, 256], [convert_element_type_234, convert_element_type_235, convert_element_type_236, convert_element_type_237, convert_element_type_238, convert_element_type_239, convert_element_type_240, convert_element_type_241, convert_element_type_242]); convert_element_type_234 = convert_element_type_235 = convert_element_type_236 = convert_element_type_237 = convert_element_type_238 = convert_element_type_239 = convert_element_type_240 = convert_element_type_241 = convert_element_type_242 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_721: "bf16[6423040]" = all_gather_copy_in_8[0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_722: "bf16[51384320]" = all_gather_copy_in_8[1]; all_gather_copy_in_8 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_functional_collectives.py:229 in all_gather_tensor, code: tensor = torch.ops._c10d_functional.all_gather_into_tensor(
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] all_gather_into_tensor_8: "bf16[51384320]" = torch.ops._c10d_functional.all_gather_into_tensor.default(getitem_721, 8, '0'); getitem_721 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_functional_collectives.py:144 in wait_tensor, code: return torch.ops._c10d_functional.wait_tensor(tensor) # type: ignore[attr-defined]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] wait_tensor_8: "bf16[51384320]" = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_8); all_gather_into_tensor_8 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:170 in foreach_all_gather_copy_out, code: split_unpadded = torch.ops.fsdp.contiguous_view_as_strided(
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_139: "bf16[8, 6423040]" = torch.ops.aten.reshape.default(wait_tensor_8, [8, -1]); wait_tensor_8 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] split_with_sizes_75 = torch.ops.aten.split_with_sizes.default(view_139, [524288, 524288, 524288, 524288, 1441792, 1441792, 1441792, 256, 256], 1); view_139 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_732: "bf16[8, 524288]" = split_with_sizes_75[0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] contiguous_view_as_strided_66: "bf16[2048, 2048]" = torch.ops.fsdp.contiguous_view_as_strided.default(getitem_732, [2048, 2048], [2048, 1]); getitem_732 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_742: "bf16[8, 524288]" = split_with_sizes_75[1]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] contiguous_view_as_strided_67: "bf16[2048, 2048]" = torch.ops.fsdp.contiguous_view_as_strided.default(getitem_742, [2048, 2048], [2048, 1]); getitem_742 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_752: "bf16[8, 524288]" = split_with_sizes_75[2]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] contiguous_view_as_strided_68: "bf16[2048, 2048]" = torch.ops.fsdp.contiguous_view_as_strided.default(getitem_752, [2048, 2048], [2048, 1]); getitem_752 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_762: "bf16[8, 524288]" = split_with_sizes_75[3]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] contiguous_view_as_strided_69: "bf16[2048, 2048]" = torch.ops.fsdp.contiguous_view_as_strided.default(getitem_762, [2048, 2048], [2048, 1]); getitem_762 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_772: "bf16[8, 1441792]" = split_with_sizes_75[4]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] contiguous_view_as_strided_70: "bf16[5632, 2048]" = torch.ops.fsdp.contiguous_view_as_strided.default(getitem_772, [5632, 2048], [2048, 1]); getitem_772 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_782: "bf16[8, 1441792]" = split_with_sizes_75[5]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] contiguous_view_as_strided_71: "bf16[2048, 5632]" = torch.ops.fsdp.contiguous_view_as_strided.default(getitem_782, [2048, 5632], [5632, 1]); getitem_782 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_792: "bf16[8, 1441792]" = split_with_sizes_75[6]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] contiguous_view_as_strided_72: "bf16[5632, 2048]" = torch.ops.fsdp.contiguous_view_as_strided.default(getitem_792, [5632, 2048], [2048, 1]); getitem_792 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_802: "bf16[8, 256]" = split_with_sizes_75[7]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] contiguous_view_as_strided_73: "bf16[2048]" = torch.ops.fsdp.contiguous_view_as_strided.default(getitem_802, [2048], [1]); getitem_802 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_812: "bf16[8, 256]" = split_with_sizes_75[8]; split_with_sizes_75 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] contiguous_view_as_strided_74: "bf16[2048]" = torch.ops.fsdp.contiguous_view_as_strided.default(getitem_812, [2048], [1]); getitem_812 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:74 in forward, code: output = self._norm(x.float()).type_as(x)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_243: "f32[2048, 2048]" = torch.ops.prims.convert_element_type.default(add_27, torch.float32)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:61 in _norm, code: return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] pow_15: "f32[2048, 2048]" = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_243, 2)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mean_14: "f32[2048, 1]" = torch.ops.aten.mean.dim(pow_15, [-1], True); pow_15 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] add_28: "f32[2048, 1]" = torch.ops.aten.add.Tensor(mean_14, 1e-05); mean_14 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] rsqrt_14: "f32[2048, 1]" = torch.ops.aten.rsqrt.default(add_28); add_28 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mul_56: "f32[2048, 2048]" = torch.ops.aten.mul.Tensor(convert_element_type_243, rsqrt_14); convert_element_type_243 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:74 in forward, code: output = self._norm(x.float()).type_as(x)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_244: "bf16[2048, 2048]" = torch.ops.prims.convert_element_type.default(mul_56, torch.bfloat16); mul_56 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:75 in forward, code: return output * self.weight
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mul_57: "bf16[2048, 2048]" = torch.ops.aten.mul.Tensor(convert_element_type_244, contiguous_view_as_strided_73); convert_element_type_244 = contiguous_view_as_strided_73 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/pytorch_yf225/torch/nn/modules/linear.py:116 in forward, code: return F.linear(input, self.weight, self.bias)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] permute_127: "bf16[2048, 2048]" = torch.ops.aten.permute.default(contiguous_view_as_strided_66, [1, 0]); contiguous_view_as_strided_66 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mm_49: "bf16[2048, 2048]" = torch.ops.aten.mm.default(mul_57, permute_127); permute_127 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] permute_129: "bf16[2048, 2048]" = torch.ops.aten.permute.default(contiguous_view_as_strided_67, [1, 0]); contiguous_view_as_strided_67 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mm_50: "bf16[2048, 2048]" = torch.ops.aten.mm.default(mul_57, permute_129); permute_129 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] permute_131: "bf16[2048, 2048]" = torch.ops.aten.permute.default(contiguous_view_as_strided_68, [1, 0]); contiguous_view_as_strided_68 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mm_51: "bf16[2048, 2048]" = torch.ops.aten.mm.default(mul_57, permute_131); permute_131 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:235 in forward, code: xq = xq.view(bsz, seqlen, self.n_heads, self.head_dim)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_148: "bf16[1, 2048, 16, 128]" = torch.ops.aten.reshape.default(mm_49, [1, 2048, 16, 128]); mm_49 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:236 in forward, code: xk = xk.view(bsz, seqlen, self.n_kv_heads, self.head_dim)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_149: "bf16[1, 2048, 16, 128]" = torch.ops.aten.reshape.default(mm_50, [1, 2048, 16, 128]); mm_50 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:237 in forward, code: xv = xv.view(bsz, seqlen, self.n_kv_heads, self.head_dim)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_150: "bf16[1, 2048, 16, 128]" = torch.ops.aten.reshape.default(mm_51, [1, 2048, 16, 128]); mm_51 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:146 in apply_rotary_emb, code: xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_251: "f32[1, 2048, 16, 128]" = torch.ops.prims.convert_element_type.default(view_148, torch.float32); view_148 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_151: "f32[1, 2048, 16, 64, 2]" = torch.ops.aten.reshape.default(convert_element_type_251, [1, 2048, 16, -1, 2]); convert_element_type_251 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_as_complex_14: "c64[1, 2048, 16, 64]" = torch.ops.aten.view_as_complex.default(view_151); view_151 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:147 in apply_rotary_emb, code: xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_252: "f32[1, 2048, 16, 128]" = torch.ops.prims.convert_element_type.default(view_149, torch.float32); view_149 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_152: "f32[1, 2048, 16, 64, 2]" = torch.ops.aten.reshape.default(convert_element_type_252, [1, 2048, 16, -1, 2]); convert_element_type_252 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_as_complex_15: "c64[1, 2048, 16, 64]" = torch.ops.aten.view_as_complex.default(view_152); view_152 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:149 in apply_rotary_emb, code: xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mul_58: "c64[1, 2048, 16, 64]" = torch.ops.aten.mul.Tensor(view_as_complex_14, view_20); view_as_complex_14 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_as_real_14: "f32[1, 2048, 16, 64, 2]" = torch.ops.aten.view_as_real.default(mul_58); mul_58 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_154: "f32[1, 2048, 16, 128]" = torch.ops.aten.reshape.default(view_as_real_14, [1, 2048, 16, 128]); view_as_real_14 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:150 in apply_rotary_emb, code: xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mul_59: "c64[1, 2048, 16, 64]" = torch.ops.aten.mul.Tensor(view_as_complex_15, view_20); view_as_complex_15 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_as_real_15: "f32[1, 2048, 16, 64, 2]" = torch.ops.aten.view_as_real.default(mul_59); mul_59 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_155: "f32[1, 2048, 16, 128]" = torch.ops.aten.reshape.default(view_as_real_15, [1, 2048, 16, 128]); view_as_real_15 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:151 in apply_rotary_emb, code: return xq_out.type_as(xq), xk_out.type_as(xk)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_253: "bf16[1, 2048, 16, 128]" = torch.ops.prims.convert_element_type.default(view_154, torch.bfloat16); view_154 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_254: "bf16[1, 2048, 16, 128]" = torch.ops.prims.convert_element_type.default(view_155, torch.bfloat16); view_155 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:249 in forward, code: xq = xq.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] permute_132: "bf16[1, 16, 2048, 128]" = torch.ops.aten.permute.default(convert_element_type_253, [0, 2, 1, 3]); convert_element_type_253 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:250 in forward, code: xk = keys.transpose(1, 2) # (bs, n_local_heads, cache_len + seqlen, head_dim)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] permute_133: "bf16[1, 16, 2048, 128]" = torch.ops.aten.permute.default(convert_element_type_254, [0, 2, 1, 3]); convert_element_type_254 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:251 in forward, code: xv = values.transpose(1, 2) # (bs, n_local_heads, cache_len + seqlen, head_dim)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] permute_134: "bf16[1, 16, 2048, 128]" = torch.ops.aten.permute.default(view_150, [0, 2, 1, 3]); view_150 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:254 in forward, code: output = F.scaled_dot_product_attention(xq, xk, xv, is_causal=True)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] _scaled_dot_product_flash_attention_7 = torch.ops.aten._scaled_dot_product_flash_attention.default(permute_132, permute_133, permute_134, 0.0, True, scale = 0.08838834764831843)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_813: "bf16[1, 16, 2048, 128]" = _scaled_dot_product_flash_attention_7[0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_814: "f32[1, 16, 2048]" = _scaled_dot_product_flash_attention_7[1]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_819: "i64[]" = _scaled_dot_product_flash_attention_7[6]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_820: "i64[]" = _scaled_dot_product_flash_attention_7[7]; _scaled_dot_product_flash_attention_7 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:255 in forward, code: output = output.transpose(
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] permute_135: "bf16[1, 2048, 16, 128]" = torch.ops.aten.permute.default(getitem_813, [0, 2, 1, 3])
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:259 in forward, code: output = output.view(bsz * seqlen, -1)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_156: "bf16[2048, 2048]" = torch.ops.aten.reshape.default(permute_135, [2048, -1]); permute_135 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/pytorch_yf225/torch/nn/modules/linear.py:116 in forward, code: return F.linear(input, self.weight, self.bias)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] permute_137: "bf16[2048, 2048]" = torch.ops.aten.permute.default(contiguous_view_as_strided_69, [1, 0]); contiguous_view_as_strided_69 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mm_52: "bf16[2048, 2048]" = torch.ops.aten.mm.default(view_156, permute_137); permute_137 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:410 in forward, code: h = x + self.attention(self.attention_norm(x), freqs_cis)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] add_29: "bf16[2048, 2048]" = torch.ops.aten.add.Tensor(add_27, mm_52); add_27 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:74 in forward, code: output = self._norm(x.float()).type_as(x)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_257: "f32[2048, 2048]" = torch.ops.prims.convert_element_type.default(add_29, torch.float32)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:61 in _norm, code: return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] pow_16: "f32[2048, 2048]" = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_257, 2)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mean_15: "f32[2048, 1]" = torch.ops.aten.mean.dim(pow_16, [-1], True); pow_16 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] add_30: "f32[2048, 1]" = torch.ops.aten.add.Tensor(mean_15, 1e-05); mean_15 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] rsqrt_15: "f32[2048, 1]" = torch.ops.aten.rsqrt.default(add_30); add_30 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mul_60: "f32[2048, 2048]" = torch.ops.aten.mul.Tensor(convert_element_type_257, rsqrt_15); convert_element_type_257 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:74 in forward, code: output = self._norm(x.float()).type_as(x)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_258: "bf16[2048, 2048]" = torch.ops.prims.convert_element_type.default(mul_60, torch.bfloat16); mul_60 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:75 in forward, code: return output * self.weight
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mul_61: "bf16[2048, 2048]" = torch.ops.aten.mul.Tensor(convert_element_type_258, contiguous_view_as_strided_74); convert_element_type_258 = contiguous_view_as_strided_74 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/pytorch_yf225/torch/nn/modules/linear.py:116 in forward, code: return F.linear(input, self.weight, self.bias)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] permute_139: "bf16[2048, 5632]" = torch.ops.aten.permute.default(contiguous_view_as_strided_70, [1, 0]); contiguous_view_as_strided_70 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mm_53: "bf16[2048, 5632]" = torch.ops.aten.mm.default(mul_61, permute_139); permute_139 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:299 in forward, code: return self.w2(F.silu(self.w1(x)) * self.w3(x))
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_261: "f32[2048, 5632]" = torch.ops.prims.convert_element_type.default(mm_53, torch.float32)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] sigmoid_7: "f32[2048, 5632]" = torch.ops.aten.sigmoid.default(convert_element_type_261)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mul_62: "f32[2048, 5632]" = torch.ops.aten.mul.Tensor(convert_element_type_261, sigmoid_7); convert_element_type_261 = sigmoid_7 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_262: "bf16[2048, 5632]" = torch.ops.prims.convert_element_type.default(mul_62, torch.bfloat16); mul_62 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/pytorch_yf225/torch/nn/modules/linear.py:116 in forward, code: return F.linear(input, self.weight, self.bias)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] permute_141: "bf16[2048, 5632]" = torch.ops.aten.permute.default(contiguous_view_as_strided_72, [1, 0]); contiguous_view_as_strided_72 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mm_54: "bf16[2048, 5632]" = torch.ops.aten.mm.default(mul_61, permute_141); permute_141 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:299 in forward, code: return self.w2(F.silu(self.w1(x)) * self.w3(x))
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mul_63: "bf16[2048, 5632]" = torch.ops.aten.mul.Tensor(convert_element_type_262, mm_54); convert_element_type_262 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/pytorch_yf225/torch/nn/modules/linear.py:116 in forward, code: return F.linear(input, self.weight, self.bias)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] permute_143: "bf16[5632, 2048]" = torch.ops.aten.permute.default(contiguous_view_as_strided_71, [1, 0]); contiguous_view_as_strided_71 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mm_55: "bf16[2048, 2048]" = torch.ops.aten.mm.default(mul_63, permute_143); permute_143 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:411 in forward, code: out = h + self.feed_forward(self.ffn_norm(h))
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] add_31: "bf16[2048, 2048]" = torch.ops.aten.add.Tensor(add_29, mm_55); add_29 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_common.py:140 in _to_dtype_if_needed, code: return tensor.to(dtype)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_267: "bf16[524288]" = torch.ops.prims.convert_element_type.default(primals_154, torch.bfloat16); primals_154 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_268: "bf16[524288]" = torch.ops.prims.convert_element_type.default(primals_155, torch.bfloat16); primals_155 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_269: "bf16[524288]" = torch.ops.prims.convert_element_type.default(primals_156, torch.bfloat16); primals_156 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_270: "bf16[524288]" = torch.ops.prims.convert_element_type.default(primals_157, torch.bfloat16); primals_157 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_271: "bf16[1441792]" = torch.ops.prims.convert_element_type.default(primals_158, torch.bfloat16); primals_158 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_272: "bf16[1441792]" = torch.ops.prims.convert_element_type.default(primals_159, torch.bfloat16); primals_159 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_273: "bf16[1441792]" = torch.ops.prims.convert_element_type.default(primals_160, torch.bfloat16); primals_160 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_274: "bf16[256]" = torch.ops.prims.convert_element_type.default(primals_161, torch.bfloat16); primals_161 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_275: "bf16[256]" = torch.ops.prims.convert_element_type.default(primals_162, torch.bfloat16); primals_162 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:80 in foreach_all_gather, code: all_gather_input, all_gather_output = torch.ops.fsdp.all_gather_copy_in(all_gather_input_numel, world_size, rank, dtype, device, inp_split_sizes, param_all_gather_inputs)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] all_gather_copy_in_9 = torch.ops.fsdp.all_gather_copy_in.default(6423040, 8, 0, torch.bfloat16, device(type='cuda', index=0), [524288, 524288, 524288, 524288, 1441792, 1441792, 1441792, 256, 256], [convert_element_type_267, convert_element_type_268, convert_element_type_269, convert_element_type_270, convert_element_type_271, convert_element_type_272, convert_element_type_273, convert_element_type_274, convert_element_type_275]); convert_element_type_267 = convert_element_type_268 = convert_element_type_269 = convert_element_type_270 = convert_element_type_271 = convert_element_type_272 = convert_element_type_273 = convert_element_type_274 = convert_element_type_275 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_822: "bf16[6423040]" = all_gather_copy_in_9[0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_823: "bf16[51384320]" = all_gather_copy_in_9[1]; all_gather_copy_in_9 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_functional_collectives.py:229 in all_gather_tensor, code: tensor = torch.ops._c10d_functional.all_gather_into_tensor(
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] all_gather_into_tensor_9: "bf16[51384320]" = torch.ops._c10d_functional.all_gather_into_tensor.default(getitem_822, 8, '0'); getitem_822 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_functional_collectives.py:144 in wait_tensor, code: return torch.ops._c10d_functional.wait_tensor(tensor) # type: ignore[attr-defined]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] wait_tensor_9: "bf16[51384320]" = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_9); all_gather_into_tensor_9 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:170 in foreach_all_gather_copy_out, code: split_unpadded = torch.ops.fsdp.contiguous_view_as_strided(
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_158: "bf16[8, 6423040]" = torch.ops.aten.reshape.default(wait_tensor_9, [8, -1]); wait_tensor_9 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] split_with_sizes_85 = torch.ops.aten.split_with_sizes.default(view_158, [524288, 524288, 524288, 524288, 1441792, 1441792, 1441792, 256, 256], 1); view_158 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_833: "bf16[8, 524288]" = split_with_sizes_85[0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] contiguous_view_as_strided_75: "bf16[2048, 2048]" = torch.ops.fsdp.contiguous_view_as_strided.default(getitem_833, [2048, 2048], [2048, 1]); getitem_833 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_843: "bf16[8, 524288]" = split_with_sizes_85[1]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] contiguous_view_as_strided_76: "bf16[2048, 2048]" = torch.ops.fsdp.contiguous_view_as_strided.default(getitem_843, [2048, 2048], [2048, 1]); getitem_843 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_853: "bf16[8, 524288]" = split_with_sizes_85[2]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] contiguous_view_as_strided_77: "bf16[2048, 2048]" = torch.ops.fsdp.contiguous_view_as_strided.default(getitem_853, [2048, 2048], [2048, 1]); getitem_853 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_863: "bf16[8, 524288]" = split_with_sizes_85[3]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] contiguous_view_as_strided_78: "bf16[2048, 2048]" = torch.ops.fsdp.contiguous_view_as_strided.default(getitem_863, [2048, 2048], [2048, 1]); getitem_863 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_873: "bf16[8, 1441792]" = split_with_sizes_85[4]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] contiguous_view_as_strided_79: "bf16[5632, 2048]" = torch.ops.fsdp.contiguous_view_as_strided.default(getitem_873, [5632, 2048], [2048, 1]); getitem_873 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_883: "bf16[8, 1441792]" = split_with_sizes_85[5]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] contiguous_view_as_strided_80: "bf16[2048, 5632]" = torch.ops.fsdp.contiguous_view_as_strided.default(getitem_883, [2048, 5632], [5632, 1]); getitem_883 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_893: "bf16[8, 1441792]" = split_with_sizes_85[6]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] contiguous_view_as_strided_81: "bf16[5632, 2048]" = torch.ops.fsdp.contiguous_view_as_strided.default(getitem_893, [5632, 2048], [2048, 1]); getitem_893 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_903: "bf16[8, 256]" = split_with_sizes_85[7]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] contiguous_view_as_strided_82: "bf16[2048]" = torch.ops.fsdp.contiguous_view_as_strided.default(getitem_903, [2048], [1]); getitem_903 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_913: "bf16[8, 256]" = split_with_sizes_85[8]; split_with_sizes_85 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] contiguous_view_as_strided_83: "bf16[2048]" = torch.ops.fsdp.contiguous_view_as_strided.default(getitem_913, [2048], [1]); getitem_913 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:74 in forward, code: output = self._norm(x.float()).type_as(x)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_276: "f32[2048, 2048]" = torch.ops.prims.convert_element_type.default(add_31, torch.float32)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:61 in _norm, code: return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] pow_17: "f32[2048, 2048]" = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_276, 2)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mean_16: "f32[2048, 1]" = torch.ops.aten.mean.dim(pow_17, [-1], True); pow_17 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] add_32: "f32[2048, 1]" = torch.ops.aten.add.Tensor(mean_16, 1e-05); mean_16 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] rsqrt_16: "f32[2048, 1]" = torch.ops.aten.rsqrt.default(add_32); add_32 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mul_64: "f32[2048, 2048]" = torch.ops.aten.mul.Tensor(convert_element_type_276, rsqrt_16); convert_element_type_276 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:74 in forward, code: output = self._norm(x.float()).type_as(x)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_277: "bf16[2048, 2048]" = torch.ops.prims.convert_element_type.default(mul_64, torch.bfloat16); mul_64 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:75 in forward, code: return output * self.weight
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mul_65: "bf16[2048, 2048]" = torch.ops.aten.mul.Tensor(convert_element_type_277, contiguous_view_as_strided_82); convert_element_type_277 = contiguous_view_as_strided_82 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/pytorch_yf225/torch/nn/modules/linear.py:116 in forward, code: return F.linear(input, self.weight, self.bias)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] permute_145: "bf16[2048, 2048]" = torch.ops.aten.permute.default(contiguous_view_as_strided_75, [1, 0]); contiguous_view_as_strided_75 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mm_56: "bf16[2048, 2048]" = torch.ops.aten.mm.default(mul_65, permute_145); permute_145 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] permute_147: "bf16[2048, 2048]" = torch.ops.aten.permute.default(contiguous_view_as_strided_76, [1, 0]); contiguous_view_as_strided_76 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mm_57: "bf16[2048, 2048]" = torch.ops.aten.mm.default(mul_65, permute_147); permute_147 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] permute_149: "bf16[2048, 2048]" = torch.ops.aten.permute.default(contiguous_view_as_strided_77, [1, 0]); contiguous_view_as_strided_77 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mm_58: "bf16[2048, 2048]" = torch.ops.aten.mm.default(mul_65, permute_149); permute_149 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:235 in forward, code: xq = xq.view(bsz, seqlen, self.n_heads, self.head_dim)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_167: "bf16[1, 2048, 16, 128]" = torch.ops.aten.reshape.default(mm_56, [1, 2048, 16, 128]); mm_56 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:236 in forward, code: xk = xk.view(bsz, seqlen, self.n_kv_heads, self.head_dim)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_168: "bf16[1, 2048, 16, 128]" = torch.ops.aten.reshape.default(mm_57, [1, 2048, 16, 128]); mm_57 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:237 in forward, code: xv = xv.view(bsz, seqlen, self.n_kv_heads, self.head_dim)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_169: "bf16[1, 2048, 16, 128]" = torch.ops.aten.reshape.default(mm_58, [1, 2048, 16, 128]); mm_58 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:146 in apply_rotary_emb, code: xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_284: "f32[1, 2048, 16, 128]" = torch.ops.prims.convert_element_type.default(view_167, torch.float32); view_167 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_170: "f32[1, 2048, 16, 64, 2]" = torch.ops.aten.reshape.default(convert_element_type_284, [1, 2048, 16, -1, 2]); convert_element_type_284 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_as_complex_16: "c64[1, 2048, 16, 64]" = torch.ops.aten.view_as_complex.default(view_170); view_170 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:147 in apply_rotary_emb, code: xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_285: "f32[1, 2048, 16, 128]" = torch.ops.prims.convert_element_type.default(view_168, torch.float32); view_168 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_171: "f32[1, 2048, 16, 64, 2]" = torch.ops.aten.reshape.default(convert_element_type_285, [1, 2048, 16, -1, 2]); convert_element_type_285 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_as_complex_17: "c64[1, 2048, 16, 64]" = torch.ops.aten.view_as_complex.default(view_171); view_171 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:149 in apply_rotary_emb, code: xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mul_66: "c64[1, 2048, 16, 64]" = torch.ops.aten.mul.Tensor(view_as_complex_16, view_20); view_as_complex_16 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_as_real_16: "f32[1, 2048, 16, 64, 2]" = torch.ops.aten.view_as_real.default(mul_66); mul_66 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_173: "f32[1, 2048, 16, 128]" = torch.ops.aten.reshape.default(view_as_real_16, [1, 2048, 16, 128]); view_as_real_16 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:150 in apply_rotary_emb, code: xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mul_67: "c64[1, 2048, 16, 64]" = torch.ops.aten.mul.Tensor(view_as_complex_17, view_20); view_as_complex_17 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_as_real_17: "f32[1, 2048, 16, 64, 2]" = torch.ops.aten.view_as_real.default(mul_67); mul_67 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_174: "f32[1, 2048, 16, 128]" = torch.ops.aten.reshape.default(view_as_real_17, [1, 2048, 16, 128]); view_as_real_17 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:151 in apply_rotary_emb, code: return xq_out.type_as(xq), xk_out.type_as(xk)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_286: "bf16[1, 2048, 16, 128]" = torch.ops.prims.convert_element_type.default(view_173, torch.bfloat16); view_173 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_287: "bf16[1, 2048, 16, 128]" = torch.ops.prims.convert_element_type.default(view_174, torch.bfloat16); view_174 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:249 in forward, code: xq = xq.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] permute_150: "bf16[1, 16, 2048, 128]" = torch.ops.aten.permute.default(convert_element_type_286, [0, 2, 1, 3]); convert_element_type_286 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:250 in forward, code: xk = keys.transpose(1, 2) # (bs, n_local_heads, cache_len + seqlen, head_dim)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] permute_151: "bf16[1, 16, 2048, 128]" = torch.ops.aten.permute.default(convert_element_type_287, [0, 2, 1, 3]); convert_element_type_287 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:251 in forward, code: xv = values.transpose(1, 2) # (bs, n_local_heads, cache_len + seqlen, head_dim)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] permute_152: "bf16[1, 16, 2048, 128]" = torch.ops.aten.permute.default(view_169, [0, 2, 1, 3]); view_169 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:254 in forward, code: output = F.scaled_dot_product_attention(xq, xk, xv, is_causal=True)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] _scaled_dot_product_flash_attention_8 = torch.ops.aten._scaled_dot_product_flash_attention.default(permute_150, permute_151, permute_152, 0.0, True, scale = 0.08838834764831843)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_914: "bf16[1, 16, 2048, 128]" = _scaled_dot_product_flash_attention_8[0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_915: "f32[1, 16, 2048]" = _scaled_dot_product_flash_attention_8[1]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_920: "i64[]" = _scaled_dot_product_flash_attention_8[6]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_921: "i64[]" = _scaled_dot_product_flash_attention_8[7]; _scaled_dot_product_flash_attention_8 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:255 in forward, code: output = output.transpose(
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] permute_153: "bf16[1, 2048, 16, 128]" = torch.ops.aten.permute.default(getitem_914, [0, 2, 1, 3])
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:259 in forward, code: output = output.view(bsz * seqlen, -1)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_175: "bf16[2048, 2048]" = torch.ops.aten.reshape.default(permute_153, [2048, -1]); permute_153 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/pytorch_yf225/torch/nn/modules/linear.py:116 in forward, code: return F.linear(input, self.weight, self.bias)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] permute_155: "bf16[2048, 2048]" = torch.ops.aten.permute.default(contiguous_view_as_strided_78, [1, 0]); contiguous_view_as_strided_78 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mm_59: "bf16[2048, 2048]" = torch.ops.aten.mm.default(view_175, permute_155); permute_155 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:410 in forward, code: h = x + self.attention(self.attention_norm(x), freqs_cis)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] add_33: "bf16[2048, 2048]" = torch.ops.aten.add.Tensor(add_31, mm_59); add_31 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:74 in forward, code: output = self._norm(x.float()).type_as(x)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_290: "f32[2048, 2048]" = torch.ops.prims.convert_element_type.default(add_33, torch.float32)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:61 in _norm, code: return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] pow_18: "f32[2048, 2048]" = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_290, 2)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mean_17: "f32[2048, 1]" = torch.ops.aten.mean.dim(pow_18, [-1], True); pow_18 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] add_34: "f32[2048, 1]" = torch.ops.aten.add.Tensor(mean_17, 1e-05); mean_17 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] rsqrt_17: "f32[2048, 1]" = torch.ops.aten.rsqrt.default(add_34); add_34 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mul_68: "f32[2048, 2048]" = torch.ops.aten.mul.Tensor(convert_element_type_290, rsqrt_17); convert_element_type_290 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:74 in forward, code: output = self._norm(x.float()).type_as(x)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_291: "bf16[2048, 2048]" = torch.ops.prims.convert_element_type.default(mul_68, torch.bfloat16); mul_68 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:75 in forward, code: return output * self.weight
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mul_69: "bf16[2048, 2048]" = torch.ops.aten.mul.Tensor(convert_element_type_291, contiguous_view_as_strided_83); convert_element_type_291 = contiguous_view_as_strided_83 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/pytorch_yf225/torch/nn/modules/linear.py:116 in forward, code: return F.linear(input, self.weight, self.bias)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] permute_157: "bf16[2048, 5632]" = torch.ops.aten.permute.default(contiguous_view_as_strided_79, [1, 0]); contiguous_view_as_strided_79 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mm_60: "bf16[2048, 5632]" = torch.ops.aten.mm.default(mul_69, permute_157); permute_157 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:299 in forward, code: return self.w2(F.silu(self.w1(x)) * self.w3(x))
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_294: "f32[2048, 5632]" = torch.ops.prims.convert_element_type.default(mm_60, torch.float32)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] sigmoid_8: "f32[2048, 5632]" = torch.ops.aten.sigmoid.default(convert_element_type_294)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mul_70: "f32[2048, 5632]" = torch.ops.aten.mul.Tensor(convert_element_type_294, sigmoid_8); convert_element_type_294 = sigmoid_8 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_295: "bf16[2048, 5632]" = torch.ops.prims.convert_element_type.default(mul_70, torch.bfloat16); mul_70 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/pytorch_yf225/torch/nn/modules/linear.py:116 in forward, code: return F.linear(input, self.weight, self.bias)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] permute_159: "bf16[2048, 5632]" = torch.ops.aten.permute.default(contiguous_view_as_strided_81, [1, 0]); contiguous_view_as_strided_81 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mm_61: "bf16[2048, 5632]" = torch.ops.aten.mm.default(mul_69, permute_159); permute_159 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:299 in forward, code: return self.w2(F.silu(self.w1(x)) * self.w3(x))
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mul_71: "bf16[2048, 5632]" = torch.ops.aten.mul.Tensor(convert_element_type_295, mm_61); convert_element_type_295 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/pytorch_yf225/torch/nn/modules/linear.py:116 in forward, code: return F.linear(input, self.weight, self.bias)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] permute_161: "bf16[5632, 2048]" = torch.ops.aten.permute.default(contiguous_view_as_strided_80, [1, 0]); contiguous_view_as_strided_80 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mm_62: "bf16[2048, 2048]" = torch.ops.aten.mm.default(mul_71, permute_161); permute_161 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:411 in forward, code: out = h + self.feed_forward(self.ffn_norm(h))
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] add_35: "bf16[2048, 2048]" = torch.ops.aten.add.Tensor(add_33, mm_62); add_33 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_common.py:140 in _to_dtype_if_needed, code: return tensor.to(dtype)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_300: "bf16[524288]" = torch.ops.prims.convert_element_type.default(primals_172, torch.bfloat16); primals_172 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_301: "bf16[524288]" = torch.ops.prims.convert_element_type.default(primals_173, torch.bfloat16); primals_173 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_302: "bf16[524288]" = torch.ops.prims.convert_element_type.default(primals_174, torch.bfloat16); primals_174 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_303: "bf16[524288]" = torch.ops.prims.convert_element_type.default(primals_175, torch.bfloat16); primals_175 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_304: "bf16[1441792]" = torch.ops.prims.convert_element_type.default(primals_176, torch.bfloat16); primals_176 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_305: "bf16[1441792]" = torch.ops.prims.convert_element_type.default(primals_177, torch.bfloat16); primals_177 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_306: "bf16[1441792]" = torch.ops.prims.convert_element_type.default(primals_178, torch.bfloat16); primals_178 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_307: "bf16[256]" = torch.ops.prims.convert_element_type.default(primals_179, torch.bfloat16); primals_179 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_308: "bf16[256]" = torch.ops.prims.convert_element_type.default(primals_180, torch.bfloat16); primals_180 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:80 in foreach_all_gather, code: all_gather_input, all_gather_output = torch.ops.fsdp.all_gather_copy_in(all_gather_input_numel, world_size, rank, dtype, device, inp_split_sizes, param_all_gather_inputs)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] all_gather_copy_in_10 = torch.ops.fsdp.all_gather_copy_in.default(6423040, 8, 0, torch.bfloat16, device(type='cuda', index=0), [524288, 524288, 524288, 524288, 1441792, 1441792, 1441792, 256, 256], [convert_element_type_300, convert_element_type_301, convert_element_type_302, convert_element_type_303, convert_element_type_304, convert_element_type_305, convert_element_type_306, convert_element_type_307, convert_element_type_308]); convert_element_type_300 = convert_element_type_301 = convert_element_type_302 = convert_element_type_303 = convert_element_type_304 = convert_element_type_305 = convert_element_type_306 = convert_element_type_307 = convert_element_type_308 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_923: "bf16[6423040]" = all_gather_copy_in_10[0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_924: "bf16[51384320]" = all_gather_copy_in_10[1]; all_gather_copy_in_10 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_functional_collectives.py:229 in all_gather_tensor, code: tensor = torch.ops._c10d_functional.all_gather_into_tensor(
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] all_gather_into_tensor_10: "bf16[51384320]" = torch.ops._c10d_functional.all_gather_into_tensor.default(getitem_923, 8, '0'); getitem_923 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_functional_collectives.py:144 in wait_tensor, code: return torch.ops._c10d_functional.wait_tensor(tensor) # type: ignore[attr-defined]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] wait_tensor_10: "bf16[51384320]" = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_10); all_gather_into_tensor_10 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:170 in foreach_all_gather_copy_out, code: split_unpadded = torch.ops.fsdp.contiguous_view_as_strided(
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_177: "bf16[8, 6423040]" = torch.ops.aten.reshape.default(wait_tensor_10, [8, -1]); wait_tensor_10 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] split_with_sizes_95 = torch.ops.aten.split_with_sizes.default(view_177, [524288, 524288, 524288, 524288, 1441792, 1441792, 1441792, 256, 256], 1); view_177 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_934: "bf16[8, 524288]" = split_with_sizes_95[0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] contiguous_view_as_strided_84: "bf16[2048, 2048]" = torch.ops.fsdp.contiguous_view_as_strided.default(getitem_934, [2048, 2048], [2048, 1]); getitem_934 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_944: "bf16[8, 524288]" = split_with_sizes_95[1]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] contiguous_view_as_strided_85: "bf16[2048, 2048]" = torch.ops.fsdp.contiguous_view_as_strided.default(getitem_944, [2048, 2048], [2048, 1]); getitem_944 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_954: "bf16[8, 524288]" = split_with_sizes_95[2]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] contiguous_view_as_strided_86: "bf16[2048, 2048]" = torch.ops.fsdp.contiguous_view_as_strided.default(getitem_954, [2048, 2048], [2048, 1]); getitem_954 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_964: "bf16[8, 524288]" = split_with_sizes_95[3]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] contiguous_view_as_strided_87: "bf16[2048, 2048]" = torch.ops.fsdp.contiguous_view_as_strided.default(getitem_964, [2048, 2048], [2048, 1]); getitem_964 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_974: "bf16[8, 1441792]" = split_with_sizes_95[4]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] contiguous_view_as_strided_88: "bf16[5632, 2048]" = torch.ops.fsdp.contiguous_view_as_strided.default(getitem_974, [5632, 2048], [2048, 1]); getitem_974 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_984: "bf16[8, 1441792]" = split_with_sizes_95[5]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] contiguous_view_as_strided_89: "bf16[2048, 5632]" = torch.ops.fsdp.contiguous_view_as_strided.default(getitem_984, [2048, 5632], [5632, 1]); getitem_984 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_994: "bf16[8, 1441792]" = split_with_sizes_95[6]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] contiguous_view_as_strided_90: "bf16[5632, 2048]" = torch.ops.fsdp.contiguous_view_as_strided.default(getitem_994, [5632, 2048], [2048, 1]); getitem_994 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_1004: "bf16[8, 256]" = split_with_sizes_95[7]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] contiguous_view_as_strided_91: "bf16[2048]" = torch.ops.fsdp.contiguous_view_as_strided.default(getitem_1004, [2048], [1]); getitem_1004 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_1014: "bf16[8, 256]" = split_with_sizes_95[8]; split_with_sizes_95 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] contiguous_view_as_strided_92: "bf16[2048]" = torch.ops.fsdp.contiguous_view_as_strided.default(getitem_1014, [2048], [1]); getitem_1014 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:74 in forward, code: output = self._norm(x.float()).type_as(x)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_309: "f32[2048, 2048]" = torch.ops.prims.convert_element_type.default(add_35, torch.float32)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:61 in _norm, code: return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] pow_19: "f32[2048, 2048]" = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_309, 2)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mean_18: "f32[2048, 1]" = torch.ops.aten.mean.dim(pow_19, [-1], True); pow_19 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] add_36: "f32[2048, 1]" = torch.ops.aten.add.Tensor(mean_18, 1e-05); mean_18 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] rsqrt_18: "f32[2048, 1]" = torch.ops.aten.rsqrt.default(add_36); add_36 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mul_72: "f32[2048, 2048]" = torch.ops.aten.mul.Tensor(convert_element_type_309, rsqrt_18); convert_element_type_309 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:74 in forward, code: output = self._norm(x.float()).type_as(x)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_310: "bf16[2048, 2048]" = torch.ops.prims.convert_element_type.default(mul_72, torch.bfloat16); mul_72 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:75 in forward, code: return output * self.weight
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mul_73: "bf16[2048, 2048]" = torch.ops.aten.mul.Tensor(convert_element_type_310, contiguous_view_as_strided_91); convert_element_type_310 = contiguous_view_as_strided_91 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/pytorch_yf225/torch/nn/modules/linear.py:116 in forward, code: return F.linear(input, self.weight, self.bias)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] permute_163: "bf16[2048, 2048]" = torch.ops.aten.permute.default(contiguous_view_as_strided_84, [1, 0]); contiguous_view_as_strided_84 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mm_63: "bf16[2048, 2048]" = torch.ops.aten.mm.default(mul_73, permute_163); permute_163 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] permute_165: "bf16[2048, 2048]" = torch.ops.aten.permute.default(contiguous_view_as_strided_85, [1, 0]); contiguous_view_as_strided_85 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mm_64: "bf16[2048, 2048]" = torch.ops.aten.mm.default(mul_73, permute_165); permute_165 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] permute_167: "bf16[2048, 2048]" = torch.ops.aten.permute.default(contiguous_view_as_strided_86, [1, 0]); contiguous_view_as_strided_86 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mm_65: "bf16[2048, 2048]" = torch.ops.aten.mm.default(mul_73, permute_167); permute_167 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:235 in forward, code: xq = xq.view(bsz, seqlen, self.n_heads, self.head_dim)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_186: "bf16[1, 2048, 16, 128]" = torch.ops.aten.reshape.default(mm_63, [1, 2048, 16, 128]); mm_63 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:236 in forward, code: xk = xk.view(bsz, seqlen, self.n_kv_heads, self.head_dim)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_187: "bf16[1, 2048, 16, 128]" = torch.ops.aten.reshape.default(mm_64, [1, 2048, 16, 128]); mm_64 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:237 in forward, code: xv = xv.view(bsz, seqlen, self.n_kv_heads, self.head_dim)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_188: "bf16[1, 2048, 16, 128]" = torch.ops.aten.reshape.default(mm_65, [1, 2048, 16, 128]); mm_65 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:146 in apply_rotary_emb, code: xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_317: "f32[1, 2048, 16, 128]" = torch.ops.prims.convert_element_type.default(view_186, torch.float32); view_186 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_189: "f32[1, 2048, 16, 64, 2]" = torch.ops.aten.reshape.default(convert_element_type_317, [1, 2048, 16, -1, 2]); convert_element_type_317 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_as_complex_18: "c64[1, 2048, 16, 64]" = torch.ops.aten.view_as_complex.default(view_189); view_189 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:147 in apply_rotary_emb, code: xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_318: "f32[1, 2048, 16, 128]" = torch.ops.prims.convert_element_type.default(view_187, torch.float32); view_187 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_190: "f32[1, 2048, 16, 64, 2]" = torch.ops.aten.reshape.default(convert_element_type_318, [1, 2048, 16, -1, 2]); convert_element_type_318 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_as_complex_19: "c64[1, 2048, 16, 64]" = torch.ops.aten.view_as_complex.default(view_190); view_190 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:149 in apply_rotary_emb, code: xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mul_74: "c64[1, 2048, 16, 64]" = torch.ops.aten.mul.Tensor(view_as_complex_18, view_20); view_as_complex_18 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_as_real_18: "f32[1, 2048, 16, 64, 2]" = torch.ops.aten.view_as_real.default(mul_74); mul_74 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_192: "f32[1, 2048, 16, 128]" = torch.ops.aten.reshape.default(view_as_real_18, [1, 2048, 16, 128]); view_as_real_18 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:150 in apply_rotary_emb, code: xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mul_75: "c64[1, 2048, 16, 64]" = torch.ops.aten.mul.Tensor(view_as_complex_19, view_20); view_as_complex_19 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_as_real_19: "f32[1, 2048, 16, 64, 2]" = torch.ops.aten.view_as_real.default(mul_75); mul_75 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_193: "f32[1, 2048, 16, 128]" = torch.ops.aten.reshape.default(view_as_real_19, [1, 2048, 16, 128]); view_as_real_19 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:151 in apply_rotary_emb, code: return xq_out.type_as(xq), xk_out.type_as(xk)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_319: "bf16[1, 2048, 16, 128]" = torch.ops.prims.convert_element_type.default(view_192, torch.bfloat16); view_192 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_320: "bf16[1, 2048, 16, 128]" = torch.ops.prims.convert_element_type.default(view_193, torch.bfloat16); view_193 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:249 in forward, code: xq = xq.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] permute_168: "bf16[1, 16, 2048, 128]" = torch.ops.aten.permute.default(convert_element_type_319, [0, 2, 1, 3]); convert_element_type_319 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:250 in forward, code: xk = keys.transpose(1, 2) # (bs, n_local_heads, cache_len + seqlen, head_dim)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] permute_169: "bf16[1, 16, 2048, 128]" = torch.ops.aten.permute.default(convert_element_type_320, [0, 2, 1, 3]); convert_element_type_320 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:251 in forward, code: xv = values.transpose(1, 2) # (bs, n_local_heads, cache_len + seqlen, head_dim)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] permute_170: "bf16[1, 16, 2048, 128]" = torch.ops.aten.permute.default(view_188, [0, 2, 1, 3]); view_188 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:254 in forward, code: output = F.scaled_dot_product_attention(xq, xk, xv, is_causal=True)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] _scaled_dot_product_flash_attention_9 = torch.ops.aten._scaled_dot_product_flash_attention.default(permute_168, permute_169, permute_170, 0.0, True, scale = 0.08838834764831843)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_1015: "bf16[1, 16, 2048, 128]" = _scaled_dot_product_flash_attention_9[0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_1016: "f32[1, 16, 2048]" = _scaled_dot_product_flash_attention_9[1]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_1021: "i64[]" = _scaled_dot_product_flash_attention_9[6]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_1022: "i64[]" = _scaled_dot_product_flash_attention_9[7]; _scaled_dot_product_flash_attention_9 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:255 in forward, code: output = output.transpose(
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] permute_171: "bf16[1, 2048, 16, 128]" = torch.ops.aten.permute.default(getitem_1015, [0, 2, 1, 3])
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:259 in forward, code: output = output.view(bsz * seqlen, -1)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_194: "bf16[2048, 2048]" = torch.ops.aten.reshape.default(permute_171, [2048, -1]); permute_171 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/pytorch_yf225/torch/nn/modules/linear.py:116 in forward, code: return F.linear(input, self.weight, self.bias)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] permute_173: "bf16[2048, 2048]" = torch.ops.aten.permute.default(contiguous_view_as_strided_87, [1, 0]); contiguous_view_as_strided_87 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mm_66: "bf16[2048, 2048]" = torch.ops.aten.mm.default(view_194, permute_173); permute_173 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:410 in forward, code: h = x + self.attention(self.attention_norm(x), freqs_cis)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] add_37: "bf16[2048, 2048]" = torch.ops.aten.add.Tensor(add_35, mm_66); add_35 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:74 in forward, code: output = self._norm(x.float()).type_as(x)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_323: "f32[2048, 2048]" = torch.ops.prims.convert_element_type.default(add_37, torch.float32)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:61 in _norm, code: return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] pow_20: "f32[2048, 2048]" = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_323, 2)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mean_19: "f32[2048, 1]" = torch.ops.aten.mean.dim(pow_20, [-1], True); pow_20 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] add_38: "f32[2048, 1]" = torch.ops.aten.add.Tensor(mean_19, 1e-05); mean_19 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] rsqrt_19: "f32[2048, 1]" = torch.ops.aten.rsqrt.default(add_38); add_38 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mul_76: "f32[2048, 2048]" = torch.ops.aten.mul.Tensor(convert_element_type_323, rsqrt_19); convert_element_type_323 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:74 in forward, code: output = self._norm(x.float()).type_as(x)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_324: "bf16[2048, 2048]" = torch.ops.prims.convert_element_type.default(mul_76, torch.bfloat16); mul_76 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:75 in forward, code: return output * self.weight
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mul_77: "bf16[2048, 2048]" = torch.ops.aten.mul.Tensor(convert_element_type_324, contiguous_view_as_strided_92); convert_element_type_324 = contiguous_view_as_strided_92 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/pytorch_yf225/torch/nn/modules/linear.py:116 in forward, code: return F.linear(input, self.weight, self.bias)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] permute_175: "bf16[2048, 5632]" = torch.ops.aten.permute.default(contiguous_view_as_strided_88, [1, 0]); contiguous_view_as_strided_88 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mm_67: "bf16[2048, 5632]" = torch.ops.aten.mm.default(mul_77, permute_175); permute_175 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:299 in forward, code: return self.w2(F.silu(self.w1(x)) * self.w3(x))
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_327: "f32[2048, 5632]" = torch.ops.prims.convert_element_type.default(mm_67, torch.float32)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] sigmoid_9: "f32[2048, 5632]" = torch.ops.aten.sigmoid.default(convert_element_type_327)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mul_78: "f32[2048, 5632]" = torch.ops.aten.mul.Tensor(convert_element_type_327, sigmoid_9); convert_element_type_327 = sigmoid_9 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_328: "bf16[2048, 5632]" = torch.ops.prims.convert_element_type.default(mul_78, torch.bfloat16); mul_78 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/pytorch_yf225/torch/nn/modules/linear.py:116 in forward, code: return F.linear(input, self.weight, self.bias)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] permute_177: "bf16[2048, 5632]" = torch.ops.aten.permute.default(contiguous_view_as_strided_90, [1, 0]); contiguous_view_as_strided_90 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mm_68: "bf16[2048, 5632]" = torch.ops.aten.mm.default(mul_77, permute_177); permute_177 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:299 in forward, code: return self.w2(F.silu(self.w1(x)) * self.w3(x))
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mul_79: "bf16[2048, 5632]" = torch.ops.aten.mul.Tensor(convert_element_type_328, mm_68); convert_element_type_328 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/pytorch_yf225/torch/nn/modules/linear.py:116 in forward, code: return F.linear(input, self.weight, self.bias)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] permute_179: "bf16[5632, 2048]" = torch.ops.aten.permute.default(contiguous_view_as_strided_89, [1, 0]); contiguous_view_as_strided_89 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mm_69: "bf16[2048, 2048]" = torch.ops.aten.mm.default(mul_79, permute_179); permute_179 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:411 in forward, code: out = h + self.feed_forward(self.ffn_norm(h))
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] add_39: "bf16[2048, 2048]" = torch.ops.aten.add.Tensor(add_37, mm_69); add_37 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_common.py:140 in _to_dtype_if_needed, code: return tensor.to(dtype)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_333: "bf16[524288]" = torch.ops.prims.convert_element_type.default(primals_190, torch.bfloat16); primals_190 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_334: "bf16[524288]" = torch.ops.prims.convert_element_type.default(primals_191, torch.bfloat16); primals_191 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_335: "bf16[524288]" = torch.ops.prims.convert_element_type.default(primals_192, torch.bfloat16); primals_192 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_336: "bf16[524288]" = torch.ops.prims.convert_element_type.default(primals_193, torch.bfloat16); primals_193 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_337: "bf16[1441792]" = torch.ops.prims.convert_element_type.default(primals_194, torch.bfloat16); primals_194 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_338: "bf16[1441792]" = torch.ops.prims.convert_element_type.default(primals_195, torch.bfloat16); primals_195 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_339: "bf16[1441792]" = torch.ops.prims.convert_element_type.default(primals_196, torch.bfloat16); primals_196 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_340: "bf16[256]" = torch.ops.prims.convert_element_type.default(primals_197, torch.bfloat16); primals_197 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_341: "bf16[256]" = torch.ops.prims.convert_element_type.default(primals_198, torch.bfloat16); primals_198 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:80 in foreach_all_gather, code: all_gather_input, all_gather_output = torch.ops.fsdp.all_gather_copy_in(all_gather_input_numel, world_size, rank, dtype, device, inp_split_sizes, param_all_gather_inputs)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] all_gather_copy_in_11 = torch.ops.fsdp.all_gather_copy_in.default(6423040, 8, 0, torch.bfloat16, device(type='cuda', index=0), [524288, 524288, 524288, 524288, 1441792, 1441792, 1441792, 256, 256], [convert_element_type_333, convert_element_type_334, convert_element_type_335, convert_element_type_336, convert_element_type_337, convert_element_type_338, convert_element_type_339, convert_element_type_340, convert_element_type_341]); convert_element_type_333 = convert_element_type_334 = convert_element_type_335 = convert_element_type_336 = convert_element_type_337 = convert_element_type_338 = convert_element_type_339 = convert_element_type_340 = convert_element_type_341 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_1024: "bf16[6423040]" = all_gather_copy_in_11[0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_1025: "bf16[51384320]" = all_gather_copy_in_11[1]; all_gather_copy_in_11 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_functional_collectives.py:229 in all_gather_tensor, code: tensor = torch.ops._c10d_functional.all_gather_into_tensor(
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] all_gather_into_tensor_11: "bf16[51384320]" = torch.ops._c10d_functional.all_gather_into_tensor.default(getitem_1024, 8, '0'); getitem_1024 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_functional_collectives.py:144 in wait_tensor, code: return torch.ops._c10d_functional.wait_tensor(tensor) # type: ignore[attr-defined]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] wait_tensor_11: "bf16[51384320]" = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_11); all_gather_into_tensor_11 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:170 in foreach_all_gather_copy_out, code: split_unpadded = torch.ops.fsdp.contiguous_view_as_strided(
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_196: "bf16[8, 6423040]" = torch.ops.aten.reshape.default(wait_tensor_11, [8, -1]); wait_tensor_11 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] split_with_sizes_105 = torch.ops.aten.split_with_sizes.default(view_196, [524288, 524288, 524288, 524288, 1441792, 1441792, 1441792, 256, 256], 1); view_196 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_1035: "bf16[8, 524288]" = split_with_sizes_105[0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] contiguous_view_as_strided_93: "bf16[2048, 2048]" = torch.ops.fsdp.contiguous_view_as_strided.default(getitem_1035, [2048, 2048], [2048, 1]); getitem_1035 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_1045: "bf16[8, 524288]" = split_with_sizes_105[1]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] contiguous_view_as_strided_94: "bf16[2048, 2048]" = torch.ops.fsdp.contiguous_view_as_strided.default(getitem_1045, [2048, 2048], [2048, 1]); getitem_1045 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_1055: "bf16[8, 524288]" = split_with_sizes_105[2]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] contiguous_view_as_strided_95: "bf16[2048, 2048]" = torch.ops.fsdp.contiguous_view_as_strided.default(getitem_1055, [2048, 2048], [2048, 1]); getitem_1055 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_1065: "bf16[8, 524288]" = split_with_sizes_105[3]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] contiguous_view_as_strided_96: "bf16[2048, 2048]" = torch.ops.fsdp.contiguous_view_as_strided.default(getitem_1065, [2048, 2048], [2048, 1]); getitem_1065 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_1075: "bf16[8, 1441792]" = split_with_sizes_105[4]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] contiguous_view_as_strided_97: "bf16[5632, 2048]" = torch.ops.fsdp.contiguous_view_as_strided.default(getitem_1075, [5632, 2048], [2048, 1]); getitem_1075 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_1085: "bf16[8, 1441792]" = split_with_sizes_105[5]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] contiguous_view_as_strided_98: "bf16[2048, 5632]" = torch.ops.fsdp.contiguous_view_as_strided.default(getitem_1085, [2048, 5632], [5632, 1]); getitem_1085 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_1095: "bf16[8, 1441792]" = split_with_sizes_105[6]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] contiguous_view_as_strided_99: "bf16[5632, 2048]" = torch.ops.fsdp.contiguous_view_as_strided.default(getitem_1095, [5632, 2048], [2048, 1]); getitem_1095 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_1105: "bf16[8, 256]" = split_with_sizes_105[7]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] contiguous_view_as_strided_100: "bf16[2048]" = torch.ops.fsdp.contiguous_view_as_strided.default(getitem_1105, [2048], [1]); getitem_1105 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_1115: "bf16[8, 256]" = split_with_sizes_105[8]; split_with_sizes_105 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] contiguous_view_as_strided_101: "bf16[2048]" = torch.ops.fsdp.contiguous_view_as_strided.default(getitem_1115, [2048], [1]); getitem_1115 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:74 in forward, code: output = self._norm(x.float()).type_as(x)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_342: "f32[2048, 2048]" = torch.ops.prims.convert_element_type.default(add_39, torch.float32)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:61 in _norm, code: return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] pow_21: "f32[2048, 2048]" = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_342, 2)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mean_20: "f32[2048, 1]" = torch.ops.aten.mean.dim(pow_21, [-1], True); pow_21 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] add_40: "f32[2048, 1]" = torch.ops.aten.add.Tensor(mean_20, 1e-05); mean_20 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] rsqrt_20: "f32[2048, 1]" = torch.ops.aten.rsqrt.default(add_40); add_40 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mul_80: "f32[2048, 2048]" = torch.ops.aten.mul.Tensor(convert_element_type_342, rsqrt_20); convert_element_type_342 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:74 in forward, code: output = self._norm(x.float()).type_as(x)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_343: "bf16[2048, 2048]" = torch.ops.prims.convert_element_type.default(mul_80, torch.bfloat16); mul_80 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:75 in forward, code: return output * self.weight
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mul_81: "bf16[2048, 2048]" = torch.ops.aten.mul.Tensor(convert_element_type_343, contiguous_view_as_strided_100); convert_element_type_343 = contiguous_view_as_strided_100 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/pytorch_yf225/torch/nn/modules/linear.py:116 in forward, code: return F.linear(input, self.weight, self.bias)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] permute_181: "bf16[2048, 2048]" = torch.ops.aten.permute.default(contiguous_view_as_strided_93, [1, 0]); contiguous_view_as_strided_93 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mm_70: "bf16[2048, 2048]" = torch.ops.aten.mm.default(mul_81, permute_181); permute_181 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] permute_183: "bf16[2048, 2048]" = torch.ops.aten.permute.default(contiguous_view_as_strided_94, [1, 0]); contiguous_view_as_strided_94 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mm_71: "bf16[2048, 2048]" = torch.ops.aten.mm.default(mul_81, permute_183); permute_183 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] permute_185: "bf16[2048, 2048]" = torch.ops.aten.permute.default(contiguous_view_as_strided_95, [1, 0]); contiguous_view_as_strided_95 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mm_72: "bf16[2048, 2048]" = torch.ops.aten.mm.default(mul_81, permute_185); permute_185 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:235 in forward, code: xq = xq.view(bsz, seqlen, self.n_heads, self.head_dim)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_205: "bf16[1, 2048, 16, 128]" = torch.ops.aten.reshape.default(mm_70, [1, 2048, 16, 128]); mm_70 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:236 in forward, code: xk = xk.view(bsz, seqlen, self.n_kv_heads, self.head_dim)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_206: "bf16[1, 2048, 16, 128]" = torch.ops.aten.reshape.default(mm_71, [1, 2048, 16, 128]); mm_71 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:237 in forward, code: xv = xv.view(bsz, seqlen, self.n_kv_heads, self.head_dim)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_207: "bf16[1, 2048, 16, 128]" = torch.ops.aten.reshape.default(mm_72, [1, 2048, 16, 128]); mm_72 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:146 in apply_rotary_emb, code: xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_350: "f32[1, 2048, 16, 128]" = torch.ops.prims.convert_element_type.default(view_205, torch.float32); view_205 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_208: "f32[1, 2048, 16, 64, 2]" = torch.ops.aten.reshape.default(convert_element_type_350, [1, 2048, 16, -1, 2]); convert_element_type_350 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_as_complex_20: "c64[1, 2048, 16, 64]" = torch.ops.aten.view_as_complex.default(view_208); view_208 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:147 in apply_rotary_emb, code: xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_351: "f32[1, 2048, 16, 128]" = torch.ops.prims.convert_element_type.default(view_206, torch.float32); view_206 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_209: "f32[1, 2048, 16, 64, 2]" = torch.ops.aten.reshape.default(convert_element_type_351, [1, 2048, 16, -1, 2]); convert_element_type_351 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_as_complex_21: "c64[1, 2048, 16, 64]" = torch.ops.aten.view_as_complex.default(view_209); view_209 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:149 in apply_rotary_emb, code: xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mul_82: "c64[1, 2048, 16, 64]" = torch.ops.aten.mul.Tensor(view_as_complex_20, view_20); view_as_complex_20 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_as_real_20: "f32[1, 2048, 16, 64, 2]" = torch.ops.aten.view_as_real.default(mul_82); mul_82 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_211: "f32[1, 2048, 16, 128]" = torch.ops.aten.reshape.default(view_as_real_20, [1, 2048, 16, 128]); view_as_real_20 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:150 in apply_rotary_emb, code: xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mul_83: "c64[1, 2048, 16, 64]" = torch.ops.aten.mul.Tensor(view_as_complex_21, view_20); view_as_complex_21 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_as_real_21: "f32[1, 2048, 16, 64, 2]" = torch.ops.aten.view_as_real.default(mul_83); mul_83 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_212: "f32[1, 2048, 16, 128]" = torch.ops.aten.reshape.default(view_as_real_21, [1, 2048, 16, 128]); view_as_real_21 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:151 in apply_rotary_emb, code: return xq_out.type_as(xq), xk_out.type_as(xk)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_352: "bf16[1, 2048, 16, 128]" = torch.ops.prims.convert_element_type.default(view_211, torch.bfloat16); view_211 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_353: "bf16[1, 2048, 16, 128]" = torch.ops.prims.convert_element_type.default(view_212, torch.bfloat16); view_212 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:249 in forward, code: xq = xq.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] permute_186: "bf16[1, 16, 2048, 128]" = torch.ops.aten.permute.default(convert_element_type_352, [0, 2, 1, 3]); convert_element_type_352 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:250 in forward, code: xk = keys.transpose(1, 2) # (bs, n_local_heads, cache_len + seqlen, head_dim)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] permute_187: "bf16[1, 16, 2048, 128]" = torch.ops.aten.permute.default(convert_element_type_353, [0, 2, 1, 3]); convert_element_type_353 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:251 in forward, code: xv = values.transpose(1, 2) # (bs, n_local_heads, cache_len + seqlen, head_dim)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] permute_188: "bf16[1, 16, 2048, 128]" = torch.ops.aten.permute.default(view_207, [0, 2, 1, 3]); view_207 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:254 in forward, code: output = F.scaled_dot_product_attention(xq, xk, xv, is_causal=True)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] _scaled_dot_product_flash_attention_10 = torch.ops.aten._scaled_dot_product_flash_attention.default(permute_186, permute_187, permute_188, 0.0, True, scale = 0.08838834764831843)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_1116: "bf16[1, 16, 2048, 128]" = _scaled_dot_product_flash_attention_10[0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_1117: "f32[1, 16, 2048]" = _scaled_dot_product_flash_attention_10[1]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_1122: "i64[]" = _scaled_dot_product_flash_attention_10[6]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_1123: "i64[]" = _scaled_dot_product_flash_attention_10[7]; _scaled_dot_product_flash_attention_10 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:255 in forward, code: output = output.transpose(
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] permute_189: "bf16[1, 2048, 16, 128]" = torch.ops.aten.permute.default(getitem_1116, [0, 2, 1, 3])
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:259 in forward, code: output = output.view(bsz * seqlen, -1)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_213: "bf16[2048, 2048]" = torch.ops.aten.reshape.default(permute_189, [2048, -1]); permute_189 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/pytorch_yf225/torch/nn/modules/linear.py:116 in forward, code: return F.linear(input, self.weight, self.bias)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] permute_191: "bf16[2048, 2048]" = torch.ops.aten.permute.default(contiguous_view_as_strided_96, [1, 0]); contiguous_view_as_strided_96 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mm_73: "bf16[2048, 2048]" = torch.ops.aten.mm.default(view_213, permute_191); permute_191 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:410 in forward, code: h = x + self.attention(self.attention_norm(x), freqs_cis)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] add_41: "bf16[2048, 2048]" = torch.ops.aten.add.Tensor(add_39, mm_73); add_39 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:74 in forward, code: output = self._norm(x.float()).type_as(x)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_356: "f32[2048, 2048]" = torch.ops.prims.convert_element_type.default(add_41, torch.float32)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:61 in _norm, code: return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] pow_22: "f32[2048, 2048]" = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_356, 2)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mean_21: "f32[2048, 1]" = torch.ops.aten.mean.dim(pow_22, [-1], True); pow_22 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] add_42: "f32[2048, 1]" = torch.ops.aten.add.Tensor(mean_21, 1e-05); mean_21 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] rsqrt_21: "f32[2048, 1]" = torch.ops.aten.rsqrt.default(add_42); add_42 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mul_84: "f32[2048, 2048]" = torch.ops.aten.mul.Tensor(convert_element_type_356, rsqrt_21); convert_element_type_356 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:74 in forward, code: output = self._norm(x.float()).type_as(x)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_357: "bf16[2048, 2048]" = torch.ops.prims.convert_element_type.default(mul_84, torch.bfloat16); mul_84 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:75 in forward, code: return output * self.weight
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mul_85: "bf16[2048, 2048]" = torch.ops.aten.mul.Tensor(convert_element_type_357, contiguous_view_as_strided_101); convert_element_type_357 = contiguous_view_as_strided_101 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/pytorch_yf225/torch/nn/modules/linear.py:116 in forward, code: return F.linear(input, self.weight, self.bias)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] permute_193: "bf16[2048, 5632]" = torch.ops.aten.permute.default(contiguous_view_as_strided_97, [1, 0]); contiguous_view_as_strided_97 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mm_74: "bf16[2048, 5632]" = torch.ops.aten.mm.default(mul_85, permute_193); permute_193 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:299 in forward, code: return self.w2(F.silu(self.w1(x)) * self.w3(x))
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_360: "f32[2048, 5632]" = torch.ops.prims.convert_element_type.default(mm_74, torch.float32)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] sigmoid_10: "f32[2048, 5632]" = torch.ops.aten.sigmoid.default(convert_element_type_360)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mul_86: "f32[2048, 5632]" = torch.ops.aten.mul.Tensor(convert_element_type_360, sigmoid_10); convert_element_type_360 = sigmoid_10 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_361: "bf16[2048, 5632]" = torch.ops.prims.convert_element_type.default(mul_86, torch.bfloat16); mul_86 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/pytorch_yf225/torch/nn/modules/linear.py:116 in forward, code: return F.linear(input, self.weight, self.bias)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] permute_195: "bf16[2048, 5632]" = torch.ops.aten.permute.default(contiguous_view_as_strided_99, [1, 0]); contiguous_view_as_strided_99 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mm_75: "bf16[2048, 5632]" = torch.ops.aten.mm.default(mul_85, permute_195); permute_195 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:299 in forward, code: return self.w2(F.silu(self.w1(x)) * self.w3(x))
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mul_87: "bf16[2048, 5632]" = torch.ops.aten.mul.Tensor(convert_element_type_361, mm_75); convert_element_type_361 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/pytorch_yf225/torch/nn/modules/linear.py:116 in forward, code: return F.linear(input, self.weight, self.bias)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] permute_197: "bf16[5632, 2048]" = torch.ops.aten.permute.default(contiguous_view_as_strided_98, [1, 0]); contiguous_view_as_strided_98 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mm_76: "bf16[2048, 2048]" = torch.ops.aten.mm.default(mul_87, permute_197); permute_197 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:411 in forward, code: out = h + self.feed_forward(self.ffn_norm(h))
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] add_43: "bf16[2048, 2048]" = torch.ops.aten.add.Tensor(add_41, mm_76); add_41 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_common.py:140 in _to_dtype_if_needed, code: return tensor.to(dtype)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_366: "bf16[524288]" = torch.ops.prims.convert_element_type.default(primals_208, torch.bfloat16); primals_208 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_367: "bf16[524288]" = torch.ops.prims.convert_element_type.default(primals_209, torch.bfloat16); primals_209 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_368: "bf16[524288]" = torch.ops.prims.convert_element_type.default(primals_210, torch.bfloat16); primals_210 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_369: "bf16[524288]" = torch.ops.prims.convert_element_type.default(primals_211, torch.bfloat16); primals_211 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_370: "bf16[1441792]" = torch.ops.prims.convert_element_type.default(primals_212, torch.bfloat16); primals_212 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_371: "bf16[1441792]" = torch.ops.prims.convert_element_type.default(primals_213, torch.bfloat16); primals_213 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_372: "bf16[1441792]" = torch.ops.prims.convert_element_type.default(primals_214, torch.bfloat16); primals_214 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_373: "bf16[256]" = torch.ops.prims.convert_element_type.default(primals_215, torch.bfloat16); primals_215 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_374: "bf16[256]" = torch.ops.prims.convert_element_type.default(primals_216, torch.bfloat16); primals_216 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:80 in foreach_all_gather, code: all_gather_input, all_gather_output = torch.ops.fsdp.all_gather_copy_in(all_gather_input_numel, world_size, rank, dtype, device, inp_split_sizes, param_all_gather_inputs)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] all_gather_copy_in_12 = torch.ops.fsdp.all_gather_copy_in.default(6423040, 8, 0, torch.bfloat16, device(type='cuda', index=0), [524288, 524288, 524288, 524288, 1441792, 1441792, 1441792, 256, 256], [convert_element_type_366, convert_element_type_367, convert_element_type_368, convert_element_type_369, convert_element_type_370, convert_element_type_371, convert_element_type_372, convert_element_type_373, convert_element_type_374]); convert_element_type_366 = convert_element_type_367 = convert_element_type_368 = convert_element_type_369 = convert_element_type_370 = convert_element_type_371 = convert_element_type_372 = convert_element_type_373 = convert_element_type_374 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_1125: "bf16[6423040]" = all_gather_copy_in_12[0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_1126: "bf16[51384320]" = all_gather_copy_in_12[1]; all_gather_copy_in_12 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_functional_collectives.py:229 in all_gather_tensor, code: tensor = torch.ops._c10d_functional.all_gather_into_tensor(
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] all_gather_into_tensor_12: "bf16[51384320]" = torch.ops._c10d_functional.all_gather_into_tensor.default(getitem_1125, 8, '0'); getitem_1125 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_functional_collectives.py:144 in wait_tensor, code: return torch.ops._c10d_functional.wait_tensor(tensor) # type: ignore[attr-defined]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] wait_tensor_12: "bf16[51384320]" = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_12); all_gather_into_tensor_12 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:170 in foreach_all_gather_copy_out, code: split_unpadded = torch.ops.fsdp.contiguous_view_as_strided(
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_215: "bf16[8, 6423040]" = torch.ops.aten.reshape.default(wait_tensor_12, [8, -1]); wait_tensor_12 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] split_with_sizes_115 = torch.ops.aten.split_with_sizes.default(view_215, [524288, 524288, 524288, 524288, 1441792, 1441792, 1441792, 256, 256], 1); view_215 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_1136: "bf16[8, 524288]" = split_with_sizes_115[0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] contiguous_view_as_strided_102: "bf16[2048, 2048]" = torch.ops.fsdp.contiguous_view_as_strided.default(getitem_1136, [2048, 2048], [2048, 1]); getitem_1136 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_1146: "bf16[8, 524288]" = split_with_sizes_115[1]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] contiguous_view_as_strided_103: "bf16[2048, 2048]" = torch.ops.fsdp.contiguous_view_as_strided.default(getitem_1146, [2048, 2048], [2048, 1]); getitem_1146 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_1156: "bf16[8, 524288]" = split_with_sizes_115[2]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] contiguous_view_as_strided_104: "bf16[2048, 2048]" = torch.ops.fsdp.contiguous_view_as_strided.default(getitem_1156, [2048, 2048], [2048, 1]); getitem_1156 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_1166: "bf16[8, 524288]" = split_with_sizes_115[3]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] contiguous_view_as_strided_105: "bf16[2048, 2048]" = torch.ops.fsdp.contiguous_view_as_strided.default(getitem_1166, [2048, 2048], [2048, 1]); getitem_1166 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_1176: "bf16[8, 1441792]" = split_with_sizes_115[4]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] contiguous_view_as_strided_106: "bf16[5632, 2048]" = torch.ops.fsdp.contiguous_view_as_strided.default(getitem_1176, [5632, 2048], [2048, 1]); getitem_1176 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_1186: "bf16[8, 1441792]" = split_with_sizes_115[5]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] contiguous_view_as_strided_107: "bf16[2048, 5632]" = torch.ops.fsdp.contiguous_view_as_strided.default(getitem_1186, [2048, 5632], [5632, 1]); getitem_1186 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_1196: "bf16[8, 1441792]" = split_with_sizes_115[6]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] contiguous_view_as_strided_108: "bf16[5632, 2048]" = torch.ops.fsdp.contiguous_view_as_strided.default(getitem_1196, [5632, 2048], [2048, 1]); getitem_1196 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_1206: "bf16[8, 256]" = split_with_sizes_115[7]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] contiguous_view_as_strided_109: "bf16[2048]" = torch.ops.fsdp.contiguous_view_as_strided.default(getitem_1206, [2048], [1]); getitem_1206 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_1216: "bf16[8, 256]" = split_with_sizes_115[8]; split_with_sizes_115 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] contiguous_view_as_strided_110: "bf16[2048]" = torch.ops.fsdp.contiguous_view_as_strided.default(getitem_1216, [2048], [1]); getitem_1216 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:74 in forward, code: output = self._norm(x.float()).type_as(x)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_375: "f32[2048, 2048]" = torch.ops.prims.convert_element_type.default(add_43, torch.float32)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:61 in _norm, code: return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] pow_23: "f32[2048, 2048]" = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_375, 2)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mean_22: "f32[2048, 1]" = torch.ops.aten.mean.dim(pow_23, [-1], True); pow_23 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] add_44: "f32[2048, 1]" = torch.ops.aten.add.Tensor(mean_22, 1e-05); mean_22 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] rsqrt_22: "f32[2048, 1]" = torch.ops.aten.rsqrt.default(add_44); add_44 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mul_88: "f32[2048, 2048]" = torch.ops.aten.mul.Tensor(convert_element_type_375, rsqrt_22); convert_element_type_375 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:74 in forward, code: output = self._norm(x.float()).type_as(x)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_376: "bf16[2048, 2048]" = torch.ops.prims.convert_element_type.default(mul_88, torch.bfloat16); mul_88 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:75 in forward, code: return output * self.weight
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mul_89: "bf16[2048, 2048]" = torch.ops.aten.mul.Tensor(convert_element_type_376, contiguous_view_as_strided_109); convert_element_type_376 = contiguous_view_as_strided_109 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/pytorch_yf225/torch/nn/modules/linear.py:116 in forward, code: return F.linear(input, self.weight, self.bias)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] permute_199: "bf16[2048, 2048]" = torch.ops.aten.permute.default(contiguous_view_as_strided_102, [1, 0]); contiguous_view_as_strided_102 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mm_77: "bf16[2048, 2048]" = torch.ops.aten.mm.default(mul_89, permute_199); permute_199 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] permute_201: "bf16[2048, 2048]" = torch.ops.aten.permute.default(contiguous_view_as_strided_103, [1, 0]); contiguous_view_as_strided_103 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mm_78: "bf16[2048, 2048]" = torch.ops.aten.mm.default(mul_89, permute_201); permute_201 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] permute_203: "bf16[2048, 2048]" = torch.ops.aten.permute.default(contiguous_view_as_strided_104, [1, 0]); contiguous_view_as_strided_104 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mm_79: "bf16[2048, 2048]" = torch.ops.aten.mm.default(mul_89, permute_203); permute_203 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:235 in forward, code: xq = xq.view(bsz, seqlen, self.n_heads, self.head_dim)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_224: "bf16[1, 2048, 16, 128]" = torch.ops.aten.reshape.default(mm_77, [1, 2048, 16, 128]); mm_77 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:236 in forward, code: xk = xk.view(bsz, seqlen, self.n_kv_heads, self.head_dim)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_225: "bf16[1, 2048, 16, 128]" = torch.ops.aten.reshape.default(mm_78, [1, 2048, 16, 128]); mm_78 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:237 in forward, code: xv = xv.view(bsz, seqlen, self.n_kv_heads, self.head_dim)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_226: "bf16[1, 2048, 16, 128]" = torch.ops.aten.reshape.default(mm_79, [1, 2048, 16, 128]); mm_79 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:146 in apply_rotary_emb, code: xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_383: "f32[1, 2048, 16, 128]" = torch.ops.prims.convert_element_type.default(view_224, torch.float32); view_224 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_227: "f32[1, 2048, 16, 64, 2]" = torch.ops.aten.reshape.default(convert_element_type_383, [1, 2048, 16, -1, 2]); convert_element_type_383 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_as_complex_22: "c64[1, 2048, 16, 64]" = torch.ops.aten.view_as_complex.default(view_227); view_227 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:147 in apply_rotary_emb, code: xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_384: "f32[1, 2048, 16, 128]" = torch.ops.prims.convert_element_type.default(view_225, torch.float32); view_225 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_228: "f32[1, 2048, 16, 64, 2]" = torch.ops.aten.reshape.default(convert_element_type_384, [1, 2048, 16, -1, 2]); convert_element_type_384 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_as_complex_23: "c64[1, 2048, 16, 64]" = torch.ops.aten.view_as_complex.default(view_228); view_228 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:149 in apply_rotary_emb, code: xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mul_90: "c64[1, 2048, 16, 64]" = torch.ops.aten.mul.Tensor(view_as_complex_22, view_20); view_as_complex_22 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_as_real_22: "f32[1, 2048, 16, 64, 2]" = torch.ops.aten.view_as_real.default(mul_90); mul_90 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_230: "f32[1, 2048, 16, 128]" = torch.ops.aten.reshape.default(view_as_real_22, [1, 2048, 16, 128]); view_as_real_22 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:150 in apply_rotary_emb, code: xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mul_91: "c64[1, 2048, 16, 64]" = torch.ops.aten.mul.Tensor(view_as_complex_23, view_20); view_as_complex_23 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_as_real_23: "f32[1, 2048, 16, 64, 2]" = torch.ops.aten.view_as_real.default(mul_91); mul_91 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_231: "f32[1, 2048, 16, 128]" = torch.ops.aten.reshape.default(view_as_real_23, [1, 2048, 16, 128]); view_as_real_23 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:151 in apply_rotary_emb, code: return xq_out.type_as(xq), xk_out.type_as(xk)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_385: "bf16[1, 2048, 16, 128]" = torch.ops.prims.convert_element_type.default(view_230, torch.bfloat16); view_230 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_386: "bf16[1, 2048, 16, 128]" = torch.ops.prims.convert_element_type.default(view_231, torch.bfloat16); view_231 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:249 in forward, code: xq = xq.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] permute_204: "bf16[1, 16, 2048, 128]" = torch.ops.aten.permute.default(convert_element_type_385, [0, 2, 1, 3]); convert_element_type_385 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:250 in forward, code: xk = keys.transpose(1, 2) # (bs, n_local_heads, cache_len + seqlen, head_dim)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] permute_205: "bf16[1, 16, 2048, 128]" = torch.ops.aten.permute.default(convert_element_type_386, [0, 2, 1, 3]); convert_element_type_386 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:251 in forward, code: xv = values.transpose(1, 2) # (bs, n_local_heads, cache_len + seqlen, head_dim)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] permute_206: "bf16[1, 16, 2048, 128]" = torch.ops.aten.permute.default(view_226, [0, 2, 1, 3]); view_226 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:254 in forward, code: output = F.scaled_dot_product_attention(xq, xk, xv, is_causal=True)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] _scaled_dot_product_flash_attention_11 = torch.ops.aten._scaled_dot_product_flash_attention.default(permute_204, permute_205, permute_206, 0.0, True, scale = 0.08838834764831843)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_1217: "bf16[1, 16, 2048, 128]" = _scaled_dot_product_flash_attention_11[0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_1218: "f32[1, 16, 2048]" = _scaled_dot_product_flash_attention_11[1]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_1223: "i64[]" = _scaled_dot_product_flash_attention_11[6]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_1224: "i64[]" = _scaled_dot_product_flash_attention_11[7]; _scaled_dot_product_flash_attention_11 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:255 in forward, code: output = output.transpose(
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] permute_207: "bf16[1, 2048, 16, 128]" = torch.ops.aten.permute.default(getitem_1217, [0, 2, 1, 3])
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:259 in forward, code: output = output.view(bsz * seqlen, -1)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_232: "bf16[2048, 2048]" = torch.ops.aten.reshape.default(permute_207, [2048, -1]); permute_207 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/pytorch_yf225/torch/nn/modules/linear.py:116 in forward, code: return F.linear(input, self.weight, self.bias)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] permute_209: "bf16[2048, 2048]" = torch.ops.aten.permute.default(contiguous_view_as_strided_105, [1, 0]); contiguous_view_as_strided_105 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mm_80: "bf16[2048, 2048]" = torch.ops.aten.mm.default(view_232, permute_209); permute_209 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:410 in forward, code: h = x + self.attention(self.attention_norm(x), freqs_cis)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] add_45: "bf16[2048, 2048]" = torch.ops.aten.add.Tensor(add_43, mm_80); add_43 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:74 in forward, code: output = self._norm(x.float()).type_as(x)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_389: "f32[2048, 2048]" = torch.ops.prims.convert_element_type.default(add_45, torch.float32)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:61 in _norm, code: return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] pow_24: "f32[2048, 2048]" = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_389, 2)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mean_23: "f32[2048, 1]" = torch.ops.aten.mean.dim(pow_24, [-1], True); pow_24 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] add_46: "f32[2048, 1]" = torch.ops.aten.add.Tensor(mean_23, 1e-05); mean_23 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] rsqrt_23: "f32[2048, 1]" = torch.ops.aten.rsqrt.default(add_46); add_46 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mul_92: "f32[2048, 2048]" = torch.ops.aten.mul.Tensor(convert_element_type_389, rsqrt_23); convert_element_type_389 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:74 in forward, code: output = self._norm(x.float()).type_as(x)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_390: "bf16[2048, 2048]" = torch.ops.prims.convert_element_type.default(mul_92, torch.bfloat16); mul_92 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:75 in forward, code: return output * self.weight
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mul_93: "bf16[2048, 2048]" = torch.ops.aten.mul.Tensor(convert_element_type_390, contiguous_view_as_strided_110); convert_element_type_390 = contiguous_view_as_strided_110 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/pytorch_yf225/torch/nn/modules/linear.py:116 in forward, code: return F.linear(input, self.weight, self.bias)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] permute_211: "bf16[2048, 5632]" = torch.ops.aten.permute.default(contiguous_view_as_strided_106, [1, 0]); contiguous_view_as_strided_106 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mm_81: "bf16[2048, 5632]" = torch.ops.aten.mm.default(mul_93, permute_211); permute_211 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:299 in forward, code: return self.w2(F.silu(self.w1(x)) * self.w3(x))
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_393: "f32[2048, 5632]" = torch.ops.prims.convert_element_type.default(mm_81, torch.float32)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] sigmoid_11: "f32[2048, 5632]" = torch.ops.aten.sigmoid.default(convert_element_type_393)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mul_94: "f32[2048, 5632]" = torch.ops.aten.mul.Tensor(convert_element_type_393, sigmoid_11); convert_element_type_393 = sigmoid_11 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_394: "bf16[2048, 5632]" = torch.ops.prims.convert_element_type.default(mul_94, torch.bfloat16); mul_94 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/pytorch_yf225/torch/nn/modules/linear.py:116 in forward, code: return F.linear(input, self.weight, self.bias)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] permute_213: "bf16[2048, 5632]" = torch.ops.aten.permute.default(contiguous_view_as_strided_108, [1, 0]); contiguous_view_as_strided_108 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mm_82: "bf16[2048, 5632]" = torch.ops.aten.mm.default(mul_93, permute_213); permute_213 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:299 in forward, code: return self.w2(F.silu(self.w1(x)) * self.w3(x))
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mul_95: "bf16[2048, 5632]" = torch.ops.aten.mul.Tensor(convert_element_type_394, mm_82); convert_element_type_394 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/pytorch_yf225/torch/nn/modules/linear.py:116 in forward, code: return F.linear(input, self.weight, self.bias)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] permute_215: "bf16[5632, 2048]" = torch.ops.aten.permute.default(contiguous_view_as_strided_107, [1, 0]); contiguous_view_as_strided_107 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mm_83: "bf16[2048, 2048]" = torch.ops.aten.mm.default(mul_95, permute_215); permute_215 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:411 in forward, code: out = h + self.feed_forward(self.ffn_norm(h))
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] add_47: "bf16[2048, 2048]" = torch.ops.aten.add.Tensor(add_45, mm_83); add_45 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_common.py:140 in _to_dtype_if_needed, code: return tensor.to(dtype)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_399: "bf16[524288]" = torch.ops.prims.convert_element_type.default(primals_226, torch.bfloat16); primals_226 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_400: "bf16[524288]" = torch.ops.prims.convert_element_type.default(primals_227, torch.bfloat16); primals_227 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_401: "bf16[524288]" = torch.ops.prims.convert_element_type.default(primals_228, torch.bfloat16); primals_228 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_402: "bf16[524288]" = torch.ops.prims.convert_element_type.default(primals_229, torch.bfloat16); primals_229 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_403: "bf16[1441792]" = torch.ops.prims.convert_element_type.default(primals_230, torch.bfloat16); primals_230 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_404: "bf16[1441792]" = torch.ops.prims.convert_element_type.default(primals_231, torch.bfloat16); primals_231 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_405: "bf16[1441792]" = torch.ops.prims.convert_element_type.default(primals_232, torch.bfloat16); primals_232 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_406: "bf16[256]" = torch.ops.prims.convert_element_type.default(primals_233, torch.bfloat16); primals_233 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_407: "bf16[256]" = torch.ops.prims.convert_element_type.default(primals_234, torch.bfloat16); primals_234 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:80 in foreach_all_gather, code: all_gather_input, all_gather_output = torch.ops.fsdp.all_gather_copy_in(all_gather_input_numel, world_size, rank, dtype, device, inp_split_sizes, param_all_gather_inputs)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] all_gather_copy_in_13 = torch.ops.fsdp.all_gather_copy_in.default(6423040, 8, 0, torch.bfloat16, device(type='cuda', index=0), [524288, 524288, 524288, 524288, 1441792, 1441792, 1441792, 256, 256], [convert_element_type_399, convert_element_type_400, convert_element_type_401, convert_element_type_402, convert_element_type_403, convert_element_type_404, convert_element_type_405, convert_element_type_406, convert_element_type_407]); convert_element_type_399 = convert_element_type_400 = convert_element_type_401 = convert_element_type_402 = convert_element_type_403 = convert_element_type_404 = convert_element_type_405 = convert_element_type_406 = convert_element_type_407 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_1226: "bf16[6423040]" = all_gather_copy_in_13[0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_1227: "bf16[51384320]" = all_gather_copy_in_13[1]; all_gather_copy_in_13 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_functional_collectives.py:229 in all_gather_tensor, code: tensor = torch.ops._c10d_functional.all_gather_into_tensor(
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] all_gather_into_tensor_13: "bf16[51384320]" = torch.ops._c10d_functional.all_gather_into_tensor.default(getitem_1226, 8, '0'); getitem_1226 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_functional_collectives.py:144 in wait_tensor, code: return torch.ops._c10d_functional.wait_tensor(tensor) # type: ignore[attr-defined]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] wait_tensor_13: "bf16[51384320]" = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_13); all_gather_into_tensor_13 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:170 in foreach_all_gather_copy_out, code: split_unpadded = torch.ops.fsdp.contiguous_view_as_strided(
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_234: "bf16[8, 6423040]" = torch.ops.aten.reshape.default(wait_tensor_13, [8, -1]); wait_tensor_13 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] split_with_sizes_125 = torch.ops.aten.split_with_sizes.default(view_234, [524288, 524288, 524288, 524288, 1441792, 1441792, 1441792, 256, 256], 1); view_234 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_1237: "bf16[8, 524288]" = split_with_sizes_125[0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] contiguous_view_as_strided_111: "bf16[2048, 2048]" = torch.ops.fsdp.contiguous_view_as_strided.default(getitem_1237, [2048, 2048], [2048, 1]); getitem_1237 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_1247: "bf16[8, 524288]" = split_with_sizes_125[1]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] contiguous_view_as_strided_112: "bf16[2048, 2048]" = torch.ops.fsdp.contiguous_view_as_strided.default(getitem_1247, [2048, 2048], [2048, 1]); getitem_1247 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_1257: "bf16[8, 524288]" = split_with_sizes_125[2]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] contiguous_view_as_strided_113: "bf16[2048, 2048]" = torch.ops.fsdp.contiguous_view_as_strided.default(getitem_1257, [2048, 2048], [2048, 1]); getitem_1257 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_1267: "bf16[8, 524288]" = split_with_sizes_125[3]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] contiguous_view_as_strided_114: "bf16[2048, 2048]" = torch.ops.fsdp.contiguous_view_as_strided.default(getitem_1267, [2048, 2048], [2048, 1]); getitem_1267 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_1277: "bf16[8, 1441792]" = split_with_sizes_125[4]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] contiguous_view_as_strided_115: "bf16[5632, 2048]" = torch.ops.fsdp.contiguous_view_as_strided.default(getitem_1277, [5632, 2048], [2048, 1]); getitem_1277 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_1287: "bf16[8, 1441792]" = split_with_sizes_125[5]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] contiguous_view_as_strided_116: "bf16[2048, 5632]" = torch.ops.fsdp.contiguous_view_as_strided.default(getitem_1287, [2048, 5632], [5632, 1]); getitem_1287 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_1297: "bf16[8, 1441792]" = split_with_sizes_125[6]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] contiguous_view_as_strided_117: "bf16[5632, 2048]" = torch.ops.fsdp.contiguous_view_as_strided.default(getitem_1297, [5632, 2048], [2048, 1]); getitem_1297 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_1307: "bf16[8, 256]" = split_with_sizes_125[7]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] contiguous_view_as_strided_118: "bf16[2048]" = torch.ops.fsdp.contiguous_view_as_strided.default(getitem_1307, [2048], [1]); getitem_1307 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_1317: "bf16[8, 256]" = split_with_sizes_125[8]; split_with_sizes_125 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] contiguous_view_as_strided_119: "bf16[2048]" = torch.ops.fsdp.contiguous_view_as_strided.default(getitem_1317, [2048], [1]); getitem_1317 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:74 in forward, code: output = self._norm(x.float()).type_as(x)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_408: "f32[2048, 2048]" = torch.ops.prims.convert_element_type.default(add_47, torch.float32)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:61 in _norm, code: return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] pow_25: "f32[2048, 2048]" = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_408, 2)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mean_24: "f32[2048, 1]" = torch.ops.aten.mean.dim(pow_25, [-1], True); pow_25 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] add_48: "f32[2048, 1]" = torch.ops.aten.add.Tensor(mean_24, 1e-05); mean_24 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] rsqrt_24: "f32[2048, 1]" = torch.ops.aten.rsqrt.default(add_48); add_48 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mul_96: "f32[2048, 2048]" = torch.ops.aten.mul.Tensor(convert_element_type_408, rsqrt_24); convert_element_type_408 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:74 in forward, code: output = self._norm(x.float()).type_as(x)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_409: "bf16[2048, 2048]" = torch.ops.prims.convert_element_type.default(mul_96, torch.bfloat16); mul_96 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:75 in forward, code: return output * self.weight
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mul_97: "bf16[2048, 2048]" = torch.ops.aten.mul.Tensor(convert_element_type_409, contiguous_view_as_strided_118); convert_element_type_409 = contiguous_view_as_strided_118 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/pytorch_yf225/torch/nn/modules/linear.py:116 in forward, code: return F.linear(input, self.weight, self.bias)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] permute_217: "bf16[2048, 2048]" = torch.ops.aten.permute.default(contiguous_view_as_strided_111, [1, 0]); contiguous_view_as_strided_111 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mm_84: "bf16[2048, 2048]" = torch.ops.aten.mm.default(mul_97, permute_217); permute_217 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] permute_219: "bf16[2048, 2048]" = torch.ops.aten.permute.default(contiguous_view_as_strided_112, [1, 0]); contiguous_view_as_strided_112 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mm_85: "bf16[2048, 2048]" = torch.ops.aten.mm.default(mul_97, permute_219); permute_219 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] permute_221: "bf16[2048, 2048]" = torch.ops.aten.permute.default(contiguous_view_as_strided_113, [1, 0]); contiguous_view_as_strided_113 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mm_86: "bf16[2048, 2048]" = torch.ops.aten.mm.default(mul_97, permute_221); permute_221 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:235 in forward, code: xq = xq.view(bsz, seqlen, self.n_heads, self.head_dim)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_243: "bf16[1, 2048, 16, 128]" = torch.ops.aten.reshape.default(mm_84, [1, 2048, 16, 128]); mm_84 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:236 in forward, code: xk = xk.view(bsz, seqlen, self.n_kv_heads, self.head_dim)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_244: "bf16[1, 2048, 16, 128]" = torch.ops.aten.reshape.default(mm_85, [1, 2048, 16, 128]); mm_85 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:237 in forward, code: xv = xv.view(bsz, seqlen, self.n_kv_heads, self.head_dim)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_245: "bf16[1, 2048, 16, 128]" = torch.ops.aten.reshape.default(mm_86, [1, 2048, 16, 128]); mm_86 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:146 in apply_rotary_emb, code: xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_416: "f32[1, 2048, 16, 128]" = torch.ops.prims.convert_element_type.default(view_243, torch.float32); view_243 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_246: "f32[1, 2048, 16, 64, 2]" = torch.ops.aten.reshape.default(convert_element_type_416, [1, 2048, 16, -1, 2]); convert_element_type_416 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_as_complex_24: "c64[1, 2048, 16, 64]" = torch.ops.aten.view_as_complex.default(view_246); view_246 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:147 in apply_rotary_emb, code: xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_417: "f32[1, 2048, 16, 128]" = torch.ops.prims.convert_element_type.default(view_244, torch.float32); view_244 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_247: "f32[1, 2048, 16, 64, 2]" = torch.ops.aten.reshape.default(convert_element_type_417, [1, 2048, 16, -1, 2]); convert_element_type_417 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_as_complex_25: "c64[1, 2048, 16, 64]" = torch.ops.aten.view_as_complex.default(view_247); view_247 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:149 in apply_rotary_emb, code: xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mul_98: "c64[1, 2048, 16, 64]" = torch.ops.aten.mul.Tensor(view_as_complex_24, view_20); view_as_complex_24 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_as_real_24: "f32[1, 2048, 16, 64, 2]" = torch.ops.aten.view_as_real.default(mul_98); mul_98 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_249: "f32[1, 2048, 16, 128]" = torch.ops.aten.reshape.default(view_as_real_24, [1, 2048, 16, 128]); view_as_real_24 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:150 in apply_rotary_emb, code: xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mul_99: "c64[1, 2048, 16, 64]" = torch.ops.aten.mul.Tensor(view_as_complex_25, view_20); view_as_complex_25 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_as_real_25: "f32[1, 2048, 16, 64, 2]" = torch.ops.aten.view_as_real.default(mul_99); mul_99 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_250: "f32[1, 2048, 16, 128]" = torch.ops.aten.reshape.default(view_as_real_25, [1, 2048, 16, 128]); view_as_real_25 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:151 in apply_rotary_emb, code: return xq_out.type_as(xq), xk_out.type_as(xk)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_418: "bf16[1, 2048, 16, 128]" = torch.ops.prims.convert_element_type.default(view_249, torch.bfloat16); view_249 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_419: "bf16[1, 2048, 16, 128]" = torch.ops.prims.convert_element_type.default(view_250, torch.bfloat16); view_250 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:249 in forward, code: xq = xq.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] permute_222: "bf16[1, 16, 2048, 128]" = torch.ops.aten.permute.default(convert_element_type_418, [0, 2, 1, 3]); convert_element_type_418 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:250 in forward, code: xk = keys.transpose(1, 2) # (bs, n_local_heads, cache_len + seqlen, head_dim)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] permute_223: "bf16[1, 16, 2048, 128]" = torch.ops.aten.permute.default(convert_element_type_419, [0, 2, 1, 3]); convert_element_type_419 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:251 in forward, code: xv = values.transpose(1, 2) # (bs, n_local_heads, cache_len + seqlen, head_dim)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] permute_224: "bf16[1, 16, 2048, 128]" = torch.ops.aten.permute.default(view_245, [0, 2, 1, 3]); view_245 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:254 in forward, code: output = F.scaled_dot_product_attention(xq, xk, xv, is_causal=True)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] _scaled_dot_product_flash_attention_12 = torch.ops.aten._scaled_dot_product_flash_attention.default(permute_222, permute_223, permute_224, 0.0, True, scale = 0.08838834764831843)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_1318: "bf16[1, 16, 2048, 128]" = _scaled_dot_product_flash_attention_12[0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_1319: "f32[1, 16, 2048]" = _scaled_dot_product_flash_attention_12[1]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_1324: "i64[]" = _scaled_dot_product_flash_attention_12[6]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_1325: "i64[]" = _scaled_dot_product_flash_attention_12[7]; _scaled_dot_product_flash_attention_12 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:255 in forward, code: output = output.transpose(
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] permute_225: "bf16[1, 2048, 16, 128]" = torch.ops.aten.permute.default(getitem_1318, [0, 2, 1, 3])
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:259 in forward, code: output = output.view(bsz * seqlen, -1)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_251: "bf16[2048, 2048]" = torch.ops.aten.reshape.default(permute_225, [2048, -1]); permute_225 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/pytorch_yf225/torch/nn/modules/linear.py:116 in forward, code: return F.linear(input, self.weight, self.bias)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] permute_227: "bf16[2048, 2048]" = torch.ops.aten.permute.default(contiguous_view_as_strided_114, [1, 0]); contiguous_view_as_strided_114 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mm_87: "bf16[2048, 2048]" = torch.ops.aten.mm.default(view_251, permute_227); permute_227 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:410 in forward, code: h = x + self.attention(self.attention_norm(x), freqs_cis)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] add_49: "bf16[2048, 2048]" = torch.ops.aten.add.Tensor(add_47, mm_87); add_47 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:74 in forward, code: output = self._norm(x.float()).type_as(x)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_422: "f32[2048, 2048]" = torch.ops.prims.convert_element_type.default(add_49, torch.float32)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:61 in _norm, code: return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] pow_26: "f32[2048, 2048]" = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_422, 2)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mean_25: "f32[2048, 1]" = torch.ops.aten.mean.dim(pow_26, [-1], True); pow_26 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] add_50: "f32[2048, 1]" = torch.ops.aten.add.Tensor(mean_25, 1e-05); mean_25 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] rsqrt_25: "f32[2048, 1]" = torch.ops.aten.rsqrt.default(add_50); add_50 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mul_100: "f32[2048, 2048]" = torch.ops.aten.mul.Tensor(convert_element_type_422, rsqrt_25); convert_element_type_422 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:74 in forward, code: output = self._norm(x.float()).type_as(x)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_423: "bf16[2048, 2048]" = torch.ops.prims.convert_element_type.default(mul_100, torch.bfloat16); mul_100 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:75 in forward, code: return output * self.weight
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mul_101: "bf16[2048, 2048]" = torch.ops.aten.mul.Tensor(convert_element_type_423, contiguous_view_as_strided_119); convert_element_type_423 = contiguous_view_as_strided_119 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/pytorch_yf225/torch/nn/modules/linear.py:116 in forward, code: return F.linear(input, self.weight, self.bias)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] permute_229: "bf16[2048, 5632]" = torch.ops.aten.permute.default(contiguous_view_as_strided_115, [1, 0]); contiguous_view_as_strided_115 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mm_88: "bf16[2048, 5632]" = torch.ops.aten.mm.default(mul_101, permute_229); permute_229 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:299 in forward, code: return self.w2(F.silu(self.w1(x)) * self.w3(x))
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_426: "f32[2048, 5632]" = torch.ops.prims.convert_element_type.default(mm_88, torch.float32)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] sigmoid_12: "f32[2048, 5632]" = torch.ops.aten.sigmoid.default(convert_element_type_426)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mul_102: "f32[2048, 5632]" = torch.ops.aten.mul.Tensor(convert_element_type_426, sigmoid_12); convert_element_type_426 = sigmoid_12 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_427: "bf16[2048, 5632]" = torch.ops.prims.convert_element_type.default(mul_102, torch.bfloat16); mul_102 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/pytorch_yf225/torch/nn/modules/linear.py:116 in forward, code: return F.linear(input, self.weight, self.bias)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] permute_231: "bf16[2048, 5632]" = torch.ops.aten.permute.default(contiguous_view_as_strided_117, [1, 0]); contiguous_view_as_strided_117 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mm_89: "bf16[2048, 5632]" = torch.ops.aten.mm.default(mul_101, permute_231); permute_231 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:299 in forward, code: return self.w2(F.silu(self.w1(x)) * self.w3(x))
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mul_103: "bf16[2048, 5632]" = torch.ops.aten.mul.Tensor(convert_element_type_427, mm_89); convert_element_type_427 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/pytorch_yf225/torch/nn/modules/linear.py:116 in forward, code: return F.linear(input, self.weight, self.bias)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] permute_233: "bf16[5632, 2048]" = torch.ops.aten.permute.default(contiguous_view_as_strided_116, [1, 0]); contiguous_view_as_strided_116 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mm_90: "bf16[2048, 2048]" = torch.ops.aten.mm.default(mul_103, permute_233); permute_233 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:411 in forward, code: out = h + self.feed_forward(self.ffn_norm(h))
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] add_51: "bf16[2048, 2048]" = torch.ops.aten.add.Tensor(add_49, mm_90); add_49 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_common.py:140 in _to_dtype_if_needed, code: return tensor.to(dtype)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_432: "bf16[524288]" = torch.ops.prims.convert_element_type.default(primals_244, torch.bfloat16); primals_244 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_433: "bf16[524288]" = torch.ops.prims.convert_element_type.default(primals_245, torch.bfloat16); primals_245 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_434: "bf16[524288]" = torch.ops.prims.convert_element_type.default(primals_246, torch.bfloat16); primals_246 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_435: "bf16[524288]" = torch.ops.prims.convert_element_type.default(primals_247, torch.bfloat16); primals_247 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_436: "bf16[1441792]" = torch.ops.prims.convert_element_type.default(primals_248, torch.bfloat16); primals_248 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_437: "bf16[1441792]" = torch.ops.prims.convert_element_type.default(primals_249, torch.bfloat16); primals_249 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_438: "bf16[1441792]" = torch.ops.prims.convert_element_type.default(primals_250, torch.bfloat16); primals_250 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_439: "bf16[256]" = torch.ops.prims.convert_element_type.default(primals_251, torch.bfloat16); primals_251 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_440: "bf16[256]" = torch.ops.prims.convert_element_type.default(primals_252, torch.bfloat16); primals_252 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:80 in foreach_all_gather, code: all_gather_input, all_gather_output = torch.ops.fsdp.all_gather_copy_in(all_gather_input_numel, world_size, rank, dtype, device, inp_split_sizes, param_all_gather_inputs)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] all_gather_copy_in_14 = torch.ops.fsdp.all_gather_copy_in.default(6423040, 8, 0, torch.bfloat16, device(type='cuda', index=0), [524288, 524288, 524288, 524288, 1441792, 1441792, 1441792, 256, 256], [convert_element_type_432, convert_element_type_433, convert_element_type_434, convert_element_type_435, convert_element_type_436, convert_element_type_437, convert_element_type_438, convert_element_type_439, convert_element_type_440]); convert_element_type_432 = convert_element_type_433 = convert_element_type_434 = convert_element_type_435 = convert_element_type_436 = convert_element_type_437 = convert_element_type_438 = convert_element_type_439 = convert_element_type_440 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_1327: "bf16[6423040]" = all_gather_copy_in_14[0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_1328: "bf16[51384320]" = all_gather_copy_in_14[1]; all_gather_copy_in_14 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_functional_collectives.py:229 in all_gather_tensor, code: tensor = torch.ops._c10d_functional.all_gather_into_tensor(
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] all_gather_into_tensor_14: "bf16[51384320]" = torch.ops._c10d_functional.all_gather_into_tensor.default(getitem_1327, 8, '0'); getitem_1327 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_functional_collectives.py:144 in wait_tensor, code: return torch.ops._c10d_functional.wait_tensor(tensor) # type: ignore[attr-defined]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] wait_tensor_14: "bf16[51384320]" = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_14); all_gather_into_tensor_14 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:170 in foreach_all_gather_copy_out, code: split_unpadded = torch.ops.fsdp.contiguous_view_as_strided(
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_253: "bf16[8, 6423040]" = torch.ops.aten.reshape.default(wait_tensor_14, [8, -1]); wait_tensor_14 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] split_with_sizes_135 = torch.ops.aten.split_with_sizes.default(view_253, [524288, 524288, 524288, 524288, 1441792, 1441792, 1441792, 256, 256], 1); view_253 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_1338: "bf16[8, 524288]" = split_with_sizes_135[0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] contiguous_view_as_strided_120: "bf16[2048, 2048]" = torch.ops.fsdp.contiguous_view_as_strided.default(getitem_1338, [2048, 2048], [2048, 1]); getitem_1338 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_1348: "bf16[8, 524288]" = split_with_sizes_135[1]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] contiguous_view_as_strided_121: "bf16[2048, 2048]" = torch.ops.fsdp.contiguous_view_as_strided.default(getitem_1348, [2048, 2048], [2048, 1]); getitem_1348 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_1358: "bf16[8, 524288]" = split_with_sizes_135[2]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] contiguous_view_as_strided_122: "bf16[2048, 2048]" = torch.ops.fsdp.contiguous_view_as_strided.default(getitem_1358, [2048, 2048], [2048, 1]); getitem_1358 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_1368: "bf16[8, 524288]" = split_with_sizes_135[3]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] contiguous_view_as_strided_123: "bf16[2048, 2048]" = torch.ops.fsdp.contiguous_view_as_strided.default(getitem_1368, [2048, 2048], [2048, 1]); getitem_1368 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_1378: "bf16[8, 1441792]" = split_with_sizes_135[4]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] contiguous_view_as_strided_124: "bf16[5632, 2048]" = torch.ops.fsdp.contiguous_view_as_strided.default(getitem_1378, [5632, 2048], [2048, 1]); getitem_1378 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_1388: "bf16[8, 1441792]" = split_with_sizes_135[5]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] contiguous_view_as_strided_125: "bf16[2048, 5632]" = torch.ops.fsdp.contiguous_view_as_strided.default(getitem_1388, [2048, 5632], [5632, 1]); getitem_1388 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_1398: "bf16[8, 1441792]" = split_with_sizes_135[6]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] contiguous_view_as_strided_126: "bf16[5632, 2048]" = torch.ops.fsdp.contiguous_view_as_strided.default(getitem_1398, [5632, 2048], [2048, 1]); getitem_1398 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_1408: "bf16[8, 256]" = split_with_sizes_135[7]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] contiguous_view_as_strided_127: "bf16[2048]" = torch.ops.fsdp.contiguous_view_as_strided.default(getitem_1408, [2048], [1]); getitem_1408 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_1418: "bf16[8, 256]" = split_with_sizes_135[8]; split_with_sizes_135 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] contiguous_view_as_strided_128: "bf16[2048]" = torch.ops.fsdp.contiguous_view_as_strided.default(getitem_1418, [2048], [1]); getitem_1418 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:74 in forward, code: output = self._norm(x.float()).type_as(x)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_441: "f32[2048, 2048]" = torch.ops.prims.convert_element_type.default(add_51, torch.float32)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:61 in _norm, code: return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] pow_27: "f32[2048, 2048]" = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_441, 2)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mean_26: "f32[2048, 1]" = torch.ops.aten.mean.dim(pow_27, [-1], True); pow_27 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] add_52: "f32[2048, 1]" = torch.ops.aten.add.Tensor(mean_26, 1e-05); mean_26 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] rsqrt_26: "f32[2048, 1]" = torch.ops.aten.rsqrt.default(add_52); add_52 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mul_104: "f32[2048, 2048]" = torch.ops.aten.mul.Tensor(convert_element_type_441, rsqrt_26); convert_element_type_441 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:74 in forward, code: output = self._norm(x.float()).type_as(x)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_442: "bf16[2048, 2048]" = torch.ops.prims.convert_element_type.default(mul_104, torch.bfloat16); mul_104 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:75 in forward, code: return output * self.weight
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mul_105: "bf16[2048, 2048]" = torch.ops.aten.mul.Tensor(convert_element_type_442, contiguous_view_as_strided_127); convert_element_type_442 = contiguous_view_as_strided_127 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/pytorch_yf225/torch/nn/modules/linear.py:116 in forward, code: return F.linear(input, self.weight, self.bias)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] permute_235: "bf16[2048, 2048]" = torch.ops.aten.permute.default(contiguous_view_as_strided_120, [1, 0]); contiguous_view_as_strided_120 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mm_91: "bf16[2048, 2048]" = torch.ops.aten.mm.default(mul_105, permute_235); permute_235 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] permute_237: "bf16[2048, 2048]" = torch.ops.aten.permute.default(contiguous_view_as_strided_121, [1, 0]); contiguous_view_as_strided_121 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mm_92: "bf16[2048, 2048]" = torch.ops.aten.mm.default(mul_105, permute_237); permute_237 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] permute_239: "bf16[2048, 2048]" = torch.ops.aten.permute.default(contiguous_view_as_strided_122, [1, 0]); contiguous_view_as_strided_122 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mm_93: "bf16[2048, 2048]" = torch.ops.aten.mm.default(mul_105, permute_239); permute_239 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:235 in forward, code: xq = xq.view(bsz, seqlen, self.n_heads, self.head_dim)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_262: "bf16[1, 2048, 16, 128]" = torch.ops.aten.reshape.default(mm_91, [1, 2048, 16, 128]); mm_91 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:236 in forward, code: xk = xk.view(bsz, seqlen, self.n_kv_heads, self.head_dim)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_263: "bf16[1, 2048, 16, 128]" = torch.ops.aten.reshape.default(mm_92, [1, 2048, 16, 128]); mm_92 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:237 in forward, code: xv = xv.view(bsz, seqlen, self.n_kv_heads, self.head_dim)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_264: "bf16[1, 2048, 16, 128]" = torch.ops.aten.reshape.default(mm_93, [1, 2048, 16, 128]); mm_93 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:146 in apply_rotary_emb, code: xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_449: "f32[1, 2048, 16, 128]" = torch.ops.prims.convert_element_type.default(view_262, torch.float32); view_262 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_265: "f32[1, 2048, 16, 64, 2]" = torch.ops.aten.reshape.default(convert_element_type_449, [1, 2048, 16, -1, 2]); convert_element_type_449 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_as_complex_26: "c64[1, 2048, 16, 64]" = torch.ops.aten.view_as_complex.default(view_265); view_265 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:147 in apply_rotary_emb, code: xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_450: "f32[1, 2048, 16, 128]" = torch.ops.prims.convert_element_type.default(view_263, torch.float32); view_263 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_266: "f32[1, 2048, 16, 64, 2]" = torch.ops.aten.reshape.default(convert_element_type_450, [1, 2048, 16, -1, 2]); convert_element_type_450 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_as_complex_27: "c64[1, 2048, 16, 64]" = torch.ops.aten.view_as_complex.default(view_266); view_266 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:149 in apply_rotary_emb, code: xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mul_106: "c64[1, 2048, 16, 64]" = torch.ops.aten.mul.Tensor(view_as_complex_26, view_20); view_as_complex_26 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_as_real_26: "f32[1, 2048, 16, 64, 2]" = torch.ops.aten.view_as_real.default(mul_106); mul_106 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_268: "f32[1, 2048, 16, 128]" = torch.ops.aten.reshape.default(view_as_real_26, [1, 2048, 16, 128]); view_as_real_26 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:150 in apply_rotary_emb, code: xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mul_107: "c64[1, 2048, 16, 64]" = torch.ops.aten.mul.Tensor(view_as_complex_27, view_20); view_as_complex_27 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_as_real_27: "f32[1, 2048, 16, 64, 2]" = torch.ops.aten.view_as_real.default(mul_107); mul_107 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_269: "f32[1, 2048, 16, 128]" = torch.ops.aten.reshape.default(view_as_real_27, [1, 2048, 16, 128]); view_as_real_27 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:151 in apply_rotary_emb, code: return xq_out.type_as(xq), xk_out.type_as(xk)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_451: "bf16[1, 2048, 16, 128]" = torch.ops.prims.convert_element_type.default(view_268, torch.bfloat16); view_268 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_452: "bf16[1, 2048, 16, 128]" = torch.ops.prims.convert_element_type.default(view_269, torch.bfloat16); view_269 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:249 in forward, code: xq = xq.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] permute_240: "bf16[1, 16, 2048, 128]" = torch.ops.aten.permute.default(convert_element_type_451, [0, 2, 1, 3]); convert_element_type_451 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:250 in forward, code: xk = keys.transpose(1, 2) # (bs, n_local_heads, cache_len + seqlen, head_dim)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] permute_241: "bf16[1, 16, 2048, 128]" = torch.ops.aten.permute.default(convert_element_type_452, [0, 2, 1, 3]); convert_element_type_452 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:251 in forward, code: xv = values.transpose(1, 2) # (bs, n_local_heads, cache_len + seqlen, head_dim)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] permute_242: "bf16[1, 16, 2048, 128]" = torch.ops.aten.permute.default(view_264, [0, 2, 1, 3]); view_264 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:254 in forward, code: output = F.scaled_dot_product_attention(xq, xk, xv, is_causal=True)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] _scaled_dot_product_flash_attention_13 = torch.ops.aten._scaled_dot_product_flash_attention.default(permute_240, permute_241, permute_242, 0.0, True, scale = 0.08838834764831843)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_1419: "bf16[1, 16, 2048, 128]" = _scaled_dot_product_flash_attention_13[0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_1420: "f32[1, 16, 2048]" = _scaled_dot_product_flash_attention_13[1]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_1425: "i64[]" = _scaled_dot_product_flash_attention_13[6]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_1426: "i64[]" = _scaled_dot_product_flash_attention_13[7]; _scaled_dot_product_flash_attention_13 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:255 in forward, code: output = output.transpose(
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] permute_243: "bf16[1, 2048, 16, 128]" = torch.ops.aten.permute.default(getitem_1419, [0, 2, 1, 3])
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:259 in forward, code: output = output.view(bsz * seqlen, -1)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_270: "bf16[2048, 2048]" = torch.ops.aten.reshape.default(permute_243, [2048, -1]); permute_243 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/pytorch_yf225/torch/nn/modules/linear.py:116 in forward, code: return F.linear(input, self.weight, self.bias)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] permute_245: "bf16[2048, 2048]" = torch.ops.aten.permute.default(contiguous_view_as_strided_123, [1, 0]); contiguous_view_as_strided_123 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mm_94: "bf16[2048, 2048]" = torch.ops.aten.mm.default(view_270, permute_245); permute_245 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:410 in forward, code: h = x + self.attention(self.attention_norm(x), freqs_cis)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] add_53: "bf16[2048, 2048]" = torch.ops.aten.add.Tensor(add_51, mm_94); add_51 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:74 in forward, code: output = self._norm(x.float()).type_as(x)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_455: "f32[2048, 2048]" = torch.ops.prims.convert_element_type.default(add_53, torch.float32)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:61 in _norm, code: return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] pow_28: "f32[2048, 2048]" = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_455, 2)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mean_27: "f32[2048, 1]" = torch.ops.aten.mean.dim(pow_28, [-1], True); pow_28 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] add_54: "f32[2048, 1]" = torch.ops.aten.add.Tensor(mean_27, 1e-05); mean_27 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] rsqrt_27: "f32[2048, 1]" = torch.ops.aten.rsqrt.default(add_54); add_54 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mul_108: "f32[2048, 2048]" = torch.ops.aten.mul.Tensor(convert_element_type_455, rsqrt_27); convert_element_type_455 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:74 in forward, code: output = self._norm(x.float()).type_as(x)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_456: "bf16[2048, 2048]" = torch.ops.prims.convert_element_type.default(mul_108, torch.bfloat16); mul_108 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:75 in forward, code: return output * self.weight
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mul_109: "bf16[2048, 2048]" = torch.ops.aten.mul.Tensor(convert_element_type_456, contiguous_view_as_strided_128); convert_element_type_456 = contiguous_view_as_strided_128 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/pytorch_yf225/torch/nn/modules/linear.py:116 in forward, code: return F.linear(input, self.weight, self.bias)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] permute_247: "bf16[2048, 5632]" = torch.ops.aten.permute.default(contiguous_view_as_strided_124, [1, 0]); contiguous_view_as_strided_124 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mm_95: "bf16[2048, 5632]" = torch.ops.aten.mm.default(mul_109, permute_247); permute_247 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:299 in forward, code: return self.w2(F.silu(self.w1(x)) * self.w3(x))
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_459: "f32[2048, 5632]" = torch.ops.prims.convert_element_type.default(mm_95, torch.float32)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] sigmoid_13: "f32[2048, 5632]" = torch.ops.aten.sigmoid.default(convert_element_type_459)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mul_110: "f32[2048, 5632]" = torch.ops.aten.mul.Tensor(convert_element_type_459, sigmoid_13); convert_element_type_459 = sigmoid_13 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_460: "bf16[2048, 5632]" = torch.ops.prims.convert_element_type.default(mul_110, torch.bfloat16); mul_110 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/pytorch_yf225/torch/nn/modules/linear.py:116 in forward, code: return F.linear(input, self.weight, self.bias)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] permute_249: "bf16[2048, 5632]" = torch.ops.aten.permute.default(contiguous_view_as_strided_126, [1, 0]); contiguous_view_as_strided_126 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mm_96: "bf16[2048, 5632]" = torch.ops.aten.mm.default(mul_109, permute_249); permute_249 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:299 in forward, code: return self.w2(F.silu(self.w1(x)) * self.w3(x))
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mul_111: "bf16[2048, 5632]" = torch.ops.aten.mul.Tensor(convert_element_type_460, mm_96); convert_element_type_460 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/pytorch_yf225/torch/nn/modules/linear.py:116 in forward, code: return F.linear(input, self.weight, self.bias)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] permute_251: "bf16[5632, 2048]" = torch.ops.aten.permute.default(contiguous_view_as_strided_125, [1, 0]); contiguous_view_as_strided_125 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mm_97: "bf16[2048, 2048]" = torch.ops.aten.mm.default(mul_111, permute_251); permute_251 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:411 in forward, code: out = h + self.feed_forward(self.ffn_norm(h))
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] add_55: "bf16[2048, 2048]" = torch.ops.aten.add.Tensor(add_53, mm_97); add_53 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_common.py:140 in _to_dtype_if_needed, code: return tensor.to(dtype)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_465: "bf16[524288]" = torch.ops.prims.convert_element_type.default(primals_262, torch.bfloat16); primals_262 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_466: "bf16[524288]" = torch.ops.prims.convert_element_type.default(primals_263, torch.bfloat16); primals_263 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_467: "bf16[524288]" = torch.ops.prims.convert_element_type.default(primals_264, torch.bfloat16); primals_264 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_468: "bf16[524288]" = torch.ops.prims.convert_element_type.default(primals_265, torch.bfloat16); primals_265 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_469: "bf16[1441792]" = torch.ops.prims.convert_element_type.default(primals_266, torch.bfloat16); primals_266 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_470: "bf16[1441792]" = torch.ops.prims.convert_element_type.default(primals_267, torch.bfloat16); primals_267 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_471: "bf16[1441792]" = torch.ops.prims.convert_element_type.default(primals_268, torch.bfloat16); primals_268 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_472: "bf16[256]" = torch.ops.prims.convert_element_type.default(primals_269, torch.bfloat16); primals_269 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_473: "bf16[256]" = torch.ops.prims.convert_element_type.default(primals_270, torch.bfloat16); primals_270 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:80 in foreach_all_gather, code: all_gather_input, all_gather_output = torch.ops.fsdp.all_gather_copy_in(all_gather_input_numel, world_size, rank, dtype, device, inp_split_sizes, param_all_gather_inputs)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] all_gather_copy_in_15 = torch.ops.fsdp.all_gather_copy_in.default(6423040, 8, 0, torch.bfloat16, device(type='cuda', index=0), [524288, 524288, 524288, 524288, 1441792, 1441792, 1441792, 256, 256], [convert_element_type_465, convert_element_type_466, convert_element_type_467, convert_element_type_468, convert_element_type_469, convert_element_type_470, convert_element_type_471, convert_element_type_472, convert_element_type_473]); convert_element_type_465 = convert_element_type_466 = convert_element_type_467 = convert_element_type_468 = convert_element_type_469 = convert_element_type_470 = convert_element_type_471 = convert_element_type_472 = convert_element_type_473 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_1428: "bf16[6423040]" = all_gather_copy_in_15[0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_1429: "bf16[51384320]" = all_gather_copy_in_15[1]; all_gather_copy_in_15 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_functional_collectives.py:229 in all_gather_tensor, code: tensor = torch.ops._c10d_functional.all_gather_into_tensor(
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] all_gather_into_tensor_15: "bf16[51384320]" = torch.ops._c10d_functional.all_gather_into_tensor.default(getitem_1428, 8, '0'); getitem_1428 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_functional_collectives.py:144 in wait_tensor, code: return torch.ops._c10d_functional.wait_tensor(tensor) # type: ignore[attr-defined]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] wait_tensor_15: "bf16[51384320]" = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_15); all_gather_into_tensor_15 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:170 in foreach_all_gather_copy_out, code: split_unpadded = torch.ops.fsdp.contiguous_view_as_strided(
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_272: "bf16[8, 6423040]" = torch.ops.aten.reshape.default(wait_tensor_15, [8, -1]); wait_tensor_15 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] split_with_sizes_145 = torch.ops.aten.split_with_sizes.default(view_272, [524288, 524288, 524288, 524288, 1441792, 1441792, 1441792, 256, 256], 1); view_272 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_1439: "bf16[8, 524288]" = split_with_sizes_145[0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] contiguous_view_as_strided_129: "bf16[2048, 2048]" = torch.ops.fsdp.contiguous_view_as_strided.default(getitem_1439, [2048, 2048], [2048, 1]); getitem_1439 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_1449: "bf16[8, 524288]" = split_with_sizes_145[1]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] contiguous_view_as_strided_130: "bf16[2048, 2048]" = torch.ops.fsdp.contiguous_view_as_strided.default(getitem_1449, [2048, 2048], [2048, 1]); getitem_1449 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_1459: "bf16[8, 524288]" = split_with_sizes_145[2]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] contiguous_view_as_strided_131: "bf16[2048, 2048]" = torch.ops.fsdp.contiguous_view_as_strided.default(getitem_1459, [2048, 2048], [2048, 1]); getitem_1459 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_1469: "bf16[8, 524288]" = split_with_sizes_145[3]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] contiguous_view_as_strided_132: "bf16[2048, 2048]" = torch.ops.fsdp.contiguous_view_as_strided.default(getitem_1469, [2048, 2048], [2048, 1]); getitem_1469 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_1479: "bf16[8, 1441792]" = split_with_sizes_145[4]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] contiguous_view_as_strided_133: "bf16[5632, 2048]" = torch.ops.fsdp.contiguous_view_as_strided.default(getitem_1479, [5632, 2048], [2048, 1]); getitem_1479 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_1489: "bf16[8, 1441792]" = split_with_sizes_145[5]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] contiguous_view_as_strided_134: "bf16[2048, 5632]" = torch.ops.fsdp.contiguous_view_as_strided.default(getitem_1489, [2048, 5632], [5632, 1]); getitem_1489 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_1499: "bf16[8, 1441792]" = split_with_sizes_145[6]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] contiguous_view_as_strided_135: "bf16[5632, 2048]" = torch.ops.fsdp.contiguous_view_as_strided.default(getitem_1499, [5632, 2048], [2048, 1]); getitem_1499 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_1509: "bf16[8, 256]" = split_with_sizes_145[7]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] contiguous_view_as_strided_136: "bf16[2048]" = torch.ops.fsdp.contiguous_view_as_strided.default(getitem_1509, [2048], [1]); getitem_1509 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_1519: "bf16[8, 256]" = split_with_sizes_145[8]; split_with_sizes_145 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] contiguous_view_as_strided_137: "bf16[2048]" = torch.ops.fsdp.contiguous_view_as_strided.default(getitem_1519, [2048], [1]); getitem_1519 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:74 in forward, code: output = self._norm(x.float()).type_as(x)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_474: "f32[2048, 2048]" = torch.ops.prims.convert_element_type.default(add_55, torch.float32)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:61 in _norm, code: return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] pow_29: "f32[2048, 2048]" = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_474, 2)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mean_28: "f32[2048, 1]" = torch.ops.aten.mean.dim(pow_29, [-1], True); pow_29 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] add_56: "f32[2048, 1]" = torch.ops.aten.add.Tensor(mean_28, 1e-05); mean_28 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] rsqrt_28: "f32[2048, 1]" = torch.ops.aten.rsqrt.default(add_56); add_56 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mul_112: "f32[2048, 2048]" = torch.ops.aten.mul.Tensor(convert_element_type_474, rsqrt_28); convert_element_type_474 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:74 in forward, code: output = self._norm(x.float()).type_as(x)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_475: "bf16[2048, 2048]" = torch.ops.prims.convert_element_type.default(mul_112, torch.bfloat16); mul_112 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:75 in forward, code: return output * self.weight
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mul_113: "bf16[2048, 2048]" = torch.ops.aten.mul.Tensor(convert_element_type_475, contiguous_view_as_strided_136); convert_element_type_475 = contiguous_view_as_strided_136 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/pytorch_yf225/torch/nn/modules/linear.py:116 in forward, code: return F.linear(input, self.weight, self.bias)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] permute_253: "bf16[2048, 2048]" = torch.ops.aten.permute.default(contiguous_view_as_strided_129, [1, 0]); contiguous_view_as_strided_129 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mm_98: "bf16[2048, 2048]" = torch.ops.aten.mm.default(mul_113, permute_253); permute_253 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] permute_255: "bf16[2048, 2048]" = torch.ops.aten.permute.default(contiguous_view_as_strided_130, [1, 0]); contiguous_view_as_strided_130 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mm_99: "bf16[2048, 2048]" = torch.ops.aten.mm.default(mul_113, permute_255); permute_255 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] permute_257: "bf16[2048, 2048]" = torch.ops.aten.permute.default(contiguous_view_as_strided_131, [1, 0]); contiguous_view_as_strided_131 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mm_100: "bf16[2048, 2048]" = torch.ops.aten.mm.default(mul_113, permute_257); permute_257 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:235 in forward, code: xq = xq.view(bsz, seqlen, self.n_heads, self.head_dim)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_281: "bf16[1, 2048, 16, 128]" = torch.ops.aten.reshape.default(mm_98, [1, 2048, 16, 128]); mm_98 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:236 in forward, code: xk = xk.view(bsz, seqlen, self.n_kv_heads, self.head_dim)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_282: "bf16[1, 2048, 16, 128]" = torch.ops.aten.reshape.default(mm_99, [1, 2048, 16, 128]); mm_99 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:237 in forward, code: xv = xv.view(bsz, seqlen, self.n_kv_heads, self.head_dim)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_283: "bf16[1, 2048, 16, 128]" = torch.ops.aten.reshape.default(mm_100, [1, 2048, 16, 128]); mm_100 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:146 in apply_rotary_emb, code: xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_482: "f32[1, 2048, 16, 128]" = torch.ops.prims.convert_element_type.default(view_281, torch.float32); view_281 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_284: "f32[1, 2048, 16, 64, 2]" = torch.ops.aten.reshape.default(convert_element_type_482, [1, 2048, 16, -1, 2]); convert_element_type_482 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_as_complex_28: "c64[1, 2048, 16, 64]" = torch.ops.aten.view_as_complex.default(view_284); view_284 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:147 in apply_rotary_emb, code: xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_483: "f32[1, 2048, 16, 128]" = torch.ops.prims.convert_element_type.default(view_282, torch.float32); view_282 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_285: "f32[1, 2048, 16, 64, 2]" = torch.ops.aten.reshape.default(convert_element_type_483, [1, 2048, 16, -1, 2]); convert_element_type_483 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_as_complex_29: "c64[1, 2048, 16, 64]" = torch.ops.aten.view_as_complex.default(view_285); view_285 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:149 in apply_rotary_emb, code: xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mul_114: "c64[1, 2048, 16, 64]" = torch.ops.aten.mul.Tensor(view_as_complex_28, view_20); view_as_complex_28 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_as_real_28: "f32[1, 2048, 16, 64, 2]" = torch.ops.aten.view_as_real.default(mul_114); mul_114 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_287: "f32[1, 2048, 16, 128]" = torch.ops.aten.reshape.default(view_as_real_28, [1, 2048, 16, 128]); view_as_real_28 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:150 in apply_rotary_emb, code: xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mul_115: "c64[1, 2048, 16, 64]" = torch.ops.aten.mul.Tensor(view_as_complex_29, view_20); view_as_complex_29 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_as_real_29: "f32[1, 2048, 16, 64, 2]" = torch.ops.aten.view_as_real.default(mul_115); mul_115 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_288: "f32[1, 2048, 16, 128]" = torch.ops.aten.reshape.default(view_as_real_29, [1, 2048, 16, 128]); view_as_real_29 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:151 in apply_rotary_emb, code: return xq_out.type_as(xq), xk_out.type_as(xk)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_484: "bf16[1, 2048, 16, 128]" = torch.ops.prims.convert_element_type.default(view_287, torch.bfloat16); view_287 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_485: "bf16[1, 2048, 16, 128]" = torch.ops.prims.convert_element_type.default(view_288, torch.bfloat16); view_288 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:249 in forward, code: xq = xq.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] permute_258: "bf16[1, 16, 2048, 128]" = torch.ops.aten.permute.default(convert_element_type_484, [0, 2, 1, 3]); convert_element_type_484 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:250 in forward, code: xk = keys.transpose(1, 2) # (bs, n_local_heads, cache_len + seqlen, head_dim)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] permute_259: "bf16[1, 16, 2048, 128]" = torch.ops.aten.permute.default(convert_element_type_485, [0, 2, 1, 3]); convert_element_type_485 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:251 in forward, code: xv = values.transpose(1, 2) # (bs, n_local_heads, cache_len + seqlen, head_dim)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] permute_260: "bf16[1, 16, 2048, 128]" = torch.ops.aten.permute.default(view_283, [0, 2, 1, 3]); view_283 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:254 in forward, code: output = F.scaled_dot_product_attention(xq, xk, xv, is_causal=True)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] _scaled_dot_product_flash_attention_14 = torch.ops.aten._scaled_dot_product_flash_attention.default(permute_258, permute_259, permute_260, 0.0, True, scale = 0.08838834764831843)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_1520: "bf16[1, 16, 2048, 128]" = _scaled_dot_product_flash_attention_14[0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_1521: "f32[1, 16, 2048]" = _scaled_dot_product_flash_attention_14[1]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_1526: "i64[]" = _scaled_dot_product_flash_attention_14[6]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_1527: "i64[]" = _scaled_dot_product_flash_attention_14[7]; _scaled_dot_product_flash_attention_14 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:255 in forward, code: output = output.transpose(
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] permute_261: "bf16[1, 2048, 16, 128]" = torch.ops.aten.permute.default(getitem_1520, [0, 2, 1, 3])
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:259 in forward, code: output = output.view(bsz * seqlen, -1)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_289: "bf16[2048, 2048]" = torch.ops.aten.reshape.default(permute_261, [2048, -1]); permute_261 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/pytorch_yf225/torch/nn/modules/linear.py:116 in forward, code: return F.linear(input, self.weight, self.bias)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] permute_263: "bf16[2048, 2048]" = torch.ops.aten.permute.default(contiguous_view_as_strided_132, [1, 0]); contiguous_view_as_strided_132 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mm_101: "bf16[2048, 2048]" = torch.ops.aten.mm.default(view_289, permute_263); permute_263 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:410 in forward, code: h = x + self.attention(self.attention_norm(x), freqs_cis)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] add_57: "bf16[2048, 2048]" = torch.ops.aten.add.Tensor(add_55, mm_101); add_55 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:74 in forward, code: output = self._norm(x.float()).type_as(x)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_488: "f32[2048, 2048]" = torch.ops.prims.convert_element_type.default(add_57, torch.float32)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:61 in _norm, code: return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] pow_30: "f32[2048, 2048]" = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_488, 2)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mean_29: "f32[2048, 1]" = torch.ops.aten.mean.dim(pow_30, [-1], True); pow_30 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] add_58: "f32[2048, 1]" = torch.ops.aten.add.Tensor(mean_29, 1e-05); mean_29 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] rsqrt_29: "f32[2048, 1]" = torch.ops.aten.rsqrt.default(add_58); add_58 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mul_116: "f32[2048, 2048]" = torch.ops.aten.mul.Tensor(convert_element_type_488, rsqrt_29); convert_element_type_488 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:74 in forward, code: output = self._norm(x.float()).type_as(x)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_489: "bf16[2048, 2048]" = torch.ops.prims.convert_element_type.default(mul_116, torch.bfloat16); mul_116 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:75 in forward, code: return output * self.weight
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mul_117: "bf16[2048, 2048]" = torch.ops.aten.mul.Tensor(convert_element_type_489, contiguous_view_as_strided_137); convert_element_type_489 = contiguous_view_as_strided_137 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/pytorch_yf225/torch/nn/modules/linear.py:116 in forward, code: return F.linear(input, self.weight, self.bias)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] permute_265: "bf16[2048, 5632]" = torch.ops.aten.permute.default(contiguous_view_as_strided_133, [1, 0]); contiguous_view_as_strided_133 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mm_102: "bf16[2048, 5632]" = torch.ops.aten.mm.default(mul_117, permute_265); permute_265 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:299 in forward, code: return self.w2(F.silu(self.w1(x)) * self.w3(x))
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_492: "f32[2048, 5632]" = torch.ops.prims.convert_element_type.default(mm_102, torch.float32)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] sigmoid_14: "f32[2048, 5632]" = torch.ops.aten.sigmoid.default(convert_element_type_492)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mul_118: "f32[2048, 5632]" = torch.ops.aten.mul.Tensor(convert_element_type_492, sigmoid_14); convert_element_type_492 = sigmoid_14 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_493: "bf16[2048, 5632]" = torch.ops.prims.convert_element_type.default(mul_118, torch.bfloat16); mul_118 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/pytorch_yf225/torch/nn/modules/linear.py:116 in forward, code: return F.linear(input, self.weight, self.bias)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] permute_267: "bf16[2048, 5632]" = torch.ops.aten.permute.default(contiguous_view_as_strided_135, [1, 0]); contiguous_view_as_strided_135 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mm_103: "bf16[2048, 5632]" = torch.ops.aten.mm.default(mul_117, permute_267); permute_267 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:299 in forward, code: return self.w2(F.silu(self.w1(x)) * self.w3(x))
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mul_119: "bf16[2048, 5632]" = torch.ops.aten.mul.Tensor(convert_element_type_493, mm_103); convert_element_type_493 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/pytorch_yf225/torch/nn/modules/linear.py:116 in forward, code: return F.linear(input, self.weight, self.bias)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] permute_269: "bf16[5632, 2048]" = torch.ops.aten.permute.default(contiguous_view_as_strided_134, [1, 0]); contiguous_view_as_strided_134 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mm_104: "bf16[2048, 2048]" = torch.ops.aten.mm.default(mul_119, permute_269); permute_269 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:411 in forward, code: out = h + self.feed_forward(self.ffn_norm(h))
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] add_59: "bf16[2048, 2048]" = torch.ops.aten.add.Tensor(add_57, mm_104); add_57 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_common.py:140 in _to_dtype_if_needed, code: return tensor.to(dtype)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_498: "bf16[524288]" = torch.ops.prims.convert_element_type.default(primals_280, torch.bfloat16); primals_280 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_499: "bf16[524288]" = torch.ops.prims.convert_element_type.default(primals_281, torch.bfloat16); primals_281 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_500: "bf16[524288]" = torch.ops.prims.convert_element_type.default(primals_282, torch.bfloat16); primals_282 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_501: "bf16[524288]" = torch.ops.prims.convert_element_type.default(primals_283, torch.bfloat16); primals_283 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_502: "bf16[1441792]" = torch.ops.prims.convert_element_type.default(primals_284, torch.bfloat16); primals_284 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_503: "bf16[1441792]" = torch.ops.prims.convert_element_type.default(primals_285, torch.bfloat16); primals_285 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_504: "bf16[1441792]" = torch.ops.prims.convert_element_type.default(primals_286, torch.bfloat16); primals_286 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_505: "bf16[256]" = torch.ops.prims.convert_element_type.default(primals_287, torch.bfloat16); primals_287 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_506: "bf16[256]" = torch.ops.prims.convert_element_type.default(primals_288, torch.bfloat16); primals_288 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:80 in foreach_all_gather, code: all_gather_input, all_gather_output = torch.ops.fsdp.all_gather_copy_in(all_gather_input_numel, world_size, rank, dtype, device, inp_split_sizes, param_all_gather_inputs)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] all_gather_copy_in_16 = torch.ops.fsdp.all_gather_copy_in.default(6423040, 8, 0, torch.bfloat16, device(type='cuda', index=0), [524288, 524288, 524288, 524288, 1441792, 1441792, 1441792, 256, 256], [convert_element_type_498, convert_element_type_499, convert_element_type_500, convert_element_type_501, convert_element_type_502, convert_element_type_503, convert_element_type_504, convert_element_type_505, convert_element_type_506]); convert_element_type_498 = convert_element_type_499 = convert_element_type_500 = convert_element_type_501 = convert_element_type_502 = convert_element_type_503 = convert_element_type_504 = convert_element_type_505 = convert_element_type_506 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_1529: "bf16[6423040]" = all_gather_copy_in_16[0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_1530: "bf16[51384320]" = all_gather_copy_in_16[1]; all_gather_copy_in_16 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_functional_collectives.py:229 in all_gather_tensor, code: tensor = torch.ops._c10d_functional.all_gather_into_tensor(
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] all_gather_into_tensor_16: "bf16[51384320]" = torch.ops._c10d_functional.all_gather_into_tensor.default(getitem_1529, 8, '0'); getitem_1529 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_functional_collectives.py:144 in wait_tensor, code: return torch.ops._c10d_functional.wait_tensor(tensor) # type: ignore[attr-defined]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] wait_tensor_16: "bf16[51384320]" = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_16); all_gather_into_tensor_16 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:170 in foreach_all_gather_copy_out, code: split_unpadded = torch.ops.fsdp.contiguous_view_as_strided(
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_291: "bf16[8, 6423040]" = torch.ops.aten.reshape.default(wait_tensor_16, [8, -1]); wait_tensor_16 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] split_with_sizes_155 = torch.ops.aten.split_with_sizes.default(view_291, [524288, 524288, 524288, 524288, 1441792, 1441792, 1441792, 256, 256], 1); view_291 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_1540: "bf16[8, 524288]" = split_with_sizes_155[0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] contiguous_view_as_strided_138: "bf16[2048, 2048]" = torch.ops.fsdp.contiguous_view_as_strided.default(getitem_1540, [2048, 2048], [2048, 1]); getitem_1540 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_1550: "bf16[8, 524288]" = split_with_sizes_155[1]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] contiguous_view_as_strided_139: "bf16[2048, 2048]" = torch.ops.fsdp.contiguous_view_as_strided.default(getitem_1550, [2048, 2048], [2048, 1]); getitem_1550 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_1560: "bf16[8, 524288]" = split_with_sizes_155[2]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] contiguous_view_as_strided_140: "bf16[2048, 2048]" = torch.ops.fsdp.contiguous_view_as_strided.default(getitem_1560, [2048, 2048], [2048, 1]); getitem_1560 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_1570: "bf16[8, 524288]" = split_with_sizes_155[3]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] contiguous_view_as_strided_141: "bf16[2048, 2048]" = torch.ops.fsdp.contiguous_view_as_strided.default(getitem_1570, [2048, 2048], [2048, 1]); getitem_1570 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_1580: "bf16[8, 1441792]" = split_with_sizes_155[4]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] contiguous_view_as_strided_142: "bf16[5632, 2048]" = torch.ops.fsdp.contiguous_view_as_strided.default(getitem_1580, [5632, 2048], [2048, 1]); getitem_1580 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_1590: "bf16[8, 1441792]" = split_with_sizes_155[5]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] contiguous_view_as_strided_143: "bf16[2048, 5632]" = torch.ops.fsdp.contiguous_view_as_strided.default(getitem_1590, [2048, 5632], [5632, 1]); getitem_1590 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_1600: "bf16[8, 1441792]" = split_with_sizes_155[6]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] contiguous_view_as_strided_144: "bf16[5632, 2048]" = torch.ops.fsdp.contiguous_view_as_strided.default(getitem_1600, [5632, 2048], [2048, 1]); getitem_1600 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_1610: "bf16[8, 256]" = split_with_sizes_155[7]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] contiguous_view_as_strided_145: "bf16[2048]" = torch.ops.fsdp.contiguous_view_as_strided.default(getitem_1610, [2048], [1]); getitem_1610 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_1620: "bf16[8, 256]" = split_with_sizes_155[8]; split_with_sizes_155 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] contiguous_view_as_strided_146: "bf16[2048]" = torch.ops.fsdp.contiguous_view_as_strided.default(getitem_1620, [2048], [1]); getitem_1620 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:74 in forward, code: output = self._norm(x.float()).type_as(x)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_507: "f32[2048, 2048]" = torch.ops.prims.convert_element_type.default(add_59, torch.float32)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:61 in _norm, code: return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] pow_31: "f32[2048, 2048]" = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_507, 2)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mean_30: "f32[2048, 1]" = torch.ops.aten.mean.dim(pow_31, [-1], True); pow_31 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] add_60: "f32[2048, 1]" = torch.ops.aten.add.Tensor(mean_30, 1e-05); mean_30 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] rsqrt_30: "f32[2048, 1]" = torch.ops.aten.rsqrt.default(add_60); add_60 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mul_120: "f32[2048, 2048]" = torch.ops.aten.mul.Tensor(convert_element_type_507, rsqrt_30); convert_element_type_507 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:74 in forward, code: output = self._norm(x.float()).type_as(x)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_508: "bf16[2048, 2048]" = torch.ops.prims.convert_element_type.default(mul_120, torch.bfloat16); mul_120 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:75 in forward, code: return output * self.weight
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mul_121: "bf16[2048, 2048]" = torch.ops.aten.mul.Tensor(convert_element_type_508, contiguous_view_as_strided_145); convert_element_type_508 = contiguous_view_as_strided_145 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/pytorch_yf225/torch/nn/modules/linear.py:116 in forward, code: return F.linear(input, self.weight, self.bias)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] permute_271: "bf16[2048, 2048]" = torch.ops.aten.permute.default(contiguous_view_as_strided_138, [1, 0]); contiguous_view_as_strided_138 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mm_105: "bf16[2048, 2048]" = torch.ops.aten.mm.default(mul_121, permute_271); permute_271 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] permute_273: "bf16[2048, 2048]" = torch.ops.aten.permute.default(contiguous_view_as_strided_139, [1, 0]); contiguous_view_as_strided_139 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mm_106: "bf16[2048, 2048]" = torch.ops.aten.mm.default(mul_121, permute_273); permute_273 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] permute_275: "bf16[2048, 2048]" = torch.ops.aten.permute.default(contiguous_view_as_strided_140, [1, 0]); contiguous_view_as_strided_140 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mm_107: "bf16[2048, 2048]" = torch.ops.aten.mm.default(mul_121, permute_275); permute_275 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:235 in forward, code: xq = xq.view(bsz, seqlen, self.n_heads, self.head_dim)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_300: "bf16[1, 2048, 16, 128]" = torch.ops.aten.reshape.default(mm_105, [1, 2048, 16, 128]); mm_105 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:236 in forward, code: xk = xk.view(bsz, seqlen, self.n_kv_heads, self.head_dim)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_301: "bf16[1, 2048, 16, 128]" = torch.ops.aten.reshape.default(mm_106, [1, 2048, 16, 128]); mm_106 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:237 in forward, code: xv = xv.view(bsz, seqlen, self.n_kv_heads, self.head_dim)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_302: "bf16[1, 2048, 16, 128]" = torch.ops.aten.reshape.default(mm_107, [1, 2048, 16, 128]); mm_107 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:146 in apply_rotary_emb, code: xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_515: "f32[1, 2048, 16, 128]" = torch.ops.prims.convert_element_type.default(view_300, torch.float32); view_300 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_303: "f32[1, 2048, 16, 64, 2]" = torch.ops.aten.reshape.default(convert_element_type_515, [1, 2048, 16, -1, 2]); convert_element_type_515 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_as_complex_30: "c64[1, 2048, 16, 64]" = torch.ops.aten.view_as_complex.default(view_303); view_303 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:147 in apply_rotary_emb, code: xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_516: "f32[1, 2048, 16, 128]" = torch.ops.prims.convert_element_type.default(view_301, torch.float32); view_301 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_304: "f32[1, 2048, 16, 64, 2]" = torch.ops.aten.reshape.default(convert_element_type_516, [1, 2048, 16, -1, 2]); convert_element_type_516 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_as_complex_31: "c64[1, 2048, 16, 64]" = torch.ops.aten.view_as_complex.default(view_304); view_304 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:149 in apply_rotary_emb, code: xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mul_122: "c64[1, 2048, 16, 64]" = torch.ops.aten.mul.Tensor(view_as_complex_30, view_20); view_as_complex_30 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_as_real_30: "f32[1, 2048, 16, 64, 2]" = torch.ops.aten.view_as_real.default(mul_122); mul_122 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_306: "f32[1, 2048, 16, 128]" = torch.ops.aten.reshape.default(view_as_real_30, [1, 2048, 16, 128]); view_as_real_30 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:150 in apply_rotary_emb, code: xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mul_123: "c64[1, 2048, 16, 64]" = torch.ops.aten.mul.Tensor(view_as_complex_31, view_20); view_as_complex_31 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_as_real_31: "f32[1, 2048, 16, 64, 2]" = torch.ops.aten.view_as_real.default(mul_123); mul_123 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_307: "f32[1, 2048, 16, 128]" = torch.ops.aten.reshape.default(view_as_real_31, [1, 2048, 16, 128]); view_as_real_31 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:151 in apply_rotary_emb, code: return xq_out.type_as(xq), xk_out.type_as(xk)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_517: "bf16[1, 2048, 16, 128]" = torch.ops.prims.convert_element_type.default(view_306, torch.bfloat16); view_306 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_518: "bf16[1, 2048, 16, 128]" = torch.ops.prims.convert_element_type.default(view_307, torch.bfloat16); view_307 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:249 in forward, code: xq = xq.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] permute_276: "bf16[1, 16, 2048, 128]" = torch.ops.aten.permute.default(convert_element_type_517, [0, 2, 1, 3]); convert_element_type_517 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:250 in forward, code: xk = keys.transpose(1, 2) # (bs, n_local_heads, cache_len + seqlen, head_dim)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] permute_277: "bf16[1, 16, 2048, 128]" = torch.ops.aten.permute.default(convert_element_type_518, [0, 2, 1, 3]); convert_element_type_518 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:251 in forward, code: xv = values.transpose(1, 2) # (bs, n_local_heads, cache_len + seqlen, head_dim)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] permute_278: "bf16[1, 16, 2048, 128]" = torch.ops.aten.permute.default(view_302, [0, 2, 1, 3]); view_302 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:254 in forward, code: output = F.scaled_dot_product_attention(xq, xk, xv, is_causal=True)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] _scaled_dot_product_flash_attention_15 = torch.ops.aten._scaled_dot_product_flash_attention.default(permute_276, permute_277, permute_278, 0.0, True, scale = 0.08838834764831843)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_1621: "bf16[1, 16, 2048, 128]" = _scaled_dot_product_flash_attention_15[0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_1622: "f32[1, 16, 2048]" = _scaled_dot_product_flash_attention_15[1]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_1627: "i64[]" = _scaled_dot_product_flash_attention_15[6]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_1628: "i64[]" = _scaled_dot_product_flash_attention_15[7]; _scaled_dot_product_flash_attention_15 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:255 in forward, code: output = output.transpose(
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] permute_279: "bf16[1, 2048, 16, 128]" = torch.ops.aten.permute.default(getitem_1621, [0, 2, 1, 3])
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:259 in forward, code: output = output.view(bsz * seqlen, -1)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_308: "bf16[2048, 2048]" = torch.ops.aten.reshape.default(permute_279, [2048, -1]); permute_279 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/pytorch_yf225/torch/nn/modules/linear.py:116 in forward, code: return F.linear(input, self.weight, self.bias)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] permute_281: "bf16[2048, 2048]" = torch.ops.aten.permute.default(contiguous_view_as_strided_141, [1, 0]); contiguous_view_as_strided_141 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mm_108: "bf16[2048, 2048]" = torch.ops.aten.mm.default(view_308, permute_281); permute_281 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:410 in forward, code: h = x + self.attention(self.attention_norm(x), freqs_cis)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] add_61: "bf16[2048, 2048]" = torch.ops.aten.add.Tensor(add_59, mm_108); add_59 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:74 in forward, code: output = self._norm(x.float()).type_as(x)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_521: "f32[2048, 2048]" = torch.ops.prims.convert_element_type.default(add_61, torch.float32)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:61 in _norm, code: return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] pow_32: "f32[2048, 2048]" = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_521, 2)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mean_31: "f32[2048, 1]" = torch.ops.aten.mean.dim(pow_32, [-1], True); pow_32 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] add_62: "f32[2048, 1]" = torch.ops.aten.add.Tensor(mean_31, 1e-05); mean_31 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] rsqrt_31: "f32[2048, 1]" = torch.ops.aten.rsqrt.default(add_62); add_62 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mul_124: "f32[2048, 2048]" = torch.ops.aten.mul.Tensor(convert_element_type_521, rsqrt_31); convert_element_type_521 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:74 in forward, code: output = self._norm(x.float()).type_as(x)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_522: "bf16[2048, 2048]" = torch.ops.prims.convert_element_type.default(mul_124, torch.bfloat16); mul_124 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:75 in forward, code: return output * self.weight
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mul_125: "bf16[2048, 2048]" = torch.ops.aten.mul.Tensor(convert_element_type_522, contiguous_view_as_strided_146); convert_element_type_522 = contiguous_view_as_strided_146 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/pytorch_yf225/torch/nn/modules/linear.py:116 in forward, code: return F.linear(input, self.weight, self.bias)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] permute_283: "bf16[2048, 5632]" = torch.ops.aten.permute.default(contiguous_view_as_strided_142, [1, 0]); contiguous_view_as_strided_142 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mm_109: "bf16[2048, 5632]" = torch.ops.aten.mm.default(mul_125, permute_283); permute_283 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:299 in forward, code: return self.w2(F.silu(self.w1(x)) * self.w3(x))
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_525: "f32[2048, 5632]" = torch.ops.prims.convert_element_type.default(mm_109, torch.float32)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] sigmoid_15: "f32[2048, 5632]" = torch.ops.aten.sigmoid.default(convert_element_type_525)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mul_126: "f32[2048, 5632]" = torch.ops.aten.mul.Tensor(convert_element_type_525, sigmoid_15); convert_element_type_525 = sigmoid_15 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_526: "bf16[2048, 5632]" = torch.ops.prims.convert_element_type.default(mul_126, torch.bfloat16); mul_126 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/pytorch_yf225/torch/nn/modules/linear.py:116 in forward, code: return F.linear(input, self.weight, self.bias)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] permute_285: "bf16[2048, 5632]" = torch.ops.aten.permute.default(contiguous_view_as_strided_144, [1, 0]); contiguous_view_as_strided_144 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mm_110: "bf16[2048, 5632]" = torch.ops.aten.mm.default(mul_125, permute_285); permute_285 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:299 in forward, code: return self.w2(F.silu(self.w1(x)) * self.w3(x))
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mul_127: "bf16[2048, 5632]" = torch.ops.aten.mul.Tensor(convert_element_type_526, mm_110); convert_element_type_526 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/pytorch_yf225/torch/nn/modules/linear.py:116 in forward, code: return F.linear(input, self.weight, self.bias)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] permute_287: "bf16[5632, 2048]" = torch.ops.aten.permute.default(contiguous_view_as_strided_143, [1, 0]); contiguous_view_as_strided_143 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mm_111: "bf16[2048, 2048]" = torch.ops.aten.mm.default(mul_127, permute_287); permute_287 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:411 in forward, code: out = h + self.feed_forward(self.ffn_norm(h))
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] add_63: "bf16[2048, 2048]" = torch.ops.aten.add.Tensor(add_61, mm_111); add_61 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_common.py:140 in _to_dtype_if_needed, code: return tensor.to(dtype)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_531: "bf16[524288]" = torch.ops.prims.convert_element_type.default(primals_298, torch.bfloat16); primals_298 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_532: "bf16[524288]" = torch.ops.prims.convert_element_type.default(primals_299, torch.bfloat16); primals_299 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_533: "bf16[524288]" = torch.ops.prims.convert_element_type.default(primals_300, torch.bfloat16); primals_300 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_534: "bf16[524288]" = torch.ops.prims.convert_element_type.default(primals_301, torch.bfloat16); primals_301 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_535: "bf16[1441792]" = torch.ops.prims.convert_element_type.default(primals_302, torch.bfloat16); primals_302 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_536: "bf16[1441792]" = torch.ops.prims.convert_element_type.default(primals_303, torch.bfloat16); primals_303 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_537: "bf16[1441792]" = torch.ops.prims.convert_element_type.default(primals_304, torch.bfloat16); primals_304 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_538: "bf16[256]" = torch.ops.prims.convert_element_type.default(primals_305, torch.bfloat16); primals_305 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_539: "bf16[256]" = torch.ops.prims.convert_element_type.default(primals_306, torch.bfloat16); primals_306 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:80 in foreach_all_gather, code: all_gather_input, all_gather_output = torch.ops.fsdp.all_gather_copy_in(all_gather_input_numel, world_size, rank, dtype, device, inp_split_sizes, param_all_gather_inputs)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] all_gather_copy_in_17 = torch.ops.fsdp.all_gather_copy_in.default(6423040, 8, 0, torch.bfloat16, device(type='cuda', index=0), [524288, 524288, 524288, 524288, 1441792, 1441792, 1441792, 256, 256], [convert_element_type_531, convert_element_type_532, convert_element_type_533, convert_element_type_534, convert_element_type_535, convert_element_type_536, convert_element_type_537, convert_element_type_538, convert_element_type_539]); convert_element_type_531 = convert_element_type_532 = convert_element_type_533 = convert_element_type_534 = convert_element_type_535 = convert_element_type_536 = convert_element_type_537 = convert_element_type_538 = convert_element_type_539 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_1630: "bf16[6423040]" = all_gather_copy_in_17[0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_1631: "bf16[51384320]" = all_gather_copy_in_17[1]; all_gather_copy_in_17 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_functional_collectives.py:229 in all_gather_tensor, code: tensor = torch.ops._c10d_functional.all_gather_into_tensor(
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] all_gather_into_tensor_17: "bf16[51384320]" = torch.ops._c10d_functional.all_gather_into_tensor.default(getitem_1630, 8, '0'); getitem_1630 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_functional_collectives.py:144 in wait_tensor, code: return torch.ops._c10d_functional.wait_tensor(tensor) # type: ignore[attr-defined]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] wait_tensor_17: "bf16[51384320]" = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_17); all_gather_into_tensor_17 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:170 in foreach_all_gather_copy_out, code: split_unpadded = torch.ops.fsdp.contiguous_view_as_strided(
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_310: "bf16[8, 6423040]" = torch.ops.aten.reshape.default(wait_tensor_17, [8, -1]); wait_tensor_17 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] split_with_sizes_165 = torch.ops.aten.split_with_sizes.default(view_310, [524288, 524288, 524288, 524288, 1441792, 1441792, 1441792, 256, 256], 1); view_310 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_1641: "bf16[8, 524288]" = split_with_sizes_165[0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] contiguous_view_as_strided_147: "bf16[2048, 2048]" = torch.ops.fsdp.contiguous_view_as_strided.default(getitem_1641, [2048, 2048], [2048, 1]); getitem_1641 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_1651: "bf16[8, 524288]" = split_with_sizes_165[1]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] contiguous_view_as_strided_148: "bf16[2048, 2048]" = torch.ops.fsdp.contiguous_view_as_strided.default(getitem_1651, [2048, 2048], [2048, 1]); getitem_1651 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_1661: "bf16[8, 524288]" = split_with_sizes_165[2]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] contiguous_view_as_strided_149: "bf16[2048, 2048]" = torch.ops.fsdp.contiguous_view_as_strided.default(getitem_1661, [2048, 2048], [2048, 1]); getitem_1661 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_1671: "bf16[8, 524288]" = split_with_sizes_165[3]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] contiguous_view_as_strided_150: "bf16[2048, 2048]" = torch.ops.fsdp.contiguous_view_as_strided.default(getitem_1671, [2048, 2048], [2048, 1]); getitem_1671 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_1681: "bf16[8, 1441792]" = split_with_sizes_165[4]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] contiguous_view_as_strided_151: "bf16[5632, 2048]" = torch.ops.fsdp.contiguous_view_as_strided.default(getitem_1681, [5632, 2048], [2048, 1]); getitem_1681 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_1691: "bf16[8, 1441792]" = split_with_sizes_165[5]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] contiguous_view_as_strided_152: "bf16[2048, 5632]" = torch.ops.fsdp.contiguous_view_as_strided.default(getitem_1691, [2048, 5632], [5632, 1]); getitem_1691 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_1701: "bf16[8, 1441792]" = split_with_sizes_165[6]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] contiguous_view_as_strided_153: "bf16[5632, 2048]" = torch.ops.fsdp.contiguous_view_as_strided.default(getitem_1701, [5632, 2048], [2048, 1]); getitem_1701 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_1711: "bf16[8, 256]" = split_with_sizes_165[7]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] contiguous_view_as_strided_154: "bf16[2048]" = torch.ops.fsdp.contiguous_view_as_strided.default(getitem_1711, [2048], [1]); getitem_1711 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_1721: "bf16[8, 256]" = split_with_sizes_165[8]; split_with_sizes_165 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] contiguous_view_as_strided_155: "bf16[2048]" = torch.ops.fsdp.contiguous_view_as_strided.default(getitem_1721, [2048], [1]); getitem_1721 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:74 in forward, code: output = self._norm(x.float()).type_as(x)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_540: "f32[2048, 2048]" = torch.ops.prims.convert_element_type.default(add_63, torch.float32)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:61 in _norm, code: return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] pow_33: "f32[2048, 2048]" = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_540, 2)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mean_32: "f32[2048, 1]" = torch.ops.aten.mean.dim(pow_33, [-1], True); pow_33 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] add_64: "f32[2048, 1]" = torch.ops.aten.add.Tensor(mean_32, 1e-05); mean_32 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] rsqrt_32: "f32[2048, 1]" = torch.ops.aten.rsqrt.default(add_64); add_64 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mul_128: "f32[2048, 2048]" = torch.ops.aten.mul.Tensor(convert_element_type_540, rsqrt_32); convert_element_type_540 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:74 in forward, code: output = self._norm(x.float()).type_as(x)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_541: "bf16[2048, 2048]" = torch.ops.prims.convert_element_type.default(mul_128, torch.bfloat16); mul_128 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:75 in forward, code: return output * self.weight
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mul_129: "bf16[2048, 2048]" = torch.ops.aten.mul.Tensor(convert_element_type_541, contiguous_view_as_strided_154); convert_element_type_541 = contiguous_view_as_strided_154 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/pytorch_yf225/torch/nn/modules/linear.py:116 in forward, code: return F.linear(input, self.weight, self.bias)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] permute_289: "bf16[2048, 2048]" = torch.ops.aten.permute.default(contiguous_view_as_strided_147, [1, 0]); contiguous_view_as_strided_147 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mm_112: "bf16[2048, 2048]" = torch.ops.aten.mm.default(mul_129, permute_289); permute_289 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] permute_291: "bf16[2048, 2048]" = torch.ops.aten.permute.default(contiguous_view_as_strided_148, [1, 0]); contiguous_view_as_strided_148 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mm_113: "bf16[2048, 2048]" = torch.ops.aten.mm.default(mul_129, permute_291); permute_291 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] permute_293: "bf16[2048, 2048]" = torch.ops.aten.permute.default(contiguous_view_as_strided_149, [1, 0]); contiguous_view_as_strided_149 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mm_114: "bf16[2048, 2048]" = torch.ops.aten.mm.default(mul_129, permute_293); permute_293 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:235 in forward, code: xq = xq.view(bsz, seqlen, self.n_heads, self.head_dim)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_319: "bf16[1, 2048, 16, 128]" = torch.ops.aten.reshape.default(mm_112, [1, 2048, 16, 128]); mm_112 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:236 in forward, code: xk = xk.view(bsz, seqlen, self.n_kv_heads, self.head_dim)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_320: "bf16[1, 2048, 16, 128]" = torch.ops.aten.reshape.default(mm_113, [1, 2048, 16, 128]); mm_113 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:237 in forward, code: xv = xv.view(bsz, seqlen, self.n_kv_heads, self.head_dim)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_321: "bf16[1, 2048, 16, 128]" = torch.ops.aten.reshape.default(mm_114, [1, 2048, 16, 128]); mm_114 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:146 in apply_rotary_emb, code: xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_548: "f32[1, 2048, 16, 128]" = torch.ops.prims.convert_element_type.default(view_319, torch.float32); view_319 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_322: "f32[1, 2048, 16, 64, 2]" = torch.ops.aten.reshape.default(convert_element_type_548, [1, 2048, 16, -1, 2]); convert_element_type_548 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_as_complex_32: "c64[1, 2048, 16, 64]" = torch.ops.aten.view_as_complex.default(view_322); view_322 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:147 in apply_rotary_emb, code: xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_549: "f32[1, 2048, 16, 128]" = torch.ops.prims.convert_element_type.default(view_320, torch.float32); view_320 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_323: "f32[1, 2048, 16, 64, 2]" = torch.ops.aten.reshape.default(convert_element_type_549, [1, 2048, 16, -1, 2]); convert_element_type_549 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_as_complex_33: "c64[1, 2048, 16, 64]" = torch.ops.aten.view_as_complex.default(view_323); view_323 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:149 in apply_rotary_emb, code: xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mul_130: "c64[1, 2048, 16, 64]" = torch.ops.aten.mul.Tensor(view_as_complex_32, view_20); view_as_complex_32 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_as_real_32: "f32[1, 2048, 16, 64, 2]" = torch.ops.aten.view_as_real.default(mul_130); mul_130 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_325: "f32[1, 2048, 16, 128]" = torch.ops.aten.reshape.default(view_as_real_32, [1, 2048, 16, 128]); view_as_real_32 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:150 in apply_rotary_emb, code: xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mul_131: "c64[1, 2048, 16, 64]" = torch.ops.aten.mul.Tensor(view_as_complex_33, view_20); view_as_complex_33 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_as_real_33: "f32[1, 2048, 16, 64, 2]" = torch.ops.aten.view_as_real.default(mul_131); mul_131 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_326: "f32[1, 2048, 16, 128]" = torch.ops.aten.reshape.default(view_as_real_33, [1, 2048, 16, 128]); view_as_real_33 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:151 in apply_rotary_emb, code: return xq_out.type_as(xq), xk_out.type_as(xk)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_550: "bf16[1, 2048, 16, 128]" = torch.ops.prims.convert_element_type.default(view_325, torch.bfloat16); view_325 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_551: "bf16[1, 2048, 16, 128]" = torch.ops.prims.convert_element_type.default(view_326, torch.bfloat16); view_326 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:249 in forward, code: xq = xq.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] permute_294: "bf16[1, 16, 2048, 128]" = torch.ops.aten.permute.default(convert_element_type_550, [0, 2, 1, 3]); convert_element_type_550 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:250 in forward, code: xk = keys.transpose(1, 2) # (bs, n_local_heads, cache_len + seqlen, head_dim)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] permute_295: "bf16[1, 16, 2048, 128]" = torch.ops.aten.permute.default(convert_element_type_551, [0, 2, 1, 3]); convert_element_type_551 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:251 in forward, code: xv = values.transpose(1, 2) # (bs, n_local_heads, cache_len + seqlen, head_dim)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] permute_296: "bf16[1, 16, 2048, 128]" = torch.ops.aten.permute.default(view_321, [0, 2, 1, 3]); view_321 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:254 in forward, code: output = F.scaled_dot_product_attention(xq, xk, xv, is_causal=True)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] _scaled_dot_product_flash_attention_16 = torch.ops.aten._scaled_dot_product_flash_attention.default(permute_294, permute_295, permute_296, 0.0, True, scale = 0.08838834764831843)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_1722: "bf16[1, 16, 2048, 128]" = _scaled_dot_product_flash_attention_16[0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_1723: "f32[1, 16, 2048]" = _scaled_dot_product_flash_attention_16[1]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_1728: "i64[]" = _scaled_dot_product_flash_attention_16[6]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_1729: "i64[]" = _scaled_dot_product_flash_attention_16[7]; _scaled_dot_product_flash_attention_16 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:255 in forward, code: output = output.transpose(
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] permute_297: "bf16[1, 2048, 16, 128]" = torch.ops.aten.permute.default(getitem_1722, [0, 2, 1, 3])
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:259 in forward, code: output = output.view(bsz * seqlen, -1)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_327: "bf16[2048, 2048]" = torch.ops.aten.reshape.default(permute_297, [2048, -1]); permute_297 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/pytorch_yf225/torch/nn/modules/linear.py:116 in forward, code: return F.linear(input, self.weight, self.bias)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] permute_299: "bf16[2048, 2048]" = torch.ops.aten.permute.default(contiguous_view_as_strided_150, [1, 0]); contiguous_view_as_strided_150 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mm_115: "bf16[2048, 2048]" = torch.ops.aten.mm.default(view_327, permute_299); permute_299 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:410 in forward, code: h = x + self.attention(self.attention_norm(x), freqs_cis)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] add_65: "bf16[2048, 2048]" = torch.ops.aten.add.Tensor(add_63, mm_115); add_63 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:74 in forward, code: output = self._norm(x.float()).type_as(x)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_554: "f32[2048, 2048]" = torch.ops.prims.convert_element_type.default(add_65, torch.float32)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:61 in _norm, code: return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] pow_34: "f32[2048, 2048]" = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_554, 2)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mean_33: "f32[2048, 1]" = torch.ops.aten.mean.dim(pow_34, [-1], True); pow_34 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] add_66: "f32[2048, 1]" = torch.ops.aten.add.Tensor(mean_33, 1e-05); mean_33 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] rsqrt_33: "f32[2048, 1]" = torch.ops.aten.rsqrt.default(add_66); add_66 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mul_132: "f32[2048, 2048]" = torch.ops.aten.mul.Tensor(convert_element_type_554, rsqrt_33); convert_element_type_554 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:74 in forward, code: output = self._norm(x.float()).type_as(x)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_555: "bf16[2048, 2048]" = torch.ops.prims.convert_element_type.default(mul_132, torch.bfloat16); mul_132 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:75 in forward, code: return output * self.weight
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mul_133: "bf16[2048, 2048]" = torch.ops.aten.mul.Tensor(convert_element_type_555, contiguous_view_as_strided_155); convert_element_type_555 = contiguous_view_as_strided_155 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/pytorch_yf225/torch/nn/modules/linear.py:116 in forward, code: return F.linear(input, self.weight, self.bias)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] permute_301: "bf16[2048, 5632]" = torch.ops.aten.permute.default(contiguous_view_as_strided_151, [1, 0]); contiguous_view_as_strided_151 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mm_116: "bf16[2048, 5632]" = torch.ops.aten.mm.default(mul_133, permute_301); permute_301 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:299 in forward, code: return self.w2(F.silu(self.w1(x)) * self.w3(x))
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_558: "f32[2048, 5632]" = torch.ops.prims.convert_element_type.default(mm_116, torch.float32)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] sigmoid_16: "f32[2048, 5632]" = torch.ops.aten.sigmoid.default(convert_element_type_558)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mul_134: "f32[2048, 5632]" = torch.ops.aten.mul.Tensor(convert_element_type_558, sigmoid_16); convert_element_type_558 = sigmoid_16 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_559: "bf16[2048, 5632]" = torch.ops.prims.convert_element_type.default(mul_134, torch.bfloat16); mul_134 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/pytorch_yf225/torch/nn/modules/linear.py:116 in forward, code: return F.linear(input, self.weight, self.bias)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] permute_303: "bf16[2048, 5632]" = torch.ops.aten.permute.default(contiguous_view_as_strided_153, [1, 0]); contiguous_view_as_strided_153 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mm_117: "bf16[2048, 5632]" = torch.ops.aten.mm.default(mul_133, permute_303); permute_303 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:299 in forward, code: return self.w2(F.silu(self.w1(x)) * self.w3(x))
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mul_135: "bf16[2048, 5632]" = torch.ops.aten.mul.Tensor(convert_element_type_559, mm_117); convert_element_type_559 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/pytorch_yf225/torch/nn/modules/linear.py:116 in forward, code: return F.linear(input, self.weight, self.bias)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] permute_305: "bf16[5632, 2048]" = torch.ops.aten.permute.default(contiguous_view_as_strided_152, [1, 0]); contiguous_view_as_strided_152 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mm_118: "bf16[2048, 2048]" = torch.ops.aten.mm.default(mul_135, permute_305); permute_305 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:411 in forward, code: out = h + self.feed_forward(self.ffn_norm(h))
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] add_67: "bf16[2048, 2048]" = torch.ops.aten.add.Tensor(add_65, mm_118); add_65 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_common.py:140 in _to_dtype_if_needed, code: return tensor.to(dtype)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_564: "bf16[524288]" = torch.ops.prims.convert_element_type.default(primals_316, torch.bfloat16); primals_316 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_565: "bf16[524288]" = torch.ops.prims.convert_element_type.default(primals_317, torch.bfloat16); primals_317 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_566: "bf16[524288]" = torch.ops.prims.convert_element_type.default(primals_318, torch.bfloat16); primals_318 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_567: "bf16[524288]" = torch.ops.prims.convert_element_type.default(primals_319, torch.bfloat16); primals_319 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_568: "bf16[1441792]" = torch.ops.prims.convert_element_type.default(primals_320, torch.bfloat16); primals_320 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_569: "bf16[1441792]" = torch.ops.prims.convert_element_type.default(primals_321, torch.bfloat16); primals_321 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_570: "bf16[1441792]" = torch.ops.prims.convert_element_type.default(primals_322, torch.bfloat16); primals_322 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_571: "bf16[256]" = torch.ops.prims.convert_element_type.default(primals_323, torch.bfloat16); primals_323 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_572: "bf16[256]" = torch.ops.prims.convert_element_type.default(primals_324, torch.bfloat16); primals_324 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:80 in foreach_all_gather, code: all_gather_input, all_gather_output = torch.ops.fsdp.all_gather_copy_in(all_gather_input_numel, world_size, rank, dtype, device, inp_split_sizes, param_all_gather_inputs)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] all_gather_copy_in_18 = torch.ops.fsdp.all_gather_copy_in.default(6423040, 8, 0, torch.bfloat16, device(type='cuda', index=0), [524288, 524288, 524288, 524288, 1441792, 1441792, 1441792, 256, 256], [convert_element_type_564, convert_element_type_565, convert_element_type_566, convert_element_type_567, convert_element_type_568, convert_element_type_569, convert_element_type_570, convert_element_type_571, convert_element_type_572]); convert_element_type_564 = convert_element_type_565 = convert_element_type_566 = convert_element_type_567 = convert_element_type_568 = convert_element_type_569 = convert_element_type_570 = convert_element_type_571 = convert_element_type_572 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_1731: "bf16[6423040]" = all_gather_copy_in_18[0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_1732: "bf16[51384320]" = all_gather_copy_in_18[1]; all_gather_copy_in_18 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_functional_collectives.py:229 in all_gather_tensor, code: tensor = torch.ops._c10d_functional.all_gather_into_tensor(
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] all_gather_into_tensor_18: "bf16[51384320]" = torch.ops._c10d_functional.all_gather_into_tensor.default(getitem_1731, 8, '0'); getitem_1731 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_functional_collectives.py:144 in wait_tensor, code: return torch.ops._c10d_functional.wait_tensor(tensor) # type: ignore[attr-defined]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] wait_tensor_18: "bf16[51384320]" = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_18); all_gather_into_tensor_18 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:170 in foreach_all_gather_copy_out, code: split_unpadded = torch.ops.fsdp.contiguous_view_as_strided(
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_329: "bf16[8, 6423040]" = torch.ops.aten.reshape.default(wait_tensor_18, [8, -1]); wait_tensor_18 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] split_with_sizes_175 = torch.ops.aten.split_with_sizes.default(view_329, [524288, 524288, 524288, 524288, 1441792, 1441792, 1441792, 256, 256], 1); view_329 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_1742: "bf16[8, 524288]" = split_with_sizes_175[0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] contiguous_view_as_strided_156: "bf16[2048, 2048]" = torch.ops.fsdp.contiguous_view_as_strided.default(getitem_1742, [2048, 2048], [2048, 1]); getitem_1742 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_1752: "bf16[8, 524288]" = split_with_sizes_175[1]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] contiguous_view_as_strided_157: "bf16[2048, 2048]" = torch.ops.fsdp.contiguous_view_as_strided.default(getitem_1752, [2048, 2048], [2048, 1]); getitem_1752 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_1762: "bf16[8, 524288]" = split_with_sizes_175[2]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] contiguous_view_as_strided_158: "bf16[2048, 2048]" = torch.ops.fsdp.contiguous_view_as_strided.default(getitem_1762, [2048, 2048], [2048, 1]); getitem_1762 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_1772: "bf16[8, 524288]" = split_with_sizes_175[3]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] contiguous_view_as_strided_159: "bf16[2048, 2048]" = torch.ops.fsdp.contiguous_view_as_strided.default(getitem_1772, [2048, 2048], [2048, 1]); getitem_1772 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_1782: "bf16[8, 1441792]" = split_with_sizes_175[4]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] contiguous_view_as_strided_160: "bf16[5632, 2048]" = torch.ops.fsdp.contiguous_view_as_strided.default(getitem_1782, [5632, 2048], [2048, 1]); getitem_1782 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_1792: "bf16[8, 1441792]" = split_with_sizes_175[5]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] contiguous_view_as_strided_161: "bf16[2048, 5632]" = torch.ops.fsdp.contiguous_view_as_strided.default(getitem_1792, [2048, 5632], [5632, 1]); getitem_1792 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_1802: "bf16[8, 1441792]" = split_with_sizes_175[6]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] contiguous_view_as_strided_162: "bf16[5632, 2048]" = torch.ops.fsdp.contiguous_view_as_strided.default(getitem_1802, [5632, 2048], [2048, 1]); getitem_1802 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_1812: "bf16[8, 256]" = split_with_sizes_175[7]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] contiguous_view_as_strided_163: "bf16[2048]" = torch.ops.fsdp.contiguous_view_as_strided.default(getitem_1812, [2048], [1]); getitem_1812 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_1822: "bf16[8, 256]" = split_with_sizes_175[8]; split_with_sizes_175 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] contiguous_view_as_strided_164: "bf16[2048]" = torch.ops.fsdp.contiguous_view_as_strided.default(getitem_1822, [2048], [1]); getitem_1822 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:74 in forward, code: output = self._norm(x.float()).type_as(x)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_573: "f32[2048, 2048]" = torch.ops.prims.convert_element_type.default(add_67, torch.float32)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:61 in _norm, code: return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] pow_35: "f32[2048, 2048]" = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_573, 2)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mean_34: "f32[2048, 1]" = torch.ops.aten.mean.dim(pow_35, [-1], True); pow_35 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] add_68: "f32[2048, 1]" = torch.ops.aten.add.Tensor(mean_34, 1e-05); mean_34 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] rsqrt_34: "f32[2048, 1]" = torch.ops.aten.rsqrt.default(add_68); add_68 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mul_136: "f32[2048, 2048]" = torch.ops.aten.mul.Tensor(convert_element_type_573, rsqrt_34); convert_element_type_573 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:74 in forward, code: output = self._norm(x.float()).type_as(x)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_574: "bf16[2048, 2048]" = torch.ops.prims.convert_element_type.default(mul_136, torch.bfloat16); mul_136 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:75 in forward, code: return output * self.weight
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mul_137: "bf16[2048, 2048]" = torch.ops.aten.mul.Tensor(convert_element_type_574, contiguous_view_as_strided_163); convert_element_type_574 = contiguous_view_as_strided_163 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/pytorch_yf225/torch/nn/modules/linear.py:116 in forward, code: return F.linear(input, self.weight, self.bias)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] permute_307: "bf16[2048, 2048]" = torch.ops.aten.permute.default(contiguous_view_as_strided_156, [1, 0]); contiguous_view_as_strided_156 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mm_119: "bf16[2048, 2048]" = torch.ops.aten.mm.default(mul_137, permute_307); permute_307 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] permute_309: "bf16[2048, 2048]" = torch.ops.aten.permute.default(contiguous_view_as_strided_157, [1, 0]); contiguous_view_as_strided_157 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mm_120: "bf16[2048, 2048]" = torch.ops.aten.mm.default(mul_137, permute_309); permute_309 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] permute_311: "bf16[2048, 2048]" = torch.ops.aten.permute.default(contiguous_view_as_strided_158, [1, 0]); contiguous_view_as_strided_158 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mm_121: "bf16[2048, 2048]" = torch.ops.aten.mm.default(mul_137, permute_311); permute_311 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:235 in forward, code: xq = xq.view(bsz, seqlen, self.n_heads, self.head_dim)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_338: "bf16[1, 2048, 16, 128]" = torch.ops.aten.reshape.default(mm_119, [1, 2048, 16, 128]); mm_119 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:236 in forward, code: xk = xk.view(bsz, seqlen, self.n_kv_heads, self.head_dim)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_339: "bf16[1, 2048, 16, 128]" = torch.ops.aten.reshape.default(mm_120, [1, 2048, 16, 128]); mm_120 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:237 in forward, code: xv = xv.view(bsz, seqlen, self.n_kv_heads, self.head_dim)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_340: "bf16[1, 2048, 16, 128]" = torch.ops.aten.reshape.default(mm_121, [1, 2048, 16, 128]); mm_121 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:146 in apply_rotary_emb, code: xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_581: "f32[1, 2048, 16, 128]" = torch.ops.prims.convert_element_type.default(view_338, torch.float32); view_338 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_341: "f32[1, 2048, 16, 64, 2]" = torch.ops.aten.reshape.default(convert_element_type_581, [1, 2048, 16, -1, 2]); convert_element_type_581 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_as_complex_34: "c64[1, 2048, 16, 64]" = torch.ops.aten.view_as_complex.default(view_341); view_341 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:147 in apply_rotary_emb, code: xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_582: "f32[1, 2048, 16, 128]" = torch.ops.prims.convert_element_type.default(view_339, torch.float32); view_339 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_342: "f32[1, 2048, 16, 64, 2]" = torch.ops.aten.reshape.default(convert_element_type_582, [1, 2048, 16, -1, 2]); convert_element_type_582 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_as_complex_35: "c64[1, 2048, 16, 64]" = torch.ops.aten.view_as_complex.default(view_342); view_342 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:149 in apply_rotary_emb, code: xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mul_138: "c64[1, 2048, 16, 64]" = torch.ops.aten.mul.Tensor(view_as_complex_34, view_20); view_as_complex_34 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_as_real_34: "f32[1, 2048, 16, 64, 2]" = torch.ops.aten.view_as_real.default(mul_138); mul_138 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_344: "f32[1, 2048, 16, 128]" = torch.ops.aten.reshape.default(view_as_real_34, [1, 2048, 16, 128]); view_as_real_34 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:150 in apply_rotary_emb, code: xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mul_139: "c64[1, 2048, 16, 64]" = torch.ops.aten.mul.Tensor(view_as_complex_35, view_20); view_as_complex_35 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_as_real_35: "f32[1, 2048, 16, 64, 2]" = torch.ops.aten.view_as_real.default(mul_139); mul_139 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_345: "f32[1, 2048, 16, 128]" = torch.ops.aten.reshape.default(view_as_real_35, [1, 2048, 16, 128]); view_as_real_35 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:151 in apply_rotary_emb, code: return xq_out.type_as(xq), xk_out.type_as(xk)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_583: "bf16[1, 2048, 16, 128]" = torch.ops.prims.convert_element_type.default(view_344, torch.bfloat16); view_344 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_584: "bf16[1, 2048, 16, 128]" = torch.ops.prims.convert_element_type.default(view_345, torch.bfloat16); view_345 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:249 in forward, code: xq = xq.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] permute_312: "bf16[1, 16, 2048, 128]" = torch.ops.aten.permute.default(convert_element_type_583, [0, 2, 1, 3]); convert_element_type_583 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:250 in forward, code: xk = keys.transpose(1, 2) # (bs, n_local_heads, cache_len + seqlen, head_dim)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] permute_313: "bf16[1, 16, 2048, 128]" = torch.ops.aten.permute.default(convert_element_type_584, [0, 2, 1, 3]); convert_element_type_584 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:251 in forward, code: xv = values.transpose(1, 2) # (bs, n_local_heads, cache_len + seqlen, head_dim)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] permute_314: "bf16[1, 16, 2048, 128]" = torch.ops.aten.permute.default(view_340, [0, 2, 1, 3]); view_340 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:254 in forward, code: output = F.scaled_dot_product_attention(xq, xk, xv, is_causal=True)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] _scaled_dot_product_flash_attention_17 = torch.ops.aten._scaled_dot_product_flash_attention.default(permute_312, permute_313, permute_314, 0.0, True, scale = 0.08838834764831843)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_1823: "bf16[1, 16, 2048, 128]" = _scaled_dot_product_flash_attention_17[0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_1824: "f32[1, 16, 2048]" = _scaled_dot_product_flash_attention_17[1]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_1829: "i64[]" = _scaled_dot_product_flash_attention_17[6]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] getitem_1830: "i64[]" = _scaled_dot_product_flash_attention_17[7]; _scaled_dot_product_flash_attention_17 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:255 in forward, code: output = output.transpose(
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] permute_315: "bf16[1, 2048, 16, 128]" = torch.ops.aten.permute.default(getitem_1823, [0, 2, 1, 3])
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:259 in forward, code: output = output.view(bsz * seqlen, -1)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_346: "bf16[2048, 2048]" = torch.ops.aten.reshape.default(permute_315, [2048, -1]); permute_315 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/pytorch_yf225/torch/nn/modules/linear.py:116 in forward, code: return F.linear(input, self.weight, self.bias)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] permute_317: "bf16[2048, 2048]" = torch.ops.aten.permute.default(contiguous_view_as_strided_159, [1, 0]); contiguous_view_as_strided_159 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mm_122: "bf16[2048, 2048]" = torch.ops.aten.mm.default(view_346, permute_317); permute_317 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:410 in forward, code: h = x + self.attention(self.attention_norm(x), freqs_cis)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] add_69: "bf16[2048, 2048]" = torch.ops.aten.add.Tensor(add_67, mm_122); add_67 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:74 in forward, code: output = self._norm(x.float()).type_as(x)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_587: "f32[2048, 2048]" = torch.ops.prims.convert_element_type.default(add_69, torch.float32)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:61 in _norm, code: return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] pow_36: "f32[2048, 2048]" = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_587, 2)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mean_35: "f32[2048, 1]" = torch.ops.aten.mean.dim(pow_36, [-1], True); pow_36 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] add_70: "f32[2048, 1]" = torch.ops.aten.add.Tensor(mean_35, 1e-05); mean_35 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] rsqrt_35: "f32[2048, 1]" = torch.ops.aten.rsqrt.default(add_70); add_70 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mul_140: "f32[2048, 2048]" = torch.ops.aten.mul.Tensor(convert_element_type_587, rsqrt_35); convert_element_type_587 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:74 in forward, code: output = self._norm(x.float()).type_as(x)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_588: "bf16[2048, 2048]" = torch.ops.prims.convert_element_type.default(mul_140, torch.bfloat16); mul_140 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:75 in forward, code: return output * self.weight
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mul_141: "bf16[2048, 2048]" = torch.ops.aten.mul.Tensor(convert_element_type_588, contiguous_view_as_strided_164); convert_element_type_588 = contiguous_view_as_strided_164 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/pytorch_yf225/torch/nn/modules/linear.py:116 in forward, code: return F.linear(input, self.weight, self.bias)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] permute_319: "bf16[2048, 5632]" = torch.ops.aten.permute.default(contiguous_view_as_strided_160, [1, 0]); contiguous_view_as_strided_160 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mm_123: "bf16[2048, 5632]" = torch.ops.aten.mm.default(mul_141, permute_319); permute_319 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:299 in forward, code: return self.w2(F.silu(self.w1(x)) * self.w3(x))
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_591: "f32[2048, 5632]" = torch.ops.prims.convert_element_type.default(mm_123, torch.float32)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] sigmoid_17: "f32[2048, 5632]" = torch.ops.aten.sigmoid.default(convert_element_type_591)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mul_142: "f32[2048, 5632]" = torch.ops.aten.mul.Tensor(convert_element_type_591, sigmoid_17); convert_element_type_591 = sigmoid_17 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_592: "bf16[2048, 5632]" = torch.ops.prims.convert_element_type.default(mul_142, torch.bfloat16); mul_142 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/pytorch_yf225/torch/nn/modules/linear.py:116 in forward, code: return F.linear(input, self.weight, self.bias)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] permute_321: "bf16[2048, 5632]" = torch.ops.aten.permute.default(contiguous_view_as_strided_162, [1, 0]); contiguous_view_as_strided_162 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mm_124: "bf16[2048, 5632]" = torch.ops.aten.mm.default(mul_141, permute_321); permute_321 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:299 in forward, code: return self.w2(F.silu(self.w1(x)) * self.w3(x))
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mul_143: "bf16[2048, 5632]" = torch.ops.aten.mul.Tensor(convert_element_type_592, mm_124); convert_element_type_592 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/pytorch_yf225/torch/nn/modules/linear.py:116 in forward, code: return F.linear(input, self.weight, self.bias)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] permute_323: "bf16[5632, 2048]" = torch.ops.aten.permute.default(contiguous_view_as_strided_161, [1, 0]); contiguous_view_as_strided_161 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mm_125: "bf16[2048, 2048]" = torch.ops.aten.mm.default(mul_143, permute_323); permute_323 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:411 in forward, code: out = h + self.feed_forward(self.ffn_norm(h))
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] add_71: "bf16[2048, 2048]" = torch.ops.aten.add.Tensor(add_69, mm_125); add_69 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:74 in forward, code: output = self._norm(x.float()).type_as(x)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_597: "f32[2048, 2048]" = torch.ops.prims.convert_element_type.default(add_71, torch.float32); add_71 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:61 in _norm, code: return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] pow_37: "f32[2048, 2048]" = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_597, 2)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mean_36: "f32[2048, 1]" = torch.ops.aten.mean.dim(pow_37, [-1], True); pow_37 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] add_72: "f32[2048, 1]" = torch.ops.aten.add.Tensor(mean_36, 1e-05); mean_36 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] rsqrt_36: "f32[2048, 1]" = torch.ops.aten.rsqrt.default(add_72); add_72 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mul_144: "f32[2048, 2048]" = torch.ops.aten.mul.Tensor(convert_element_type_597, rsqrt_36); convert_element_type_597 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:74 in forward, code: output = self._norm(x.float()).type_as(x)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_598: "bf16[2048, 2048]" = torch.ops.prims.convert_element_type.default(mul_144, torch.bfloat16); mul_144 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:75 in forward, code: return output * self.weight
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mul_145: "bf16[2048, 2048]" = torch.ops.aten.mul.Tensor(convert_element_type_598, contiguous_view_as_strided_1); convert_element_type_598 = contiguous_view_as_strided_1 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:493 in forward, code: h = h.view(bsz, bs_seqlen // bsz, self.model_args.dim)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_347: "bf16[1, 2048, 2048]" = torch.ops.aten.reshape.default(mul_145, [1, 2048, 2048]); mul_145 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/pytorch_yf225/torch/nn/modules/linear.py:116 in forward, code: return F.linear(input, self.weight, self.bias)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_348: "bf16[2048, 2048]" = torch.ops.aten.reshape.default(view_347, [2048, 2048]); view_347 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] permute_325: "bf16[2048, 32000]" = torch.ops.aten.permute.default(contiguous_view_as_strided_2, [1, 0]); contiguous_view_as_strided_2 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] mm_126: "bf16[2048, 32000]" = torch.ops.aten.mm.default(view_348, permute_325); permute_325 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] view_349: "bf16[1, 2048, 32000]" = torch.ops.aten.reshape.default(mm_126, [1, 2048, 32000]); mm_126 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:494 in forward, code: output = self.output(h).float()
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] convert_element_type_601: "f32[1, 2048, 32000]" = torch.ops.prims.convert_element_type.default(view_349, torch.float32); view_349 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] # File: /data/users/willfeng/torchtrain/torchtrain/models/llama/model.py:150 in apply_rotary_emb, code: xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] _conj: "c64[1, 2048, 1, 64]" = torch.ops.aten._conj.default(view_20); view_20 = None
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0] return [convert_element_type_601, primals_1, primals_6, primals_7, primals_18, primals_19, primals_20, primals_21, primals_22, primals_23, primals_24, primals_25, primals_26, primals_37, primals_38, primals_39, primals_40, primals_41, primals_42, primals_43, primals_44, primals_45, primals_55, primals_56, primals_57, primals_58, primals_59, primals_60, primals_61, primals_62, primals_63, primals_73, primals_74, primals_75, primals_76, primals_77, primals_78, primals_79, primals_80, primals_81, primals_91, primals_92, primals_93, primals_94, primals_95, primals_96, primals_97, primals_98, primals_99, primals_109, primals_110, primals_111, primals_112, primals_113, primals_114, primals_115, primals_116, primals_117, primals_127, primals_128, primals_129, primals_130, primals_131, primals_132, primals_133, primals_134, primals_135, primals_145, primals_146, primals_147, primals_148, primals_149, primals_150, primals_151, primals_152, primals_153, primals_163, primals_164, primals_165, primals_166, primals_167, primals_168, primals_169, primals_170, primals_171, primals_181, primals_182, primals_183, primals_184, primals_185, primals_186, primals_187, primals_188, primals_189, primals_199, primals_200, primals_201, primals_202, primals_203, primals_204, primals_205, primals_206, primals_207, primals_217, primals_218, primals_219, primals_220, primals_221, primals_222, primals_223, primals_224, primals_225, primals_235, primals_236, primals_237, primals_238, primals_239, primals_240, primals_241, primals_242, primals_243, primals_253, primals_254, primals_255, primals_256, primals_257, primals_258, primals_259, primals_260, primals_261, primals_271, primals_272, primals_273, primals_274, primals_275, primals_276, primals_277, primals_278, primals_279, primals_289, primals_290, primals_291, primals_292, primals_293, primals_294, primals_295, primals_296, primals_297, primals_307, primals_308, primals_309, primals_310, primals_311, primals_312, primals_313, primals_314, primals_315, primals_325, primals_326, primals_327, primals_328, primals_329, primals_330, primals_331, primals_332, primals_333, embedding, rsqrt, mul_1, permute_6, permute_7, permute_8, getitem_107, getitem_112, getitem_113, view_23, mm_3, rsqrt_1, mul_5, mm_4, mm_5, mul_7, mm_6, rsqrt_2, mul_9, permute_24, permute_25, permute_26, getitem_208, getitem_213, getitem_214, view_42, mm_10, rsqrt_3, mul_13, mm_11, mm_12, mul_15, mm_13, rsqrt_4, mul_17, permute_42, permute_43, permute_44, getitem_309, getitem_314, getitem_315, view_61, mm_17, rsqrt_5, mul_21, mm_18, mm_19, mul_23, mm_20, rsqrt_6, mul_25, permute_60, permute_61, permute_62, getitem_410, getitem_415, getitem_416, view_80, mm_24, rsqrt_7, mul_29, mm_25, mm_26, mul_31, mm_27, rsqrt_8, mul_33, permute_78, permute_79, permute_80, getitem_511, getitem_516, getitem_517, view_99, mm_31, rsqrt_9, mul_37, mm_32, mm_33, mul_39, mm_34, rsqrt_10, mul_41, permute_96, permute_97, permute_98, getitem_612, getitem_617, getitem_618, view_118, mm_38, rsqrt_11, mul_45, mm_39, mm_40, mul_47, mm_41, rsqrt_12, mul_49, permute_114, permute_115, permute_116, getitem_713, getitem_718, getitem_719, view_137, mm_45, rsqrt_13, mul_53, mm_46, mm_47, mul_55, mm_48, rsqrt_14, mul_57, permute_132, permute_133, permute_134, getitem_814, getitem_819, getitem_820, view_156, mm_52, rsqrt_15, mul_61, mm_53, mm_54, mul_63, mm_55, rsqrt_16, mul_65, permute_150, permute_151, permute_152, getitem_915, getitem_920, getitem_921, view_175, mm_59, rsqrt_17, mul_69, mm_60, mm_61, mul_71, mm_62, rsqrt_18, mul_73, permute_168, permute_169, permute_170, getitem_1016, getitem_1021, getitem_1022, view_194, mm_66, rsqrt_19, mul_77, mm_67, mm_68, mul_79, mm_69, rsqrt_20, mul_81, permute_186, permute_187, permute_188, getitem_1117, getitem_1122, getitem_1123, view_213, mm_73, rsqrt_21, mul_85, mm_74, mm_75, mul_87, mm_76, rsqrt_22, mul_89, permute_204, permute_205, permute_206, getitem_1218, getitem_1223, getitem_1224, view_232, mm_80, rsqrt_23, mul_93, mm_81, mm_82, mul_95, mm_83, rsqrt_24, mul_97, permute_222, permute_223, permute_224, getitem_1319, getitem_1324, getitem_1325, view_251, mm_87, rsqrt_25, mul_101, mm_88, mm_89, mul_103, mm_90, rsqrt_26, mul_105, permute_240, permute_241, permute_242, getitem_1420, getitem_1425, getitem_1426, view_270, mm_94, rsqrt_27, mul_109, mm_95, mm_96, mul_111, mm_97, rsqrt_28, mul_113, permute_258, permute_259, permute_260, getitem_1521, getitem_1526, getitem_1527, view_289, mm_101, rsqrt_29, mul_117, mm_102, mm_103, mul_119, mm_104, rsqrt_30, mul_121, permute_276, permute_277, permute_278, getitem_1622, getitem_1627, getitem_1628, view_308, mm_108, rsqrt_31, mul_125, mm_109, mm_110, mul_127, mm_111, rsqrt_32, mul_129, permute_294, permute_295, permute_296, getitem_1723, getitem_1728, getitem_1729, view_327, mm_115, rsqrt_33, mul_133, mm_116, mm_117, mul_135, mm_118, rsqrt_34, mul_137, permute_312, permute_313, permute_314, getitem_1824, getitem_1829, getitem_1830, view_346, mm_122, rsqrt_35, mul_141, mm_123, mm_124, mul_143, mm_125, rsqrt_36, view_348, getitem_1823, _conj, getitem_1722, getitem_1621, getitem_1520, getitem_1419, getitem_1318, getitem_1217, getitem_1116, getitem_1015, getitem_914, getitem_813, getitem_712, getitem_611, getitem_510, getitem_409, getitem_308, getitem_207, getitem_106]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:04:12,078.078000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [0/0]
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] TRACED GRAPH
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] ===== after FSDP FX passes: =====
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] /data/users/willfeng/pytorch_yf225/torch/fx/_lazy_graph_module.py class <lambda>(torch.nn.Module):
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] def forward(self, arg0_1: "f32[]", arg1_1: "f32[2048, 32000]", arg2_1: "i64[2048]", arg3_1: "f32[]", arg4_1: "f32[2048, 32000]", arg5_1: "i64[1, 2048]", arg6_1: "bf16[2048]", arg7_1: "bf16[32000, 2048]", arg8_1: "bf16[2048, 2048]", arg9_1: "bf16[2048, 2048]", arg10_1: "bf16[2048, 2048]", arg11_1: "bf16[2048, 2048]", arg12_1: "bf16[5632, 2048]", arg13_1: "bf16[2048, 5632]", arg14_1: "bf16[5632, 2048]", arg15_1: "bf16[2048]", arg16_1: "bf16[2048]", arg17_1: "bf16[2048, 2048]", arg18_1: "bf16[2048, 2048]", arg19_1: "bf16[2048, 2048]", arg20_1: "bf16[2048, 2048]", arg21_1: "bf16[5632, 2048]", arg22_1: "bf16[2048, 5632]", arg23_1: "bf16[5632, 2048]", arg24_1: "bf16[2048]", arg25_1: "bf16[2048]", arg26_1: "bf16[2048, 2048]", arg27_1: "bf16[2048, 2048]", arg28_1: "bf16[2048, 2048]", arg29_1: "bf16[2048, 2048]", arg30_1: "bf16[5632, 2048]", arg31_1: "bf16[2048, 5632]", arg32_1: "bf16[5632, 2048]", arg33_1: "bf16[2048]", arg34_1: "bf16[2048]", arg35_1: "bf16[2048, 2048]", arg36_1: "bf16[2048, 2048]", arg37_1: "bf16[2048, 2048]", arg38_1: "bf16[2048, 2048]", arg39_1: "bf16[5632, 2048]", arg40_1: "bf16[2048, 5632]", arg41_1: "bf16[5632, 2048]", arg42_1: "bf16[2048]", arg43_1: "bf16[2048]", arg44_1: "bf16[2048, 2048]", arg45_1: "bf16[2048, 2048]", arg46_1: "bf16[2048, 2048]", arg47_1: "bf16[2048, 2048]", arg48_1: "bf16[5632, 2048]", arg49_1: "bf16[2048, 5632]", arg50_1: "bf16[5632, 2048]", arg51_1: "bf16[2048]", arg52_1: "bf16[2048]", arg53_1: "bf16[2048, 2048]", arg54_1: "bf16[2048, 2048]", arg55_1: "bf16[2048, 2048]", arg56_1: "bf16[2048, 2048]", arg57_1: "bf16[5632, 2048]", arg58_1: "bf16[2048, 5632]", arg59_1: "bf16[5632, 2048]", arg60_1: "bf16[2048]", arg61_1: "bf16[2048]", arg62_1: "bf16[2048, 2048]", arg63_1: "bf16[2048, 2048]", arg64_1: "bf16[2048, 2048]", arg65_1: "bf16[2048, 2048]", arg66_1: "bf16[5632, 2048]", arg67_1: "bf16[2048, 5632]", arg68_1: "bf16[5632, 2048]", arg69_1: "bf16[2048]", arg70_1: "bf16[2048]", arg71_1: "bf16[2048, 2048]", arg72_1: "bf16[2048, 2048]", arg73_1: "bf16[2048, 2048]", arg74_1: "bf16[2048, 2048]", arg75_1: "bf16[5632, 2048]", arg76_1: "bf16[2048, 5632]", arg77_1: "bf16[5632, 2048]", arg78_1: "bf16[2048]", arg79_1: "bf16[2048]", arg80_1: "bf16[2048, 2048]", arg81_1: "bf16[2048, 2048]", arg82_1: "bf16[2048, 2048]", arg83_1: "bf16[2048, 2048]", arg84_1: "bf16[5632, 2048]", arg85_1: "bf16[2048, 5632]", arg86_1: "bf16[5632, 2048]", arg87_1: "bf16[2048]", arg88_1: "bf16[2048]", arg89_1: "bf16[2048, 2048]", arg90_1: "bf16[2048, 2048]", arg91_1: "bf16[2048, 2048]", arg92_1: "bf16[2048, 2048]", arg93_1: "bf16[5632, 2048]", arg94_1: "bf16[2048, 5632]", arg95_1: "bf16[5632, 2048]", arg96_1: "bf16[2048]", arg97_1: "bf16[2048]", arg98_1: "bf16[2048, 2048]", arg99_1: "bf16[2048, 2048]", arg100_1: "bf16[2048, 2048]", arg101_1: "bf16[2048, 2048]", arg102_1: "bf16[5632, 2048]", arg103_1: "bf16[2048, 5632]", arg104_1: "bf16[5632, 2048]", arg105_1: "bf16[2048]", arg106_1: "bf16[2048]", arg107_1: "bf16[2048, 2048]", arg108_1: "bf16[2048, 2048]", arg109_1: "bf16[2048, 2048]", arg110_1: "bf16[2048, 2048]", arg111_1: "bf16[5632, 2048]", arg112_1: "bf16[2048, 5632]", arg113_1: "bf16[5632, 2048]", arg114_1: "bf16[2048]", arg115_1: "bf16[2048]", arg116_1: "bf16[2048, 2048]", arg117_1: "bf16[2048, 2048]", arg118_1: "bf16[2048, 2048]", arg119_1: "bf16[2048, 2048]", arg120_1: "bf16[5632, 2048]", arg121_1: "bf16[2048, 5632]", arg122_1: "bf16[5632, 2048]", arg123_1: "bf16[2048]", arg124_1: "bf16[2048]", arg125_1: "bf16[2048, 2048]", arg126_1: "bf16[2048, 2048]", arg127_1: "bf16[2048, 2048]", arg128_1: "bf16[2048, 2048]", arg129_1: "bf16[5632, 2048]", arg130_1: "bf16[2048, 5632]", arg131_1: "bf16[5632, 2048]", arg132_1: "bf16[2048]", arg133_1: "bf16[2048]", arg134_1: "bf16[2048, 2048]", arg135_1: "bf16[2048, 2048]", arg136_1: "bf16[2048, 2048]", arg137_1: "bf16[2048, 2048]", arg138_1: "bf16[5632, 2048]", arg139_1: "bf16[2048, 5632]", arg140_1: "bf16[5632, 2048]", arg141_1: "bf16[2048]", arg142_1: "bf16[2048]", arg143_1: "bf16[2048, 2048]", arg144_1: "bf16[2048, 2048]", arg145_1: "bf16[2048, 2048]", arg146_1: "bf16[2048, 2048]", arg147_1: "bf16[5632, 2048]", arg148_1: "bf16[2048, 5632]", arg149_1: "bf16[5632, 2048]", arg150_1: "bf16[2048]", arg151_1: "bf16[2048]", arg152_1: "bf16[2048, 2048]", arg153_1: "bf16[2048, 2048]", arg154_1: "bf16[2048, 2048]", arg155_1: "bf16[2048, 2048]", arg156_1: "bf16[5632, 2048]", arg157_1: "bf16[2048, 5632]", arg158_1: "bf16[5632, 2048]", arg159_1: "bf16[2048]", arg160_1: "bf16[2048]", arg161_1: "bf16[2048, 2048]", arg162_1: "bf16[2048, 2048]", arg163_1: "bf16[2048, 2048]", arg164_1: "bf16[2048, 2048]", arg165_1: "bf16[5632, 2048]", arg166_1: "bf16[2048, 5632]", arg167_1: "bf16[5632, 2048]", arg168_1: "bf16[2048]", arg169_1: "bf16[2048]", arg170_1: "bf16[1, 2048, 2048]", arg171_1: "f32[2048, 1]", arg172_1: "bf16[2048, 2048]", arg173_1: "bf16[1, 16, 2048, 128]", arg174_1: "bf16[1, 16, 2048, 128]", arg175_1: "bf16[1, 16, 2048, 128]", arg176_1: "f32[1, 16, 2048]", arg177_1: "i64[]", arg178_1: "i64[]", arg179_1: "bf16[2048, 2048]", arg180_1: "bf16[2048, 2048]", arg181_1: "f32[2048, 1]", arg182_1: "bf16[2048, 2048]", arg183_1: "bf16[2048, 5632]", arg184_1: "bf16[2048, 5632]", arg185_1: "bf16[2048, 5632]", arg186_1: "bf16[2048, 2048]", arg187_1: "f32[2048, 1]", arg188_1: "bf16[2048, 2048]", arg189_1: "bf16[1, 16, 2048, 128]", arg190_1: "bf16[1, 16, 2048, 128]", arg191_1: "bf16[1, 16, 2048, 128]", arg192_1: "f32[1, 16, 2048]", arg193_1: "i64[]", arg194_1: "i64[]", arg195_1: "bf16[2048, 2048]", arg196_1: "bf16[2048, 2048]", arg197_1: "f32[2048, 1]", arg198_1: "bf16[2048, 2048]", arg199_1: "bf16[2048, 5632]", arg200_1: "bf16[2048, 5632]", arg201_1: "bf16[2048, 5632]", arg202_1: "bf16[2048, 2048]", arg203_1: "f32[2048, 1]", arg204_1: "bf16[2048, 2048]", arg205_1: "bf16[1, 16, 2048, 128]", arg206_1: "bf16[1, 16, 2048, 128]", arg207_1: "bf16[1, 16, 2048, 128]", arg208_1: "f32[1, 16, 2048]", arg209_1: "i64[]", arg210_1: "i64[]", arg211_1: "bf16[2048, 2048]", arg212_1: "bf16[2048, 2048]", arg213_1: "f32[2048, 1]", arg214_1: "bf16[2048, 2048]", arg215_1: "bf16[2048, 5632]", arg216_1: "bf16[2048, 5632]", arg217_1: "bf16[2048, 5632]", arg218_1: "bf16[2048, 2048]", arg219_1: "f32[2048, 1]", arg220_1: "bf16[2048, 2048]", arg221_1: "bf16[1, 16, 2048, 128]", arg222_1: "bf16[1, 16, 2048, 128]", arg223_1: "bf16[1, 16, 2048, 128]", arg224_1: "f32[1, 16, 2048]", arg225_1: "i64[]", arg226_1: "i64[]", arg227_1: "bf16[2048, 2048]", arg228_1: "bf16[2048, 2048]", arg229_1: "f32[2048, 1]", arg230_1: "bf16[2048, 2048]", arg231_1: "bf16[2048, 5632]", arg232_1: "bf16[2048, 5632]", arg233_1: "bf16[2048, 5632]", arg234_1: "bf16[2048, 2048]", arg235_1: "f32[2048, 1]", arg236_1: "bf16[2048, 2048]", arg237_1: "bf16[1, 16, 2048, 128]", arg238_1: "bf16[1, 16, 2048, 128]", arg239_1: "bf16[1, 16, 2048, 128]", arg240_1: "f32[1, 16, 2048]", arg241_1: "i64[]", arg242_1: "i64[]", arg243_1: "bf16[2048, 2048]", arg244_1: "bf16[2048, 2048]", arg245_1: "f32[2048, 1]", arg246_1: "bf16[2048, 2048]", arg247_1: "bf16[2048, 5632]", arg248_1: "bf16[2048, 5632]", arg249_1: "bf16[2048, 5632]", arg250_1: "bf16[2048, 2048]", arg251_1: "f32[2048, 1]", arg252_1: "bf16[2048, 2048]", arg253_1: "bf16[1, 16, 2048, 128]", arg254_1: "bf16[1, 16, 2048, 128]", arg255_1: "bf16[1, 16, 2048, 128]", arg256_1: "f32[1, 16, 2048]", arg257_1: "i64[]", arg258_1: "i64[]", arg259_1: "bf16[2048, 2048]", arg260_1: "bf16[2048, 2048]", arg261_1: "f32[2048, 1]", arg262_1: "bf16[2048, 2048]", arg263_1: "bf16[2048, 5632]", arg264_1: "bf16[2048, 5632]", arg265_1: "bf16[2048, 5632]", arg266_1: "bf16[2048, 2048]", arg267_1: "f32[2048, 1]", arg268_1: "bf16[2048, 2048]", arg269_1: "bf16[1, 16, 2048, 128]", arg270_1: "bf16[1, 16, 2048, 128]", arg271_1: "bf16[1, 16, 2048, 128]", arg272_1: "f32[1, 16, 2048]", arg273_1: "i64[]", arg274_1: "i64[]", arg275_1: "bf16[2048, 2048]", arg276_1: "bf16[2048, 2048]", arg277_1: "f32[2048, 1]", arg278_1: "bf16[2048, 2048]", arg279_1: "bf16[2048, 5632]", arg280_1: "bf16[2048, 5632]", arg281_1: "bf16[2048, 5632]", arg282_1: "bf16[2048, 2048]", arg283_1: "f32[2048, 1]", arg284_1: "bf16[2048, 2048]", arg285_1: "bf16[1, 16, 2048, 128]", arg286_1: "bf16[1, 16, 2048, 128]", arg287_1: "bf16[1, 16, 2048, 128]", arg288_1: "f32[1, 16, 2048]", arg289_1: "i64[]", arg290_1: "i64[]", arg291_1: "bf16[2048, 2048]", arg292_1: "bf16[2048, 2048]", arg293_1: "f32[2048, 1]", arg294_1: "bf16[2048, 2048]", arg295_1: "bf16[2048, 5632]", arg296_1: "bf16[2048, 5632]", arg297_1: "bf16[2048, 5632]", arg298_1: "bf16[2048, 2048]", arg299_1: "f32[2048, 1]", arg300_1: "bf16[2048, 2048]", arg301_1: "bf16[1, 16, 2048, 128]", arg302_1: "bf16[1, 16, 2048, 128]", arg303_1: "bf16[1, 16, 2048, 128]", arg304_1: "f32[1, 16, 2048]", arg305_1: "i64[]", arg306_1: "i64[]", arg307_1: "bf16[2048, 2048]", arg308_1: "bf16[2048, 2048]", arg309_1: "f32[2048, 1]", arg310_1: "bf16[2048, 2048]", arg311_1: "bf16[2048, 5632]", arg312_1: "bf16[2048, 5632]", arg313_1: "bf16[2048, 5632]", arg314_1: "bf16[2048, 2048]", arg315_1: "f32[2048, 1]", arg316_1: "bf16[2048, 2048]", arg317_1: "bf16[1, 16, 2048, 128]", arg318_1: "bf16[1, 16, 2048, 128]", arg319_1: "bf16[1, 16, 2048, 128]", arg320_1: "f32[1, 16, 2048]", arg321_1: "i64[]", arg322_1: "i64[]", arg323_1: "bf16[2048, 2048]", arg324_1: "bf16[2048, 2048]", arg325_1: "f32[2048, 1]", arg326_1: "bf16[2048, 2048]", arg327_1: "bf16[2048, 5632]", arg328_1: "bf16[2048, 5632]", arg329_1: "bf16[2048, 5632]", arg330_1: "bf16[2048, 2048]", arg331_1: "f32[2048, 1]", arg332_1: "bf16[2048, 2048]", arg333_1: "bf16[1, 16, 2048, 128]", arg334_1: "bf16[1, 16, 2048, 128]", arg335_1: "bf16[1, 16, 2048, 128]", arg336_1: "f32[1, 16, 2048]", arg337_1: "i64[]", arg338_1: "i64[]", arg339_1: "bf16[2048, 2048]", arg340_1: "bf16[2048, 2048]", arg341_1: "f32[2048, 1]", arg342_1: "bf16[2048, 2048]", arg343_1: "bf16[2048, 5632]", arg344_1: "bf16[2048, 5632]", arg345_1: "bf16[2048, 5632]", arg346_1: "bf16[2048, 2048]", arg347_1: "f32[2048, 1]", arg348_1: "bf16[2048, 2048]", arg349_1: "bf16[1, 16, 2048, 128]", arg350_1: "bf16[1, 16, 2048, 128]", arg351_1: "bf16[1, 16, 2048, 128]", arg352_1: "f32[1, 16, 2048]", arg353_1: "i64[]", arg354_1: "i64[]", arg355_1: "bf16[2048, 2048]", arg356_1: "bf16[2048, 2048]", arg357_1: "f32[2048, 1]", arg358_1: "bf16[2048, 2048]", arg359_1: "bf16[2048, 5632]", arg360_1: "bf16[2048, 5632]", arg361_1: "bf16[2048, 5632]", arg362_1: "bf16[2048, 2048]", arg363_1: "f32[2048, 1]", arg364_1: "bf16[2048, 2048]", arg365_1: "bf16[1, 16, 2048, 128]", arg366_1: "bf16[1, 16, 2048, 128]", arg367_1: "bf16[1, 16, 2048, 128]", arg368_1: "f32[1, 16, 2048]", arg369_1: "i64[]", arg370_1: "i64[]", arg371_1: "bf16[2048, 2048]", arg372_1: "bf16[2048, 2048]", arg373_1: "f32[2048, 1]", arg374_1: "bf16[2048, 2048]", arg375_1: "bf16[2048, 5632]", arg376_1: "bf16[2048, 5632]", arg377_1: "bf16[2048, 5632]", arg378_1: "bf16[2048, 2048]", arg379_1: "f32[2048, 1]", arg380_1: "bf16[2048, 2048]", arg381_1: "bf16[1, 16, 2048, 128]", arg382_1: "bf16[1, 16, 2048, 128]", arg383_1: "bf16[1, 16, 2048, 128]", arg384_1: "f32[1, 16, 2048]", arg385_1: "i64[]", arg386_1: "i64[]", arg387_1: "bf16[2048, 2048]", arg388_1: "bf16[2048, 2048]", arg389_1: "f32[2048, 1]", arg390_1: "bf16[2048, 2048]", arg391_1: "bf16[2048, 5632]", arg392_1: "bf16[2048, 5632]", arg393_1: "bf16[2048, 5632]", arg394_1: "bf16[2048, 2048]", arg395_1: "f32[2048, 1]", arg396_1: "bf16[2048, 2048]", arg397_1: "bf16[1, 16, 2048, 128]", arg398_1: "bf16[1, 16, 2048, 128]", arg399_1: "bf16[1, 16, 2048, 128]", arg400_1: "f32[1, 16, 2048]", arg401_1: "i64[]", arg402_1: "i64[]", arg403_1: "bf16[2048, 2048]", arg404_1: "bf16[2048, 2048]", arg405_1: "f32[2048, 1]", arg406_1: "bf16[2048, 2048]", arg407_1: "bf16[2048, 5632]", arg408_1: "bf16[2048, 5632]", arg409_1: "bf16[2048, 5632]", arg410_1: "bf16[2048, 2048]", arg411_1: "f32[2048, 1]", arg412_1: "bf16[2048, 2048]", arg413_1: "bf16[1, 16, 2048, 128]", arg414_1: "bf16[1, 16, 2048, 128]", arg415_1: "bf16[1, 16, 2048, 128]", arg416_1: "f32[1, 16, 2048]", arg417_1: "i64[]", arg418_1: "i64[]", arg419_1: "bf16[2048, 2048]", arg420_1: "bf16[2048, 2048]", arg421_1: "f32[2048, 1]", arg422_1: "bf16[2048, 2048]", arg423_1: "bf16[2048, 5632]", arg424_1: "bf16[2048, 5632]", arg425_1: "bf16[2048, 5632]", arg426_1: "bf16[2048, 2048]", arg427_1: "f32[2048, 1]", arg428_1: "bf16[2048, 2048]", arg429_1: "bf16[1, 16, 2048, 128]", arg430_1: "bf16[1, 16, 2048, 128]", arg431_1: "bf16[1, 16, 2048, 128]", arg432_1: "f32[1, 16, 2048]", arg433_1: "i64[]", arg434_1: "i64[]", arg435_1: "bf16[2048, 2048]", arg436_1: "bf16[2048, 2048]", arg437_1: "f32[2048, 1]", arg438_1: "bf16[2048, 2048]", arg439_1: "bf16[2048, 5632]", arg440_1: "bf16[2048, 5632]", arg441_1: "bf16[2048, 5632]", arg442_1: "bf16[2048, 2048]", arg443_1: "f32[2048, 1]", arg444_1: "bf16[2048, 2048]", arg445_1: "bf16[1, 16, 2048, 128]", arg446_1: "bf16[1, 16, 2048, 128]", arg447_1: "bf16[1, 16, 2048, 128]", arg448_1: "f32[1, 16, 2048]", arg449_1: "i64[]", arg450_1: "i64[]", arg451_1: "bf16[2048, 2048]", arg452_1: "bf16[2048, 2048]", arg453_1: "f32[2048, 1]", arg454_1: "bf16[2048, 2048]", arg455_1: "bf16[2048, 5632]", arg456_1: "bf16[2048, 5632]", arg457_1: "bf16[2048, 5632]", arg458_1: "bf16[2048, 2048]", arg459_1: "f32[2048, 1]", arg460_1: "bf16[2048, 2048]", arg461_1: "bf16[1, 16, 2048, 128]", arg462_1: "c64[1, 2048, 1, 64]", arg463_1: "bf16[1, 16, 2048, 128]", arg464_1: "bf16[1, 16, 2048, 128]", arg465_1: "bf16[1, 16, 2048, 128]", arg466_1: "bf16[1, 16, 2048, 128]", arg467_1: "bf16[1, 16, 2048, 128]", arg468_1: "bf16[1, 16, 2048, 128]", arg469_1: "bf16[1, 16, 2048, 128]", arg470_1: "bf16[1, 16, 2048, 128]", arg471_1: "bf16[1, 16, 2048, 128]", arg472_1: "bf16[1, 16, 2048, 128]", arg473_1: "bf16[1, 16, 2048, 128]", arg474_1: "bf16[1, 16, 2048, 128]", arg475_1: "bf16[1, 16, 2048, 128]", arg476_1: "bf16[1, 16, 2048, 128]", arg477_1: "bf16[1, 16, 2048, 128]", arg478_1: "bf16[1, 16, 2048, 128]", arg479_1: "bf16[1, 16, 2048, 128]", arg480_1: "bf16[32000, 2048]", arg481_1: "f32[8192000]", arg482_1: "f32[256]", arg483_1: "f32[8192000]", arg484_1: "f32[524288]", arg485_1: "f32[524288]", arg486_1: "f32[524288]", arg487_1: "f32[524288]", arg488_1: "f32[1441792]", arg489_1: "f32[1441792]", arg490_1: "f32[1441792]", arg491_1: "f32[256]", arg492_1: "f32[256]", arg493_1: "f32[524288]", arg494_1: "f32[524288]", arg495_1: "f32[524288]", arg496_1: "f32[524288]", arg497_1: "f32[1441792]", arg498_1: "f32[1441792]", arg499_1: "f32[1441792]", arg500_1: "f32[256]", arg501_1: "f32[256]", arg502_1: "f32[524288]", arg503_1: "f32[524288]", arg504_1: "f32[524288]", arg505_1: "f32[524288]", arg506_1: "f32[1441792]", arg507_1: "f32[1441792]", arg508_1: "f32[1441792]", arg509_1: "f32[256]", arg510_1: "f32[256]", arg511_1: "f32[524288]", arg512_1: "f32[524288]", arg513_1: "f32[524288]", arg514_1: "f32[524288]", arg515_1: "f32[1441792]", arg516_1: "f32[1441792]", arg517_1: "f32[1441792]", arg518_1: "f32[256]", arg519_1: "f32[256]", arg520_1: "f32[524288]", arg521_1: "f32[524288]", arg522_1: "f32[524288]", arg523_1: "f32[524288]", arg524_1: "f32[1441792]", arg525_1: "f32[1441792]", arg526_1: "f32[1441792]", arg527_1: "f32[256]", arg528_1: "f32[256]", arg529_1: "f32[524288]", arg530_1: "f32[524288]", arg531_1: "f32[524288]", arg532_1: "f32[524288]", arg533_1: "f32[1441792]", arg534_1: "f32[1441792]", arg535_1: "f32[1441792]", arg536_1: "f32[256]", arg537_1: "f32[256]", arg538_1: "f32[524288]", arg539_1: "f32[524288]", arg540_1: "f32[524288]", arg541_1: "f32[524288]", arg542_1: "f32[1441792]", arg543_1: "f32[1441792]", arg544_1: "f32[1441792]", arg545_1: "f32[256]", arg546_1: "f32[256]", arg547_1: "f32[524288]", arg548_1: "f32[524288]", arg549_1: "f32[524288]", arg550_1: "f32[524288]", arg551_1: "f32[1441792]", arg552_1: "f32[1441792]", arg553_1: "f32[1441792]", arg554_1: "f32[256]", arg555_1: "f32[256]", arg556_1: "f32[524288]", arg557_1: "f32[524288]", arg558_1: "f32[524288]", arg559_1: "f32[524288]", arg560_1: "f32[1441792]", arg561_1: "f32[1441792]", arg562_1: "f32[1441792]", arg563_1: "f32[256]", arg564_1: "f32[256]", arg565_1: "f32[524288]", arg566_1: "f32[524288]", arg567_1: "f32[524288]", arg568_1: "f32[524288]", arg569_1: "f32[1441792]", arg570_1: "f32[1441792]", arg571_1: "f32[1441792]", arg572_1: "f32[256]", arg573_1: "f32[256]", arg574_1: "f32[524288]", arg575_1: "f32[524288]", arg576_1: "f32[524288]", arg577_1: "f32[524288]", arg578_1: "f32[1441792]", arg579_1: "f32[1441792]", arg580_1: "f32[1441792]", arg581_1: "f32[256]", arg582_1: "f32[256]", arg583_1: "f32[524288]", arg584_1: "f32[524288]", arg585_1: "f32[524288]", arg586_1: "f32[524288]", arg587_1: "f32[1441792]", arg588_1: "f32[1441792]", arg589_1: "f32[1441792]", arg590_1: "f32[256]", arg591_1: "f32[256]", arg592_1: "f32[524288]", arg593_1: "f32[524288]", arg594_1: "f32[524288]", arg595_1: "f32[524288]", arg596_1: "f32[1441792]", arg597_1: "f32[1441792]", arg598_1: "f32[1441792]", arg599_1: "f32[256]", arg600_1: "f32[256]", arg601_1: "f32[524288]", arg602_1: "f32[524288]", arg603_1: "f32[524288]", arg604_1: "f32[524288]", arg605_1: "f32[1441792]", arg606_1: "f32[1441792]", arg607_1: "f32[1441792]", arg608_1: "f32[256]", arg609_1: "f32[256]", arg610_1: "f32[524288]", arg611_1: "f32[524288]", arg612_1: "f32[524288]", arg613_1: "f32[524288]", arg614_1: "f32[1441792]", arg615_1: "f32[1441792]", arg616_1: "f32[1441792]", arg617_1: "f32[256]", arg618_1: "f32[256]", arg619_1: "f32[524288]", arg620_1: "f32[524288]", arg621_1: "f32[524288]", arg622_1: "f32[524288]", arg623_1: "f32[1441792]", arg624_1: "f32[1441792]", arg625_1: "f32[1441792]", arg626_1: "f32[256]", arg627_1: "f32[256]", arg628_1: "f32[524288]", arg629_1: "f32[524288]", arg630_1: "f32[524288]", arg631_1: "f32[524288]", arg632_1: "f32[1441792]", arg633_1: "f32[1441792]", arg634_1: "f32[1441792]", arg635_1: "f32[256]", arg636_1: "f32[256]", arg637_1: "f32[524288]", arg638_1: "f32[524288]", arg639_1: "f32[524288]", arg640_1: "f32[524288]", arg641_1: "f32[1441792]", arg642_1: "f32[1441792]", arg643_1: "f32[1441792]", arg644_1: "f32[256]", arg645_1: "f32[256]"):
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_common.py:140 in _to_dtype_if_needed, code: return tensor.to(dtype)
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] convert_element_type_110: "bf16[8192000]" = torch.ops.prims.convert_element_type.default(arg481_1, torch.bfloat16); arg481_1 = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] convert_element_type_111: "bf16[256]" = torch.ops.prims.convert_element_type.default(arg482_1, torch.bfloat16); arg482_1 = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] convert_element_type_112: "bf16[8192000]" = torch.ops.prims.convert_element_type.default(arg483_1, torch.bfloat16); arg483_1 = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0]
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:80 in foreach_all_gather, code: all_gather_input, all_gather_output = torch.ops.fsdp.all_gather_copy_in(all_gather_input_numel, world_size, rank, dtype, device, inp_split_sizes, param_all_gather_inputs)
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] all_gather_copy_in = torch.ops.fsdp.all_gather_copy_in.default(16384256, 8, 0, torch.bfloat16, device(type='cuda', index=0), [8192000, 256, 8192000], [convert_element_type_110, convert_element_type_111, convert_element_type_112]); convert_element_type_110 = convert_element_type_111 = convert_element_type_112 = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] getitem: "bf16[16384256]" = all_gather_copy_in[0]
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] getitem_1: "bf16[131074048]" = all_gather_copy_in[1]; all_gather_copy_in = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0]
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_functional_collectives.py:229 in all_gather_tensor, code: tensor = torch.ops._c10d_functional.all_gather_into_tensor(
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] all_gather_into_tensor: "bf16[131074048]" = torch.ops._c10d_functional.all_gather_into_tensor.default(getitem, 8, '0'); getitem = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0]
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_functional_collectives.py:144 in wait_tensor, code: return torch.ops._c10d_functional.wait_tensor(tensor) # type: ignore[attr-defined]
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] wait_tensor: "bf16[131074048]" = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor); all_gather_into_tensor = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0]
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:170 in foreach_all_gather_copy_out, code: split_unpadded = torch.ops.fsdp.contiguous_view_as_strided(
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] view_4: "bf16[8, 16384256]" = torch.ops.aten.reshape.default(wait_tensor, [8, -1])
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] split_with_sizes_2 = torch.ops.aten.split_with_sizes.default(view_4, [8192000, 256, 8192000], 1); view_4 = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] getitem_9: "bf16[8, 256]" = split_with_sizes_2[1]; split_with_sizes_2 = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] view_5: "bf16[8, 16384256]" = torch.ops.aten.reshape.default(wait_tensor, [8, -1]); wait_tensor = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] split_with_sizes_3 = torch.ops.aten.split_with_sizes.default(view_5, [8192000, 256, 8192000], 1); view_5 = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] getitem_13: "bf16[8, 8192000]" = split_with_sizes_3[2]; split_with_sizes_3 = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0]
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_common.py:140 in _to_dtype_if_needed, code: return tensor.to(dtype)
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] convert_element_type_113: "bf16[524288]" = torch.ops.prims.convert_element_type.default(arg484_1, torch.bfloat16); arg484_1 = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] convert_element_type_114: "bf16[524288]" = torch.ops.prims.convert_element_type.default(arg485_1, torch.bfloat16); arg485_1 = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] convert_element_type_115: "bf16[524288]" = torch.ops.prims.convert_element_type.default(arg486_1, torch.bfloat16); arg486_1 = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] convert_element_type_116: "bf16[524288]" = torch.ops.prims.convert_element_type.default(arg487_1, torch.bfloat16); arg487_1 = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] convert_element_type_117: "bf16[1441792]" = torch.ops.prims.convert_element_type.default(arg488_1, torch.bfloat16); arg488_1 = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] convert_element_type_118: "bf16[1441792]" = torch.ops.prims.convert_element_type.default(arg489_1, torch.bfloat16); arg489_1 = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] convert_element_type_119: "bf16[1441792]" = torch.ops.prims.convert_element_type.default(arg490_1, torch.bfloat16); arg490_1 = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] convert_element_type_120: "bf16[256]" = torch.ops.prims.convert_element_type.default(arg491_1, torch.bfloat16); arg491_1 = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] convert_element_type_121: "bf16[256]" = torch.ops.prims.convert_element_type.default(arg492_1, torch.bfloat16); arg492_1 = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0]
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:80 in foreach_all_gather, code: all_gather_input, all_gather_output = torch.ops.fsdp.all_gather_copy_in(all_gather_input_numel, world_size, rank, dtype, device, inp_split_sizes, param_all_gather_inputs)
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] all_gather_copy_in_1 = torch.ops.fsdp.all_gather_copy_in.default(6423040, 8, 0, torch.bfloat16, device(type='cuda', index=0), [524288, 524288, 524288, 524288, 1441792, 1441792, 1441792, 256, 256], [convert_element_type_113, convert_element_type_114, convert_element_type_115, convert_element_type_116, convert_element_type_117, convert_element_type_118, convert_element_type_119, convert_element_type_120, convert_element_type_121]); convert_element_type_113 = convert_element_type_114 = convert_element_type_115 = convert_element_type_116 = convert_element_type_117 = convert_element_type_118 = convert_element_type_119 = convert_element_type_120 = convert_element_type_121 = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] getitem_14: "bf16[6423040]" = all_gather_copy_in_1[0]
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] getitem_15: "bf16[51384320]" = all_gather_copy_in_1[1]; all_gather_copy_in_1 = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0]
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_functional_collectives.py:229 in all_gather_tensor, code: tensor = torch.ops._c10d_functional.all_gather_into_tensor(
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] all_gather_into_tensor_1: "bf16[51384320]" = torch.ops._c10d_functional.all_gather_into_tensor.default(getitem_14, 8, '0'); getitem_14 = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0]
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_functional_collectives.py:144 in wait_tensor, code: return torch.ops._c10d_functional.wait_tensor(tensor) # type: ignore[attr-defined]
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] wait_tensor_1: "bf16[51384320]" = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_1); all_gather_into_tensor_1 = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0]
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:170 in foreach_all_gather_copy_out, code: split_unpadded = torch.ops.fsdp.contiguous_view_as_strided(
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] view_11: "bf16[8, 6423040]" = torch.ops.aten.reshape.default(wait_tensor_1, [8, -1])
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] split_with_sizes_5 = torch.ops.aten.split_with_sizes.default(view_11, [524288, 524288, 524288, 524288, 1441792, 1441792, 1441792, 256, 256], 1); view_11 = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] getitem_25: "bf16[8, 524288]" = split_with_sizes_5[0]; split_with_sizes_5 = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] view_12: "bf16[8, 6423040]" = torch.ops.aten.reshape.default(wait_tensor_1, [8, -1])
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] split_with_sizes_6 = torch.ops.aten.split_with_sizes.default(view_12, [524288, 524288, 524288, 524288, 1441792, 1441792, 1441792, 256, 256], 1); view_12 = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] getitem_35: "bf16[8, 524288]" = split_with_sizes_6[1]; split_with_sizes_6 = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] view_13: "bf16[8, 6423040]" = torch.ops.aten.reshape.default(wait_tensor_1, [8, -1])
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] split_with_sizes_7 = torch.ops.aten.split_with_sizes.default(view_13, [524288, 524288, 524288, 524288, 1441792, 1441792, 1441792, 256, 256], 1); view_13 = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] getitem_45: "bf16[8, 524288]" = split_with_sizes_7[2]; split_with_sizes_7 = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] view_14: "bf16[8, 6423040]" = torch.ops.aten.reshape.default(wait_tensor_1, [8, -1])
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] split_with_sizes_8 = torch.ops.aten.split_with_sizes.default(view_14, [524288, 524288, 524288, 524288, 1441792, 1441792, 1441792, 256, 256], 1); view_14 = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] getitem_55: "bf16[8, 524288]" = split_with_sizes_8[3]; split_with_sizes_8 = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] view_15: "bf16[8, 6423040]" = torch.ops.aten.reshape.default(wait_tensor_1, [8, -1])
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] split_with_sizes_9 = torch.ops.aten.split_with_sizes.default(view_15, [524288, 524288, 524288, 524288, 1441792, 1441792, 1441792, 256, 256], 1); view_15 = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] getitem_65: "bf16[8, 1441792]" = split_with_sizes_9[4]; split_with_sizes_9 = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] view_16: "bf16[8, 6423040]" = torch.ops.aten.reshape.default(wait_tensor_1, [8, -1])
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] split_with_sizes_10 = torch.ops.aten.split_with_sizes.default(view_16, [524288, 524288, 524288, 524288, 1441792, 1441792, 1441792, 256, 256], 1); view_16 = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] getitem_75: "bf16[8, 1441792]" = split_with_sizes_10[5]; split_with_sizes_10 = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] view_17: "bf16[8, 6423040]" = torch.ops.aten.reshape.default(wait_tensor_1, [8, -1])
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] split_with_sizes_11 = torch.ops.aten.split_with_sizes.default(view_17, [524288, 524288, 524288, 524288, 1441792, 1441792, 1441792, 256, 256], 1); view_17 = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] getitem_85: "bf16[8, 1441792]" = split_with_sizes_11[6]; split_with_sizes_11 = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] view_18: "bf16[8, 6423040]" = torch.ops.aten.reshape.default(wait_tensor_1, [8, -1])
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] split_with_sizes_12 = torch.ops.aten.split_with_sizes.default(view_18, [524288, 524288, 524288, 524288, 1441792, 1441792, 1441792, 256, 256], 1); view_18 = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] getitem_95: "bf16[8, 256]" = split_with_sizes_12[7]; split_with_sizes_12 = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] view_19: "bf16[8, 6423040]" = torch.ops.aten.reshape.default(wait_tensor_1, [8, -1]); wait_tensor_1 = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] split_with_sizes_13 = torch.ops.aten.split_with_sizes.default(view_19, [524288, 524288, 524288, 524288, 1441792, 1441792, 1441792, 256, 256], 1); view_19 = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] getitem_105: "bf16[8, 256]" = split_with_sizes_13[8]; split_with_sizes_13 = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0]
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_common.py:140 in _to_dtype_if_needed, code: return tensor.to(dtype)
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] convert_element_type_129: "bf16[524288]" = torch.ops.prims.convert_element_type.default(arg493_1, torch.bfloat16); arg493_1 = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] convert_element_type_130: "bf16[524288]" = torch.ops.prims.convert_element_type.default(arg494_1, torch.bfloat16); arg494_1 = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] convert_element_type_131: "bf16[524288]" = torch.ops.prims.convert_element_type.default(arg495_1, torch.bfloat16); arg495_1 = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] convert_element_type_132: "bf16[524288]" = torch.ops.prims.convert_element_type.default(arg496_1, torch.bfloat16); arg496_1 = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] convert_element_type_133: "bf16[1441792]" = torch.ops.prims.convert_element_type.default(arg497_1, torch.bfloat16); arg497_1 = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] convert_element_type_134: "bf16[1441792]" = torch.ops.prims.convert_element_type.default(arg498_1, torch.bfloat16); arg498_1 = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] convert_element_type_135: "bf16[1441792]" = torch.ops.prims.convert_element_type.default(arg499_1, torch.bfloat16); arg499_1 = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] convert_element_type_136: "bf16[256]" = torch.ops.prims.convert_element_type.default(arg500_1, torch.bfloat16); arg500_1 = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] convert_element_type_137: "bf16[256]" = torch.ops.prims.convert_element_type.default(arg501_1, torch.bfloat16); arg501_1 = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0]
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:80 in foreach_all_gather, code: all_gather_input, all_gather_output = torch.ops.fsdp.all_gather_copy_in(all_gather_input_numel, world_size, rank, dtype, device, inp_split_sizes, param_all_gather_inputs)
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] all_gather_copy_in_2 = torch.ops.fsdp.all_gather_copy_in.default(6423040, 8, 0, torch.bfloat16, device(type='cuda', index=0), [524288, 524288, 524288, 524288, 1441792, 1441792, 1441792, 256, 256], [convert_element_type_129, convert_element_type_130, convert_element_type_131, convert_element_type_132, convert_element_type_133, convert_element_type_134, convert_element_type_135, convert_element_type_136, convert_element_type_137]); convert_element_type_129 = convert_element_type_130 = convert_element_type_131 = convert_element_type_132 = convert_element_type_133 = convert_element_type_134 = convert_element_type_135 = convert_element_type_136 = convert_element_type_137 = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] getitem_106: "bf16[6423040]" = all_gather_copy_in_2[0]
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] getitem_107: "bf16[51384320]" = all_gather_copy_in_2[1]; all_gather_copy_in_2 = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0]
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] # File: <eval_with_key>.7:487 in forward, code: _log_softmax_backward_data = torch.ops.aten._log_softmax_backward_data.default(nll_loss_backward, getitem_4, 1, torch.float32); nll_loss_backward = getitem_4 = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] exp: "f32[2048, 32000]" = torch.ops.aten.exp.default(arg4_1); arg4_1 = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0]
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] # File: <eval_with_key>.7:486 in forward, code: nll_loss_backward = torch.ops.aten.nll_loss_backward.default(getitem, getitem_1, getitem_2, None, 1, -100, getitem_3); getitem = getitem_1 = getitem_2 = getitem_3 = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] full: "f32[2048, 32000]" = torch.ops.aten.full.default([2048, 32000], 0, dtype = torch.float32, layout = torch.strided, device = device(type='cuda', index=0), pin_memory = False)
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] unsqueeze: "i64[2048, 1]" = torch.ops.aten.unsqueeze.default(arg2_1, 1); arg2_1 = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] ne: "b8[2048, 1]" = torch.ops.aten.ne.Scalar(unsqueeze, -100)
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] scalar_tensor: "i64[]" = torch.ops.aten.scalar_tensor.default(0, dtype = torch.int64, layout = torch.strided, device = device(type='cuda', index=0))
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] where: "i64[2048, 1]" = torch.ops.aten.where.self(ne, unsqueeze, scalar_tensor); ne = scalar_tensor = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] scatter: "f32[2048, 32000]" = torch.ops.aten.scatter.value(full, 1, where, -1.0); full = where = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] ne_1: "b8[2048, 1]" = torch.ops.aten.ne.Scalar(unsqueeze, -100); unsqueeze = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] div: "f32[]" = torch.ops.aten.div.Tensor(arg0_1, arg3_1); arg0_1 = arg3_1 = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] scalar_tensor_1: "f32[]" = torch.ops.aten.scalar_tensor.default(0, dtype = torch.float32, layout = torch.strided, device = device(type='cuda', index=0))
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] where_1: "f32[2048, 1]" = torch.ops.aten.where.self(ne_1, div, scalar_tensor_1); ne_1 = div = scalar_tensor_1 = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] mul: "f32[2048, 32000]" = torch.ops.aten.mul.Tensor(scatter, where_1); scatter = where_1 = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0]
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] # File: <eval_with_key>.7:487 in forward, code: _log_softmax_backward_data = torch.ops.aten._log_softmax_backward_data.default(nll_loss_backward, getitem_4, 1, torch.float32); nll_loss_backward = getitem_4 = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] sum_1: "f32[2048, 1]" = torch.ops.aten.sum.dim_IntList(mul, [1], True)
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] mul_1: "f32[2048, 32000]" = torch.ops.aten.mul.Tensor(exp, sum_1); exp = sum_1 = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] sub: "f32[2048, 32000]" = torch.ops.aten.sub.Tensor(mul, mul_1); mul = mul_1 = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0]
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] # File: <eval_with_key>.7:488 in forward, code: view = torch.ops.aten.view.default(_log_softmax_backward_data, [1, 2048, 32000]); _log_softmax_backward_data = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] view: "f32[1, 2048, 32000]" = torch.ops.aten.reshape.default(sub, [1, 2048, 32000]); sub = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0]
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] # File: <eval_with_key>.7:950 in forward, code: convert_element_type_110 = torch.ops.prims.convert_element_type.default(trace_wrapped, torch.bfloat16); trace_wrapped = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] convert_element_type_122: "bf16[1, 2048, 32000]" = torch.ops.prims.convert_element_type.default(view, torch.bfloat16); view = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0]
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] # File: <eval_with_key>.7:951 in forward, code: view_2 = torch.ops.aten.view.default(convert_element_type_110, [2048, 32000]); convert_element_type_110 = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] view_6: "bf16[2048, 32000]" = torch.ops.aten.reshape.default(convert_element_type_122, [2048, 32000]); convert_element_type_122 = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0]
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:170 in foreach_all_gather_copy_out, code: split_unpadded = torch.ops.fsdp.contiguous_view_as_strided(
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] contiguous_view_as_strided_2: "bf16[32000, 2048]" = torch.ops.fsdp.contiguous_view_as_strided.default(getitem_13, [32000, 2048], [2048, 1]); getitem_13 = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0]
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] # File: <eval_with_key>.7:957 in forward, code: mm_1 = torch.ops.aten.mm.default(view_2, permute_3); view_2 = permute_3 = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] mm_1: "bf16[2048, 2048]" = torch.ops.aten.mm.default(view_6, contiguous_view_as_strided_2); contiguous_view_as_strided_2 = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0]
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] # File: <eval_with_key>.7:958 in forward, code: view_3 = torch.ops.aten.view.default(mm_1, [1, 2048, 2048]); mm_1 = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] view_7: "bf16[1, 2048, 2048]" = torch.ops.aten.reshape.default(mm_1, [1, 2048, 2048]); mm_1 = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0]
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] # File: <eval_with_key>.7:961 in forward, code: view_4 = torch.ops.aten.view.default(view_3, [2048, 2048]); view_3 = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] view_8: "bf16[2048, 2048]" = torch.ops.aten.reshape.default(view_7, [2048, 2048]); view_7 = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0]
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:170 in foreach_all_gather_copy_out, code: split_unpadded = torch.ops.fsdp.contiguous_view_as_strided(
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] contiguous_view_as_strided_1: "bf16[2048]" = torch.ops.fsdp.contiguous_view_as_strided.default(getitem_9, [2048], [1]); getitem_9 = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0]
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] # File: <eval_with_key>.7:963 in forward, code: mul_56 = torch.ops.aten.mul.Tensor(view_4, getitem_6); view_4 = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] mul_58: "bf16[2048, 2048]" = torch.ops.aten.mul.Tensor(view_8, contiguous_view_as_strided_1); contiguous_view_as_strided_1 = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0]
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] # File: <eval_with_key>.7:967 in forward, code: convert_element_type_111 = torch.ops.prims.convert_element_type.default(mul_56, torch.float32); mul_56 = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] convert_element_type_127: "f32[2048, 2048]" = torch.ops.prims.convert_element_type.default(mul_58, torch.float32); mul_58 = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0]
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] # File: <eval_with_key>.7:969 in forward, code: mul_58 = torch.ops.aten.mul.Tensor(convert_element_type_111, getitem_459); convert_element_type_111 = getitem_459 = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] mul_60: "f32[2048, 2048]" = torch.ops.aten.mul.Tensor(convert_element_type_127, arg459_1)
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0]
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] # File: <eval_with_key>.7:655 in forward, code: view_1 = torch.ops.aten.view.default(getitem_170, [-1, 2048]); getitem_170 = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] view_1: "bf16[2048, 2048]" = torch.ops.aten.reshape.default(arg170_1, [-1, 2048]); arg170_1 = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0]
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] # File: <eval_with_key>.7:661 in forward, code: add = torch.ops.aten.add.Tensor(view_1, getitem_180); view_1 = getitem_180 = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] add: "bf16[2048, 2048]" = torch.ops.aten.add.Tensor(view_1, arg180_1); arg180_1 = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0]
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] # File: <eval_with_key>.7:671 in forward, code: add_1 = torch.ops.aten.add.Tensor(add, getitem_186); add = getitem_186 = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] add_1: "bf16[2048, 2048]" = torch.ops.aten.add.Tensor(add, arg186_1); arg186_1 = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0]
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] # File: <eval_with_key>.7:677 in forward, code: add_2 = torch.ops.aten.add.Tensor(add_1, getitem_196); add_1 = getitem_196 = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] add_2: "bf16[2048, 2048]" = torch.ops.aten.add.Tensor(add_1, arg196_1); arg196_1 = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0]
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] # File: <eval_with_key>.7:687 in forward, code: add_3 = torch.ops.aten.add.Tensor(add_2, getitem_202); add_2 = getitem_202 = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] add_3: "bf16[2048, 2048]" = torch.ops.aten.add.Tensor(add_2, arg202_1); arg202_1 = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0]
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] # File: <eval_with_key>.7:693 in forward, code: add_4 = torch.ops.aten.add.Tensor(add_3, getitem_212); add_3 = getitem_212 = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] add_4: "bf16[2048, 2048]" = torch.ops.aten.add.Tensor(add_3, arg212_1); arg212_1 = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0]
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] # File: <eval_with_key>.7:703 in forward, code: add_5 = torch.ops.aten.add.Tensor(add_4, getitem_218); add_4 = getitem_218 = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] add_5: "bf16[2048, 2048]" = torch.ops.aten.add.Tensor(add_4, arg218_1); arg218_1 = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0]
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] # File: <eval_with_key>.7:709 in forward, code: add_6 = torch.ops.aten.add.Tensor(add_5, getitem_228); add_5 = getitem_228 = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] add_6: "bf16[2048, 2048]" = torch.ops.aten.add.Tensor(add_5, arg228_1); arg228_1 = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0]
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] # File: <eval_with_key>.7:719 in forward, code: add_7 = torch.ops.aten.add.Tensor(add_6, getitem_234); add_6 = getitem_234 = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] add_7: "bf16[2048, 2048]" = torch.ops.aten.add.Tensor(add_6, arg234_1); arg234_1 = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0]
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] # File: <eval_with_key>.7:725 in forward, code: add_8 = torch.ops.aten.add.Tensor(add_7, getitem_244); add_7 = getitem_244 = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] add_8: "bf16[2048, 2048]" = torch.ops.aten.add.Tensor(add_7, arg244_1); arg244_1 = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0]
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] # File: <eval_with_key>.7:735 in forward, code: add_9 = torch.ops.aten.add.Tensor(add_8, getitem_250); add_8 = getitem_250 = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] add_9: "bf16[2048, 2048]" = torch.ops.aten.add.Tensor(add_8, arg250_1); arg250_1 = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0]
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] # File: <eval_with_key>.7:741 in forward, code: add_10 = torch.ops.aten.add.Tensor(add_9, getitem_260); add_9 = getitem_260 = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] add_10: "bf16[2048, 2048]" = torch.ops.aten.add.Tensor(add_9, arg260_1); arg260_1 = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0]
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] # File: <eval_with_key>.7:751 in forward, code: add_11 = torch.ops.aten.add.Tensor(add_10, getitem_266); add_10 = getitem_266 = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] add_11: "bf16[2048, 2048]" = torch.ops.aten.add.Tensor(add_10, arg266_1); arg266_1 = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0]
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] # File: <eval_with_key>.7:757 in forward, code: add_12 = torch.ops.aten.add.Tensor(add_11, getitem_276); add_11 = getitem_276 = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] add_12: "bf16[2048, 2048]" = torch.ops.aten.add.Tensor(add_11, arg276_1); arg276_1 = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0]
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] # File: <eval_with_key>.7:767 in forward, code: add_13 = torch.ops.aten.add.Tensor(add_12, getitem_282); add_12 = getitem_282 = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] add_13: "bf16[2048, 2048]" = torch.ops.aten.add.Tensor(add_12, arg282_1); arg282_1 = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0]
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] # File: <eval_with_key>.7:773 in forward, code: add_14 = torch.ops.aten.add.Tensor(add_13, getitem_292); add_13 = getitem_292 = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] add_14: "bf16[2048, 2048]" = torch.ops.aten.add.Tensor(add_13, arg292_1); arg292_1 = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0]
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] # File: <eval_with_key>.7:783 in forward, code: add_15 = torch.ops.aten.add.Tensor(add_14, getitem_298); add_14 = getitem_298 = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] add_15: "bf16[2048, 2048]" = torch.ops.aten.add.Tensor(add_14, arg298_1); arg298_1 = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0]
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] # File: <eval_with_key>.7:789 in forward, code: add_16 = torch.ops.aten.add.Tensor(add_15, getitem_308); add_15 = getitem_308 = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] add_16: "bf16[2048, 2048]" = torch.ops.aten.add.Tensor(add_15, arg308_1); arg308_1 = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0]
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] # File: <eval_with_key>.7:799 in forward, code: add_17 = torch.ops.aten.add.Tensor(add_16, getitem_314); add_16 = getitem_314 = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] add_17: "bf16[2048, 2048]" = torch.ops.aten.add.Tensor(add_16, arg314_1); arg314_1 = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0]
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] # File: <eval_with_key>.7:805 in forward, code: add_18 = torch.ops.aten.add.Tensor(add_17, getitem_324); add_17 = getitem_324 = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] add_18: "bf16[2048, 2048]" = torch.ops.aten.add.Tensor(add_17, arg324_1); arg324_1 = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0]
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] # File: <eval_with_key>.7:815 in forward, code: add_19 = torch.ops.aten.add.Tensor(add_18, getitem_330); add_18 = getitem_330 = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] add_19: "bf16[2048, 2048]" = torch.ops.aten.add.Tensor(add_18, arg330_1); arg330_1 = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0]
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] # File: <eval_with_key>.7:821 in forward, code: add_20 = torch.ops.aten.add.Tensor(add_19, getitem_340); add_19 = getitem_340 = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] add_20: "bf16[2048, 2048]" = torch.ops.aten.add.Tensor(add_19, arg340_1); arg340_1 = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0]
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] # File: <eval_with_key>.7:831 in forward, code: add_21 = torch.ops.aten.add.Tensor(add_20, getitem_346); add_20 = getitem_346 = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] add_21: "bf16[2048, 2048]" = torch.ops.aten.add.Tensor(add_20, arg346_1); arg346_1 = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0]
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] # File: <eval_with_key>.7:837 in forward, code: add_22 = torch.ops.aten.add.Tensor(add_21, getitem_356); add_21 = getitem_356 = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] add_22: "bf16[2048, 2048]" = torch.ops.aten.add.Tensor(add_21, arg356_1); arg356_1 = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0]
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] # File: <eval_with_key>.7:847 in forward, code: add_23 = torch.ops.aten.add.Tensor(add_22, getitem_362); add_22 = getitem_362 = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] add_23: "bf16[2048, 2048]" = torch.ops.aten.add.Tensor(add_22, arg362_1); arg362_1 = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0]
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] # File: <eval_with_key>.7:853 in forward, code: add_24 = torch.ops.aten.add.Tensor(add_23, getitem_372); add_23 = getitem_372 = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] add_24: "bf16[2048, 2048]" = torch.ops.aten.add.Tensor(add_23, arg372_1); arg372_1 = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0]
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] # File: <eval_with_key>.7:863 in forward, code: add_25 = torch.ops.aten.add.Tensor(add_24, getitem_378); add_24 = getitem_378 = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] add_25: "bf16[2048, 2048]" = torch.ops.aten.add.Tensor(add_24, arg378_1); arg378_1 = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0]
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] # File: <eval_with_key>.7:869 in forward, code: add_26 = torch.ops.aten.add.Tensor(add_25, getitem_388); add_25 = getitem_388 = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] add_26: "bf16[2048, 2048]" = torch.ops.aten.add.Tensor(add_25, arg388_1); arg388_1 = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0]
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] # File: <eval_with_key>.7:879 in forward, code: add_27 = torch.ops.aten.add.Tensor(add_26, getitem_394); add_26 = getitem_394 = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] add_27: "bf16[2048, 2048]" = torch.ops.aten.add.Tensor(add_26, arg394_1); arg394_1 = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0]
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] # File: <eval_with_key>.7:885 in forward, code: add_28 = torch.ops.aten.add.Tensor(add_27, getitem_404); add_27 = getitem_404 = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] add_28: "bf16[2048, 2048]" = torch.ops.aten.add.Tensor(add_27, arg404_1); arg404_1 = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0]
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] # File: <eval_with_key>.7:895 in forward, code: add_29 = torch.ops.aten.add.Tensor(add_28, getitem_410); add_28 = getitem_410 = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] add_29: "bf16[2048, 2048]" = torch.ops.aten.add.Tensor(add_28, arg410_1); arg410_1 = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0]
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] # File: <eval_with_key>.7:901 in forward, code: add_30 = torch.ops.aten.add.Tensor(add_29, getitem_420); add_29 = getitem_420 = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] add_30: "bf16[2048, 2048]" = torch.ops.aten.add.Tensor(add_29, arg420_1); arg420_1 = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0]
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] # File: <eval_with_key>.7:911 in forward, code: add_31 = torch.ops.aten.add.Tensor(add_30, getitem_426); add_30 = getitem_426 = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] add_31: "bf16[2048, 2048]" = torch.ops.aten.add.Tensor(add_30, arg426_1); arg426_1 = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0]
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] # File: <eval_with_key>.7:917 in forward, code: add_32 = torch.ops.aten.add.Tensor(add_31, getitem_436); add_31 = getitem_436 = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] add_32: "bf16[2048, 2048]" = torch.ops.aten.add.Tensor(add_31, arg436_1); arg436_1 = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0]
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] # File: <eval_with_key>.7:927 in forward, code: add_33 = torch.ops.aten.add.Tensor(add_32, getitem_442); add_32 = getitem_442 = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] add_33: "bf16[2048, 2048]" = torch.ops.aten.add.Tensor(add_32, arg442_1); arg442_1 = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0]
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] # File: <eval_with_key>.7:933 in forward, code: add_34 = torch.ops.aten.add.Tensor(add_33, getitem_452); add_33 = getitem_452 = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] add_34: "bf16[2048, 2048]" = torch.ops.aten.add.Tensor(add_33, arg452_1); arg452_1 = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0]
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] # File: <eval_with_key>.7:943 in forward, code: add_35 = torch.ops.aten.add.Tensor(add_34, getitem_458); add_34 = getitem_458 = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] add_35: "bf16[2048, 2048]" = torch.ops.aten.add.Tensor(add_34, arg458_1); arg458_1 = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0]
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] # File: <eval_with_key>.7:944 in forward, code: convert_element_type_108 = torch.ops.prims.convert_element_type.default(add_35, torch.float32); add_35 = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] convert_element_type_108: "f32[2048, 2048]" = torch.ops.prims.convert_element_type.default(add_35, torch.float32); add_35 = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0]
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] # File: <eval_with_key>.7:968 in forward, code: mul_57 = torch.ops.aten.mul.Tensor(convert_element_type_111, convert_element_type_108)
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] mul_59: "f32[2048, 2048]" = torch.ops.aten.mul.Tensor(convert_element_type_127, convert_element_type_108); convert_element_type_127 = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0]
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] # File: <eval_with_key>.7:970 in forward, code: sum_2 = torch.ops.aten.sum.dim_IntList(mul_57, [1], True); mul_57 = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] sum_3: "f32[2048, 1]" = torch.ops.aten.sum.dim_IntList(mul_59, [1], True); mul_59 = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0]
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] # File: <eval_with_key>.7:974 in forward, code: mul_59 = torch.ops.aten.mul.Scalar(sum_2, -0.5); sum_2 = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] mul_61: "f32[2048, 1]" = torch.ops.aten.mul.Scalar(sum_3, -0.5); sum_3 = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0]
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] # File: <eval_with_key>.7:973 in forward, code: pow_1 = torch.ops.aten.pow.Tensor_Scalar(alias_75, 3); alias_75 = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] pow_1: "f32[2048, 1]" = torch.ops.aten.pow.Tensor_Scalar(arg459_1, 3)
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0]
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] # File: <eval_with_key>.7:975 in forward, code: mul_60 = torch.ops.aten.mul.Tensor(mul_59, pow_1); mul_59 = pow_1 = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] mul_62: "f32[2048, 1]" = torch.ops.aten.mul.Tensor(mul_61, pow_1); mul_61 = pow_1 = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0]
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] # File: <eval_with_key>.7:976 in forward, code: expand = torch.ops.aten.expand.default(mul_60, [2048, 2048]); mul_60 = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] expand: "f32[2048, 2048]" = torch.ops.aten.expand.default(mul_62, [2048, 2048]); mul_62 = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0]
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] # File: <eval_with_key>.7:977 in forward, code: div = torch.ops.aten.div.Scalar(expand, 2048); expand = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] div_1: "f32[2048, 2048]" = torch.ops.aten.div.Scalar(expand, 2048); expand = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0]
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] # File: <eval_with_key>.7:978 in forward, code: pow_2 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_108, 1.0); convert_element_type_108 = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] pow_2: "f32[2048, 2048]" = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_108, 1.0)
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0]
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] # File: <eval_with_key>.7:979 in forward, code: mul_61 = torch.ops.aten.mul.Scalar(pow_2, 2.0); pow_2 = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] mul_63: "f32[2048, 2048]" = torch.ops.aten.mul.Scalar(pow_2, 2.0); pow_2 = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0]
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] # File: <eval_with_key>.7:980 in forward, code: mul_62 = torch.ops.aten.mul.Tensor(div, mul_61); div = mul_61 = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] mul_64: "f32[2048, 2048]" = torch.ops.aten.mul.Tensor(div_1, mul_63); div_1 = mul_63 = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0]
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] # File: <eval_with_key>.7:981 in forward, code: add_36 = torch.ops.aten.add.Tensor(mul_58, mul_62); mul_58 = mul_62 = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] add_36: "f32[2048, 2048]" = torch.ops.aten.add.Tensor(mul_60, mul_64); mul_60 = mul_64 = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0]
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] # File: <eval_with_key>.7:982 in forward, code: convert_element_type_112 = torch.ops.prims.convert_element_type.default(add_36, torch.bfloat16); add_36 = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] convert_element_type_128: "bf16[2048, 2048]" = torch.ops.prims.convert_element_type.default(add_36, torch.bfloat16); add_36 = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0]
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:170 in foreach_all_gather_copy_out, code: split_unpadded = torch.ops.fsdp.contiguous_view_as_strided(
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] contiguous_view_as_strided_8: "bf16[2048, 5632]" = torch.ops.fsdp.contiguous_view_as_strided.default(getitem_75, [2048, 5632], [5632, 1]); getitem_75 = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0]
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] # File: <eval_with_key>.7:989 in forward, code: mm_3 = torch.ops.aten.mm.default(trace_wrapped_1, permute_8); permute_8 = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] mm_3: "bf16[2048, 5632]" = torch.ops.aten.mm.default(convert_element_type_128, contiguous_view_as_strided_8); contiguous_view_as_strided_8 = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0]
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] # File: <eval_with_key>.7:939 in forward, code: convert_element_type_106 = torch.ops.prims.convert_element_type.default(getitem_455, torch.float32)
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] convert_element_type_106: "f32[2048, 5632]" = torch.ops.prims.convert_element_type.default(arg455_1, torch.float32)
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0]
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] # File: <eval_with_key>.7:940 in forward, code: sigmoid_17 = torch.ops.aten.sigmoid.default(convert_element_type_106)
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] sigmoid_17: "f32[2048, 5632]" = torch.ops.aten.sigmoid.default(convert_element_type_106)
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0]
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] # File: <eval_with_key>.7:941 in forward, code: mul_53 = torch.ops.aten.mul.Tensor(convert_element_type_106, sigmoid_17); convert_element_type_106 = sigmoid_17 = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] mul_55: "f32[2048, 5632]" = torch.ops.aten.mul.Tensor(convert_element_type_106, sigmoid_17); convert_element_type_106 = sigmoid_17 = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0]
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] # File: <eval_with_key>.7:942 in forward, code: convert_element_type_107 = torch.ops.prims.convert_element_type.default(mul_53, torch.bfloat16); mul_53 = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] convert_element_type_107: "bf16[2048, 5632]" = torch.ops.prims.convert_element_type.default(mul_55, torch.bfloat16); mul_55 = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0]
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] # File: <eval_with_key>.7:992 in forward, code: mul_63 = torch.ops.aten.mul.Tensor(mm_3, convert_element_type_107); convert_element_type_107 = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] mul_65: "bf16[2048, 5632]" = torch.ops.aten.mul.Tensor(mm_3, convert_element_type_107); convert_element_type_107 = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0]
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:170 in foreach_all_gather_copy_out, code: split_unpadded = torch.ops.fsdp.contiguous_view_as_strided(
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] contiguous_view_as_strided_9: "bf16[5632, 2048]" = torch.ops.fsdp.contiguous_view_as_strided.default(getitem_85, [5632, 2048], [2048, 1]); getitem_85 = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0]
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] # File: <eval_with_key>.7:993 in forward, code: mul_64 = torch.ops.aten.mul.Tensor(mm_3, getitem_456); mm_3 = getitem_456 = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] mul_66: "bf16[2048, 5632]" = torch.ops.aten.mul.Tensor(mm_3, arg456_1); mm_3 = arg456_1 = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0]
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] # File: <eval_with_key>.7:1003 in forward, code: full = torch.ops.aten.full.default([2048, 5632], 1, dtype = torch.bfloat16, layout = torch.strided, device = device(type='cuda', index=0), pin_memory = False)
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] full_166: "bf16[2048, 5632]" = torch.ops.aten.full.default([2048, 5632], 1, dtype = torch.bfloat16, layout = torch.strided, device = device(type='cuda', index=0), pin_memory = False)
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0]
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] # File: <eval_with_key>.7:1002 in forward, code: sigmoid_18 = torch.ops.aten.sigmoid.default(getitem_455)
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] sigmoid_18: "bf16[2048, 5632]" = torch.ops.aten.sigmoid.default(arg455_1)
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0]
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] # File: <eval_with_key>.7:1004 in forward, code: sub = torch.ops.aten.sub.Tensor(full, sigmoid_18)
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] sub_1: "bf16[2048, 5632]" = torch.ops.aten.sub.Tensor(full_166, sigmoid_18)
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0]
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] # File: <eval_with_key>.7:1005 in forward, code: mul_65 = torch.ops.aten.mul.Tensor(getitem_455, sub); getitem_455 = sub = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] mul_67: "bf16[2048, 5632]" = torch.ops.aten.mul.Tensor(arg455_1, sub_1); arg455_1 = sub_1 = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0]
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] # File: <eval_with_key>.7:1006 in forward, code: add_37 = torch.ops.aten.add.Scalar(mul_65, 1); mul_65 = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] add_37: "bf16[2048, 5632]" = torch.ops.aten.add.Scalar(mul_67, 1); mul_67 = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0]
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] # File: <eval_with_key>.7:1007 in forward, code: mul_66 = torch.ops.aten.mul.Tensor(sigmoid_18, add_37); sigmoid_18 = add_37 = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] mul_68: "bf16[2048, 5632]" = torch.ops.aten.mul.Tensor(sigmoid_18, add_37); sigmoid_18 = add_37 = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0]
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] # File: <eval_with_key>.7:1008 in forward, code: mul_67 = torch.ops.aten.mul.Tensor(mul_64, mul_66); mul_64 = mul_66 = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] mul_69: "bf16[2048, 5632]" = torch.ops.aten.mul.Tensor(mul_66, mul_68); mul_66 = mul_68 = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0]
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:170 in foreach_all_gather_copy_out, code: split_unpadded = torch.ops.fsdp.contiguous_view_as_strided(
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] contiguous_view_as_strided_7: "bf16[5632, 2048]" = torch.ops.fsdp.contiguous_view_as_strided.default(getitem_65, [5632, 2048], [2048, 1]); getitem_65 = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0]
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] # File: <eval_with_key>.7:1015 in forward, code: add_38 = torch.ops.aten.add.Tensor(mm_5, mm_7); mm_5 = mm_7 = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] mm_plus_mm_35: "bf16[2048, 2048]" = torch__inductor_fx_passes_post_grad_mm_plus_mm(mat1 = mul_65, mat2 = contiguous_view_as_strided_9, mat3 = mul_69, mat4 = contiguous_view_as_strided_7); contiguous_view_as_strided_9 = contiguous_view_as_strided_7 = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0]
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:170 in foreach_all_gather_copy_out, code: split_unpadded = torch.ops.fsdp.contiguous_view_as_strided(
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] contiguous_view_as_strided_11: "bf16[2048]" = torch.ops.fsdp.contiguous_view_as_strided.default(getitem_105, [2048], [1]); getitem_105 = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0]
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] # File: <eval_with_key>.7:1019 in forward, code: mul_69 = torch.ops.aten.mul.Tensor(add_38, getitem_169); add_38 = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] mul_71: "bf16[2048, 2048]" = torch.ops.aten.mul.Tensor(mm_plus_mm_35, contiguous_view_as_strided_11); contiguous_view_as_strided_11 = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0]
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] # File: <eval_with_key>.7:1023 in forward, code: convert_element_type_113 = torch.ops.prims.convert_element_type.default(mul_69, torch.float32); mul_69 = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] convert_element_type_150: "f32[2048, 2048]" = torch.ops.prims.convert_element_type.default(mul_71, torch.float32); mul_71 = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0]
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] # File: <eval_with_key>.7:1025 in forward, code: mul_71 = torch.ops.aten.mul.Tensor(convert_element_type_113, getitem_453); convert_element_type_113 = getitem_453 = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] mul_73: "f32[2048, 2048]" = torch.ops.aten.mul.Tensor(convert_element_type_150, arg453_1)
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0]
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] # File: <eval_with_key>.7:934 in forward, code: convert_element_type_104 = torch.ops.prims.convert_element_type.default(add_34, torch.float32)
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] convert_element_type_104: "f32[2048, 2048]" = torch.ops.prims.convert_element_type.default(add_34, torch.float32); add_34 = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0]
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] # File: <eval_with_key>.7:1024 in forward, code: mul_70 = torch.ops.aten.mul.Tensor(convert_element_type_113, convert_element_type_104)
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] mul_72: "f32[2048, 2048]" = torch.ops.aten.mul.Tensor(convert_element_type_150, convert_element_type_104); convert_element_type_150 = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0]
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] # File: <eval_with_key>.7:1026 in forward, code: sum_4 = torch.ops.aten.sum.dim_IntList(mul_70, [1], True); mul_70 = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] sum_5: "f32[2048, 1]" = torch.ops.aten.sum.dim_IntList(mul_72, [1], True); mul_72 = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0]
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] # File: <eval_with_key>.7:1030 in forward, code: mul_72 = torch.ops.aten.mul.Scalar(sum_4, -0.5); sum_4 = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] mul_74: "f32[2048, 1]" = torch.ops.aten.mul.Scalar(sum_5, -0.5); sum_5 = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0]
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] # File: <eval_with_key>.7:1029 in forward, code: pow_3 = torch.ops.aten.pow.Tensor_Scalar(alias_77, 3); alias_77 = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] pow_3: "f32[2048, 1]" = torch.ops.aten.pow.Tensor_Scalar(arg453_1, 3)
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0]
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] # File: <eval_with_key>.7:1031 in forward, code: mul_73 = torch.ops.aten.mul.Tensor(mul_72, pow_3); mul_72 = pow_3 = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] mul_75: "f32[2048, 1]" = torch.ops.aten.mul.Tensor(mul_74, pow_3); mul_74 = pow_3 = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0]
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] # File: <eval_with_key>.7:1032 in forward, code: expand_1 = torch.ops.aten.expand.default(mul_73, [2048, 2048]); mul_73 = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] expand_1: "f32[2048, 2048]" = torch.ops.aten.expand.default(mul_75, [2048, 2048]); mul_75 = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0]
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] # File: <eval_with_key>.7:1033 in forward, code: div_1 = torch.ops.aten.div.Scalar(expand_1, 2048); expand_1 = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] div_2: "f32[2048, 2048]" = torch.ops.aten.div.Scalar(expand_1, 2048); expand_1 = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0]
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] # File: <eval_with_key>.7:1034 in forward, code: pow_4 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_104, 1.0); convert_element_type_104 = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] pow_4: "f32[2048, 2048]" = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_104, 1.0)
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0]
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] # File: <eval_with_key>.7:1035 in forward, code: mul_74 = torch.ops.aten.mul.Scalar(pow_4, 2.0); pow_4 = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] mul_76: "f32[2048, 2048]" = torch.ops.aten.mul.Scalar(pow_4, 2.0); pow_4 = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0]
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] # File: <eval_with_key>.7:1036 in forward, code: mul_75 = torch.ops.aten.mul.Tensor(div_1, mul_74); div_1 = mul_74 = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] mul_77: "f32[2048, 2048]" = torch.ops.aten.mul.Tensor(div_2, mul_76); div_2 = mul_76 = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0]
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] # File: <eval_with_key>.7:1037 in forward, code: add_39 = torch.ops.aten.add.Tensor(mul_71, mul_75); mul_71 = mul_75 = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] add_39: "f32[2048, 2048]" = torch.ops.aten.add.Tensor(mul_73, mul_77); mul_73 = mul_77 = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0]
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] # File: <eval_with_key>.7:1038 in forward, code: convert_element_type_114 = torch.ops.prims.convert_element_type.default(add_39, torch.bfloat16); add_39 = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] convert_element_type_151: "bf16[2048, 2048]" = torch.ops.prims.convert_element_type.default(add_39, torch.bfloat16); add_39 = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0]
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] # File: <eval_with_key>.7:1039 in forward, code: add_40 = torch.ops.aten.add.Tensor(trace_wrapped_1, convert_element_type_114); trace_wrapped_1 = convert_element_type_114 = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] add_40: "bf16[2048, 2048]" = torch.ops.aten.add.Tensor(convert_element_type_128, convert_element_type_151); convert_element_type_151 = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0]
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:170 in foreach_all_gather_copy_out, code: split_unpadded = torch.ops.fsdp.contiguous_view_as_strided(
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] contiguous_view_as_strided_6: "bf16[2048, 2048]" = torch.ops.fsdp.contiguous_view_as_strided.default(getitem_55, [2048, 2048], [2048, 1]); getitem_55 = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0]
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] # File: <eval_with_key>.7:1045 in forward, code: mm_9 = torch.ops.aten.mm.default(add_40, permute_23); permute_23 = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] mm_9: "bf16[2048, 2048]" = torch.ops.aten.mm.default(add_40, contiguous_view_as_strided_6); contiguous_view_as_strided_6 = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0]
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] # File: <eval_with_key>.7:1048 in forward, code: view_7 = torch.ops.aten.view.default(mm_9, [1, 2048, 16, 128]); mm_9 = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] view_21: "bf16[1, 2048, 16, 128]" = torch.ops.aten.reshape.default(mm_9, [1, 2048, 16, 128]); mm_9 = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0]
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] # File: <eval_with_key>.7:1049 in forward, code: permute_25 = torch.ops.aten.permute.default(view_7, [0, 2, 1, 3]); view_7 = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] permute_35: "bf16[1, 16, 2048, 128]" = torch.ops.aten.permute.default(view_21, [0, 2, 1, 3]); view_21 = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0]
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] # File: <eval_with_key>.7:1050 in forward, code: _scaled_dot_product_flash_attention_backward = torch.ops.aten._scaled_dot_product_flash_attention_backward.default(permute_25, getitem_445, getitem_446, getitem_447, getitem_461, getitem_448, None, None, 2048, 2048, 0.0, True, getitem_449, getitem_450, scale = 0.08838834764831843); permute_25 = getitem_445 = getitem_446 = getitem_447 = getitem_461 = getitem_448 = getitem_449 = getitem_450 = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] _scaled_dot_product_flash_attention_backward = torch.ops.aten._scaled_dot_product_flash_attention_backward.default(permute_35, arg445_1, arg446_1, arg447_1, arg461_1, arg448_1, None, None, 2048, 2048, 0.0, True, arg449_1, arg450_1, scale = 0.08838834764831843); permute_35 = arg445_1 = arg446_1 = arg447_1 = arg461_1 = arg448_1 = arg449_1 = arg450_1 = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] getitem_108: "bf16[1, 16, 2048, 128]" = _scaled_dot_product_flash_attention_backward[0]
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] getitem_109: "bf16[1, 16, 2048, 128]" = _scaled_dot_product_flash_attention_backward[1]
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] getitem_110: "bf16[1, 16, 2048, 128]" = _scaled_dot_product_flash_attention_backward[2]; _scaled_dot_product_flash_attention_backward = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0]
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_functional_collectives.py:229 in all_gather_tensor, code: tensor = torch.ops._c10d_functional.all_gather_into_tensor(
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] all_gather_into_tensor_2: "bf16[51384320]" = torch.ops._c10d_functional.all_gather_into_tensor.default(getitem_106, 8, '0'); getitem_106 = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0]
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_functional_collectives.py:144 in wait_tensor, code: return torch.ops._c10d_functional.wait_tensor(tensor) # type: ignore[attr-defined]
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] wait_tensor_2: "bf16[51384320]" = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_2); all_gather_into_tensor_2 = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0]
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:170 in foreach_all_gather_copy_out, code: split_unpadded = torch.ops.fsdp.contiguous_view_as_strided(
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] view_31: "bf16[8, 6423040]" = torch.ops.aten.reshape.default(wait_tensor_2, [8, -1])
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] split_with_sizes_15 = torch.ops.aten.split_with_sizes.default(view_31, [524288, 524288, 524288, 524288, 1441792, 1441792, 1441792, 256, 256], 1); view_31 = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] getitem_120: "bf16[8, 524288]" = split_with_sizes_15[0]; split_with_sizes_15 = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] view_32: "bf16[8, 6423040]" = torch.ops.aten.reshape.default(wait_tensor_2, [8, -1])
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] split_with_sizes_16 = torch.ops.aten.split_with_sizes.default(view_32, [524288, 524288, 524288, 524288, 1441792, 1441792, 1441792, 256, 256], 1); view_32 = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] getitem_130: "bf16[8, 524288]" = split_with_sizes_16[1]; split_with_sizes_16 = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] view_33: "bf16[8, 6423040]" = torch.ops.aten.reshape.default(wait_tensor_2, [8, -1])
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] split_with_sizes_17 = torch.ops.aten.split_with_sizes.default(view_33, [524288, 524288, 524288, 524288, 1441792, 1441792, 1441792, 256, 256], 1); view_33 = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] getitem_140: "bf16[8, 524288]" = split_with_sizes_17[2]; split_with_sizes_17 = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] view_34: "bf16[8, 6423040]" = torch.ops.aten.reshape.default(wait_tensor_2, [8, -1])
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] split_with_sizes_18 = torch.ops.aten.split_with_sizes.default(view_34, [524288, 524288, 524288, 524288, 1441792, 1441792, 1441792, 256, 256], 1); view_34 = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] getitem_150: "bf16[8, 524288]" = split_with_sizes_18[3]; split_with_sizes_18 = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] view_35: "bf16[8, 6423040]" = torch.ops.aten.reshape.default(wait_tensor_2, [8, -1])
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] split_with_sizes_19 = torch.ops.aten.split_with_sizes.default(view_35, [524288, 524288, 524288, 524288, 1441792, 1441792, 1441792, 256, 256], 1); view_35 = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] getitem_160: "bf16[8, 1441792]" = split_with_sizes_19[4]; split_with_sizes_19 = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] view_36: "bf16[8, 6423040]" = torch.ops.aten.reshape.default(wait_tensor_2, [8, -1])
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] split_with_sizes_20 = torch.ops.aten.split_with_sizes.default(view_36, [524288, 524288, 524288, 524288, 1441792, 1441792, 1441792, 256, 256], 1); view_36 = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] getitem_170: "bf16[8, 1441792]" = split_with_sizes_20[5]; split_with_sizes_20 = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] view_37: "bf16[8, 6423040]" = torch.ops.aten.reshape.default(wait_tensor_2, [8, -1])
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] split_with_sizes_21 = torch.ops.aten.split_with_sizes.default(view_37, [524288, 524288, 524288, 524288, 1441792, 1441792, 1441792, 256, 256], 1); view_37 = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] getitem_180: "bf16[8, 1441792]" = split_with_sizes_21[6]; split_with_sizes_21 = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] view_38: "bf16[8, 6423040]" = torch.ops.aten.reshape.default(wait_tensor_2, [8, -1])
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] split_with_sizes_22 = torch.ops.aten.split_with_sizes.default(view_38, [524288, 524288, 524288, 524288, 1441792, 1441792, 1441792, 256, 256], 1); view_38 = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] getitem_190: "bf16[8, 256]" = split_with_sizes_22[7]; split_with_sizes_22 = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] view_39: "bf16[8, 6423040]" = torch.ops.aten.reshape.default(wait_tensor_2, [8, -1]); wait_tensor_2 = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] split_with_sizes_23 = torch.ops.aten.split_with_sizes.default(view_39, [524288, 524288, 524288, 524288, 1441792, 1441792, 1441792, 256, 256], 1); view_39 = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] getitem_200: "bf16[8, 256]" = split_with_sizes_23[8]; split_with_sizes_23 = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0]
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_common.py:140 in _to_dtype_if_needed, code: return tensor.to(dtype)
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] convert_element_type_174: "bf16[524288]" = torch.ops.prims.convert_element_type.default(arg502_1, torch.bfloat16); arg502_1 = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] convert_element_type_175: "bf16[524288]" = torch.ops.prims.convert_element_type.default(arg503_1, torch.bfloat16); arg503_1 = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] convert_element_type_176: "bf16[524288]" = torch.ops.prims.convert_element_type.default(arg504_1, torch.bfloat16); arg504_1 = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] convert_element_type_177: "bf16[524288]" = torch.ops.prims.convert_element_type.default(arg505_1, torch.bfloat16); arg505_1 = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] convert_element_type_178: "bf16[1441792]" = torch.ops.prims.convert_element_type.default(arg506_1, torch.bfloat16); arg506_1 = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] convert_element_type_179: "bf16[1441792]" = torch.ops.prims.convert_element_type.default(arg507_1, torch.bfloat16); arg507_1 = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] convert_element_type_180: "bf16[1441792]" = torch.ops.prims.convert_element_type.default(arg508_1, torch.bfloat16); arg508_1 = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] convert_element_type_181: "bf16[256]" = torch.ops.prims.convert_element_type.default(arg509_1, torch.bfloat16); arg509_1 = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] convert_element_type_182: "bf16[256]" = torch.ops.prims.convert_element_type.default(arg510_1, torch.bfloat16); arg510_1 = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0]
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:80 in foreach_all_gather, code: all_gather_input, all_gather_output = torch.ops.fsdp.all_gather_copy_in(all_gather_input_numel, world_size, rank, dtype, device, inp_split_sizes, param_all_gather_inputs)
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] all_gather_copy_in_3 = torch.ops.fsdp.all_gather_copy_in.default(6423040, 8, 0, torch.bfloat16, device(type='cuda', index=0), [524288, 524288, 524288, 524288, 1441792, 1441792, 1441792, 256, 256], [convert_element_type_174, convert_element_type_175, convert_element_type_176, convert_element_type_177, convert_element_type_178, convert_element_type_179, convert_element_type_180, convert_element_type_181, convert_element_type_182]); convert_element_type_174 = convert_element_type_175 = convert_element_type_176 = convert_element_type_177 = convert_element_type_178 = convert_element_type_179 = convert_element_type_180 = convert_element_type_181 = convert_element_type_182 = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] getitem_201: "bf16[6423040]" = all_gather_copy_in_3[0]
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] getitem_202: "bf16[51384320]" = all_gather_copy_in_3[1]; all_gather_copy_in_3 = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0]
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] # File: <eval_with_key>.7:1054 in forward, code: permute_26 = torch.ops.aten.permute.default(getitem_484, [0, 2, 1, 3]); getitem_484 = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] permute_36: "bf16[1, 2048, 16, 128]" = torch.ops.aten.permute.default(getitem_110, [0, 2, 1, 3]); getitem_110 = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0]
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] # File: <eval_with_key>.7:1072 in forward, code: view_12 = torch.ops.aten.view.default(permute_26, [2048, 2048]); permute_26 = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] view_26: "bf16[2048, 2048]" = torch.ops.aten.reshape.default(permute_36, [2048, 2048]); permute_36 = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0]
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:170 in foreach_all_gather_copy_out, code: split_unpadded = torch.ops.fsdp.contiguous_view_as_strided(
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] contiguous_view_as_strided_5: "bf16[2048, 2048]" = torch.ops.fsdp.contiguous_view_as_strided.default(getitem_45, [2048, 2048], [2048, 1]); getitem_45 = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0]
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] # File: <eval_with_key>.7:1055 in forward, code: permute_27 = torch.ops.aten.permute.default(getitem_483, [0, 2, 1, 3]); getitem_483 = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] permute_37: "bf16[1, 2048, 16, 128]" = torch.ops.aten.permute.default(getitem_109, [0, 2, 1, 3]); getitem_109 = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0]
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] # File: <eval_with_key>.7:1057 in forward, code: convert_element_type_115 = torch.ops.prims.convert_element_type.default(permute_27, torch.float32); permute_27 = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] convert_element_type_156: "f32[1, 2048, 16, 128]" = torch.ops.prims.convert_element_type.default(permute_37, torch.float32); permute_37 = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0]
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] # File: <eval_with_key>.7:1059 in forward, code: view_8 = torch.ops.aten.view.default(convert_element_type_115, [1, 2048, 16, 64, 2]); convert_element_type_115 = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] view_22: "f32[1, 2048, 16, 64, 2]" = torch.ops.aten.reshape.default(convert_element_type_156, [1, 2048, 16, 64, 2]); convert_element_type_156 = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0]
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] # File: <eval_with_key>.7:1060 in forward, code: view_as_complex = torch.ops.aten.view_as_complex.default(view_8); view_8 = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] view_as_complex: "c64[1, 2048, 16, 64]" = torch.ops.aten.view_as_complex.default(view_22); view_22 = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0]
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] # File: <eval_with_key>.7:1062 in forward, code: mul_76 = torch.ops.aten.mul.Tensor(view_as_complex, clone); view_as_complex = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] mul_78: "c64[1, 2048, 16, 64]" = torch.ops.aten.mul.Tensor(view_as_complex, arg462_1); view_as_complex = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0]
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] # File: <eval_with_key>.7:1066 in forward, code: view_as_real = torch.ops.aten.view_as_real.default(mul_76); mul_76 = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] view_as_real: "f32[1, 2048, 16, 64, 2]" = torch.ops.aten.view_as_real.default(mul_78); mul_78 = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0]
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] # File: <eval_with_key>.7:1067 in forward, code: view_10 = torch.ops.aten.view.default(view_as_real, [1, 2048, 16, 128]); view_as_real = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] view_24: "f32[1, 2048, 16, 128]" = torch.ops.aten.reshape.default(view_as_real, [1, 2048, 16, 128]); view_as_real = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0]
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] # File: <eval_with_key>.7:1068 in forward, code: convert_element_type_117 = torch.ops.prims.convert_element_type.default(view_10, torch.bfloat16); view_10 = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] convert_element_type_158: "bf16[1, 2048, 16, 128]" = torch.ops.prims.convert_element_type.default(view_24, torch.bfloat16); view_24 = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0]
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] # File: <eval_with_key>.7:1073 in forward, code: view_13 = torch.ops.aten.view.default(convert_element_type_117, [2048, 2048]); convert_element_type_117 = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] view_27: "bf16[2048, 2048]" = torch.ops.aten.reshape.default(convert_element_type_158, [2048, 2048]); convert_element_type_158 = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0]
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:170 in foreach_all_gather_copy_out, code: split_unpadded = torch.ops.fsdp.contiguous_view_as_strided(
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] contiguous_view_as_strided_4: "bf16[2048, 2048]" = torch.ops.fsdp.contiguous_view_as_strided.default(getitem_35, [2048, 2048], [2048, 1]); getitem_35 = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0]
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] # File: <eval_with_key>.7:1089 in forward, code: add_41 = torch.ops.aten.add.Tensor(mm_11, mm_13); mm_11 = mm_13 = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] mm_plus_mm_34: "bf16[2048, 2048]" = torch__inductor_fx_passes_post_grad_mm_plus_mm_1(mat1 = view_26, mat2 = contiguous_view_as_strided_5, mat3 = view_27, mat4 = contiguous_view_as_strided_4); contiguous_view_as_strided_5 = contiguous_view_as_strided_4 = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0]
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] # File: <eval_with_key>.7:1056 in forward, code: permute_28 = torch.ops.aten.permute.default(getitem_482, [0, 2, 1, 3]); getitem_482 = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] permute_38: "bf16[1, 2048, 16, 128]" = torch.ops.aten.permute.default(getitem_108, [0, 2, 1, 3]); getitem_108 = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0]
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] # File: <eval_with_key>.7:1058 in forward, code: convert_element_type_116 = torch.ops.prims.convert_element_type.default(permute_28, torch.float32); permute_28 = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] convert_element_type_157: "f32[1, 2048, 16, 128]" = torch.ops.prims.convert_element_type.default(permute_38, torch.float32); permute_38 = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0]
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] # File: <eval_with_key>.7:1063 in forward, code: view_9 = torch.ops.aten.view.default(convert_element_type_116, [1, 2048, 16, 64, 2]); convert_element_type_116 = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] view_23: "f32[1, 2048, 16, 64, 2]" = torch.ops.aten.reshape.default(convert_element_type_157, [1, 2048, 16, 64, 2]); convert_element_type_157 = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0]
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] # File: <eval_with_key>.7:1064 in forward, code: view_as_complex_1 = torch.ops.aten.view_as_complex.default(view_9); view_9 = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] view_as_complex_1: "c64[1, 2048, 16, 64]" = torch.ops.aten.view_as_complex.default(view_23); view_23 = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0]
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] # File: <eval_with_key>.7:1065 in forward, code: mul_77 = torch.ops.aten.mul.Tensor(view_as_complex_1, clone); view_as_complex_1 = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] mul_79: "c64[1, 2048, 16, 64]" = torch.ops.aten.mul.Tensor(view_as_complex_1, arg462_1); view_as_complex_1 = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0]
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] # File: <eval_with_key>.7:1069 in forward, code: view_as_real_1 = torch.ops.aten.view_as_real.default(mul_77); mul_77 = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] view_as_real_1: "f32[1, 2048, 16, 64, 2]" = torch.ops.aten.view_as_real.default(mul_79); mul_79 = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0]
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] # File: <eval_with_key>.7:1070 in forward, code: view_11 = torch.ops.aten.view.default(view_as_real_1, [1, 2048, 16, 128]); view_as_real_1 = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] view_25: "f32[1, 2048, 16, 128]" = torch.ops.aten.reshape.default(view_as_real_1, [1, 2048, 16, 128]); view_as_real_1 = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0]
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] # File: <eval_with_key>.7:1071 in forward, code: convert_element_type_118 = torch.ops.prims.convert_element_type.default(view_11, torch.bfloat16); view_11 = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] convert_element_type_159: "bf16[1, 2048, 16, 128]" = torch.ops.prims.convert_element_type.default(view_25, torch.bfloat16); view_25 = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0]
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] # File: <eval_with_key>.7:1074 in forward, code: view_14 = torch.ops.aten.view.default(convert_element_type_118, [2048, 2048]); convert_element_type_118 = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] view_28: "bf16[2048, 2048]" = torch.ops.aten.reshape.default(convert_element_type_159, [2048, 2048]); convert_element_type_159 = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0]
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:170 in foreach_all_gather_copy_out, code: split_unpadded = torch.ops.fsdp.contiguous_view_as_strided(
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] contiguous_view_as_strided_3: "bf16[2048, 2048]" = torch.ops.fsdp.contiguous_view_as_strided.default(getitem_25, [2048, 2048], [2048, 1]); getitem_25 = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0]
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] # File: <eval_with_key>.7:1097 in forward, code: mm_15 = torch.ops.aten.mm.default(view_14, permute_42); view_14 = permute_42 = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] mm_15: "bf16[2048, 2048]" = torch.ops.aten.mm.default(view_28, contiguous_view_as_strided_3); contiguous_view_as_strided_3 = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0]
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] # File: <eval_with_key>.7:1098 in forward, code: add_42 = torch.ops.aten.add.Tensor(add_41, mm_15); add_41 = mm_15 = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] add_42: "bf16[2048, 2048]" = torch.ops.aten.add.Tensor(mm_plus_mm_34, mm_15); mm_plus_mm_34 = mm_15 = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0]
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:170 in foreach_all_gather_copy_out, code: split_unpadded = torch.ops.fsdp.contiguous_view_as_strided(
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] contiguous_view_as_strided_10: "bf16[2048]" = torch.ops.fsdp.contiguous_view_as_strided.default(getitem_95, [2048], [1]); getitem_95 = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0]
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] # File: <eval_with_key>.7:1102 in forward, code: mul_79 = torch.ops.aten.mul.Tensor(add_42, getitem_168); add_42 = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] mul_81: "bf16[2048, 2048]" = torch.ops.aten.mul.Tensor(add_42, contiguous_view_as_strided_10); contiguous_view_as_strided_10 = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0]
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] # File: <eval_with_key>.7:1106 in forward, code: convert_element_type_119 = torch.ops.prims.convert_element_type.default(mul_79, torch.float32); mul_79 = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] convert_element_type_172: "f32[2048, 2048]" = torch.ops.prims.convert_element_type.default(mul_81, torch.float32); mul_81 = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0]
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] # File: <eval_with_key>.7:1108 in forward, code: mul_81 = torch.ops.aten.mul.Tensor(convert_element_type_119, getitem_443); convert_element_type_119 = getitem_443 = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] mul_83: "f32[2048, 2048]" = torch.ops.aten.mul.Tensor(convert_element_type_172, arg443_1)
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0]
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] # File: <eval_with_key>.7:928 in forward, code: convert_element_type_102 = torch.ops.prims.convert_element_type.default(add_33, torch.float32)
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] convert_element_type_102: "f32[2048, 2048]" = torch.ops.prims.convert_element_type.default(add_33, torch.float32); add_33 = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0]
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] # File: <eval_with_key>.7:1107 in forward, code: mul_80 = torch.ops.aten.mul.Tensor(convert_element_type_119, convert_element_type_102)
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] mul_82: "f32[2048, 2048]" = torch.ops.aten.mul.Tensor(convert_element_type_172, convert_element_type_102); convert_element_type_172 = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0]
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] # File: <eval_with_key>.7:1109 in forward, code: sum_6 = torch.ops.aten.sum.dim_IntList(mul_80, [1], True); mul_80 = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] sum_7: "f32[2048, 1]" = torch.ops.aten.sum.dim_IntList(mul_82, [1], True); mul_82 = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0]
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] # File: <eval_with_key>.7:1113 in forward, code: mul_82 = torch.ops.aten.mul.Scalar(sum_6, -0.5); sum_6 = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] mul_84: "f32[2048, 1]" = torch.ops.aten.mul.Scalar(sum_7, -0.5); sum_7 = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0]
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] # File: <eval_with_key>.7:1112 in forward, code: pow_5 = torch.ops.aten.pow.Tensor_Scalar(alias_79, 3); alias_79 = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] pow_5: "f32[2048, 1]" = torch.ops.aten.pow.Tensor_Scalar(arg443_1, 3)
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0]
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] # File: <eval_with_key>.7:1114 in forward, code: mul_83 = torch.ops.aten.mul.Tensor(mul_82, pow_5); mul_82 = pow_5 = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] mul_85: "f32[2048, 1]" = torch.ops.aten.mul.Tensor(mul_84, pow_5); mul_84 = pow_5 = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0]
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] # File: <eval_with_key>.7:1115 in forward, code: expand_2 = torch.ops.aten.expand.default(mul_83, [2048, 2048]); mul_83 = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] expand_2: "f32[2048, 2048]" = torch.ops.aten.expand.default(mul_85, [2048, 2048]); mul_85 = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0]
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] # File: <eval_with_key>.7:1116 in forward, code: div_2 = torch.ops.aten.div.Scalar(expand_2, 2048); expand_2 = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] div_3: "f32[2048, 2048]" = torch.ops.aten.div.Scalar(expand_2, 2048); expand_2 = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0]
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] # File: <eval_with_key>.7:1117 in forward, code: pow_6 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_102, 1.0); convert_element_type_102 = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] pow_6: "f32[2048, 2048]" = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_102, 1.0)
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0]
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] # File: <eval_with_key>.7:1118 in forward, code: mul_84 = torch.ops.aten.mul.Scalar(pow_6, 2.0); pow_6 = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] mul_86: "f32[2048, 2048]" = torch.ops.aten.mul.Scalar(pow_6, 2.0); pow_6 = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0]
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] # File: <eval_with_key>.7:1119 in forward, code: mul_85 = torch.ops.aten.mul.Tensor(div_2, mul_84); div_2 = mul_84 = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] mul_87: "f32[2048, 2048]" = torch.ops.aten.mul.Tensor(div_3, mul_86); div_3 = mul_86 = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0]
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] # File: <eval_with_key>.7:1120 in forward, code: add_43 = torch.ops.aten.add.Tensor(mul_81, mul_85); mul_81 = mul_85 = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] add_43: "f32[2048, 2048]" = torch.ops.aten.add.Tensor(mul_83, mul_87); mul_83 = mul_87 = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0]
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] # File: <eval_with_key>.7:1121 in forward, code: convert_element_type_120 = torch.ops.prims.convert_element_type.default(add_43, torch.bfloat16); add_43 = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] convert_element_type_173: "bf16[2048, 2048]" = torch.ops.prims.convert_element_type.default(add_43, torch.bfloat16); add_43 = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0]
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] # File: <eval_with_key>.7:1122 in forward, code: add_44 = torch.ops.aten.add.Tensor(add_40, convert_element_type_120); add_40 = convert_element_type_120 = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] add_44: "bf16[2048, 2048]" = torch.ops.aten.add.Tensor(add_40, convert_element_type_173); convert_element_type_173 = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0]
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:170 in foreach_all_gather_copy_out, code: split_unpadded = torch.ops.fsdp.contiguous_view_as_strided(
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] contiguous_view_as_strided_17: "bf16[2048, 5632]" = torch.ops.fsdp.contiguous_view_as_strided.default(getitem_170, [2048, 5632], [5632, 1]); getitem_170 = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0]
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] # File: <eval_with_key>.7:1129 in forward, code: mm_17 = torch.ops.aten.mm.default(trace_wrapped_2, permute_47); permute_47 = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] mm_17: "bf16[2048, 5632]" = torch.ops.aten.mm.default(add_44, contiguous_view_as_strided_17); contiguous_view_as_strided_17 = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0]
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] # File: <eval_with_key>.7:923 in forward, code: convert_element_type_100 = torch.ops.prims.convert_element_type.default(getitem_439, torch.float32)
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] convert_element_type_100: "f32[2048, 5632]" = torch.ops.prims.convert_element_type.default(arg439_1, torch.float32)
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0]
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] # File: <eval_with_key>.7:924 in forward, code: sigmoid_16 = torch.ops.aten.sigmoid.default(convert_element_type_100)
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] sigmoid_16: "f32[2048, 5632]" = torch.ops.aten.sigmoid.default(convert_element_type_100)
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0]
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] # File: <eval_with_key>.7:925 in forward, code: mul_50 = torch.ops.aten.mul.Tensor(convert_element_type_100, sigmoid_16); convert_element_type_100 = sigmoid_16 = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] mul_52: "f32[2048, 5632]" = torch.ops.aten.mul.Tensor(convert_element_type_100, sigmoid_16); convert_element_type_100 = sigmoid_16 = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0]
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] # File: <eval_with_key>.7:926 in forward, code: convert_element_type_101 = torch.ops.prims.convert_element_type.default(mul_50, torch.bfloat16); mul_50 = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] convert_element_type_101: "bf16[2048, 5632]" = torch.ops.prims.convert_element_type.default(mul_52, torch.bfloat16); mul_52 = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0]
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] # File: <eval_with_key>.7:1132 in forward, code: mul_86 = torch.ops.aten.mul.Tensor(mm_17, convert_element_type_101); convert_element_type_101 = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] mul_88: "bf16[2048, 5632]" = torch.ops.aten.mul.Tensor(mm_17, convert_element_type_101); convert_element_type_101 = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0]
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:170 in foreach_all_gather_copy_out, code: split_unpadded = torch.ops.fsdp.contiguous_view_as_strided(
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] contiguous_view_as_strided_18: "bf16[5632, 2048]" = torch.ops.fsdp.contiguous_view_as_strided.default(getitem_180, [5632, 2048], [2048, 1]); getitem_180 = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0]
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] # File: <eval_with_key>.7:1133 in forward, code: mul_87 = torch.ops.aten.mul.Tensor(mm_17, getitem_440); mm_17 = getitem_440 = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] mul_89: "bf16[2048, 5632]" = torch.ops.aten.mul.Tensor(mm_17, arg440_1); mm_17 = arg440_1 = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0]
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] # File: <eval_with_key>.7:1142 in forward, code: sigmoid_19 = torch.ops.aten.sigmoid.default(getitem_439)
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] sigmoid_19: "bf16[2048, 5632]" = torch.ops.aten.sigmoid.default(arg439_1)
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0]
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] # File: <eval_with_key>.7:1143 in forward, code: sub_1 = torch.ops.aten.sub.Tensor(full, sigmoid_19)
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] sub_2: "bf16[2048, 5632]" = torch.ops.aten.sub.Tensor(full_166, sigmoid_19)
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0]
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] # File: <eval_with_key>.7:1144 in forward, code: mul_88 = torch.ops.aten.mul.Tensor(getitem_439, sub_1); getitem_439 = sub_1 = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] mul_90: "bf16[2048, 5632]" = torch.ops.aten.mul.Tensor(arg439_1, sub_2); arg439_1 = sub_2 = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0]
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] # File: <eval_with_key>.7:1145 in forward, code: add_45 = torch.ops.aten.add.Scalar(mul_88, 1); mul_88 = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] add_45: "bf16[2048, 5632]" = torch.ops.aten.add.Scalar(mul_90, 1); mul_90 = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0]
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] # File: <eval_with_key>.7:1146 in forward, code: mul_89 = torch.ops.aten.mul.Tensor(sigmoid_19, add_45); sigmoid_19 = add_45 = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] mul_91: "bf16[2048, 5632]" = torch.ops.aten.mul.Tensor(sigmoid_19, add_45); sigmoid_19 = add_45 = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0]
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] # File: <eval_with_key>.7:1147 in forward, code: mul_90 = torch.ops.aten.mul.Tensor(mul_87, mul_89); mul_87 = mul_89 = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] mul_92: "bf16[2048, 5632]" = torch.ops.aten.mul.Tensor(mul_89, mul_91); mul_89 = mul_91 = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0]
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:170 in foreach_all_gather_copy_out, code: split_unpadded = torch.ops.fsdp.contiguous_view_as_strided(
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] contiguous_view_as_strided_16: "bf16[5632, 2048]" = torch.ops.fsdp.contiguous_view_as_strided.default(getitem_160, [5632, 2048], [2048, 1]); getitem_160 = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0]
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] # File: <eval_with_key>.7:1154 in forward, code: add_46 = torch.ops.aten.add.Tensor(mm_19, mm_21); mm_19 = mm_21 = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] mm_plus_mm_33: "bf16[2048, 2048]" = torch__inductor_fx_passes_post_grad_mm_plus_mm_2(mat1 = mul_88, mat2 = contiguous_view_as_strided_18, mat3 = mul_92, mat4 = contiguous_view_as_strided_16); contiguous_view_as_strided_18 = contiguous_view_as_strided_16 = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0]
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:170 in foreach_all_gather_copy_out, code: split_unpadded = torch.ops.fsdp.contiguous_view_as_strided(
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] contiguous_view_as_strided_20: "bf16[2048]" = torch.ops.fsdp.contiguous_view_as_strided.default(getitem_200, [2048], [1]); getitem_200 = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0]
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] # File: <eval_with_key>.7:1158 in forward, code: mul_92 = torch.ops.aten.mul.Tensor(add_46, getitem_160); add_46 = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] mul_94: "bf16[2048, 2048]" = torch.ops.aten.mul.Tensor(mm_plus_mm_33, contiguous_view_as_strided_20); contiguous_view_as_strided_20 = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0]
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] # File: <eval_with_key>.7:1162 in forward, code: convert_element_type_121 = torch.ops.prims.convert_element_type.default(mul_92, torch.float32); mul_92 = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] convert_element_type_195: "f32[2048, 2048]" = torch.ops.prims.convert_element_type.default(mul_94, torch.float32); mul_94 = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0]
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] # File: <eval_with_key>.7:1164 in forward, code: mul_94 = torch.ops.aten.mul.Tensor(convert_element_type_121, getitem_437); convert_element_type_121 = getitem_437 = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] mul_96: "f32[2048, 2048]" = torch.ops.aten.mul.Tensor(convert_element_type_195, arg437_1)
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0]
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] # File: <eval_with_key>.7:918 in forward, code: convert_element_type_98 = torch.ops.prims.convert_element_type.default(add_32, torch.float32)
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] convert_element_type_98: "f32[2048, 2048]" = torch.ops.prims.convert_element_type.default(add_32, torch.float32); add_32 = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0]
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] # File: <eval_with_key>.7:1163 in forward, code: mul_93 = torch.ops.aten.mul.Tensor(convert_element_type_121, convert_element_type_98)
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] mul_95: "f32[2048, 2048]" = torch.ops.aten.mul.Tensor(convert_element_type_195, convert_element_type_98); convert_element_type_195 = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0]
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] # File: <eval_with_key>.7:1165 in forward, code: sum_8 = torch.ops.aten.sum.dim_IntList(mul_93, [1], True); mul_93 = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] sum_9: "f32[2048, 1]" = torch.ops.aten.sum.dim_IntList(mul_95, [1], True); mul_95 = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0]
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] # File: <eval_with_key>.7:1169 in forward, code: mul_95 = torch.ops.aten.mul.Scalar(sum_8, -0.5); sum_8 = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] mul_97: "f32[2048, 1]" = torch.ops.aten.mul.Scalar(sum_9, -0.5); sum_9 = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0]
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] # File: <eval_with_key>.7:1168 in forward, code: pow_7 = torch.ops.aten.pow.Tensor_Scalar(alias_81, 3); alias_81 = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] pow_7: "f32[2048, 1]" = torch.ops.aten.pow.Tensor_Scalar(arg437_1, 3)
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0]
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] # File: <eval_with_key>.7:1170 in forward, code: mul_96 = torch.ops.aten.mul.Tensor(mul_95, pow_7); mul_95 = pow_7 = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] mul_98: "f32[2048, 1]" = torch.ops.aten.mul.Tensor(mul_97, pow_7); mul_97 = pow_7 = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0]
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] # File: <eval_with_key>.7:1171 in forward, code: expand_3 = torch.ops.aten.expand.default(mul_96, [2048, 2048]); mul_96 = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] expand_3: "f32[2048, 2048]" = torch.ops.aten.expand.default(mul_98, [2048, 2048]); mul_98 = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0]
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] # File: <eval_with_key>.7:1172 in forward, code: div_3 = torch.ops.aten.div.Scalar(expand_3, 2048); expand_3 = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] div_4: "f32[2048, 2048]" = torch.ops.aten.div.Scalar(expand_3, 2048); expand_3 = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0]
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] # File: <eval_with_key>.7:1173 in forward, code: pow_8 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_98, 1.0); convert_element_type_98 = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] pow_8: "f32[2048, 2048]" = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_98, 1.0)
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0]
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] # File: <eval_with_key>.7:1174 in forward, code: mul_97 = torch.ops.aten.mul.Scalar(pow_8, 2.0); pow_8 = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] mul_99: "f32[2048, 2048]" = torch.ops.aten.mul.Scalar(pow_8, 2.0); pow_8 = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0]
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] # File: <eval_with_key>.7:1175 in forward, code: mul_98 = torch.ops.aten.mul.Tensor(div_3, mul_97); div_3 = mul_97 = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] mul_100: "f32[2048, 2048]" = torch.ops.aten.mul.Tensor(div_4, mul_99); div_4 = mul_99 = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0]
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] # File: <eval_with_key>.7:1176 in forward, code: add_47 = torch.ops.aten.add.Tensor(mul_94, mul_98); mul_94 = mul_98 = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] add_47: "f32[2048, 2048]" = torch.ops.aten.add.Tensor(mul_96, mul_100); mul_96 = mul_100 = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0]
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] # File: <eval_with_key>.7:1177 in forward, code: convert_element_type_122 = torch.ops.prims.convert_element_type.default(add_47, torch.bfloat16); add_47 = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] convert_element_type_196: "bf16[2048, 2048]" = torch.ops.prims.convert_element_type.default(add_47, torch.bfloat16); add_47 = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0]
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] # File: <eval_with_key>.7:1178 in forward, code: add_48 = torch.ops.aten.add.Tensor(trace_wrapped_2, convert_element_type_122); trace_wrapped_2 = convert_element_type_122 = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] add_48: "bf16[2048, 2048]" = torch.ops.aten.add.Tensor(add_44, convert_element_type_196); convert_element_type_196 = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0]
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:170 in foreach_all_gather_copy_out, code: split_unpadded = torch.ops.fsdp.contiguous_view_as_strided(
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] contiguous_view_as_strided_15: "bf16[2048, 2048]" = torch.ops.fsdp.contiguous_view_as_strided.default(getitem_150, [2048, 2048], [2048, 1]); getitem_150 = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0]
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] # File: <eval_with_key>.7:1184 in forward, code: mm_23 = torch.ops.aten.mm.default(add_48, permute_62); permute_62 = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] mm_23: "bf16[2048, 2048]" = torch.ops.aten.mm.default(add_48, contiguous_view_as_strided_15); contiguous_view_as_strided_15 = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0]
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] # File: <eval_with_key>.7:1187 in forward, code: view_17 = torch.ops.aten.view.default(mm_23, [1, 2048, 16, 128]); mm_23 = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] view_41: "bf16[1, 2048, 16, 128]" = torch.ops.aten.reshape.default(mm_23, [1, 2048, 16, 128]); mm_23 = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0]
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] # File: <eval_with_key>.7:1188 in forward, code: permute_64 = torch.ops.aten.permute.default(view_17, [0, 2, 1, 3]); view_17 = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] permute_88: "bf16[1, 16, 2048, 128]" = torch.ops.aten.permute.default(view_41, [0, 2, 1, 3]); view_41 = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0]
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] # File: <eval_with_key>.7:1189 in forward, code: _scaled_dot_product_flash_attention_backward_1 = torch.ops.aten._scaled_dot_product_flash_attention_backward.default(permute_64, getitem_429, getitem_430, getitem_431, getitem_463, getitem_432, None, None, 2048, 2048, 0.0, True, getitem_433, getitem_434, scale = 0.08838834764831843); permute_64 = getitem_429 = getitem_430 = getitem_431 = getitem_463 = getitem_432 = getitem_433 = getitem_434 = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] _scaled_dot_product_flash_attention_backward_1 = torch.ops.aten._scaled_dot_product_flash_attention_backward.default(permute_88, arg429_1, arg430_1, arg431_1, arg463_1, arg432_1, None, None, 2048, 2048, 0.0, True, arg433_1, arg434_1, scale = 0.08838834764831843); permute_88 = arg429_1 = arg430_1 = arg431_1 = arg463_1 = arg432_1 = arg433_1 = arg434_1 = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] getitem_203: "bf16[1, 16, 2048, 128]" = _scaled_dot_product_flash_attention_backward_1[0]
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] getitem_204: "bf16[1, 16, 2048, 128]" = _scaled_dot_product_flash_attention_backward_1[1]
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] getitem_205: "bf16[1, 16, 2048, 128]" = _scaled_dot_product_flash_attention_backward_1[2]; _scaled_dot_product_flash_attention_backward_1 = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0]
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_functional_collectives.py:229 in all_gather_tensor, code: tensor = torch.ops._c10d_functional.all_gather_into_tensor(
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] all_gather_into_tensor_3: "bf16[51384320]" = torch.ops._c10d_functional.all_gather_into_tensor.default(getitem_201, 8, '0'); getitem_201 = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0]
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_functional_collectives.py:144 in wait_tensor, code: return torch.ops._c10d_functional.wait_tensor(tensor) # type: ignore[attr-defined]
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] wait_tensor_3: "bf16[51384320]" = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_3); all_gather_into_tensor_3 = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0]
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:170 in foreach_all_gather_copy_out, code: split_unpadded = torch.ops.fsdp.contiguous_view_as_strided(
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] view_51: "bf16[8, 6423040]" = torch.ops.aten.reshape.default(wait_tensor_3, [8, -1])
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] split_with_sizes_25 = torch.ops.aten.split_with_sizes.default(view_51, [524288, 524288, 524288, 524288, 1441792, 1441792, 1441792, 256, 256], 1); view_51 = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] getitem_215: "bf16[8, 524288]" = split_with_sizes_25[0]; split_with_sizes_25 = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] view_52: "bf16[8, 6423040]" = torch.ops.aten.reshape.default(wait_tensor_3, [8, -1])
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] split_with_sizes_26 = torch.ops.aten.split_with_sizes.default(view_52, [524288, 524288, 524288, 524288, 1441792, 1441792, 1441792, 256, 256], 1); view_52 = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] getitem_225: "bf16[8, 524288]" = split_with_sizes_26[1]; split_with_sizes_26 = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] view_53: "bf16[8, 6423040]" = torch.ops.aten.reshape.default(wait_tensor_3, [8, -1])
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] split_with_sizes_27 = torch.ops.aten.split_with_sizes.default(view_53, [524288, 524288, 524288, 524288, 1441792, 1441792, 1441792, 256, 256], 1); view_53 = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] getitem_235: "bf16[8, 524288]" = split_with_sizes_27[2]; split_with_sizes_27 = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] view_54: "bf16[8, 6423040]" = torch.ops.aten.reshape.default(wait_tensor_3, [8, -1])
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] split_with_sizes_28 = torch.ops.aten.split_with_sizes.default(view_54, [524288, 524288, 524288, 524288, 1441792, 1441792, 1441792, 256, 256], 1); view_54 = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] getitem_245: "bf16[8, 524288]" = split_with_sizes_28[3]; split_with_sizes_28 = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] view_55: "bf16[8, 6423040]" = torch.ops.aten.reshape.default(wait_tensor_3, [8, -1])
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] split_with_sizes_29 = torch.ops.aten.split_with_sizes.default(view_55, [524288, 524288, 524288, 524288, 1441792, 1441792, 1441792, 256, 256], 1); view_55 = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] getitem_255: "bf16[8, 1441792]" = split_with_sizes_29[4]; split_with_sizes_29 = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] view_56: "bf16[8, 6423040]" = torch.ops.aten.reshape.default(wait_tensor_3, [8, -1])
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] split_with_sizes_30 = torch.ops.aten.split_with_sizes.default(view_56, [524288, 524288, 524288, 524288, 1441792, 1441792, 1441792, 256, 256], 1); view_56 = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] getitem_265: "bf16[8, 1441792]" = split_with_sizes_30[5]; split_with_sizes_30 = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] view_57: "bf16[8, 6423040]" = torch.ops.aten.reshape.default(wait_tensor_3, [8, -1])
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] split_with_sizes_31 = torch.ops.aten.split_with_sizes.default(view_57, [524288, 524288, 524288, 524288, 1441792, 1441792, 1441792, 256, 256], 1); view_57 = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] getitem_275: "bf16[8, 1441792]" = split_with_sizes_31[6]; split_with_sizes_31 = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] view_58: "bf16[8, 6423040]" = torch.ops.aten.reshape.default(wait_tensor_3, [8, -1])
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] split_with_sizes_32 = torch.ops.aten.split_with_sizes.default(view_58, [524288, 524288, 524288, 524288, 1441792, 1441792, 1441792, 256, 256], 1); view_58 = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] getitem_285: "bf16[8, 256]" = split_with_sizes_32[7]; split_with_sizes_32 = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] view_59: "bf16[8, 6423040]" = torch.ops.aten.reshape.default(wait_tensor_3, [8, -1]); wait_tensor_3 = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] split_with_sizes_33 = torch.ops.aten.split_with_sizes.default(view_59, [524288, 524288, 524288, 524288, 1441792, 1441792, 1441792, 256, 256], 1); view_59 = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] getitem_295: "bf16[8, 256]" = split_with_sizes_33[8]; split_with_sizes_33 = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0]
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_common.py:140 in _to_dtype_if_needed, code: return tensor.to(dtype)
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] convert_element_type_219: "bf16[524288]" = torch.ops.prims.convert_element_type.default(arg511_1, torch.bfloat16); arg511_1 = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] convert_element_type_220: "bf16[524288]" = torch.ops.prims.convert_element_type.default(arg512_1, torch.bfloat16); arg512_1 = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] convert_element_type_221: "bf16[524288]" = torch.ops.prims.convert_element_type.default(arg513_1, torch.bfloat16); arg513_1 = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] convert_element_type_222: "bf16[524288]" = torch.ops.prims.convert_element_type.default(arg514_1, torch.bfloat16); arg514_1 = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] convert_element_type_223: "bf16[1441792]" = torch.ops.prims.convert_element_type.default(arg515_1, torch.bfloat16); arg515_1 = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] convert_element_type_224: "bf16[1441792]" = torch.ops.prims.convert_element_type.default(arg516_1, torch.bfloat16); arg516_1 = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] convert_element_type_225: "bf16[1441792]" = torch.ops.prims.convert_element_type.default(arg517_1, torch.bfloat16); arg517_1 = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] convert_element_type_226: "bf16[256]" = torch.ops.prims.convert_element_type.default(arg518_1, torch.bfloat16); arg518_1 = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] convert_element_type_227: "bf16[256]" = torch.ops.prims.convert_element_type.default(arg519_1, torch.bfloat16); arg519_1 = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0]
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:80 in foreach_all_gather, code: all_gather_input, all_gather_output = torch.ops.fsdp.all_gather_copy_in(all_gather_input_numel, world_size, rank, dtype, device, inp_split_sizes, param_all_gather_inputs)
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] all_gather_copy_in_4 = torch.ops.fsdp.all_gather_copy_in.default(6423040, 8, 0, torch.bfloat16, device(type='cuda', index=0), [524288, 524288, 524288, 524288, 1441792, 1441792, 1441792, 256, 256], [convert_element_type_219, convert_element_type_220, convert_element_type_221, convert_element_type_222, convert_element_type_223, convert_element_type_224, convert_element_type_225, convert_element_type_226, convert_element_type_227]); convert_element_type_219 = convert_element_type_220 = convert_element_type_221 = convert_element_type_222 = convert_element_type_223 = convert_element_type_224 = convert_element_type_225 = convert_element_type_226 = convert_element_type_227 = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] getitem_296: "bf16[6423040]" = all_gather_copy_in_4[0]
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] getitem_297: "bf16[51384320]" = all_gather_copy_in_4[1]; all_gather_copy_in_4 = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0]
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] # File: <eval_with_key>.7:1193 in forward, code: permute_65 = torch.ops.aten.permute.default(getitem_487, [0, 2, 1, 3]); getitem_487 = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] permute_89: "bf16[1, 2048, 16, 128]" = torch.ops.aten.permute.default(getitem_205, [0, 2, 1, 3]); getitem_205 = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0]
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] # File: <eval_with_key>.7:1210 in forward, code: view_22 = torch.ops.aten.view.default(permute_65, [2048, 2048]); permute_65 = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] view_46: "bf16[2048, 2048]" = torch.ops.aten.reshape.default(permute_89, [2048, 2048]); permute_89 = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0]
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:170 in foreach_all_gather_copy_out, code: split_unpadded = torch.ops.fsdp.contiguous_view_as_strided(
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] contiguous_view_as_strided_14: "bf16[2048, 2048]" = torch.ops.fsdp.contiguous_view_as_strided.default(getitem_140, [2048, 2048], [2048, 1]); getitem_140 = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0]
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] # File: <eval_with_key>.7:1194 in forward, code: permute_66 = torch.ops.aten.permute.default(getitem_486, [0, 2, 1, 3]); getitem_486 = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] permute_90: "bf16[1, 2048, 16, 128]" = torch.ops.aten.permute.default(getitem_204, [0, 2, 1, 3]); getitem_204 = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0]
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] # File: <eval_with_key>.7:1196 in forward, code: convert_element_type_123 = torch.ops.prims.convert_element_type.default(permute_66, torch.float32); permute_66 = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] convert_element_type_201: "f32[1, 2048, 16, 128]" = torch.ops.prims.convert_element_type.default(permute_90, torch.float32); permute_90 = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0]
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] # File: <eval_with_key>.7:1198 in forward, code: view_18 = torch.ops.aten.view.default(convert_element_type_123, [1, 2048, 16, 64, 2]); convert_element_type_123 = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] view_42: "f32[1, 2048, 16, 64, 2]" = torch.ops.aten.reshape.default(convert_element_type_201, [1, 2048, 16, 64, 2]); convert_element_type_201 = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0]
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] # File: <eval_with_key>.7:1199 in forward, code: view_as_complex_2 = torch.ops.aten.view_as_complex.default(view_18); view_18 = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] view_as_complex_2: "c64[1, 2048, 16, 64]" = torch.ops.aten.view_as_complex.default(view_42); view_42 = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0]
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] # File: <eval_with_key>.7:1200 in forward, code: mul_99 = torch.ops.aten.mul.Tensor(view_as_complex_2, clone); view_as_complex_2 = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] mul_101: "c64[1, 2048, 16, 64]" = torch.ops.aten.mul.Tensor(view_as_complex_2, arg462_1); view_as_complex_2 = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0]
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] # File: <eval_with_key>.7:1204 in forward, code: view_as_real_2 = torch.ops.aten.view_as_real.default(mul_99); mul_99 = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] view_as_real_2: "f32[1, 2048, 16, 64, 2]" = torch.ops.aten.view_as_real.default(mul_101); mul_101 = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0]
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] # File: <eval_with_key>.7:1205 in forward, code: view_20 = torch.ops.aten.view.default(view_as_real_2, [1, 2048, 16, 128]); view_as_real_2 = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] view_44: "f32[1, 2048, 16, 128]" = torch.ops.aten.reshape.default(view_as_real_2, [1, 2048, 16, 128]); view_as_real_2 = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0]
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inductor/fx_passes/post_grad.py:143] [1/0] # File: <eval_with_key>.7:1206 in forward, code: convert_element_type_125 = torch.ops.prims.convert_element_type.default(view_20, torch.bfloat16); view_20 = None
[rank0]:[rank0]:W2024-03-28 18:05:19,715.715000 139804330415936 torch/_inducto
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment