Skip to content

Instantly share code, notes, and snippets.

@janeyx99
Last active April 8, 2025 16:31
Show Gist options
  • Select an option

  • Save janeyx99/39374a2eb3def3c69bc7ea582becfd52 to your computer and use it in GitHub Desktop.

Select an option

Save janeyx99/39374a2eb3def3c69bc7ea582becfd52 to your computer and use it in GitHub Desktop.
# write a hook to flatten optimizer state_dict when saving
def flatten_state_dict(optim, osd):
flattened_sd = {}
state = osd['state']
for idx, param_group in enumerate(osd['param_groups']):
assert 'param_names' in param_group, "param names are required as they'll be used as keys"
for param_name, param_id in zip(param_group['param_names'], param_group['params']):
# add all the state
if param_id in state:
for key, value in state[param_id].items():
flattened_sd[param_name + '.' + key] = value
# add all the param_group values
for key, value in param_group.items():
if key != 'param_names' and key != 'params':
flattened_sd[param_name + '.' + key] = value
# add the param id
flattened_sd[param_name + '.id'] = param_id
# add the param_group idx to make unflattening easier
flattened_sd[param_name + '.param_group_id'] = idx
return flattened_sd
# and another hook to unflatten for when loading
def unflatten_state_dict(optim, sd):
state = {}
param_groups = {} # will be converted to a list later!
for key, value in sd.items():
assert '.' in key, "key must contain a '.'"
# key will look like one of the following:
# (a) state: linear_relu_stack.4.bias.exp_avg_sq, with value Tensor
# (b) param_group default: linear_relu_stack.4.weight.fused, with value Optional[bool]
# (c) param id: linear_relu_stack.4.weight.id, with value int
# (d) param_group id: linear_relu_stack.4.weight.param_group_id, with value int
string_pieces = key.rsplit('.', 1)
param_name = string_pieces[0]
actual_key = string_pieces[1]
# we'll retrieve and handle c and d as we go, so no need to explicitly do anything
if actual_key == 'id' or actual_key == 'param_group_id':
continue
# add the param_name, param_id pair into a param_group, if not completed already
param_group_id = sd[param_name + '.param_group_id']
param_id = sd[param_name + '.id']
if param_group_id not in param_groups:
param_groups[param_group_id] = {}
if 'params' not in param_groups[param_group_id]:
assert 'param_names' not in param_groups[param_group_id], 'param_names should be set at the same time as params but already exists'
param_groups[param_group_id]['params'] = [param_id]
param_groups[param_group_id]['param_names'] = [param_name]
else:
# params and param_names already exist!
assert len(param_groups[param_group_id]['params']) == len(param_groups[param_group_id]['param_names']), 'params and param_names should be in sync, but their lengths are different'
already_in = False
# go in reverse because heuristics have it such that we're usually updating a recent param
for existing_id, existing_name in zip(reversed(param_groups[param_group_id]['params']), reversed(param_groups[param_group_id]['param_names'])):
if existing_id == param_id:
assert existing_name == param_name, f'{existing_name=} found for {param_id=} which is different from the in process {param_name=}'
already_in = True
break
if not already_in:
param_groups[param_group_id]['params'].append(param_id)
param_groups[param_group_id]['param_names'].append(param_name)
# if the actual key is in defaults, then it belongs to param_group, else, it goes in state
if actual_key in optim.defaults:
assert param_group_id in param_groups, f'we should have created a param_group for {param_group_id=} already!'
if actual_key in param_groups[param_group_id]:
assert value == param_groups[param_group_id][actual_key], f'param_group of the same index must have the same values but {value=} for {actual_key=} at {param_group_id=} is not the same as pre-existing {param_groups[param_group_id][actual_key]=}'
else:
# by this point, param_groups[param_group_id] will exist and actual_key will not be in it
param_groups[param_group_id][actual_key] = value
else:
# state
if not param_id in state:
state[param_id] = {}
assert actual_key not in state[param_id], f'duplicate state {actual_key} for {param_name=} and {param_id=}'
state[param_id][actual_key] = value
return {
'state': state,
'param_groups': list(param_groups.values())
}
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from copy import deepcopy
class NeuralNetwork(nn.Module):
def __init__(self):
super().__init__()
self.flatten = nn.Flatten()
self.linear_relu_stack = nn.Sequential(
nn.Linear(28*28, 512),
nn.ReLU(),
nn.Linear(512, 512),
nn.ReLU(),
nn.Linear(512, 10),
)
def forward(self, x):
x = self.flatten(x)
logits = self.linear_relu_stack(x)
return logits
net = NeuralNetwork()
optimizer = optim.AdamW(net.named_parameters())
optimizer.register_state_dict_post_hook(flatten_state_dict)
def train_loop(optim):
logits = net(torch.rand(1, 28, 28))
pred_probab = nn.Softmax(dim=1)(logits)
y_pred = pred_probab.sum() # this should not be sum but alas this is a fake example
loss = y_pred.to(dtype=torch.float) - 5
loss.backward()
optim.step()
optim.zero_grad()
def print_fsd(fsd):
for fqn, v in fsd.items():
if isinstance(v, torch.Tensor):
print(fqn, v.shape)
else:
print(fqn, v)
print('NO OPTIM STATE YET:')
print_fsd(optimizer.state_dict())
train_loop(optimizer)
print('\n\nAFTER OPTIM STEP, THERE IS STATE:')
saved_state_dict = deepcopy(optimizer.state_dict())
print_fsd(saved_state_dict)
# ==== now pretend you're loading a checkpoint ==== #
new_optim = optim.AdamW(net.named_parameters())
new_optim.register_load_state_dict_pre_hook(unflatten_state_dict)
new_optim.load_state_dict(saved_state_dict)
print('\n\nLOADING TO A NEW OPTIM SUCCEEDED!')
train_loop(new_optim)
print('\n\nNEW OPTIM STATE DICT AFTER ONE STEP')
print_fsd(new_optim.state_dict())
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment