Skip to content

Instantly share code, notes, and snippets.

@qsh-zh
Last active November 26, 2025 21:30
Show Gist options
  • Select an option

  • Save qsh-zh/e0e16c4348852887df66a3fe10cd076a to your computer and use it in GitHub Desktop.

Select an option

Save qsh-zh/e0e16c4348852887df66a3fe10cd076a to your computer and use it in GitHub Desktop.
# %%
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import TensorDataset, DataLoader
import numpy as np
import matplotlib.pyplot as plt
from sklearn.datasets import make_swiss_roll, make_moons
from tqdm import tqdm
import os
torch.set_float32_matmul_precision('high')
# ==========================================
# 1. Configuration
# ==========================================
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
SEED = 42
torch.manual_seed(SEED)
np.random.seed(SEED)
INTRINSIC_DIM = 2 # true manifold dimension
AMBIENT_DIM = 32 # high-dimensional ambient space
print(f"Running on {DEVICE} with Ambient Dimension D={AMBIENT_DIM}")
# %%
# ==========================================
# 2. Data + Random Orthogonal Embedding
# ==========================================
def get_data(batch_size: int, data_type: str) -> np.ndarray:
if data_type == "swiss_roll":
data, _ = make_swiss_roll(batch_size)
data = data[:, [0, 2]] / 10.0
elif data_type == "moons":
data, _ = make_moons(batch_size)
return data
# random orthogonal projection 2 → D
Q = torch.randn(AMBIENT_DIM, INTRINSIC_DIM, device=DEVICE)
Q, _ = torch.linalg.qr(Q)
def embed(x2d):
return x2d @ Q.T
def project_back(xD):
return xD @ Q
n_points = 8192
data_type = "swiss_roll"
data = get_data(n_points, data_type)
plt.scatter(data[:,0], data[:,1], s=5, c='black', alpha=0.5)
plt.tight_layout()
plt.show()
# %%
batch_size = 128
# batch_size = 4
global_learning_rate = 5e-4
dataset = torch.from_numpy(data).float().to(DEVICE)
dataset = TensorDataset(dataset)
dataloader = DataLoader(dataset, batch_size=batch_size)
# %%
# ==========================================
# 3. Model (MLP)
# ==========================================
class DenoisingMLP(nn.Module):
def __init__(self, dim, hidden_dim, use_skip_connection: bool = False):
super().__init__()
self.use_skip_connection = use_skip_connection
self.time_encoder = nn.Sequential(
nn.Linear(1, 20),
nn.ReLU(),
)
self.net = nn.Sequential(
nn.Linear(dim + 20, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, dim),
)
# zero out last layer weights and bias
self.net[-1].weight.data.zero_()
self.net[-1].bias.data.zero_()
def forward(self, z_t, t):
t_emb = self.time_encoder(t)
out = self.net(torch.cat((z_t, t_emb), dim=-1))
if self.use_skip_connection:
out = out + z_t # optional residual skip from the noisy input
return out
def format_arch_label(use_skip_connection: bool) -> str:
return "skip" if use_skip_connection else "no_skip"
# %%
# ==================================================
# 4. Loss (x prediction + v loss)
# ==================================================
def compute_loss(x1, eps, xt, t, pred_raw):
x1_hat = pred_raw
v_hat = (x1_hat - xt) / (1-t)
v_target = x1 - eps
return ((v_hat - v_target)**2).mean()
def clone_state_dict(state_dict: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
"""Return a detached clone so optimizer updates do not mutate the snapshot."""
return {key: value.detach().clone() for key, value in state_dict.items()}
# %%
# ==========================================
# 5. Training (x prediction only)
# ==========================================
def train_model(model_arch: bool = False, hidden_dim=256, train_steps=1000):
model = DenoisingMLP(AMBIENT_DIM, hidden_dim, use_skip_connection=model_arch).to(DEVICE)
optim_ = optim.Adam(model.parameters(), lr=global_learning_rate, betas=(0.9, 0.99), eps=1e-12)
arch_label = format_arch_label(model_arch)
print(f"Training DenoisingMLP [{arch_label}] (x-prediction + v-loss)...")
t_eps = 1e-3
losses = []
def train_step(x1_low, model_state_dict: dict[str, torch.Tensor] | None = None):
x1 = embed(x1_low)
eps = torch.randn_like(x1)
t = torch.rand((x1.shape[0], 1), device=DEVICE)
t = t.clip(t_eps, 1 - t_eps)
xt = t * x1 + (1 - t) * eps
pred_raw = model(xt, t)
loss = compute_loss(x1, eps, xt, t, pred_raw)
optim_.zero_grad()
loss.backward()
optim_.step()
if model_state_dict is not None:
current_model_state_dict = clone_state_dict(model.state_dict())
diff_model_state_dict = {
k: (current_model_state_dict[k] - model_state_dict[k]) / global_learning_rate
for k in current_model_state_dict.keys()
}
return loss, diff_model_state_dict, current_model_state_dict
return loss, None, None
model_state_dict = clone_state_dict(model.state_dict())
training_state_diff = {key: [] for key in model_state_dict.keys()}
for _ in tqdm(range(train_steps)):
current_loss = []
state_diff = {key: [] for key in model_state_dict.keys()}
for batch in dataloader:
x1_low = batch[0].to(DEVICE, non_blocking=True)
loss, diff_model_state_dict, model_state_dict = train_step(x1_low, model_state_dict)
if diff_model_state_dict is not None:
for key in diff_model_state_dict.keys():
# log the rms norm of the difference
state_diff[key].append(
torch.sqrt(torch.mean(diff_model_state_dict[key]**2)).item()
)
# losses.append(loss.clone().detach())
current_loss.append(loss.clone().detach().item())
losses.append(np.mean(current_loss))
state_diff = {key: np.mean(state_diff[key]) for key in state_diff.keys()}
for key in state_diff.keys():
training_state_diff[key].append(state_diff[key])
# losses = [loss.item() for loss in losses]
return model, losses, training_state_diff
TRAIN_STEPS = 100
HIDDEN_DIM = 256
experiments = [
False, # no skip connection
True, # add skip connection inside DenoisingMLP
]
results = {}
for model_arch in experiments:
model, losses, training_state_diff = train_model(model_arch=model_arch, hidden_dim=HIDDEN_DIM, train_steps=TRAIN_STEPS)
results[model_arch] = (model, losses, training_state_diff)
# %%
# ==========================================
# 6. Sampling
# ==========================================
@torch.no_grad()
def sample(model, num_samples=2000, steps=100):
x_t = torch.randn(num_samples, AMBIENT_DIM, device=DEVICE)
eps_t = 1e-2
ts = torch.linspace(eps_t, 1 - eps_t, steps, device=DEVICE)
ts = ts.unsqueeze(0).repeat(num_samples, 1)
interval = ts[:, 1:2] - ts[:, 0:1]
for i in range(steps):
t = ts[:, i:i+1]
x1_hat = model(x_t, t)
v_hat = (x1_hat - x_t) / (1 - t)
x_t = x_t + v_hat * interval
return x_t
samples = {}
for model_arch, (model, _, _) in results.items():
samples[model_arch] = sample(model, num_samples=n_points)
# Project back to 2D
x_orig = data
x_preds = {}
for model_arch, sample in samples.items():
x_preds[model_arch] = project_back(sample).cpu().numpy()
# %%
# ==========================================
# 7. Plot
# ==========================================
if not os.path.exists("outputs"):
os.makedirs("outputs")
model_arches = list(dict.fromkeys(experiments))
cols = ["original", "prediction", "loss"]
fig, axes = plt.subplots(
len(model_arches),
len(cols),
figsize=(6 * len(cols), 5 * len(model_arches)),
)
if len(model_arches) == 1:
axes = axes[None, :] # ensure 2D indexing even for a single row
for row_idx, model_arch in enumerate(model_arches):
arch_label = format_arch_label(model_arch)
for col_idx, col_name in enumerate(cols):
ax = axes[row_idx, col_idx]
if col_name == "original":
ax.scatter(x_orig[:, 0], x_orig[:, 1], s=5, c="black", alpha=0.5)
ax.set_title(f"{arch_label} - Ground Truth")
ax.set_xlim(-1.5, 1.5)
ax.set_ylim(-1.5, 1.5)
continue
if col_name == "loss":
if model_arch in results:
losses = results[model_arch][1]
ax.plot(losses)
else:
ax.text(0.5, 0.5, "No loss data", ha="center", va="center", transform=ax.transAxes)
ax.set_title(f"{arch_label} - Training Loss")
ax.set_xlabel("Iteration")
ax.set_ylabel("Loss")
ax.grid(True)
continue
if model_arch not in x_preds:
ax.text(0.5, 0.5, "No predictions", ha="center", va="center", transform=ax.transAxes)
ax.set_title(f"{arch_label} - {col_name}")
ax.set_xlim(-1.5, 1.5)
ax.set_ylim(-1.5, 1.5)
continue
x_pred = x_preds[model_arch]
ax.scatter(x_pred[:, 0], x_pred[:, 1], s=5, alpha=0.5)
ax.set_title(f"{arch_label} - {col_name}")
ax.set_xlim(-1.5, 1.5)
ax.set_ylim(-1.5, 1.5)
plt.tight_layout()
plt.savefig(f"outputs/adam_0pt2_update_{batch_size}.png")
splits = [
("Weight", lambda key: key.endswith("weight")),
("Bias", lambda key: key.endswith("bias")),
]
fig, axes = plt.subplots(
len(model_arches),
len(splits),
figsize=(6 * len(splits), 4.5 * len(model_arches)),
)
if len(model_arches) == 1:
axes = axes[None, :] # force 2D indexing
for row_idx, model_arch in enumerate(model_arches):
arch_label = format_arch_label(model_arch)
model_result = results.get(model_arch)
arch_state_diff = model_result[2] if model_result else None
for col_idx, (split_name, include_param) in enumerate(splits):
ax = axes[row_idx, col_idx]
has_curve = False
if not arch_state_diff:
ax.text(0.5, 0.5, "No state diff data", ha="center", va="center", transform=ax.transAxes)
else:
for key, values in arch_state_diff.items():
if include_param(key):
ax.plot(values, label=key)
has_curve = True
ax.set_title(f"{arch_label} - {split_name} State Diff")
ax.set_xlabel("Iteration")
ax.set_ylabel("State Diff")
ax.grid(True)
if has_curve:
ax.legend(loc="center left", bbox_to_anchor=(1.02, 0.5))
else:
ax.text(0.5, 0.5, f"No {split_name.lower()} params tracked", ha="center", va="center", transform=ax.transAxes)
fig.tight_layout(rect=[0, 0, 0.85, 1])
plt.savefig(f"outputs/adam_0pt2_update_state_diff_{batch_size}.png", bbox_inches="tight")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment