Skip to content

Instantly share code, notes, and snippets.

@wassname2
Last active September 2, 2025 07:28
Show Gist options
  • Select an option

  • Save wassname2/80b6e8a36bfc2b0738412c85e12b43ac to your computer and use it in GitHub Desktop.

Select an option

Save wassname2/80b6e8a36bfc2b0738412c85e12b43ac to your computer and use it in GitHub Desktop.
when using upgrad of any torchjd loss weight you can print the aggregation weight using a `with` function, and it's tidier than manually applying hooks
import torch
from torch import Tensor
from torch.nn.functional import cosine_similarity
from contextlib import AbstractContextManager
from typing import List, Dict
from functools import partial
class PrintGradWeights(AbstractContextManager):
"""Capture and print UPGrad weights + cosine-similarity during a forward pass.
Usage:
with PrintGradWeights(aggregator, names) as tracer:
_ = model(...) # or your torchjd.backward(...) step
# afterwards
data = tracer.data
"""
def __init__(
self,
aggregator: torch.nn.Module,
names: List[str],
enabled: bool = True,
):
self.aggregator = aggregator
self.names = names
self.enabled = enabled
self.data: Dict[str, float] = {}
self._handles = []
def _hook_weights(
self,
module: torch.nn.Module,
inputs: tuple,
output: Tensor,
) -> None:
for name, w in zip(self.names, output):
self.data[name] = w.detach().cpu().item()
def _hook_similarity(
self,
module: torch.nn.Module,
inputs: tuple,
aggregation: Tensor,
) -> None:
matrix = inputs[0]
avg_grad = matrix.mean(dim=0)
sim = cosine_similarity(aggregation, avg_grad, dim=0)
self.data["cosine_similarity"] = sim.item()
# print(f"Cosine similarity: {sim.item():.4f} (→ ideally ≈ 1)")
def __enter__(self) -> "PrintGradWeights":
if self.enabled:
self.data = {}
self._handles = []
# hook on the weighting submodule’s forward
self._handles.append(
self.aggregator.weighting.register_forward_hook(self._hook_weights)
)
# hook on the aggregator’s forward
self._handles.append(
self.aggregator.register_forward_hook(self._hook_similarity)
)
return self
def __exit__(self, exc_type, exc_val, exc_tb) -> bool:
for h in self._handles:
h.remove()
self._handles = []
# don’t suppress exceptions
return False
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment