Last active
November 26, 2025 21:30
-
-
Save qsh-zh/e0e16c4348852887df66a3fe10cd076a to your computer and use it in GitHub Desktop.
adam param update https://kexue.fm/archives/11267
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| # %% | |
| import torch | |
| import torch.nn as nn | |
| import torch.optim as optim | |
| from torch.utils.data import TensorDataset, DataLoader | |
| import numpy as np | |
| import matplotlib.pyplot as plt | |
| from sklearn.datasets import make_swiss_roll, make_moons | |
| from tqdm import tqdm | |
| import os | |
| torch.set_float32_matmul_precision('high') | |
| # ========================================== | |
| # 1. Configuration | |
| # ========================================== | |
| DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| SEED = 42 | |
| torch.manual_seed(SEED) | |
| np.random.seed(SEED) | |
| INTRINSIC_DIM = 2 # true manifold dimension | |
| AMBIENT_DIM = 32 # high-dimensional ambient space | |
| print(f"Running on {DEVICE} with Ambient Dimension D={AMBIENT_DIM}") | |
| # %% | |
| # ========================================== | |
| # 2. Data + Random Orthogonal Embedding | |
| # ========================================== | |
| def get_data(batch_size: int, data_type: str) -> np.ndarray: | |
| if data_type == "swiss_roll": | |
| data, _ = make_swiss_roll(batch_size) | |
| data = data[:, [0, 2]] / 10.0 | |
| elif data_type == "moons": | |
| data, _ = make_moons(batch_size) | |
| return data | |
| # random orthogonal projection 2 → D | |
| Q = torch.randn(AMBIENT_DIM, INTRINSIC_DIM, device=DEVICE) | |
| Q, _ = torch.linalg.qr(Q) | |
| def embed(x2d): | |
| return x2d @ Q.T | |
| def project_back(xD): | |
| return xD @ Q | |
| n_points = 8192 | |
| data_type = "swiss_roll" | |
| data = get_data(n_points, data_type) | |
| plt.scatter(data[:,0], data[:,1], s=5, c='black', alpha=0.5) | |
| plt.tight_layout() | |
| plt.show() | |
| # %% | |
| batch_size = 128 | |
| # batch_size = 4 | |
| global_learning_rate = 5e-4 | |
| dataset = torch.from_numpy(data).float().to(DEVICE) | |
| dataset = TensorDataset(dataset) | |
| dataloader = DataLoader(dataset, batch_size=batch_size) | |
| # %% | |
| # ========================================== | |
| # 3. Model (MLP) | |
| # ========================================== | |
| class DenoisingMLP(nn.Module): | |
| def __init__(self, dim, hidden_dim, use_skip_connection: bool = False): | |
| super().__init__() | |
| self.use_skip_connection = use_skip_connection | |
| self.time_encoder = nn.Sequential( | |
| nn.Linear(1, 20), | |
| nn.ReLU(), | |
| ) | |
| self.net = nn.Sequential( | |
| nn.Linear(dim + 20, hidden_dim), | |
| nn.ReLU(), | |
| nn.Linear(hidden_dim, hidden_dim), | |
| nn.ReLU(), | |
| nn.Linear(hidden_dim, hidden_dim), | |
| nn.ReLU(), | |
| nn.Linear(hidden_dim, hidden_dim), | |
| nn.ReLU(), | |
| nn.Linear(hidden_dim, dim), | |
| ) | |
| # zero out last layer weights and bias | |
| self.net[-1].weight.data.zero_() | |
| self.net[-1].bias.data.zero_() | |
| def forward(self, z_t, t): | |
| t_emb = self.time_encoder(t) | |
| out = self.net(torch.cat((z_t, t_emb), dim=-1)) | |
| if self.use_skip_connection: | |
| out = out + z_t # optional residual skip from the noisy input | |
| return out | |
| def format_arch_label(use_skip_connection: bool) -> str: | |
| return "skip" if use_skip_connection else "no_skip" | |
| # %% | |
| # ================================================== | |
| # 4. Loss (x prediction + v loss) | |
| # ================================================== | |
| def compute_loss(x1, eps, xt, t, pred_raw): | |
| x1_hat = pred_raw | |
| v_hat = (x1_hat - xt) / (1-t) | |
| v_target = x1 - eps | |
| return ((v_hat - v_target)**2).mean() | |
| def clone_state_dict(state_dict: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]: | |
| """Return a detached clone so optimizer updates do not mutate the snapshot.""" | |
| return {key: value.detach().clone() for key, value in state_dict.items()} | |
| # %% | |
| # ========================================== | |
| # 5. Training (x prediction only) | |
| # ========================================== | |
| def train_model(model_arch: bool = False, hidden_dim=256, train_steps=1000): | |
| model = DenoisingMLP(AMBIENT_DIM, hidden_dim, use_skip_connection=model_arch).to(DEVICE) | |
| optim_ = optim.Adam(model.parameters(), lr=global_learning_rate, betas=(0.9, 0.99), eps=1e-12) | |
| arch_label = format_arch_label(model_arch) | |
| print(f"Training DenoisingMLP [{arch_label}] (x-prediction + v-loss)...") | |
| t_eps = 1e-3 | |
| losses = [] | |
| def train_step(x1_low, model_state_dict: dict[str, torch.Tensor] | None = None): | |
| x1 = embed(x1_low) | |
| eps = torch.randn_like(x1) | |
| t = torch.rand((x1.shape[0], 1), device=DEVICE) | |
| t = t.clip(t_eps, 1 - t_eps) | |
| xt = t * x1 + (1 - t) * eps | |
| pred_raw = model(xt, t) | |
| loss = compute_loss(x1, eps, xt, t, pred_raw) | |
| optim_.zero_grad() | |
| loss.backward() | |
| optim_.step() | |
| if model_state_dict is not None: | |
| current_model_state_dict = clone_state_dict(model.state_dict()) | |
| diff_model_state_dict = { | |
| k: (current_model_state_dict[k] - model_state_dict[k]) / global_learning_rate | |
| for k in current_model_state_dict.keys() | |
| } | |
| return loss, diff_model_state_dict, current_model_state_dict | |
| return loss, None, None | |
| model_state_dict = clone_state_dict(model.state_dict()) | |
| training_state_diff = {key: [] for key in model_state_dict.keys()} | |
| for _ in tqdm(range(train_steps)): | |
| current_loss = [] | |
| state_diff = {key: [] for key in model_state_dict.keys()} | |
| for batch in dataloader: | |
| x1_low = batch[0].to(DEVICE, non_blocking=True) | |
| loss, diff_model_state_dict, model_state_dict = train_step(x1_low, model_state_dict) | |
| if diff_model_state_dict is not None: | |
| for key in diff_model_state_dict.keys(): | |
| # log the rms norm of the difference | |
| state_diff[key].append( | |
| torch.sqrt(torch.mean(diff_model_state_dict[key]**2)).item() | |
| ) | |
| # losses.append(loss.clone().detach()) | |
| current_loss.append(loss.clone().detach().item()) | |
| losses.append(np.mean(current_loss)) | |
| state_diff = {key: np.mean(state_diff[key]) for key in state_diff.keys()} | |
| for key in state_diff.keys(): | |
| training_state_diff[key].append(state_diff[key]) | |
| # losses = [loss.item() for loss in losses] | |
| return model, losses, training_state_diff | |
| TRAIN_STEPS = 100 | |
| HIDDEN_DIM = 256 | |
| experiments = [ | |
| False, # no skip connection | |
| True, # add skip connection inside DenoisingMLP | |
| ] | |
| results = {} | |
| for model_arch in experiments: | |
| model, losses, training_state_diff = train_model(model_arch=model_arch, hidden_dim=HIDDEN_DIM, train_steps=TRAIN_STEPS) | |
| results[model_arch] = (model, losses, training_state_diff) | |
| # %% | |
| # ========================================== | |
| # 6. Sampling | |
| # ========================================== | |
| @torch.no_grad() | |
| def sample(model, num_samples=2000, steps=100): | |
| x_t = torch.randn(num_samples, AMBIENT_DIM, device=DEVICE) | |
| eps_t = 1e-2 | |
| ts = torch.linspace(eps_t, 1 - eps_t, steps, device=DEVICE) | |
| ts = ts.unsqueeze(0).repeat(num_samples, 1) | |
| interval = ts[:, 1:2] - ts[:, 0:1] | |
| for i in range(steps): | |
| t = ts[:, i:i+1] | |
| x1_hat = model(x_t, t) | |
| v_hat = (x1_hat - x_t) / (1 - t) | |
| x_t = x_t + v_hat * interval | |
| return x_t | |
| samples = {} | |
| for model_arch, (model, _, _) in results.items(): | |
| samples[model_arch] = sample(model, num_samples=n_points) | |
| # Project back to 2D | |
| x_orig = data | |
| x_preds = {} | |
| for model_arch, sample in samples.items(): | |
| x_preds[model_arch] = project_back(sample).cpu().numpy() | |
| # %% | |
| # ========================================== | |
| # 7. Plot | |
| # ========================================== | |
| if not os.path.exists("outputs"): | |
| os.makedirs("outputs") | |
| model_arches = list(dict.fromkeys(experiments)) | |
| cols = ["original", "prediction", "loss"] | |
| fig, axes = plt.subplots( | |
| len(model_arches), | |
| len(cols), | |
| figsize=(6 * len(cols), 5 * len(model_arches)), | |
| ) | |
| if len(model_arches) == 1: | |
| axes = axes[None, :] # ensure 2D indexing even for a single row | |
| for row_idx, model_arch in enumerate(model_arches): | |
| arch_label = format_arch_label(model_arch) | |
| for col_idx, col_name in enumerate(cols): | |
| ax = axes[row_idx, col_idx] | |
| if col_name == "original": | |
| ax.scatter(x_orig[:, 0], x_orig[:, 1], s=5, c="black", alpha=0.5) | |
| ax.set_title(f"{arch_label} - Ground Truth") | |
| ax.set_xlim(-1.5, 1.5) | |
| ax.set_ylim(-1.5, 1.5) | |
| continue | |
| if col_name == "loss": | |
| if model_arch in results: | |
| losses = results[model_arch][1] | |
| ax.plot(losses) | |
| else: | |
| ax.text(0.5, 0.5, "No loss data", ha="center", va="center", transform=ax.transAxes) | |
| ax.set_title(f"{arch_label} - Training Loss") | |
| ax.set_xlabel("Iteration") | |
| ax.set_ylabel("Loss") | |
| ax.grid(True) | |
| continue | |
| if model_arch not in x_preds: | |
| ax.text(0.5, 0.5, "No predictions", ha="center", va="center", transform=ax.transAxes) | |
| ax.set_title(f"{arch_label} - {col_name}") | |
| ax.set_xlim(-1.5, 1.5) | |
| ax.set_ylim(-1.5, 1.5) | |
| continue | |
| x_pred = x_preds[model_arch] | |
| ax.scatter(x_pred[:, 0], x_pred[:, 1], s=5, alpha=0.5) | |
| ax.set_title(f"{arch_label} - {col_name}") | |
| ax.set_xlim(-1.5, 1.5) | |
| ax.set_ylim(-1.5, 1.5) | |
| plt.tight_layout() | |
| plt.savefig(f"outputs/adam_0pt2_update_{batch_size}.png") | |
| splits = [ | |
| ("Weight", lambda key: key.endswith("weight")), | |
| ("Bias", lambda key: key.endswith("bias")), | |
| ] | |
| fig, axes = plt.subplots( | |
| len(model_arches), | |
| len(splits), | |
| figsize=(6 * len(splits), 4.5 * len(model_arches)), | |
| ) | |
| if len(model_arches) == 1: | |
| axes = axes[None, :] # force 2D indexing | |
| for row_idx, model_arch in enumerate(model_arches): | |
| arch_label = format_arch_label(model_arch) | |
| model_result = results.get(model_arch) | |
| arch_state_diff = model_result[2] if model_result else None | |
| for col_idx, (split_name, include_param) in enumerate(splits): | |
| ax = axes[row_idx, col_idx] | |
| has_curve = False | |
| if not arch_state_diff: | |
| ax.text(0.5, 0.5, "No state diff data", ha="center", va="center", transform=ax.transAxes) | |
| else: | |
| for key, values in arch_state_diff.items(): | |
| if include_param(key): | |
| ax.plot(values, label=key) | |
| has_curve = True | |
| ax.set_title(f"{arch_label} - {split_name} State Diff") | |
| ax.set_xlabel("Iteration") | |
| ax.set_ylabel("State Diff") | |
| ax.grid(True) | |
| if has_curve: | |
| ax.legend(loc="center left", bbox_to_anchor=(1.02, 0.5)) | |
| else: | |
| ax.text(0.5, 0.5, f"No {split_name.lower()} params tracked", ha="center", va="center", transform=ax.transAxes) | |
| fig.tight_layout(rect=[0, 0, 0.85, 1]) | |
| plt.savefig(f"outputs/adam_0pt2_update_state_diff_{batch_size}.png", bbox_inches="tight") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment