Skip to content

Instantly share code, notes, and snippets.

# write a hook to flatten optimizer state_dict when saving
def flatten_state_dict(optim, osd):
flattened_sd = {}
state = osd['state']
for idx, param_group in enumerate(osd['param_groups']):
assert 'param_names' in param_group, "param names are required as they'll be used as keys"
for param_name, param_id in zip(param_group['param_names'], param_group['params']):
# add all the state
if param_id in state:
for key, value in state[param_id].items():
@janeyx99
janeyx99 / offloadtensor.py
Created March 18, 2024 21:56
Prototype OffloadTensor subclass
import torch
from torch.utils.weak import WeakTensorKeyDictionary
from torch.utils._pytree import tree_map
from torch.utils._python_dispatch import TorchDispatchMode
evictable_tensors: WeakTensorKeyDictionary = {}
class OffloadTensor(torch.Tensor):
cuda_elem: torch.Tensor
cpu_elem: torch.Tensor
@janeyx99
janeyx99 / c55nzp6gjeaxbq2uk3om2dmdybo5daxnp32w7czxaoxlz73dvza6.py
Created March 23, 2023 21:41
Compiled kernels for ASGD too many args
from ctypes import c_void_p, c_long
import torch
import math
import random
from torch import empty_strided, as_strided, device
from torch._inductor.codecache import AsyncCompile
from torch._inductor.select_algorithm import extern_kernels
aten = torch.ops.aten