Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Select an option

  • Save Stefan-Heimersheim/0d18a288959c33419cb48ad1e54b718e to your computer and use it in GitHub Desktop.

Select an option

Save Stefan-Heimersheim/0d18a288959c33419cb48ad1e54b718e to your computer and use it in GitHub Desktop.
Demo of activation space geometry in a simple Toy Model of Superposition setup
import itertools
import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
import plotly.graph_objects as go
import torch
import torch.nn as nn
import torch.optim as optim
from jaxtyping import Float
from mpl_toolkits.mplot3d import Axes3D
from torch import Tensor
# Set random seed for reproducibility
torch.manual_seed(42)
np.random.seed(42)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
input_dim = 8
latent_dim = 3
# ==== 1. Generate Sparse Data ====
def generate_sparse_data(batch_size: int = 100000, input_dim: int = 8, sparsity_level: float = 0.95) -> Float[Tensor, "batch input_dim"]:
"""Generate sparse data with active feature probability 1 - sparsity_level."""
# Generate random data
data = np.random.rand(batch_size, input_dim)
# Create sparsity mask (1 = keep, 0 = zero out)
mask = np.random.binomial(1, 1 - sparsity_level, data.shape)
# Apply mask to create sparse data
sparse_data = data * mask
return torch.tensor(sparse_data, dtype=torch.float32)
# Generate data with n_active features active
def generate_n_active_data(batch_size: int = 100000, input_dim: int = 8, n_active: int = 2, sparsity_level: float = 0.95) -> Float[Tensor, "batch input_dim"]:
"""Generate data with n_active features active per sample."""
data = np.random.rand(batch_size, input_dim)
# For each sample, randomly zero out input_dim - n_active features
for i in range(batch_size):
inactive_features = np.random.choice(input_dim, input_dim - n_active, replace=False)
data[i, inactive_features] = 0
return torch.tensor(data, dtype=torch.float32)
# ==== 2. Define Model ====
class Autoencoder(nn.Module):
def __init__(self, input_dim: int = input_dim, latent_dim: int = latent_dim):
super(Autoencoder, self).__init__()
self.W = nn.Parameter(torch.randn(input_dim, latent_dim))
def forward(self, x: Float[Tensor, "batch input_dim"]) -> tuple[Float[Tensor, "batch input_dim"], Float[Tensor, "batch latent_dim"]]:
h = x @ self.W
y = torch.relu(h @ self.W.T)
return y, h
# ==== 3. Setup Training ====
model = Autoencoder().to(device)
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=0.01)
# Training loop
num_epochs = 10000
batch_size = 32
for epoch in range(num_epochs):
epoch_loss = 0.0
data = generate_sparse_data(batch_size)
optimizer.zero_grad()
output, _ = model(data.to(device))
loss = criterion(output, data.to(device))
loss.backward()
optimizer.step()
epoch_loss += loss.item()
if (epoch + 1) % 10 == 0:
print(f"Epoch {epoch + 1}/{num_epochs}, Loss: {epoch_loss / len(data):.4f}")
# %% Visualizations (Plotly)
# Plot weights as vectors from origin
W = model.W.detach().cpu().numpy().T
fig = go.Figure()
for i in range(input_dim):
fig.add_trace(go.Scatter3d(x=[0, W[0, i]], y=[0, W[1, i]], z=[0, W[2, i]], mode="lines", name=f"Weight {i}"))
fig.update_layout(scene=dict(xaxis_title="X", yaxis_title="Y", zaxis_title="Z"), width=400, height=400)
fig.show()
# Show activations with probabilistic features, plus for exactly 1, 2, and 3 features active
def plot_data(data: Float[Tensor, "batch input_dim"], subplot_idx: int = None, fig: go.Figure = None, title: str = None) -> go.Figure:
_, hidden = model(data.to(device))
hidden = hidden.detach().cpu().numpy()
if fig is None:
fig = go.Figure()
scatter = go.Scatter3d(x=hidden[:, 0], y=hidden[:, 1], z=hidden[:, 2], mode="markers", marker=dict(size=3, color=data[:, 0], colorscale="Viridis"), scene=f"scene{subplot_idx + 1}" if subplot_idx is not None else "scene")
fig.add_trace(scatter)
if subplot_idx is not None:
fig.layout[f"scene{subplot_idx + 1}"].update(xaxis_title="X", yaxis_title="Y", zaxis_title="Z")
if title:
fig.layout[f"scene{subplot_idx + 1}"].update(annotations=[dict(text=title, x=0.5, y=0.9, showarrow=False)])
return fig
# Create figure for probabilistic distribution of data
data = generate_sparse_data(100_000)
fig = plot_data(data)
fig.show()
# Create subplots
fig = go.Figure()
fig.update_layout(width=800, height=300, scene1=dict(domain_x=[0, 0.33]), scene2=dict(domain_x=[0.33, 0.66]), scene3=dict(domain_x=[0.66, 1.0]))
# Generate and plot data for each n_active value
for i, n_active in enumerate(range(1, 4)):
data = generate_n_active_data(10000, n_active=n_active)
fig = plot_data(data, subplot_idx=i, fig=fig, title=f"n_active={n_active}")
fig.show()
# %% Matplotlib version of the above 3 plots (3D projection) with Axes3D
def get_colors(data: Float[Tensor, "batch input_dim"]) -> np.ndarray:
colours = list(mpl.colors.TABLEAU_COLORS.values())[:8]
colours_rgb = np.array([mpl.colors.to_rgb(c) for c in colours])
data = data.detach().cpu().numpy()
def get_color(feature_values):
color = np.array([colours_rgb[i] * mag for i, mag in enumerate(feature_values)]).sum(axis=0)
color = np.clip(color, 0, 1)
return color
return np.array([get_color(data[i]) for i in range(data.shape[0])])
def plot_data_matplotlib(data: Float[Tensor, "batch input_dim"], subplot_idx: int = None, ax: plt.Axes = None, title: str = None) -> plt.Axes:
_, hidden = model(data.to(device))
hidden = hidden.detach().cpu().numpy()
ax.scatter(hidden[:, 0], hidden[:, 1], hidden[:, 2], c=get_colors(data), alpha=0.5, s=1)
ax.set_title(title)
ax.set_xlabel("X")
ax.set_ylabel("Y")
ax.set_zlabel("Z")
return ax
fig, axes = plt.subplots(1, 3, figsize=(15, 5), subplot_kw={"projection": "3d"})
fig.suptitle("TMS (8 feats, 3 dims) hidden activations for exactly 1, 2, or 3 active features. Colored by feature magnitudes.")
for i, n_active in enumerate(range(1, 4)):
data = generate_n_active_data(10000, n_active=n_active)
plot_data_matplotlib(data, subplot_idx=i, ax=axes[i], title=f"n_active={n_active}")
fig.savefig("tms_3d_colored_by_feature_magnitudes.png", dpi=300)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment