Skip to content

Instantly share code, notes, and snippets.

@shunting314
Created December 13, 2025 00:01
Show Gist options
  • Select an option

  • Save shunting314/231e4e5c15923f3d16e069768628daaf to your computer and use it in GitHub Desktop.

Select an option

Save shunting314/231e4e5c15923f3d16e069768628daaf to your computer and use it in GitHub Desktop.
class inner_f(torch.nn.Module):
def forward(self, primals, tangents):
primals_1: "bf16[32, 1024, 768]"; primals_2: "bf16[50257, 768]"; primals_3: "bf16[50257]"; primals_4: "i64[32, 1024]"; tangents_1: "bf16[]";
primals_1, primals_2, primals_3, primals_4, tangents_1, = fx_pytree.tree_flatten_spec([primals, tangents], self._in_spec)
# File: /home/shunting/ws/pytorch/test/inductor/test_auto_chunker.py:135 in f, code: x = x * 2
mul: "bf16[32, 1024, 768]" = torch.ops.aten.mul.Tensor(primals_1, 2); primals_1 = None
# File: /home/shunting/ws/pytorch/test/inductor/test_auto_chunker.py:126 in forward, code: return self.ce(self.linear(x).view(B * T, -1), y.view(-1))
view: "bf16[32768, 768]" = torch.ops.aten.view.default(mul, [32768, 768]); mul = None
permute: "bf16[768, 50257]" = torch.ops.aten.permute.default(primals_2, [1, 0]); primals_2 = None
addmm: "bf16[32768, 50257]" = torch.ops.aten.addmm.default(primals_3, view, permute); primals_3 = None
view_3: "i64[32768]" = torch.ops.aten.view.default(primals_4, [-1]); primals_4 = None
convert_element_type_3: "f32[32768, 50257]" = torch.ops.prims.convert_element_type.default(addmm, torch.float32); addmm = None
amax: "f32[32768, 1]" = torch.ops.aten.amax.default(convert_element_type_3, [1], True)
sub: "f32[32768, 50257]" = torch.ops.aten.sub.Tensor(convert_element_type_3, amax); convert_element_type_3 = amax = None
exp: "f32[32768, 50257]" = torch.ops.aten.exp.default(sub)
sum_1: "f32[32768, 1]" = torch.ops.aten.sum.dim_IntList(exp, [1], True); exp = None
log: "f32[32768, 1]" = torch.ops.aten.log.default(sum_1); sum_1 = None
sub_1: "f32[32768, 50257]" = torch.ops.aten.sub.Tensor(sub, log); sub = log = None
convert_element_type_4: "bf16[32768, 50257]" = torch.ops.prims.convert_element_type.default(sub_1, torch.bfloat16); sub_1 = None
ne: "b8[32768]" = torch.ops.aten.ne.Scalar(view_3, -100)
full_default: "i64[]" = torch.ops.aten.full.default([], 0, dtype = torch.int64, layout = torch.strided, device = device(type='cuda', index=0), pin_memory = False)
where: "i64[32768]" = torch.ops.aten.where.self(ne, view_3, full_default); ne = full_default = None
unsqueeze: "i64[32768, 1]" = torch.ops.aten.unsqueeze.default(where, 1); where = None
gather: "bf16[32768, 1]" = torch.ops.aten.gather.default(convert_element_type_4, 1, unsqueeze); unsqueeze = None
squeeze: "bf16[32768]" = torch.ops.aten.squeeze.dim(gather, 1); gather = None
neg: "bf16[32768]" = torch.ops.aten.neg.default(squeeze); squeeze = None
ne_1: "b8[32768]" = torch.ops.aten.ne.Scalar(view_3, -100)
full_default_1: "bf16[]" = torch.ops.aten.full.default([], 0.0, dtype = torch.bfloat16, layout = torch.strided, device = device(type='cuda', index=0), pin_memory = False)
where_1: "bf16[32768]" = torch.ops.aten.where.self(ne_1, neg, full_default_1); ne_1 = neg = full_default_1 = None
ne_2: "b8[32768]" = torch.ops.aten.ne.Scalar(view_3, -100)
sum_2: "i64[]" = torch.ops.aten.sum.default(ne_2); ne_2 = None
convert_element_type_5: "bf16[]" = torch.ops.prims.convert_element_type.default(sum_2, torch.bfloat16); sum_2 = None
sum_3: "bf16[]" = torch.ops.aten.sum.default(where_1); where_1 = None
div: "bf16[]" = torch.ops.aten.div.Tensor(sum_3, convert_element_type_5); sum_3 = None
div_1: "bf16[]" = torch.ops.aten.div.Tensor(tangents_1, convert_element_type_5); tangents_1 = convert_element_type_5 = None
unsqueeze_1: "i64[32768, 1]" = torch.ops.aten.unsqueeze.default(view_3, 1); view_3 = None
ne_3: "b8[32768, 1]" = torch.ops.aten.ne.Scalar(unsqueeze_1, -100)
full_default_2: "i64[]" = torch.ops.aten.full.default([], 0, dtype = torch.int64, layout = torch.strided, device = device(type='cuda', index=0), pin_memory = False)
where_2: "i64[32768, 1]" = torch.ops.aten.where.self(ne_3, unsqueeze_1, full_default_2); ne_3 = full_default_2 = None
full_default_3: "bf16[32768, 50257]" = torch.ops.aten.full.default([32768, 50257], 0, dtype = torch.bfloat16, layout = torch.strided, device = device(type='cuda', index=0), pin_memory = False)
scatter: "bf16[32768, 50257]" = torch.ops.aten.scatter.value(full_default_3, 1, where_2, -1.0); full_default_3 = where_2 = None
ne_4: "b8[32768, 1]" = torch.ops.aten.ne.Scalar(unsqueeze_1, -100); unsqueeze_1 = None
full_default_4: "bf16[]" = torch.ops.aten.full.default([], 0.0, dtype = torch.bfloat16, layout = torch.strided, device = device(type='cuda', index=0), pin_memory = False)
where_3: "bf16[32768, 1]" = torch.ops.aten.where.self(ne_4, div_1, full_default_4); ne_4 = div_1 = full_default_4 = None
mul_1: "bf16[32768, 50257]" = torch.ops.aten.mul.Tensor(scatter, where_3); scatter = where_3 = None
convert_element_type_6: "f32[32768, 50257]" = torch.ops.prims.convert_element_type.default(mul_1, torch.float32); mul_1 = None
convert_element_type_7: "f32[32768, 50257]" = torch.ops.prims.convert_element_type.default(convert_element_type_4, torch.float32); convert_element_type_4 = None
exp_1: "f32[32768, 50257]" = torch.ops.aten.exp.default(convert_element_type_7); convert_element_type_7 = None
sum_4: "f32[32768, 1]" = torch.ops.aten.sum.dim_IntList(convert_element_type_6, [1], True)
mul_2: "f32[32768, 50257]" = torch.ops.aten.mul.Tensor(exp_1, sum_4); exp_1 = sum_4 = None
sub_2: "f32[32768, 50257]" = torch.ops.aten.sub.Tensor(convert_element_type_6, mul_2); convert_element_type_6 = mul_2 = None
convert_element_type_8: "bf16[32768, 50257]" = torch.ops.prims.convert_element_type.default(sub_2, torch.bfloat16); sub_2 = None
permute_1: "bf16[50257, 768]" = torch.ops.aten.permute.default(permute, [1, 0]); permute = None
mm: "bf16[32768, 768]" = torch.ops.aten.mm.default(convert_element_type_8, permute_1); permute_1 = None
permute_2: "bf16[50257, 32768]" = torch.ops.aten.permute.default(convert_element_type_8, [1, 0])
mm_1: "bf16[50257, 768]" = torch.ops.aten.mm.default(permute_2, view); permute_2 = view = None
sum_5: "bf16[1, 50257]" = torch.ops.aten.sum.dim_IntList(convert_element_type_8, [0], True); convert_element_type_8 = None
view_6: "bf16[50257]" = torch.ops.aten.view.default(sum_5, [50257]); sum_5 = None
view_7: "bf16[32, 1024, 768]" = torch.ops.aten.view.default(mm, [32, 1024, 768]); mm = None
# File: /home/shunting/ws/pytorch/test/inductor/test_auto_chunker.py:135 in f, code: x = x * 2
mul_3: "bf16[32, 1024, 768]" = torch.ops.aten.mul.Tensor(view_7, 2); view_7 = None
return pytree.tree_unflatten([div, mul_3, mm_1, view_6, None], self._out_spec)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment