Created
April 21, 2025 18:19
-
-
Save garrett361/c84a721ef33cd3b142a093c6e09bfcee to your computer and use it in GitHub Desktop.
test torch._grouped_mm
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| torch.__version__='2.8.0.dev20250421+cu126' | |
| Passed first test | |
| Traceback (most recent call last): | |
| File "/app/torchtitan/gemm_test.py", line 38, in <module> | |
| test_gemm_backward_fails() | |
| File "/app/torchtitan/gemm_test.py", line 31, in test_gemm_backward_fails | |
| out.sum().backward() | |
| File "/opt/conda/envs/ai/lib/python3.11/site-packages/torch/_tensor.py", line 648, in backward | |
| torch.autograd.backward( | |
| File "/opt/conda/envs/ai/lib/python3.11/site-packages/torch/autograd/__init__.py", line 354, in backward | |
| _engine_run_backward( | |
| File "/opt/conda/envs/ai/lib/python3.11/site-packages/torch/autograd/graph.py", line 824, in _engine_run_backward | |
| return Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass | |
| ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ | |
| RuntimeError: Tensor should have a contiguous dimension and not be self-overlapping, got [0, 0, 0] for strides and [4, 16, 32] for sizes |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import torch | |
| def get_output(): | |
| device = "cuda" | |
| dtype = torch.bfloat16 | |
| strided = True | |
| s_int = int(strided) | |
| m, n, k, n_groups = 16, 32, 16, 4 | |
| a = torch.randn( | |
| n_groups * (1 + s_int), m, k * (1 + s_int), device=device, dtype=dtype | |
| )[:: (1 + s_int), :, :k] | |
| b = torch.randn( | |
| n_groups * (1 + s_int), n, k * (1 + s_int), device=device, dtype=dtype | |
| )[:: (1 + s_int), :, :k] | |
| a.requires_grad_(True) | |
| b.requires_grad_(True) | |
| out = torch._grouped_mm(a, b.transpose(-2, -1), out_dtype=torch.bfloat16) | |
| return out | |
| def test_gemm_backward_runs(): | |
| out = get_output() | |
| gO = torch.rand_like(out) | |
| out.backward(gO) | |
| def test_gemm_backward_fails(): | |
| out = get_output() | |
| out.sum().backward() | |
| if __name__ == "__main__": | |
| print(f"{torch.__version__=}") | |
| test_gemm_backward_runs() | |
| print("Passed first test") | |
| test_gemm_backward_fails() | |
| print("Passed second test") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment