Created
March 18, 2024 21:56
-
-
Save janeyx99/43ac564d8639cc9e3a5cc80f5400b3cf to your computer and use it in GitHub Desktop.
Prototype OffloadTensor subclass
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import torch | |
| from torch.utils.weak import WeakTensorKeyDictionary | |
| from torch.utils._pytree import tree_map | |
| from torch.utils._python_dispatch import TorchDispatchMode | |
| evictable_tensors: WeakTensorKeyDictionary = {} | |
| class OffloadTensor(torch.Tensor): | |
| cuda_elem: torch.Tensor | |
| cpu_elem: torch.Tensor | |
| can_evict: int | |
| __slots__ = ['cpu_elem', 'cuda_elem'] | |
| @staticmethod | |
| def __new__(cls, elem, *args, **kwargs): | |
| r = torch.Tensor._make_wrapper_subclass( # type: ignore[attr-defined] | |
| cls, elem.size(), | |
| strides=elem.stride(), storage_offset=elem.storage_offset(), | |
| # TODO: clone storage aliasing | |
| dtype=elem.dtype, layout=elem.layout, | |
| device=elem.device, requires_grad=kwargs.get("requires_grad", False) | |
| ) | |
| # ...the real tensor is held as an element on the tensor. | |
| r.cuda_elem = elem.detach() if r.requires_grad else elem | |
| evictable_tensors[r] = None | |
| return r | |
| def __repr__(self): | |
| return super().__repr__(tensor_contents=f"cuda: {self.cuda_elem}, cpu: {self.cpu_elem}") | |
| @classmethod | |
| def __torch_dispatch__(cls, func, types, args=(), kwargs=None): | |
| assert False, "this should never be called!" | |
| def evict(self): | |
| if self.cuda_elem is not None: | |
| print("About to EVICTTTTTT!") | |
| self.cpu_elem = self.cuda_elem.cpu() | |
| self.cuda_elem = None | |
| del evictable_tensors[self] | |
| return self | |
| def materialize(self): | |
| if self.cuda_elem is None: | |
| print("About to MATERIALIZE!") | |
| self.cuda_elem = self.cpu_elem.cuda() | |
| self.cpu_elem = None | |
| evictable_tensors[self] = None | |
| return self | |
| class OffloadTensorMode(TorchDispatchMode): | |
| def __torch_dispatch__(self, func, types, args=(), kwargs=None): | |
| def unwrap(e): | |
| if isinstance(e, OffloadTensor): | |
| t = e.materialize() | |
| del evictable_tensors[t] | |
| return t.cuda_elem | |
| return e | |
| def wrap(e): | |
| return OffloadTensor(e) if isinstance(e, torch.Tensor) else e | |
| def make_evictable(e): | |
| if isinstance(e, OffloadTensor): | |
| evictable_tensors[e] = None | |
| if kwargs is None: | |
| kwargs = {} | |
| rs = func(*tree_map(unwrap, args), **tree_map(unwrap, kwargs)) | |
| while torch.cuda.memory_allocated() > 32_000_000 * 4: | |
| # start evicting | |
| evict_t = next(iter(evictable_tensors)) | |
| evict_t.evict() | |
| tree_map(make_evictable, (args, kwargs)) | |
| return tree_map(wrap, rs) | |
| # with OffloadTensorMode(): | |
| # a = torch.randn(8_000_000, device="cuda") | |
| # b = a.sin() | |
| # c = b.clone() | |
| # d = c.clone() | |
| # e = d.clone() | |
| # f = e.clone() | |
| # g = f.clone() | |
| # h = a + b | |
| params = [torch.rand(2_000_000, device="cuda") for _ in range(6)] | |
| for p in params: | |
| p.grad = torch.rand_like(p) | |
| optimizer = torch.optim.AdamW(params, foreach=False) | |
| with OffloadTensorMode(): | |
| for _ in range(3): | |
| optimizer.step() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment