Last active
January 15, 2025 23:03
-
-
Save KeAWang/39f84d34c8e9358458366af4ca97a0fa to your computer and use it in GitHub Desktop.
Jacobian of a gradient update recurrence
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import torch | |
| from einops import einsum, rearrange | |
| from typing import NamedTuple | |
| class MLPParams(NamedTuple): | |
| W1: torch.Tensor | |
| W2: torch.Tensor | |
| b1: torch.Tensor | |
| def mlp(params: NamedTuple, x): | |
| W1, W2, b1 = params | |
| B, x_dim = x.shape | |
| B, hidden_size, input_size = W1.shape | |
| B, output_size, W2_in = W2.shape | |
| B, b1_dim = b1.shape | |
| assert hidden_size == W2_in | |
| assert b1_dim == hidden_size | |
| assert x_dim == input_size | |
| h = torch.relu(einsum(W1, x, "B hidden input, B input -> B hidden") + b1) | |
| y = einsum(W2, h, "B output hidden, B hidden -> B output") | |
| return y | |
| def l2_loss(model, params, x, y_true): | |
| y = model(params, x) | |
| return torch.sum((y - y_true) ** 2, axis=-1) / 2 | |
| def add_non_batch_dims(x, y): | |
| # x is (B,), y is (B, D1, D2, ..., DN) | |
| return x.view(-1, *([1] * (y.dim() - 1))) | |
| def gradient_update(model, params, x, y_true, lr): | |
| loss = l2_loss(model, params, x, y_true) | |
| # compute gradient with respect to parameters_tensor | |
| grads = torch.autograd.grad(loss, params, grad_outputs=torch.ones_like(loss), create_graph=True) | |
| new_params = tuple(p - add_non_batch_dims(lr, g) * g for p, g in zip(params, grads)) | |
| return new_params | |
| def jacobian_gradient_update(model, params, x, y_true, lr): | |
| def gradient_update_(*params): | |
| return gradient_update(model, params, x, y_true, lr) | |
| # returns a tuple of tuples; e.g. jacobians[i][j] = jacobian(grad_loss(params[i]), params[j]) | |
| jacobians = torch.autograd.functional.jacobian(gradient_update_, inputs=params) | |
| return jacobians | |
| def block_diagonal_jacobian_gradient_update(model, params, x, y_true, lr): | |
| # For some reason this still gives a tuple of tuples | |
| #jacobian_blocks = [] | |
| #for param_name, param in params._asdict().items(): | |
| # def gradient_update_(p): | |
| # new_params = params._replace(**{param_name:p}) | |
| # return gradient_update(model, new_params, x, y_true, lr) | |
| # block = torch.autograd.functional.jacobian(gradient_update_, inputs=param) | |
| # jacobian_blocks.append(block) | |
| #return tuple(jacobian_blocks) | |
| jacobians = jacobian_gradient_update(model, params, x, y_true, lr) | |
| return tuple(jacobians[i][i] for i in range(len(jacobians))) | |
| def multi_dim_diagonal(x: torch.Tensor) -> torch.Tensor: | |
| """ | |
| Extract the multi-dimensional diagonal from a tensor of shape | |
| (B, D1, D2, ..., DN, B, D1, D2, ..., DN), returning a tensor of shape (B, D1, D2, ..., DN). | |
| """ | |
| # Validate shape: it should have an even number of dimensions | |
| if x.dim() % 2 != 0: | |
| raise ValueError("Input must have an even number of dimensions: " | |
| "(B, D1, ..., DN, B, D1, ..., DN).") | |
| # Split the shape into two halves | |
| sizes = x.shape | |
| half = x.dim() // 2 | |
| first_half = sizes[:half] | |
| second_half = sizes[half:] | |
| # Check that the two halves match | |
| if first_half != second_half: | |
| raise ValueError( | |
| "The two halves of the shape must be identical, but got " | |
| f"{first_half} and {second_half}." | |
| ) | |
| # Create coordinate arrays for the first half, each of shape (B, D1, D2, ..., DN) | |
| coords_1d = [torch.arange(s, device=x.device) for s in first_half] # 1D coords | |
| grids = torch.meshgrid(*coords_1d, indexing='ij') # list of length 'half' | |
| # Each grids[i] has shape (B, D1, ..., DN) | |
| # For advanced indexing, we replicate these coords for the second half | |
| all_coords = [] | |
| for i in range(half): | |
| all_coords.append(grids[i]) # first half | |
| for i in range(half): | |
| all_coords.append(grids[i]) # second half | |
| # Index into x to extract the diagonal | |
| return x[all_coords] | |
| def diagonal_jacobian_gradient_update(model, params, x, y_true, lr): | |
| block_diags = block_diagonal_jacobian_gradient_update(model, params, x, y_true, lr) | |
| return tuple(multi_dim_diagonal(block) for block in block_diags) | |
| device = "cuda:1" | |
| B, input_size, hidden_size, output_size = 1, 64, 64, 64 | |
| W1 = torch.randn(B, hidden_size, input_size, device=device).requires_grad_() | |
| W2 = torch.randn(B, output_size, hidden_size, device=device).requires_grad_() | |
| b1 = torch.randn(B, hidden_size, device=device).requires_grad_() | |
| params = MLPParams(W1, W2, b1) | |
| x = torch.randn(B, input_size, device=device) | |
| y_true = torch.randn(B, output_size, device=device) | |
| lr = torch.ones(B, device=device) * 0.01 | |
| # %% | |
| #gradient_update(mlp, params, x, y_true) | |
| #jacs = jacobian_gradient_update(mlp, params, x, y_true) | |
| #jac_blocks = block_diagonal_jacobian_gradient_update(mlp, params, x, y_true) | |
| #jac_diags = diagonal_jacobian_gradient_update(mlp, params, x, y_true) | |
| # %% | |
| #jacs = jacobian_gradient_update(mlp, params, x, y_true) | |
| # %% | |
| def batched_diag_kron(A, B): | |
| A_batch_size, *A_shape = A.shape | |
| B_batch_size, *B_shape = B.shape | |
| assert A_batch_size == B_batch_size | |
| A_diag = torch.diagonal(A, dim1=-2, dim2=-1) | |
| B_diag = torch.diagonal(B, dim1=-2, dim2=-1) | |
| # The diagonal of A ⊗ B is the (flattened) outer product of diag(A) and diag(B). | |
| kron_diag = einsum(A_diag, B_diag, "b n, b m -> b n m").flatten(1) | |
| return kron_diag | |
| def batch_eye(n, batch_size, device=None, dtype=None): | |
| return torch.eye(n, device=device, dtype=dtype).repeat(batch_size, 1, 1) | |
| def d_grad_loss_d_params(params, x, y_true, diagonal=True): | |
| W1, W2, b1 = params | |
| B, x_dim = x.shape | |
| B, hidden_size, input_size = W1.shape | |
| B, output_size, W2_in = W2.shape | |
| B, b1_dim = b1.shape | |
| assert hidden_size == W2_in | |
| assert b1_dim == hidden_size | |
| assert x_dim == input_size | |
| #### first we get the forward pass | |
| phi = torch.relu(einsum(W1, x, "B hidden input, B input -> B hidden") + b1) | |
| m_x = einsum(W2, phi, "B output hidden, B hidden -> B output") | |
| # derivative of relu | |
| dphi = (phi > 0).float() # should be a diagonal matrix but we store as a vector | |
| #### now we compute the hessians | |
| #diff = y_true - m_x # not needed because when phi is relu, all terms involving diff will be zero (since relu's second derivative is zero) | |
| # compute for W1 | |
| W2_dphi = W2 * dphi[:, None, :] | |
| dphi_t_W2_t_W2_dphi = einsum(W2_dphi, W2_dphi, "B j i, B j k -> B i k") | |
| x_x_t = einsum(x, x, "B i, B j -> B i j") | |
| d_grad_loss_d_W1 = matricize_linear_op( | |
| A=dphi_t_W2_t_W2_dphi, | |
| B=x_x_t, | |
| ) | |
| # compute for W2 | |
| phi_t_phi = einsum(phi, phi, "B hidden1, B hidden2 -> B hidden1 hidden2") | |
| d_grad_loss_d_W2 = matricize_linear_op( | |
| A=batch_eye(W2.shape[-2], B, device=W2.device, dtype=W2.dtype), | |
| B=phi_t_phi | |
| ) | |
| # compute for b1 | |
| d_grad_loss_d_b1 = dphi_t_W2_t_W2_dphi | |
| if diagonal: | |
| d_grad_loss_d_W1 = torch.diagonal(d_grad_loss_d_W1, dim1=-2, dim2=-1) | |
| d_grad_loss_d_W2 = torch.diagonal(d_grad_loss_d_W2, dim1=-2, dim2=-1) | |
| d_grad_loss_d_b1 = torch.diagonal(d_grad_loss_d_b1, dim1=-2, dim2=-1) | |
| return d_grad_loss_d_W1, d_grad_loss_d_W2, d_grad_loss_d_b1 | |
| def vectorize(X): | |
| # vectorize in column major order, assuming a batch dim | |
| batch_size, *X_shape = X.shape | |
| X = X.permute(0, *reversed(range(1, len(X_shape) + 1))) | |
| return X.flatten(1) | |
| def matricize_linear_op(A, B): | |
| # vec(AXB) = (B^T ⊗ A) vec(X) | |
| # A, X, B are all batched 2D matrices, i.e. (B, _, _) | |
| B_T = B.transpose(-2, -1) | |
| B_T_kron_A = rearrange(einsum(B_T, A, "B M N, B P Q -> B M P N Q"), "B M P N Q -> B (M P) (N Q)") | |
| return B_T_kron_A | |
| #def matmul_linear_op(A, X, B): | |
| # B_T_kron_A = matricize_linear_op(A, B) | |
| # vectorized_X = rearrange(X, "B M N -> B (N M)") | |
| # return einsum(B_T_kron_A, vectorized_X, "B MP NQ, B NQ -> B MP") | |
| # %% | |
| def d_grad_update_d_params(params, x, y_true, lr): | |
| W1, W2, b1 = params | |
| B, x_dim = x.shape | |
| B, hidden_size, input_size = W1.shape | |
| B, output_size, W2_in = W2.shape | |
| B, b1_dim = b1.shape | |
| assert hidden_size == W2_in | |
| assert b1_dim == hidden_size | |
| assert x_dim == input_size | |
| hessians = d_grad_loss_d_params(params, x, y_true) | |
| return tuple(vectorize(torch.ones_like(p)) - add_non_batch_dims(lr, hessian) * hessian for p, hessian in zip(params, hessians)) | |
| # %% | |
| analytic_jac_diags = d_grad_update_d_params(params, x, y_true, lr) | |
| jac_diags = diagonal_jacobian_gradient_update(mlp, params, x, y_true, lr) | |
| assert all(torch.allclose(vectorize(j1), j2, atol=1e-5) for j1, j2 in zip(jac_diags, analytic_jac_diags)) | |
| # %% | |
| %%timeit -n 5 | |
| _ = d_grad_update_d_params(params, x, y_true, lr) | |
| torch.cuda.synchronize(device=device) | |
| # %% | |
| %%timeit -n 5 | |
| _ = diagonal_jacobian_gradient_update(mlp, params, x, y_true, lr) | |
| # %% |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment