Skip to content

Instantly share code, notes, and snippets.

@KeAWang
Last active January 15, 2025 23:03
Show Gist options
  • Select an option

  • Save KeAWang/39f84d34c8e9358458366af4ca97a0fa to your computer and use it in GitHub Desktop.

Select an option

Save KeAWang/39f84d34c8e9358458366af4ca97a0fa to your computer and use it in GitHub Desktop.
Jacobian of a gradient update recurrence
import torch
from einops import einsum, rearrange
from typing import NamedTuple
class MLPParams(NamedTuple):
W1: torch.Tensor
W2: torch.Tensor
b1: torch.Tensor
def mlp(params: NamedTuple, x):
W1, W2, b1 = params
B, x_dim = x.shape
B, hidden_size, input_size = W1.shape
B, output_size, W2_in = W2.shape
B, b1_dim = b1.shape
assert hidden_size == W2_in
assert b1_dim == hidden_size
assert x_dim == input_size
h = torch.relu(einsum(W1, x, "B hidden input, B input -> B hidden") + b1)
y = einsum(W2, h, "B output hidden, B hidden -> B output")
return y
def l2_loss(model, params, x, y_true):
y = model(params, x)
return torch.sum((y - y_true) ** 2, axis=-1) / 2
def add_non_batch_dims(x, y):
# x is (B,), y is (B, D1, D2, ..., DN)
return x.view(-1, *([1] * (y.dim() - 1)))
def gradient_update(model, params, x, y_true, lr):
loss = l2_loss(model, params, x, y_true)
# compute gradient with respect to parameters_tensor
grads = torch.autograd.grad(loss, params, grad_outputs=torch.ones_like(loss), create_graph=True)
new_params = tuple(p - add_non_batch_dims(lr, g) * g for p, g in zip(params, grads))
return new_params
def jacobian_gradient_update(model, params, x, y_true, lr):
def gradient_update_(*params):
return gradient_update(model, params, x, y_true, lr)
# returns a tuple of tuples; e.g. jacobians[i][j] = jacobian(grad_loss(params[i]), params[j])
jacobians = torch.autograd.functional.jacobian(gradient_update_, inputs=params)
return jacobians
def block_diagonal_jacobian_gradient_update(model, params, x, y_true, lr):
# For some reason this still gives a tuple of tuples
#jacobian_blocks = []
#for param_name, param in params._asdict().items():
# def gradient_update_(p):
# new_params = params._replace(**{param_name:p})
# return gradient_update(model, new_params, x, y_true, lr)
# block = torch.autograd.functional.jacobian(gradient_update_, inputs=param)
# jacobian_blocks.append(block)
#return tuple(jacobian_blocks)
jacobians = jacobian_gradient_update(model, params, x, y_true, lr)
return tuple(jacobians[i][i] for i in range(len(jacobians)))
def multi_dim_diagonal(x: torch.Tensor) -> torch.Tensor:
"""
Extract the multi-dimensional diagonal from a tensor of shape
(B, D1, D2, ..., DN, B, D1, D2, ..., DN), returning a tensor of shape (B, D1, D2, ..., DN).
"""
# Validate shape: it should have an even number of dimensions
if x.dim() % 2 != 0:
raise ValueError("Input must have an even number of dimensions: "
"(B, D1, ..., DN, B, D1, ..., DN).")
# Split the shape into two halves
sizes = x.shape
half = x.dim() // 2
first_half = sizes[:half]
second_half = sizes[half:]
# Check that the two halves match
if first_half != second_half:
raise ValueError(
"The two halves of the shape must be identical, but got "
f"{first_half} and {second_half}."
)
# Create coordinate arrays for the first half, each of shape (B, D1, D2, ..., DN)
coords_1d = [torch.arange(s, device=x.device) for s in first_half] # 1D coords
grids = torch.meshgrid(*coords_1d, indexing='ij') # list of length 'half'
# Each grids[i] has shape (B, D1, ..., DN)
# For advanced indexing, we replicate these coords for the second half
all_coords = []
for i in range(half):
all_coords.append(grids[i]) # first half
for i in range(half):
all_coords.append(grids[i]) # second half
# Index into x to extract the diagonal
return x[all_coords]
def diagonal_jacobian_gradient_update(model, params, x, y_true, lr):
block_diags = block_diagonal_jacobian_gradient_update(model, params, x, y_true, lr)
return tuple(multi_dim_diagonal(block) for block in block_diags)
device = "cuda:1"
B, input_size, hidden_size, output_size = 1, 64, 64, 64
W1 = torch.randn(B, hidden_size, input_size, device=device).requires_grad_()
W2 = torch.randn(B, output_size, hidden_size, device=device).requires_grad_()
b1 = torch.randn(B, hidden_size, device=device).requires_grad_()
params = MLPParams(W1, W2, b1)
x = torch.randn(B, input_size, device=device)
y_true = torch.randn(B, output_size, device=device)
lr = torch.ones(B, device=device) * 0.01
# %%
#gradient_update(mlp, params, x, y_true)
#jacs = jacobian_gradient_update(mlp, params, x, y_true)
#jac_blocks = block_diagonal_jacobian_gradient_update(mlp, params, x, y_true)
#jac_diags = diagonal_jacobian_gradient_update(mlp, params, x, y_true)
# %%
#jacs = jacobian_gradient_update(mlp, params, x, y_true)
# %%
def batched_diag_kron(A, B):
A_batch_size, *A_shape = A.shape
B_batch_size, *B_shape = B.shape
assert A_batch_size == B_batch_size
A_diag = torch.diagonal(A, dim1=-2, dim2=-1)
B_diag = torch.diagonal(B, dim1=-2, dim2=-1)
# The diagonal of A ⊗ B is the (flattened) outer product of diag(A) and diag(B).
kron_diag = einsum(A_diag, B_diag, "b n, b m -> b n m").flatten(1)
return kron_diag
def batch_eye(n, batch_size, device=None, dtype=None):
return torch.eye(n, device=device, dtype=dtype).repeat(batch_size, 1, 1)
def d_grad_loss_d_params(params, x, y_true, diagonal=True):
W1, W2, b1 = params
B, x_dim = x.shape
B, hidden_size, input_size = W1.shape
B, output_size, W2_in = W2.shape
B, b1_dim = b1.shape
assert hidden_size == W2_in
assert b1_dim == hidden_size
assert x_dim == input_size
#### first we get the forward pass
phi = torch.relu(einsum(W1, x, "B hidden input, B input -> B hidden") + b1)
m_x = einsum(W2, phi, "B output hidden, B hidden -> B output")
# derivative of relu
dphi = (phi > 0).float() # should be a diagonal matrix but we store as a vector
#### now we compute the hessians
#diff = y_true - m_x # not needed because when phi is relu, all terms involving diff will be zero (since relu's second derivative is zero)
# compute for W1
W2_dphi = W2 * dphi[:, None, :]
dphi_t_W2_t_W2_dphi = einsum(W2_dphi, W2_dphi, "B j i, B j k -> B i k")
x_x_t = einsum(x, x, "B i, B j -> B i j")
d_grad_loss_d_W1 = matricize_linear_op(
A=dphi_t_W2_t_W2_dphi,
B=x_x_t,
)
# compute for W2
phi_t_phi = einsum(phi, phi, "B hidden1, B hidden2 -> B hidden1 hidden2")
d_grad_loss_d_W2 = matricize_linear_op(
A=batch_eye(W2.shape[-2], B, device=W2.device, dtype=W2.dtype),
B=phi_t_phi
)
# compute for b1
d_grad_loss_d_b1 = dphi_t_W2_t_W2_dphi
if diagonal:
d_grad_loss_d_W1 = torch.diagonal(d_grad_loss_d_W1, dim1=-2, dim2=-1)
d_grad_loss_d_W2 = torch.diagonal(d_grad_loss_d_W2, dim1=-2, dim2=-1)
d_grad_loss_d_b1 = torch.diagonal(d_grad_loss_d_b1, dim1=-2, dim2=-1)
return d_grad_loss_d_W1, d_grad_loss_d_W2, d_grad_loss_d_b1
def vectorize(X):
# vectorize in column major order, assuming a batch dim
batch_size, *X_shape = X.shape
X = X.permute(0, *reversed(range(1, len(X_shape) + 1)))
return X.flatten(1)
def matricize_linear_op(A, B):
# vec(AXB) = (B^T ⊗ A) vec(X)
# A, X, B are all batched 2D matrices, i.e. (B, _, _)
B_T = B.transpose(-2, -1)
B_T_kron_A = rearrange(einsum(B_T, A, "B M N, B P Q -> B M P N Q"), "B M P N Q -> B (M P) (N Q)")
return B_T_kron_A
#def matmul_linear_op(A, X, B):
# B_T_kron_A = matricize_linear_op(A, B)
# vectorized_X = rearrange(X, "B M N -> B (N M)")
# return einsum(B_T_kron_A, vectorized_X, "B MP NQ, B NQ -> B MP")
# %%
def d_grad_update_d_params(params, x, y_true, lr):
W1, W2, b1 = params
B, x_dim = x.shape
B, hidden_size, input_size = W1.shape
B, output_size, W2_in = W2.shape
B, b1_dim = b1.shape
assert hidden_size == W2_in
assert b1_dim == hidden_size
assert x_dim == input_size
hessians = d_grad_loss_d_params(params, x, y_true)
return tuple(vectorize(torch.ones_like(p)) - add_non_batch_dims(lr, hessian) * hessian for p, hessian in zip(params, hessians))
# %%
analytic_jac_diags = d_grad_update_d_params(params, x, y_true, lr)
jac_diags = diagonal_jacobian_gradient_update(mlp, params, x, y_true, lr)
assert all(torch.allclose(vectorize(j1), j2, atol=1e-5) for j1, j2 in zip(jac_diags, analytic_jac_diags))
# %%
%%timeit -n 5
_ = d_grad_update_d_params(params, x, y_true, lr)
torch.cuda.synchronize(device=device)
# %%
%%timeit -n 5
_ = diagonal_jacobian_gradient_update(mlp, params, x, y_true, lr)
# %%
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment