Skip to content

Instantly share code, notes, and snippets.

@shunting314
Created December 13, 2025 00:01
Show Gist options
  • Select an option

  • Save shunting314/2583db2ea0ea36bdad4ed1a5732bbc49 to your computer and use it in GitHub Desktop.

Select an option

Save shunting314/2583db2ea0ea36bdad4ed1a5732bbc49 to your computer and use it in GitHub Desktop.
class GraphModule(torch.nn.Module):
def forward(self, primals_1: "bf16[32, 1024, 768]", primals_2: "bf16[50257, 768]", primals_3: "bf16[50257]", primals_4: "i64[32, 1024]", tangents_1: "bf16[]"):
# File: /home/shunting/ws/pytorch/test/inductor/test_auto_chunker.py:135 in f, code: x = x * 2
mul: "bf16[32, 1024, 768]" = torch.ops.aten.mul.Tensor(primals_1, 2); primals_1 = None
# File: /home/shunting/ws/pytorch/test/inductor/test_auto_chunker.py:126 in forward, code: return self.ce(self.linear(x).view(B * T, -1), y.view(-1))
view: "bf16[32768, 768]" = torch.ops.aten.view.default(mul, [32768, 768]); mul = None
permute: "bf16[768, 50257]" = torch.ops.aten.permute.default(primals_2, [1, 0]); primals_2 = None
view_3: "i64[32768]" = torch.ops.aten.view.default(primals_4, [-1]); primals_4 = None
ne: "b8[32768]" = torch.ops.aten.ne.Scalar(view_3, -100)
full_default: "i64[]" = torch.ops.aten.full.default([], 0, dtype = torch.int64, layout = torch.strided, device = device(type='cuda', index=0), pin_memory = False)
where: "i64[32768]" = torch.ops.aten.where.self(ne, view_3, full_default); ne = full_default = None
unsqueeze: "i64[32768, 1]" = torch.ops.aten.unsqueeze.default(where, 1); where = None
ne_1: "b8[32768]" = torch.ops.aten.ne.Scalar(view_3, -100)
full_default_1: "bf16[]" = torch.ops.aten.full.default([], 0.0, dtype = torch.bfloat16, layout = torch.strided, device = device(type='cuda', index=0), pin_memory = False)
ne_2: "b8[32768]" = torch.ops.aten.ne.Scalar(view_3, -100)
sum_2: "i64[]" = torch.ops.aten.sum.default(ne_2); ne_2 = None
convert_element_type_5: "bf16[]" = torch.ops.prims.convert_element_type.default(sum_2, torch.bfloat16); sum_2 = None
unsqueeze_1: "i64[32768, 1]" = torch.ops.aten.unsqueeze.default(view_3, 1); view_3 = None
ne_3: "b8[32768, 1]" = torch.ops.aten.ne.Scalar(unsqueeze_1, -100)
full_default_2: "i64[]" = torch.ops.aten.full.default([], 0, dtype = torch.int64, layout = torch.strided, device = device(type='cuda', index=0), pin_memory = False)
where_2: "i64[32768, 1]" = torch.ops.aten.where.self(ne_3, unsqueeze_1, full_default_2); ne_3 = full_default_2 = None
ne_4: "b8[32768, 1]" = torch.ops.aten.ne.Scalar(unsqueeze_1, -100); unsqueeze_1 = None
# No stacktrace found for following nodes
tangent_overriden_as_one: "bf16[]" = torch.ops.aten.full.default((), 1, device = device(type='cuda', index=0), dtype = torch.bfloat16)
chunk_default = torch.ops.aten.chunk.default(view, 4, 0); view = None
getitem: "bf16[8192, 768]" = chunk_default[0]
getitem_1: "bf16[8192, 768]" = chunk_default[1]
getitem_2: "bf16[8192, 768]" = chunk_default[2]
getitem_3: "bf16[8192, 768]" = chunk_default[3]; chunk_default = None
chunk_default_1 = torch.ops.aten.chunk.default(unsqueeze, 4, 0); unsqueeze = None
getitem_4: "i64[8192, 1]" = chunk_default_1[0]
getitem_5: "i64[8192, 1]" = chunk_default_1[1]
getitem_6: "i64[8192, 1]" = chunk_default_1[2]
getitem_7: "i64[8192, 1]" = chunk_default_1[3]; chunk_default_1 = None
chunk_default_2 = torch.ops.aten.chunk.default(where_2, 4, 0); where_2 = None
getitem_8: "i64[8192, 1]" = chunk_default_2[0]
getitem_9: "i64[8192, 1]" = chunk_default_2[1]
getitem_10: "i64[8192, 1]" = chunk_default_2[2]
getitem_11: "i64[8192, 1]" = chunk_default_2[3]; chunk_default_2 = None
chunk_default_3 = torch.ops.aten.chunk.default(ne_4, 4, 0); ne_4 = None
getitem_12: "b8[8192, 1]" = chunk_default_3[0]
getitem_13: "b8[8192, 1]" = chunk_default_3[1]
getitem_14: "b8[8192, 1]" = chunk_default_3[2]
getitem_15: "b8[8192, 1]" = chunk_default_3[3]; chunk_default_3 = None
full_default_6: "bf16[50257, 768]" = torch.ops.aten.full.default((50257, 768), 0, device = device(type='cuda', index=0), dtype = torch.bfloat16)
full_default_7: "bf16[1, 50257]" = torch.ops.aten.full.default((1, 50257), 0, device = device(type='cuda', index=0), dtype = torch.bfloat16)
chunking_subgraph_0 = self.chunking_subgraph_0
invoke_subgraph = torch.ops.higher_order.invoke_subgraph(chunking_subgraph_0, 'chunking_subgraph_0', primals_3, getitem, permute, getitem_4, convert_element_type_5, getitem_8, getitem_12, tangent_overriden_as_one, full_default_6, full_default_7); chunking_subgraph_0 = getitem = getitem_4 = getitem_8 = getitem_12 = full_default_6 = full_default_7 = None
getitem_16: "bf16[8192, 1]" = invoke_subgraph[0]
getitem_17: "bf16[8192, 768]" = invoke_subgraph[1]
getitem_18: "bf16[50257, 768]" = invoke_subgraph[2]
getitem_19: "bf16[1, 50257]" = invoke_subgraph[3]; invoke_subgraph = None
chunking_subgraph_1 = self.chunking_subgraph_0
invoke_subgraph_1 = torch.ops.higher_order.invoke_subgraph(chunking_subgraph_1, 'chunking_subgraph_0', primals_3, getitem_1, permute, getitem_5, convert_element_type_5, getitem_9, getitem_13, tangent_overriden_as_one, getitem_18, getitem_19); chunking_subgraph_1 = getitem_1 = getitem_5 = getitem_9 = getitem_13 = getitem_18 = getitem_19 = None
getitem_20: "bf16[8192, 1]" = invoke_subgraph_1[0]
getitem_21: "bf16[8192, 768]" = invoke_subgraph_1[1]
getitem_22: "bf16[50257, 768]" = invoke_subgraph_1[2]
getitem_23: "bf16[1, 50257]" = invoke_subgraph_1[3]; invoke_subgraph_1 = None
chunking_subgraph_2 = self.chunking_subgraph_0
invoke_subgraph_2 = torch.ops.higher_order.invoke_subgraph(chunking_subgraph_2, 'chunking_subgraph_0', primals_3, getitem_2, permute, getitem_6, convert_element_type_5, getitem_10, getitem_14, tangent_overriden_as_one, getitem_22, getitem_23); chunking_subgraph_2 = getitem_2 = getitem_6 = getitem_10 = getitem_14 = getitem_22 = getitem_23 = None
getitem_24: "bf16[8192, 1]" = invoke_subgraph_2[0]
getitem_25: "bf16[8192, 768]" = invoke_subgraph_2[1]
getitem_26: "bf16[50257, 768]" = invoke_subgraph_2[2]
getitem_27: "bf16[1, 50257]" = invoke_subgraph_2[3]; invoke_subgraph_2 = None
chunking_subgraph_3 = self.chunking_subgraph_0
invoke_subgraph_3 = torch.ops.higher_order.invoke_subgraph(chunking_subgraph_3, 'chunking_subgraph_0', primals_3, getitem_3, permute, getitem_7, convert_element_type_5, getitem_11, getitem_15, tangent_overriden_as_one, getitem_26, getitem_27); chunking_subgraph_3 = primals_3 = getitem_3 = permute = getitem_7 = getitem_11 = getitem_15 = tangent_overriden_as_one = getitem_26 = getitem_27 = None
getitem_28: "bf16[8192, 1]" = invoke_subgraph_3[0]
getitem_29: "bf16[8192, 768]" = invoke_subgraph_3[1]
getitem_30: "bf16[50257, 768]" = invoke_subgraph_3[2]
getitem_31: "bf16[1, 50257]" = invoke_subgraph_3[3]; invoke_subgraph_3 = None
cat_default: "bf16[32768, 1]" = torch.ops.aten.cat.default([getitem_16, getitem_20, getitem_24, getitem_28], 0); getitem_16 = getitem_20 = getitem_24 = getitem_28 = None
cat_default_1: "bf16[32768, 768]" = torch.ops.aten.cat.default([getitem_17, getitem_21, getitem_25, getitem_29], 0); getitem_17 = getitem_21 = getitem_25 = getitem_29 = None
mul_tensor: "bf16[32768, 768]" = torch.ops.aten.mul.Tensor(cat_default_1, tangents_1); cat_default_1 = None
mul_tensor_1: "bf16[50257, 768]" = torch.ops.aten.mul.Tensor(getitem_30, tangents_1); getitem_30 = None
convert_element_type_default: "bf16[50257, 768]" = torch.ops.prims.convert_element_type.default(mul_tensor_1, torch.bfloat16); mul_tensor_1 = None
mul_tensor_2: "bf16[1, 50257]" = torch.ops.aten.mul.Tensor(getitem_31, tangents_1); getitem_31 = tangents_1 = None
convert_element_type_default_1: "bf16[1, 50257]" = torch.ops.prims.convert_element_type.default(mul_tensor_2, torch.bfloat16); mul_tensor_2 = None
# File: /home/shunting/ws/pytorch/test/inductor/test_auto_chunker.py:126 in forward, code: return self.ce(self.linear(x).view(B * T, -1), y.view(-1))
squeeze: "bf16[32768]" = torch.ops.aten.squeeze.dim(cat_default, 1); cat_default = None
neg: "bf16[32768]" = torch.ops.aten.neg.default(squeeze); squeeze = None
where_1: "bf16[32768]" = torch.ops.aten.where.self(ne_1, neg, full_default_1); ne_1 = neg = full_default_1 = None
sum_3: "bf16[]" = torch.ops.aten.sum.default(where_1); where_1 = None
div: "bf16[]" = torch.ops.aten.div.Tensor(sum_3, convert_element_type_5); sum_3 = convert_element_type_5 = None
view_6: "bf16[50257]" = torch.ops.aten.view.default(convert_element_type_default_1, [50257]); convert_element_type_default_1 = None
view_7: "bf16[32, 1024, 768]" = torch.ops.aten.view.default(mul_tensor, [32, 1024, 768]); mul_tensor = None
# File: /home/shunting/ws/pytorch/test/inductor/test_auto_chunker.py:135 in f, code: x = x * 2
mul_3: "bf16[32, 1024, 768]" = torch.ops.aten.mul.Tensor(view_7, 2); view_7 = None
return [div, mul_3, convert_element_type_default, view_6, None]
class chunking_subgraph_0(torch.nn.Module):
def forward(self, primals_3: "bf16[50257]", view: "bf16[8192, 768]", permute: "bf16[768, 50257]", unsqueeze: "i64[8192, 1]", convert_element_type_5: "bf16[]", where_2: "i64[8192, 1]", ne_4: "b8[8192, 1]", tangent_overriden_as_one: "bf16[]", full_default_6: "bf16[50257, 768]", full_default_7: "bf16[1, 50257]"):
# File: /home/shunting/ws/pytorch/test/inductor/test_auto_chunker.py:126 in forward, code: return self.ce(self.linear(x).view(B * T, -1), y.view(-1))
addmm: "bf16[8192, 50257]" = torch.ops.aten.addmm.default(primals_3, view, permute); primals_3 = None
convert_element_type_3: "f32[8192, 50257]" = torch.ops.prims.convert_element_type.default(addmm, torch.float32); addmm = None
amax: "f32[8192, 1]" = torch.ops.aten.amax.default(convert_element_type_3, [1], True)
sub: "f32[8192, 50257]" = torch.ops.aten.sub.Tensor(convert_element_type_3, amax); convert_element_type_3 = amax = None
exp: "f32[8192, 50257]" = torch.ops.aten.exp.default(sub)
sum_1: "f32[8192, 1]" = torch.ops.aten.sum.dim_IntList(exp, [1], True); exp = None
log: "f32[8192, 1]" = torch.ops.aten.log.default(sum_1); sum_1 = None
sub_1: "f32[8192, 50257]" = torch.ops.aten.sub.Tensor(sub, log); sub = log = None
convert_element_type_4: "bf16[8192, 50257]" = torch.ops.prims.convert_element_type.default(sub_1, torch.bfloat16); sub_1 = None
div_1: "bf16[]" = torch.ops.aten.div.Tensor(tangent_overriden_as_one, convert_element_type_5); tangent_overriden_as_one = convert_element_type_5 = None
# No stacktrace found for following nodes
full_default: "bf16[8192, 50257]" = torch.ops.aten.full.default([8192, 50257], 0, dtype = torch.bfloat16, layout = torch.strided, device = device(type='cuda', index=0), pin_memory = False)
# File: /home/shunting/ws/pytorch/test/inductor/test_auto_chunker.py:126 in forward, code: return self.ce(self.linear(x).view(B * T, -1), y.view(-1))
scatter: "bf16[8192, 50257]" = torch.ops.aten.scatter.value(full_default, 1, where_2, -1.0); full_default = where_2 = None
full_default_4: "bf16[]" = torch.ops.aten.full.default([], 0.0, dtype = torch.bfloat16, layout = torch.strided, device = device(type='cuda', index=0), pin_memory = False)
where_3: "bf16[8192, 1]" = torch.ops.aten.where.self(ne_4, div_1, full_default_4); ne_4 = div_1 = full_default_4 = None
mul_1: "bf16[8192, 50257]" = torch.ops.aten.mul.Tensor(scatter, where_3); scatter = where_3 = None
convert_element_type_6: "f32[8192, 50257]" = torch.ops.prims.convert_element_type.default(mul_1, torch.float32); mul_1 = None
convert_element_type_7: "f32[8192, 50257]" = torch.ops.prims.convert_element_type.default(convert_element_type_4, torch.float32)
exp_1: "f32[8192, 50257]" = torch.ops.aten.exp.default(convert_element_type_7); convert_element_type_7 = None
sum_4: "f32[8192, 1]" = torch.ops.aten.sum.dim_IntList(convert_element_type_6, [1], True)
mul_2: "f32[8192, 50257]" = torch.ops.aten.mul.Tensor(exp_1, sum_4); exp_1 = sum_4 = None
sub_2: "f32[8192, 50257]" = torch.ops.aten.sub.Tensor(convert_element_type_6, mul_2); convert_element_type_6 = mul_2 = None
convert_element_type_8: "bf16[8192, 50257]" = torch.ops.prims.convert_element_type.default(sub_2, torch.bfloat16); sub_2 = None
permute_1: "bf16[50257, 768]" = torch.ops.aten.permute.default(permute, [1, 0]); permute = None
permute_2: "bf16[50257, 8192]" = torch.ops.aten.permute.default(convert_element_type_8, [1, 0])
gather: "bf16[8192, 1]" = torch.ops.aten.gather.default(convert_element_type_4, 1, unsqueeze); convert_element_type_4 = unsqueeze = None
mm: "bf16[8192, 768]" = torch.ops.aten.mm.default(convert_element_type_8, permute_1); permute_1 = None
mm_1: "bf16[50257, 768]" = torch.ops.aten.mm.default(permute_2, view); permute_2 = view = None
sum_5: "bf16[1, 50257]" = torch.ops.aten.sum.dim_IntList(convert_element_type_8, [0], True); convert_element_type_8 = None
# No stacktrace found for following nodes
add_tensor: "bf16[50257, 768]" = torch.ops.aten.add.Tensor(mm_1, full_default_6); mm_1 = full_default_6 = None
add_tensor_1: "bf16[1, 50257]" = torch.ops.aten.add.Tensor(sum_5, full_default_7); sum_5 = full_default_7 = None
return (gather, mm, add_tensor, add_tensor_1)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment