Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Select an option

  • Save yf225/b1063249168de00d0a1f9e4a563607f1 to your computer and use it in GitHub Desktop.

Select an option

Save yf225/b1063249168de00d0a1f9e4a563607f1 to your computer and use it in GitHub Desktop.
TRACED GRAPH
===== AFTER POST GRAD =====
/data/users/willfeng/pytorch_yf225/torch/fx/_lazy_graph_module.py class GraphModule(torch.nn.Module):
def forward(self, primals_1: "f32[8, 32]", primals_2: "f32[2048]", primals_3: "f32[64]", primals_4: "f32[2048]", primals_5: "f32[16]", primals_6: "f32[128, 32]", primals_7: "f32[128]", primals_8: "f32[32, 128]", primals_9: "f32[32]", primals_10, primals_11: "f32[2048]", primals_12: "f32[64]", primals_13: "f32[2048]", primals_14: "f32[16]", primals_15: "f32[128, 32]", primals_16: "f32[128]", primals_17: "f32[32, 128]", primals_18: "f32[32]", primals_19: "f32[2048]", primals_20: "f32[64]", primals_21: "f32[2048]", primals_22: "f32[16]", primals_23: "f32[128, 32]", primals_24: "f32[128]", primals_25: "f32[32, 128]", primals_26: "f32[32]"):
# File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:47 in foreach_all_gather, code: all_gather_output = torch.empty(
empty: "f32[8352]" = torch.ops.aten.empty.memory_format([8352], dtype = torch.float32, device = device(type='cuda', index=1), pin_memory = False)
# File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:50 in foreach_all_gather, code: all_gather_input = all_gather_output.narrow(
slice_1: "f32[4176]" = torch.ops.aten.slice.Tensor(empty, 0, 4176, 8352)
# File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:53 in foreach_all_gather, code: foreach_copy_dsts = torch.split(all_gather_input, inp_split_sizes)
split_with_sizes = torch.ops.aten.split_with_sizes.default(slice_1, [2048, 64, 2048, 16]); slice_1 = None
getitem: "f32[2048]" = split_with_sizes[0]
getitem_1: "f32[64]" = split_with_sizes[1]
getitem_2: "f32[2048]" = split_with_sizes[2]
getitem_3: "f32[16]" = split_with_sizes[3]; split_with_sizes = None
# File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:55 in foreach_all_gather, code: torch._foreach_copy_(foreach_copy_dsts, param_all_gather_inputs)
_foreach_copy = torch.ops.aten._foreach_copy.default([getitem, getitem_1, getitem_2, getitem_3], [primals_2, primals_3, primals_4, primals_5]); primals_2 = primals_3 = primals_4 = primals_5 = None
getitem_4: "f32[2048]" = _foreach_copy[0]
getitem_5: "f32[64]" = _foreach_copy[1]
getitem_6: "f32[2048]" = _foreach_copy[2]
getitem_7: "f32[16]" = _foreach_copy[3]; _foreach_copy = None
# No stacktrace found for following nodes
slice_tensor: "f32[4176]" = torch.ops.aten.slice.Tensor(empty, 0, 4176, 8352)
slice_scatter_default: "f32[4176]" = torch.ops.aten.slice_scatter.default(slice_tensor, getitem_4, 0, 0, 2048); slice_tensor = getitem_4 = None
slice_scatter_default_1: "f32[8352]" = torch.ops.aten.slice_scatter.default(empty, slice_scatter_default, 0, 4176, 8352); slice_scatter_default = None
# File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:55 in foreach_all_gather, code: torch._foreach_copy_(foreach_copy_dsts, param_all_gather_inputs)
slice_3: "f32[4176]" = torch.ops.aten.slice.Tensor(slice_scatter_default_1, 0, 4176, 8352)
# No stacktrace found for following nodes
slice_tensor_1: "f32[4176]" = torch.ops.aten.slice.Tensor(slice_scatter_default_1, 0, 4176, 8352)
slice_scatter_default_2: "f32[4176]" = torch.ops.aten.slice_scatter.default(slice_tensor_1, getitem_5, 0, 2048, 2112); slice_tensor_1 = getitem_5 = None
slice_scatter_default_3: "f32[8352]" = torch.ops.aten.slice_scatter.default(slice_scatter_default_1, slice_scatter_default_2, 0, 4176, 8352); slice_scatter_default_1 = slice_scatter_default_2 = None
# File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:55 in foreach_all_gather, code: torch._foreach_copy_(foreach_copy_dsts, param_all_gather_inputs)
slice_4: "f32[4176]" = torch.ops.aten.slice.Tensor(slice_scatter_default_3, 0, 4176, 8352)
# No stacktrace found for following nodes
slice_tensor_2: "f32[4176]" = torch.ops.aten.slice.Tensor(slice_scatter_default_3, 0, 4176, 8352)
slice_scatter_default_4: "f32[4176]" = torch.ops.aten.slice_scatter.default(slice_tensor_2, getitem_6, 0, 2112, 4160); slice_tensor_2 = getitem_6 = None
slice_scatter_default_5: "f32[8352]" = torch.ops.aten.slice_scatter.default(slice_scatter_default_3, slice_scatter_default_4, 0, 4176, 8352); slice_scatter_default_3 = slice_scatter_default_4 = None
# File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:55 in foreach_all_gather, code: torch._foreach_copy_(foreach_copy_dsts, param_all_gather_inputs)
slice_5: "f32[4176]" = torch.ops.aten.slice.Tensor(slice_scatter_default_5, 0, 4176, 8352)
# No stacktrace found for following nodes
slice_tensor_3: "f32[4176]" = torch.ops.aten.slice.Tensor(slice_scatter_default_5, 0, 4176, 8352)
slice_scatter_default_6: "f32[4176]" = torch.ops.aten.slice_scatter.default(slice_tensor_3, getitem_7, 0, 4160, 4176); slice_tensor_3 = getitem_7 = None
slice_scatter_default_7: "f32[8352]" = torch.ops.aten.slice_scatter.default(slice_scatter_default_5, slice_scatter_default_6, 0, 4176, 8352); slice_scatter_default_5 = slice_scatter_default_6 = None
# File: /data/users/willfeng/pytorch_yf225/torch/distributed/_functional_collectives.py:229 in all_gather_tensor, code: tensor = torch.ops._c10d_functional.all_gather_into_tensor(
slice_10: "f32[4176]" = torch.ops.aten.slice.Tensor(slice_scatter_default_7, 0, 4176, 8352); slice_scatter_default_7 = None
all_gather_into_tensor: "f32[8352]" = torch.ops._c10d_functional.all_gather_into_tensor.default(slice_10, 2, '0'); slice_10 = None
# File: /data/users/willfeng/pytorch_yf225/torch/distributed/_functional_collectives.py:144 in wait_tensor, code: return torch.ops._c10d_functional.wait_tensor(tensor) # type: ignore[attr-defined]
wait_tensor: "f32[8352]" = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor); all_gather_into_tensor = None
# File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:128 in foreach_all_gather_copy_out, code: splits[i].contiguous().view(splits[i].numel()),
view_1: "f32[2, 4176]" = torch.ops.aten.reshape.default(wait_tensor, [2, -1]); wait_tensor = None
split_with_sizes_6 = torch.ops.aten.split_with_sizes.default(view_1, [2048, 64, 2048, 16], 1); view_1 = None
getitem_28: "f32[2, 2048]" = split_with_sizes_6[0]
clone: "f32[2, 2048]" = torch.ops.aten.clone.default(getitem_28, memory_format = torch.contiguous_format); getitem_28 = None
view_2: "f32[4096]" = torch.ops.aten.reshape.default(clone, [4096]); clone = None
# File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:127 in foreach_all_gather_copy_out, code: torch.as_strided(
as_strided: "f32[128, 32]" = torch.ops.aten.as_strided.default(view_2, [128, 32], [32, 1], 0); view_2 = None
# File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:128 in foreach_all_gather_copy_out, code: splits[i].contiguous().view(splits[i].numel()),
getitem_33: "f32[2, 64]" = split_with_sizes_6[1]
clone_1: "f32[2, 64]" = torch.ops.aten.clone.default(getitem_33, memory_format = torch.contiguous_format); getitem_33 = None
view_4: "f32[128]" = torch.ops.aten.reshape.default(clone_1, [128]); clone_1 = None
# File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:127 in foreach_all_gather_copy_out, code: torch.as_strided(
as_strided_1: "f32[128]" = torch.ops.aten.as_strided.default(view_4, [128], [1], 0); view_4 = None
# File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:128 in foreach_all_gather_copy_out, code: splits[i].contiguous().view(splits[i].numel()),
getitem_38: "f32[2, 2048]" = split_with_sizes_6[2]
clone_2: "f32[2, 2048]" = torch.ops.aten.clone.default(getitem_38, memory_format = torch.contiguous_format); getitem_38 = None
view_6: "f32[4096]" = torch.ops.aten.reshape.default(clone_2, [4096]); clone_2 = None
# File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:127 in foreach_all_gather_copy_out, code: torch.as_strided(
as_strided_2: "f32[32, 128]" = torch.ops.aten.as_strided.default(view_6, [32, 128], [128, 1], 0); view_6 = None
# File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:128 in foreach_all_gather_copy_out, code: splits[i].contiguous().view(splits[i].numel()),
getitem_43: "f32[2, 16]" = split_with_sizes_6[3]; split_with_sizes_6 = None
clone_3: "f32[2, 16]" = torch.ops.aten.clone.default(getitem_43, memory_format = torch.contiguous_format); getitem_43 = None
view_8: "f32[32]" = torch.ops.aten.reshape.default(clone_3, [32]); clone_3 = None
# File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:127 in foreach_all_gather_copy_out, code: torch.as_strided(
as_strided_3: "f32[32]" = torch.ops.aten.as_strided.default(view_8, [32], [1], 0); view_8 = None
# File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:138 in foreach_all_gather_copy_out, code: torch._foreach_copy_(out, splits_unpadded)
_foreach_copy_1 = torch.ops.aten._foreach_copy.default([primals_6, primals_7, primals_8, primals_9], [as_strided, as_strided_1, as_strided_2, as_strided_3]); primals_6 = primals_7 = primals_9 = as_strided = as_strided_1 = as_strided_2 = as_strided_3 = None
getitem_44: "f32[128, 32]" = _foreach_copy_1[0]
getitem_45: "f32[128]" = _foreach_copy_1[1]
getitem_46: "f32[32, 128]" = _foreach_copy_1[2]
getitem_47: "f32[32]" = _foreach_copy_1[3]; _foreach_copy_1 = None
# File: /data/users/willfeng/pytorch_yf225/torch/nn/modules/linear.py:116 in forward, code: return F.linear(input, self.weight, self.bias)
permute_1: "f32[32, 128]" = torch.ops.aten.permute.default(getitem_44, [1, 0]); getitem_44 = None
# No stacktrace found for following nodes
mm_default_5: "f32[8, 128]" = torch.ops.aten.mm.default(primals_1, permute_1); permute_1 = None
add_tensor_5: "f32[8, 128]" = torch.ops.aten.add.Tensor(mm_default_5, getitem_45); mm_default_5 = getitem_45 = None
# File: /data/users/willfeng/pytorch_yf225/torch/testing/_internal/common_fsdp.py:855 in forward, code: z = F.relu(z)
relu: "f32[8, 128]" = torch.ops.aten.relu.default(add_tensor_5); add_tensor_5 = None
# File: /data/users/willfeng/pytorch_yf225/torch/nn/modules/linear.py:116 in forward, code: return F.linear(input, self.weight, self.bias)
permute_3: "f32[128, 32]" = torch.ops.aten.permute.default(getitem_46, [1, 0]); getitem_46 = None
# No stacktrace found for following nodes
mm_default_4: "f32[8, 32]" = torch.ops.aten.mm.default(relu, permute_3); permute_3 = None
add_tensor_4: "f32[8, 32]" = torch.ops.aten.add.Tensor(mm_default_4, getitem_47); mm_default_4 = getitem_47 = None
# File: /data/users/willfeng/pytorch_yf225/torch/testing/_internal/common_fsdp.py:857 in forward, code: z = F.relu(z)
relu_1: "f32[8, 32]" = torch.ops.aten.relu.default(add_tensor_4); add_tensor_4 = None
# File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:55 in foreach_all_gather, code: torch._foreach_copy_(foreach_copy_dsts, param_all_gather_inputs)
_foreach_copy_2 = torch.ops.aten._foreach_copy.default([getitem, getitem_1, getitem_2, getitem_3], [primals_11, primals_12, primals_13, primals_14]); primals_11 = primals_12 = primals_13 = primals_14 = None
getitem_52: "f32[2048]" = _foreach_copy_2[0]
getitem_53: "f32[64]" = _foreach_copy_2[1]
getitem_54: "f32[2048]" = _foreach_copy_2[2]
getitem_55: "f32[16]" = _foreach_copy_2[3]; _foreach_copy_2 = None
# No stacktrace found for following nodes
slice_tensor_4: "f32[4176]" = torch.ops.aten.slice.Tensor(empty, 0, 4176, 8352)
slice_scatter_default_8: "f32[4176]" = torch.ops.aten.slice_scatter.default(slice_tensor_4, getitem_52, 0, 0, 2048); slice_tensor_4 = getitem_52 = None
slice_scatter_default_9: "f32[8352]" = torch.ops.aten.slice_scatter.default(empty, slice_scatter_default_8, 0, 4176, 8352); slice_scatter_default_8 = None
# File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:55 in foreach_all_gather, code: torch._foreach_copy_(foreach_copy_dsts, param_all_gather_inputs)
slice_13: "f32[4176]" = torch.ops.aten.slice.Tensor(slice_scatter_default_9, 0, 4176, 8352)
# No stacktrace found for following nodes
slice_tensor_5: "f32[4176]" = torch.ops.aten.slice.Tensor(slice_scatter_default_9, 0, 4176, 8352)
slice_scatter_default_10: "f32[4176]" = torch.ops.aten.slice_scatter.default(slice_tensor_5, getitem_53, 0, 2048, 2112); slice_tensor_5 = getitem_53 = None
slice_scatter_default_11: "f32[8352]" = torch.ops.aten.slice_scatter.default(slice_scatter_default_9, slice_scatter_default_10, 0, 4176, 8352); slice_scatter_default_9 = slice_scatter_default_10 = None
# File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:55 in foreach_all_gather, code: torch._foreach_copy_(foreach_copy_dsts, param_all_gather_inputs)
slice_14: "f32[4176]" = torch.ops.aten.slice.Tensor(slice_scatter_default_11, 0, 4176, 8352)
# No stacktrace found for following nodes
slice_tensor_6: "f32[4176]" = torch.ops.aten.slice.Tensor(slice_scatter_default_11, 0, 4176, 8352)
slice_scatter_default_12: "f32[4176]" = torch.ops.aten.slice_scatter.default(slice_tensor_6, getitem_54, 0, 2112, 4160); slice_tensor_6 = getitem_54 = None
slice_scatter_default_13: "f32[8352]" = torch.ops.aten.slice_scatter.default(slice_scatter_default_11, slice_scatter_default_12, 0, 4176, 8352); slice_scatter_default_11 = slice_scatter_default_12 = None
# File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:55 in foreach_all_gather, code: torch._foreach_copy_(foreach_copy_dsts, param_all_gather_inputs)
slice_15: "f32[4176]" = torch.ops.aten.slice.Tensor(slice_scatter_default_13, 0, 4176, 8352)
# No stacktrace found for following nodes
slice_tensor_7: "f32[4176]" = torch.ops.aten.slice.Tensor(slice_scatter_default_13, 0, 4176, 8352)
slice_scatter_default_14: "f32[4176]" = torch.ops.aten.slice_scatter.default(slice_tensor_7, getitem_55, 0, 4160, 4176); slice_tensor_7 = getitem_55 = None
slice_scatter_default_15: "f32[8352]" = torch.ops.aten.slice_scatter.default(slice_scatter_default_13, slice_scatter_default_14, 0, 4176, 8352); slice_scatter_default_13 = slice_scatter_default_14 = None
# File: /data/users/willfeng/pytorch_yf225/torch/distributed/_functional_collectives.py:229 in all_gather_tensor, code: tensor = torch.ops._c10d_functional.all_gather_into_tensor(
slice_20: "f32[4176]" = torch.ops.aten.slice.Tensor(slice_scatter_default_15, 0, 4176, 8352); slice_scatter_default_15 = None
all_gather_into_tensor_1: "f32[8352]" = torch.ops._c10d_functional.all_gather_into_tensor.default(slice_20, 2, '0'); slice_20 = None
# File: /data/users/willfeng/pytorch_yf225/torch/distributed/_functional_collectives.py:144 in wait_tensor, code: return torch.ops._c10d_functional.wait_tensor(tensor) # type: ignore[attr-defined]
wait_tensor_1: "f32[8352]" = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_1); all_gather_into_tensor_1 = None
# File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:128 in foreach_all_gather_copy_out, code: splits[i].contiguous().view(splits[i].numel()),
view_10: "f32[2, 4176]" = torch.ops.aten.reshape.default(wait_tensor_1, [2, -1]); wait_tensor_1 = None
split_with_sizes_16 = torch.ops.aten.split_with_sizes.default(view_10, [2048, 64, 2048, 16], 1); view_10 = None
getitem_76: "f32[2, 2048]" = split_with_sizes_16[0]
clone_4: "f32[2, 2048]" = torch.ops.aten.clone.default(getitem_76, memory_format = torch.contiguous_format); getitem_76 = None
view_11: "f32[4096]" = torch.ops.aten.reshape.default(clone_4, [4096]); clone_4 = None
# File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:127 in foreach_all_gather_copy_out, code: torch.as_strided(
as_strided_4: "f32[128, 32]" = torch.ops.aten.as_strided.default(view_11, [128, 32], [32, 1], 0); view_11 = None
# File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:128 in foreach_all_gather_copy_out, code: splits[i].contiguous().view(splits[i].numel()),
getitem_81: "f32[2, 64]" = split_with_sizes_16[1]
clone_5: "f32[2, 64]" = torch.ops.aten.clone.default(getitem_81, memory_format = torch.contiguous_format); getitem_81 = None
view_13: "f32[128]" = torch.ops.aten.reshape.default(clone_5, [128]); clone_5 = None
# File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:127 in foreach_all_gather_copy_out, code: torch.as_strided(
as_strided_5: "f32[128]" = torch.ops.aten.as_strided.default(view_13, [128], [1], 0); view_13 = None
# File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:128 in foreach_all_gather_copy_out, code: splits[i].contiguous().view(splits[i].numel()),
getitem_86: "f32[2, 2048]" = split_with_sizes_16[2]
clone_6: "f32[2, 2048]" = torch.ops.aten.clone.default(getitem_86, memory_format = torch.contiguous_format); getitem_86 = None
view_15: "f32[4096]" = torch.ops.aten.reshape.default(clone_6, [4096]); clone_6 = None
# File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:127 in foreach_all_gather_copy_out, code: torch.as_strided(
as_strided_6: "f32[32, 128]" = torch.ops.aten.as_strided.default(view_15, [32, 128], [128, 1], 0); view_15 = None
# File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:128 in foreach_all_gather_copy_out, code: splits[i].contiguous().view(splits[i].numel()),
getitem_91: "f32[2, 16]" = split_with_sizes_16[3]; split_with_sizes_16 = None
clone_7: "f32[2, 16]" = torch.ops.aten.clone.default(getitem_91, memory_format = torch.contiguous_format); getitem_91 = None
view_17: "f32[32]" = torch.ops.aten.reshape.default(clone_7, [32]); clone_7 = None
# File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:127 in foreach_all_gather_copy_out, code: torch.as_strided(
as_strided_7: "f32[32]" = torch.ops.aten.as_strided.default(view_17, [32], [1], 0); view_17 = None
# File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:138 in foreach_all_gather_copy_out, code: torch._foreach_copy_(out, splits_unpadded)
_foreach_copy_3 = torch.ops.aten._foreach_copy.default([primals_15, primals_16, primals_17, primals_18], [as_strided_4, as_strided_5, as_strided_6, as_strided_7]); primals_16 = primals_18 = as_strided_4 = as_strided_5 = as_strided_6 = as_strided_7 = None
getitem_92: "f32[128, 32]" = _foreach_copy_3[0]
getitem_93: "f32[128]" = _foreach_copy_3[1]
getitem_94: "f32[32, 128]" = _foreach_copy_3[2]
getitem_95: "f32[32]" = _foreach_copy_3[3]; _foreach_copy_3 = None
# File: /data/users/willfeng/pytorch_yf225/torch/nn/modules/linear.py:116 in forward, code: return F.linear(input, self.weight, self.bias)
permute_5: "f32[32, 128]" = torch.ops.aten.permute.default(getitem_92, [1, 0]); getitem_92 = None
# No stacktrace found for following nodes
mm_default_3: "f32[8, 128]" = torch.ops.aten.mm.default(relu_1, permute_5); permute_5 = None
add_tensor_3: "f32[8, 128]" = torch.ops.aten.add.Tensor(mm_default_3, getitem_93); mm_default_3 = getitem_93 = None
# File: /data/users/willfeng/pytorch_yf225/torch/testing/_internal/common_fsdp.py:855 in forward, code: z = F.relu(z)
relu_2: "f32[8, 128]" = torch.ops.aten.relu.default(add_tensor_3); add_tensor_3 = None
# File: /data/users/willfeng/pytorch_yf225/torch/nn/modules/linear.py:116 in forward, code: return F.linear(input, self.weight, self.bias)
permute_7: "f32[128, 32]" = torch.ops.aten.permute.default(getitem_94, [1, 0]); getitem_94 = None
# No stacktrace found for following nodes
mm_default_2: "f32[8, 32]" = torch.ops.aten.mm.default(relu_2, permute_7); permute_7 = None
add_tensor_2: "f32[8, 32]" = torch.ops.aten.add.Tensor(mm_default_2, getitem_95); mm_default_2 = getitem_95 = None
# File: /data/users/willfeng/pytorch_yf225/torch/testing/_internal/common_fsdp.py:857 in forward, code: z = F.relu(z)
relu_3: "f32[8, 32]" = torch.ops.aten.relu.default(add_tensor_2); add_tensor_2 = None
# File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:55 in foreach_all_gather, code: torch._foreach_copy_(foreach_copy_dsts, param_all_gather_inputs)
_foreach_copy_4 = torch.ops.aten._foreach_copy.default([getitem, getitem_1, getitem_2, getitem_3], [primals_19, primals_20, primals_21, primals_22]); getitem = getitem_1 = getitem_2 = getitem_3 = primals_19 = primals_20 = primals_21 = primals_22 = None
getitem_100: "f32[2048]" = _foreach_copy_4[0]
getitem_101: "f32[64]" = _foreach_copy_4[1]
getitem_102: "f32[2048]" = _foreach_copy_4[2]
getitem_103: "f32[16]" = _foreach_copy_4[3]; _foreach_copy_4 = None
# No stacktrace found for following nodes
slice_tensor_8: "f32[4176]" = torch.ops.aten.slice.Tensor(empty, 0, 4176, 8352)
slice_scatter_default_16: "f32[4176]" = torch.ops.aten.slice_scatter.default(slice_tensor_8, getitem_100, 0, 0, 2048); slice_tensor_8 = getitem_100 = None
slice_scatter_default_17: "f32[8352]" = torch.ops.aten.slice_scatter.default(empty, slice_scatter_default_16, 0, 4176, 8352); empty = slice_scatter_default_16 = None
# File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:55 in foreach_all_gather, code: torch._foreach_copy_(foreach_copy_dsts, param_all_gather_inputs)
slice_23: "f32[4176]" = torch.ops.aten.slice.Tensor(slice_scatter_default_17, 0, 4176, 8352)
# No stacktrace found for following nodes
slice_tensor_9: "f32[4176]" = torch.ops.aten.slice.Tensor(slice_scatter_default_17, 0, 4176, 8352)
slice_scatter_default_18: "f32[4176]" = torch.ops.aten.slice_scatter.default(slice_tensor_9, getitem_101, 0, 2048, 2112); slice_tensor_9 = getitem_101 = None
slice_scatter_default_19: "f32[8352]" = torch.ops.aten.slice_scatter.default(slice_scatter_default_17, slice_scatter_default_18, 0, 4176, 8352); slice_scatter_default_17 = slice_scatter_default_18 = None
# File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:55 in foreach_all_gather, code: torch._foreach_copy_(foreach_copy_dsts, param_all_gather_inputs)
slice_24: "f32[4176]" = torch.ops.aten.slice.Tensor(slice_scatter_default_19, 0, 4176, 8352)
# No stacktrace found for following nodes
slice_tensor_10: "f32[4176]" = torch.ops.aten.slice.Tensor(slice_scatter_default_19, 0, 4176, 8352)
slice_scatter_default_20: "f32[4176]" = torch.ops.aten.slice_scatter.default(slice_tensor_10, getitem_102, 0, 2112, 4160); slice_tensor_10 = getitem_102 = None
slice_scatter_default_21: "f32[8352]" = torch.ops.aten.slice_scatter.default(slice_scatter_default_19, slice_scatter_default_20, 0, 4176, 8352); slice_scatter_default_19 = slice_scatter_default_20 = None
# File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:55 in foreach_all_gather, code: torch._foreach_copy_(foreach_copy_dsts, param_all_gather_inputs)
slice_25: "f32[4176]" = torch.ops.aten.slice.Tensor(slice_scatter_default_21, 0, 4176, 8352)
# No stacktrace found for following nodes
slice_tensor_11: "f32[4176]" = torch.ops.aten.slice.Tensor(slice_scatter_default_21, 0, 4176, 8352)
slice_scatter_default_22: "f32[4176]" = torch.ops.aten.slice_scatter.default(slice_tensor_11, getitem_103, 0, 4160, 4176); slice_tensor_11 = getitem_103 = None
slice_scatter_default_23: "f32[8352]" = torch.ops.aten.slice_scatter.default(slice_scatter_default_21, slice_scatter_default_22, 0, 4176, 8352); slice_scatter_default_21 = slice_scatter_default_22 = None
# File: /data/users/willfeng/pytorch_yf225/torch/distributed/_functional_collectives.py:229 in all_gather_tensor, code: tensor = torch.ops._c10d_functional.all_gather_into_tensor(
slice_30: "f32[4176]" = torch.ops.aten.slice.Tensor(slice_scatter_default_23, 0, 4176, 8352); slice_scatter_default_23 = None
all_gather_into_tensor_2: "f32[8352]" = torch.ops._c10d_functional.all_gather_into_tensor.default(slice_30, 2, '0'); slice_30 = None
# File: /data/users/willfeng/pytorch_yf225/torch/distributed/_functional_collectives.py:144 in wait_tensor, code: return torch.ops._c10d_functional.wait_tensor(tensor) # type: ignore[attr-defined]
wait_tensor_2: "f32[8352]" = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_2); all_gather_into_tensor_2 = None
# File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:128 in foreach_all_gather_copy_out, code: splits[i].contiguous().view(splits[i].numel()),
view_19: "f32[2, 4176]" = torch.ops.aten.reshape.default(wait_tensor_2, [2, -1]); wait_tensor_2 = None
split_with_sizes_26 = torch.ops.aten.split_with_sizes.default(view_19, [2048, 64, 2048, 16], 1); view_19 = None
getitem_124: "f32[2, 2048]" = split_with_sizes_26[0]
clone_8: "f32[2, 2048]" = torch.ops.aten.clone.default(getitem_124, memory_format = torch.contiguous_format); getitem_124 = None
view_20: "f32[4096]" = torch.ops.aten.reshape.default(clone_8, [4096]); clone_8 = None
# File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:127 in foreach_all_gather_copy_out, code: torch.as_strided(
as_strided_8: "f32[128, 32]" = torch.ops.aten.as_strided.default(view_20, [128, 32], [32, 1], 0); view_20 = None
# File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:128 in foreach_all_gather_copy_out, code: splits[i].contiguous().view(splits[i].numel()),
getitem_129: "f32[2, 64]" = split_with_sizes_26[1]
clone_9: "f32[2, 64]" = torch.ops.aten.clone.default(getitem_129, memory_format = torch.contiguous_format); getitem_129 = None
view_22: "f32[128]" = torch.ops.aten.reshape.default(clone_9, [128]); clone_9 = None
# File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:127 in foreach_all_gather_copy_out, code: torch.as_strided(
as_strided_9: "f32[128]" = torch.ops.aten.as_strided.default(view_22, [128], [1], 0); view_22 = None
# File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:128 in foreach_all_gather_copy_out, code: splits[i].contiguous().view(splits[i].numel()),
getitem_134: "f32[2, 2048]" = split_with_sizes_26[2]
clone_10: "f32[2, 2048]" = torch.ops.aten.clone.default(getitem_134, memory_format = torch.contiguous_format); getitem_134 = None
view_24: "f32[4096]" = torch.ops.aten.reshape.default(clone_10, [4096]); clone_10 = None
# File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:127 in foreach_all_gather_copy_out, code: torch.as_strided(
as_strided_10: "f32[32, 128]" = torch.ops.aten.as_strided.default(view_24, [32, 128], [128, 1], 0); view_24 = None
# File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:128 in foreach_all_gather_copy_out, code: splits[i].contiguous().view(splits[i].numel()),
getitem_139: "f32[2, 16]" = split_with_sizes_26[3]; split_with_sizes_26 = None
clone_11: "f32[2, 16]" = torch.ops.aten.clone.default(getitem_139, memory_format = torch.contiguous_format); getitem_139 = None
view_26: "f32[32]" = torch.ops.aten.reshape.default(clone_11, [32]); clone_11 = None
# File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:127 in foreach_all_gather_copy_out, code: torch.as_strided(
as_strided_11: "f32[32]" = torch.ops.aten.as_strided.default(view_26, [32], [1], 0); view_26 = None
# File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:138 in foreach_all_gather_copy_out, code: torch._foreach_copy_(out, splits_unpadded)
_foreach_copy_5 = torch.ops.aten._foreach_copy.default([primals_23, primals_24, primals_25, primals_26], [as_strided_8, as_strided_9, as_strided_10, as_strided_11]); primals_24 = primals_26 = as_strided_8 = as_strided_9 = as_strided_10 = as_strided_11 = None
getitem_140: "f32[128, 32]" = _foreach_copy_5[0]
getitem_141: "f32[128]" = _foreach_copy_5[1]
getitem_142: "f32[32, 128]" = _foreach_copy_5[2]
getitem_143: "f32[32]" = _foreach_copy_5[3]; _foreach_copy_5 = None
# File: /data/users/willfeng/pytorch_yf225/torch/nn/modules/linear.py:116 in forward, code: return F.linear(input, self.weight, self.bias)
permute_9: "f32[32, 128]" = torch.ops.aten.permute.default(getitem_140, [1, 0]); getitem_140 = None
# No stacktrace found for following nodes
mm_default_1: "f32[8, 128]" = torch.ops.aten.mm.default(relu_3, permute_9); permute_9 = None
add_tensor_1: "f32[8, 128]" = torch.ops.aten.add.Tensor(mm_default_1, getitem_141); mm_default_1 = getitem_141 = None
# File: /data/users/willfeng/pytorch_yf225/torch/testing/_internal/common_fsdp.py:855 in forward, code: z = F.relu(z)
relu_4: "f32[8, 128]" = torch.ops.aten.relu.default(add_tensor_1); add_tensor_1 = None
# File: /data/users/willfeng/pytorch_yf225/torch/nn/modules/linear.py:116 in forward, code: return F.linear(input, self.weight, self.bias)
permute_11: "f32[128, 32]" = torch.ops.aten.permute.default(getitem_142, [1, 0]); getitem_142 = None
# No stacktrace found for following nodes
mm_default: "f32[8, 32]" = torch.ops.aten.mm.default(relu_4, permute_11); permute_11 = None
add_tensor: "f32[8, 32]" = torch.ops.aten.add.Tensor(mm_default, getitem_143); mm_default = getitem_143 = None
# File: /data/users/willfeng/pytorch_yf225/torch/testing/_internal/common_fsdp.py:857 in forward, code: z = F.relu(z)
relu_5: "f32[8, 32]" = torch.ops.aten.relu.default(add_tensor); add_tensor = None
le: "b8[8, 32]" = torch.ops.aten.le.Scalar(relu_5, 0)
return [relu_5, primals_1, primals_8, primals_15, primals_17, primals_23, primals_25, relu, relu_1, relu_2, relu_3, relu_4, le]
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment