Created
March 27, 2024 01:43
-
-
Save yf225/04077ac692d436d7b778ed3953dd8641 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| ===== AFTER POST GRAD ===== | |
| /data/users/willfeng/pytorch_yf225/torch/fx/_lazy_graph_module.py class <lambda>(torch.nn.Module): | |
| def forward(self, arg0_1: "f32[]", arg1_1: "f32[8, 32]", arg2_1: "f32[32, 128]", arg3_1: "f32[128, 32]", arg4_1: "f32[32, 128]", arg5_1: "f32[128, 32]", arg6_1: "f32[32, 128]", arg7_1: "f32[8, 128]", arg8_1: "f32[8, 32]", arg9_1: "f32[8, 128]", arg10_1: "f32[8, 32]", arg11_1: "f32[8, 128]", arg12_1: "b8[8, 32]", arg13_1: "f32[32]", arg14_1: "f32[128]", arg15_1: "f32[32]", arg16_1: "f32[128]", arg17_1: "f32[32]", arg18_1: "f32[128]", arg19_1: "f32[128, 32]", arg20_1: "f32[2048]", arg21_1: "f32[64]", arg22_1: "f32[2048]", arg23_1: "f32[16]", arg24_1: "f32[2048]", arg25_1: "f32[64]", arg26_1: "f32[2048]", arg27_1: "f32[16]", arg28_1: "f32[2048]", arg29_1: "f32[64]", arg30_1: "f32[2048]", arg31_1: "f32[16]"): | |
| # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:47 in foreach_all_gather, code: all_gather_output = torch.empty( | |
| empty: "f32[8352]" = torch.ops.aten.empty.memory_format([8352], dtype = torch.float32, device = device(type='cuda', index=1), pin_memory = False) | |
| # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:50 in foreach_all_gather, code: all_gather_input = all_gather_output.narrow( | |
| slice_1: "f32[4176]" = torch.ops.aten.slice.Tensor(empty, 0, 4176, 8352) | |
| # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:53 in foreach_all_gather, code: foreach_copy_dsts = torch.split(all_gather_input, inp_split_sizes) | |
| split_with_sizes = torch.ops.aten.split_with_sizes.default(slice_1, [2048, 64, 2048, 16]); slice_1 = None | |
| getitem: "f32[2048]" = split_with_sizes[0] | |
| getitem_1: "f32[64]" = split_with_sizes[1] | |
| getitem_2: "f32[2048]" = split_with_sizes[2] | |
| getitem_3: "f32[16]" = split_with_sizes[3]; split_with_sizes = None | |
| # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:55 in foreach_all_gather, code: torch._foreach_copy_(foreach_copy_dsts, param_all_gather_inputs) | |
| _foreach_copy = torch.ops.aten._foreach_copy.default([getitem, getitem_1, getitem_2, getitem_3], [arg20_1, arg21_1, arg22_1, arg23_1]); getitem = getitem_1 = getitem_2 = getitem_3 = arg20_1 = arg21_1 = arg22_1 = arg23_1 = None | |
| getitem_4: "f32[2048]" = _foreach_copy[0] | |
| getitem_5: "f32[64]" = _foreach_copy[1] | |
| getitem_6: "f32[2048]" = _foreach_copy[2] | |
| getitem_7: "f32[16]" = _foreach_copy[3]; _foreach_copy = None | |
| slice_2: "f32[4176]" = torch.ops.aten.slice.Tensor(empty, 0, 4176, 8352) | |
| # No stacktrace found for following nodes | |
| slice_tensor: "f32[4176]" = torch.ops.aten.slice.Tensor(empty, 0, 4176, 8352) | |
| slice_scatter_default: "f32[4176]" = torch.ops.aten.slice_scatter.default(slice_tensor, getitem_4, 0, 0, 2048); slice_tensor = getitem_4 = None | |
| slice_scatter_default_1: "f32[8352]" = torch.ops.aten.slice_scatter.default(empty, slice_scatter_default, 0, 4176, 8352); empty = slice_scatter_default = None | |
| # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:55 in foreach_all_gather, code: torch._foreach_copy_(foreach_copy_dsts, param_all_gather_inputs) | |
| slice_3: "f32[4176]" = torch.ops.aten.slice.Tensor(slice_scatter_default_1, 0, 4176, 8352) | |
| # No stacktrace found for following nodes | |
| slice_tensor_1: "f32[4176]" = torch.ops.aten.slice.Tensor(slice_scatter_default_1, 0, 4176, 8352) | |
| slice_scatter_default_2: "f32[4176]" = torch.ops.aten.slice_scatter.default(slice_tensor_1, getitem_5, 0, 2048, 2112); slice_tensor_1 = getitem_5 = None | |
| slice_scatter_default_3: "f32[8352]" = torch.ops.aten.slice_scatter.default(slice_scatter_default_1, slice_scatter_default_2, 0, 4176, 8352); slice_scatter_default_1 = slice_scatter_default_2 = None | |
| # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:55 in foreach_all_gather, code: torch._foreach_copy_(foreach_copy_dsts, param_all_gather_inputs) | |
| slice_4: "f32[4176]" = torch.ops.aten.slice.Tensor(slice_scatter_default_3, 0, 4176, 8352) | |
| # No stacktrace found for following nodes | |
| slice_tensor_2: "f32[4176]" = torch.ops.aten.slice.Tensor(slice_scatter_default_3, 0, 4176, 8352) | |
| slice_scatter_default_4: "f32[4176]" = torch.ops.aten.slice_scatter.default(slice_tensor_2, getitem_6, 0, 2112, 4160); slice_tensor_2 = getitem_6 = None | |
| slice_scatter_default_5: "f32[8352]" = torch.ops.aten.slice_scatter.default(slice_scatter_default_3, slice_scatter_default_4, 0, 4176, 8352); slice_scatter_default_3 = slice_scatter_default_4 = None | |
| # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:55 in foreach_all_gather, code: torch._foreach_copy_(foreach_copy_dsts, param_all_gather_inputs) | |
| slice_5: "f32[4176]" = torch.ops.aten.slice.Tensor(slice_scatter_default_5, 0, 4176, 8352) | |
| # No stacktrace found for following nodes | |
| slice_tensor_3: "f32[4176]" = torch.ops.aten.slice.Tensor(slice_scatter_default_5, 0, 4176, 8352) | |
| slice_scatter_default_6: "f32[4176]" = torch.ops.aten.slice_scatter.default(slice_tensor_3, getitem_7, 0, 4160, 4176); slice_tensor_3 = getitem_7 = None | |
| slice_scatter_default_7: "f32[8352]" = torch.ops.aten.slice_scatter.default(slice_scatter_default_5, slice_scatter_default_6, 0, 4176, 8352); slice_scatter_default_5 = slice_scatter_default_6 = None | |
| # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_functional_collectives.py:229 in all_gather_tensor, code: tensor = torch.ops._c10d_functional.all_gather_into_tensor( | |
| slice_10: "f32[4176]" = torch.ops.aten.slice.Tensor(slice_scatter_default_7, 0, 4176, 8352); slice_scatter_default_7 = None | |
| all_gather_into_tensor: "f32[8352]" = torch.ops._c10d_functional.all_gather_into_tensor.default(slice_10, 2, '0'); slice_10 = None | |
| # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_functional_collectives.py:144 in wait_tensor, code: return torch.ops._c10d_functional.wait_tensor(tensor) # type: ignore[attr-defined] | |
| wait_tensor: "f32[8352]" = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor); all_gather_into_tensor = None | |
| # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:128 in foreach_all_gather_copy_out, code: splits[i].contiguous().view(splits[i].numel()), | |
| view_1: "f32[2, 4176]" = torch.ops.aten.reshape.default(wait_tensor, [2, -1]) | |
| split_with_sizes_6 = torch.ops.aten.split_with_sizes.default(view_1, [2048, 64, 2048, 16], 1); view_1 = None | |
| getitem_28: "f32[2, 2048]" = split_with_sizes_6[0]; split_with_sizes_6 = None | |
| view_3: "f32[2, 4176]" = torch.ops.aten.reshape.default(wait_tensor, [2, -1]) | |
| split_with_sizes_7 = torch.ops.aten.split_with_sizes.default(view_3, [2048, 64, 2048, 16], 1); view_3 = None | |
| getitem_33: "f32[2, 64]" = split_with_sizes_7[1]; split_with_sizes_7 = None | |
| view_5: "f32[2, 4176]" = torch.ops.aten.reshape.default(wait_tensor, [2, -1]) | |
| split_with_sizes_8 = torch.ops.aten.split_with_sizes.default(view_5, [2048, 64, 2048, 16], 1); view_5 = None | |
| getitem_38: "f32[2, 2048]" = split_with_sizes_8[2]; split_with_sizes_8 = None | |
| view_7: "f32[2, 4176]" = torch.ops.aten.reshape.default(wait_tensor, [2, -1]); wait_tensor = None | |
| split_with_sizes_9 = torch.ops.aten.split_with_sizes.default(view_7, [2048, 64, 2048, 16], 1); view_7 = None | |
| getitem_43: "f32[2, 16]" = split_with_sizes_9[3]; split_with_sizes_9 = None | |
| clone_1: "f32[2, 2048]" = torch.ops.aten.clone.default(getitem_28, memory_format = torch.contiguous_format); getitem_28 = None | |
| view_2: "f32[4096]" = torch.ops.aten.reshape.default(clone_1, [4096]); clone_1 = None | |
| # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:127 in foreach_all_gather_copy_out, code: torch.as_strided( | |
| as_strided: "f32[128, 32]" = torch.ops.aten.as_strided.default(view_2, [128, 32], [32, 1], 0); view_2 = None | |
| # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:128 in foreach_all_gather_copy_out, code: splits[i].contiguous().view(splits[i].numel()), | |
| clone_2: "f32[2, 64]" = torch.ops.aten.clone.default(getitem_33, memory_format = torch.contiguous_format); getitem_33 = None | |
| view_4: "f32[128]" = torch.ops.aten.reshape.default(clone_2, [128]); clone_2 = None | |
| # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:127 in foreach_all_gather_copy_out, code: torch.as_strided( | |
| as_strided_1: "f32[128]" = torch.ops.aten.as_strided.default(view_4, [128], [1], 0); view_4 = None | |
| # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:128 in foreach_all_gather_copy_out, code: splits[i].contiguous().view(splits[i].numel()), | |
| clone_3: "f32[2, 2048]" = torch.ops.aten.clone.default(getitem_38, memory_format = torch.contiguous_format); getitem_38 = None | |
| view_6: "f32[4096]" = torch.ops.aten.reshape.default(clone_3, [4096]); clone_3 = None | |
| # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:127 in foreach_all_gather_copy_out, code: torch.as_strided( | |
| as_strided_2: "f32[32, 128]" = torch.ops.aten.as_strided.default(view_6, [32, 128], [128, 1], 0); view_6 = None | |
| # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:128 in foreach_all_gather_copy_out, code: splits[i].contiguous().view(splits[i].numel()), | |
| clone_4: "f32[2, 16]" = torch.ops.aten.clone.default(getitem_43, memory_format = torch.contiguous_format); getitem_43 = None | |
| view_8: "f32[32]" = torch.ops.aten.reshape.default(clone_4, [32]); clone_4 = None | |
| # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:127 in foreach_all_gather_copy_out, code: torch.as_strided( | |
| as_strided_3: "f32[32]" = torch.ops.aten.as_strided.default(view_8, [32], [1], 0); view_8 = None | |
| # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:138 in foreach_all_gather_copy_out, code: torch._foreach_copy_(out, splits_unpadded) | |
| _foreach_copy_1 = torch.ops.aten._foreach_copy.default([arg5_1, arg14_1, arg6_1, arg13_1], [as_strided, as_strided_1, as_strided_2, as_strided_3]); arg5_1 = arg14_1 = arg6_1 = arg13_1 = as_strided = as_strided_1 = as_strided_2 = as_strided_3 = None | |
| getitem_44: "f32[128, 32]" = _foreach_copy_1[0] | |
| getitem_46: "f32[32, 128]" = _foreach_copy_1[2]; _foreach_copy_1 = None | |
| # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:47 in foreach_all_gather, code: all_gather_output = torch.empty( | |
| empty_1: "f32[8352]" = torch.ops.aten.empty.memory_format([8352], dtype = torch.float32, device = device(type='cuda', index=1), pin_memory = False) | |
| # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:50 in foreach_all_gather, code: all_gather_input = all_gather_output.narrow( | |
| slice_11: "f32[4176]" = torch.ops.aten.slice.Tensor(empty_1, 0, 4176, 8352) | |
| # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:53 in foreach_all_gather, code: foreach_copy_dsts = torch.split(all_gather_input, inp_split_sizes) | |
| split_with_sizes_10 = torch.ops.aten.split_with_sizes.default(slice_11, [2048, 64, 2048, 16]); slice_11 = None | |
| getitem_48: "f32[2048]" = split_with_sizes_10[0] | |
| getitem_49: "f32[64]" = split_with_sizes_10[1] | |
| getitem_50: "f32[2048]" = split_with_sizes_10[2] | |
| getitem_51: "f32[16]" = split_with_sizes_10[3]; split_with_sizes_10 = None | |
| # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:55 in foreach_all_gather, code: torch._foreach_copy_(foreach_copy_dsts, param_all_gather_inputs) | |
| _foreach_copy_2 = torch.ops.aten._foreach_copy.default([getitem_48, getitem_49, getitem_50, getitem_51], [arg24_1, arg25_1, arg26_1, arg27_1]); getitem_48 = getitem_49 = getitem_50 = getitem_51 = arg24_1 = arg25_1 = arg26_1 = arg27_1 = None | |
| getitem_52: "f32[2048]" = _foreach_copy_2[0] | |
| getitem_53: "f32[64]" = _foreach_copy_2[1] | |
| getitem_54: "f32[2048]" = _foreach_copy_2[2] | |
| getitem_55: "f32[16]" = _foreach_copy_2[3]; _foreach_copy_2 = None | |
| slice_12: "f32[4176]" = torch.ops.aten.slice.Tensor(empty_1, 0, 4176, 8352) | |
| # No stacktrace found for following nodes | |
| slice_tensor_4: "f32[4176]" = torch.ops.aten.slice.Tensor(empty_1, 0, 4176, 8352) | |
| slice_scatter_default_8: "f32[4176]" = torch.ops.aten.slice_scatter.default(slice_tensor_4, getitem_52, 0, 0, 2048); slice_tensor_4 = getitem_52 = None | |
| slice_scatter_default_9: "f32[8352]" = torch.ops.aten.slice_scatter.default(empty_1, slice_scatter_default_8, 0, 4176, 8352); empty_1 = slice_scatter_default_8 = None | |
| # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:55 in foreach_all_gather, code: torch._foreach_copy_(foreach_copy_dsts, param_all_gather_inputs) | |
| slice_13: "f32[4176]" = torch.ops.aten.slice.Tensor(slice_scatter_default_9, 0, 4176, 8352) | |
| # No stacktrace found for following nodes | |
| slice_tensor_5: "f32[4176]" = torch.ops.aten.slice.Tensor(slice_scatter_default_9, 0, 4176, 8352) | |
| slice_scatter_default_10: "f32[4176]" = torch.ops.aten.slice_scatter.default(slice_tensor_5, getitem_53, 0, 2048, 2112); slice_tensor_5 = getitem_53 = None | |
| slice_scatter_default_11: "f32[8352]" = torch.ops.aten.slice_scatter.default(slice_scatter_default_9, slice_scatter_default_10, 0, 4176, 8352); slice_scatter_default_9 = slice_scatter_default_10 = None | |
| # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:55 in foreach_all_gather, code: torch._foreach_copy_(foreach_copy_dsts, param_all_gather_inputs) | |
| slice_14: "f32[4176]" = torch.ops.aten.slice.Tensor(slice_scatter_default_11, 0, 4176, 8352) | |
| # No stacktrace found for following nodes | |
| slice_tensor_6: "f32[4176]" = torch.ops.aten.slice.Tensor(slice_scatter_default_11, 0, 4176, 8352) | |
| slice_scatter_default_12: "f32[4176]" = torch.ops.aten.slice_scatter.default(slice_tensor_6, getitem_54, 0, 2112, 4160); slice_tensor_6 = getitem_54 = None | |
| slice_scatter_default_13: "f32[8352]" = torch.ops.aten.slice_scatter.default(slice_scatter_default_11, slice_scatter_default_12, 0, 4176, 8352); slice_scatter_default_11 = slice_scatter_default_12 = None | |
| # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:55 in foreach_all_gather, code: torch._foreach_copy_(foreach_copy_dsts, param_all_gather_inputs) | |
| slice_15: "f32[4176]" = torch.ops.aten.slice.Tensor(slice_scatter_default_13, 0, 4176, 8352) | |
| # No stacktrace found for following nodes | |
| slice_tensor_7: "f32[4176]" = torch.ops.aten.slice.Tensor(slice_scatter_default_13, 0, 4176, 8352) | |
| slice_scatter_default_14: "f32[4176]" = torch.ops.aten.slice_scatter.default(slice_tensor_7, getitem_55, 0, 4160, 4176); slice_tensor_7 = getitem_55 = None | |
| slice_scatter_default_15: "f32[8352]" = torch.ops.aten.slice_scatter.default(slice_scatter_default_13, slice_scatter_default_14, 0, 4176, 8352); slice_scatter_default_13 = slice_scatter_default_14 = None | |
| # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_functional_collectives.py:229 in all_gather_tensor, code: tensor = torch.ops._c10d_functional.all_gather_into_tensor( | |
| slice_20: "f32[4176]" = torch.ops.aten.slice.Tensor(slice_scatter_default_15, 0, 4176, 8352); slice_scatter_default_15 = None | |
| all_gather_into_tensor_1: "f32[8352]" = torch.ops._c10d_functional.all_gather_into_tensor.default(slice_20, 2, '0'); slice_20 = None | |
| # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_functional_collectives.py:144 in wait_tensor, code: return torch.ops._c10d_functional.wait_tensor(tensor) # type: ignore[attr-defined] | |
| wait_tensor_1: "f32[8352]" = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_1); all_gather_into_tensor_1 = None | |
| # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:128 in foreach_all_gather_copy_out, code: splits[i].contiguous().view(splits[i].numel()), | |
| view_12: "f32[2, 4176]" = torch.ops.aten.reshape.default(wait_tensor_1, [2, -1]) | |
| split_with_sizes_16 = torch.ops.aten.split_with_sizes.default(view_12, [2048, 64, 2048, 16], 1); view_12 = None | |
| getitem_76: "f32[2, 2048]" = split_with_sizes_16[0]; split_with_sizes_16 = None | |
| view_14: "f32[2, 4176]" = torch.ops.aten.reshape.default(wait_tensor_1, [2, -1]) | |
| split_with_sizes_17 = torch.ops.aten.split_with_sizes.default(view_14, [2048, 64, 2048, 16], 1); view_14 = None | |
| getitem_81: "f32[2, 64]" = split_with_sizes_17[1]; split_with_sizes_17 = None | |
| view_16: "f32[2, 4176]" = torch.ops.aten.reshape.default(wait_tensor_1, [2, -1]) | |
| split_with_sizes_18 = torch.ops.aten.split_with_sizes.default(view_16, [2048, 64, 2048, 16], 1); view_16 = None | |
| getitem_86: "f32[2, 2048]" = split_with_sizes_18[2]; split_with_sizes_18 = None | |
| view_18: "f32[2, 4176]" = torch.ops.aten.reshape.default(wait_tensor_1, [2, -1]); wait_tensor_1 = None | |
| split_with_sizes_19 = torch.ops.aten.split_with_sizes.default(view_18, [2048, 64, 2048, 16], 1); view_18 = None | |
| getitem_91: "f32[2, 16]" = split_with_sizes_19[3]; split_with_sizes_19 = None | |
| clone_9: "f32[2, 2048]" = torch.ops.aten.clone.default(getitem_76, memory_format = torch.contiguous_format); getitem_76 = None | |
| view_13: "f32[4096]" = torch.ops.aten.reshape.default(clone_9, [4096]); clone_9 = None | |
| # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:127 in foreach_all_gather_copy_out, code: torch.as_strided( | |
| as_strided_4: "f32[128, 32]" = torch.ops.aten.as_strided.default(view_13, [128, 32], [32, 1], 0); view_13 = None | |
| # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:128 in foreach_all_gather_copy_out, code: splits[i].contiguous().view(splits[i].numel()), | |
| clone_10: "f32[2, 64]" = torch.ops.aten.clone.default(getitem_81, memory_format = torch.contiguous_format); getitem_81 = None | |
| view_15: "f32[128]" = torch.ops.aten.reshape.default(clone_10, [128]); clone_10 = None | |
| # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:127 in foreach_all_gather_copy_out, code: torch.as_strided( | |
| as_strided_5: "f32[128]" = torch.ops.aten.as_strided.default(view_15, [128], [1], 0); view_15 = None | |
| # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:128 in foreach_all_gather_copy_out, code: splits[i].contiguous().view(splits[i].numel()), | |
| clone_11: "f32[2, 2048]" = torch.ops.aten.clone.default(getitem_86, memory_format = torch.contiguous_format); getitem_86 = None | |
| view_17: "f32[4096]" = torch.ops.aten.reshape.default(clone_11, [4096]); clone_11 = None | |
| # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:127 in foreach_all_gather_copy_out, code: torch.as_strided( | |
| as_strided_6: "f32[32, 128]" = torch.ops.aten.as_strided.default(view_17, [32, 128], [128, 1], 0); view_17 = None | |
| # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:128 in foreach_all_gather_copy_out, code: splits[i].contiguous().view(splits[i].numel()), | |
| clone_12: "f32[2, 16]" = torch.ops.aten.clone.default(getitem_91, memory_format = torch.contiguous_format); getitem_91 = None | |
| view_19: "f32[32]" = torch.ops.aten.reshape.default(clone_12, [32]); clone_12 = None | |
| # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:127 in foreach_all_gather_copy_out, code: torch.as_strided( | |
| as_strided_7: "f32[32]" = torch.ops.aten.as_strided.default(view_19, [32], [1], 0); view_19 = None | |
| # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:138 in foreach_all_gather_copy_out, code: torch._foreach_copy_(out, splits_unpadded) | |
| _foreach_copy_3 = torch.ops.aten._foreach_copy.default([arg3_1, arg16_1, arg4_1, arg15_1], [as_strided_4, as_strided_5, as_strided_6, as_strided_7]); arg3_1 = arg16_1 = arg4_1 = arg15_1 = as_strided_4 = as_strided_5 = as_strided_6 = as_strided_7 = None | |
| getitem_92: "f32[128, 32]" = _foreach_copy_3[0] | |
| getitem_94: "f32[32, 128]" = _foreach_copy_3[2]; _foreach_copy_3 = None | |
| # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:47 in foreach_all_gather, code: all_gather_output = torch.empty( | |
| empty_2: "f32[8352]" = torch.ops.aten.empty.memory_format([8352], dtype = torch.float32, device = device(type='cuda', index=1), pin_memory = False) | |
| # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:50 in foreach_all_gather, code: all_gather_input = all_gather_output.narrow( | |
| slice_21: "f32[4176]" = torch.ops.aten.slice.Tensor(empty_2, 0, 4176, 8352) | |
| # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:53 in foreach_all_gather, code: foreach_copy_dsts = torch.split(all_gather_input, inp_split_sizes) | |
| split_with_sizes_20 = torch.ops.aten.split_with_sizes.default(slice_21, [2048, 64, 2048, 16]); slice_21 = None | |
| getitem_96: "f32[2048]" = split_with_sizes_20[0] | |
| getitem_97: "f32[64]" = split_with_sizes_20[1] | |
| getitem_98: "f32[2048]" = split_with_sizes_20[2] | |
| getitem_99: "f32[16]" = split_with_sizes_20[3]; split_with_sizes_20 = None | |
| # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:55 in foreach_all_gather, code: torch._foreach_copy_(foreach_copy_dsts, param_all_gather_inputs) | |
| _foreach_copy_4 = torch.ops.aten._foreach_copy.default([getitem_96, getitem_97, getitem_98, getitem_99], [arg28_1, arg29_1, arg30_1, arg31_1]); getitem_96 = getitem_97 = getitem_98 = getitem_99 = arg28_1 = arg29_1 = arg30_1 = arg31_1 = None | |
| getitem_100: "f32[2048]" = _foreach_copy_4[0] | |
| getitem_101: "f32[64]" = _foreach_copy_4[1] | |
| getitem_102: "f32[2048]" = _foreach_copy_4[2] | |
| getitem_103: "f32[16]" = _foreach_copy_4[3]; _foreach_copy_4 = None | |
| slice_22: "f32[4176]" = torch.ops.aten.slice.Tensor(empty_2, 0, 4176, 8352) | |
| # No stacktrace found for following nodes | |
| slice_tensor_8: "f32[4176]" = torch.ops.aten.slice.Tensor(empty_2, 0, 4176, 8352) | |
| slice_scatter_default_16: "f32[4176]" = torch.ops.aten.slice_scatter.default(slice_tensor_8, getitem_100, 0, 0, 2048); slice_tensor_8 = getitem_100 = None | |
| slice_scatter_default_17: "f32[8352]" = torch.ops.aten.slice_scatter.default(empty_2, slice_scatter_default_16, 0, 4176, 8352); empty_2 = slice_scatter_default_16 = None | |
| # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:55 in foreach_all_gather, code: torch._foreach_copy_(foreach_copy_dsts, param_all_gather_inputs) | |
| slice_23: "f32[4176]" = torch.ops.aten.slice.Tensor(slice_scatter_default_17, 0, 4176, 8352) | |
| # No stacktrace found for following nodes | |
| slice_tensor_9: "f32[4176]" = torch.ops.aten.slice.Tensor(slice_scatter_default_17, 0, 4176, 8352) | |
| slice_scatter_default_18: "f32[4176]" = torch.ops.aten.slice_scatter.default(slice_tensor_9, getitem_101, 0, 2048, 2112); slice_tensor_9 = getitem_101 = None | |
| slice_scatter_default_19: "f32[8352]" = torch.ops.aten.slice_scatter.default(slice_scatter_default_17, slice_scatter_default_18, 0, 4176, 8352); slice_scatter_default_17 = slice_scatter_default_18 = None | |
| # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:55 in foreach_all_gather, code: torch._foreach_copy_(foreach_copy_dsts, param_all_gather_inputs) | |
| slice_24: "f32[4176]" = torch.ops.aten.slice.Tensor(slice_scatter_default_19, 0, 4176, 8352) | |
| # No stacktrace found for following nodes | |
| slice_tensor_10: "f32[4176]" = torch.ops.aten.slice.Tensor(slice_scatter_default_19, 0, 4176, 8352) | |
| slice_scatter_default_20: "f32[4176]" = torch.ops.aten.slice_scatter.default(slice_tensor_10, getitem_102, 0, 2112, 4160); slice_tensor_10 = getitem_102 = None | |
| slice_scatter_default_21: "f32[8352]" = torch.ops.aten.slice_scatter.default(slice_scatter_default_19, slice_scatter_default_20, 0, 4176, 8352); slice_scatter_default_19 = slice_scatter_default_20 = None | |
| # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:55 in foreach_all_gather, code: torch._foreach_copy_(foreach_copy_dsts, param_all_gather_inputs) | |
| slice_25: "f32[4176]" = torch.ops.aten.slice.Tensor(slice_scatter_default_21, 0, 4176, 8352) | |
| # No stacktrace found for following nodes | |
| slice_tensor_11: "f32[4176]" = torch.ops.aten.slice.Tensor(slice_scatter_default_21, 0, 4176, 8352) | |
| slice_scatter_default_22: "f32[4176]" = torch.ops.aten.slice_scatter.default(slice_tensor_11, getitem_103, 0, 4160, 4176); slice_tensor_11 = getitem_103 = None | |
| slice_scatter_default_23: "f32[8352]" = torch.ops.aten.slice_scatter.default(slice_scatter_default_21, slice_scatter_default_22, 0, 4176, 8352); slice_scatter_default_21 = slice_scatter_default_22 = None | |
| # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_functional_collectives.py:229 in all_gather_tensor, code: tensor = torch.ops._c10d_functional.all_gather_into_tensor( | |
| slice_30: "f32[4176]" = torch.ops.aten.slice.Tensor(slice_scatter_default_23, 0, 4176, 8352); slice_scatter_default_23 = None | |
| all_gather_into_tensor_2: "f32[8352]" = torch.ops._c10d_functional.all_gather_into_tensor.default(slice_30, 2, '0'); slice_30 = None | |
| # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_functional_collectives.py:144 in wait_tensor, code: return torch.ops._c10d_functional.wait_tensor(tensor) # type: ignore[attr-defined] | |
| wait_tensor_2: "f32[8352]" = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_2); all_gather_into_tensor_2 = None | |
| # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:128 in foreach_all_gather_copy_out, code: splits[i].contiguous().view(splits[i].numel()), | |
| view_23: "f32[2, 4176]" = torch.ops.aten.reshape.default(wait_tensor_2, [2, -1]) | |
| split_with_sizes_26 = torch.ops.aten.split_with_sizes.default(view_23, [2048, 64, 2048, 16], 1); view_23 = None | |
| getitem_124: "f32[2, 2048]" = split_with_sizes_26[0]; split_with_sizes_26 = None | |
| view_25: "f32[2, 4176]" = torch.ops.aten.reshape.default(wait_tensor_2, [2, -1]) | |
| split_with_sizes_27 = torch.ops.aten.split_with_sizes.default(view_25, [2048, 64, 2048, 16], 1); view_25 = None | |
| getitem_129: "f32[2, 64]" = split_with_sizes_27[1]; split_with_sizes_27 = None | |
| view_27: "f32[2, 4176]" = torch.ops.aten.reshape.default(wait_tensor_2, [2, -1]) | |
| split_with_sizes_28 = torch.ops.aten.split_with_sizes.default(view_27, [2048, 64, 2048, 16], 1); view_27 = None | |
| getitem_134: "f32[2, 2048]" = split_with_sizes_28[2]; split_with_sizes_28 = None | |
| view_29: "f32[2, 4176]" = torch.ops.aten.reshape.default(wait_tensor_2, [2, -1]); wait_tensor_2 = None | |
| split_with_sizes_29 = torch.ops.aten.split_with_sizes.default(view_29, [2048, 64, 2048, 16], 1); view_29 = None | |
| getitem_139: "f32[2, 16]" = split_with_sizes_29[3]; split_with_sizes_29 = None | |
| clone_17: "f32[2, 2048]" = torch.ops.aten.clone.default(getitem_124, memory_format = torch.contiguous_format); getitem_124 = None | |
| view_24: "f32[4096]" = torch.ops.aten.reshape.default(clone_17, [4096]); clone_17 = None | |
| # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:127 in foreach_all_gather_copy_out, code: torch.as_strided( | |
| as_strided_8: "f32[128, 32]" = torch.ops.aten.as_strided.default(view_24, [128, 32], [32, 1], 0); view_24 = None | |
| # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:128 in foreach_all_gather_copy_out, code: splits[i].contiguous().view(splits[i].numel()), | |
| clone_18: "f32[2, 64]" = torch.ops.aten.clone.default(getitem_129, memory_format = torch.contiguous_format); getitem_129 = None | |
| view_26: "f32[128]" = torch.ops.aten.reshape.default(clone_18, [128]); clone_18 = None | |
| # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:127 in foreach_all_gather_copy_out, code: torch.as_strided( | |
| as_strided_9: "f32[128]" = torch.ops.aten.as_strided.default(view_26, [128], [1], 0); view_26 = None | |
| # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:128 in foreach_all_gather_copy_out, code: splits[i].contiguous().view(splits[i].numel()), | |
| clone_19: "f32[2, 2048]" = torch.ops.aten.clone.default(getitem_134, memory_format = torch.contiguous_format); getitem_134 = None | |
| view_28: "f32[4096]" = torch.ops.aten.reshape.default(clone_19, [4096]); clone_19 = None | |
| # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:127 in foreach_all_gather_copy_out, code: torch.as_strided( | |
| as_strided_10: "f32[32, 128]" = torch.ops.aten.as_strided.default(view_28, [32, 128], [128, 1], 0); view_28 = None | |
| # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:128 in foreach_all_gather_copy_out, code: splits[i].contiguous().view(splits[i].numel()), | |
| clone_20: "f32[2, 16]" = torch.ops.aten.clone.default(getitem_139, memory_format = torch.contiguous_format); getitem_139 = None | |
| view_30: "f32[32]" = torch.ops.aten.reshape.default(clone_20, [32]); clone_20 = None | |
| # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:127 in foreach_all_gather_copy_out, code: torch.as_strided( | |
| as_strided_11: "f32[32]" = torch.ops.aten.as_strided.default(view_30, [32], [1], 0); view_30 = None | |
| # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:138 in foreach_all_gather_copy_out, code: torch._foreach_copy_(out, splits_unpadded) | |
| _foreach_copy_5 = torch.ops.aten._foreach_copy.default([arg19_1, arg18_1, arg2_1, arg17_1], [as_strided_8, as_strided_9, as_strided_10, as_strided_11]); arg19_1 = arg18_1 = arg2_1 = arg17_1 = as_strided_8 = as_strided_9 = as_strided_10 = as_strided_11 = None | |
| getitem_142: "f32[32, 128]" = _foreach_copy_5[2]; _foreach_copy_5 = None | |
| # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:191 in foreach_reduce, code: post_reduce_output = reduce_scatter_input.new_empty( | |
| empty_8: "f32[4176]" = torch.ops.aten.empty.memory_format([4176], dtype = torch.float32, layout = torch.strided, device = device(type='cuda', index=1), pin_memory = False) | |
| # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:180 in foreach_reduce, code: reduce_scatter_input = torch.empty( | |
| empty_7: "f32[8352]" = torch.ops.aten.empty.memory_format([8352], dtype = torch.float32, device = device(type='cuda', index=1), pin_memory = False) | |
| # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:264 in foreach_reduce_scatter_copy_in, code: reduce_scatter_input_view = reduce_scatter_input.view(world_size, -1) | |
| view_59: "f32[2, 4176]" = torch.ops.aten.reshape.default(empty_7, [2, -1]); empty_7 = None | |
| # File: <eval_with_key>.201:67 in forward, code: le = torch.ops.aten.le.Scalar(alias_11, 0); alias_11 = None | |
| le: "b8[8, 128]" = torch.ops.aten.le.Scalar(arg11_1, 0) | |
| # File: <eval_with_key>.201:52 in forward, code: scalar_tensor = torch.ops.aten.scalar_tensor.default(0, dtype = torch.float32, layout = torch.strided, device = device(type='cuda', index=1)) | |
| scalar_tensor: "f32[]" = torch.ops.aten.scalar_tensor.default(0, dtype = torch.float32, layout = torch.strided, device = device(type='cuda', index=1)) | |
| # File: <eval_with_key>.201:25 in forward, code: expand = torch.ops.aten.expand.default(getitem, [8, 32]); getitem = None | |
| expand: "f32[8, 32]" = torch.ops.aten.expand.default(arg0_1, [8, 32]); arg0_1 = None | |
| # File: <eval_with_key>.201:39 in forward, code: clone = torch.ops.aten.clone.default(expand, memory_format = torch.contiguous_format); expand = None | |
| clone: "f32[8, 32]" = torch.ops.aten.clone.default(expand, memory_format = torch.contiguous_format); expand = None | |
| # File: <eval_with_key>.201:53 in forward, code: where = torch.ops.aten.where.self(getitem_12, scalar_tensor, trace_wrapped_1); getitem_12 = trace_wrapped_1 = None | |
| where: "f32[8, 32]" = torch.ops.aten.where.self(arg12_1, scalar_tensor, clone); arg12_1 = clone = None | |
| # File: <eval_with_key>.201:56 in forward, code: mm = torch.ops.aten.mm.default(where, permute_1); permute_1 = None | |
| permute_2: "f32[128, 32]" = torch.ops.aten.permute.default(getitem_46, [1, 0]); getitem_46 = None | |
| permute_3: "f32[32, 128]" = torch.ops.aten.permute.default(permute_2, [1, 0]); permute_2 = None | |
| mm: "f32[8, 128]" = torch.ops.aten.mm.default(where, permute_3); permute_3 = None | |
| # File: <eval_with_key>.201:68 in forward, code: where_1 = torch.ops.aten.where.self(le, scalar_tensor, mm); le = mm = None | |
| where_1: "f32[8, 128]" = torch.ops.aten.where.self(le, scalar_tensor, mm); le = mm = None | |
| # File: <eval_with_key>.201:72 in forward, code: permute_7 = torch.ops.aten.permute.default(where_1, [1, 0]) | |
| permute_11: "f32[128, 8]" = torch.ops.aten.permute.default(where_1, [1, 0]) | |
| # File: <eval_with_key>.201:73 in forward, code: mm_3 = torch.ops.aten.mm.default(permute_7, getitem_10); permute_7 = getitem_10 = None | |
| mm_3: "f32[128, 32]" = torch.ops.aten.mm.default(permute_11, arg10_1); permute_11 = None | |
| # File: <eval_with_key>.201:74 in forward, code: permute_8 = torch.ops.aten.permute.default(mm_3, [1, 0]); mm_3 = None | |
| permute_12: "f32[32, 128]" = torch.ops.aten.permute.default(mm_3, [1, 0]); mm_3 = None | |
| # File: <eval_with_key>.201:78 in forward, code: permute_9 = torch.ops.aten.permute.default(permute_8, [1, 0]); permute_8 = None | |
| permute_13: "f32[128, 32]" = torch.ops.aten.permute.default(permute_12, [1, 0]); permute_12 = None | |
| # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:256 in foreach_reduce_scatter_copy_in, code: grad_views.append(grad.view(world_size, -1)) | |
| view_55: "f32[2, 2048]" = torch.ops.aten.reshape.default(permute_13, [2, -1]); permute_13 = None | |
| # File: <eval_with_key>.201:75 in forward, code: sum_2 = torch.ops.aten.sum.dim_IntList(where_1, [0], True); where_1 = None | |
| sum_2: "f32[1, 128]" = torch.ops.aten.sum.dim_IntList(where_1, [0], True) | |
| # File: <eval_with_key>.201:76 in forward, code: view_1 = torch.ops.aten.view.default(sum_2, [128]); sum_2 = None | |
| view_10: "f32[128]" = torch.ops.aten.reshape.default(sum_2, [128]); sum_2 = None | |
| # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:256 in foreach_reduce_scatter_copy_in, code: grad_views.append(grad.view(world_size, -1)) | |
| view_56: "f32[2, 64]" = torch.ops.aten.reshape.default(view_10, [2, -1]); view_10 = None | |
| # File: <eval_with_key>.201:57 in forward, code: permute_2 = torch.ops.aten.permute.default(where, [1, 0]) | |
| permute_4: "f32[32, 8]" = torch.ops.aten.permute.default(where, [1, 0]) | |
| # File: <eval_with_key>.201:58 in forward, code: mm_1 = torch.ops.aten.mm.default(permute_2, getitem_11); permute_2 = getitem_11 = None | |
| mm_1: "f32[32, 128]" = torch.ops.aten.mm.default(permute_4, arg11_1); permute_4 = arg11_1 = None | |
| # File: <eval_with_key>.201:59 in forward, code: permute_3 = torch.ops.aten.permute.default(mm_1, [1, 0]); mm_1 = None | |
| permute_5: "f32[128, 32]" = torch.ops.aten.permute.default(mm_1, [1, 0]); mm_1 = None | |
| # File: <eval_with_key>.201:63 in forward, code: permute_4 = torch.ops.aten.permute.default(permute_3, [1, 0]); permute_3 = None | |
| permute_6: "f32[32, 128]" = torch.ops.aten.permute.default(permute_5, [1, 0]); permute_5 = None | |
| # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:256 in foreach_reduce_scatter_copy_in, code: grad_views.append(grad.view(world_size, -1)) | |
| view_57: "f32[2, 2048]" = torch.ops.aten.reshape.default(permute_6, [2, -1]); permute_6 = None | |
| # File: <eval_with_key>.201:60 in forward, code: sum_1 = torch.ops.aten.sum.dim_IntList(where, [0], True); where = None | |
| sum_1: "f32[1, 32]" = torch.ops.aten.sum.dim_IntList(where, [0], True); where = None | |
| # File: <eval_with_key>.201:61 in forward, code: view = torch.ops.aten.view.default(sum_1, [32]); sum_1 = None | |
| view_9: "f32[32]" = torch.ops.aten.reshape.default(sum_1, [32]); sum_1 = None | |
| # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:256 in foreach_reduce_scatter_copy_in, code: grad_views.append(grad.view(world_size, -1)) | |
| view_58: "f32[2, 16]" = torch.ops.aten.reshape.default(view_9, [2, -1]); view_9 = None | |
| # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:263 in foreach_reduce_scatter_copy_in, code: cat_out = torch.cat(grad_views, dim=-1) | |
| cat_2: "f32[2, 4176]" = torch.ops.aten.cat.default([view_55, view_56, view_57, view_58], 1); view_55 = view_56 = view_57 = view_58 = None | |
| # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:266 in foreach_reduce_scatter_copy_in, code: reduce_scatter_input_view.copy_(cat_out) | |
| view_60: "f32[8352]" = torch.ops.aten.reshape.default(cat_2, [8352]); cat_2 = None | |
| # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:297 in _div_if_needed, code: tensor.div_(div_factor) | |
| div_2: "f32[8352]" = torch.ops.aten.div.Tensor(view_60, 2.0); view_60 = None | |
| # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_functional_collectives.py:287 in reduce_scatter_tensor, code: tensor = torch.ops._c10d_functional.reduce_scatter_tensor( | |
| reduce_scatter_tensor_2: "f32[4176]" = torch.ops._c10d_functional.reduce_scatter_tensor.default(div_2, 'sum', 2, '0'); div_2 = None | |
| # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_functional_collectives.py:144 in wait_tensor, code: return torch.ops._c10d_functional.wait_tensor(tensor) # type: ignore[attr-defined] | |
| wait_tensor_5: "f32[4176]" = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_2); reduce_scatter_tensor_2 = None | |
| # No stacktrace found for following nodes | |
| as_strided_24: "f32[64, 32]" = torch.ops.aten.as_strided.default(wait_tensor_5, [64, 32], [32, 1], 0) | |
| as_strided_25: "f32[64]" = torch.ops.aten.as_strided.default(wait_tensor_5, [64], [1], 2048) | |
| as_strided_26: "f32[16, 128]" = torch.ops.aten.as_strided.default(wait_tensor_5, [16, 128], [128, 1], 2112) | |
| as_strided_27: "f32[16]" = torch.ops.aten.as_strided.default(wait_tensor_5, [16], [1], 4160); wait_tensor_5 = None | |
| # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:191 in foreach_reduce, code: post_reduce_output = reduce_scatter_input.new_empty( | |
| empty_6: "f32[4176]" = torch.ops.aten.empty.memory_format([4176], dtype = torch.float32, layout = torch.strided, device = device(type='cuda', index=1), pin_memory = False) | |
| # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:180 in foreach_reduce, code: reduce_scatter_input = torch.empty( | |
| empty_5: "f32[8352]" = torch.ops.aten.empty.memory_format([8352], dtype = torch.float32, device = device(type='cuda', index=1), pin_memory = False) | |
| # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:264 in foreach_reduce_scatter_copy_in, code: reduce_scatter_input_view = reduce_scatter_input.view(world_size, -1) | |
| view_48: "f32[2, 4176]" = torch.ops.aten.reshape.default(empty_5, [2, -1]); empty_5 = None | |
| # File: <eval_with_key>.201:98 in forward, code: le_2 = torch.ops.aten.le.Scalar(alias_15, 0); alias_15 = None | |
| le_2: "b8[8, 128]" = torch.ops.aten.le.Scalar(arg9_1, 0) | |
| # File: <eval_with_key>.201:83 in forward, code: le_1 = torch.ops.aten.le.Scalar(alias_13, 0); alias_13 = None | |
| le_1: "b8[8, 32]" = torch.ops.aten.le.Scalar(arg10_1, 0); arg10_1 = None | |
| # File: <eval_with_key>.201:71 in forward, code: mm_2 = torch.ops.aten.mm.default(where_1, permute_6); permute_6 = None | |
| permute_9: "f32[32, 128]" = torch.ops.aten.permute.default(getitem_44, [1, 0]); getitem_44 = None | |
| permute_10: "f32[128, 32]" = torch.ops.aten.permute.default(permute_9, [1, 0]); permute_9 = None | |
| mm_2: "f32[8, 32]" = torch.ops.aten.mm.default(where_1, permute_10); where_1 = permute_10 = None | |
| # File: <eval_with_key>.201:84 in forward, code: where_2 = torch.ops.aten.where.self(le_1, scalar_tensor, trace_wrapped_2); le_1 = trace_wrapped_2 = None | |
| where_2: "f32[8, 32]" = torch.ops.aten.where.self(le_1, scalar_tensor, mm_2); le_1 = mm_2 = None | |
| # File: <eval_with_key>.201:87 in forward, code: mm_4 = torch.ops.aten.mm.default(where_2, permute_11); permute_11 = None | |
| permute_16: "f32[128, 32]" = torch.ops.aten.permute.default(getitem_94, [1, 0]); getitem_94 = None | |
| permute_17: "f32[32, 128]" = torch.ops.aten.permute.default(permute_16, [1, 0]); permute_16 = None | |
| mm_4: "f32[8, 128]" = torch.ops.aten.mm.default(where_2, permute_17); permute_17 = None | |
| # File: <eval_with_key>.201:99 in forward, code: where_3 = torch.ops.aten.where.self(le_2, scalar_tensor, mm_4); le_2 = mm_4 = None | |
| where_3: "f32[8, 128]" = torch.ops.aten.where.self(le_2, scalar_tensor, mm_4); le_2 = mm_4 = None | |
| # File: <eval_with_key>.201:103 in forward, code: permute_17 = torch.ops.aten.permute.default(where_3, [1, 0]) | |
| permute_25: "f32[128, 8]" = torch.ops.aten.permute.default(where_3, [1, 0]) | |
| # File: <eval_with_key>.201:104 in forward, code: mm_7 = torch.ops.aten.mm.default(permute_17, getitem_8); permute_17 = getitem_8 = None | |
| mm_7: "f32[128, 32]" = torch.ops.aten.mm.default(permute_25, arg8_1); permute_25 = None | |
| # File: <eval_with_key>.201:105 in forward, code: permute_18 = torch.ops.aten.permute.default(mm_7, [1, 0]); mm_7 = None | |
| permute_26: "f32[32, 128]" = torch.ops.aten.permute.default(mm_7, [1, 0]); mm_7 = None | |
| # File: <eval_with_key>.201:109 in forward, code: permute_19 = torch.ops.aten.permute.default(permute_18, [1, 0]); permute_18 = None | |
| permute_27: "f32[128, 32]" = torch.ops.aten.permute.default(permute_26, [1, 0]); permute_26 = None | |
| # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:256 in foreach_reduce_scatter_copy_in, code: grad_views.append(grad.view(world_size, -1)) | |
| view_44: "f32[2, 2048]" = torch.ops.aten.reshape.default(permute_27, [2, -1]); permute_27 = None | |
| # File: <eval_with_key>.201:106 in forward, code: sum_4 = torch.ops.aten.sum.dim_IntList(where_3, [0], True); where_3 = None | |
| sum_4: "f32[1, 128]" = torch.ops.aten.sum.dim_IntList(where_3, [0], True) | |
| # File: <eval_with_key>.201:107 in forward, code: view_3 = torch.ops.aten.view.default(sum_4, [128]); sum_4 = None | |
| view_21: "f32[128]" = torch.ops.aten.reshape.default(sum_4, [128]); sum_4 = None | |
| # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:256 in foreach_reduce_scatter_copy_in, code: grad_views.append(grad.view(world_size, -1)) | |
| view_45: "f32[2, 64]" = torch.ops.aten.reshape.default(view_21, [2, -1]); view_21 = None | |
| # File: <eval_with_key>.201:88 in forward, code: permute_12 = torch.ops.aten.permute.default(where_2, [1, 0]) | |
| permute_18: "f32[32, 8]" = torch.ops.aten.permute.default(where_2, [1, 0]) | |
| # File: <eval_with_key>.201:89 in forward, code: mm_5 = torch.ops.aten.mm.default(permute_12, getitem_9); permute_12 = getitem_9 = None | |
| mm_5: "f32[32, 128]" = torch.ops.aten.mm.default(permute_18, arg9_1); permute_18 = arg9_1 = None | |
| # File: <eval_with_key>.201:90 in forward, code: permute_13 = torch.ops.aten.permute.default(mm_5, [1, 0]); mm_5 = None | |
| permute_19: "f32[128, 32]" = torch.ops.aten.permute.default(mm_5, [1, 0]); mm_5 = None | |
| # File: <eval_with_key>.201:94 in forward, code: permute_14 = torch.ops.aten.permute.default(permute_13, [1, 0]); permute_13 = None | |
| permute_20: "f32[32, 128]" = torch.ops.aten.permute.default(permute_19, [1, 0]); permute_19 = None | |
| # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:256 in foreach_reduce_scatter_copy_in, code: grad_views.append(grad.view(world_size, -1)) | |
| view_46: "f32[2, 2048]" = torch.ops.aten.reshape.default(permute_20, [2, -1]); permute_20 = None | |
| # File: <eval_with_key>.201:91 in forward, code: sum_3 = torch.ops.aten.sum.dim_IntList(where_2, [0], True); where_2 = None | |
| sum_3: "f32[1, 32]" = torch.ops.aten.sum.dim_IntList(where_2, [0], True); where_2 = None | |
| # File: <eval_with_key>.201:92 in forward, code: view_2 = torch.ops.aten.view.default(sum_3, [32]); sum_3 = None | |
| view_20: "f32[32]" = torch.ops.aten.reshape.default(sum_3, [32]); sum_3 = None | |
| # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:256 in foreach_reduce_scatter_copy_in, code: grad_views.append(grad.view(world_size, -1)) | |
| view_47: "f32[2, 16]" = torch.ops.aten.reshape.default(view_20, [2, -1]); view_20 = None | |
| # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:263 in foreach_reduce_scatter_copy_in, code: cat_out = torch.cat(grad_views, dim=-1) | |
| cat_1: "f32[2, 4176]" = torch.ops.aten.cat.default([view_44, view_45, view_46, view_47], 1); view_44 = view_45 = view_46 = view_47 = None | |
| # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:266 in foreach_reduce_scatter_copy_in, code: reduce_scatter_input_view.copy_(cat_out) | |
| view_49: "f32[8352]" = torch.ops.aten.reshape.default(cat_1, [8352]); cat_1 = None | |
| # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:297 in _div_if_needed, code: tensor.div_(div_factor) | |
| div_1: "f32[8352]" = torch.ops.aten.div.Tensor(view_49, 2.0); view_49 = None | |
| # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_functional_collectives.py:287 in reduce_scatter_tensor, code: tensor = torch.ops._c10d_functional.reduce_scatter_tensor( | |
| reduce_scatter_tensor_1: "f32[4176]" = torch.ops._c10d_functional.reduce_scatter_tensor.default(div_1, 'sum', 2, '0'); div_1 = None | |
| # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_functional_collectives.py:144 in wait_tensor, code: return torch.ops._c10d_functional.wait_tensor(tensor) # type: ignore[attr-defined] | |
| wait_tensor_4: "f32[4176]" = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_1); reduce_scatter_tensor_1 = None | |
| # No stacktrace found for following nodes | |
| as_strided_28: "f32[64, 32]" = torch.ops.aten.as_strided.default(wait_tensor_4, [64, 32], [32, 1], 0) | |
| as_strided_29: "f32[64]" = torch.ops.aten.as_strided.default(wait_tensor_4, [64], [1], 2048) | |
| as_strided_30: "f32[16, 128]" = torch.ops.aten.as_strided.default(wait_tensor_4, [16, 128], [128, 1], 2112) | |
| as_strided_31: "f32[16]" = torch.ops.aten.as_strided.default(wait_tensor_4, [16], [1], 4160); wait_tensor_4 = None | |
| # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:191 in foreach_reduce, code: post_reduce_output = reduce_scatter_input.new_empty( | |
| empty_4: "f32[4176]" = torch.ops.aten.empty.memory_format([4176], dtype = torch.float32, layout = torch.strided, device = device(type='cuda', index=1), pin_memory = False) | |
| # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:180 in foreach_reduce, code: reduce_scatter_input = torch.empty( | |
| empty_3: "f32[8352]" = torch.ops.aten.empty.memory_format([8352], dtype = torch.float32, device = device(type='cuda', index=1), pin_memory = False) | |
| # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:264 in foreach_reduce_scatter_copy_in, code: reduce_scatter_input_view = reduce_scatter_input.view(world_size, -1) | |
| view_37: "f32[2, 4176]" = torch.ops.aten.reshape.default(empty_3, [2, -1]); empty_3 = None | |
| # File: <eval_with_key>.201:129 in forward, code: le_4 = torch.ops.aten.le.Scalar(alias_19, 0); alias_19 = None | |
| le_4: "b8[8, 128]" = torch.ops.aten.le.Scalar(arg7_1, 0) | |
| # File: <eval_with_key>.201:114 in forward, code: le_3 = torch.ops.aten.le.Scalar(alias_17, 0); alias_17 = None | |
| le_3: "b8[8, 32]" = torch.ops.aten.le.Scalar(arg8_1, 0); arg8_1 = None | |
| # File: <eval_with_key>.201:102 in forward, code: mm_6 = torch.ops.aten.mm.default(where_3, permute_16); permute_16 = None | |
| permute_23: "f32[32, 128]" = torch.ops.aten.permute.default(getitem_92, [1, 0]); getitem_92 = None | |
| permute_24: "f32[128, 32]" = torch.ops.aten.permute.default(permute_23, [1, 0]); permute_23 = None | |
| mm_6: "f32[8, 32]" = torch.ops.aten.mm.default(where_3, permute_24); where_3 = permute_24 = None | |
| # File: <eval_with_key>.201:115 in forward, code: where_4 = torch.ops.aten.where.self(le_3, scalar_tensor, trace_wrapped_3); le_3 = trace_wrapped_3 = None | |
| where_4: "f32[8, 32]" = torch.ops.aten.where.self(le_3, scalar_tensor, mm_6); le_3 = mm_6 = None | |
| # File: <eval_with_key>.201:118 in forward, code: mm_8 = torch.ops.aten.mm.default(where_4, permute_21); permute_21 = None | |
| permute_30: "f32[128, 32]" = torch.ops.aten.permute.default(getitem_142, [1, 0]); getitem_142 = None | |
| permute_31: "f32[32, 128]" = torch.ops.aten.permute.default(permute_30, [1, 0]); permute_30 = None | |
| mm_8: "f32[8, 128]" = torch.ops.aten.mm.default(where_4, permute_31); permute_31 = None | |
| # File: <eval_with_key>.201:130 in forward, code: where_5 = torch.ops.aten.where.self(le_4, scalar_tensor, mm_8); le_4 = scalar_tensor = mm_8 = None | |
| where_5: "f32[8, 128]" = torch.ops.aten.where.self(le_4, scalar_tensor, mm_8); le_4 = scalar_tensor = mm_8 = None | |
| # File: <eval_with_key>.201:131 in forward, code: permute_25 = torch.ops.aten.permute.default(where_5, [1, 0]) | |
| permute_35: "f32[128, 8]" = torch.ops.aten.permute.default(where_5, [1, 0]) | |
| # File: <eval_with_key>.201:132 in forward, code: mm_10 = torch.ops.aten.mm.default(permute_25, getitem_1); permute_25 = getitem_1 = None | |
| mm_10: "f32[128, 32]" = torch.ops.aten.mm.default(permute_35, arg1_1); permute_35 = arg1_1 = None | |
| # File: <eval_with_key>.201:133 in forward, code: permute_26 = torch.ops.aten.permute.default(mm_10, [1, 0]); mm_10 = None | |
| permute_36: "f32[32, 128]" = torch.ops.aten.permute.default(mm_10, [1, 0]); mm_10 = None | |
| # File: <eval_with_key>.201:137 in forward, code: permute_27 = torch.ops.aten.permute.default(permute_26, [1, 0]); permute_26 = None | |
| permute_37: "f32[128, 32]" = torch.ops.aten.permute.default(permute_36, [1, 0]); permute_36 = None | |
| # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:256 in foreach_reduce_scatter_copy_in, code: grad_views.append(grad.view(world_size, -1)) | |
| view_33: "f32[2, 2048]" = torch.ops.aten.reshape.default(permute_37, [2, -1]); permute_37 = None | |
| # File: <eval_with_key>.201:134 in forward, code: sum_6 = torch.ops.aten.sum.dim_IntList(where_5, [0], True); where_5 = None | |
| sum_6: "f32[1, 128]" = torch.ops.aten.sum.dim_IntList(where_5, [0], True); where_5 = None | |
| # File: <eval_with_key>.201:135 in forward, code: view_5 = torch.ops.aten.view.default(sum_6, [128]); sum_6 = None | |
| view_32: "f32[128]" = torch.ops.aten.reshape.default(sum_6, [128]); sum_6 = None | |
| # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:256 in foreach_reduce_scatter_copy_in, code: grad_views.append(grad.view(world_size, -1)) | |
| view_34: "f32[2, 64]" = torch.ops.aten.reshape.default(view_32, [2, -1]); view_32 = None | |
| # File: <eval_with_key>.201:119 in forward, code: permute_22 = torch.ops.aten.permute.default(where_4, [1, 0]) | |
| permute_32: "f32[32, 8]" = torch.ops.aten.permute.default(where_4, [1, 0]) | |
| # File: <eval_with_key>.201:120 in forward, code: mm_9 = torch.ops.aten.mm.default(permute_22, getitem_7); permute_22 = getitem_7 = None | |
| mm_9: "f32[32, 128]" = torch.ops.aten.mm.default(permute_32, arg7_1); permute_32 = arg7_1 = None | |
| # File: <eval_with_key>.201:121 in forward, code: permute_23 = torch.ops.aten.permute.default(mm_9, [1, 0]); mm_9 = None | |
| permute_33: "f32[128, 32]" = torch.ops.aten.permute.default(mm_9, [1, 0]); mm_9 = None | |
| # File: <eval_with_key>.201:125 in forward, code: permute_24 = torch.ops.aten.permute.default(permute_23, [1, 0]); permute_23 = None | |
| permute_34: "f32[32, 128]" = torch.ops.aten.permute.default(permute_33, [1, 0]); permute_33 = None | |
| # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:256 in foreach_reduce_scatter_copy_in, code: grad_views.append(grad.view(world_size, -1)) | |
| view_35: "f32[2, 2048]" = torch.ops.aten.reshape.default(permute_34, [2, -1]); permute_34 = None | |
| # File: <eval_with_key>.201:122 in forward, code: sum_5 = torch.ops.aten.sum.dim_IntList(where_4, [0], True); where_4 = None | |
| sum_5: "f32[1, 32]" = torch.ops.aten.sum.dim_IntList(where_4, [0], True); where_4 = None | |
| # File: <eval_with_key>.201:123 in forward, code: view_4 = torch.ops.aten.view.default(sum_5, [32]); sum_5 = None | |
| view_31: "f32[32]" = torch.ops.aten.reshape.default(sum_5, [32]); sum_5 = None | |
| # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:256 in foreach_reduce_scatter_copy_in, code: grad_views.append(grad.view(world_size, -1)) | |
| view_36: "f32[2, 16]" = torch.ops.aten.reshape.default(view_31, [2, -1]); view_31 = None | |
| # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:263 in foreach_reduce_scatter_copy_in, code: cat_out = torch.cat(grad_views, dim=-1) | |
| cat: "f32[2, 4176]" = torch.ops.aten.cat.default([view_33, view_34, view_35, view_36], 1); view_33 = view_34 = view_35 = view_36 = None | |
| # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:266 in foreach_reduce_scatter_copy_in, code: reduce_scatter_input_view.copy_(cat_out) | |
| view_38: "f32[8352]" = torch.ops.aten.reshape.default(cat, [8352]); cat = None | |
| # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:297 in _div_if_needed, code: tensor.div_(div_factor) | |
| div: "f32[8352]" = torch.ops.aten.div.Tensor(view_38, 2.0); view_38 = None | |
| # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_functional_collectives.py:287 in reduce_scatter_tensor, code: tensor = torch.ops._c10d_functional.reduce_scatter_tensor( | |
| reduce_scatter_tensor: "f32[4176]" = torch.ops._c10d_functional.reduce_scatter_tensor.default(div, 'sum', 2, '0'); div = None | |
| # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_functional_collectives.py:144 in wait_tensor, code: return torch.ops._c10d_functional.wait_tensor(tensor) # type: ignore[attr-defined] | |
| wait_tensor_3: "f32[4176]" = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor); reduce_scatter_tensor = None | |
| # No stacktrace found for following nodes | |
| as_strided_32: "f32[64, 32]" = torch.ops.aten.as_strided.default(wait_tensor_3, [64, 32], [32, 1], 0) | |
| as_strided_33: "f32[64]" = torch.ops.aten.as_strided.default(wait_tensor_3, [64], [1], 2048) | |
| as_strided_34: "f32[16, 128]" = torch.ops.aten.as_strided.default(wait_tensor_3, [16, 128], [128, 1], 2112) | |
| as_strided_35: "f32[16]" = torch.ops.aten.as_strided.default(wait_tensor_3, [16], [1], 4160); wait_tensor_3 = None | |
| return [as_strided_24, as_strided_25, as_strided_26, as_strided_27, as_strided_28, as_strided_29, as_strided_30, as_strided_31, as_strided_32, as_strided_33, as_strided_34, as_strided_35] |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment