Skip to content

Instantly share code, notes, and snippets.

@yiliu30
Created August 25, 2025 03:20
Show Gist options
  • Select an option

  • Save yiliu30/5535ac154cdd000d731a9fdd385b5df8 to your computer and use it in GitHub Desktop.

Select an option

Save yiliu30/5535ac154cdd000d731a9fdd385b5df8 to your computer and use it in GitHub Desktop.
from triton.testing import do_bench
import torch
from test_packing import _create_random_e2m1_tensor, pack_fp4_to_uint8_old
from auto_round.export.export_to_autoround.qlinear_fp import FLOAT_TO_E2M1, pack_fp4_to_uint8
device = "cpu"
device = "cuda"
# Example usage:
# Create a 1000-element tensor with random values from FLOAT_TO_E2M1
model_name = "/models/Qwen3-30B-A3B"
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers import AutoConfig
qwen_model_config = AutoConfig.from_pretrained(model_name, trust_remote_code=True)
qwen_model_config.num_hidden_layers = 4
with torch.device("meta"):
model = AutoModelForCausalLM.from_config(qwen_model_config, trust_remote_code=True)
def get_weight_shape(model):
weight_shapes = set()
for name, module in model.named_modules():
if isinstance(module, (torch.nn.Linear)):
weight_shape = module.weight.shape
weight_shapes.add(weight_shape)
return weight_shapes
qwen_shape = get_weight_shape(model)
print(f"Got model shape: {qwen_shape}")
def bench_packing(shape):
print(f"Benchmarking packing for shape: {shape}")
with torch.device(device):
M, N = shape
num_elements = M * N
random_tensor = _create_random_e2m1_tensor((M, N ))
# Pack the tensor using the packing function
packed_tensor = pack_fp4_to_uint8(random_tensor)
packed_tensor_old = pack_fp4_to_uint8_old(random_tensor)
# check equal
assert torch.equal(packed_tensor, packed_tensor_old), "Packed tensors are not equal"
old_time = do_bench(lambda: pack_fp4_to_uint8_old(random_tensor), warmup=10, rep=1000)
new_time = do_bench(lambda: pack_fp4_to_uint8(random_tensor), warmup=10, rep=1000)
print(f"Old packing time: {old_time:.6f} seconds")
print(f"New packing time: {new_time:.6f} seconds")
print(f"Speedup: {old_time / new_time:.2f}x")
for shape in qwen_shape:
bench_packing(shape)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment