Created
April 2, 2025 10:38
-
-
Save Stefan-Heimersheim/0d18a288959c33419cb48ad1e54b718e to your computer and use it in GitHub Desktop.
Demo of activation space geometry in a simple Toy Model of Superposition setup
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import itertools | |
| import matplotlib as mpl | |
| import matplotlib.pyplot as plt | |
| import numpy as np | |
| import plotly.graph_objects as go | |
| import torch | |
| import torch.nn as nn | |
| import torch.optim as optim | |
| from jaxtyping import Float | |
| from mpl_toolkits.mplot3d import Axes3D | |
| from torch import Tensor | |
| # Set random seed for reproducibility | |
| torch.manual_seed(42) | |
| np.random.seed(42) | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| input_dim = 8 | |
| latent_dim = 3 | |
| # ==== 1. Generate Sparse Data ==== | |
| def generate_sparse_data(batch_size: int = 100000, input_dim: int = 8, sparsity_level: float = 0.95) -> Float[Tensor, "batch input_dim"]: | |
| """Generate sparse data with active feature probability 1 - sparsity_level.""" | |
| # Generate random data | |
| data = np.random.rand(batch_size, input_dim) | |
| # Create sparsity mask (1 = keep, 0 = zero out) | |
| mask = np.random.binomial(1, 1 - sparsity_level, data.shape) | |
| # Apply mask to create sparse data | |
| sparse_data = data * mask | |
| return torch.tensor(sparse_data, dtype=torch.float32) | |
| # Generate data with n_active features active | |
| def generate_n_active_data(batch_size: int = 100000, input_dim: int = 8, n_active: int = 2, sparsity_level: float = 0.95) -> Float[Tensor, "batch input_dim"]: | |
| """Generate data with n_active features active per sample.""" | |
| data = np.random.rand(batch_size, input_dim) | |
| # For each sample, randomly zero out input_dim - n_active features | |
| for i in range(batch_size): | |
| inactive_features = np.random.choice(input_dim, input_dim - n_active, replace=False) | |
| data[i, inactive_features] = 0 | |
| return torch.tensor(data, dtype=torch.float32) | |
| # ==== 2. Define Model ==== | |
| class Autoencoder(nn.Module): | |
| def __init__(self, input_dim: int = input_dim, latent_dim: int = latent_dim): | |
| super(Autoencoder, self).__init__() | |
| self.W = nn.Parameter(torch.randn(input_dim, latent_dim)) | |
| def forward(self, x: Float[Tensor, "batch input_dim"]) -> tuple[Float[Tensor, "batch input_dim"], Float[Tensor, "batch latent_dim"]]: | |
| h = x @ self.W | |
| y = torch.relu(h @ self.W.T) | |
| return y, h | |
| # ==== 3. Setup Training ==== | |
| model = Autoencoder().to(device) | |
| criterion = nn.MSELoss() | |
| optimizer = optim.Adam(model.parameters(), lr=0.01) | |
| # Training loop | |
| num_epochs = 10000 | |
| batch_size = 32 | |
| for epoch in range(num_epochs): | |
| epoch_loss = 0.0 | |
| data = generate_sparse_data(batch_size) | |
| optimizer.zero_grad() | |
| output, _ = model(data.to(device)) | |
| loss = criterion(output, data.to(device)) | |
| loss.backward() | |
| optimizer.step() | |
| epoch_loss += loss.item() | |
| if (epoch + 1) % 10 == 0: | |
| print(f"Epoch {epoch + 1}/{num_epochs}, Loss: {epoch_loss / len(data):.4f}") | |
| # %% Visualizations (Plotly) | |
| # Plot weights as vectors from origin | |
| W = model.W.detach().cpu().numpy().T | |
| fig = go.Figure() | |
| for i in range(input_dim): | |
| fig.add_trace(go.Scatter3d(x=[0, W[0, i]], y=[0, W[1, i]], z=[0, W[2, i]], mode="lines", name=f"Weight {i}")) | |
| fig.update_layout(scene=dict(xaxis_title="X", yaxis_title="Y", zaxis_title="Z"), width=400, height=400) | |
| fig.show() | |
| # Show activations with probabilistic features, plus for exactly 1, 2, and 3 features active | |
| def plot_data(data: Float[Tensor, "batch input_dim"], subplot_idx: int = None, fig: go.Figure = None, title: str = None) -> go.Figure: | |
| _, hidden = model(data.to(device)) | |
| hidden = hidden.detach().cpu().numpy() | |
| if fig is None: | |
| fig = go.Figure() | |
| scatter = go.Scatter3d(x=hidden[:, 0], y=hidden[:, 1], z=hidden[:, 2], mode="markers", marker=dict(size=3, color=data[:, 0], colorscale="Viridis"), scene=f"scene{subplot_idx + 1}" if subplot_idx is not None else "scene") | |
| fig.add_trace(scatter) | |
| if subplot_idx is not None: | |
| fig.layout[f"scene{subplot_idx + 1}"].update(xaxis_title="X", yaxis_title="Y", zaxis_title="Z") | |
| if title: | |
| fig.layout[f"scene{subplot_idx + 1}"].update(annotations=[dict(text=title, x=0.5, y=0.9, showarrow=False)]) | |
| return fig | |
| # Create figure for probabilistic distribution of data | |
| data = generate_sparse_data(100_000) | |
| fig = plot_data(data) | |
| fig.show() | |
| # Create subplots | |
| fig = go.Figure() | |
| fig.update_layout(width=800, height=300, scene1=dict(domain_x=[0, 0.33]), scene2=dict(domain_x=[0.33, 0.66]), scene3=dict(domain_x=[0.66, 1.0])) | |
| # Generate and plot data for each n_active value | |
| for i, n_active in enumerate(range(1, 4)): | |
| data = generate_n_active_data(10000, n_active=n_active) | |
| fig = plot_data(data, subplot_idx=i, fig=fig, title=f"n_active={n_active}") | |
| fig.show() | |
| # %% Matplotlib version of the above 3 plots (3D projection) with Axes3D | |
| def get_colors(data: Float[Tensor, "batch input_dim"]) -> np.ndarray: | |
| colours = list(mpl.colors.TABLEAU_COLORS.values())[:8] | |
| colours_rgb = np.array([mpl.colors.to_rgb(c) for c in colours]) | |
| data = data.detach().cpu().numpy() | |
| def get_color(feature_values): | |
| color = np.array([colours_rgb[i] * mag for i, mag in enumerate(feature_values)]).sum(axis=0) | |
| color = np.clip(color, 0, 1) | |
| return color | |
| return np.array([get_color(data[i]) for i in range(data.shape[0])]) | |
| def plot_data_matplotlib(data: Float[Tensor, "batch input_dim"], subplot_idx: int = None, ax: plt.Axes = None, title: str = None) -> plt.Axes: | |
| _, hidden = model(data.to(device)) | |
| hidden = hidden.detach().cpu().numpy() | |
| ax.scatter(hidden[:, 0], hidden[:, 1], hidden[:, 2], c=get_colors(data), alpha=0.5, s=1) | |
| ax.set_title(title) | |
| ax.set_xlabel("X") | |
| ax.set_ylabel("Y") | |
| ax.set_zlabel("Z") | |
| return ax | |
| fig, axes = plt.subplots(1, 3, figsize=(15, 5), subplot_kw={"projection": "3d"}) | |
| fig.suptitle("TMS (8 feats, 3 dims) hidden activations for exactly 1, 2, or 3 active features. Colored by feature magnitudes.") | |
| for i, n_active in enumerate(range(1, 4)): | |
| data = generate_n_active_data(10000, n_active=n_active) | |
| plot_data_matplotlib(data, subplot_idx=i, ax=axes[i], title=f"n_active={n_active}") | |
| fig.savefig("tms_3d_colored_by_feature_magnitudes.png", dpi=300) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment