Created
March 23, 2024 19:07
-
-
Save yf225/8eebf52f5fe3352caaa4f08cdcd61623 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| ===== Joint graph 0 ===== | |
| /data/users/willfeng/pytorch_yf225/torch/fx/_lazy_graph_module.py class joint_helper(torch.nn.Module): | |
| def forward(self, primals, tangents): | |
| primals_1: "f32[4, 16]"; primals_2: "f32[128]"; primals_3: "f32[8]"; primals_4: "f32[60]"; primals_5: "f32[4]"; primals_6: "f32[15, 16]"; primals_7: "f32[15]"; primals_8: "f32[8, 15]"; primals_9: "f32[8]"; tangents_1: "f32[4, 8]"; | |
| primals_1, primals_2, primals_3, primals_4, primals_5, primals_6, primals_7, primals_8, primals_9, primals_10, tangents_1, = fx_pytree.tree_flatten_spec([primals, tangents], self._in_spec) | |
| # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:47 in foreach_all_gather, code: all_gather_output = torch.empty( | |
| empty: "f32[400]" = torch.ops.aten.empty.memory_format([400], dtype = torch.float32, device = device(type='cuda', index=1), pin_memory = False) | |
| # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:50 in foreach_all_gather, code: all_gather_input = all_gather_output.narrow( | |
| slice_1: "f32[200]" = torch.ops.aten.slice.Tensor(empty, 0, 200, 400) | |
| # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:53 in foreach_all_gather, code: foreach_copy_dsts = torch.split(all_gather_input, inp_split_sizes) | |
| split_with_sizes = torch.ops.aten.split_with_sizes.default(slice_1, [128, 8, 60, 4]); slice_1 = None | |
| getitem: "f32[128]" = split_with_sizes[0] | |
| getitem_1: "f32[8]" = split_with_sizes[1] | |
| getitem_2: "f32[60]" = split_with_sizes[2] | |
| getitem_3: "f32[4]" = split_with_sizes[3]; split_with_sizes = None | |
| # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:55 in foreach_all_gather, code: torch._foreach_copy_(foreach_copy_dsts, param_all_gather_inputs) | |
| _foreach_copy = torch.ops.aten._foreach_copy.default([getitem, getitem_1, getitem_2, getitem_3], [primals_2, primals_3, primals_4, primals_5]); getitem = getitem_1 = getitem_2 = getitem_3 = primals_2 = primals_3 = primals_4 = primals_5 = None | |
| getitem_4: "f32[128]" = _foreach_copy[0] | |
| getitem_5: "f32[8]" = _foreach_copy[1] | |
| getitem_6: "f32[60]" = _foreach_copy[2] | |
| getitem_7: "f32[4]" = _foreach_copy[3]; _foreach_copy = None | |
| slice_2: "f32[200]" = torch.ops.aten.slice.Tensor(empty, 0, 200, 400) | |
| slice_scatter: "f32[200]" = torch.ops.aten.slice_scatter.default(slice_2, getitem_4, 0, 0, 128); slice_2 = getitem_4 = None | |
| slice_scatter_1: "f32[400]" = torch.ops.aten.slice_scatter.default(empty, slice_scatter, 0, 200, 400); empty = slice_scatter = None | |
| slice_3: "f32[200]" = torch.ops.aten.slice.Tensor(slice_scatter_1, 0, 200, 400) | |
| slice_scatter_2: "f32[200]" = torch.ops.aten.slice_scatter.default(slice_3, getitem_5, 0, 128, 136); slice_3 = getitem_5 = None | |
| slice_scatter_3: "f32[400]" = torch.ops.aten.slice_scatter.default(slice_scatter_1, slice_scatter_2, 0, 200, 400); slice_scatter_1 = slice_scatter_2 = None | |
| slice_4: "f32[200]" = torch.ops.aten.slice.Tensor(slice_scatter_3, 0, 200, 400) | |
| slice_scatter_4: "f32[200]" = torch.ops.aten.slice_scatter.default(slice_4, getitem_6, 0, 136, 196); slice_4 = getitem_6 = None | |
| slice_scatter_5: "f32[400]" = torch.ops.aten.slice_scatter.default(slice_scatter_3, slice_scatter_4, 0, 200, 400); slice_scatter_3 = slice_scatter_4 = None | |
| slice_5: "f32[200]" = torch.ops.aten.slice.Tensor(slice_scatter_5, 0, 200, 400) | |
| slice_scatter_6: "f32[200]" = torch.ops.aten.slice_scatter.default(slice_5, getitem_7, 0, 196, 200); slice_5 = getitem_7 = None | |
| slice_scatter_7: "f32[400]" = torch.ops.aten.slice_scatter.default(slice_scatter_5, slice_scatter_6, 0, 200, 400); slice_scatter_5 = slice_scatter_6 = None | |
| # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_functional_collectives.py:229 in all_gather_tensor, code: tensor = torch.ops._c10d_functional.all_gather_into_tensor( | |
| slice_10: "f32[200]" = torch.ops.aten.slice.Tensor(slice_scatter_7, 0, 200, 400) | |
| all_gather_into_tensor: "f32[400]" = torch.ops._c10d_functional.all_gather_into_tensor.default(slice_10, 2, '0'); slice_10 = None | |
| # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_functional_collectives.py:144 in wait_tensor, code: return torch.ops._c10d_functional.wait_tensor(tensor) # type: ignore[attr-defined] | |
| wait_tensor: "f32[400]" = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor); all_gather_into_tensor = None | |
| # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_functional_collectives.py:975 in all_gather_tensor_inplace, code: return output_tensor.copy_(all_gather_tensor(input_tensor, gather_dim, group, tag)) | |
| copy: "f32[400]" = torch.ops.aten.copy.default(slice_scatter_7, wait_tensor); slice_scatter_7 = wait_tensor = None | |
| # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:128 in foreach_all_gather_copy_out, code: splits[i].contiguous().view(splits[i].numel()), | |
| view_1: "f32[2, 200]" = torch.ops.aten.view.default(copy, [2, -1]) | |
| split_with_sizes_6 = torch.ops.aten.split_with_sizes.default(view_1, [128, 8, 60, 4], 1); view_1 = None | |
| getitem_28: "f32[2, 128]" = split_with_sizes_6[0]; split_with_sizes_6 = None | |
| clone: "f32[2, 128]" = torch.ops.aten.clone.default(getitem_28, memory_format = torch.contiguous_format); getitem_28 = None | |
| view_2: "f32[256]" = torch.ops.aten.view.default(clone, [256]); clone = None | |
| # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:127 in foreach_all_gather_copy_out, code: torch.as_strided( | |
| as_strided: "f32[15, 16]" = torch.ops.aten.as_strided.default(view_2, [15, 16], [16, 1], 0); view_2 = None | |
| # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:128 in foreach_all_gather_copy_out, code: splits[i].contiguous().view(splits[i].numel()), | |
| view_3: "f32[2, 200]" = torch.ops.aten.view.default(copy, [2, -1]) | |
| split_with_sizes_7 = torch.ops.aten.split_with_sizes.default(view_3, [128, 8, 60, 4], 1); view_3 = None | |
| getitem_33: "f32[2, 8]" = split_with_sizes_7[1]; split_with_sizes_7 = None | |
| clone_1: "f32[2, 8]" = torch.ops.aten.clone.default(getitem_33, memory_format = torch.contiguous_format); getitem_33 = None | |
| view_4: "f32[16]" = torch.ops.aten.view.default(clone_1, [16]); clone_1 = None | |
| # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:127 in foreach_all_gather_copy_out, code: torch.as_strided( | |
| as_strided_1: "f32[15]" = torch.ops.aten.as_strided.default(view_4, [15], [1], 0); view_4 = None | |
| # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:128 in foreach_all_gather_copy_out, code: splits[i].contiguous().view(splits[i].numel()), | |
| view_5: "f32[2, 200]" = torch.ops.aten.view.default(copy, [2, -1]) | |
| split_with_sizes_8 = torch.ops.aten.split_with_sizes.default(view_5, [128, 8, 60, 4], 1); view_5 = None | |
| getitem_38: "f32[2, 60]" = split_with_sizes_8[2]; split_with_sizes_8 = None | |
| clone_2: "f32[2, 60]" = torch.ops.aten.clone.default(getitem_38, memory_format = torch.contiguous_format); getitem_38 = None | |
| view_6: "f32[120]" = torch.ops.aten.view.default(clone_2, [120]); clone_2 = None | |
| # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:127 in foreach_all_gather_copy_out, code: torch.as_strided( | |
| as_strided_2: "f32[8, 15]" = torch.ops.aten.as_strided.default(view_6, [8, 15], [15, 1], 0); view_6 = None | |
| # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:128 in foreach_all_gather_copy_out, code: splits[i].contiguous().view(splits[i].numel()), | |
| view_7: "f32[2, 200]" = torch.ops.aten.view.default(copy, [2, -1]); copy = None | |
| split_with_sizes_9 = torch.ops.aten.split_with_sizes.default(view_7, [128, 8, 60, 4], 1); view_7 = None | |
| getitem_43: "f32[2, 4]" = split_with_sizes_9[3]; split_with_sizes_9 = None | |
| clone_3: "f32[2, 4]" = torch.ops.aten.clone.default(getitem_43, memory_format = torch.contiguous_format); getitem_43 = None | |
| view_8: "f32[8]" = torch.ops.aten.view.default(clone_3, [8]); clone_3 = None | |
| # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:127 in foreach_all_gather_copy_out, code: torch.as_strided( | |
| as_strided_3: "f32[8]" = torch.ops.aten.as_strided.default(view_8, [8], [1], 0); view_8 = None | |
| # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:138 in foreach_all_gather_copy_out, code: torch._foreach_copy_(out, splits_unpadded) | |
| _foreach_copy_1 = torch.ops.aten._foreach_copy.default([primals_6, primals_7, primals_8, primals_9], [as_strided, as_strided_1, as_strided_2, as_strided_3]); primals_6 = primals_7 = primals_9 = as_strided = as_strided_1 = as_strided_2 = as_strided_3 = None | |
| getitem_44: "f32[15, 16]" = _foreach_copy_1[0] | |
| getitem_45: "f32[15]" = _foreach_copy_1[1] | |
| getitem_46: "f32[8, 15]" = _foreach_copy_1[2] | |
| getitem_47: "f32[8]" = _foreach_copy_1[3]; _foreach_copy_1 = None | |
| # File: /data/users/willfeng/pytorch_yf225/torch/nn/modules/linear.py:116 in forward, code: return F.linear(input, self.weight, self.bias) | |
| permute_1: "f32[16, 15]" = torch.ops.aten.permute.default(getitem_44, [1, 0]); getitem_44 = None | |
| addmm: "f32[4, 15]" = torch.ops.aten.addmm.default(getitem_45, primals_1, permute_1); getitem_45 = permute_1 = None | |
| # File: /data/users/willfeng/pytorch_yf225/torch/nn/modules/activation.py:103 in forward, code: return F.relu(input, inplace=self.inplace) | |
| relu: "f32[4, 15]" = torch.ops.aten.relu.default(addmm); addmm = None | |
| alias: "f32[4, 15]" = torch.ops.aten.alias.default(relu) | |
| alias_1: "f32[4, 15]" = torch.ops.aten.alias.default(alias); alias = None | |
| # File: /data/users/willfeng/pytorch_yf225/torch/nn/modules/linear.py:116 in forward, code: return F.linear(input, self.weight, self.bias) | |
| permute_3: "f32[15, 8]" = torch.ops.aten.permute.default(getitem_46, [1, 0]); getitem_46 = None | |
| addmm_1: "f32[4, 8]" = torch.ops.aten.addmm.default(getitem_47, relu, permute_3); getitem_47 = permute_3 = None | |
| # No stacktrace found for following nodes | |
| trace_wrapped: "f32[4, 8]" = torch__dynamo__trace_wrapped_higher_order_op_self_invoke(tangents_1, bw_state = primals_10); tangents_1 = primals_10 = None | |
| # File: /data/users/willfeng/pytorch_yf225/torch/nn/modules/linear.py:116 in forward, code: return F.linear(input, self.weight, self.bias) | |
| permute_5: "f32[15, 8]" = torch.ops.aten.permute.default(primals_8, [1, 0]); primals_8 = None | |
| permute_6: "f32[8, 15]" = torch.ops.aten.permute.default(permute_5, [1, 0]); permute_5 = None | |
| mm: "f32[4, 15]" = torch.ops.aten.mm.default(trace_wrapped, permute_6); permute_6 = None | |
| permute_7: "f32[8, 4]" = torch.ops.aten.permute.default(trace_wrapped, [1, 0]) | |
| mm_1: "f32[8, 15]" = torch.ops.aten.mm.default(permute_7, relu); permute_7 = relu = None | |
| permute_8: "f32[15, 8]" = torch.ops.aten.permute.default(mm_1, [1, 0]); mm_1 = None | |
| sum_1: "f32[1, 8]" = torch.ops.aten.sum.dim_IntList(trace_wrapped, [0], True); trace_wrapped = None | |
| view_9: "f32[8]" = torch.ops.aten.view.default(sum_1, [8]); sum_1 = None | |
| permute_9: "f32[8, 15]" = torch.ops.aten.permute.default(permute_8, [1, 0]); permute_8 = None | |
| # File: /data/users/willfeng/pytorch_yf225/torch/nn/modules/activation.py:103 in forward, code: return F.relu(input, inplace=self.inplace) | |
| alias_2: "f32[4, 15]" = torch.ops.aten.alias.default(alias_1); alias_1 = None | |
| alias_3: "f32[4, 15]" = torch.ops.aten.alias.default(alias_2); alias_2 = None | |
| le: "b8[4, 15]" = torch.ops.aten.le.Scalar(alias_3, 0); alias_3 = None | |
| scalar_tensor: "f32[]" = torch.ops.aten.scalar_tensor.default(0, dtype = torch.float32, layout = torch.strided, device = device(type='cuda', index=1)) | |
| where: "f32[4, 15]" = torch.ops.aten.where.self(le, scalar_tensor, mm); le = scalar_tensor = mm = None | |
| # File: /data/users/willfeng/pytorch_yf225/torch/nn/modules/linear.py:116 in forward, code: return F.linear(input, self.weight, self.bias) | |
| permute_10: "f32[15, 4]" = torch.ops.aten.permute.default(where, [1, 0]) | |
| mm_2: "f32[15, 16]" = torch.ops.aten.mm.default(permute_10, primals_1); permute_10 = primals_1 = None | |
| permute_11: "f32[16, 15]" = torch.ops.aten.permute.default(mm_2, [1, 0]); mm_2 = None | |
| sum_2: "f32[1, 15]" = torch.ops.aten.sum.dim_IntList(where, [0], True); where = None | |
| view_10: "f32[15]" = torch.ops.aten.view.default(sum_2, [15]); sum_2 = None | |
| permute_12: "f32[15, 16]" = torch.ops.aten.permute.default(permute_11, [1, 0]); permute_11 = None | |
| return pytree.tree_unflatten([addmm_1, None, None, None, None, None, permute_12, view_10, permute_9, view_9, None], self._out_spec) | |
| /data/users/willfeng/pytorch_yf225/torch/_inductor/compile_fx.py:133: UserWarning: TensorFloat32 tensor cores for float32 matrix multiplication available but not enabled. Consider setting `torch.set_float32_matmul_precision('high')` for better performance. | |
| warnings.warn( | |
| TRACED GRAPH | |
| ===== Forward graph 0 ===== | |
| /data/users/willfeng/pytorch_yf225/torch/fx/_lazy_graph_module.py class GraphModule(torch.nn.Module): | |
| def forward(self, primals_1: "f32[4, 16]", primals_2: "f32[128]", primals_3: "f32[8]", primals_4: "f32[60]", primals_5: "f32[4]", primals_6: "f32[15, 16]", primals_7: "f32[15]", primals_8: "f32[8, 15]", primals_9: "f32[8]", primals_10): | |
| # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:47 in foreach_all_gather, code: all_gather_output = torch.empty( | |
| empty: "f32[400]" = torch.ops.aten.empty.memory_format([400], dtype = torch.float32, device = device(type='cuda', index=1), pin_memory = False) | |
| # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:50 in foreach_all_gather, code: all_gather_input = all_gather_output.narrow( | |
| slice_1: "f32[200]" = torch.ops.aten.slice.Tensor(empty, 0, 200, 400) | |
| # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:53 in foreach_all_gather, code: foreach_copy_dsts = torch.split(all_gather_input, inp_split_sizes) | |
| split_with_sizes = torch.ops.aten.split_with_sizes.default(slice_1, [128, 8, 60, 4]) | |
| getitem: "f32[128]" = split_with_sizes[0] | |
| full_default: "f32[8]" = torch.ops.aten.full.default([8], 0.0, dtype = torch.float32, layout = torch.strided, device = device(type='cuda', index=1), pin_memory = False) | |
| getitem_2: "f32[60]" = split_with_sizes[2] | |
| getitem_3: "f32[4]" = split_with_sizes[3]; split_with_sizes = None | |
| # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:55 in foreach_all_gather, code: torch._foreach_copy_(foreach_copy_dsts, param_all_gather_inputs) | |
| _foreach_copy = torch.ops.aten._foreach_copy_.default([getitem, full_default, getitem_2, getitem_3], [primals_2, primals_3, primals_4, primals_5]); primals_2 = primals_3 = primals_4 = primals_5 = None | |
| slice_scatter: "f32[200]" = torch.ops.aten.slice_scatter.default(slice_1, getitem, 0, 0, 128); slice_1 = getitem = None | |
| slice_scatter_1: "f32[400]" = torch.ops.aten.slice_scatter.default(empty, slice_scatter, 0, 200, 400); empty = slice_scatter = None | |
| slice_3: "f32[200]" = torch.ops.aten.slice.Tensor(slice_scatter_1, 0, 200, 400) | |
| slice_scatter_2: "f32[200]" = torch.ops.aten.slice_scatter.default(slice_3, full_default, 0, 128, 136); slice_3 = full_default = None | |
| slice_scatter_3: "f32[400]" = torch.ops.aten.slice_scatter.default(slice_scatter_1, slice_scatter_2, 0, 200, 400); slice_scatter_1 = slice_scatter_2 = None | |
| slice_4: "f32[200]" = torch.ops.aten.slice.Tensor(slice_scatter_3, 0, 200, 400) | |
| slice_scatter_4: "f32[200]" = torch.ops.aten.slice_scatter.default(slice_4, getitem_2, 0, 136, 196); slice_4 = getitem_2 = None | |
| slice_scatter_5: "f32[400]" = torch.ops.aten.slice_scatter.default(slice_scatter_3, slice_scatter_4, 0, 200, 400); slice_scatter_3 = slice_scatter_4 = None | |
| slice_5: "f32[200]" = torch.ops.aten.slice.Tensor(slice_scatter_5, 0, 200, 400) | |
| slice_scatter_6: "f32[200]" = torch.ops.aten.slice_scatter.default(slice_5, getitem_3, 0, 196, 200); slice_5 = getitem_3 = None | |
| slice_scatter_7: "f32[400]" = torch.ops.aten.slice_scatter.default(slice_scatter_5, slice_scatter_6, 0, 200, 400); slice_scatter_5 = slice_scatter_6 = None | |
| # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_functional_collectives.py:229 in all_gather_tensor, code: tensor = torch.ops._c10d_functional.all_gather_into_tensor( | |
| slice_10: "f32[200]" = torch.ops.aten.slice.Tensor(slice_scatter_7, 0, 200, 400) | |
| all_gather_into_tensor: "f32[400]" = torch.ops._c10d_functional.all_gather_into_tensor.default(slice_10, 2, '0'); slice_10 = None | |
| # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_functional_collectives.py:144 in wait_tensor, code: return torch.ops._c10d_functional.wait_tensor(tensor) # type: ignore[attr-defined] | |
| wait_tensor: "f32[400]" = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor); all_gather_into_tensor = None | |
| # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_functional_collectives.py:975 in all_gather_tensor_inplace, code: return output_tensor.copy_(all_gather_tensor(input_tensor, gather_dim, group, tag)) | |
| copy: "f32[400]" = torch.ops.aten.copy.default(slice_scatter_7, wait_tensor); slice_scatter_7 = wait_tensor = None | |
| # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:128 in foreach_all_gather_copy_out, code: splits[i].contiguous().view(splits[i].numel()), | |
| view_1: "f32[2, 200]" = torch.ops.aten.view.default(copy, [2, -1]); copy = None | |
| split_with_sizes_6 = torch.ops.aten.split_with_sizes.default(view_1, [128, 8, 60, 4], 1); view_1 = None | |
| getitem_28: "f32[2, 128]" = split_with_sizes_6[0] | |
| clone: "f32[2, 128]" = torch.ops.aten.clone.default(getitem_28, memory_format = torch.contiguous_format); getitem_28 = None | |
| view_2: "f32[256]" = torch.ops.aten.view.default(clone, [256]); clone = None | |
| # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:127 in foreach_all_gather_copy_out, code: torch.as_strided( | |
| as_strided: "f32[15, 16]" = torch.ops.aten.as_strided.default(view_2, [15, 16], [16, 1], 0); view_2 = None | |
| # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:128 in foreach_all_gather_copy_out, code: splits[i].contiguous().view(splits[i].numel()), | |
| getitem_33: "f32[2, 8]" = split_with_sizes_6[1] | |
| clone_1: "f32[2, 8]" = torch.ops.aten.clone.default(getitem_33, memory_format = torch.contiguous_format); getitem_33 = None | |
| view_4: "f32[16]" = torch.ops.aten.view.default(clone_1, [16]); clone_1 = None | |
| # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:127 in foreach_all_gather_copy_out, code: torch.as_strided( | |
| as_strided_1: "f32[15]" = torch.ops.aten.as_strided.default(view_4, [15], [1], 0); view_4 = None | |
| # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:128 in foreach_all_gather_copy_out, code: splits[i].contiguous().view(splits[i].numel()), | |
| getitem_38: "f32[2, 60]" = split_with_sizes_6[2] | |
| clone_2: "f32[2, 60]" = torch.ops.aten.clone.default(getitem_38, memory_format = torch.contiguous_format); getitem_38 = None | |
| view_6: "f32[120]" = torch.ops.aten.view.default(clone_2, [120]); clone_2 = None | |
| # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:127 in foreach_all_gather_copy_out, code: torch.as_strided( | |
| as_strided_2: "f32[8, 15]" = torch.ops.aten.as_strided.default(view_6, [8, 15], [15, 1], 0); view_6 = None | |
| # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:128 in foreach_all_gather_copy_out, code: splits[i].contiguous().view(splits[i].numel()), | |
| getitem_43: "f32[2, 4]" = split_with_sizes_6[3]; split_with_sizes_6 = None | |
| clone_3: "f32[2, 4]" = torch.ops.aten.clone.default(getitem_43, memory_format = torch.contiguous_format); getitem_43 = None | |
| view_8: "f32[8]" = torch.ops.aten.view.default(clone_3, [8]); clone_3 = None | |
| # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:127 in foreach_all_gather_copy_out, code: torch.as_strided( | |
| as_strided_3: "f32[8]" = torch.ops.aten.as_strided.default(view_8, [8], [1], 0); view_8 = None | |
| # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:138 in foreach_all_gather_copy_out, code: torch._foreach_copy_(out, splits_unpadded) | |
| _foreach_copy_1 = torch.ops.aten._foreach_copy_.default([primals_6, primals_7, primals_8, primals_9], [as_strided, as_strided_1, as_strided_2, as_strided_3]); as_strided = as_strided_1 = as_strided_2 = as_strided_3 = None | |
| # File: /data/users/willfeng/pytorch_yf225/torch/nn/modules/linear.py:116 in forward, code: return F.linear(input, self.weight, self.bias) | |
| permute_1: "f32[16, 15]" = torch.ops.aten.permute.default(primals_6, [1, 0]); primals_6 = None | |
| addmm: "f32[4, 15]" = torch.ops.aten.addmm.default(primals_7, primals_1, permute_1); primals_7 = permute_1 = None | |
| # File: /data/users/willfeng/pytorch_yf225/torch/nn/modules/activation.py:103 in forward, code: return F.relu(input, inplace=self.inplace) | |
| relu: "f32[4, 15]" = torch.ops.aten.relu.default(addmm); addmm = None | |
| # File: /data/users/willfeng/pytorch_yf225/torch/nn/modules/linear.py:116 in forward, code: return F.linear(input, self.weight, self.bias) | |
| permute_3: "f32[15, 8]" = torch.ops.aten.permute.default(primals_8, [1, 0]) | |
| addmm_1: "f32[4, 8]" = torch.ops.aten.addmm.default(primals_9, relu, permute_3); primals_9 = permute_3 = None | |
| return [addmm_1, primals_1, primals_8, relu] | |
| TRACED GRAPH | |
| ===== Backward graph 0 ===== | |
| /data/users/willfeng/pytorch_yf225/torch/fx/_lazy_graph_module.py class GraphModule(torch.nn.Module): | |
| def forward(self, primals_1: "f32[4, 16]", primals_8: "f32[8, 15]", relu: "f32[4, 15]", tangents_1: "f32[4, 8]", primals_10): | |
| # File: /data/users/willfeng/pytorch_yf225/torch/nn/modules/activation.py:103 in forward, code: return F.relu(input, inplace=self.inplace) | |
| alias: "f32[4, 15]" = torch.ops.aten.alias.default(relu) | |
| alias_1: "f32[4, 15]" = torch.ops.aten.alias.default(alias); alias = None | |
| # File: /data/users/willfeng/pytorch_yf225/torch/nn/modules/linear.py:116 in forward, code: return F.linear(input, self.weight, self.bias) | |
| permute_3: "f32[15, 8]" = torch.ops.aten.permute.default(primals_8, [1, 0]); primals_8 = None | |
| # No stacktrace found for following nodes | |
| trace_wrapped: "f32[4, 8]" = torch__dynamo__trace_wrapped_higher_order_op_self_invoke(tangents_1, bw_state = primals_10); tangents_1 = primals_10 = None | |
| # File: /data/users/willfeng/pytorch_yf225/torch/nn/modules/linear.py:116 in forward, code: return F.linear(input, self.weight, self.bias) | |
| permute_6: "f32[8, 15]" = torch.ops.aten.permute.default(permute_3, [1, 0]); permute_3 = None | |
| mm: "f32[4, 15]" = torch.ops.aten.mm.default(trace_wrapped, permute_6); permute_6 = None | |
| permute_7: "f32[8, 4]" = torch.ops.aten.permute.default(trace_wrapped, [1, 0]) | |
| mm_1: "f32[8, 15]" = torch.ops.aten.mm.default(permute_7, relu); permute_7 = relu = None | |
| permute_8: "f32[15, 8]" = torch.ops.aten.permute.default(mm_1, [1, 0]); mm_1 = None | |
| sum_1: "f32[1, 8]" = torch.ops.aten.sum.dim_IntList(trace_wrapped, [0], True); trace_wrapped = None | |
| view_9: "f32[8]" = torch.ops.aten.view.default(sum_1, [8]); sum_1 = None | |
| permute_9: "f32[8, 15]" = torch.ops.aten.permute.default(permute_8, [1, 0]); permute_8 = None | |
| # File: /data/users/willfeng/pytorch_yf225/torch/nn/modules/activation.py:103 in forward, code: return F.relu(input, inplace=self.inplace) | |
| alias_2: "f32[4, 15]" = torch.ops.aten.alias.default(alias_1); alias_1 = None | |
| alias_3: "f32[4, 15]" = torch.ops.aten.alias.default(alias_2); alias_2 = None | |
| le: "b8[4, 15]" = torch.ops.aten.le.Scalar(alias_3, 0); alias_3 = None | |
| full_default_1: "f32[]" = torch.ops.aten.full.default([], 0.0, dtype = torch.float32, layout = torch.strided, device = device(type='cuda', index=1), pin_memory = False) | |
| where: "f32[4, 15]" = torch.ops.aten.where.self(le, full_default_1, mm); le = full_default_1 = mm = None | |
| # File: /data/users/willfeng/pytorch_yf225/torch/nn/modules/linear.py:116 in forward, code: return F.linear(input, self.weight, self.bias) | |
| permute_10: "f32[15, 4]" = torch.ops.aten.permute.default(where, [1, 0]) | |
| mm_2: "f32[15, 16]" = torch.ops.aten.mm.default(permute_10, primals_1); permute_10 = primals_1 = None | |
| permute_11: "f32[16, 15]" = torch.ops.aten.permute.default(mm_2, [1, 0]); mm_2 = None | |
| sum_2: "f32[1, 15]" = torch.ops.aten.sum.dim_IntList(where, [0], True); where = None | |
| view_10: "f32[15]" = torch.ops.aten.view.default(sum_2, [15]); sum_2 = None | |
| permute_12: "f32[15, 16]" = torch.ops.aten.permute.default(permute_11, [1, 0]); permute_11 = None | |
| return [None, None, None, None, None, permute_12, view_10, permute_9, view_9, None] |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment