Last active
April 8, 2025 16:31
-
-
Save janeyx99/39374a2eb3def3c69bc7ea582becfd52 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| # write a hook to flatten optimizer state_dict when saving | |
| def flatten_state_dict(optim, osd): | |
| flattened_sd = {} | |
| state = osd['state'] | |
| for idx, param_group in enumerate(osd['param_groups']): | |
| assert 'param_names' in param_group, "param names are required as they'll be used as keys" | |
| for param_name, param_id in zip(param_group['param_names'], param_group['params']): | |
| # add all the state | |
| if param_id in state: | |
| for key, value in state[param_id].items(): | |
| flattened_sd[param_name + '.' + key] = value | |
| # add all the param_group values | |
| for key, value in param_group.items(): | |
| if key != 'param_names' and key != 'params': | |
| flattened_sd[param_name + '.' + key] = value | |
| # add the param id | |
| flattened_sd[param_name + '.id'] = param_id | |
| # add the param_group idx to make unflattening easier | |
| flattened_sd[param_name + '.param_group_id'] = idx | |
| return flattened_sd | |
| # and another hook to unflatten for when loading | |
| def unflatten_state_dict(optim, sd): | |
| state = {} | |
| param_groups = {} # will be converted to a list later! | |
| for key, value in sd.items(): | |
| assert '.' in key, "key must contain a '.'" | |
| # key will look like one of the following: | |
| # (a) state: linear_relu_stack.4.bias.exp_avg_sq, with value Tensor | |
| # (b) param_group default: linear_relu_stack.4.weight.fused, with value Optional[bool] | |
| # (c) param id: linear_relu_stack.4.weight.id, with value int | |
| # (d) param_group id: linear_relu_stack.4.weight.param_group_id, with value int | |
| string_pieces = key.rsplit('.', 1) | |
| param_name = string_pieces[0] | |
| actual_key = string_pieces[1] | |
| # we'll retrieve and handle c and d as we go, so no need to explicitly do anything | |
| if actual_key == 'id' or actual_key == 'param_group_id': | |
| continue | |
| # add the param_name, param_id pair into a param_group, if not completed already | |
| param_group_id = sd[param_name + '.param_group_id'] | |
| param_id = sd[param_name + '.id'] | |
| if param_group_id not in param_groups: | |
| param_groups[param_group_id] = {} | |
| if 'params' not in param_groups[param_group_id]: | |
| assert 'param_names' not in param_groups[param_group_id], 'param_names should be set at the same time as params but already exists' | |
| param_groups[param_group_id]['params'] = [param_id] | |
| param_groups[param_group_id]['param_names'] = [param_name] | |
| else: | |
| # params and param_names already exist! | |
| assert len(param_groups[param_group_id]['params']) == len(param_groups[param_group_id]['param_names']), 'params and param_names should be in sync, but their lengths are different' | |
| already_in = False | |
| # go in reverse because heuristics have it such that we're usually updating a recent param | |
| for existing_id, existing_name in zip(reversed(param_groups[param_group_id]['params']), reversed(param_groups[param_group_id]['param_names'])): | |
| if existing_id == param_id: | |
| assert existing_name == param_name, f'{existing_name=} found for {param_id=} which is different from the in process {param_name=}' | |
| already_in = True | |
| break | |
| if not already_in: | |
| param_groups[param_group_id]['params'].append(param_id) | |
| param_groups[param_group_id]['param_names'].append(param_name) | |
| # if the actual key is in defaults, then it belongs to param_group, else, it goes in state | |
| if actual_key in optim.defaults: | |
| assert param_group_id in param_groups, f'we should have created a param_group for {param_group_id=} already!' | |
| if actual_key in param_groups[param_group_id]: | |
| assert value == param_groups[param_group_id][actual_key], f'param_group of the same index must have the same values but {value=} for {actual_key=} at {param_group_id=} is not the same as pre-existing {param_groups[param_group_id][actual_key]=}' | |
| else: | |
| # by this point, param_groups[param_group_id] will exist and actual_key will not be in it | |
| param_groups[param_group_id][actual_key] = value | |
| else: | |
| # state | |
| if not param_id in state: | |
| state[param_id] = {} | |
| assert actual_key not in state[param_id], f'duplicate state {actual_key} for {param_name=} and {param_id=}' | |
| state[param_id][actual_key] = value | |
| return { | |
| 'state': state, | |
| 'param_groups': list(param_groups.values()) | |
| } | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| import torch.optim as optim | |
| from copy import deepcopy | |
| class NeuralNetwork(nn.Module): | |
| def __init__(self): | |
| super().__init__() | |
| self.flatten = nn.Flatten() | |
| self.linear_relu_stack = nn.Sequential( | |
| nn.Linear(28*28, 512), | |
| nn.ReLU(), | |
| nn.Linear(512, 512), | |
| nn.ReLU(), | |
| nn.Linear(512, 10), | |
| ) | |
| def forward(self, x): | |
| x = self.flatten(x) | |
| logits = self.linear_relu_stack(x) | |
| return logits | |
| net = NeuralNetwork() | |
| optimizer = optim.AdamW(net.named_parameters()) | |
| optimizer.register_state_dict_post_hook(flatten_state_dict) | |
| def train_loop(optim): | |
| logits = net(torch.rand(1, 28, 28)) | |
| pred_probab = nn.Softmax(dim=1)(logits) | |
| y_pred = pred_probab.sum() # this should not be sum but alas this is a fake example | |
| loss = y_pred.to(dtype=torch.float) - 5 | |
| loss.backward() | |
| optim.step() | |
| optim.zero_grad() | |
| def print_fsd(fsd): | |
| for fqn, v in fsd.items(): | |
| if isinstance(v, torch.Tensor): | |
| print(fqn, v.shape) | |
| else: | |
| print(fqn, v) | |
| print('NO OPTIM STATE YET:') | |
| print_fsd(optimizer.state_dict()) | |
| train_loop(optimizer) | |
| print('\n\nAFTER OPTIM STEP, THERE IS STATE:') | |
| saved_state_dict = deepcopy(optimizer.state_dict()) | |
| print_fsd(saved_state_dict) | |
| # ==== now pretend you're loading a checkpoint ==== # | |
| new_optim = optim.AdamW(net.named_parameters()) | |
| new_optim.register_load_state_dict_pre_hook(unflatten_state_dict) | |
| new_optim.load_state_dict(saved_state_dict) | |
| print('\n\nLOADING TO A NEW OPTIM SUCCEEDED!') | |
| train_loop(new_optim) | |
| print('\n\nNEW OPTIM STATE DICT AFTER ONE STEP') | |
| print_fsd(new_optim.state_dict()) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment