Skip to content

Instantly share code, notes, and snippets.

@cloneofsimo
Created August 7, 2024 09:07
Show Gist options
  • Select an option

  • Save cloneofsimo/c799c863154d5da4cae65e83491d918d to your computer and use it in GitHub Desktop.

Select an option

Save cloneofsimo/c799c863154d5da4cae65e83491d918d to your computer and use it in GitHub Desktop.
Demonstrate ABC invariance
# Suppose you have neural network that
# x_l = a_l * W_l x_{l-1}, W_l_{i,j} ~ N(0, b_l^2), Learning rate of W_l := c_l,
# If you are using adam, you can
# a_l <- a_l * A , b_l <- b_l / A, c_l <- c_l / A
# and it will have exactly identical training dynamics as before.
# This is known as ABC (ABCD) redundancy. For more general case: https://arxiv.org/abs/2308.01814
# Let me show you what I mean:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
from torch.nn import functional as F
torch.manual_seed(0)
x = torch.randn(100, 10)
y = torch.randn(100, 10)
dataset = TensorDataset(x, y)
data_loader = DataLoader(dataset, batch_size=10, shuffle=False)
class ModelA(nn.Module):
def __init__(self, a):
super(ModelA, self).__init__()
self.a = a
self.linear = nn.Linear(10, 10, bias = False)
def forward(self, x):
return self.a * self.linear(x)
class ModelB(nn.Module):
def __init__(self):
super(ModelB, self).__init__()
self.linear = nn.Linear(10, 10, bias = False)
def forward(self, x):
return self.linear(x)
def train_model(model, optimizer, data_loader, loss_func, epochs):
loss_history = []
for epoch in range(epochs):
for x_batch, y_batch in data_loader:
optimizer.zero_grad()
output = model(x_batch)
loss = loss_func(output, y_batch)
loss.backward()
optimizer.step()
loss_history.append(loss.item())
return loss_history
A = 3.1415
model_a = ModelA(A)
model_b = ModelB()
model_b.linear.weight.data = A * model_a.linear.weight.data.clone()
epochs = 10
lr_a = 0.01
lr_b = lr_a * A
optimizer_a = optim.Adam(model_a.parameters(), lr=lr_a, eps=0.0)
optimizer_b = optim.Adam(model_b.parameters(), lr=lr_b, eps = 0.0)
loss_func = nn.MSELoss()
# Train both models
loss_history_a = train_model(model_a, optimizer_a, data_loader, loss_func, epochs)
loss_history_b = train_model(model_b, optimizer_b, data_loader, loss_func, epochs)
print(loss_history_a, loss_history_b)
# assert
assert torch.allclose(model_a.linear.weight.data, model_b.linear.weight.data / A, atol=1e-3)
print("Max error Weight: ", torch.max(torch.abs(model_a.linear.weight.data - model_b.linear.weight.data / A)))
@cloneofsimo
Copy link
Author

Loss history for Model A: [1.026654839515686, 1.0110559463500977, 0.9968965649604797, 0.9845868945121765, 0.973927915096283, 0.9646140933036804, 0.9563595056533813, 0.9489277005195618, 0.9421310424804688, 0.9358218908309937]
Loss history for Model B: [1.026654839515686, 1.0110559463500977, 0.9968964457511902, 0.9845868945121765, 0.973927915096283, 0.9646140933036804, 0.9563595652580261, 0.9489277601242065, 0.9421310424804688, 0.9358218908309937]

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment