Skip to content

Instantly share code, notes, and snippets.

@janeyx99
Created March 18, 2024 21:56
Show Gist options
  • Select an option

  • Save janeyx99/43ac564d8639cc9e3a5cc80f5400b3cf to your computer and use it in GitHub Desktop.

Select an option

Save janeyx99/43ac564d8639cc9e3a5cc80f5400b3cf to your computer and use it in GitHub Desktop.
Prototype OffloadTensor subclass
import torch
from torch.utils.weak import WeakTensorKeyDictionary
from torch.utils._pytree import tree_map
from torch.utils._python_dispatch import TorchDispatchMode
evictable_tensors: WeakTensorKeyDictionary = {}
class OffloadTensor(torch.Tensor):
cuda_elem: torch.Tensor
cpu_elem: torch.Tensor
can_evict: int
__slots__ = ['cpu_elem', 'cuda_elem']
@staticmethod
def __new__(cls, elem, *args, **kwargs):
r = torch.Tensor._make_wrapper_subclass( # type: ignore[attr-defined]
cls, elem.size(),
strides=elem.stride(), storage_offset=elem.storage_offset(),
# TODO: clone storage aliasing
dtype=elem.dtype, layout=elem.layout,
device=elem.device, requires_grad=kwargs.get("requires_grad", False)
)
# ...the real tensor is held as an element on the tensor.
r.cuda_elem = elem.detach() if r.requires_grad else elem
evictable_tensors[r] = None
return r
def __repr__(self):
return super().__repr__(tensor_contents=f"cuda: {self.cuda_elem}, cpu: {self.cpu_elem}")
@classmethod
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
assert False, "this should never be called!"
def evict(self):
if self.cuda_elem is not None:
print("About to EVICTTTTTT!")
self.cpu_elem = self.cuda_elem.cpu()
self.cuda_elem = None
del evictable_tensors[self]
return self
def materialize(self):
if self.cuda_elem is None:
print("About to MATERIALIZE!")
self.cuda_elem = self.cpu_elem.cuda()
self.cpu_elem = None
evictable_tensors[self] = None
return self
class OffloadTensorMode(TorchDispatchMode):
def __torch_dispatch__(self, func, types, args=(), kwargs=None):
def unwrap(e):
if isinstance(e, OffloadTensor):
t = e.materialize()
del evictable_tensors[t]
return t.cuda_elem
return e
def wrap(e):
return OffloadTensor(e) if isinstance(e, torch.Tensor) else e
def make_evictable(e):
if isinstance(e, OffloadTensor):
evictable_tensors[e] = None
if kwargs is None:
kwargs = {}
rs = func(*tree_map(unwrap, args), **tree_map(unwrap, kwargs))
while torch.cuda.memory_allocated() > 32_000_000 * 4:
# start evicting
evict_t = next(iter(evictable_tensors))
evict_t.evict()
tree_map(make_evictable, (args, kwargs))
return tree_map(wrap, rs)
# with OffloadTensorMode():
# a = torch.randn(8_000_000, device="cuda")
# b = a.sin()
# c = b.clone()
# d = c.clone()
# e = d.clone()
# f = e.clone()
# g = f.clone()
# h = a + b
params = [torch.rand(2_000_000, device="cuda") for _ in range(6)]
for p in params:
p.grad = torch.rand_like(p)
optimizer = torch.optim.AdamW(params, foreach=False)
with OffloadTensorMode():
for _ in range(3):
optimizer.step()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment