Last active
February 11, 2025 11:26
-
-
Save Stefan-Heimersheim/ff1d3b92add92a29602b411b9cd76cec to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import math | |
| import os | |
| from abc import abstractmethod | |
| from collections import defaultdict | |
| from functools import partial | |
| from typing import Optional | |
| import joblib | |
| import numpy as np | |
| import sae_bench.custom_saes.base_sae as base_sae | |
| import torch | |
| import torch.nn.functional as F | |
| from cuml.cluster import KMeans | |
| from datasets import load_dataset | |
| from jaxtyping import Float | |
| from sklearn.decomposition import PCA | |
| from torch import Tensor | |
| from tqdm import tqdm | |
| from transformer_lens import HookedTransformer | |
| from transformer_lens.utils import test_prompt | |
| # %% | |
| def to_numpy(t: Tensor) -> np.ndarray: | |
| return t.detach().cpu().numpy() | |
| def from_numpy(x: np.ndarray, ref: Tensor) -> Tensor: | |
| return torch.from_numpy(x).to(ref.device) | |
| class BaseCompressor(base_sae.BaseSAE): | |
| """ | |
| Abstract base class for compression methods. | |
| """ | |
| def __init__(self, d_sae, hook_layer): | |
| self.device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") | |
| hook_name = f"blocks.{hook_layer}.hook_resid_post" | |
| self.d_in = 512 | |
| self.d_sae = d_sae | |
| super().__init__( | |
| d_in=self.d_in, | |
| d_sae=d_sae, | |
| model_name="pythia-70m-deduped", | |
| hook_layer=hook_layer, | |
| device=self.device, | |
| dtype=torch.float32, | |
| hook_name=hook_name, | |
| ) | |
| @abstractmethod | |
| def compress(self, x: Float[Tensor, "batch d_model"]) -> Float[Tensor, "batch d_sae"]: | |
| pass | |
| @abstractmethod | |
| def uncompress(self, latents: Float[Tensor, "batch d_sae"]) -> Float[Tensor, "batch d_model"]: | |
| pass | |
| # encode, decode, forward for compatibility with SAEBench | |
| def encode(self, x: Float[Tensor, "batch n_ctx d_model"]) -> Float[Tensor, "batch n_ctx d_sae"]: | |
| batch, ctx, d_model = x.shape | |
| flattened_x = x.reshape(-1, self.d_in) | |
| latents = self.compress(flattened_x) | |
| return latents.reshape(batch, ctx, self.d_sae) | |
| def decode( | |
| self, latents: Float[Tensor, "batch n_ctx d_sae"] | |
| ) -> Float[Tensor, "batch n_ctx d_model"]: | |
| batch, ctx, d_sae = latents.shape | |
| flattened_latents = latents.reshape(-1, d_sae) | |
| recon = self.uncompress(flattened_latents) | |
| return recon.reshape(batch, ctx, self.d_in) | |
| def forward(self, x): | |
| return self.decode(self.encode(x)) | |
| class PCACompressor(BaseCompressor): | |
| """ | |
| PCA compressor based on scikit-learn's PCA. | |
| """ | |
| def __init__(self, n_components: int, hook_layer: int): | |
| self.n_components = n_components | |
| self._pca: Optional[PCA] = None | |
| super().__init__(d_sae=n_components, hook_layer=hook_layer) | |
| def fit(self, data: Float[Tensor, "n_samples d_model"]) -> None: | |
| self._pca = PCA(n_components=self.n_components) | |
| self._pca.fit(to_numpy(data)) | |
| def compress(self, data: Float[Tensor, "batch d_model"]) -> Float[Tensor, "batch n_components"]: | |
| if self._pca is None: | |
| raise RuntimeError("PCACompressor: call fit() before compress()") | |
| latents_np = self._pca.transform(to_numpy(data)) | |
| return from_numpy(latents_np, data) | |
| def uncompress( | |
| self, latents: Float[Tensor, "batch n_components"] | |
| ) -> Float[Tensor, "batch d_model"]: | |
| if self._pca is None: | |
| raise RuntimeError("PCACompressor: call fit() before uncompress()") | |
| recon_np = self._pca.inverse_transform(to_numpy(latents)) | |
| return from_numpy(recon_np, latents) | |
| def save(self, path: str) -> None: | |
| os.makedirs(path, exist_ok=True) | |
| torch.save(self._pca, os.path.join(path, "pca.pt")) | |
| def load(self, path: str) -> None: | |
| self._pca = torch.load( | |
| os.path.join(path, "pca.pt"), weights_only=False, map_location=self.device | |
| ) | |
| self.n_components = self._pca.n_components | |
| class KMeansCompressor(BaseCompressor): | |
| """ | |
| KMeans compressor using cuML's KMeans. | |
| (This version does not output additional attributes like normalized centroids.) | |
| """ | |
| def __init__(self, n_clusters: int, hook_layer: int): | |
| self.n_clusters = n_clusters | |
| self._kmeans: Optional[KMeans] = None | |
| super().__init__(d_sae=n_clusters, hook_layer=hook_layer) | |
| def fit(self, data: Float[Tensor, "n_samples d_model"]) -> None: | |
| # Idea: Consider norming data first (kinda what I do in SAElikeKMeansCompressor) | |
| self._kmeans = KMeans(n_clusters=self.n_clusters).fit(to_numpy(data)) | |
| self.cluster_centers = torch.tensor(self._kmeans.cluster_centers_, device=self.device) | |
| def compress(self, data: Float[Tensor, "batch d_model"]) -> Float[Tensor, "batch n_clusters"]: | |
| if self._kmeans is None: | |
| raise RuntimeError("KMeansCompressor: call fit() before compress()") | |
| labels = torch.tensor(self._kmeans.predict(to_numpy(data)), device=data.device) | |
| return F.one_hot(labels.long(), num_classes=self.n_clusters).float() | |
| def uncompress( | |
| self, latents: Float[Tensor, "batch n_clusters"] | |
| ) -> Float[Tensor, "batch d_model"]: | |
| if self._kmeans is None: | |
| raise RuntimeError("KMeansCompressor: call fit() before uncompress()") | |
| labels = latents.argmax(dim=-1) | |
| return self.cluster_centers[labels] | |
| def save(self, path: str) -> None: | |
| os.makedirs(path, exist_ok=True) | |
| joblib.dump(self._kmeans, os.path.join(path, "kmeans.joblib")) | |
| def load(self, path: str) -> None: | |
| self._kmeans = joblib.load(os.path.join(path, "kmeans.joblib")) | |
| self.n_clusters = self._kmeans.n_clusters | |
| self.cluster_centers = torch.tensor(self._kmeans.cluster_centers_, device=self.device) | |
| class SAElikeKMeansCompressor(BaseCompressor): | |
| """ | |
| A KMeans compressor that centers data and then normalizes the centroids. | |
| In compression, it computes a similarity matrix (via dot products) between the centered data | |
| and the normalized centroids, then returns a one-hot (but weighted) latent. | |
| This compressor exposes the attributes `normalized_centroids` and `dataset_mean` for saving. | |
| """ | |
| def __init__(self, n_clusters: int, hook_layer: int): | |
| self.n_clusters = n_clusters | |
| self.dataset_mean: Optional[Tensor] = None | |
| self.normalized_centroids: Optional[Tensor] = None | |
| super().__init__(d_sae=n_clusters, hook_layer=hook_layer) | |
| def fit(self, data: Float[Tensor, "n_samples d_model"]) -> None: | |
| self.dataset_mean = data.mean(dim=0).to(self.device) | |
| data_centered = to_numpy(data - self.dataset_mean) | |
| kmeans = KMeans(n_clusters=self.n_clusters).fit(data_centered) | |
| centroids = torch.tensor(kmeans.cluster_centers_, dtype=torch.float32, device=data.device) | |
| self.normalized_centroids = centroids / (centroids.norm(dim=-1, keepdim=True) + 1e-9) | |
| def compress(self, data: Float[Tensor, "batch d_model"]) -> Float[Tensor, "batch n_clusters"]: | |
| if self.dataset_mean is None or self.normalized_centroids is None: | |
| raise RuntimeError("SAElikeKMeansCompressor: call fit() before compress()") | |
| centered_data = data - self.dataset_mean | |
| alpha = torch.einsum("bd,nd->bn", centered_data, self.normalized_centroids) | |
| labels = alpha.argmax(dim=-1, keepdim=True) | |
| latent = torch.zeros_like(alpha) | |
| latent.scatter_(1, labels, alpha.gather(1, labels)) | |
| return latent | |
| def uncompress( | |
| self, latents: Float[Tensor, "batch n_clusters"] | |
| ) -> Float[Tensor, "batch d_model"]: | |
| if self.normalized_centroids is None or self.dataset_mean is None: | |
| raise RuntimeError("SAElikeKMeansCompressor: call fit() before uncompress()") | |
| labels = latents.argmax(dim=-1) | |
| magnitudes = latents.max(dim=-1).values | |
| return self.normalized_centroids[labels] * magnitudes.unsqueeze(1) + self.dataset_mean | |
| def save(self, path: str) -> None: | |
| os.makedirs(path, exist_ok=True) | |
| torch.save( | |
| {"normalized_centroids": self.normalized_centroids, "dataset_mean": self.dataset_mean}, | |
| os.path.join(path, "sae_kmeans.pt"), | |
| ) | |
| def load(self, path: str) -> None: | |
| state = torch.load( | |
| os.path.join(path, "sae_kmeans.pt"), weights_only=True, map_location=self.device | |
| ) | |
| self.normalized_centroids = state["normalized_centroids"] | |
| self.dataset_mean = state["dataset_mean"] | |
| self.n_clusters = self.normalized_centroids.shape[0] | |
| class PCAKMeansPipelineCompressor(BaseCompressor): | |
| """ | |
| A composite compressor that first applies PCA and then clusters the residual. | |
| """ | |
| def __init__(self, n_pca: int, n_clusters: int, hook_layer: int): | |
| super().__init__(d_sae=n_pca + n_clusters, hook_layer=hook_layer) | |
| self.n_pca = n_pca | |
| self.pca_comp = PCACompressor(n_components=n_pca, hook_layer=hook_layer) | |
| self.kmeans_comp = SAElikeKMeansCompressor(n_clusters=n_clusters, hook_layer=hook_layer) | |
| self.n_clusters = self.kmeans_comp.n_clusters | |
| def fit(self, data: Float[Tensor, "n_samples d_model"]) -> None: | |
| self.pca_comp.fit(data) | |
| pca_recon = self.pca_comp.uncompress(self.pca_comp.compress(data)) | |
| self.kmeans_comp.fit(data - pca_recon) | |
| def compress( | |
| self, data: Float[Tensor, "batch d_model"] | |
| ) -> Float[Tensor, "batch (n_pca + n_clusters)"]: | |
| pca_latents = self.pca_comp.compress(data) | |
| pca_recon = self.pca_comp.uncompress(pca_latents) | |
| residual = data - pca_recon | |
| clustering_latents = self.kmeans_comp.compress(residual) | |
| latents = torch.cat([pca_latents, clustering_latents], dim=-1) | |
| assert latents.shape == ( | |
| data.shape[0], | |
| self.d_sae, | |
| ), f"latents.shape = {latents.shape}, expected {data.shape[0], self.d_sae}" | |
| return latents | |
| def uncompress( | |
| self, latents: Float[Tensor, "batch (n_pca + n_clusters)"] | |
| ) -> Float[Tensor, "batch d_model"]: | |
| pca_latents = latents[:, : self.n_pca] | |
| clustering_latents = latents[:, self.n_pca :] | |
| return self.pca_comp.uncompress(pca_latents) + self.kmeans_comp.uncompress( | |
| clustering_latents | |
| ) | |
| def save(self, path: str) -> None: | |
| os.makedirs(path, exist_ok=True) | |
| self.pca_comp.save(os.path.join(path, "pca/")) | |
| self.kmeans_comp.save(os.path.join(path, "kmeans/")) | |
| def load(self, path: str) -> None: | |
| self.pca_comp.load(os.path.join(path, "pca/")) | |
| self.n_pca = self.pca_comp.n_components | |
| self.kmeans_comp.load(os.path.join(path, "kmeans/")) | |
| self.n_clusters = self.kmeans_comp.n_clusters | |
| class KMeansPCACompressor(BaseCompressor): | |
| """ | |
| A composite compressor that first applies KMeans and then PCA. | |
| """ | |
| def __init__(self, n_pca: int, kmeans_comp: SAElikeKMeansCompressor, hook_layer: int): | |
| """Initialize from existing KMeansCompressor.""" | |
| super().__init__(d_sae=n_pca + kmeans_comp.n_clusters, hook_layer=hook_layer) | |
| self.n_pca = n_pca | |
| assert ( | |
| kmeans_comp.dataset_mean is not None and kmeans_comp.normalized_centroids is not None | |
| ), "Need to provide a fitted KMeansCompressor" | |
| self.kmeans_comp = kmeans_comp | |
| self.pca_comp = PCACompressor(n_components=n_pca, hook_layer=hook_layer) | |
| def fit(self, data: Float[Tensor, "n_samples d_model"]) -> None: | |
| kmeans_recon = self.kmeans_comp.uncompress(self.kmeans_comp.compress(data)) | |
| self.pca_comp.fit(data - kmeans_recon) | |
| def compress( | |
| self, data: Float[Tensor, "batch d_model"] | |
| ) -> Float[Tensor, "batch (n_pca + n_clusters)"]: | |
| kmeans_latents = self.kmeans_comp.compress(data) | |
| kmeans_recon = self.kmeans_comp.uncompress(kmeans_latents) | |
| residual = data - kmeans_recon | |
| pca_latents = self.pca_comp.compress(residual) | |
| return torch.cat([kmeans_latents, pca_latents], dim=1) | |
| def uncompress( | |
| self, latents: Float[Tensor, "batch (n_pca + n_clusters)"] | |
| ) -> Float[Tensor, "batch d_model"]: | |
| kmeans_latents = latents[:, : self.kmeans_comp.n_clusters] | |
| pca_latents = latents[:, self.kmeans_comp.n_clusters :] | |
| return self.pca_comp.uncompress(pca_latents) + self.kmeans_comp.uncompress(kmeans_latents) | |
| def save(self, path: str) -> None: | |
| os.makedirs(path, exist_ok=True) | |
| self.kmeans_comp.save(os.path.join(path, "kmeans/")) | |
| self.pca_comp.save(os.path.join(path, "pca/")) | |
| def load(self, path: str) -> None: | |
| self.kmeans_comp.load(os.path.join(path, "kmeans/")) | |
| self.pca_comp.load(os.path.join(path, "pca/")) | |
| self.n_clusters = self.kmeans_comp.n_clusters | |
| self.n_pca = self.pca_comp.n_components | |
| # %% === Main Block (Data Collection & Experiments) === | |
| if __name__ == "__main__": | |
| os.environ["TOKENIZERS_PARALLELISM"] = "false" | |
| # Which K-Means and PCA combinations to try | |
| layer_indices = [3, 4] | |
| n_latents_list = [4096, 16384] | |
| n_pca_list = [1, 2, 5, 10, 20, 50, 100, 200, 500] | |
| outfile_csv = "pythia_70m_clustering_results_v6d_1.csv" | |
| # Model settings | |
| n_ctx = 128 | |
| batch_size = 1000 | |
| n_train_list = [100_000] | |
| n_eval = 100_000 | |
| n_batches = math.ceil((max(n_train_list) + n_eval) / (batch_size * n_ctx)) | |
| # Load dataset | |
| dataset_name = "apollo-research/Skylion007-openwebtext-tokenizer-EleutherAI-gpt-neox-20b" | |
| dataset = load_dataset(dataset_name, split="train", streaming=True).with_format("torch") | |
| dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=False) | |
| # Load model & test prompt to confirm tokenization is correct. | |
| model = HookedTransformer.from_pretrained_no_processing("EleutherAI/pythia-70m-deduped") | |
| sample = next(iter(dataset)) | |
| sample_prompt_tokens = sample["input_ids"][:18] | |
| sample_answer_token = sample["input_ids"][18] | |
| sample_prompt_text = model.tokenizer.decode(sample_prompt_tokens) | |
| sample_answer_text = model.tokenizer.decode(sample_answer_token) | |
| test_prompt(sample_prompt_text, sample_answer_text, model, prepend_space_to_answer=False) | |
| # === Collect Activations === | |
| activation_cache = defaultdict(list) | |
| n_tokens_available = 0 | |
| def hook(act, hook, d): | |
| assert act.shape == (batch_size, n_ctx, model.cfg.d_model) | |
| act_reshaped = act[:, :, :].reshape(-1, model.cfg.d_model) | |
| d[hook.name].append(act_reshaped.cpu()) | |
| with torch.no_grad(): | |
| for batch_idx, batch in tqdm(enumerate(dataloader), desc="Collecting activations"): | |
| hooks = [] | |
| for layer_idx in layer_indices: | |
| layer_name = f"blocks.{layer_idx}.hook_resid_post" | |
| h = partial(hook, d=activation_cache) | |
| hooks.append((layer_name, h)) | |
| with model.hooks(fwd_hooks=hooks): | |
| inputs = batch["input_ids"][:, :128] | |
| model(inputs) | |
| if batch_idx >= n_batches: | |
| break | |
| # Concatenate and shuffle activations. | |
| activation_cache = {k: torch.cat(v) for k, v in activation_cache.items()} | |
| activation_cache = {k: v[torch.randperm(v.shape[0])] for k, v in activation_cache.items()} | |
| # === Experiment Helpers === | |
| def run_experiment( | |
| compressor: BaseCompressor, | |
| train_acts: Float[Tensor, "n_train d_model"], | |
| eval_acts: Float[Tensor, "n_eval d_model"], | |
| ): | |
| delta_to_mean_squared = (eval_acts - eval_acts.mean(dim=0)).pow(2).sum(dim=-1) | |
| compressor.fit(train_acts) | |
| latents = compressor.compress(eval_acts) | |
| recon = compressor.uncompress(latents) | |
| delta_to_recon_squared = (eval_acts - recon).pow(2).sum(dim=-1) | |
| FVU_conventional = delta_to_recon_squared.mean(dim=0) / delta_to_mean_squared.mean(dim=0) | |
| FVU_saebench = (delta_to_recon_squared / delta_to_mean_squared).mean(dim=0) | |
| return FVU_conventional.item(), FVU_saebench.item(), compressor | |
| # Write CSV header. | |
| with open(outfile_csv, "w") as f: | |
| f.write("method,layer,training_tokens,n_clusters,n_pca,l0,FVU_saebench,FVU_conventional\n") | |
| device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") | |
| for layer_idx in layer_indices: | |
| layer = f"blocks.{layer_idx}.hook_resid_post" | |
| for n_train in n_train_list: | |
| if n_train + n_eval > len(activation_cache[layer]): | |
| print(f"Skipping {n_train} training tokens run because not enough data.") | |
| continue | |
| train_acts = activation_cache[layer][:n_train].to(device) | |
| eval_acts = activation_cache[layer][-n_eval:].to(device) | |
| for n_latents in n_latents_list: | |
| # === Experiment: Pure Clustering (using KMeansCompressor) === | |
| FVU_conventional, FVU_saebench, compressor = run_experiment( | |
| KMeansCompressor(n_clusters=n_latents, hook_layer=layer_idx), | |
| train_acts, | |
| eval_acts, | |
| ) | |
| compressor.save(f"out_v6d/KMeansCompressor_L{layer_idx}_C{n_latents}_T{n_train}") | |
| print( | |
| f"KMeans ({layer_idx=}, {n_latents=}, {n_train=}): FVU = {FVU_conventional:.3f}" | |
| ) | |
| with open(outfile_csv, "a") as f: | |
| f.write( | |
| f"clustering,{layer},{n_train},{n_latents},0,1," | |
| f"{FVU_saebench:.10f},{FVU_conventional:.10f}\n" | |
| ) | |
| # === Experiment: KMeans but with SAE-like encoder === | |
| compressor = SAElikeKMeansCompressor(n_clusters=n_latents, hook_layer=layer_idx) | |
| FVU_conventional, FVU_saebench, compressor = run_experiment( | |
| compressor, | |
| train_acts, | |
| eval_acts, | |
| ) | |
| compressor.save( | |
| f"out_v6d/SAElikeKMeansCompressor_L{layer_idx}_C{n_latents}_T{n_train}" | |
| ) | |
| print( | |
| f"Top1SAE ({layer_idx=}, {n_latents=}, {n_train=}): FVU = {FVU_conventional:.3f}" | |
| ) | |
| with open(outfile_csv, "a") as f: | |
| f.write( | |
| f"top1sae_clustering,{layer},{n_train},{n_latents},0,1," | |
| f"{FVU_saebench:.10f},{FVU_conventional:.10f}\n" | |
| ) | |
| # === Experiment: PCA + Clustering (using PCAKMeansPipelineCompressor) === | |
| for n_pca in n_pca_list: | |
| n_clusters = n_latents - n_pca | |
| if n_clusters <= 0: | |
| continue | |
| compressor = PCAKMeansPipelineCompressor( | |
| n_pca=n_pca, n_clusters=n_clusters, hook_layer=layer_idx | |
| ) | |
| FVU_conventional, FVU_saebench, compressor = run_experiment( | |
| compressor, | |
| train_acts, | |
| eval_acts, | |
| ) | |
| compressor.save( | |
| f"out_v6d/PCAKMeansPipelineCompressor_L{layer_idx}_C{n_latents}_P{n_pca}_T{n_train}" | |
| ) | |
| print( | |
| f"PCA+KMeans ({layer_idx=}, {n_latents=}, {n_train=}, {n_pca=}, {n_clusters=}): FVU = {FVU_conventional:.3f}" | |
| ) | |
| with open(outfile_csv, "a") as f: | |
| f.write( | |
| f"pca_clustering,{layer},{n_train},{n_clusters},{n_pca},{n_pca+1}," | |
| f"{FVU_saebench:.10f},{FVU_conventional:.10f}\n" | |
| ) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment