Skip to content

Instantly share code, notes, and snippets.

@Stefan-Heimersheim
Last active February 11, 2025 11:26
Show Gist options
  • Select an option

  • Save Stefan-Heimersheim/ff1d3b92add92a29602b411b9cd76cec to your computer and use it in GitHub Desktop.

Select an option

Save Stefan-Heimersheim/ff1d3b92add92a29602b411b9cd76cec to your computer and use it in GitHub Desktop.
import math
import os
from abc import abstractmethod
from collections import defaultdict
from functools import partial
from typing import Optional
import joblib
import numpy as np
import sae_bench.custom_saes.base_sae as base_sae
import torch
import torch.nn.functional as F
from cuml.cluster import KMeans
from datasets import load_dataset
from jaxtyping import Float
from sklearn.decomposition import PCA
from torch import Tensor
from tqdm import tqdm
from transformer_lens import HookedTransformer
from transformer_lens.utils import test_prompt
# %%
def to_numpy(t: Tensor) -> np.ndarray:
return t.detach().cpu().numpy()
def from_numpy(x: np.ndarray, ref: Tensor) -> Tensor:
return torch.from_numpy(x).to(ref.device)
class BaseCompressor(base_sae.BaseSAE):
"""
Abstract base class for compression methods.
"""
def __init__(self, d_sae, hook_layer):
self.device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
hook_name = f"blocks.{hook_layer}.hook_resid_post"
self.d_in = 512
self.d_sae = d_sae
super().__init__(
d_in=self.d_in,
d_sae=d_sae,
model_name="pythia-70m-deduped",
hook_layer=hook_layer,
device=self.device,
dtype=torch.float32,
hook_name=hook_name,
)
@abstractmethod
def compress(self, x: Float[Tensor, "batch d_model"]) -> Float[Tensor, "batch d_sae"]:
pass
@abstractmethod
def uncompress(self, latents: Float[Tensor, "batch d_sae"]) -> Float[Tensor, "batch d_model"]:
pass
# encode, decode, forward for compatibility with SAEBench
def encode(self, x: Float[Tensor, "batch n_ctx d_model"]) -> Float[Tensor, "batch n_ctx d_sae"]:
batch, ctx, d_model = x.shape
flattened_x = x.reshape(-1, self.d_in)
latents = self.compress(flattened_x)
return latents.reshape(batch, ctx, self.d_sae)
def decode(
self, latents: Float[Tensor, "batch n_ctx d_sae"]
) -> Float[Tensor, "batch n_ctx d_model"]:
batch, ctx, d_sae = latents.shape
flattened_latents = latents.reshape(-1, d_sae)
recon = self.uncompress(flattened_latents)
return recon.reshape(batch, ctx, self.d_in)
def forward(self, x):
return self.decode(self.encode(x))
class PCACompressor(BaseCompressor):
"""
PCA compressor based on scikit-learn's PCA.
"""
def __init__(self, n_components: int, hook_layer: int):
self.n_components = n_components
self._pca: Optional[PCA] = None
super().__init__(d_sae=n_components, hook_layer=hook_layer)
def fit(self, data: Float[Tensor, "n_samples d_model"]) -> None:
self._pca = PCA(n_components=self.n_components)
self._pca.fit(to_numpy(data))
def compress(self, data: Float[Tensor, "batch d_model"]) -> Float[Tensor, "batch n_components"]:
if self._pca is None:
raise RuntimeError("PCACompressor: call fit() before compress()")
latents_np = self._pca.transform(to_numpy(data))
return from_numpy(latents_np, data)
def uncompress(
self, latents: Float[Tensor, "batch n_components"]
) -> Float[Tensor, "batch d_model"]:
if self._pca is None:
raise RuntimeError("PCACompressor: call fit() before uncompress()")
recon_np = self._pca.inverse_transform(to_numpy(latents))
return from_numpy(recon_np, latents)
def save(self, path: str) -> None:
os.makedirs(path, exist_ok=True)
torch.save(self._pca, os.path.join(path, "pca.pt"))
def load(self, path: str) -> None:
self._pca = torch.load(
os.path.join(path, "pca.pt"), weights_only=False, map_location=self.device
)
self.n_components = self._pca.n_components
class KMeansCompressor(BaseCompressor):
"""
KMeans compressor using cuML's KMeans.
(This version does not output additional attributes like normalized centroids.)
"""
def __init__(self, n_clusters: int, hook_layer: int):
self.n_clusters = n_clusters
self._kmeans: Optional[KMeans] = None
super().__init__(d_sae=n_clusters, hook_layer=hook_layer)
def fit(self, data: Float[Tensor, "n_samples d_model"]) -> None:
# Idea: Consider norming data first (kinda what I do in SAElikeKMeansCompressor)
self._kmeans = KMeans(n_clusters=self.n_clusters).fit(to_numpy(data))
self.cluster_centers = torch.tensor(self._kmeans.cluster_centers_, device=self.device)
def compress(self, data: Float[Tensor, "batch d_model"]) -> Float[Tensor, "batch n_clusters"]:
if self._kmeans is None:
raise RuntimeError("KMeansCompressor: call fit() before compress()")
labels = torch.tensor(self._kmeans.predict(to_numpy(data)), device=data.device)
return F.one_hot(labels.long(), num_classes=self.n_clusters).float()
def uncompress(
self, latents: Float[Tensor, "batch n_clusters"]
) -> Float[Tensor, "batch d_model"]:
if self._kmeans is None:
raise RuntimeError("KMeansCompressor: call fit() before uncompress()")
labels = latents.argmax(dim=-1)
return self.cluster_centers[labels]
def save(self, path: str) -> None:
os.makedirs(path, exist_ok=True)
joblib.dump(self._kmeans, os.path.join(path, "kmeans.joblib"))
def load(self, path: str) -> None:
self._kmeans = joblib.load(os.path.join(path, "kmeans.joblib"))
self.n_clusters = self._kmeans.n_clusters
self.cluster_centers = torch.tensor(self._kmeans.cluster_centers_, device=self.device)
class SAElikeKMeansCompressor(BaseCompressor):
"""
A KMeans compressor that centers data and then normalizes the centroids.
In compression, it computes a similarity matrix (via dot products) between the centered data
and the normalized centroids, then returns a one-hot (but weighted) latent.
This compressor exposes the attributes `normalized_centroids` and `dataset_mean` for saving.
"""
def __init__(self, n_clusters: int, hook_layer: int):
self.n_clusters = n_clusters
self.dataset_mean: Optional[Tensor] = None
self.normalized_centroids: Optional[Tensor] = None
super().__init__(d_sae=n_clusters, hook_layer=hook_layer)
def fit(self, data: Float[Tensor, "n_samples d_model"]) -> None:
self.dataset_mean = data.mean(dim=0).to(self.device)
data_centered = to_numpy(data - self.dataset_mean)
kmeans = KMeans(n_clusters=self.n_clusters).fit(data_centered)
centroids = torch.tensor(kmeans.cluster_centers_, dtype=torch.float32, device=data.device)
self.normalized_centroids = centroids / (centroids.norm(dim=-1, keepdim=True) + 1e-9)
def compress(self, data: Float[Tensor, "batch d_model"]) -> Float[Tensor, "batch n_clusters"]:
if self.dataset_mean is None or self.normalized_centroids is None:
raise RuntimeError("SAElikeKMeansCompressor: call fit() before compress()")
centered_data = data - self.dataset_mean
alpha = torch.einsum("bd,nd->bn", centered_data, self.normalized_centroids)
labels = alpha.argmax(dim=-1, keepdim=True)
latent = torch.zeros_like(alpha)
latent.scatter_(1, labels, alpha.gather(1, labels))
return latent
def uncompress(
self, latents: Float[Tensor, "batch n_clusters"]
) -> Float[Tensor, "batch d_model"]:
if self.normalized_centroids is None or self.dataset_mean is None:
raise RuntimeError("SAElikeKMeansCompressor: call fit() before uncompress()")
labels = latents.argmax(dim=-1)
magnitudes = latents.max(dim=-1).values
return self.normalized_centroids[labels] * magnitudes.unsqueeze(1) + self.dataset_mean
def save(self, path: str) -> None:
os.makedirs(path, exist_ok=True)
torch.save(
{"normalized_centroids": self.normalized_centroids, "dataset_mean": self.dataset_mean},
os.path.join(path, "sae_kmeans.pt"),
)
def load(self, path: str) -> None:
state = torch.load(
os.path.join(path, "sae_kmeans.pt"), weights_only=True, map_location=self.device
)
self.normalized_centroids = state["normalized_centroids"]
self.dataset_mean = state["dataset_mean"]
self.n_clusters = self.normalized_centroids.shape[0]
class PCAKMeansPipelineCompressor(BaseCompressor):
"""
A composite compressor that first applies PCA and then clusters the residual.
"""
def __init__(self, n_pca: int, n_clusters: int, hook_layer: int):
super().__init__(d_sae=n_pca + n_clusters, hook_layer=hook_layer)
self.n_pca = n_pca
self.pca_comp = PCACompressor(n_components=n_pca, hook_layer=hook_layer)
self.kmeans_comp = SAElikeKMeansCompressor(n_clusters=n_clusters, hook_layer=hook_layer)
self.n_clusters = self.kmeans_comp.n_clusters
def fit(self, data: Float[Tensor, "n_samples d_model"]) -> None:
self.pca_comp.fit(data)
pca_recon = self.pca_comp.uncompress(self.pca_comp.compress(data))
self.kmeans_comp.fit(data - pca_recon)
def compress(
self, data: Float[Tensor, "batch d_model"]
) -> Float[Tensor, "batch (n_pca + n_clusters)"]:
pca_latents = self.pca_comp.compress(data)
pca_recon = self.pca_comp.uncompress(pca_latents)
residual = data - pca_recon
clustering_latents = self.kmeans_comp.compress(residual)
latents = torch.cat([pca_latents, clustering_latents], dim=-1)
assert latents.shape == (
data.shape[0],
self.d_sae,
), f"latents.shape = {latents.shape}, expected {data.shape[0], self.d_sae}"
return latents
def uncompress(
self, latents: Float[Tensor, "batch (n_pca + n_clusters)"]
) -> Float[Tensor, "batch d_model"]:
pca_latents = latents[:, : self.n_pca]
clustering_latents = latents[:, self.n_pca :]
return self.pca_comp.uncompress(pca_latents) + self.kmeans_comp.uncompress(
clustering_latents
)
def save(self, path: str) -> None:
os.makedirs(path, exist_ok=True)
self.pca_comp.save(os.path.join(path, "pca/"))
self.kmeans_comp.save(os.path.join(path, "kmeans/"))
def load(self, path: str) -> None:
self.pca_comp.load(os.path.join(path, "pca/"))
self.n_pca = self.pca_comp.n_components
self.kmeans_comp.load(os.path.join(path, "kmeans/"))
self.n_clusters = self.kmeans_comp.n_clusters
class KMeansPCACompressor(BaseCompressor):
"""
A composite compressor that first applies KMeans and then PCA.
"""
def __init__(self, n_pca: int, kmeans_comp: SAElikeKMeansCompressor, hook_layer: int):
"""Initialize from existing KMeansCompressor."""
super().__init__(d_sae=n_pca + kmeans_comp.n_clusters, hook_layer=hook_layer)
self.n_pca = n_pca
assert (
kmeans_comp.dataset_mean is not None and kmeans_comp.normalized_centroids is not None
), "Need to provide a fitted KMeansCompressor"
self.kmeans_comp = kmeans_comp
self.pca_comp = PCACompressor(n_components=n_pca, hook_layer=hook_layer)
def fit(self, data: Float[Tensor, "n_samples d_model"]) -> None:
kmeans_recon = self.kmeans_comp.uncompress(self.kmeans_comp.compress(data))
self.pca_comp.fit(data - kmeans_recon)
def compress(
self, data: Float[Tensor, "batch d_model"]
) -> Float[Tensor, "batch (n_pca + n_clusters)"]:
kmeans_latents = self.kmeans_comp.compress(data)
kmeans_recon = self.kmeans_comp.uncompress(kmeans_latents)
residual = data - kmeans_recon
pca_latents = self.pca_comp.compress(residual)
return torch.cat([kmeans_latents, pca_latents], dim=1)
def uncompress(
self, latents: Float[Tensor, "batch (n_pca + n_clusters)"]
) -> Float[Tensor, "batch d_model"]:
kmeans_latents = latents[:, : self.kmeans_comp.n_clusters]
pca_latents = latents[:, self.kmeans_comp.n_clusters :]
return self.pca_comp.uncompress(pca_latents) + self.kmeans_comp.uncompress(kmeans_latents)
def save(self, path: str) -> None:
os.makedirs(path, exist_ok=True)
self.kmeans_comp.save(os.path.join(path, "kmeans/"))
self.pca_comp.save(os.path.join(path, "pca/"))
def load(self, path: str) -> None:
self.kmeans_comp.load(os.path.join(path, "kmeans/"))
self.pca_comp.load(os.path.join(path, "pca/"))
self.n_clusters = self.kmeans_comp.n_clusters
self.n_pca = self.pca_comp.n_components
# %% === Main Block (Data Collection & Experiments) ===
if __name__ == "__main__":
os.environ["TOKENIZERS_PARALLELISM"] = "false"
# Which K-Means and PCA combinations to try
layer_indices = [3, 4]
n_latents_list = [4096, 16384]
n_pca_list = [1, 2, 5, 10, 20, 50, 100, 200, 500]
outfile_csv = "pythia_70m_clustering_results_v6d_1.csv"
# Model settings
n_ctx = 128
batch_size = 1000
n_train_list = [100_000]
n_eval = 100_000
n_batches = math.ceil((max(n_train_list) + n_eval) / (batch_size * n_ctx))
# Load dataset
dataset_name = "apollo-research/Skylion007-openwebtext-tokenizer-EleutherAI-gpt-neox-20b"
dataset = load_dataset(dataset_name, split="train", streaming=True).with_format("torch")
dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=False)
# Load model & test prompt to confirm tokenization is correct.
model = HookedTransformer.from_pretrained_no_processing("EleutherAI/pythia-70m-deduped")
sample = next(iter(dataset))
sample_prompt_tokens = sample["input_ids"][:18]
sample_answer_token = sample["input_ids"][18]
sample_prompt_text = model.tokenizer.decode(sample_prompt_tokens)
sample_answer_text = model.tokenizer.decode(sample_answer_token)
test_prompt(sample_prompt_text, sample_answer_text, model, prepend_space_to_answer=False)
# === Collect Activations ===
activation_cache = defaultdict(list)
n_tokens_available = 0
def hook(act, hook, d):
assert act.shape == (batch_size, n_ctx, model.cfg.d_model)
act_reshaped = act[:, :, :].reshape(-1, model.cfg.d_model)
d[hook.name].append(act_reshaped.cpu())
with torch.no_grad():
for batch_idx, batch in tqdm(enumerate(dataloader), desc="Collecting activations"):
hooks = []
for layer_idx in layer_indices:
layer_name = f"blocks.{layer_idx}.hook_resid_post"
h = partial(hook, d=activation_cache)
hooks.append((layer_name, h))
with model.hooks(fwd_hooks=hooks):
inputs = batch["input_ids"][:, :128]
model(inputs)
if batch_idx >= n_batches:
break
# Concatenate and shuffle activations.
activation_cache = {k: torch.cat(v) for k, v in activation_cache.items()}
activation_cache = {k: v[torch.randperm(v.shape[0])] for k, v in activation_cache.items()}
# === Experiment Helpers ===
def run_experiment(
compressor: BaseCompressor,
train_acts: Float[Tensor, "n_train d_model"],
eval_acts: Float[Tensor, "n_eval d_model"],
):
delta_to_mean_squared = (eval_acts - eval_acts.mean(dim=0)).pow(2).sum(dim=-1)
compressor.fit(train_acts)
latents = compressor.compress(eval_acts)
recon = compressor.uncompress(latents)
delta_to_recon_squared = (eval_acts - recon).pow(2).sum(dim=-1)
FVU_conventional = delta_to_recon_squared.mean(dim=0) / delta_to_mean_squared.mean(dim=0)
FVU_saebench = (delta_to_recon_squared / delta_to_mean_squared).mean(dim=0)
return FVU_conventional.item(), FVU_saebench.item(), compressor
# Write CSV header.
with open(outfile_csv, "w") as f:
f.write("method,layer,training_tokens,n_clusters,n_pca,l0,FVU_saebench,FVU_conventional\n")
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
for layer_idx in layer_indices:
layer = f"blocks.{layer_idx}.hook_resid_post"
for n_train in n_train_list:
if n_train + n_eval > len(activation_cache[layer]):
print(f"Skipping {n_train} training tokens run because not enough data.")
continue
train_acts = activation_cache[layer][:n_train].to(device)
eval_acts = activation_cache[layer][-n_eval:].to(device)
for n_latents in n_latents_list:
# === Experiment: Pure Clustering (using KMeansCompressor) ===
FVU_conventional, FVU_saebench, compressor = run_experiment(
KMeansCompressor(n_clusters=n_latents, hook_layer=layer_idx),
train_acts,
eval_acts,
)
compressor.save(f"out_v6d/KMeansCompressor_L{layer_idx}_C{n_latents}_T{n_train}")
print(
f"KMeans ({layer_idx=}, {n_latents=}, {n_train=}): FVU = {FVU_conventional:.3f}"
)
with open(outfile_csv, "a") as f:
f.write(
f"clustering,{layer},{n_train},{n_latents},0,1,"
f"{FVU_saebench:.10f},{FVU_conventional:.10f}\n"
)
# === Experiment: KMeans but with SAE-like encoder ===
compressor = SAElikeKMeansCompressor(n_clusters=n_latents, hook_layer=layer_idx)
FVU_conventional, FVU_saebench, compressor = run_experiment(
compressor,
train_acts,
eval_acts,
)
compressor.save(
f"out_v6d/SAElikeKMeansCompressor_L{layer_idx}_C{n_latents}_T{n_train}"
)
print(
f"Top1SAE ({layer_idx=}, {n_latents=}, {n_train=}): FVU = {FVU_conventional:.3f}"
)
with open(outfile_csv, "a") as f:
f.write(
f"top1sae_clustering,{layer},{n_train},{n_latents},0,1,"
f"{FVU_saebench:.10f},{FVU_conventional:.10f}\n"
)
# === Experiment: PCA + Clustering (using PCAKMeansPipelineCompressor) ===
for n_pca in n_pca_list:
n_clusters = n_latents - n_pca
if n_clusters <= 0:
continue
compressor = PCAKMeansPipelineCompressor(
n_pca=n_pca, n_clusters=n_clusters, hook_layer=layer_idx
)
FVU_conventional, FVU_saebench, compressor = run_experiment(
compressor,
train_acts,
eval_acts,
)
compressor.save(
f"out_v6d/PCAKMeansPipelineCompressor_L{layer_idx}_C{n_latents}_P{n_pca}_T{n_train}"
)
print(
f"PCA+KMeans ({layer_idx=}, {n_latents=}, {n_train=}, {n_pca=}, {n_clusters=}): FVU = {FVU_conventional:.3f}"
)
with open(outfile_csv, "a") as f:
f.write(
f"pca_clustering,{layer},{n_train},{n_clusters},{n_pca},{n_pca+1},"
f"{FVU_saebench:.10f},{FVU_conventional:.10f}\n"
)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment