Skip to content

Instantly share code, notes, and snippets.

@yf225
Created March 26, 2024 23:03
Show Gist options
  • Select an option

  • Save yf225/f4b5b98cfd343b97547a3e97cc5421c7 to your computer and use it in GitHub Desktop.

Select an option

Save yf225/f4b5b98cfd343b97547a3e97cc5421c7 to your computer and use it in GitHub Desktop.
FWD graph
===== AFTER POST GRAD =====
/data/users/willfeng/pytorch_yf225/torch/fx/_lazy_graph_module.py class GraphModule(torch.nn.Module):
def forward(self, primals_1: "f32[8, 32]", primals_2: "f32[2048]", primals_3: "f32[64]", primals_4: "f32[2048]", primals_5: "f32[16]", primals_6: "f32[128, 32]", primals_7: "f32[128]", primals_8: "f32[32, 128]", primals_9: "f32[32]", primals_10, primals_11: "f32[2048]", primals_12: "f32[64]", primals_13: "f32[2048]", primals_14: "f32[16]", primals_15: "f32[128, 32]", primals_16: "f32[128]", primals_17: "f32[32, 128]", primals_18: "f32[32]", primals_19: "f32[2048]", primals_20: "f32[64]", primals_21: "f32[2048]", primals_22: "f32[16]", primals_23: "f32[128, 32]", primals_24: "f32[128]", primals_25: "f32[32, 128]", primals_26: "f32[32]"):
# File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:47 in foreach_all_gather, code: all_gather_output = torch.empty(
empty: "f32[8352]" = torch.ops.aten.empty.memory_format([8352], dtype = torch.float32, device = device(type='cuda', index=1), pin_memory = False)
# File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:50 in foreach_all_gather, code: all_gather_input = all_gather_output.narrow(
slice_1: "f32[4176]" = torch.ops.aten.slice.Tensor(empty, 0, 4176, 8352); empty = None
# File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:53 in foreach_all_gather, code: foreach_copy_dsts = torch.split(all_gather_input, inp_split_sizes)
split_with_sizes = torch.ops.aten.split_with_sizes.default(slice_1, [2048, 64, 2048, 16])
getitem: "f32[2048]" = split_with_sizes[0]
getitem_1: "f32[64]" = split_with_sizes[1]
getitem_2: "f32[2048]" = split_with_sizes[2]
getitem_3: "f32[16]" = split_with_sizes[3]; split_with_sizes = None
# No stacktrace found for following nodes
copy__default = torch.ops.aten.copy_.default(getitem, primals_2); primals_2 = None
copy__default_1 = torch.ops.aten.copy_.default(getitem_1, primals_3); primals_3 = None
copy__default_2 = torch.ops.aten.copy_.default(getitem_2, primals_4); primals_4 = None
copy__default_3 = torch.ops.aten.copy_.default(getitem_3, primals_5); primals_5 = None
# File: /data/users/willfeng/pytorch_yf225/torch/distributed/_functional_collectives.py:229 in all_gather_tensor, code: tensor = torch.ops._c10d_functional.all_gather_into_tensor(
all_gather_into_tensor: "f32[8352]" = torch.ops._c10d_functional.all_gather_into_tensor.default(slice_1, 2, '0')
# File: /data/users/willfeng/pytorch_yf225/torch/distributed/_functional_collectives.py:144 in wait_tensor, code: return torch.ops._c10d_functional.wait_tensor(tensor) # type: ignore[attr-defined]
wait_tensor: "f32[8352]" = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor); all_gather_into_tensor = None
# File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:128 in foreach_all_gather_copy_out, code: splits[i].contiguous().view(splits[i].numel()),
view_1: "f32[2, 4176]" = torch.ops.aten.reshape.default(wait_tensor, [2, -1]); wait_tensor = None
split_with_sizes_6 = torch.ops.aten.split_with_sizes.default(view_1, [2048, 64, 2048, 16], 1); view_1 = None
getitem_28: "f32[2, 2048]" = split_with_sizes_6[0]
view_2: "f32[4096]" = torch.ops.aten.reshape.default(getitem_28, [4096]); getitem_28 = None
# File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:127 in foreach_all_gather_copy_out, code: torch.as_strided(
as_strided: "f32[128, 32]" = torch.ops.aten.as_strided.default(view_2, [128, 32], [32, 1], 0); view_2 = None
# File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:128 in foreach_all_gather_copy_out, code: splits[i].contiguous().view(splits[i].numel()),
getitem_33: "f32[2, 64]" = split_with_sizes_6[1]
view_4: "f32[128]" = torch.ops.aten.reshape.default(getitem_33, [128]); getitem_33 = None
# File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:127 in foreach_all_gather_copy_out, code: torch.as_strided(
as_strided_1: "f32[128]" = torch.ops.aten.as_strided.default(view_4, [128], [1], 0); view_4 = None
# File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:128 in foreach_all_gather_copy_out, code: splits[i].contiguous().view(splits[i].numel()),
getitem_38: "f32[2, 2048]" = split_with_sizes_6[2]
view_6: "f32[4096]" = torch.ops.aten.reshape.default(getitem_38, [4096]); getitem_38 = None
# File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:127 in foreach_all_gather_copy_out, code: torch.as_strided(
as_strided_2: "f32[32, 128]" = torch.ops.aten.as_strided.default(view_6, [32, 128], [128, 1], 0); view_6 = None
# File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:128 in foreach_all_gather_copy_out, code: splits[i].contiguous().view(splits[i].numel()),
getitem_43: "f32[2, 16]" = split_with_sizes_6[3]; split_with_sizes_6 = None
view_8: "f32[32]" = torch.ops.aten.reshape.default(getitem_43, [32]); getitem_43 = None
# File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:127 in foreach_all_gather_copy_out, code: torch.as_strided(
as_strided_3: "f32[32]" = torch.ops.aten.as_strided.default(view_8, [32], [1], 0); view_8 = None
# File: /data/users/willfeng/pytorch_yf225/torch/nn/modules/linear.py:116 in forward, code: return F.linear(input, self.weight, self.bias)
permute_1: "f32[32, 128]" = torch.ops.aten.permute.default(as_strided, [1, 0]); as_strided = None
# No stacktrace found for following nodes
mm_default_5: "f32[8, 128]" = torch.ops.aten.mm.default(primals_1, permute_1); permute_1 = None
add_tensor_5: "f32[8, 128]" = torch.ops.aten.add.Tensor(mm_default_5, as_strided_1); mm_default_5 = as_strided_1 = None
# File: /data/users/willfeng/pytorch_yf225/torch/testing/_internal/common_fsdp.py:855 in forward, code: z = F.relu(z)
relu: "f32[8, 128]" = torch.ops.aten.relu.default(add_tensor_5); add_tensor_5 = None
# File: /data/users/willfeng/pytorch_yf225/torch/nn/modules/linear.py:116 in forward, code: return F.linear(input, self.weight, self.bias)
permute_3: "f32[128, 32]" = torch.ops.aten.permute.default(as_strided_2, [1, 0]); as_strided_2 = None
# No stacktrace found for following nodes
mm_default_4: "f32[8, 32]" = torch.ops.aten.mm.default(relu, permute_3); permute_3 = None
add_tensor_4: "f32[8, 32]" = torch.ops.aten.add.Tensor(mm_default_4, as_strided_3); mm_default_4 = as_strided_3 = None
# File: /data/users/willfeng/pytorch_yf225/torch/testing/_internal/common_fsdp.py:857 in forward, code: z = F.relu(z)
relu_1: "f32[8, 32]" = torch.ops.aten.relu.default(add_tensor_4); add_tensor_4 = None
# No stacktrace found for following nodes
copy__default_4 = torch.ops.aten.copy_.default(getitem, primals_11); primals_11 = None
copy__default_5 = torch.ops.aten.copy_.default(getitem_1, primals_12); primals_12 = None
copy__default_6 = torch.ops.aten.copy_.default(getitem_2, primals_13); primals_13 = None
copy__default_7 = torch.ops.aten.copy_.default(getitem_3, primals_14); primals_14 = None
# File: /data/users/willfeng/pytorch_yf225/torch/distributed/_functional_collectives.py:229 in all_gather_tensor, code: tensor = torch.ops._c10d_functional.all_gather_into_tensor(
all_gather_into_tensor_1: "f32[8352]" = torch.ops._c10d_functional.all_gather_into_tensor.default(slice_1, 2, '0')
# File: /data/users/willfeng/pytorch_yf225/torch/distributed/_functional_collectives.py:144 in wait_tensor, code: return torch.ops._c10d_functional.wait_tensor(tensor) # type: ignore[attr-defined]
wait_tensor_1: "f32[8352]" = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_1); all_gather_into_tensor_1 = None
# File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:128 in foreach_all_gather_copy_out, code: splits[i].contiguous().view(splits[i].numel()),
view_10: "f32[2, 4176]" = torch.ops.aten.reshape.default(wait_tensor_1, [2, -1]); wait_tensor_1 = None
split_with_sizes_16 = torch.ops.aten.split_with_sizes.default(view_10, [2048, 64, 2048, 16], 1); view_10 = None
getitem_76: "f32[2, 2048]" = split_with_sizes_16[0]
view_11: "f32[4096]" = torch.ops.aten.reshape.default(getitem_76, [4096]); getitem_76 = None
# File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:127 in foreach_all_gather_copy_out, code: torch.as_strided(
as_strided_4: "f32[128, 32]" = torch.ops.aten.as_strided.default(view_11, [128, 32], [32, 1], 0); view_11 = None
# File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:128 in foreach_all_gather_copy_out, code: splits[i].contiguous().view(splits[i].numel()),
getitem_81: "f32[2, 64]" = split_with_sizes_16[1]
view_13: "f32[128]" = torch.ops.aten.reshape.default(getitem_81, [128]); getitem_81 = None
# File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:127 in foreach_all_gather_copy_out, code: torch.as_strided(
as_strided_5: "f32[128]" = torch.ops.aten.as_strided.default(view_13, [128], [1], 0); view_13 = None
# File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:128 in foreach_all_gather_copy_out, code: splits[i].contiguous().view(splits[i].numel()),
getitem_86: "f32[2, 2048]" = split_with_sizes_16[2]
view_15: "f32[4096]" = torch.ops.aten.reshape.default(getitem_86, [4096]); getitem_86 = None
# File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:127 in foreach_all_gather_copy_out, code: torch.as_strided(
as_strided_6: "f32[32, 128]" = torch.ops.aten.as_strided.default(view_15, [32, 128], [128, 1], 0); view_15 = None
# File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:128 in foreach_all_gather_copy_out, code: splits[i].contiguous().view(splits[i].numel()),
getitem_91: "f32[2, 16]" = split_with_sizes_16[3]; split_with_sizes_16 = None
view_17: "f32[32]" = torch.ops.aten.reshape.default(getitem_91, [32]); getitem_91 = None
# File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:127 in foreach_all_gather_copy_out, code: torch.as_strided(
as_strided_7: "f32[32]" = torch.ops.aten.as_strided.default(view_17, [32], [1], 0); view_17 = None
# File: /data/users/willfeng/pytorch_yf225/torch/nn/modules/linear.py:116 in forward, code: return F.linear(input, self.weight, self.bias)
permute_5: "f32[32, 128]" = torch.ops.aten.permute.default(as_strided_4, [1, 0]); as_strided_4 = None
# No stacktrace found for following nodes
mm_default_3: "f32[8, 128]" = torch.ops.aten.mm.default(relu_1, permute_5); permute_5 = None
add_tensor_3: "f32[8, 128]" = torch.ops.aten.add.Tensor(mm_default_3, as_strided_5); mm_default_3 = as_strided_5 = None
# File: /data/users/willfeng/pytorch_yf225/torch/testing/_internal/common_fsdp.py:855 in forward, code: z = F.relu(z)
relu_2: "f32[8, 128]" = torch.ops.aten.relu.default(add_tensor_3); add_tensor_3 = None
# File: /data/users/willfeng/pytorch_yf225/torch/nn/modules/linear.py:116 in forward, code: return F.linear(input, self.weight, self.bias)
permute_7: "f32[128, 32]" = torch.ops.aten.permute.default(as_strided_6, [1, 0]); as_strided_6 = None
# No stacktrace found for following nodes
mm_default_2: "f32[8, 32]" = torch.ops.aten.mm.default(relu_2, permute_7); permute_7 = None
add_tensor_2: "f32[8, 32]" = torch.ops.aten.add.Tensor(mm_default_2, as_strided_7); mm_default_2 = as_strided_7 = None
# File: /data/users/willfeng/pytorch_yf225/torch/testing/_internal/common_fsdp.py:857 in forward, code: z = F.relu(z)
relu_3: "f32[8, 32]" = torch.ops.aten.relu.default(add_tensor_2); add_tensor_2 = None
# No stacktrace found for following nodes
copy__default_8 = torch.ops.aten.copy_.default(getitem, primals_19); getitem = primals_19 = None
copy__default_9 = torch.ops.aten.copy_.default(getitem_1, primals_20); getitem_1 = primals_20 = None
copy__default_10 = torch.ops.aten.copy_.default(getitem_2, primals_21); getitem_2 = primals_21 = None
copy__default_11 = torch.ops.aten.copy_.default(getitem_3, primals_22); getitem_3 = primals_22 = None
# File: /data/users/willfeng/pytorch_yf225/torch/distributed/_functional_collectives.py:229 in all_gather_tensor, code: tensor = torch.ops._c10d_functional.all_gather_into_tensor(
all_gather_into_tensor_2: "f32[8352]" = torch.ops._c10d_functional.all_gather_into_tensor.default(slice_1, 2, '0'); slice_1 = None
# File: /data/users/willfeng/pytorch_yf225/torch/distributed/_functional_collectives.py:144 in wait_tensor, code: return torch.ops._c10d_functional.wait_tensor(tensor) # type: ignore[attr-defined]
wait_tensor_2: "f32[8352]" = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_2); all_gather_into_tensor_2 = None
# File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:128 in foreach_all_gather_copy_out, code: splits[i].contiguous().view(splits[i].numel()),
view_19: "f32[2, 4176]" = torch.ops.aten.reshape.default(wait_tensor_2, [2, -1]); wait_tensor_2 = None
split_with_sizes_26 = torch.ops.aten.split_with_sizes.default(view_19, [2048, 64, 2048, 16], 1); view_19 = None
getitem_124: "f32[2, 2048]" = split_with_sizes_26[0]
view_20: "f32[4096]" = torch.ops.aten.reshape.default(getitem_124, [4096]); getitem_124 = None
# File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:127 in foreach_all_gather_copy_out, code: torch.as_strided(
as_strided_8: "f32[128, 32]" = torch.ops.aten.as_strided.default(view_20, [128, 32], [32, 1], 0); view_20 = None
# File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:128 in foreach_all_gather_copy_out, code: splits[i].contiguous().view(splits[i].numel()),
getitem_129: "f32[2, 64]" = split_with_sizes_26[1]
view_22: "f32[128]" = torch.ops.aten.reshape.default(getitem_129, [128]); getitem_129 = None
# File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:127 in foreach_all_gather_copy_out, code: torch.as_strided(
as_strided_9: "f32[128]" = torch.ops.aten.as_strided.default(view_22, [128], [1], 0); view_22 = None
# File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:128 in foreach_all_gather_copy_out, code: splits[i].contiguous().view(splits[i].numel()),
getitem_134: "f32[2, 2048]" = split_with_sizes_26[2]
view_24: "f32[4096]" = torch.ops.aten.reshape.default(getitem_134, [4096]); getitem_134 = None
# File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:127 in foreach_all_gather_copy_out, code: torch.as_strided(
as_strided_10: "f32[32, 128]" = torch.ops.aten.as_strided.default(view_24, [32, 128], [128, 1], 0); view_24 = None
# File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:128 in foreach_all_gather_copy_out, code: splits[i].contiguous().view(splits[i].numel()),
getitem_139: "f32[2, 16]" = split_with_sizes_26[3]; split_with_sizes_26 = None
view_26: "f32[32]" = torch.ops.aten.reshape.default(getitem_139, [32]); getitem_139 = None
# File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:127 in foreach_all_gather_copy_out, code: torch.as_strided(
as_strided_11: "f32[32]" = torch.ops.aten.as_strided.default(view_26, [32], [1], 0); view_26 = None
# File: /data/users/willfeng/pytorch_yf225/torch/nn/modules/linear.py:116 in forward, code: return F.linear(input, self.weight, self.bias)
permute_9: "f32[32, 128]" = torch.ops.aten.permute.default(as_strided_8, [1, 0]); as_strided_8 = None
# No stacktrace found for following nodes
mm_default_1: "f32[8, 128]" = torch.ops.aten.mm.default(relu_3, permute_9); permute_9 = None
add_tensor_1: "f32[8, 128]" = torch.ops.aten.add.Tensor(mm_default_1, as_strided_9); mm_default_1 = as_strided_9 = None
# File: /data/users/willfeng/pytorch_yf225/torch/testing/_internal/common_fsdp.py:855 in forward, code: z = F.relu(z)
relu_4: "f32[8, 128]" = torch.ops.aten.relu.default(add_tensor_1); add_tensor_1 = None
# File: /data/users/willfeng/pytorch_yf225/torch/nn/modules/linear.py:116 in forward, code: return F.linear(input, self.weight, self.bias)
permute_11: "f32[128, 32]" = torch.ops.aten.permute.default(as_strided_10, [1, 0]); as_strided_10 = None
# No stacktrace found for following nodes
mm_default: "f32[8, 32]" = torch.ops.aten.mm.default(relu_4, permute_11); permute_11 = None
add_tensor: "f32[8, 32]" = torch.ops.aten.add.Tensor(mm_default, as_strided_11); mm_default = as_strided_11 = None
# File: /data/users/willfeng/pytorch_yf225/torch/testing/_internal/common_fsdp.py:857 in forward, code: z = F.relu(z)
relu_5: "f32[8, 32]" = torch.ops.aten.relu.default(add_tensor); add_tensor = None
le: "b8[8, 32]" = torch.ops.aten.le.Scalar(relu_5, 0)
return [relu_5, primals_1, primals_8, primals_15, primals_17, primals_23, primals_25, relu, relu_1, relu_2, relu_3, relu_4, le]
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment