Created
August 25, 2025 03:20
-
-
Save yiliu30/5535ac154cdd000d731a9fdd385b5df8 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| from triton.testing import do_bench | |
| import torch | |
| from test_packing import _create_random_e2m1_tensor, pack_fp4_to_uint8_old | |
| from auto_round.export.export_to_autoround.qlinear_fp import FLOAT_TO_E2M1, pack_fp4_to_uint8 | |
| device = "cpu" | |
| device = "cuda" | |
| # Example usage: | |
| # Create a 1000-element tensor with random values from FLOAT_TO_E2M1 | |
| model_name = "/models/Qwen3-30B-A3B" | |
| from transformers import AutoModelForCausalLM, AutoTokenizer | |
| from transformers import AutoConfig | |
| qwen_model_config = AutoConfig.from_pretrained(model_name, trust_remote_code=True) | |
| qwen_model_config.num_hidden_layers = 4 | |
| with torch.device("meta"): | |
| model = AutoModelForCausalLM.from_config(qwen_model_config, trust_remote_code=True) | |
| def get_weight_shape(model): | |
| weight_shapes = set() | |
| for name, module in model.named_modules(): | |
| if isinstance(module, (torch.nn.Linear)): | |
| weight_shape = module.weight.shape | |
| weight_shapes.add(weight_shape) | |
| return weight_shapes | |
| qwen_shape = get_weight_shape(model) | |
| print(f"Got model shape: {qwen_shape}") | |
| def bench_packing(shape): | |
| print(f"Benchmarking packing for shape: {shape}") | |
| with torch.device(device): | |
| M, N = shape | |
| num_elements = M * N | |
| random_tensor = _create_random_e2m1_tensor((M, N )) | |
| # Pack the tensor using the packing function | |
| packed_tensor = pack_fp4_to_uint8(random_tensor) | |
| packed_tensor_old = pack_fp4_to_uint8_old(random_tensor) | |
| # check equal | |
| assert torch.equal(packed_tensor, packed_tensor_old), "Packed tensors are not equal" | |
| old_time = do_bench(lambda: pack_fp4_to_uint8_old(random_tensor), warmup=10, rep=1000) | |
| new_time = do_bench(lambda: pack_fp4_to_uint8(random_tensor), warmup=10, rep=1000) | |
| print(f"Old packing time: {old_time:.6f} seconds") | |
| print(f"New packing time: {new_time:.6f} seconds") | |
| print(f"Speedup: {old_time / new_time:.2f}x") | |
| for shape in qwen_shape: | |
| bench_packing(shape) | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment