Skip to content

Instantly share code, notes, and snippets.

@cloneofsimo
Created August 7, 2024 09:07
Show Gist options
  • Select an option

  • Save cloneofsimo/c799c863154d5da4cae65e83491d918d to your computer and use it in GitHub Desktop.

Select an option

Save cloneofsimo/c799c863154d5da4cae65e83491d918d to your computer and use it in GitHub Desktop.
Demonstrate ABC invariance
# Suppose you have neural network that
# x_l = a_l * W_l x_{l-1}, W_l_{i,j} ~ N(0, b_l^2), Learning rate of W_l := c_l,
# If you are using adam, you can
# a_l <- a_l * A , b_l <- b_l / A, c_l <- c_l / A
# and it will have exactly identical training dynamics as before.
# This is known as ABC (ABCD) redundancy. For more general case: https://arxiv.org/abs/2308.01814
# Let me show you what I mean:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
from torch.nn import functional as F
torch.manual_seed(0)
x = torch.randn(100, 10)
y = torch.randn(100, 10)
dataset = TensorDataset(x, y)
data_loader = DataLoader(dataset, batch_size=10, shuffle=False)
class ModelA(nn.Module):
def __init__(self, a):
super(ModelA, self).__init__()
self.a = a
self.linear = nn.Linear(10, 10, bias = False)
def forward(self, x):
return self.a * self.linear(x)
class ModelB(nn.Module):
def __init__(self):
super(ModelB, self).__init__()
self.linear = nn.Linear(10, 10, bias = False)
def forward(self, x):
return self.linear(x)
def train_model(model, optimizer, data_loader, loss_func, epochs):
loss_history = []
for epoch in range(epochs):
for x_batch, y_batch in data_loader:
optimizer.zero_grad()
output = model(x_batch)
loss = loss_func(output, y_batch)
loss.backward()
optimizer.step()
loss_history.append(loss.item())
return loss_history
A = 3.1415
model_a = ModelA(A)
model_b = ModelB()
model_b.linear.weight.data = A * model_a.linear.weight.data.clone()
epochs = 10
lr_a = 0.01
lr_b = lr_a * A
optimizer_a = optim.Adam(model_a.parameters(), lr=lr_a, eps=0.0)
optimizer_b = optim.Adam(model_b.parameters(), lr=lr_b, eps = 0.0)
loss_func = nn.MSELoss()
# Train both models
loss_history_a = train_model(model_a, optimizer_a, data_loader, loss_func, epochs)
loss_history_b = train_model(model_b, optimizer_b, data_loader, loss_func, epochs)
print(loss_history_a, loss_history_b)
# assert
assert torch.allclose(model_a.linear.weight.data, model_b.linear.weight.data / A, atol=1e-3)
print("Max error Weight: ", torch.max(torch.abs(model_a.linear.weight.data - model_b.linear.weight.data / A)))
@cloneofsimo
Copy link
Author

This is completely identical in general case, in case you are doubting this is toy case:

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
from torch.nn import functional as F


torch.manual_seed(0)


x = torch.randn(100, 10)
y = torch.randn(100, 10)

dataset = TensorDataset(x, y)
data_loader = DataLoader(dataset, batch_size=10, shuffle=False)

class ModelA(nn.Module):
    def __init__(self, a):
        super(ModelA, self).__init__()
        self.a = a
        self.prev = nn.Linear(10, 10)
        self.linear = nn.Linear(10, 10, bias = False)
        self.after = nn.Linear(10, 10)
    
    def forward(self, x):
        x = nn.SiLU()(self.prev(x))
        x = nn.SiLU()(self.a * self.linear(x))
        return self.after(x)

class ModelB(nn.Module):
    def __init__(self):
        super(ModelB, self).__init__()
        self.prev = nn.Linear(10, 10)
        self.linear = nn.Linear(10, 10, bias = False)
        self.after = nn.Linear(10, 10)
    
    def forward(self, x):
        x = nn.SiLU()(self.prev(x))
        x = nn.SiLU()(self.linear(x))
        return self.after(x)

def train_model(model, optimizer, data_loader, loss_func, epochs):
    loss_history = []
    for epoch in range(epochs):
        for x_batch, y_batch in data_loader:
            optimizer.zero_grad()
            output = model(x_batch)
            loss = loss_func(output, y_batch)
            loss.backward()
            optimizer.step()
        loss_history.append(loss.item())
    return loss_history

A = 3.1415
model_a = ModelA(A)
model_b = ModelB()

model_b.load_state_dict(model_a.state_dict())
model_b.linear.weight.data = A * model_a.linear.weight.data.clone()


epochs = 10
lr_a = 0.001
lr_b = lr_a * A 
optimizer_a = optim.Adam(model_a.parameters(), lr=lr_a, eps=0.0)

# only lr_b for linear layer
opt_configs = [{'params': model_b.linear.parameters(), 'lr': lr_b},
               {'params': model_b.prev.parameters(), 'lr': lr_a},
               {'params': model_b.after.parameters(), 'lr': lr_a}]

optimizer_b = optim.Adam(opt_configs, eps=0.0)

loss_func = nn.MSELoss()

# Train both models
loss_history_a = train_model(model_a, optimizer_a, data_loader, loss_func, epochs)
loss_history_b = train_model(model_b, optimizer_b, data_loader, loss_func, epochs)

print(loss_history_a, loss_history_b)

# assert
assert torch.allclose(model_a.linear.weight.data, model_b.linear.weight.data / A, atol=1e-3)

print("Max error Weight: ", torch.max(torch.abs(model_a.linear.weight.data - model_b.linear.weight.data / A)))

@cloneofsimo
Copy link
Author

Loss history for Model A: [1.026654839515686, 1.0110559463500977, 0.9968965649604797, 0.9845868945121765, 0.973927915096283, 0.9646140933036804, 0.9563595056533813, 0.9489277005195618, 0.9421310424804688, 0.9358218908309937]
Loss history for Model B: [1.026654839515686, 1.0110559463500977, 0.9968964457511902, 0.9845868945121765, 0.973927915096283, 0.9646140933036804, 0.9563595652580261, 0.9489277601242065, 0.9421310424804688, 0.9358218908309937]

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment