Skip to content

Instantly share code, notes, and snippets.

@Algomancer
Created January 27, 2025 22:36
Show Gist options
  • Select an option

  • Save Algomancer/692b905f05535614254400a96c1be5a3 to your computer and use it in GitHub Desktop.

Select an option

Save Algomancer/692b905f05535614254400a96c1be5a3 to your computer and use it in GitHub Desktop.
import torch
class broadcast_right:
def __init__(self, dim=0):
self.dim = dim
self._old_add = None
def __enter__(self):
self._old_add = torch.Tensor.__add__
def _broadcast_add(left, right):
if left.dim() > right.dim() and right.dim() == 1:
right = right.unsqueeze(self.dim).expand_as(left)
return self._old_add(left, right)
torch.Tensor.__add__ = _broadcast_add
return self
def __exit__(self, exc_type, exc_val, exc_tb):
torch.Tensor.__add__ = self._old_add
if __name__ == "__main__":
A = torch.randn(3, 4)
B = torch.randn(3) # shape [3]
# C_normal = A + B # RuntimeError: The size of tensor a (4) must match ...
with broadcast_right(dim=1):
C = A + B # => B.unsqueeze(1).expand_as(A) => shape [3,4]
print("C.shape:", C.shape)
print("C:", C)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment