Skip to content

Instantly share code, notes, and snippets.

@eldar
Created March 27, 2025 14:56
Show Gist options
  • Select an option

  • Save eldar/881b659a29a5e16af622fb3d84cc4a0d to your computer and use it in GitHub Desktop.

Select an option

Save eldar/881b659a29a5e16af622fb3d84cc4a0d to your computer and use it in GitHub Desktop.
def train():
torch.backends.cuda.matmul.allow_tf32 = True
# ...
args.lr = 0.00005 # best for Mast3r
args.weight_decay = 0.05 # from Mast3r
# following timm: set wd as 0 for bias and norm layers
param_groups = misc.get_parameter_groups(model_without_ddp, args.weight_decay)
optimizer = torch.optim.AdamW(param_groups, lr=args.lr, betas=(0.9, 0.95))
loss_scaler = NativeScaler()
train_stats = train_one_epoch(
model, train_criterion, data_loader_train,
optimizer, device, epoch, loss_scaler,
args=args)
def train_one_epoch(model: torch.nn.Module, criterion: torch.nn.Module,
data_loader: Sized, optimizer: torch.optim.Optimizer,
device: torch.device, epoch: int, loss_scaler,
args):
assert torch.backends.cuda.matmul.allow_tf32 == True
model.train(True)
accum_iter = args.accum_iter
optimizer.zero_grad()
for data_iter_step, batch in enumerate(data_loader):
epoch_f = epoch + data_iter_step / len(data_loader)
# we use a per iteration (instead of per epoch) lr scheduler
if data_iter_step % accum_iter == 0:
misc.adjust_learning_rate(optimizer, epoch_f, args)
loss = loss_of_one_batch(batch, model, criterion, device, amp=args.amp)
loss /= accum_iter
loss_scaler(loss, optimizer, parameters=model.parameters(),
update_grad=(data_iter_step + 1) % accum_iter == 0)
if (data_iter_step + 1) % accum_iter == 0:
optimizer.zero_grad()
def loss_of_one_batch(batch, model, criterion, device, amp="", ret=None):
view1, view2 = batch
for view in batch:
for name in 'img pts3d valid_mask pts3d_t0 pts3d_t1 valid_mask_t0 valid_mask_t1 camera_pose camera_intrinsics F_matrix corres'.split(): # pseudo_focal
if name not in view:
continue
view[name] = view[name].to(device, non_blocking=True)
if amp == "bfloat16":
autocast_amp = torch.amp.autocast("cuda", enabled=True, dtype=torch.bfloat16)
else:
autocast_amp = torch.amp.autocast("cuda", enabled=False)
with autocast_amp:
pred1, pred2 = model(view1, view2)
# No autocast for loss in Dust3r
with torch.amp.autocast("cuda", enabled=False):
loss = criterion(view1, view2, pred1, pred2) if criterion is not None else None
result = dict(view1=view1, view2=view2, pred1=pred1, pred2=pred2, loss=loss)
return result[ret] if ret else result
with autocast_amp:
pred1, pred2 = model(view1, view2)
# loss is supposed to be symmetric
with torch.amp.autocast("cuda", enabled=False):
loss = criterion(view1, view2, pred1, pred2) if criterion is not None else None
result = dict(view1=view1, view2=view2, pred1=pred1, pred2=pred2, loss=loss)
return result[ret] if ret else result
class NativeScaler:
state_dict_key = "amp_scaler"
def __init__(self, enabled=True):
self._scaler = torch.cuda.amp.GradScaler(enabled=enabled)
def __call__(self, loss, optimizer, clip_grad=None, parameters=None, create_graph=False, update_grad=True):
self._scaler.scale(loss).backward(create_graph=create_graph)
if update_grad:
if clip_grad is not None:
assert parameters is not None
self._scaler.unscale_(optimizer) # unscale the gradients of optimizer's assigned params in-place
norm = torch.nn.utils.clip_grad_norm_(parameters, clip_grad)
else:
self._scaler.unscale_(optimizer)
norm = get_grad_norm_(parameters)
self._scaler.step(optimizer)
self._scaler.update()
else:
norm = None
return norm
def state_dict(self):
return self._scaler.state_dict()
def load_state_dict(self, state_dict):
self._scaler.load_state_dict(state_dict)
def get_parameter_groups(model, weight_decay, layer_decay=1.0, skip_list=(), no_lr_scale_list=[]):
parameter_group_names = {}
parameter_group_vars = {}
enc_depth, dec_depth = None, None
for name, param in model.named_parameters():
if not param.requires_grad:
continue # frozen weights
# Assign weight decay values
if len(param.shape) == 1 or name.endswith(".bias") or name in skip_list:
group_name = "no_decay"
this_weight_decay = 0.
else:
group_name = "decay"
this_weight_decay = weight_decay
layer_id = 0
skip_scale = True
if group_name not in parameter_group_names:
scale = 1.
parameter_group_names[group_name] = {
"weight_decay": this_weight_decay,
"params": [],
"lr_scale": scale
}
parameter_group_vars[group_name] = {
"weight_decay": this_weight_decay,
"params": [],
"lr_scale": scale
}
parameter_group_vars[group_name]["params"].append(param)
parameter_group_names[group_name]["params"].append(name)
return list(parameter_group_vars.values())
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment