Created
March 27, 2025 14:56
-
-
Save eldar/881b659a29a5e16af622fb3d84cc4a0d to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| def train(): | |
| torch.backends.cuda.matmul.allow_tf32 = True | |
| # ... | |
| args.lr = 0.00005 # best for Mast3r | |
| args.weight_decay = 0.05 # from Mast3r | |
| # following timm: set wd as 0 for bias and norm layers | |
| param_groups = misc.get_parameter_groups(model_without_ddp, args.weight_decay) | |
| optimizer = torch.optim.AdamW(param_groups, lr=args.lr, betas=(0.9, 0.95)) | |
| loss_scaler = NativeScaler() | |
| train_stats = train_one_epoch( | |
| model, train_criterion, data_loader_train, | |
| optimizer, device, epoch, loss_scaler, | |
| args=args) | |
| def train_one_epoch(model: torch.nn.Module, criterion: torch.nn.Module, | |
| data_loader: Sized, optimizer: torch.optim.Optimizer, | |
| device: torch.device, epoch: int, loss_scaler, | |
| args): | |
| assert torch.backends.cuda.matmul.allow_tf32 == True | |
| model.train(True) | |
| accum_iter = args.accum_iter | |
| optimizer.zero_grad() | |
| for data_iter_step, batch in enumerate(data_loader): | |
| epoch_f = epoch + data_iter_step / len(data_loader) | |
| # we use a per iteration (instead of per epoch) lr scheduler | |
| if data_iter_step % accum_iter == 0: | |
| misc.adjust_learning_rate(optimizer, epoch_f, args) | |
| loss = loss_of_one_batch(batch, model, criterion, device, amp=args.amp) | |
| loss /= accum_iter | |
| loss_scaler(loss, optimizer, parameters=model.parameters(), | |
| update_grad=(data_iter_step + 1) % accum_iter == 0) | |
| if (data_iter_step + 1) % accum_iter == 0: | |
| optimizer.zero_grad() | |
| def loss_of_one_batch(batch, model, criterion, device, amp="", ret=None): | |
| view1, view2 = batch | |
| for view in batch: | |
| for name in 'img pts3d valid_mask pts3d_t0 pts3d_t1 valid_mask_t0 valid_mask_t1 camera_pose camera_intrinsics F_matrix corres'.split(): # pseudo_focal | |
| if name not in view: | |
| continue | |
| view[name] = view[name].to(device, non_blocking=True) | |
| if amp == "bfloat16": | |
| autocast_amp = torch.amp.autocast("cuda", enabled=True, dtype=torch.bfloat16) | |
| else: | |
| autocast_amp = torch.amp.autocast("cuda", enabled=False) | |
| with autocast_amp: | |
| pred1, pred2 = model(view1, view2) | |
| # No autocast for loss in Dust3r | |
| with torch.amp.autocast("cuda", enabled=False): | |
| loss = criterion(view1, view2, pred1, pred2) if criterion is not None else None | |
| result = dict(view1=view1, view2=view2, pred1=pred1, pred2=pred2, loss=loss) | |
| return result[ret] if ret else result | |
| with autocast_amp: | |
| pred1, pred2 = model(view1, view2) | |
| # loss is supposed to be symmetric | |
| with torch.amp.autocast("cuda", enabled=False): | |
| loss = criterion(view1, view2, pred1, pred2) if criterion is not None else None | |
| result = dict(view1=view1, view2=view2, pred1=pred1, pred2=pred2, loss=loss) | |
| return result[ret] if ret else result | |
| class NativeScaler: | |
| state_dict_key = "amp_scaler" | |
| def __init__(self, enabled=True): | |
| self._scaler = torch.cuda.amp.GradScaler(enabled=enabled) | |
| def __call__(self, loss, optimizer, clip_grad=None, parameters=None, create_graph=False, update_grad=True): | |
| self._scaler.scale(loss).backward(create_graph=create_graph) | |
| if update_grad: | |
| if clip_grad is not None: | |
| assert parameters is not None | |
| self._scaler.unscale_(optimizer) # unscale the gradients of optimizer's assigned params in-place | |
| norm = torch.nn.utils.clip_grad_norm_(parameters, clip_grad) | |
| else: | |
| self._scaler.unscale_(optimizer) | |
| norm = get_grad_norm_(parameters) | |
| self._scaler.step(optimizer) | |
| self._scaler.update() | |
| else: | |
| norm = None | |
| return norm | |
| def state_dict(self): | |
| return self._scaler.state_dict() | |
| def load_state_dict(self, state_dict): | |
| self._scaler.load_state_dict(state_dict) | |
| def get_parameter_groups(model, weight_decay, layer_decay=1.0, skip_list=(), no_lr_scale_list=[]): | |
| parameter_group_names = {} | |
| parameter_group_vars = {} | |
| enc_depth, dec_depth = None, None | |
| for name, param in model.named_parameters(): | |
| if not param.requires_grad: | |
| continue # frozen weights | |
| # Assign weight decay values | |
| if len(param.shape) == 1 or name.endswith(".bias") or name in skip_list: | |
| group_name = "no_decay" | |
| this_weight_decay = 0. | |
| else: | |
| group_name = "decay" | |
| this_weight_decay = weight_decay | |
| layer_id = 0 | |
| skip_scale = True | |
| if group_name not in parameter_group_names: | |
| scale = 1. | |
| parameter_group_names[group_name] = { | |
| "weight_decay": this_weight_decay, | |
| "params": [], | |
| "lr_scale": scale | |
| } | |
| parameter_group_vars[group_name] = { | |
| "weight_decay": this_weight_decay, | |
| "params": [], | |
| "lr_scale": scale | |
| } | |
| parameter_group_vars[group_name]["params"].append(param) | |
| parameter_group_names[group_name]["params"].append(name) | |
| return list(parameter_group_vars.values()) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment