Created
January 27, 2025 22:36
-
-
Save Algomancer/692b905f05535614254400a96c1be5a3 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import torch | |
| class broadcast_right: | |
| def __init__(self, dim=0): | |
| self.dim = dim | |
| self._old_add = None | |
| def __enter__(self): | |
| self._old_add = torch.Tensor.__add__ | |
| def _broadcast_add(left, right): | |
| if left.dim() > right.dim() and right.dim() == 1: | |
| right = right.unsqueeze(self.dim).expand_as(left) | |
| return self._old_add(left, right) | |
| torch.Tensor.__add__ = _broadcast_add | |
| return self | |
| def __exit__(self, exc_type, exc_val, exc_tb): | |
| torch.Tensor.__add__ = self._old_add | |
| if __name__ == "__main__": | |
| A = torch.randn(3, 4) | |
| B = torch.randn(3) # shape [3] | |
| # C_normal = A + B # RuntimeError: The size of tensor a (4) must match ... | |
| with broadcast_right(dim=1): | |
| C = A + B # => B.unsqueeze(1).expand_as(A) => shape [3,4] | |
| print("C.shape:", C.shape) | |
| print("C:", C) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment