Created
August 7, 2024 09:07
-
-
Save cloneofsimo/c799c863154d5da4cae65e83491d918d to your computer and use it in GitHub Desktop.
Demonstrate ABC invariance
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| # Suppose you have neural network that | |
| # x_l = a_l * W_l x_{l-1}, W_l_{i,j} ~ N(0, b_l^2), Learning rate of W_l := c_l, | |
| # If you are using adam, you can | |
| # a_l <- a_l * A , b_l <- b_l / A, c_l <- c_l / A | |
| # and it will have exactly identical training dynamics as before. | |
| # This is known as ABC (ABCD) redundancy. For more general case: https://arxiv.org/abs/2308.01814 | |
| # Let me show you what I mean: | |
| import torch | |
| import torch.nn as nn | |
| import torch.optim as optim | |
| from torch.utils.data import DataLoader, TensorDataset | |
| from torch.nn import functional as F | |
| torch.manual_seed(0) | |
| x = torch.randn(100, 10) | |
| y = torch.randn(100, 10) | |
| dataset = TensorDataset(x, y) | |
| data_loader = DataLoader(dataset, batch_size=10, shuffle=False) | |
| class ModelA(nn.Module): | |
| def __init__(self, a): | |
| super(ModelA, self).__init__() | |
| self.a = a | |
| self.linear = nn.Linear(10, 10, bias = False) | |
| def forward(self, x): | |
| return self.a * self.linear(x) | |
| class ModelB(nn.Module): | |
| def __init__(self): | |
| super(ModelB, self).__init__() | |
| self.linear = nn.Linear(10, 10, bias = False) | |
| def forward(self, x): | |
| return self.linear(x) | |
| def train_model(model, optimizer, data_loader, loss_func, epochs): | |
| loss_history = [] | |
| for epoch in range(epochs): | |
| for x_batch, y_batch in data_loader: | |
| optimizer.zero_grad() | |
| output = model(x_batch) | |
| loss = loss_func(output, y_batch) | |
| loss.backward() | |
| optimizer.step() | |
| loss_history.append(loss.item()) | |
| return loss_history | |
| A = 3.1415 | |
| model_a = ModelA(A) | |
| model_b = ModelB() | |
| model_b.linear.weight.data = A * model_a.linear.weight.data.clone() | |
| epochs = 10 | |
| lr_a = 0.01 | |
| lr_b = lr_a * A | |
| optimizer_a = optim.Adam(model_a.parameters(), lr=lr_a, eps=0.0) | |
| optimizer_b = optim.Adam(model_b.parameters(), lr=lr_b, eps = 0.0) | |
| loss_func = nn.MSELoss() | |
| # Train both models | |
| loss_history_a = train_model(model_a, optimizer_a, data_loader, loss_func, epochs) | |
| loss_history_b = train_model(model_b, optimizer_b, data_loader, loss_func, epochs) | |
| print(loss_history_a, loss_history_b) | |
| # assert | |
| assert torch.allclose(model_a.linear.weight.data, model_b.linear.weight.data / A, atol=1e-3) | |
| print("Max error Weight: ", torch.max(torch.abs(model_a.linear.weight.data - model_b.linear.weight.data / A))) |
Author
Author
Loss history for Model A: [1.026654839515686, 1.0110559463500977, 0.9968965649604797, 0.9845868945121765, 0.973927915096283, 0.9646140933036804, 0.9563595056533813, 0.9489277005195618, 0.9421310424804688, 0.9358218908309937]
Loss history for Model B: [1.026654839515686, 1.0110559463500977, 0.9968964457511902, 0.9845868945121765, 0.973927915096283, 0.9646140933036804, 0.9563595652580261, 0.9489277601242065, 0.9421310424804688, 0.9358218908309937]
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
This is completely identical in general case, in case you are doubting this is toy case: