Created
December 13, 2025 00:01
-
-
Save shunting314/231e4e5c15923f3d16e069768628daaf to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| class inner_f(torch.nn.Module): | |
| def forward(self, primals, tangents): | |
| primals_1: "bf16[32, 1024, 768]"; primals_2: "bf16[50257, 768]"; primals_3: "bf16[50257]"; primals_4: "i64[32, 1024]"; tangents_1: "bf16[]"; | |
| primals_1, primals_2, primals_3, primals_4, tangents_1, = fx_pytree.tree_flatten_spec([primals, tangents], self._in_spec) | |
| # File: /home/shunting/ws/pytorch/test/inductor/test_auto_chunker.py:135 in f, code: x = x * 2 | |
| mul: "bf16[32, 1024, 768]" = torch.ops.aten.mul.Tensor(primals_1, 2); primals_1 = None | |
| # File: /home/shunting/ws/pytorch/test/inductor/test_auto_chunker.py:126 in forward, code: return self.ce(self.linear(x).view(B * T, -1), y.view(-1)) | |
| view: "bf16[32768, 768]" = torch.ops.aten.view.default(mul, [32768, 768]); mul = None | |
| permute: "bf16[768, 50257]" = torch.ops.aten.permute.default(primals_2, [1, 0]); primals_2 = None | |
| addmm: "bf16[32768, 50257]" = torch.ops.aten.addmm.default(primals_3, view, permute); primals_3 = None | |
| view_3: "i64[32768]" = torch.ops.aten.view.default(primals_4, [-1]); primals_4 = None | |
| convert_element_type_3: "f32[32768, 50257]" = torch.ops.prims.convert_element_type.default(addmm, torch.float32); addmm = None | |
| amax: "f32[32768, 1]" = torch.ops.aten.amax.default(convert_element_type_3, [1], True) | |
| sub: "f32[32768, 50257]" = torch.ops.aten.sub.Tensor(convert_element_type_3, amax); convert_element_type_3 = amax = None | |
| exp: "f32[32768, 50257]" = torch.ops.aten.exp.default(sub) | |
| sum_1: "f32[32768, 1]" = torch.ops.aten.sum.dim_IntList(exp, [1], True); exp = None | |
| log: "f32[32768, 1]" = torch.ops.aten.log.default(sum_1); sum_1 = None | |
| sub_1: "f32[32768, 50257]" = torch.ops.aten.sub.Tensor(sub, log); sub = log = None | |
| convert_element_type_4: "bf16[32768, 50257]" = torch.ops.prims.convert_element_type.default(sub_1, torch.bfloat16); sub_1 = None | |
| ne: "b8[32768]" = torch.ops.aten.ne.Scalar(view_3, -100) | |
| full_default: "i64[]" = torch.ops.aten.full.default([], 0, dtype = torch.int64, layout = torch.strided, device = device(type='cuda', index=0), pin_memory = False) | |
| where: "i64[32768]" = torch.ops.aten.where.self(ne, view_3, full_default); ne = full_default = None | |
| unsqueeze: "i64[32768, 1]" = torch.ops.aten.unsqueeze.default(where, 1); where = None | |
| gather: "bf16[32768, 1]" = torch.ops.aten.gather.default(convert_element_type_4, 1, unsqueeze); unsqueeze = None | |
| squeeze: "bf16[32768]" = torch.ops.aten.squeeze.dim(gather, 1); gather = None | |
| neg: "bf16[32768]" = torch.ops.aten.neg.default(squeeze); squeeze = None | |
| ne_1: "b8[32768]" = torch.ops.aten.ne.Scalar(view_3, -100) | |
| full_default_1: "bf16[]" = torch.ops.aten.full.default([], 0.0, dtype = torch.bfloat16, layout = torch.strided, device = device(type='cuda', index=0), pin_memory = False) | |
| where_1: "bf16[32768]" = torch.ops.aten.where.self(ne_1, neg, full_default_1); ne_1 = neg = full_default_1 = None | |
| ne_2: "b8[32768]" = torch.ops.aten.ne.Scalar(view_3, -100) | |
| sum_2: "i64[]" = torch.ops.aten.sum.default(ne_2); ne_2 = None | |
| convert_element_type_5: "bf16[]" = torch.ops.prims.convert_element_type.default(sum_2, torch.bfloat16); sum_2 = None | |
| sum_3: "bf16[]" = torch.ops.aten.sum.default(where_1); where_1 = None | |
| div: "bf16[]" = torch.ops.aten.div.Tensor(sum_3, convert_element_type_5); sum_3 = None | |
| div_1: "bf16[]" = torch.ops.aten.div.Tensor(tangents_1, convert_element_type_5); tangents_1 = convert_element_type_5 = None | |
| unsqueeze_1: "i64[32768, 1]" = torch.ops.aten.unsqueeze.default(view_3, 1); view_3 = None | |
| ne_3: "b8[32768, 1]" = torch.ops.aten.ne.Scalar(unsqueeze_1, -100) | |
| full_default_2: "i64[]" = torch.ops.aten.full.default([], 0, dtype = torch.int64, layout = torch.strided, device = device(type='cuda', index=0), pin_memory = False) | |
| where_2: "i64[32768, 1]" = torch.ops.aten.where.self(ne_3, unsqueeze_1, full_default_2); ne_3 = full_default_2 = None | |
| full_default_3: "bf16[32768, 50257]" = torch.ops.aten.full.default([32768, 50257], 0, dtype = torch.bfloat16, layout = torch.strided, device = device(type='cuda', index=0), pin_memory = False) | |
| scatter: "bf16[32768, 50257]" = torch.ops.aten.scatter.value(full_default_3, 1, where_2, -1.0); full_default_3 = where_2 = None | |
| ne_4: "b8[32768, 1]" = torch.ops.aten.ne.Scalar(unsqueeze_1, -100); unsqueeze_1 = None | |
| full_default_4: "bf16[]" = torch.ops.aten.full.default([], 0.0, dtype = torch.bfloat16, layout = torch.strided, device = device(type='cuda', index=0), pin_memory = False) | |
| where_3: "bf16[32768, 1]" = torch.ops.aten.where.self(ne_4, div_1, full_default_4); ne_4 = div_1 = full_default_4 = None | |
| mul_1: "bf16[32768, 50257]" = torch.ops.aten.mul.Tensor(scatter, where_3); scatter = where_3 = None | |
| convert_element_type_6: "f32[32768, 50257]" = torch.ops.prims.convert_element_type.default(mul_1, torch.float32); mul_1 = None | |
| convert_element_type_7: "f32[32768, 50257]" = torch.ops.prims.convert_element_type.default(convert_element_type_4, torch.float32); convert_element_type_4 = None | |
| exp_1: "f32[32768, 50257]" = torch.ops.aten.exp.default(convert_element_type_7); convert_element_type_7 = None | |
| sum_4: "f32[32768, 1]" = torch.ops.aten.sum.dim_IntList(convert_element_type_6, [1], True) | |
| mul_2: "f32[32768, 50257]" = torch.ops.aten.mul.Tensor(exp_1, sum_4); exp_1 = sum_4 = None | |
| sub_2: "f32[32768, 50257]" = torch.ops.aten.sub.Tensor(convert_element_type_6, mul_2); convert_element_type_6 = mul_2 = None | |
| convert_element_type_8: "bf16[32768, 50257]" = torch.ops.prims.convert_element_type.default(sub_2, torch.bfloat16); sub_2 = None | |
| permute_1: "bf16[50257, 768]" = torch.ops.aten.permute.default(permute, [1, 0]); permute = None | |
| mm: "bf16[32768, 768]" = torch.ops.aten.mm.default(convert_element_type_8, permute_1); permute_1 = None | |
| permute_2: "bf16[50257, 32768]" = torch.ops.aten.permute.default(convert_element_type_8, [1, 0]) | |
| mm_1: "bf16[50257, 768]" = torch.ops.aten.mm.default(permute_2, view); permute_2 = view = None | |
| sum_5: "bf16[1, 50257]" = torch.ops.aten.sum.dim_IntList(convert_element_type_8, [0], True); convert_element_type_8 = None | |
| view_6: "bf16[50257]" = torch.ops.aten.view.default(sum_5, [50257]); sum_5 = None | |
| view_7: "bf16[32, 1024, 768]" = torch.ops.aten.view.default(mm, [32, 1024, 768]); mm = None | |
| # File: /home/shunting/ws/pytorch/test/inductor/test_auto_chunker.py:135 in f, code: x = x * 2 | |
| mul_3: "bf16[32, 1024, 768]" = torch.ops.aten.mul.Tensor(view_7, 2); view_7 = None | |
| return pytree.tree_unflatten([div, mul_3, mm_1, view_6, None], self._out_spec) | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment