Created
March 26, 2024 23:03
-
-
Save yf225/774e6b86df0d6b5e5045e88fae7d026a to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| BWD graph | |
| ===== AFTER POST GRAD ===== | |
| /data/users/willfeng/pytorch_yf225/torch/fx/_lazy_graph_module.py class <lambda>(torch.nn.Module): | |
| def forward(self, arg0_1: "f32[]", arg1_1: "f32[8, 32]", arg2_1: "f32[32, 128]", arg3_1: "f32[128, 32]", arg4_1: "f32[32, 128]", arg5_1: "f32[128, 32]", arg6_1: "f32[32, 128]", arg7_1: "f32[8, 128]", arg8_1: "f32[8, 32]", arg9_1: "f32[8, 128]", arg10_1: "f32[8, 32]", arg11_1: "f32[8, 128]", arg12_1: "b8[8, 32]", arg13_1: "f32[32]", arg14_1: "f32[128]", arg15_1: "f32[32]", arg16_1: "f32[128]", arg17_1: "f32[32]", arg18_1: "f32[128]", arg19_1: "f32[128, 32]", arg20_1: "f32[2048]", arg21_1: "f32[64]", arg22_1: "f32[2048]", arg23_1: "f32[16]", arg24_1: "f32[2048]", arg25_1: "f32[64]", arg26_1: "f32[2048]", arg27_1: "f32[16]", arg28_1: "f32[2048]", arg29_1: "f32[64]", arg30_1: "f32[2048]", arg31_1: "f32[16]"): | |
| # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:47 in foreach_all_gather, code: all_gather_output = torch.empty( | |
| empty: "f32[8352]" = torch.ops.aten.empty.memory_format([8352], dtype = torch.float32, device = device(type='cuda', index=1), pin_memory = False) | |
| # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:50 in foreach_all_gather, code: all_gather_input = all_gather_output.narrow( | |
| slice_1: "f32[4176]" = torch.ops.aten.slice.Tensor(empty, 0, 4176, 8352); empty = None | |
| # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:53 in foreach_all_gather, code: foreach_copy_dsts = torch.split(all_gather_input, inp_split_sizes) | |
| split_with_sizes = torch.ops.aten.split_with_sizes.default(slice_1, [2048, 64, 2048, 16]) | |
| getitem: "f32[2048]" = split_with_sizes[0] | |
| getitem_1: "f32[64]" = split_with_sizes[1] | |
| getitem_2: "f32[2048]" = split_with_sizes[2] | |
| getitem_3: "f32[16]" = split_with_sizes[3]; split_with_sizes = None | |
| # No stacktrace found for following nodes | |
| copy__default = torch.ops.aten.copy_.default(getitem, arg20_1); getitem = arg20_1 = None | |
| copy__default_1 = torch.ops.aten.copy_.default(getitem_1, arg21_1); getitem_1 = arg21_1 = None | |
| copy__default_2 = torch.ops.aten.copy_.default(getitem_2, arg22_1); getitem_2 = arg22_1 = None | |
| copy__default_3 = torch.ops.aten.copy_.default(getitem_3, arg23_1); getitem_3 = arg23_1 = None | |
| # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_functional_collectives.py:229 in all_gather_tensor, code: tensor = torch.ops._c10d_functional.all_gather_into_tensor( | |
| all_gather_into_tensor: "f32[8352]" = torch.ops._c10d_functional.all_gather_into_tensor.default(slice_1, 2, '0'); slice_1 = None | |
| # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_functional_collectives.py:144 in wait_tensor, code: return torch.ops._c10d_functional.wait_tensor(tensor) # type: ignore[attr-defined] | |
| wait_tensor: "f32[8352]" = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor); all_gather_into_tensor = None | |
| # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:128 in foreach_all_gather_copy_out, code: splits[i].contiguous().view(splits[i].numel()), | |
| view_1: "f32[2, 4176]" = torch.ops.aten.reshape.default(wait_tensor, [2, -1]); wait_tensor = None | |
| split_with_sizes_6 = torch.ops.aten.split_with_sizes.default(view_1, [2048, 64, 2048, 16], 1); view_1 = None | |
| getitem_28: "f32[2, 2048]" = split_with_sizes_6[0] | |
| view_2: "f32[4096]" = torch.ops.aten.reshape.default(getitem_28, [4096]); getitem_28 = None | |
| # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:127 in foreach_all_gather_copy_out, code: torch.as_strided( | |
| as_strided: "f32[128, 32]" = torch.ops.aten.as_strided.default(view_2, [128, 32], [32, 1], 0); view_2 = None | |
| # No stacktrace found for following nodes | |
| getitem_144 = split_with_sizes_6[1] | |
| # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:128 in foreach_all_gather_copy_out, code: splits[i].contiguous().view(splits[i].numel()), | |
| view_4: "f32[128]" = torch.ops.aten.reshape.default(getitem_144, [128]); getitem_144 = None | |
| # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:127 in foreach_all_gather_copy_out, code: torch.as_strided( | |
| as_strided_1: "f32[128]" = torch.ops.aten.as_strided.default(view_4, [128], [1], 0); view_4 = None | |
| # No stacktrace found for following nodes | |
| getitem_145 = split_with_sizes_6[2] | |
| # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:128 in foreach_all_gather_copy_out, code: splits[i].contiguous().view(splits[i].numel()), | |
| view_6: "f32[4096]" = torch.ops.aten.reshape.default(getitem_145, [4096]); getitem_145 = None | |
| # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:127 in foreach_all_gather_copy_out, code: torch.as_strided( | |
| as_strided_2: "f32[32, 128]" = torch.ops.aten.as_strided.default(view_6, [32, 128], [128, 1], 0); view_6 = None | |
| # No stacktrace found for following nodes | |
| getitem_146 = split_with_sizes_6[3]; split_with_sizes_6 = None | |
| # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:128 in foreach_all_gather_copy_out, code: splits[i].contiguous().view(splits[i].numel()), | |
| view_8: "f32[32]" = torch.ops.aten.reshape.default(getitem_146, [32]); getitem_146 = None | |
| # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:127 in foreach_all_gather_copy_out, code: torch.as_strided( | |
| as_strided_3: "f32[32]" = torch.ops.aten.as_strided.default(view_8, [32], [1], 0); view_8 = None | |
| # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:47 in foreach_all_gather, code: all_gather_output = torch.empty( | |
| empty_1: "f32[8352]" = torch.ops.aten.empty.memory_format([8352], dtype = torch.float32, device = device(type='cuda', index=1), pin_memory = False) | |
| # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:50 in foreach_all_gather, code: all_gather_input = all_gather_output.narrow( | |
| slice_11: "f32[4176]" = torch.ops.aten.slice.Tensor(empty_1, 0, 4176, 8352); empty_1 = None | |
| # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:53 in foreach_all_gather, code: foreach_copy_dsts = torch.split(all_gather_input, inp_split_sizes) | |
| split_with_sizes_10 = torch.ops.aten.split_with_sizes.default(slice_11, [2048, 64, 2048, 16]) | |
| getitem_48: "f32[2048]" = split_with_sizes_10[0] | |
| getitem_49: "f32[64]" = split_with_sizes_10[1] | |
| getitem_50: "f32[2048]" = split_with_sizes_10[2] | |
| getitem_51: "f32[16]" = split_with_sizes_10[3]; split_with_sizes_10 = None | |
| # No stacktrace found for following nodes | |
| copy__default_4 = torch.ops.aten.copy_.default(getitem_48, arg24_1); getitem_48 = arg24_1 = None | |
| copy__default_5 = torch.ops.aten.copy_.default(getitem_49, arg25_1); getitem_49 = arg25_1 = None | |
| copy__default_6 = torch.ops.aten.copy_.default(getitem_50, arg26_1); getitem_50 = arg26_1 = None | |
| copy__default_7 = torch.ops.aten.copy_.default(getitem_51, arg27_1); getitem_51 = arg27_1 = None | |
| # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_functional_collectives.py:229 in all_gather_tensor, code: tensor = torch.ops._c10d_functional.all_gather_into_tensor( | |
| all_gather_into_tensor_1: "f32[8352]" = torch.ops._c10d_functional.all_gather_into_tensor.default(slice_11, 2, '0'); slice_11 = None | |
| # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_functional_collectives.py:144 in wait_tensor, code: return torch.ops._c10d_functional.wait_tensor(tensor) # type: ignore[attr-defined] | |
| wait_tensor_1: "f32[8352]" = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_1); all_gather_into_tensor_1 = None | |
| # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:128 in foreach_all_gather_copy_out, code: splits[i].contiguous().view(splits[i].numel()), | |
| view_12: "f32[2, 4176]" = torch.ops.aten.reshape.default(wait_tensor_1, [2, -1]); wait_tensor_1 = None | |
| split_with_sizes_16 = torch.ops.aten.split_with_sizes.default(view_12, [2048, 64, 2048, 16], 1); view_12 = None | |
| getitem_76: "f32[2, 2048]" = split_with_sizes_16[0] | |
| view_13: "f32[4096]" = torch.ops.aten.reshape.default(getitem_76, [4096]); getitem_76 = None | |
| # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:127 in foreach_all_gather_copy_out, code: torch.as_strided( | |
| as_strided_4: "f32[128, 32]" = torch.ops.aten.as_strided.default(view_13, [128, 32], [32, 1], 0); view_13 = None | |
| # No stacktrace found for following nodes | |
| getitem_147 = split_with_sizes_16[1] | |
| # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:128 in foreach_all_gather_copy_out, code: splits[i].contiguous().view(splits[i].numel()), | |
| view_15: "f32[128]" = torch.ops.aten.reshape.default(getitem_147, [128]); getitem_147 = None | |
| # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:127 in foreach_all_gather_copy_out, code: torch.as_strided( | |
| as_strided_5: "f32[128]" = torch.ops.aten.as_strided.default(view_15, [128], [1], 0); view_15 = None | |
| # No stacktrace found for following nodes | |
| getitem_148 = split_with_sizes_16[2] | |
| # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:128 in foreach_all_gather_copy_out, code: splits[i].contiguous().view(splits[i].numel()), | |
| view_17: "f32[4096]" = torch.ops.aten.reshape.default(getitem_148, [4096]); getitem_148 = None | |
| # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:127 in foreach_all_gather_copy_out, code: torch.as_strided( | |
| as_strided_6: "f32[32, 128]" = torch.ops.aten.as_strided.default(view_17, [32, 128], [128, 1], 0); view_17 = None | |
| # No stacktrace found for following nodes | |
| getitem_149 = split_with_sizes_16[3]; split_with_sizes_16 = None | |
| # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:128 in foreach_all_gather_copy_out, code: splits[i].contiguous().view(splits[i].numel()), | |
| view_19: "f32[32]" = torch.ops.aten.reshape.default(getitem_149, [32]); getitem_149 = None | |
| # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:127 in foreach_all_gather_copy_out, code: torch.as_strided( | |
| as_strided_7: "f32[32]" = torch.ops.aten.as_strided.default(view_19, [32], [1], 0); view_19 = None | |
| # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:47 in foreach_all_gather, code: all_gather_output = torch.empty( | |
| empty_2: "f32[8352]" = torch.ops.aten.empty.memory_format([8352], dtype = torch.float32, device = device(type='cuda', index=1), pin_memory = False) | |
| # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:50 in foreach_all_gather, code: all_gather_input = all_gather_output.narrow( | |
| slice_21: "f32[4176]" = torch.ops.aten.slice.Tensor(empty_2, 0, 4176, 8352); empty_2 = None | |
| # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:53 in foreach_all_gather, code: foreach_copy_dsts = torch.split(all_gather_input, inp_split_sizes) | |
| split_with_sizes_20 = torch.ops.aten.split_with_sizes.default(slice_21, [2048, 64, 2048, 16]) | |
| getitem_96: "f32[2048]" = split_with_sizes_20[0] | |
| getitem_97: "f32[64]" = split_with_sizes_20[1] | |
| getitem_98: "f32[2048]" = split_with_sizes_20[2] | |
| getitem_99: "f32[16]" = split_with_sizes_20[3]; split_with_sizes_20 = None | |
| # No stacktrace found for following nodes | |
| copy__default_8 = torch.ops.aten.copy_.default(getitem_96, arg28_1); getitem_96 = arg28_1 = None | |
| copy__default_9 = torch.ops.aten.copy_.default(getitem_97, arg29_1); getitem_97 = arg29_1 = None | |
| copy__default_10 = torch.ops.aten.copy_.default(getitem_98, arg30_1); getitem_98 = arg30_1 = None | |
| copy__default_11 = torch.ops.aten.copy_.default(getitem_99, arg31_1); getitem_99 = arg31_1 = None | |
| # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_functional_collectives.py:229 in all_gather_tensor, code: tensor = torch.ops._c10d_functional.all_gather_into_tensor( | |
| all_gather_into_tensor_2: "f32[8352]" = torch.ops._c10d_functional.all_gather_into_tensor.default(slice_21, 2, '0'); slice_21 = None | |
| # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_functional_collectives.py:144 in wait_tensor, code: return torch.ops._c10d_functional.wait_tensor(tensor) # type: ignore[attr-defined] | |
| wait_tensor_2: "f32[8352]" = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_2); all_gather_into_tensor_2 = None | |
| # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:128 in foreach_all_gather_copy_out, code: splits[i].contiguous().view(splits[i].numel()), | |
| view_23: "f32[2, 4176]" = torch.ops.aten.reshape.default(wait_tensor_2, [2, -1]); wait_tensor_2 = None | |
| split_with_sizes_26 = torch.ops.aten.split_with_sizes.default(view_23, [2048, 64, 2048, 16], 1); view_23 = None | |
| getitem_124: "f32[2, 2048]" = split_with_sizes_26[0] | |
| view_24: "f32[4096]" = torch.ops.aten.reshape.default(getitem_124, [4096]); getitem_124 = None | |
| # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:127 in foreach_all_gather_copy_out, code: torch.as_strided( | |
| as_strided_8: "f32[128, 32]" = torch.ops.aten.as_strided.default(view_24, [128, 32], [32, 1], 0); view_24 = None | |
| # No stacktrace found for following nodes | |
| getitem_150 = split_with_sizes_26[1] | |
| # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:128 in foreach_all_gather_copy_out, code: splits[i].contiguous().view(splits[i].numel()), | |
| view_26: "f32[128]" = torch.ops.aten.reshape.default(getitem_150, [128]); getitem_150 = None | |
| # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:127 in foreach_all_gather_copy_out, code: torch.as_strided( | |
| as_strided_9: "f32[128]" = torch.ops.aten.as_strided.default(view_26, [128], [1], 0); view_26 = None | |
| # No stacktrace found for following nodes | |
| getitem_151 = split_with_sizes_26[2] | |
| # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:128 in foreach_all_gather_copy_out, code: splits[i].contiguous().view(splits[i].numel()), | |
| view_28: "f32[4096]" = torch.ops.aten.reshape.default(getitem_151, [4096]); getitem_151 = None | |
| # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:127 in foreach_all_gather_copy_out, code: torch.as_strided( | |
| as_strided_10: "f32[32, 128]" = torch.ops.aten.as_strided.default(view_28, [32, 128], [128, 1], 0); view_28 = None | |
| # No stacktrace found for following nodes | |
| getitem_152 = split_with_sizes_26[3]; split_with_sizes_26 = None | |
| # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:128 in foreach_all_gather_copy_out, code: splits[i].contiguous().view(splits[i].numel()), | |
| view_30: "f32[32]" = torch.ops.aten.reshape.default(getitem_152, [32]); getitem_152 = None | |
| # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:127 in foreach_all_gather_copy_out, code: torch.as_strided( | |
| as_strided_11: "f32[32]" = torch.ops.aten.as_strided.default(view_30, [32], [1], 0); view_30 = None | |
| # File: <eval_with_key>.373:67 in forward, code: le = torch.ops.aten.le.Scalar(alias_11, 0); alias_11 = None | |
| le: "b8[8, 128]" = torch.ops.aten.le.Scalar(arg11_1, 0) | |
| # File: <eval_with_key>.373:52 in forward, code: scalar_tensor = torch.ops.aten.scalar_tensor.default(0, dtype = torch.float32, layout = torch.strided, device = device(type='cuda', index=1)) | |
| scalar_tensor: "f32[]" = torch.ops.aten.scalar_tensor.default(0, dtype = torch.float32, layout = torch.strided, device = device(type='cuda', index=1)) | |
| # File: <eval_with_key>.373:25 in forward, code: expand = torch.ops.aten.expand.default(getitem, [8, 32]); getitem = None | |
| expand: "f32[8, 32]" = torch.ops.aten.expand.default(arg0_1, [8, 32]); arg0_1 = None | |
| # File: <eval_with_key>.373:53 in forward, code: where = torch.ops.aten.where.self(getitem_12, scalar_tensor, trace_wrapped_1); getitem_12 = trace_wrapped_1 = None | |
| where: "f32[8, 32]" = torch.ops.aten.where.self(arg12_1, scalar_tensor, expand); arg12_1 = expand = None | |
| # File: <eval_with_key>.373:56 in forward, code: mm = torch.ops.aten.mm.default(where, permute_1); permute_1 = None | |
| mm: "f32[8, 128]" = torch.ops.aten.mm.default(where, as_strided_2); as_strided_2 = None | |
| # File: <eval_with_key>.373:68 in forward, code: where_1 = torch.ops.aten.where.self(le, scalar_tensor, mm); le = mm = None | |
| where_1: "f32[8, 128]" = torch.ops.aten.where.self(le, scalar_tensor, mm); le = mm = None | |
| # File: <eval_with_key>.373:72 in forward, code: permute_7 = torch.ops.aten.permute.default(where_1, [1, 0]) | |
| permute_11: "f32[128, 8]" = torch.ops.aten.permute.default(where_1, [1, 0]) | |
| # File: <eval_with_key>.373:73 in forward, code: mm_3 = torch.ops.aten.mm.default(permute_7, getitem_10); permute_7 = getitem_10 = None | |
| mm_3: "f32[128, 32]" = torch.ops.aten.mm.default(permute_11, arg10_1); permute_11 = None | |
| # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:256 in foreach_reduce_scatter_copy_in, code: grad_views.append(grad.view(world_size, -1)) | |
| view_55: "f32[2, 2048]" = torch.ops.aten.reshape.default(mm_3, [2, -1]); mm_3 = None | |
| # File: <eval_with_key>.373:75 in forward, code: sum_2 = torch.ops.aten.sum.dim_IntList(where_1, [0], True); where_1 = None | |
| sum_2: "f32[1, 128]" = torch.ops.aten.sum.dim_IntList(where_1, [0], True) | |
| # File: <eval_with_key>.373:76 in forward, code: view_1 = torch.ops.aten.view.default(sum_2, [128]); sum_2 = None | |
| view_10: "f32[128]" = torch.ops.aten.reshape.default(sum_2, [128]); sum_2 = None | |
| # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:256 in foreach_reduce_scatter_copy_in, code: grad_views.append(grad.view(world_size, -1)) | |
| view_56: "f32[2, 64]" = torch.ops.aten.reshape.default(view_10, [2, -1]); view_10 = None | |
| # File: <eval_with_key>.373:57 in forward, code: permute_2 = torch.ops.aten.permute.default(where, [1, 0]) | |
| permute_4: "f32[32, 8]" = torch.ops.aten.permute.default(where, [1, 0]) | |
| # File: <eval_with_key>.373:58 in forward, code: mm_1 = torch.ops.aten.mm.default(permute_2, getitem_11); permute_2 = getitem_11 = None | |
| mm_1: "f32[32, 128]" = torch.ops.aten.mm.default(permute_4, arg11_1); permute_4 = arg11_1 = None | |
| # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:256 in foreach_reduce_scatter_copy_in, code: grad_views.append(grad.view(world_size, -1)) | |
| view_57: "f32[2, 2048]" = torch.ops.aten.reshape.default(mm_1, [2, -1]); mm_1 = None | |
| # File: <eval_with_key>.373:60 in forward, code: sum_1 = torch.ops.aten.sum.dim_IntList(where, [0], True); where = None | |
| sum_1: "f32[1, 32]" = torch.ops.aten.sum.dim_IntList(where, [0], True); where = None | |
| # File: <eval_with_key>.373:61 in forward, code: view = torch.ops.aten.view.default(sum_1, [32]); sum_1 = None | |
| view_9: "f32[32]" = torch.ops.aten.reshape.default(sum_1, [32]); sum_1 = None | |
| # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:256 in foreach_reduce_scatter_copy_in, code: grad_views.append(grad.view(world_size, -1)) | |
| view_58: "f32[2, 16]" = torch.ops.aten.reshape.default(view_9, [2, -1]); view_9 = None | |
| # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:263 in foreach_reduce_scatter_copy_in, code: cat_out = torch.cat(grad_views, dim=-1) | |
| cat_2: "f32[2, 4176]" = torch.ops.aten.cat.default([view_55, view_56, view_57, view_58], 1); view_55 = view_56 = view_57 = view_58 = None | |
| # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:266 in foreach_reduce_scatter_copy_in, code: reduce_scatter_input_view.copy_(cat_out) | |
| view_60: "f32[8352]" = torch.ops.aten.reshape.default(cat_2, [8352]); cat_2 = None | |
| # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:297 in _div_if_needed, code: tensor.div_(div_factor) | |
| div_2: "f32[8352]" = torch.ops.aten.div.Tensor(view_60, 2.0); view_60 = None | |
| # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_functional_collectives.py:287 in reduce_scatter_tensor, code: tensor = torch.ops._c10d_functional.reduce_scatter_tensor( | |
| reduce_scatter_tensor_2: "f32[4176]" = torch.ops._c10d_functional.reduce_scatter_tensor.default(div_2, 'sum', 2, '0'); div_2 = None | |
| # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_functional_collectives.py:144 in wait_tensor, code: return torch.ops._c10d_functional.wait_tensor(tensor) # type: ignore[attr-defined] | |
| wait_tensor_5: "f32[4176]" = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_2); reduce_scatter_tensor_2 = None | |
| # No stacktrace found for following nodes | |
| as_strided_24: "f32[64, 32]" = torch.ops.aten.as_strided.default(wait_tensor_5, [64, 32], [32, 1], 0) | |
| as_strided_25: "f32[64]" = torch.ops.aten.as_strided.default(wait_tensor_5, [64], [1], 2048) | |
| as_strided_26: "f32[16, 128]" = torch.ops.aten.as_strided.default(wait_tensor_5, [16, 128], [128, 1], 2112) | |
| as_strided_27: "f32[16]" = torch.ops.aten.as_strided.default(wait_tensor_5, [16], [1], 4160); wait_tensor_5 = None | |
| # File: <eval_with_key>.373:98 in forward, code: le_2 = torch.ops.aten.le.Scalar(alias_15, 0); alias_15 = None | |
| le_2: "b8[8, 128]" = torch.ops.aten.le.Scalar(arg9_1, 0) | |
| # File: <eval_with_key>.373:83 in forward, code: le_1 = torch.ops.aten.le.Scalar(alias_13, 0); alias_13 = None | |
| le_1: "b8[8, 32]" = torch.ops.aten.le.Scalar(arg10_1, 0); arg10_1 = None | |
| # File: <eval_with_key>.373:71 in forward, code: mm_2 = torch.ops.aten.mm.default(where_1, permute_6); permute_6 = None | |
| mm_2: "f32[8, 32]" = torch.ops.aten.mm.default(where_1, as_strided); where_1 = as_strided = None | |
| # File: <eval_with_key>.373:84 in forward, code: where_2 = torch.ops.aten.where.self(le_1, scalar_tensor, trace_wrapped_2); le_1 = trace_wrapped_2 = None | |
| where_2: "f32[8, 32]" = torch.ops.aten.where.self(le_1, scalar_tensor, mm_2); le_1 = mm_2 = None | |
| # File: <eval_with_key>.373:87 in forward, code: mm_4 = torch.ops.aten.mm.default(where_2, permute_11); permute_11 = None | |
| mm_4: "f32[8, 128]" = torch.ops.aten.mm.default(where_2, as_strided_6); as_strided_6 = None | |
| # File: <eval_with_key>.373:99 in forward, code: where_3 = torch.ops.aten.where.self(le_2, scalar_tensor, mm_4); le_2 = mm_4 = None | |
| where_3: "f32[8, 128]" = torch.ops.aten.where.self(le_2, scalar_tensor, mm_4); le_2 = mm_4 = None | |
| # File: <eval_with_key>.373:103 in forward, code: permute_17 = torch.ops.aten.permute.default(where_3, [1, 0]) | |
| permute_25: "f32[128, 8]" = torch.ops.aten.permute.default(where_3, [1, 0]) | |
| # File: <eval_with_key>.373:104 in forward, code: mm_7 = torch.ops.aten.mm.default(permute_17, getitem_8); permute_17 = getitem_8 = None | |
| mm_7: "f32[128, 32]" = torch.ops.aten.mm.default(permute_25, arg8_1); permute_25 = None | |
| # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:256 in foreach_reduce_scatter_copy_in, code: grad_views.append(grad.view(world_size, -1)) | |
| view_44: "f32[2, 2048]" = torch.ops.aten.reshape.default(mm_7, [2, -1]); mm_7 = None | |
| # File: <eval_with_key>.373:106 in forward, code: sum_4 = torch.ops.aten.sum.dim_IntList(where_3, [0], True); where_3 = None | |
| sum_4: "f32[1, 128]" = torch.ops.aten.sum.dim_IntList(where_3, [0], True) | |
| # File: <eval_with_key>.373:107 in forward, code: view_3 = torch.ops.aten.view.default(sum_4, [128]); sum_4 = None | |
| view_21: "f32[128]" = torch.ops.aten.reshape.default(sum_4, [128]); sum_4 = None | |
| # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:256 in foreach_reduce_scatter_copy_in, code: grad_views.append(grad.view(world_size, -1)) | |
| view_45: "f32[2, 64]" = torch.ops.aten.reshape.default(view_21, [2, -1]); view_21 = None | |
| # File: <eval_with_key>.373:88 in forward, code: permute_12 = torch.ops.aten.permute.default(where_2, [1, 0]) | |
| permute_18: "f32[32, 8]" = torch.ops.aten.permute.default(where_2, [1, 0]) | |
| # File: <eval_with_key>.373:89 in forward, code: mm_5 = torch.ops.aten.mm.default(permute_12, getitem_9); permute_12 = getitem_9 = None | |
| mm_5: "f32[32, 128]" = torch.ops.aten.mm.default(permute_18, arg9_1); permute_18 = arg9_1 = None | |
| # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:256 in foreach_reduce_scatter_copy_in, code: grad_views.append(grad.view(world_size, -1)) | |
| view_46: "f32[2, 2048]" = torch.ops.aten.reshape.default(mm_5, [2, -1]); mm_5 = None | |
| # File: <eval_with_key>.373:91 in forward, code: sum_3 = torch.ops.aten.sum.dim_IntList(where_2, [0], True); where_2 = None | |
| sum_3: "f32[1, 32]" = torch.ops.aten.sum.dim_IntList(where_2, [0], True); where_2 = None | |
| # File: <eval_with_key>.373:92 in forward, code: view_2 = torch.ops.aten.view.default(sum_3, [32]); sum_3 = None | |
| view_20: "f32[32]" = torch.ops.aten.reshape.default(sum_3, [32]); sum_3 = None | |
| # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:256 in foreach_reduce_scatter_copy_in, code: grad_views.append(grad.view(world_size, -1)) | |
| view_47: "f32[2, 16]" = torch.ops.aten.reshape.default(view_20, [2, -1]); view_20 = None | |
| # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:263 in foreach_reduce_scatter_copy_in, code: cat_out = torch.cat(grad_views, dim=-1) | |
| cat_1: "f32[2, 4176]" = torch.ops.aten.cat.default([view_44, view_45, view_46, view_47], 1); view_44 = view_45 = view_46 = view_47 = None | |
| # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:266 in foreach_reduce_scatter_copy_in, code: reduce_scatter_input_view.copy_(cat_out) | |
| view_49: "f32[8352]" = torch.ops.aten.reshape.default(cat_1, [8352]); cat_1 = None | |
| # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:297 in _div_if_needed, code: tensor.div_(div_factor) | |
| div_1: "f32[8352]" = torch.ops.aten.div.Tensor(view_49, 2.0); view_49 = None | |
| # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_functional_collectives.py:287 in reduce_scatter_tensor, code: tensor = torch.ops._c10d_functional.reduce_scatter_tensor( | |
| reduce_scatter_tensor_1: "f32[4176]" = torch.ops._c10d_functional.reduce_scatter_tensor.default(div_1, 'sum', 2, '0'); div_1 = None | |
| # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_functional_collectives.py:144 in wait_tensor, code: return torch.ops._c10d_functional.wait_tensor(tensor) # type: ignore[attr-defined] | |
| wait_tensor_4: "f32[4176]" = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_1); reduce_scatter_tensor_1 = None | |
| # No stacktrace found for following nodes | |
| as_strided_28: "f32[64, 32]" = torch.ops.aten.as_strided.default(wait_tensor_4, [64, 32], [32, 1], 0) | |
| as_strided_29: "f32[64]" = torch.ops.aten.as_strided.default(wait_tensor_4, [64], [1], 2048) | |
| as_strided_30: "f32[16, 128]" = torch.ops.aten.as_strided.default(wait_tensor_4, [16, 128], [128, 1], 2112) | |
| as_strided_31: "f32[16]" = torch.ops.aten.as_strided.default(wait_tensor_4, [16], [1], 4160); wait_tensor_4 = None | |
| # File: <eval_with_key>.373:129 in forward, code: le_4 = torch.ops.aten.le.Scalar(alias_19, 0); alias_19 = None | |
| le_4: "b8[8, 128]" = torch.ops.aten.le.Scalar(arg7_1, 0) | |
| # File: <eval_with_key>.373:114 in forward, code: le_3 = torch.ops.aten.le.Scalar(alias_17, 0); alias_17 = None | |
| le_3: "b8[8, 32]" = torch.ops.aten.le.Scalar(arg8_1, 0); arg8_1 = None | |
| # File: <eval_with_key>.373:102 in forward, code: mm_6 = torch.ops.aten.mm.default(where_3, permute_16); permute_16 = None | |
| mm_6: "f32[8, 32]" = torch.ops.aten.mm.default(where_3, as_strided_4); where_3 = as_strided_4 = None | |
| # File: <eval_with_key>.373:115 in forward, code: where_4 = torch.ops.aten.where.self(le_3, scalar_tensor, trace_wrapped_3); le_3 = trace_wrapped_3 = None | |
| where_4: "f32[8, 32]" = torch.ops.aten.where.self(le_3, scalar_tensor, mm_6); le_3 = mm_6 = None | |
| # File: <eval_with_key>.373:118 in forward, code: mm_8 = torch.ops.aten.mm.default(where_4, permute_21); permute_21 = None | |
| mm_8: "f32[8, 128]" = torch.ops.aten.mm.default(where_4, as_strided_10); as_strided_10 = None | |
| # File: <eval_with_key>.373:130 in forward, code: where_5 = torch.ops.aten.where.self(le_4, scalar_tensor, mm_8); le_4 = scalar_tensor = mm_8 = None | |
| where_5: "f32[8, 128]" = torch.ops.aten.where.self(le_4, scalar_tensor, mm_8); le_4 = scalar_tensor = mm_8 = None | |
| # File: <eval_with_key>.373:131 in forward, code: permute_25 = torch.ops.aten.permute.default(where_5, [1, 0]) | |
| permute_35: "f32[128, 8]" = torch.ops.aten.permute.default(where_5, [1, 0]) | |
| # File: <eval_with_key>.373:132 in forward, code: mm_10 = torch.ops.aten.mm.default(permute_25, getitem_1); permute_25 = getitem_1 = None | |
| mm_10: "f32[128, 32]" = torch.ops.aten.mm.default(permute_35, arg1_1); permute_35 = arg1_1 = None | |
| # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:256 in foreach_reduce_scatter_copy_in, code: grad_views.append(grad.view(world_size, -1)) | |
| view_33: "f32[2, 2048]" = torch.ops.aten.reshape.default(mm_10, [2, -1]); mm_10 = None | |
| # File: <eval_with_key>.373:134 in forward, code: sum_6 = torch.ops.aten.sum.dim_IntList(where_5, [0], True); where_5 = None | |
| sum_6: "f32[1, 128]" = torch.ops.aten.sum.dim_IntList(where_5, [0], True); where_5 = None | |
| # File: <eval_with_key>.373:135 in forward, code: view_5 = torch.ops.aten.view.default(sum_6, [128]); sum_6 = None | |
| view_32: "f32[128]" = torch.ops.aten.reshape.default(sum_6, [128]); sum_6 = None | |
| # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:256 in foreach_reduce_scatter_copy_in, code: grad_views.append(grad.view(world_size, -1)) | |
| view_34: "f32[2, 64]" = torch.ops.aten.reshape.default(view_32, [2, -1]); view_32 = None | |
| # File: <eval_with_key>.373:119 in forward, code: permute_22 = torch.ops.aten.permute.default(where_4, [1, 0]) | |
| permute_32: "f32[32, 8]" = torch.ops.aten.permute.default(where_4, [1, 0]) | |
| # File: <eval_with_key>.373:120 in forward, code: mm_9 = torch.ops.aten.mm.default(permute_22, getitem_7); permute_22 = getitem_7 = None | |
| mm_9: "f32[32, 128]" = torch.ops.aten.mm.default(permute_32, arg7_1); permute_32 = arg7_1 = None | |
| # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:256 in foreach_reduce_scatter_copy_in, code: grad_views.append(grad.view(world_size, -1)) | |
| view_35: "f32[2, 2048]" = torch.ops.aten.reshape.default(mm_9, [2, -1]); mm_9 = None | |
| # File: <eval_with_key>.373:122 in forward, code: sum_5 = torch.ops.aten.sum.dim_IntList(where_4, [0], True); where_4 = None | |
| sum_5: "f32[1, 32]" = torch.ops.aten.sum.dim_IntList(where_4, [0], True); where_4 = None | |
| # File: <eval_with_key>.373:123 in forward, code: view_4 = torch.ops.aten.view.default(sum_5, [32]); sum_5 = None | |
| view_31: "f32[32]" = torch.ops.aten.reshape.default(sum_5, [32]); sum_5 = None | |
| # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:256 in foreach_reduce_scatter_copy_in, code: grad_views.append(grad.view(world_size, -1)) | |
| view_36: "f32[2, 16]" = torch.ops.aten.reshape.default(view_31, [2, -1]); view_31 = None | |
| # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:263 in foreach_reduce_scatter_copy_in, code: cat_out = torch.cat(grad_views, dim=-1) | |
| cat: "f32[2, 4176]" = torch.ops.aten.cat.default([view_33, view_34, view_35, view_36], 1); view_33 = view_34 = view_35 = view_36 = None | |
| # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:266 in foreach_reduce_scatter_copy_in, code: reduce_scatter_input_view.copy_(cat_out) | |
| view_38: "f32[8352]" = torch.ops.aten.reshape.default(cat, [8352]); cat = None | |
| # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:297 in _div_if_needed, code: tensor.div_(div_factor) | |
| div: "f32[8352]" = torch.ops.aten.div.Tensor(view_38, 2.0); view_38 = None | |
| # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_functional_collectives.py:287 in reduce_scatter_tensor, code: tensor = torch.ops._c10d_functional.reduce_scatter_tensor( | |
| reduce_scatter_tensor: "f32[4176]" = torch.ops._c10d_functional.reduce_scatter_tensor.default(div, 'sum', 2, '0'); div = None | |
| # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_functional_collectives.py:144 in wait_tensor, code: return torch.ops._c10d_functional.wait_tensor(tensor) # type: ignore[attr-defined] | |
| wait_tensor_3: "f32[4176]" = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor); reduce_scatter_tensor = None | |
| # No stacktrace found for following nodes | |
| as_strided_32: "f32[64, 32]" = torch.ops.aten.as_strided.default(wait_tensor_3, [64, 32], [32, 1], 0) | |
| as_strided_33: "f32[64]" = torch.ops.aten.as_strided.default(wait_tensor_3, [64], [1], 2048) | |
| as_strided_34: "f32[16, 128]" = torch.ops.aten.as_strided.default(wait_tensor_3, [16, 128], [128, 1], 2112) | |
| as_strided_35: "f32[16]" = torch.ops.aten.as_strided.default(wait_tensor_3, [16], [1], 4160); wait_tensor_3 = None | |
| return [as_strided_24, as_strided_25, as_strided_26, as_strided_27, as_strided_28, as_strided_29, as_strided_30, as_strided_31, as_strided_32, as_strided_33, as_strided_34, as_strided_35] |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment