Skip to content

Instantly share code, notes, and snippets.

@garrett361
Created April 21, 2025 18:19
Show Gist options
  • Select an option

  • Save garrett361/c84a721ef33cd3b142a093c6e09bfcee to your computer and use it in GitHub Desktop.

Select an option

Save garrett361/c84a721ef33cd3b142a093c6e09bfcee to your computer and use it in GitHub Desktop.
test torch._grouped_mm
torch.__version__='2.8.0.dev20250421+cu126'
Passed first test
Traceback (most recent call last):
File "/app/torchtitan/gemm_test.py", line 38, in <module>
test_gemm_backward_fails()
File "/app/torchtitan/gemm_test.py", line 31, in test_gemm_backward_fails
out.sum().backward()
File "/opt/conda/envs/ai/lib/python3.11/site-packages/torch/_tensor.py", line 648, in backward
torch.autograd.backward(
File "/opt/conda/envs/ai/lib/python3.11/site-packages/torch/autograd/__init__.py", line 354, in backward
_engine_run_backward(
File "/opt/conda/envs/ai/lib/python3.11/site-packages/torch/autograd/graph.py", line 824, in _engine_run_backward
return Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: Tensor should have a contiguous dimension and not be self-overlapping, got [0, 0, 0] for strides and [4, 16, 32] for sizes
import torch
def get_output():
device = "cuda"
dtype = torch.bfloat16
strided = True
s_int = int(strided)
m, n, k, n_groups = 16, 32, 16, 4
a = torch.randn(
n_groups * (1 + s_int), m, k * (1 + s_int), device=device, dtype=dtype
)[:: (1 + s_int), :, :k]
b = torch.randn(
n_groups * (1 + s_int), n, k * (1 + s_int), device=device, dtype=dtype
)[:: (1 + s_int), :, :k]
a.requires_grad_(True)
b.requires_grad_(True)
out = torch._grouped_mm(a, b.transpose(-2, -1), out_dtype=torch.bfloat16)
return out
def test_gemm_backward_runs():
out = get_output()
gO = torch.rand_like(out)
out.backward(gO)
def test_gemm_backward_fails():
out = get_output()
out.sum().backward()
if __name__ == "__main__":
print(f"{torch.__version__=}")
test_gemm_backward_runs()
print("Passed first test")
test_gemm_backward_fails()
print("Passed second test")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment