This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| # write a hook to flatten optimizer state_dict when saving | |
| def flatten_state_dict(optim, osd): | |
| flattened_sd = {} | |
| state = osd['state'] | |
| for idx, param_group in enumerate(osd['param_groups']): | |
| assert 'param_names' in param_group, "param names are required as they'll be used as keys" | |
| for param_name, param_id in zip(param_group['param_names'], param_group['params']): | |
| # add all the state | |
| if param_id in state: | |
| for key, value in state[param_id].items(): |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import torch | |
| from torch.utils.weak import WeakTensorKeyDictionary | |
| from torch.utils._pytree import tree_map | |
| from torch.utils._python_dispatch import TorchDispatchMode | |
| evictable_tensors: WeakTensorKeyDictionary = {} | |
| class OffloadTensor(torch.Tensor): | |
| cuda_elem: torch.Tensor | |
| cpu_elem: torch.Tensor |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| from ctypes import c_void_p, c_long | |
| import torch | |
| import math | |
| import random | |
| from torch import empty_strided, as_strided, device | |
| from torch._inductor.codecache import AsyncCompile | |
| from torch._inductor.select_algorithm import extern_kernels | |
| aten = torch.ops.aten |