Skip to content

Instantly share code, notes, and snippets.

@rebeccajae
Created August 27, 2025 18:25
Show Gist options
  • Select an option

  • Save rebeccajae/f1767b9ced719aee03a073977c60245b to your computer and use it in GitHub Desktop.

Select an option

Save rebeccajae/f1767b9ced719aee03a073977c60245b to your computer and use it in GitHub Desktop.
import torch
def main():
if not torch.backends.mps.is_available():
print("MPS not available")
return
large_weight = torch.randn(12, 8, device='mps')
weight_sliced = large_weight[::2, ::1]
weight_contiguous_equiv = weight_sliced.contiguous()
input_s = torch.randn(2, 8, device='mps')
result_sliced = torch.nn.functional.linear(input_s, weight_sliced)
result_contig = torch.nn.functional.linear(input_s, weight_contiguous_equiv)
input_cpu = input_s.cpu()
weight_sliced_cpu = weight_sliced.cpu()
weight_contig_cpu = weight_contiguous_equiv.cpu()
result_sliced_cpu = torch.nn.functional.linear(input_cpu, weight_sliced_cpu)
result_contig_cpu = torch.nn.functional.linear(input_cpu, weight_contig_cpu)
print(f"weight_sliced.shape: {weight_sliced.shape}")
print(f"input_s.shape: {input_s.shape}")
print(f"weight_sliced.is_contiguous(): {weight_sliced.is_contiguous()}")
print(f"weight_sliced.stride(): {weight_sliced.stride()}")
print(f"weight_contiguous_equiv.stride(): {weight_contiguous_equiv.stride()}")
print(f"torch.equal(weight_sliced, weight_contiguous_equiv): {torch.equal(weight_sliced, weight_contiguous_equiv)}")
print(f"MPS: Contiguous and non-contiguous match: {torch.allclose(result_contig, result_sliced, atol=1e-4)}")
print(f"CPU: Contiguous and non-contiguous match: {torch.allclose(result_contig_cpu, result_sliced_cpu, atol=1e-4)}")
print(f"MPS contiguous vs CPU contiguous: {torch.allclose(result_contig, result_contig_cpu.to('mps'), atol=1e-4)}")
print(f"MPS non-contig vs CPU non-contig: {torch.allclose(result_sliced, result_sliced_cpu.to('mps'), atol=1e-4)}")
mps_matches_internally = torch.allclose(result_contig, result_sliced, atol=1e-4)
mps_matches_cpu = torch.allclose(result_sliced, result_sliced_cpu.to('mps'), atol=1e-4)
if not mps_matches_internally:
print("MPS result not consistent")
if not mps_matches_cpu:
print("MPS result does not match CPU result")
print(f"MPS contiguous result sample: {result_contig.flatten()[:5]}")
print(f"MPS non-contig result sample: {result_sliced.flatten()[:5]}")
print(f"CPU non-contig result sample: {result_sliced_cpu.flatten()[:5]}")
if __name__ == "__main__":
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment