Last active
August 18, 2025 11:27
-
-
Save Stefan-Heimersheim/85c1091408e113e2ef9ca2a798ec6553 to your computer and use it in GitHub Desktop.
Code: [Interim research report] Activation plateaus & sensitive directions in GPT2
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| # 1.py (Figure 1 is a zoom into one panel of Figure 3/4) | |
| import os | |
| import random | |
| from collections import defaultdict | |
| from collections.abc import Callable | |
| from dataclasses import dataclass | |
| import matplotlib | |
| import matplotlib.pyplot as plt | |
| matplotlib.use("Agg") | |
| import numpy as np | |
| import scipy.interpolate as sip | |
| import torch | |
| import torch.nn.functional as F | |
| from datasets import load_dataset | |
| from torch.distributions.multivariate_normal import MultivariateNormal | |
| from tqdm import tqdm | |
| from transformer_lens import HookedTransformer | |
| os.environ["TOKENIZERS_PARALLELISM"] = "false" | |
| os.makedirs("plots", exist_ok=True) | |
| # %% | |
| def lp_norm(direction, p=2): | |
| """Pytorch norm""" | |
| return torch.linalg.vector_norm(direction, dim=-1, keepdim=True, ord=p) | |
| def mean(direction): | |
| return torch.mean(direction, dim=-1, keepdim=True) | |
| def dot(x, y): | |
| return torch.einsum("...k,...k->...", x, y).unsqueeze(-1) | |
| def compute_angle_dist(ref_act, new_act, datasetmean=None): | |
| angle = torch.acos(dot(new_act, ref_act) / (lp_norm(new_act) * lp_norm(ref_act))).squeeze(-1) | |
| if torch.all(angle[0].isnan()): | |
| angle[0] = torch.zeros_like(angle[0]) | |
| if torch.all(angle[-1].isnan()): | |
| angle[-1] = np.pi * torch.ones_like(angle[-1]) | |
| dist = lp_norm(new_act - ref_act).squeeze(-1) | |
| if datasetmean is None: | |
| return angle, dist | |
| else: | |
| angle_wrt_datasetmean = torch.acos(dot(new_act - datasetmean, ref_act - datasetmean) / (lp_norm(new_act - datasetmean) * lp_norm(ref_act - datasetmean))).squeeze( | |
| -1 | |
| ) | |
| return angle, dist, angle_wrt_datasetmean | |
| class Reference: | |
| def __init__( | |
| self, | |
| model: HookedTransformer, | |
| prompt: torch.Tensor, | |
| replacement_layer: str, | |
| read_layer: str, | |
| replacement_pos: slice, | |
| n_ctx: int, | |
| ): | |
| self.model = model | |
| n_batch_prompt, n_ctx_prompt = prompt.shape | |
| assert n_ctx == n_ctx_prompt, f"n_ctx {n_ctx} must match prompt n_ctx {n_ctx_prompt}" | |
| self.prompt = prompt | |
| logits, cache = model.run_with_cache(prompt) | |
| self.logits = logits.to("cpu").detach() | |
| self.cache = cache.to("cpu") | |
| self.act = self.cache[replacement_layer][:, replacement_pos] | |
| self.replacement_layer = replacement_layer | |
| self.read_layer = read_layer | |
| self.replacement_pos = replacement_pos | |
| self.n_ctx = n_ctx | |
| @dataclass | |
| class Result: | |
| angle: float | |
| angle_wrt_datasetmean: float | |
| dist: float | |
| norm: float | |
| kl_div: float | |
| out_angle: float | |
| l2_diff: float | |
| logit_l2_diff: float | |
| dim0_diff: float | |
| dim1_diff: float | |
| def set_seed(seed: int): | |
| """Set the random seed for reproducibility.""" | |
| torch.manual_seed(seed) | |
| np.random.seed(seed) | |
| random.seed(seed) | |
| def generate_prompt(dataset, tokenizer: Callable, n_ctx: int = 1, batch: int = 1) -> torch.Tensor: | |
| """Generate a prompt from the dataset.""" | |
| dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch, shuffle=True) | |
| return next(iter(dataloader))["input_ids"][:, :n_ctx] | |
| def get_random_activation(model: HookedTransformer, dataset: torch.Tensor, n_ctx: int, layer: str, pos) -> torch.Tensor: | |
| """Get a random activation from the dataset.""" | |
| rand_prompt = generate_prompt(dataset, model.tokenizer, n_ctx=n_ctx) | |
| _, cache = model.run_with_cache(rand_prompt) | |
| return cache[layer][:, pos, :].to("cpu").detach() | |
| def compute_kl_div(logits_ref: torch.Tensor, logits_pert: torch.Tensor) -> torch.Tensor: | |
| """Compute the KL divergence between the reference and perturbed logprobs.""" | |
| logprobs_ref = F.log_softmax(logits_ref, dim=-1) | |
| logprobs_pert = F.log_softmax(logits_pert, dim=-1) | |
| return F.kl_div(logprobs_pert, logprobs_ref, log_target=True, reduction="none").sum(dim=-1) | |
| def compute_Lp_metric( | |
| cache: dict, | |
| cache_pert: dict, | |
| p, | |
| read_layer, | |
| read_pos, | |
| ): | |
| ref_readoff = cache[read_layer][:, read_pos] | |
| pert_readoff = cache_pert[read_layer][:, read_pos] | |
| Lp_diff = torch.linalg.norm(ref_readoff - pert_readoff, ord=p, dim=-1) | |
| return Lp_diff | |
| def run_perturbed_activation(perturbed_act: torch.Tensor, ref: Reference): | |
| pos = ref.replacement_pos | |
| layer = ref.replacement_layer | |
| def hook(act, hook): | |
| act[:, pos, :] = perturbed_act | |
| with ref.model.hooks(fwd_hooks=[(layer, hook)]): | |
| prompts = torch.cat([ref.prompt for _ in range(len(perturbed_act))]) | |
| logits_pert, cache = ref.model.run_with_cache(prompts) | |
| return logits_pert.to("cpu").detach(), cache.to("cpu") | |
| def eval_activation( | |
| perturbed_act: torch.Tensor, | |
| base_ref: Reference, | |
| unrelated_ref: Reference, | |
| read_pos=-1, | |
| ): | |
| read_layer = base_ref.read_layer | |
| angle, dist = compute_angle_dist(base_ref.act, perturbed_act, datasetmean=None) | |
| angle_wrt_datasetmean = angle # PLACEHOLDER | |
| norm = lp_norm(perturbed_act).squeeze(-1) | |
| logits_pert, cache = run_perturbed_activation(perturbed_act, base_ref) | |
| base_kl_div = compute_kl_div(base_ref.logits, logits_pert)[:, read_pos] | |
| base_l1_diff = compute_Lp_metric(base_ref.cache, cache, p=1, read_layer=read_layer, read_pos=read_pos) | |
| base_l2_diff = compute_Lp_metric(base_ref.cache, cache, p=2, read_layer=read_layer, read_pos=read_pos) | |
| base_logit_l2_diff = torch.linalg.norm(base_ref.logits - logits_pert, ord=2, dim=-1) | |
| base_dim0_diff = base_ref.cache[read_layer][:, read_pos, 0] - cache[read_layer][:, read_pos, 0] | |
| base_dim1_diff = base_ref.cache[read_layer][:, read_pos, 1] - cache[read_layer][:, read_pos, 1] | |
| base_out_angle, _ = compute_angle_dist( | |
| base_ref.cache[read_layer][:, read_pos, :], | |
| cache[read_layer][:, read_pos, :], | |
| datasetmean=None, | |
| ) | |
| unrelated_kl_div = compute_kl_div(unrelated_ref.logits, logits_pert)[:, read_pos] | |
| unrelated_l1_diff = compute_Lp_metric(unrelated_ref.cache, cache, p=1, read_layer=read_layer, read_pos=read_pos) | |
| unrelated_l2_diff = compute_Lp_metric(unrelated_ref.cache, cache, p=2, read_layer=read_layer, read_pos=read_pos) | |
| unrelated_logit_l2_diff = torch.linalg.norm(unrelated_ref.logits - logits_pert, ord=2, dim=-1) | |
| unrelated_dim0_diff = unrelated_ref.cache[read_layer][:, read_pos, 0] - cache[read_layer][:, read_pos, 0] | |
| unrelated_dim1_diff = unrelated_ref.cache[read_layer][:, read_pos, 1] - cache[read_layer][:, read_pos, 1] | |
| unrelated_out_angle, _ = compute_angle_dist( | |
| unrelated_ref.cache[read_layer][:, read_pos, :], | |
| cache[read_layer][:, read_pos, :], | |
| datasetmean=None, | |
| ) | |
| return ( | |
| Result( | |
| angle, | |
| angle_wrt_datasetmean, | |
| dist, | |
| norm, | |
| base_kl_div, | |
| base_out_angle, | |
| base_l2_diff, | |
| base_logit_l2_diff, | |
| base_dim0_diff, | |
| base_dim1_diff, | |
| ), | |
| Result( | |
| angle, | |
| angle_wrt_datasetmean, | |
| dist, | |
| norm, | |
| unrelated_kl_div, | |
| unrelated_out_angle, | |
| unrelated_l2_diff, | |
| unrelated_logit_l2_diff, | |
| unrelated_dim0_diff, | |
| unrelated_dim1_diff, | |
| ), | |
| ) | |
| class Slerp: | |
| def __init__(self, start, direction, datasetmean=None): | |
| """Return interpolation points along the sphere from | |
| A towards direction B""" | |
| self.datasetmean = datasetmean | |
| self.A = start - self.datasetmean if self.datasetmean is not None else start | |
| self.A_norm = lp_norm(self.A) | |
| self.a = self.A / self.A_norm | |
| d = direction / lp_norm(direction) | |
| self.B = d - dot(d, self.a) * self.a | |
| self.B_norm = lp_norm(self.B) | |
| self.b = self.B / self.B_norm | |
| def __call__(self, alpha): | |
| result = self.A_norm * (torch.cos(alpha) * self.a + torch.sin(alpha) * self.b) | |
| if self.datasetmean is not None: | |
| return result + self.datasetmean | |
| else: | |
| return result | |
| def get_alpha(self, X): | |
| x = X / lp_norm(X) | |
| return torch.acos(dot(x, self.a)) | |
| class Perturbation: | |
| def gen_direction(self): | |
| raise NotImplementedError | |
| def __init__( | |
| self, | |
| base_ref: Reference, | |
| unrelated_ref: Reference, | |
| ensure_mean=True, | |
| ensure_norm=True, | |
| ): | |
| self.ensure_mean = ensure_mean | |
| self.ensure_norm = ensure_norm | |
| self.base_ref = base_ref | |
| self.norm_base = lp_norm(base_ref.act) | |
| self.unrelated_ref = unrelated_ref | |
| def scan(self, n_steps: int = 33, range: tuple[float, float] = (0, 1.0), step_angular=True): | |
| direction = self.gen_direction() | |
| return self.scan_dir(direction, n_steps, range, step_angular) | |
| def scan_dir(self, direction, n_steps: int = 33, range: tuple[float, float] = (0, 1.0), step_angular=True): | |
| direction = direction.clone() | |
| if self.ensure_mean: | |
| direction -= mean(direction) | |
| if self.ensure_norm: | |
| direction *= self.norm_base / lp_norm(direction) | |
| range = np.array(range) | |
| if step_angular: | |
| s = Slerp(self.base_ref.act, direction, datasetmean=None) | |
| self.activation_steps = [s(alpha) for alpha in torch.linspace(*range, n_steps)] | |
| else: | |
| self.activation_steps = [self.base_ref.act + alpha * direction for alpha in torch.linspace(*range, n_steps)] | |
| act = torch.cat(self.activation_steps, dim=0) | |
| torch.cuda.empty_cache() | |
| base_result, unrelated_result = eval_activation(act, self.base_ref, self.unrelated_ref) | |
| return base_result, unrelated_result, direction | |
| class RandomUniformPerturbation(Perturbation): | |
| def gen_direction(self): | |
| return torch.randn_like(self.base_ref.act) | |
| class RandomPerturbation(Perturbation): | |
| def gen_target(self): | |
| return self.distrib.sample(self.base_ref.act.shape[:-1]) | |
| def gen_direction(self): | |
| self.target = self.gen_target() | |
| return self.target - self.base_ref.act | |
| class RandomActDirPerturbation(Perturbation): | |
| def gen_target(self): | |
| return get_random_activation( | |
| self.base_ref.model, | |
| dataset, | |
| self.base_ref.n_ctx, | |
| self.base_ref.replacement_layer, | |
| self.base_ref.replacement_pos, | |
| ) | |
| def gen_direction(self): | |
| self.target = self.gen_target() | |
| return self.target - self.base_ref.act | |
| class NeuronDirPerturbation(Perturbation): | |
| def __init__(self, ref: Reference, negate=False, on_only=None): | |
| super().__init__(ref) | |
| self.base_ref = ref | |
| self.negate = 1 if not negate else -1 | |
| if on_only is True: | |
| feature_acts = self.base_ref.cache[ref.replacement_layer][0, -1, :] | |
| self.active_features = feature_acts / feature_acts.max() > 0.1 | |
| def gen_direction(self): | |
| if self.active_features is None: | |
| random_int = random.randint(0, 768) | |
| else: | |
| random_int = random.choice(self.active_features.nonzero(as_tuple=True)[0]) | |
| one_hot = torch.zeros_like(self.base_ref.act) | |
| one_hot[..., random_int] = 1 | |
| single_direction = self.negate * one_hot | |
| return torch.stack([single_direction for _ in range(self.base_ref.act.shape[0])]) | |
| class OptimizedPerturbation(Perturbation): | |
| def __init__(self, ref: Reference, learning_rate: float = 0.1): | |
| super().__init__(ref) | |
| self.learning_rate = learning_rate | |
| def sample(self, num_steps: int, step_size: float = 0.01, sign=-1, init=None): | |
| kl_divs = [] | |
| directions = [] | |
| if init is not None: | |
| init = np.array(init.cpu().detach()) | |
| direction = torch.tensor(init, requires_grad=True, device=self.base_ref.act.device) | |
| else: | |
| direction = torch.randn_like(self.base_ref.act, requires_grad=True) | |
| optimizer = torch.optim.SGD([direction], lr=self.learning_rate) | |
| for _ in tqdm(range(num_steps), desc="Finding perturbation direction"): | |
| kl_div, _, _ = eval_direction(direction, self.base_ref, step_size) | |
| kl_divs.append(kl_div.item()) | |
| directions.append(direction.detach().clone()) | |
| (sign * kl_div).backward(retain_graph=True) | |
| optimizer.step() | |
| optimizer.zero_grad() | |
| return kl_divs, directions | |
| def format_ax(ax1, ax1t, ax2, ax2t, results, label, color, x, base_ref, ls="-", x_is_angle=True): | |
| assert torch.allclose(results.angle.T - results.angle[:, 0].T, torch.tensor(0.0)) | |
| assert torch.allclose(results.dist.T - results.dist[:, 0].T, torch.tensor(0.0)) | |
| angles = results.angle[:, 0] | |
| dists = results.dist[:, 0].detach() | |
| ax1.plot(x, results.kl_div, label=label, color=color, lw=0.5, ls=ls) | |
| ax2.plot(x, results.l2_diff, label=label, color=color, lw=0.5, ls=ls) | |
| ax1.set_ylabel("KL divergence to base logits") | |
| ax2.set_ylabel(f"L2 difference in {base_ref.read_layer}") | |
| ax1.legend() | |
| ax2.legend() | |
| ax1.set_xlim(min(x), max(x)) | |
| ax2.set_xlim(min(x), max(x)) | |
| if len(angles) < 40: | |
| tick_angles = angles | |
| tick_dists = dists | |
| else: | |
| tick_angles = angles[::10] | |
| tick_dists = dists[::10] | |
| if x_is_angle: | |
| ax1.set_xlabel(f"Angle from base activation at {layer}") | |
| ax2.set_xlabel(f"Angle from base activation at {layer}") | |
| ax1.set_xticks(tick_angles) | |
| ax1.set_xticklabels([f"{angle * 180 / np.pi:.0f}°" for angle in tick_angles], rotation=-80) | |
| ax2.set_xticks(tick_angles) | |
| ax2.set_xticklabels([f"{angle * 180 / np.pi:.0f}°" for angle in tick_angles], rotation=-80) | |
| if label is not None: | |
| ax1t.set_xlabel("Distance") | |
| ax2t.set_xlabel("Distance") | |
| ax1t.set_xticks(ax1.get_xticks()) | |
| ax2t.set_xticks(ax2.get_xticks()) | |
| ax1t.set_xticklabels([f"{dist:.1f}" for dist in tick_dists], rotation=-80) | |
| ax2t.set_xticklabels([f"{dist:.1f}" for dist in tick_dists], rotation=-80) | |
| else: | |
| ax1.set_xlabel(f"Distance from base activation at {layer}") | |
| ax2.set_xlabel(f"Distance from base activation at {layer}") | |
| ax1.set_xticks(tick_dists) | |
| ax1.set_xticklabels([f"{dist:.1f}" for dist in tick_dists], rotation=-80) | |
| ax2.set_xticks(tick_dists) | |
| ax2.set_xticklabels([f"{dist:.1f}" for dist in tick_dists], rotation=-80) | |
| if label is not None: | |
| ax1t.set_xlabel("Angle") | |
| ax2t.set_xlabel("Angle") | |
| ax1t.set_xticks(ax1.get_xticks()) | |
| ax2t.set_xticks(ax2.get_xticks()) | |
| ax1t.set_xticklabels([f"{angle * 180 / np.pi:.0f}°" for angle in tick_angles], rotation=-80) | |
| ax2t.set_xticklabels([f"{angle * 180 / np.pi:.0f}°" for angle in tick_angles], rotation=-80) | |
| def find_sensitivity(x, y, threshold=0.5): | |
| f = sip.interp1d(x, y, kind="linear") | |
| angles_interp = np.linspace(x.min(), x.max(), 100000) | |
| index_interp = np.argmax(f(angles_interp) > threshold) | |
| return angles_interp[index_interp] | |
| # %% | |
| torch.cuda.empty_cache() | |
| print(f"Allocated: {torch.cuda.memory_allocated() / 1024**3:.2f} GB") | |
| seed = 0 | |
| step_angular = False | |
| on_only = True | |
| negate = on_only | |
| set_seed(seed) | |
| n_ctx = 10 | |
| layer = "blocks.1.hook_resid_pre" | |
| pos = slice(-1, None, 1) | |
| num_steps = 40 | |
| step_size = 1 | |
| learning_rate = 1 | |
| dataset = load_dataset("apollo-research/Skylion007-openwebtext-tokenizer-gpt2", split="train", streaming=False).with_format("torch") | |
| dataloader = torch.utils.data.DataLoader(dataset, batch_size=15, shuffle=True) | |
| # %% | |
| model = HookedTransformer.from_pretrained("gpt2") | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| # %% | |
| base_prompt = generate_prompt(dataset, model.tokenizer, n_ctx=n_ctx) | |
| base_ref = Reference(model, base_prompt, layer, "blocks.11.hook_resid_post", pos, n_ctx) | |
| unrelated_prompt = generate_prompt(dataset, model.tokenizer, n_ctx=n_ctx) | |
| unrelated_ref = Reference(model, unrelated_prompt, layer, "blocks.11.hook_resid_post", pos, n_ctx) | |
| # %% | |
| torch.cuda.empty_cache() | |
| print(f"Allocated: {torch.cuda.memory_allocated() / 1024**3:.2f} GB") | |
| tensor_of_prompts = generate_prompt(dataset, model.tokenizer, n_ctx=n_ctx, batch=4000) | |
| mean_act_cache = model.run_with_cache(tensor_of_prompts)[1].to("cpu") | |
| plt.figure() | |
| plt.plot(mean_act_cache[layer].mean(dim=0)[1:].mean(dim=0).cpu()) | |
| plt.title(f"Mean activations for {layer}") | |
| plt.xlabel("Dimension") | |
| plt.ylabel("Activation") | |
| plt.savefig(f"plots/mean_activations_seed_{seed}.png") | |
| plt.close() | |
| print(f"Allocated: {torch.cuda.memory_allocated() / 1024**3:.2f} GB") | |
| data = mean_act_cache[layer][:, -1, :] | |
| data_mean = data.mean(dim=0, keepdim=True) | |
| data_cov = torch.einsum("ij,ik->jk", data - data_mean, data - data_mean) / data.shape[0] | |
| distrib = MultivariateNormal(data_mean.squeeze(0), data_cov) | |
| # %% | |
| random_perturbation = RandomPerturbation(base_ref, unrelated_ref) | |
| random_perturbation.distrib = distrib | |
| randomact_perturbation = RandomActDirPerturbation(base_ref, unrelated_ref) | |
| perturbation = RandomPerturbation(base_ref, unrelated_ref) | |
| perturbation.distrib = distrib | |
| results_list = defaultdict(list) | |
| for c, p in [ | |
| ("C0", random_perturbation), | |
| ("C1", randomact_perturbation), | |
| ]: | |
| target = p.gen_target() | |
| perturbation.base_ref.act = target | |
| perturbation.base_ref.logits, perturbation.base_ref.cache = run_perturbed_activation(target, perturbation.base_ref) | |
| for i in tqdm(range(20)): | |
| results, _, _ = perturbation.scan(n_steps=37, step_angular=True, range=(0, 1)) | |
| results_list[c].append(results) | |
| # %% | |
| fig, [ax1, ax2] = plt.subplots(1, 2, figsize=(12, 6), constrained_layout=True) | |
| ax1t = ax1.twiny() | |
| ax2t = ax2.twiny() | |
| prompt_str = "|" + "|".join(model.to_str_tokens(base_ref.prompt)).replace("\n", "⏎") + "|" | |
| title_str = "Sphere mode (slerping angle around origin)" if step_angular else "Straight mode" | |
| fig.suptitle(f"{title_str}. Perturbing {layer} at pos {pos}.\nSeed = {seed}. Prompt = {prompt_str}. Norm = {lp_norm(base_ref.act).item():.2f}.") | |
| for c in results_list: | |
| label = "Random perturbation around random activation" if c == "C0" else "Random perturbation around (random) real activation" if c == "C1" else None | |
| for i, r in enumerate(results_list[c]): | |
| r.dist[:, 0].detach().cpu().numpy() | |
| format_ax( | |
| ax1, | |
| ax1t, | |
| ax2, | |
| ax2t, | |
| r, | |
| label if i == 0 else None, | |
| c, | |
| r.angle[:, 0] if step_angular else r.dist[:, 0].detach(), | |
| base_ref, | |
| x_is_angle=step_angular, | |
| ) | |
| mode_str = "sphere" if step_angular else "straight" | |
| fig.savefig(f"plots/perturbations_{mode_str}_seed_{seed}.png", dpi=150, bbox_inches="tight") | |
| plt.close() |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| # 2_naive.py | |
| # %% | |
| import os | |
| import random | |
| from collections.abc import Callable | |
| from dataclasses import dataclass | |
| import matplotlib | |
| import matplotlib.pyplot as plt | |
| matplotlib.use("Agg") | |
| import numpy as np | |
| import scipy.interpolate as sip | |
| import torch | |
| import torch.nn.functional as F | |
| from datasets import load_dataset | |
| from torch.distributions.multivariate_normal import MultivariateNormal | |
| from tqdm import tqdm | |
| from transformer_lens import HookedTransformer | |
| os.environ["TOKENIZERS_PARALLELISM"] = "false" | |
| os.makedirs("plots", exist_ok=True) | |
| # %% | |
| def lp_norm(direction, p=2): | |
| """Pytorch norm""" | |
| return torch.linalg.vector_norm(direction, dim=-1, keepdim=True, ord=p) | |
| def mean(direction): | |
| return torch.mean(direction, dim=-1, keepdim=True) | |
| def dot(x, y): | |
| return torch.einsum("...k,...k->...", x, y).unsqueeze(-1) | |
| def compute_angle_dist(ref_act, new_act, datasetmean=None): | |
| angle = torch.acos(dot(new_act, ref_act) / (lp_norm(new_act) * lp_norm(ref_act))).squeeze(-1) | |
| if torch.all(angle[0].isnan()): | |
| angle[0] = torch.zeros_like(angle[0]) | |
| if torch.all(angle[-1].isnan()): | |
| angle[-1] = np.pi * torch.ones_like(angle[-1]) | |
| dist = lp_norm(new_act - ref_act).squeeze(-1) | |
| if datasetmean is None: | |
| return angle, dist | |
| else: | |
| angle_wrt_datasetmean = torch.acos(dot(new_act - datasetmean, ref_act - datasetmean) / (lp_norm(new_act - datasetmean) * lp_norm(ref_act - datasetmean))).squeeze( | |
| -1 | |
| ) | |
| return angle, dist, angle_wrt_datasetmean | |
| class Reference: | |
| def __init__( | |
| self, | |
| model: HookedTransformer, | |
| prompt: torch.Tensor, | |
| replacement_layer: str, | |
| read_layer: str, | |
| replacement_pos: slice, | |
| n_ctx: int, | |
| ): | |
| self.model = model | |
| n_batch_prompt, n_ctx_prompt = prompt.shape | |
| assert n_ctx == n_ctx_prompt, f"n_ctx {n_ctx} must match prompt n_ctx {n_ctx_prompt}" | |
| self.prompt = prompt | |
| logits, cache = model.run_with_cache(prompt) | |
| self.logits = logits.to("cpu").detach() | |
| self.cache = cache.to("cpu") | |
| self.act = self.cache[replacement_layer][:, replacement_pos] | |
| self.replacement_layer = replacement_layer | |
| self.read_layer = read_layer | |
| self.replacement_pos = replacement_pos | |
| self.n_ctx = n_ctx | |
| @dataclass | |
| class Result: | |
| angle: float | |
| angle_wrt_datasetmean: float | |
| dist: float | |
| norm: float | |
| kl_div: float | |
| out_angle: float | |
| l2_diff: float | |
| logit_l2_diff: float | |
| dim0_diff: float | |
| dim1_diff: float | |
| def set_seed(seed: int): | |
| """Set the random seed for reproducibility.""" | |
| torch.manual_seed(seed) | |
| np.random.seed(seed) | |
| random.seed(seed) | |
| def generate_prompt(dataset, tokenizer: Callable, n_ctx: int = 1, batch: int = 1) -> torch.Tensor: | |
| """Generate a prompt from the dataset.""" | |
| dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch, shuffle=True) | |
| return next(iter(dataloader))["input_ids"][:, :n_ctx] | |
| tokens = [0, *sample[: n_ctx - 1]] | |
| return torch.tensor([tokens[:n_ctx]]) | |
| def get_random_activation(model: HookedTransformer, dataset: torch.Tensor, n_ctx: int, layer: str, pos) -> torch.Tensor: | |
| """Get a random activation from the dataset.""" | |
| rand_prompt = generate_prompt(dataset, model.tokenizer, n_ctx=n_ctx) | |
| _, cache = model.run_with_cache(rand_prompt) | |
| return cache[layer][:, pos, :].to("cpu").detach() | |
| def compute_kl_div(logits_ref: torch.Tensor, logits_pert: torch.Tensor) -> torch.Tensor: | |
| """Compute the KL divergence between the reference and perturbed logprobs.""" | |
| logprobs_ref = F.log_softmax(logits_ref, dim=-1) | |
| logprobs_pert = F.log_softmax(logits_pert, dim=-1) | |
| return F.kl_div(logprobs_pert, logprobs_ref, log_target=True, reduction="none").sum(dim=-1) | |
| def compute_Lp_metric( | |
| cache: dict, | |
| cache_pert: dict, | |
| p, | |
| read_layer, | |
| read_pos, | |
| ): | |
| ref_readoff = cache[read_layer][:, read_pos] | |
| pert_readoff = cache_pert[read_layer][:, read_pos] | |
| Lp_diff = torch.linalg.norm(ref_readoff - pert_readoff, ord=p, dim=-1) | |
| return Lp_diff | |
| def run_perturbed_activation(perturbed_act: torch.Tensor, ref: Reference): | |
| pos = ref.replacement_pos | |
| layer = ref.replacement_layer | |
| def hook(act, hook): | |
| act[:, pos, :] = perturbed_act | |
| with ref.model.hooks(fwd_hooks=[(layer, hook)]): | |
| prompts = torch.cat([ref.prompt for _ in range(len(perturbed_act))]) | |
| logits_pert, cache = ref.model.run_with_cache(prompts) | |
| return logits_pert.to("cpu").detach(), cache.to("cpu") | |
| def eval_activation( | |
| perturbed_act: torch.Tensor, | |
| base_ref: Reference, | |
| unrelated_ref: Reference, | |
| read_pos=-1, | |
| ): | |
| read_layer = base_ref.read_layer | |
| angle, dist = compute_angle_dist(base_ref.act, perturbed_act, datasetmean=None) | |
| angle_wrt_datasetmean = angle | |
| norm = lp_norm(perturbed_act).squeeze(-1) | |
| logits_pert, cache = run_perturbed_activation(perturbed_act, base_ref) | |
| base_kl_div = compute_kl_div(base_ref.logits, logits_pert)[:, read_pos] | |
| base_l1_diff = compute_Lp_metric(base_ref.cache, cache, p=1, read_layer=read_layer, read_pos=read_pos) | |
| base_l2_diff = compute_Lp_metric(base_ref.cache, cache, p=2, read_layer=read_layer, read_pos=read_pos) | |
| base_logit_l2_diff = torch.linalg.norm(base_ref.logits - logits_pert, ord=2, dim=-1) | |
| base_dim0_diff = base_ref.cache[read_layer][:, read_pos, 0] - cache[read_layer][:, read_pos, 0] | |
| base_dim1_diff = base_ref.cache[read_layer][:, read_pos, 1] - cache[read_layer][:, read_pos, 1] | |
| base_out_angle, _ = compute_angle_dist( | |
| base_ref.cache[read_layer][:, read_pos, :], | |
| cache[read_layer][:, read_pos, :], | |
| datasetmean=None, | |
| ) | |
| unrelated_kl_div = compute_kl_div(unrelated_ref.logits, logits_pert)[:, read_pos] | |
| unrelated_l1_diff = compute_Lp_metric(unrelated_ref.cache, cache, p=1, read_layer=read_layer, read_pos=read_pos) | |
| unrelated_l2_diff = compute_Lp_metric(unrelated_ref.cache, cache, p=2, read_layer=read_layer, read_pos=read_pos) | |
| unrelated_logit_l2_diff = torch.linalg.norm(unrelated_ref.logits - logits_pert, ord=2, dim=-1) | |
| unrelated_dim0_diff = unrelated_ref.cache[read_layer][:, read_pos, 0] - cache[read_layer][:, read_pos, 0] | |
| unrelated_dim1_diff = unrelated_ref.cache[read_layer][:, read_pos, 1] - cache[read_layer][:, read_pos, 1] | |
| unrelated_out_angle, _ = compute_angle_dist( | |
| unrelated_ref.cache[read_layer][:, read_pos, :], | |
| cache[read_layer][:, read_pos, :], | |
| datasetmean=None, | |
| ) | |
| return ( | |
| Result( | |
| angle, | |
| angle_wrt_datasetmean, | |
| dist, | |
| norm, | |
| base_kl_div, | |
| base_out_angle, | |
| base_l2_diff, | |
| base_logit_l2_diff, | |
| base_dim0_diff, | |
| base_dim1_diff, | |
| ), | |
| Result( | |
| angle, | |
| angle_wrt_datasetmean, | |
| dist, | |
| norm, | |
| unrelated_kl_div, | |
| unrelated_out_angle, | |
| unrelated_l2_diff, | |
| unrelated_logit_l2_diff, | |
| unrelated_dim0_diff, | |
| unrelated_dim1_diff, | |
| ), | |
| ) | |
| class Slerp: | |
| def __init__(self, start, direction, datasetmean=None): | |
| """Return interpolation points along the sphere from | |
| A towards direction B""" | |
| self.datasetmean = datasetmean | |
| self.A = start - self.datasetmean if self.datasetmean is not None else start | |
| self.A_norm = lp_norm(self.A) | |
| self.a = self.A / self.A_norm | |
| d = direction / lp_norm(direction) | |
| self.B = d - dot(d, self.a) * self.a | |
| self.B_norm = lp_norm(self.B) | |
| self.b = self.B / self.B_norm | |
| def __call__(self, alpha): | |
| result = self.A_norm * (torch.cos(alpha) * self.a + torch.sin(alpha) * self.b) | |
| if self.datasetmean is not None: | |
| return result + self.datasetmean | |
| else: | |
| return result | |
| def get_alpha(self, X): | |
| x = X / lp_norm(X) | |
| return torch.acos(dot(x, self.a)) | |
| class Perturbation: | |
| def gen_direction(self): | |
| raise NotImplementedError | |
| def __init__( | |
| self, | |
| base_ref: Reference, | |
| unrelated_ref: Reference, | |
| ensure_mean=True, | |
| ensure_norm=True, | |
| ): | |
| self.ensure_mean = ensure_mean | |
| self.ensure_norm = ensure_norm | |
| self.base_ref = base_ref | |
| self.norm_base = lp_norm(base_ref.act) | |
| self.unrelated_ref = unrelated_ref | |
| def scan(self, n_steps: int = 33, range: tuple[float, float] = (0, 1.0), step_angular=True): | |
| direction = self.gen_direction() | |
| return self.scan_dir(direction, n_steps, range, step_angular) | |
| def scan_dir(self, direction, n_steps: int = 33, range: tuple[float, float] = (0, 1.0), step_angular=True): | |
| direction = direction.clone() | |
| if self.ensure_mean: | |
| direction -= mean(direction) | |
| if self.ensure_norm: | |
| direction *= self.norm_base / lp_norm(direction) | |
| range = np.array(range) | |
| if step_angular: | |
| s = Slerp(self.base_ref.act, direction, datasetmean=None) | |
| self.activation_steps = [s(alpha) for alpha in torch.linspace(*range, n_steps)] | |
| else: | |
| self.activation_steps = [self.base_ref.act + alpha * direction for alpha in torch.linspace(*range, n_steps)] | |
| act = torch.cat(self.activation_steps, dim=0) | |
| torch.cuda.empty_cache() | |
| base_result, unrelated_result = eval_activation(act, self.base_ref, self.unrelated_ref) | |
| return base_result, unrelated_result, direction | |
| class RandomUniformPerturbation(Perturbation): | |
| def gen_direction(self): | |
| return torch.randn_like(self.base_ref.act) | |
| class RandomPerturbation(Perturbation): | |
| def gen_target(self): | |
| return self.distrib.sample(self.base_ref.act.shape[:-1]) | |
| def gen_direction(self): | |
| self.target = self.gen_target() | |
| return self.target - self.base_ref.act | |
| class RandomActDirPerturbation(Perturbation): | |
| def gen_target(self): | |
| return get_random_activation( | |
| self.base_ref.model, | |
| dataset, | |
| self.base_ref.n_ctx, | |
| self.base_ref.replacement_layer, | |
| self.base_ref.replacement_pos, | |
| ) | |
| def gen_direction(self): | |
| self.target = self.gen_target() | |
| return self.target - self.base_ref.act | |
| class NeuronDirPerturbation(Perturbation): | |
| def __init__(self, ref: Reference, negate=False, on_only=None): | |
| super().__init__(ref) | |
| self.base_ref = ref | |
| self.negate = 1 if not negate else -1 | |
| if on_only is True: | |
| feature_acts = self.base_ref.cache[ref.replacement_layer][0, -1, :] | |
| self.active_features = feature_acts / feature_acts.max() > 0.1 | |
| def gen_direction(self): | |
| if self.active_features is None: | |
| random_int = random.randint(0, 768) | |
| else: | |
| random_int = random.choice(self.active_features.nonzero(as_tuple=True)[0]) | |
| one_hot = torch.zeros_like(self.base_ref.act) | |
| one_hot[..., random_int] = 1 | |
| single_direction = self.negate * one_hot | |
| return torch.stack([single_direction for _ in range(self.base_ref.act.shape[0])]) | |
| class OptimizedPerturbation(Perturbation): | |
| def __init__(self, ref: Reference, learning_rate: float = 0.1): | |
| super().__init__(ref) | |
| self.learning_rate = learning_rate | |
| def sample(self, num_steps: int, step_size: float = 0.01, sign=-1, init=None): | |
| kl_divs = [] | |
| directions = [] | |
| if init is not None: | |
| init = np.array(init.cpu().detach()) | |
| direction = torch.tensor(init, requires_grad=True, device=self.base_ref.act.device) | |
| else: | |
| direction = torch.randn_like(self.base_ref.act, requires_grad=True) | |
| optimizer = torch.optim.SGD([direction], lr=self.learning_rate) | |
| for _ in tqdm(range(num_steps), desc="Finding perturbation direction"): | |
| kl_div, _, _ = eval_direction(direction, self.base_ref, step_size) | |
| kl_divs.append(kl_div.item()) | |
| directions.append(direction.detach().clone()) | |
| (sign * kl_div).backward(retain_graph=True) | |
| optimizer.step() | |
| optimizer.zero_grad() | |
| return kl_divs, directions | |
| def format_ax(ax1, ax1t, ax2, ax2t, results, label, color, x, base_ref, ls="-", x_is_angle=True): | |
| assert torch.allclose(results.angle.T - results.angle[:, 0].T, torch.tensor(0.0)) | |
| assert torch.allclose(results.dist.T - results.dist[:, 0].T, torch.tensor(0.0)) | |
| angles = r.angle[:, 0] | |
| dists = results.dist[:, 0] | |
| ax1.plot(x, results.kl_div, label=label, color=color, lw=0.5, ls=ls) | |
| ax2.plot(x, results.l2_diff, label=label, color=color, lw=0.5, ls=ls) | |
| ax1.set_ylabel("KL divergence to base logits") | |
| ax2.set_ylabel(f"L2 difference in {base_ref.read_layer}") | |
| ax1.legend() | |
| ax2.legend() | |
| ax1.set_xlim(min(x), max(x)) | |
| ax2.set_xlim(min(x), max(x)) | |
| if len(angles) < 40: | |
| tick_angles = angles | |
| tick_dists = dists | |
| else: | |
| tick_angles = angles[::10] | |
| tick_dists = dists[::10] | |
| if x_is_angle: | |
| ax1.set_xlabel(f"Angle from base activation at {layer}") | |
| ax2.set_xlabel(f"Angle from base activation at {layer}") | |
| ax1.set_xticks(tick_angles) | |
| ax1.set_xticklabels([f"{angle * 180 / np.pi:.0f}°" for angle in tick_angles], rotation=-80) | |
| ax2.set_xticks(tick_angles) | |
| ax2.set_xticklabels([f"{angle * 180 / np.pi:.0f}°" for angle in tick_angles], rotation=-80) | |
| if label is not None: | |
| ax1t.set_xlabel("Distance") | |
| ax2t.set_xlabel("Distance") | |
| ax1t.set_xticks(ax1.get_xticks()) | |
| ax2t.set_xticks(ax2.get_xticks()) | |
| ax1t.set_xticklabels([f"{dist:.1f}" for dist in tick_dists], rotation=-80) | |
| ax2t.set_xticklabels([f"{dist:.1f}" for dist in tick_dists], rotation=-80) | |
| else: | |
| ax1.set_xlabel(f"Distance from base activation at {layer}") | |
| ax2.set_xlabel(f"Distance from base activation at {layer}") | |
| ax1.set_xticks(tick_dists) | |
| ax1.set_xticklabels([f"{dist:.1f}" for dist in tick_dists], rotation=-80) | |
| ax2.set_xticks(tick_dists) | |
| ax2.set_xticklabels([f"{dist:.1f}" for dist in tick_dists], rotation=-80) | |
| if label is not None: | |
| ax1t.set_xlabel("Angle") | |
| ax2t.set_xlabel("Angle") | |
| ax1t.set_xticks(ax1.get_xticks()) | |
| ax2t.set_xticks(ax2.get_xticks()) | |
| ax1t.set_xticklabels([f"{angle * 180 / np.pi:.0f}°" for angle in tick_angles], rotation=-80) | |
| ax2t.set_xticklabels([f"{angle * 180 / np.pi:.0f}°" for angle in tick_angles], rotation=-80) | |
| def find_sensitivity(x, y, threshold=0.5): | |
| f = sip.interp1d(x, y, kind="linear") | |
| angles_interp = np.linspace(x.min(), x.max(), 100000) | |
| index_interp = np.argmax(f(angles_interp) > threshold) | |
| return angles_interp[index_interp] | |
| # %% | |
| torch.cuda.empty_cache() | |
| print(f"Allocated: {torch.cuda.memory_allocated() / 1024**3:.2f} GB") | |
| seed = 1 | |
| step_angular = False | |
| on_only = True | |
| negate = on_only | |
| set_seed(seed) | |
| n_ctx = 10 | |
| layer = "blocks.1.hook_resid_pre" | |
| pos = slice(-1, None, 1) | |
| num_steps = 40 | |
| step_size = 1 | |
| learning_rate = 1 | |
| dataset = load_dataset("apollo-research/Skylion007-openwebtext-tokenizer-gpt2", split="train", streaming=False).with_format("torch") | |
| dataloader = torch.utils.data.DataLoader(dataset, batch_size=15, shuffle=True) | |
| # %% | |
| model = HookedTransformer.from_pretrained("gpt2") | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| # %% | |
| base_prompt = generate_prompt(dataset, model.tokenizer, n_ctx=n_ctx) | |
| base_ref = Reference(model, base_prompt, layer, "blocks.11.hook_resid_post", pos, n_ctx) | |
| unrelated_prompt = generate_prompt(dataset, model.tokenizer, n_ctx=n_ctx) | |
| unrelated_ref = Reference(model, unrelated_prompt, layer, "blocks.11.hook_resid_post", pos, n_ctx) | |
| # %% | |
| torch.cuda.empty_cache() | |
| print(f"Allocated: {torch.cuda.memory_allocated() / 1024**3:.2f} GB") | |
| tensor_of_prompts = generate_prompt(dataset, model.tokenizer, n_ctx=n_ctx, batch=4000) | |
| mean_act_cache = model.run_with_cache(tensor_of_prompts)[1].to("cpu") | |
| plt.figure() | |
| plt.plot(mean_act_cache[layer].mean(dim=0)[1:].mean(dim=0).cpu()) | |
| plt.title(f"Mean activations for {layer}") | |
| plt.xlabel("Dimension") | |
| plt.ylabel("Activation") | |
| plt.savefig(f"plots/2_naive_mean_activations_seed_{seed}.png") | |
| plt.close() | |
| print(f"Allocated: {torch.cuda.memory_allocated() / 1024**3:.2f} GB") | |
| data = mean_act_cache[layer][:, -1, :] | |
| data_mean = data.mean(dim=0, keepdim=True) | |
| data_cov = torch.einsum("ij,ik->jk", data - data_mean, data - data_mean) / data.shape[0] | |
| distrib = MultivariateNormal(data_mean.squeeze(0), data_cov) | |
| # %% | |
| random_perturbation = RandomPerturbation(base_ref, unrelated_ref) | |
| naiverandom_perturbation = RandomUniformPerturbation(base_ref, unrelated_ref) | |
| random_perturbation.distrib = distrib | |
| randomact_perturbation = RandomActDirPerturbation(base_ref, unrelated_ref) | |
| # %% | |
| results_list0 = [] | |
| results_list1 = [] | |
| results_list2 = [] | |
| results_list3 = [] | |
| for _ in tqdm(range(20)): | |
| random_perturbation.gen_direction() | |
| results, _, dir = random_perturbation.scan(n_steps=361, range=(0, np.pi), step_angular=step_angular) | |
| results_list0.append(results) | |
| randomact_perturbation.gen_direction() | |
| results, _, dir = randomact_perturbation.scan(n_steps=361, range=(0, np.pi), step_angular=step_angular) | |
| results_list1.append(results) | |
| naiverandom_perturbation.gen_direction() | |
| results, _, dir = naiverandom_perturbation.scan(n_steps=361, range=(0, np.pi), step_angular=step_angular) | |
| results_list3.append(results) | |
| # %% | |
| fig, [ax1, ax2] = plt.subplots(1, 2, figsize=(12, 6), constrained_layout=True) | |
| ax1t = ax1.twiny() | |
| ax2t = ax2.twiny() | |
| prompt_str = "|" + "|".join(model.to_str_tokens(base_ref.prompt)).replace("\n", "⏎") + "|" | |
| title_str = "Sphere mode (slerping angle around origin)" if step_angular else "Straight mode" | |
| fig.suptitle(f"{title_str}. Perturbing {layer} at pos {pos}.\nSeed = {seed}. Prompt = {prompt_str}. Norm = {lp_norm(base_ref.act).item():.2f}.") | |
| for i, r in enumerate(results_list0): | |
| format_ax( | |
| ax1, | |
| ax1t, | |
| ax2, | |
| ax2t, | |
| r, | |
| "Random direction perturbation" if i == 0 else None, | |
| "C0", | |
| r.angle[:, 0] if step_angular else r.dist[:, 0], | |
| base_ref, | |
| x_is_angle=step_angular, | |
| ) | |
| for i, r in enumerate(results_list1): | |
| format_ax( | |
| ax1, | |
| ax1t, | |
| ax2, | |
| ax2t, | |
| r, | |
| "Direction to random other activation" if i == 0 else None, | |
| "C1", | |
| r.angle[:, 0] if step_angular else r.dist[:, 0], | |
| base_ref, | |
| x_is_angle=step_angular, | |
| ) | |
| for i, r in enumerate(results_list3): | |
| format_ax( | |
| ax1, | |
| ax1t, | |
| ax2, | |
| ax2t, | |
| r, | |
| "Naive random direction perturbation" if i == 0 else None, | |
| "C4", | |
| r.angle[:, 0] if step_angular else r.dist[:, 0], | |
| base_ref, | |
| x_is_angle=step_angular, | |
| ) | |
| mode_str = "sphere" if step_angular else "straight" | |
| fig.savefig(f"plots/2_naive_perturbations_{mode_str}_seed_{seed}.png", dpi=150, bbox_inches="tight") | |
| plt.close() | |
| # %% | |
| plt.figure(figsize=(8, 4)) | |
| plt.plot(results_list0[0].angle, torch.mean(torch.stack([r.kl_div for r in results_list0]), dim=0)) | |
| plt.plot(results_list1[0].angle, torch.mean(torch.stack([r.kl_div for r in results_list1]), dim=0)) | |
| plt.fill_between( | |
| results_list0[0].angle[:, 0], | |
| torch.min(torch.stack([r.kl_div for r in results_list0]), dim=0).values, | |
| torch.max(torch.stack([r.kl_div for r in results_list0]), dim=0).values, | |
| alpha=0.3, | |
| ) | |
| plt.fill_between( | |
| results_list1[0].angle[:, 0], | |
| torch.min(torch.stack([r.kl_div for r in results_list1]), dim=0).values, | |
| torch.max(torch.stack([r.kl_div for r in results_list1]), dim=0).values, | |
| alpha=0.3, | |
| ) | |
| plt.xlabel("Angle from base activation") | |
| tick_angles = results_list0[0].angle[::10, 0] | |
| plt.xticks(tick_angles, [f"{angle * 180 / np.pi:.0f}°" for angle in tick_angles], rotation=-80) | |
| plt.ylabel("Average KL divergence to base logits") | |
| plt.title("Average KL divergence comparison") | |
| plt.savefig(f"plots/2_naive_average_kl_div_seed_{seed}.png", dpi=150, bbox_inches="tight") | |
| plt.close() |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| # See code in figure1.py |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| # 2_2a.py | |
| # %% | |
| import os | |
| import random | |
| from collections.abc import Callable | |
| from dataclasses import dataclass | |
| import matplotlib | |
| import matplotlib.pyplot as plt | |
| matplotlib.use("Agg") # Use non-interactive backend | |
| import numpy as np | |
| import scipy.interpolate as sip | |
| import torch | |
| import torch.nn.functional as F | |
| from datasets import load_dataset | |
| from torch.distributions.multivariate_normal import MultivariateNormal | |
| from tqdm import tqdm | |
| from transformer_lens import HookedTransformer | |
| os.environ["TOKENIZERS_PARALLELISM"] = "false" | |
| os.makedirs("plots", exist_ok=True) | |
| # %% | |
| def lp_norm(direction, p=2): | |
| """Pytorch norm""" | |
| return torch.linalg.vector_norm(direction, dim=-1, keepdim=True, ord=p) | |
| def mean(direction): | |
| return torch.mean(direction, dim=-1, keepdim=True) | |
| def dot(x, y): | |
| return torch.einsum("...k,...k->...", x, y).unsqueeze(-1) | |
| def compute_angle_dist(ref_act, new_act, datasetmean=None): | |
| angle = torch.acos(dot(new_act, ref_act) / (lp_norm(new_act) * lp_norm(ref_act))).squeeze(-1) | |
| if torch.all(angle[0].isnan()): | |
| angle[0] = torch.zeros_like(angle[0]) | |
| if torch.all(angle[-1].isnan()): | |
| angle[-1] = np.pi * torch.ones_like(angle[-1]) | |
| dist = lp_norm(new_act - ref_act).squeeze(-1) | |
| if datasetmean is None: | |
| return angle, dist | |
| else: | |
| angle_wrt_datasetmean = torch.acos(dot(new_act - datasetmean, ref_act - datasetmean) / (lp_norm(new_act - datasetmean) * lp_norm(ref_act - datasetmean))).squeeze( | |
| -1 | |
| ) | |
| return angle, dist, angle_wrt_datasetmean | |
| class Reference: | |
| def __init__( | |
| self, | |
| model: HookedTransformer, | |
| prompt: torch.Tensor, | |
| replacement_layer: str, | |
| read_layer: str, | |
| replacement_pos: slice, | |
| n_ctx: int, | |
| ): | |
| self.model = model | |
| n_batch_prompt, n_ctx_prompt = prompt.shape | |
| assert n_ctx == n_ctx_prompt, f"n_ctx {n_ctx} must match prompt n_ctx {n_ctx_prompt}" | |
| self.prompt = prompt | |
| logits, cache = model.run_with_cache(prompt) | |
| self.logits = logits.to("cpu").detach() | |
| self.cache = cache.to("cpu") | |
| self.act = self.cache[replacement_layer][:, replacement_pos] | |
| self.replacement_layer = replacement_layer | |
| self.read_layer = read_layer | |
| self.replacement_pos = replacement_pos | |
| self.n_ctx = n_ctx | |
| @dataclass | |
| class Result: | |
| angle: float | |
| angle_wrt_datasetmean: float | |
| dist: float | |
| norm: float | |
| kl_div: float | |
| out_angle: float | |
| l2_diff: float | |
| logit_l2_diff: float | |
| dim0_diff: float | |
| dim1_diff: float | |
| def set_seed(seed: int): | |
| """Set the random seed for reproducibility.""" | |
| torch.manual_seed(seed) | |
| np.random.seed(seed) | |
| random.seed(seed) | |
| def generate_prompt(dataset, tokenizer: Callable, n_ctx: int = 1, batch: int = 1) -> torch.Tensor: | |
| """Generate a prompt from the dataset.""" | |
| dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch, shuffle=True) | |
| return next(iter(dataloader))["input_ids"][:, :n_ctx] | |
| tokens = [0, *sample[: n_ctx - 1]] | |
| return torch.tensor([tokens[:n_ctx]]) | |
| def get_random_activation(model: HookedTransformer, dataset: torch.Tensor, n_ctx: int, layer: str, pos) -> torch.Tensor: | |
| """Get a random activation from the dataset.""" | |
| rand_prompt = generate_prompt(dataset, model.tokenizer, n_ctx=n_ctx) | |
| _, cache = model.run_with_cache(rand_prompt) | |
| return cache[layer][:, pos, :].to("cpu").detach() | |
| def compute_kl_div(logits_ref: torch.Tensor, logits_pert: torch.Tensor) -> torch.Tensor: | |
| """Compute the KL divergence between the reference and perturbed logprobs.""" | |
| logprobs_ref = F.log_softmax(logits_ref, dim=-1) | |
| logprobs_pert = F.log_softmax(logits_pert, dim=-1) | |
| return F.kl_div(logprobs_pert, logprobs_ref, log_target=True, reduction="none").sum(dim=-1) | |
| def compute_Lp_metric( | |
| cache: dict, | |
| cache_pert: dict, | |
| p, | |
| read_layer, | |
| read_pos, | |
| ): | |
| ref_readoff = cache[read_layer][:, read_pos] | |
| pert_readoff = cache_pert[read_layer][:, read_pos] | |
| Lp_diff = torch.linalg.norm(ref_readoff - pert_readoff, ord=p, dim=-1) | |
| return Lp_diff | |
| def run_perturbed_activation(perturbed_act: torch.Tensor, ref: Reference): | |
| pos = ref.replacement_pos | |
| layer = ref.replacement_layer | |
| def hook(act, hook): | |
| act[:, pos, :] = perturbed_act | |
| with ref.model.hooks(fwd_hooks=[(layer, hook)]): | |
| prompts = torch.cat([ref.prompt for _ in range(len(perturbed_act))]) | |
| logits_pert, cache = ref.model.run_with_cache(prompts) | |
| return logits_pert.to("cpu").detach(), cache.to("cpu") | |
| def eval_activation( | |
| perturbed_act: torch.Tensor, | |
| base_ref: Reference, | |
| unrelated_ref: Reference, | |
| read_pos=-1, | |
| ): | |
| read_layer = base_ref.read_layer | |
| angle, dist = compute_angle_dist(base_ref.act, perturbed_act, datasetmean=None) | |
| angle_wrt_datasetmean = angle # PLACEHOLDER | |
| norm = lp_norm(perturbed_act).squeeze(-1) | |
| logits_pert, cache = run_perturbed_activation(perturbed_act, base_ref) | |
| base_kl_div = compute_kl_div(base_ref.logits, logits_pert)[:, read_pos] | |
| base_l1_diff = compute_Lp_metric(base_ref.cache, cache, p=1, read_layer=read_layer, read_pos=read_pos) | |
| base_l2_diff = compute_Lp_metric(base_ref.cache, cache, p=2, read_layer=read_layer, read_pos=read_pos) | |
| base_logit_l2_diff = torch.linalg.norm(base_ref.logits - logits_pert, ord=2, dim=-1) | |
| base_dim0_diff = base_ref.cache[read_layer][:, read_pos, 0] - cache[read_layer][:, read_pos, 0] | |
| base_dim1_diff = base_ref.cache[read_layer][:, read_pos, 1] - cache[read_layer][:, read_pos, 1] | |
| base_out_angle, _ = compute_angle_dist( | |
| base_ref.cache[read_layer][:, read_pos, :], | |
| cache[read_layer][:, read_pos, :], | |
| datasetmean=None, | |
| ) | |
| unrelated_kl_div = compute_kl_div(unrelated_ref.logits, logits_pert)[:, read_pos] | |
| unrelated_l1_diff = compute_Lp_metric(unrelated_ref.cache, cache, p=1, read_layer=read_layer, read_pos=read_pos) | |
| unrelated_l2_diff = compute_Lp_metric(unrelated_ref.cache, cache, p=2, read_layer=read_layer, read_pos=read_pos) | |
| unrelated_logit_l2_diff = torch.linalg.norm(unrelated_ref.logits - logits_pert, ord=2, dim=-1) | |
| unrelated_dim0_diff = unrelated_ref.cache[read_layer][:, read_pos, 0] - cache[read_layer][:, read_pos, 0] | |
| unrelated_dim1_diff = unrelated_ref.cache[read_layer][:, read_pos, 1] - cache[read_layer][:, read_pos, 1] | |
| unrelated_out_angle, _ = compute_angle_dist( | |
| unrelated_ref.cache[read_layer][:, read_pos, :], | |
| cache[read_layer][:, read_pos, :], | |
| datasetmean=None, | |
| ) | |
| return ( | |
| Result( | |
| angle, | |
| angle_wrt_datasetmean, | |
| dist, | |
| norm, | |
| base_kl_div, | |
| base_out_angle, | |
| base_l2_diff, | |
| base_logit_l2_diff, | |
| base_dim0_diff, | |
| base_dim1_diff, | |
| ), | |
| Result( | |
| angle, | |
| angle_wrt_datasetmean, | |
| dist, | |
| norm, | |
| unrelated_kl_div, | |
| unrelated_out_angle, | |
| unrelated_l2_diff, | |
| unrelated_logit_l2_diff, | |
| unrelated_dim0_diff, | |
| unrelated_dim1_diff, | |
| ), | |
| ) | |
| class Slerp: | |
| def __init__(self, start, direction, datasetmean=None): | |
| """Return interpolation points along the sphere from | |
| A towards direction B""" | |
| self.datasetmean = datasetmean | |
| self.A = start - self.datasetmean if self.datasetmean is not None else start | |
| self.A_norm = lp_norm(self.A) # magnitude | |
| self.a = self.A / self.A_norm # unit vectors | |
| d = direction / lp_norm(direction) | |
| self.B = d - dot(d, self.a) * self.a | |
| self.B_norm = lp_norm(self.B) # magnitude | |
| self.b = self.B / self.B_norm # unit vectors | |
| def __call__(self, alpha): | |
| result = self.A_norm * (torch.cos(alpha) * self.a + torch.sin(alpha) * self.b) | |
| if self.datasetmean is not None: | |
| return result + self.datasetmean | |
| else: | |
| return result | |
| def get_alpha(self, X): | |
| x = X / lp_norm(X) | |
| return torch.acos(dot(x, self.a)) | |
| class Perturbation: | |
| def gen_direction(self): | |
| raise NotImplementedError | |
| def __init__( | |
| self, | |
| base_ref: Reference, | |
| unrelated_ref: Reference, | |
| ensure_mean=True, | |
| ensure_norm=True, | |
| ): | |
| self.ensure_mean = ensure_mean | |
| self.ensure_norm = ensure_norm | |
| self.base_ref = base_ref | |
| self.norm_base = lp_norm(base_ref.act) | |
| self.unrelated_ref = unrelated_ref | |
| def scan(self, n_steps: int = 33, range: tuple[float, float] = (0, 1.0), step_angular=True): | |
| direction = self.gen_direction() | |
| return self.scan_dir(direction, n_steps, range, step_angular) | |
| def scan_dir(self, direction, n_steps: int = 33, range: tuple[float, float] = (0, 1.0), step_angular=True): | |
| direction = direction.clone() | |
| if self.ensure_mean: | |
| direction -= mean(direction) | |
| if self.ensure_norm: | |
| direction *= self.norm_base / lp_norm(direction) | |
| range = np.array(range) | |
| if step_angular: | |
| s = Slerp(self.base_ref.act, direction, datasetmean=None) | |
| self.activation_steps = [s(alpha) for alpha in torch.linspace(*range, n_steps)] | |
| else: | |
| self.activation_steps = [self.base_ref.act + alpha * direction for alpha in torch.linspace(*range, n_steps)] | |
| act = torch.cat(self.activation_steps, dim=0) | |
| torch.cuda.empty_cache() | |
| base_result, unrelated_result = eval_activation(act, self.base_ref, self.unrelated_ref) | |
| return base_result, unrelated_result, direction | |
| class RandomUniformPerturbation(Perturbation): | |
| def gen_direction(self): | |
| return torch.randn_like(self.base_ref.act) | |
| class RandomPerturbation(Perturbation): | |
| def gen_target(self): | |
| return self.distrib.sample(self.base_ref.act.shape[:-1]) | |
| def gen_direction(self): | |
| self.target = self.gen_target() | |
| return self.target - self.base_ref.act | |
| class RandomActDirPerturbation(Perturbation): | |
| def gen_target(self): | |
| return get_random_activation( | |
| self.base_ref.model, | |
| dataset, | |
| self.base_ref.n_ctx, | |
| self.base_ref.replacement_layer, | |
| self.base_ref.replacement_pos, | |
| ) | |
| def gen_direction(self): | |
| self.target = self.gen_target() | |
| return self.target - self.base_ref.act | |
| class NeuronDirPerturbation(Perturbation): | |
| def __init__(self, ref: Reference, negate=False, on_only=None): | |
| super().__init__(ref) | |
| self.base_ref = ref | |
| self.negate = 1 if not negate else -1 | |
| if on_only is True: | |
| feature_acts = self.base_ref.cache[ref.replacement_layer][0, -1, :] | |
| self.active_features = feature_acts / feature_acts.max() > 0.1 | |
| def gen_direction(self): | |
| if self.active_features is None: | |
| random_int = random.randint(0, 768) | |
| else: | |
| random_int = random.choice(self.active_features.nonzero(as_tuple=True)[0]) | |
| one_hot = torch.zeros_like(self.base_ref.act) | |
| one_hot[..., random_int] = 1 | |
| single_direction = self.negate * one_hot | |
| return torch.stack([single_direction for _ in range(self.base_ref.act.shape[0])]) | |
| class OptimizedPerturbation(Perturbation): | |
| def __init__(self, ref: Reference, learning_rate: float = 0.1): | |
| super().__init__(ref) | |
| self.learning_rate = learning_rate | |
| def sample(self, num_steps: int, step_size: float = 0.01, sign=-1, init=None): | |
| kl_divs = [] | |
| directions = [] | |
| if init is not None: | |
| init = np.array(init.cpu().detach()) | |
| direction = torch.tensor(init, requires_grad=True, device=self.base_ref.act.device) | |
| else: | |
| direction = torch.randn_like(self.base_ref.act, requires_grad=True) | |
| optimizer = torch.optim.SGD([direction], lr=self.learning_rate) | |
| for _ in tqdm(range(num_steps), desc="Finding perturbation direction"): | |
| kl_div, _, _ = eval_direction(direction, self.base_ref, step_size) | |
| kl_divs.append(kl_div.item()) | |
| directions.append(direction.detach().clone()) | |
| (sign * kl_div).backward(retain_graph=True) | |
| optimizer.step() | |
| optimizer.zero_grad() | |
| return kl_divs, directions | |
| def format_ax(ax1, ax1t, ax2, ax2t, results, label, color, x, base_ref, ls="-", x_is_angle=True): | |
| assert torch.allclose(results.angle.T - results.angle[:, 0].T, torch.tensor(0.0)) | |
| assert torch.allclose(results.dist.T - results.dist[:, 0].T, torch.tensor(0.0)) | |
| angles = r.angle[:, 0] | |
| dists = results.dist[:, 0] | |
| ax1.plot(x, results.kl_div, label=label, color=color, lw=0.5, ls=ls) | |
| ax2.plot(x, results.l2_diff, label=label, color=color, lw=0.5, ls=ls) | |
| ax1.set_ylabel("KL divergence to base logits") | |
| ax2.set_ylabel(f"L2 difference in {base_ref.read_layer}") | |
| ax1.legend() | |
| ax2.legend() | |
| ax1.set_xlim(min(x), max(x)) | |
| ax2.set_xlim(min(x), max(x)) | |
| if len(angles) < 40: | |
| tick_angles = angles | |
| tick_dists = dists | |
| else: | |
| tick_angles = angles[::10] | |
| tick_dists = dists[::10] | |
| if x_is_angle: | |
| ax1.set_xlabel(f"Angle from base activation at {layer}") | |
| ax2.set_xlabel(f"Angle from base activation at {layer}") | |
| ax1.set_xticks(tick_angles) | |
| ax1.set_xticklabels([f"{angle * 180 / np.pi:.0f}°" for angle in tick_angles], rotation=-80) | |
| ax2.set_xticks(tick_angles) | |
| ax2.set_xticklabels([f"{angle * 180 / np.pi:.0f}°" for angle in tick_angles], rotation=-80) | |
| if label is not None: | |
| ax1t.set_xlabel("Distance") | |
| ax2t.set_xlabel("Distance") | |
| ax1t.set_xticks(ax1.get_xticks()) | |
| ax2t.set_xticks(ax2.get_xticks()) | |
| ax1t.set_xticklabels([f"{dist:.1f}" for dist in tick_dists], rotation=-80) | |
| ax2t.set_xticklabels([f"{dist:.1f}" for dist in tick_dists], rotation=-80) | |
| else: | |
| ax1.set_xlabel(f"Distance from base activation at {layer}") | |
| ax2.set_xlabel(f"Distance from base activation at {layer}") | |
| ax1.set_xticks(tick_dists) | |
| ax1.set_xticklabels([f"{dist:.1f}" for dist in tick_dists], rotation=-80) | |
| ax2.set_xticks(tick_dists) | |
| ax2.set_xticklabels([f"{dist:.1f}" for dist in tick_dists], rotation=-80) | |
| if label is not None: | |
| ax1t.set_xlabel("Angle") | |
| ax2t.set_xlabel("Angle") | |
| ax1t.set_xticks(ax1.get_xticks()) | |
| ax2t.set_xticks(ax2.get_xticks()) | |
| ax1t.set_xticklabels([f"{angle * 180 / np.pi:.0f}°" for angle in tick_angles], rotation=-80) | |
| ax2t.set_xticklabels([f"{angle * 180 / np.pi:.0f}°" for angle in tick_angles], rotation=-80) | |
| def find_sensitivity(x, y, threshold=0.5): | |
| f = sip.interp1d(x, y, kind="linear") | |
| angles_interp = np.linspace(x.min(), x.max(), 100000) | |
| index_interp = np.argmax(f(angles_interp) > threshold) | |
| return angles_interp[index_interp] | |
| # %% | |
| torch.cuda.empty_cache() | |
| print(f"Allocated: {torch.cuda.memory_allocated() / 1024**3:.2f} GB") | |
| ################################ | |
| seed = 0 | |
| step_angular = True | |
| on_only = True | |
| negate = on_only | |
| ################ | |
| set_seed(seed) | |
| n_ctx = 10 | |
| layer = "blocks.1.hook_resid_pre" | |
| pos = slice(-1, None, 1) | |
| num_steps = 40 | |
| step_size = 1 | |
| learning_rate = 1 | |
| dataset = load_dataset("apollo-research/Skylion007-openwebtext-tokenizer-gpt2", split="train", streaming=False).with_format("torch") | |
| dataloader = torch.utils.data.DataLoader(dataset, batch_size=15, shuffle=True) | |
| # %% | |
| model = HookedTransformer.from_pretrained("gpt2") | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| # %% | |
| base_prompt = generate_prompt(dataset, model.tokenizer, n_ctx=n_ctx) | |
| base_ref = Reference(model, base_prompt, layer, "blocks.11.hook_resid_post", pos, n_ctx) | |
| unrelated_prompt = generate_prompt(dataset, model.tokenizer, n_ctx=n_ctx) | |
| unrelated_ref = Reference(model, unrelated_prompt, layer, "blocks.11.hook_resid_post", pos, n_ctx) | |
| # %% | |
| torch.cuda.empty_cache() | |
| print(f"Allocated: {torch.cuda.memory_allocated() / 1024**3:.2f} GB") | |
| tensor_of_prompts = generate_prompt(dataset, model.tokenizer, n_ctx=n_ctx, batch=4000) | |
| mean_act_cache = model.run_with_cache(tensor_of_prompts)[1].to("cpu") | |
| plt.figure() | |
| plt.plot(mean_act_cache[layer].mean(dim=0)[1:].mean(dim=0).cpu()) | |
| plt.title(f"Mean activations for {layer}") | |
| plt.xlabel("Dimension") | |
| plt.ylabel("Activation") | |
| plt.savefig(f"plots/2_2a_mean_activations_seed_{seed}.png") | |
| plt.close() | |
| print(f"Allocated: {torch.cuda.memory_allocated() / 1024**3:.2f} GB") | |
| data = mean_act_cache[layer][:, -1, :] | |
| data_mean = data.mean(dim=0, keepdim=True) | |
| data_cov = torch.einsum("ij,ik->jk", data - data_mean, data - data_mean) / data.shape[0] | |
| distrib = MultivariateNormal(data_mean.squeeze(0), data_cov) | |
| # %% | |
| random_perturbation = RandomPerturbation(base_ref, unrelated_ref) | |
| random_perturbation.distrib = distrib | |
| randomact_perturbation = RandomActDirPerturbation(base_ref, unrelated_ref) | |
| # %% | |
| results_list0 = [] | |
| results_list1 = [] | |
| results_list2 = [] | |
| for _ in tqdm(range(20)): | |
| random_perturbation.gen_direction() | |
| results, _, dir = random_perturbation.scan(n_steps=361, range=(0, np.pi), step_angular=step_angular) | |
| results_list0.append(results) | |
| randomact_perturbation.gen_direction() | |
| results, _, dir = randomact_perturbation.scan(n_steps=361, range=(0, np.pi), step_angular=step_angular) | |
| results_list1.append(results) | |
| results_list2.append(results) | |
| # %% | |
| fig, [ax1, ax2] = plt.subplots(1, 2, figsize=(12, 6), constrained_layout=True) | |
| ax1t = ax1.twiny() | |
| ax2t = ax2.twiny() | |
| prompt_str = "|" + "|".join(model.to_str_tokens(base_ref.prompt)).replace("\n", "⏎") + "|" | |
| title_str = "Sphere mode (slerping angle around origin)" if step_angular else "Straight mode" | |
| fig.suptitle(f"{title_str}. Perturbing {layer} at pos {pos}.\nSeed = {seed}. Prompt = {prompt_str}. Norm = {lp_norm(base_ref.act).item():.2f}.") | |
| for i, r in enumerate(results_list0): | |
| format_ax( | |
| ax1, | |
| ax1t, | |
| ax2, | |
| ax2t, | |
| r, | |
| "Random direction perturbation" if i == 0 else None, | |
| "C0", | |
| r.angle[:, 0] if step_angular else r.dist[:, 0], | |
| base_ref, | |
| x_is_angle=step_angular, | |
| ) | |
| for i, r in enumerate(results_list1): | |
| format_ax( | |
| ax1, | |
| ax1t, | |
| ax2, | |
| ax2t, | |
| r, | |
| "Direction to random other activation" if i == 0 else None, | |
| "C1", | |
| r.angle[:, 0] if step_angular else r.dist[:, 0], | |
| base_ref, | |
| x_is_angle=step_angular, | |
| ) | |
| mode_str = "sphere" if step_angular else "straight" | |
| fig.savefig(f"plots/2_2a_perturbations_{mode_str}_seed_{seed}.png", dpi=150, bbox_inches="tight") | |
| plt.close() | |
| # %% | |
| plt.figure(figsize=(8, 4)) | |
| plt.plot(results_list0[0].angle, torch.mean(torch.stack([r.kl_div for r in results_list0]), dim=0)) | |
| plt.plot(results_list1[0].angle, torch.mean(torch.stack([r.kl_div for r in results_list1]), dim=0)) | |
| plt.plot(results_list2[0].angle, torch.mean(torch.stack([r.kl_div for r in results_list2]), dim=0)) | |
| plt.fill_between( | |
| results_list0[0].angle[:, 0], | |
| torch.min(torch.stack([r.kl_div for r in results_list0]), dim=0).values, | |
| torch.max(torch.stack([r.kl_div for r in results_list0]), dim=0).values, | |
| alpha=0.3, | |
| ) | |
| plt.fill_between( | |
| results_list1[0].angle[:, 0], | |
| torch.min(torch.stack([r.kl_div for r in results_list1]), dim=0).values, | |
| torch.max(torch.stack([r.kl_div for r in results_list1]), dim=0).values, | |
| alpha=0.3, | |
| ) | |
| plt.xlabel("Angle from base activation") | |
| tick_angles = results_list0[0].angle[::10, 0] | |
| plt.xticks(tick_angles, [f"{angle * 180 / np.pi:.0f}°" for angle in tick_angles], rotation=-80) | |
| plt.ylabel("Average KL divergence to base logits") | |
| plt.title("Average KL divergence comparison") | |
| plt.savefig(f"plots/2_2a_average_kl_div_seed_{seed}.png", dpi=150, bbox_inches="tight") | |
| plt.close() |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import os | |
| import random | |
| from collections import defaultdict | |
| from collections.abc import Callable | |
| from dataclasses import dataclass | |
| import matplotlib.pyplot as plt | |
| import numpy as np | |
| import scipy.interpolate as sip | |
| import torch | |
| import torch.nn.functional as F | |
| from datasets import load_dataset | |
| from sae_lens import SAE | |
| from torch.distributions.multivariate_normal import MultivariateNormal | |
| from tqdm import tqdm | |
| from transformer_lens import HookedTransformer | |
| os.environ["TOKENIZERS_PARALLELISM"] = "false" | |
| def lp_norm(direction, p=2): | |
| """Pytorch norm""" | |
| return torch.linalg.vector_norm(direction, dim=-1, keepdim=True, ord=p) | |
| def mean(direction): | |
| return torch.mean(direction, dim=-1, keepdim=True) | |
| def dot(x, y): | |
| return torch.einsum("...k,...k->...", x, y).unsqueeze(-1) | |
| def compute_angle_dist(ref_act, new_act, datasetmean=None): | |
| angle = torch.acos(dot(new_act, ref_act) / (lp_norm(new_act) * lp_norm(ref_act))).squeeze(-1) | |
| if torch.all(angle[0].isnan()): | |
| angle[0] = torch.zeros_like(angle[0]) | |
| if torch.all(angle[-1].isnan()): | |
| angle[-1] = np.pi * torch.ones_like(angle[-1]) | |
| dist = lp_norm(new_act - ref_act).squeeze(-1) | |
| if datasetmean is None: | |
| return angle, dist | |
| else: | |
| angle_wrt_datasetmean = torch.acos(dot(new_act - datasetmean, ref_act - datasetmean) / (lp_norm(new_act - datasetmean) * lp_norm(ref_act - datasetmean))).squeeze( | |
| -1 | |
| ) | |
| return angle, dist, angle_wrt_datasetmean | |
| class Reference: | |
| def __init__( | |
| self, | |
| model: HookedTransformer, | |
| prompt: torch.Tensor, | |
| replacement_layer: str, | |
| replacement_pos: slice, | |
| n_ctx: int, | |
| ): | |
| self.model = model | |
| n_batch_prompt, n_ctx_prompt = prompt.shape | |
| assert n_ctx == n_ctx_prompt, f"n_ctx {n_ctx} must match prompt n_ctx {n_ctx_prompt}" | |
| self.prompt = prompt | |
| logits, cache = model.run_with_cache(prompt) | |
| self.logits = logits.to("cpu").detach() | |
| self.cache = cache.to("cpu") | |
| self.act = self.cache[replacement_layer][:, replacement_pos] | |
| self.layer = replacement_layer | |
| self.pos = replacement_pos | |
| self.n_ctx = n_ctx | |
| @dataclass | |
| class Result: | |
| angle: float | |
| angle_wrt_datasetmean: float | |
| dist: float | |
| norm: float | |
| kl_div: float | |
| out_angle: float | |
| l2_diff: float | |
| logit_l2_diff: float | |
| dim0_diff: float | |
| dim1_diff: float | |
| def set_seed(seed: int): | |
| """Set the random seed for reproducibility.""" | |
| torch.manual_seed(seed) | |
| np.random.seed(seed) | |
| random.seed(seed) | |
| def generate_prompt(dataset, tokenizer: Callable, n_ctx: int = 1, batch: int = 1) -> torch.Tensor: | |
| """Generate a prompt from the dataset.""" | |
| dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch, shuffle=True) | |
| return next(iter(dataloader))["input_ids"][:, :n_ctx] | |
| tokens = [0, *sample[: n_ctx - 1]] | |
| return torch.tensor([tokens[:n_ctx]]) | |
| def get_random_activation(model: HookedTransformer, dataset: torch.Tensor, n_ctx: int, layer: str, pos) -> torch.Tensor: | |
| """Get a random activation from the dataset.""" | |
| rand_prompt = generate_prompt(dataset, model.tokenizer, n_ctx=n_ctx) | |
| _, cache = model.run_with_cache(rand_prompt) | |
| return cache[layer][:, pos, :].to("cpu").detach() | |
| def compute_kl_div(logits_ref: torch.Tensor, logits_pert: torch.Tensor) -> torch.Tensor: | |
| """Compute the KL divergence between the reference and perturbed logprobs.""" | |
| logprobs_ref = F.log_softmax(logits_ref, dim=-1) | |
| logprobs_pert = F.log_softmax(logits_pert, dim=-1) | |
| return F.kl_div(logprobs_pert, logprobs_ref, log_target=True, reduction="none").sum(dim=-1) | |
| def compute_Lp_metric( | |
| cache: dict, | |
| cache_pert: dict, | |
| p, | |
| read_layer, | |
| read_pos, | |
| ): | |
| ref_readoff = cache[read_layer][:, read_pos] | |
| pert_readoff = cache_pert[read_layer][:, read_pos] | |
| Lp_diff = torch.linalg.norm(ref_readoff - pert_readoff, ord=p, dim=-1) | |
| return Lp_diff | |
| def run_perturbed_activation(perturbed_act: torch.Tensor, ref: Reference): | |
| pos = ref.pos | |
| layer = ref.layer | |
| def hook(act, hook): | |
| act[:, pos, :] = perturbed_act | |
| with ref.model.hooks(fwd_hooks=[(layer, hook)]): | |
| prompts = torch.cat([ref.prompt for _ in range(len(perturbed_act))]) | |
| logits_pert, cache = ref.model.run_with_cache(prompts) | |
| return logits_pert.to("cpu").detach(), cache.to("cpu") | |
| def eval_activation( | |
| perturbed_act: torch.Tensor, | |
| base_ref: Reference, | |
| unrelated_ref: Reference, | |
| read_layer="ln_final.hook_normalized", | |
| read_pos=-1, | |
| ): | |
| angle, dist = compute_angle_dist(base_ref.act, perturbed_act, datasetmean=None) | |
| angle_wrt_datasetmean = angle | |
| norm = lp_norm(perturbed_act).squeeze(-1) | |
| logits_pert, cache = run_perturbed_activation(perturbed_act, base_ref) | |
| base_kl_div = compute_kl_div(base_ref.logits, logits_pert)[:, read_pos] | |
| base_l1_diff = compute_Lp_metric(base_ref.cache, cache, p=1, read_layer=read_layer, read_pos=read_pos) | |
| base_l2_diff = compute_Lp_metric(base_ref.cache, cache, p=2, read_layer=read_layer, read_pos=read_pos) | |
| base_logit_l2_diff = torch.linalg.norm(base_ref.logits - logits_pert, ord=2, dim=-1) | |
| base_dim0_diff = base_ref.cache[read_layer][:, read_pos, 0] - cache[read_layer][:, read_pos, 0] | |
| base_dim1_diff = base_ref.cache[read_layer][:, read_pos, 1] - cache[read_layer][:, read_pos, 1] | |
| base_out_angle, _ = compute_angle_dist( | |
| base_ref.cache[read_layer][:, read_pos, :], | |
| cache[read_layer][:, read_pos, :], | |
| datasetmean=None, | |
| ) | |
| unrelated_kl_div = compute_kl_div(unrelated_ref.logits, logits_pert)[:, read_pos] | |
| unrelated_l1_diff = compute_Lp_metric(unrelated_ref.cache, cache, p=1, read_layer=read_layer, read_pos=read_pos) | |
| unrelated_l2_diff = compute_Lp_metric(unrelated_ref.cache, cache, p=2, read_layer=read_layer, read_pos=read_pos) | |
| unrelated_logit_l2_diff = torch.linalg.norm(unrelated_ref.logits - logits_pert, ord=2, dim=-1) | |
| unrelated_dim0_diff = unrelated_ref.cache[read_layer][:, read_pos, 0] - cache[read_layer][:, read_pos, 0] | |
| unrelated_dim1_diff = unrelated_ref.cache[read_layer][:, read_pos, 1] - cache[read_layer][:, read_pos, 1] | |
| unrelated_out_angle, _ = compute_angle_dist( | |
| unrelated_ref.cache[read_layer][:, read_pos, :], | |
| cache[read_layer][:, read_pos, :], | |
| datasetmean=None, | |
| ) | |
| return ( | |
| Result( | |
| angle, | |
| angle_wrt_datasetmean, | |
| dist, | |
| norm, | |
| base_kl_div, | |
| base_out_angle, | |
| base_l2_diff, | |
| base_logit_l2_diff, | |
| base_dim0_diff, | |
| base_dim1_diff, | |
| ), | |
| Result( | |
| angle, | |
| angle_wrt_datasetmean, | |
| dist, | |
| norm, | |
| unrelated_kl_div, | |
| unrelated_out_angle, | |
| unrelated_l2_diff, | |
| unrelated_logit_l2_diff, | |
| unrelated_dim0_diff, | |
| unrelated_dim1_diff, | |
| ), | |
| ) | |
| class Slerp: | |
| def __init__(self, start, direction, datasetmean=None): | |
| """Return interpolation points along the sphere from | |
| A towards direction B""" | |
| self.datasetmean = datasetmean | |
| self.A = start - self.datasetmean if self.datasetmean is not None else start | |
| self.A_norm = lp_norm(self.A) # magnitude | |
| self.a = self.A / self.A_norm # unit vectors | |
| d = direction / lp_norm(direction) | |
| self.B = d - dot(d, self.a) * self.a | |
| self.B_norm = lp_norm(self.B) # magnitude | |
| self.b = self.B / self.B_norm # unit vectors | |
| def __call__(self, alpha): | |
| result = self.A_norm * (torch.cos(alpha) * self.a + torch.sin(alpha) * self.b) | |
| if self.datasetmean is not None: | |
| return result + self.datasetmean | |
| else: | |
| return result | |
| def get_alpha(self, X): | |
| x = X / lp_norm(X) | |
| return torch.acos(dot(x, self.a)) | |
| class Perturbation: | |
| def gen_direction(self): | |
| raise NotImplementedError | |
| def __init__( | |
| self, | |
| base_ref: Reference, | |
| unrelated_ref: Reference, | |
| ensure_mean=True, | |
| ensure_norm=True, | |
| project_sphere=False, | |
| ): | |
| self.ensure_mean = ensure_mean | |
| self.ensure_norm = ensure_norm | |
| self.project_sphere = project_sphere | |
| self.base_ref = base_ref | |
| self.norm_base = lp_norm(base_ref.act) | |
| self.unrelated_ref = unrelated_ref | |
| def scan(self, n_steps: int = 33, range: tuple[float, float] = (0, 1.0), step_angular=True): | |
| direction = self.gen_direction() | |
| if self.ensure_mean: | |
| direction -= mean(direction) | |
| if self.ensure_norm: | |
| direction *= self.norm_base / lp_norm(direction) | |
| range = np.array(range) | |
| if step_angular: | |
| s = Slerp(self.base_ref.act, direction, datasetmean=None) | |
| self.activation_steps = [s(alpha) for alpha in torch.linspace(*range, n_steps)] | |
| else: | |
| self.activation_steps = [self.base_ref.act + alpha * direction for alpha in torch.linspace(*range, n_steps)] | |
| act = torch.cat(self.activation_steps, dim=0) | |
| torch.cuda.empty_cache() | |
| base_result, unrelated_result = eval_activation(act, self.base_ref, self.unrelated_ref) | |
| return base_result, unrelated_result, direction | |
| class RandomUniformPerturbation(Perturbation): | |
| def gen_direction(self): | |
| return torch.randn_like(self.base_ref.act) | |
| class RandomPerturbation(Perturbation): | |
| def gen_target(self): | |
| self.target = self.distrib.sample(self.base_ref.act.shape[:-1]) | |
| def gen_direction(self): | |
| self.gen_target() | |
| return self.target - self.base_ref.act | |
| class RandomActDirPerturbation(Perturbation): | |
| def gen_target(self): | |
| self.target = get_random_activation( | |
| self.base_ref.model, | |
| dataset, | |
| self.base_ref.n_ctx, | |
| self.base_ref.layer, | |
| self.base_ref.pos, | |
| ) | |
| def gen_direction(self): | |
| self.gen_target() | |
| return self.target - self.base_ref.act | |
| class NeuronDirPerturbation(Perturbation): | |
| def __init__(self, ref: Reference, negate=False, on_only=None): | |
| super().__init__(ref) | |
| self.base_ref = ref | |
| self.negate = 1 if not negate else -1 | |
| if on_only is True: | |
| feature_acts = self.base_ref.cache[ref.layer][0, -1, :] | |
| self.active_features = feature_acts / feature_acts.max() > 0.1 | |
| def gen_direction(self): | |
| if self.active_features is None: | |
| random_int = random.randint(0, 768) | |
| else: | |
| random_int = random.choice(self.active_features.nonzero(as_tuple=True)[0]) | |
| one_hot = torch.zeros_like(self.base_ref.act) | |
| one_hot[..., random_int] = 1 | |
| single_direction = self.negate * one_hot | |
| return torch.stack([single_direction for _ in range(self.base_ref.act.shape[0])]) | |
| class SAEDecoderDirPerturbation(Perturbation): | |
| def __init__( | |
| self, | |
| base_ref: Reference, | |
| unrelated_ref: Reference, | |
| sae, | |
| ensure_mean=True, | |
| ensure_norm=True, | |
| project_sphere=False, | |
| negate=False, | |
| on_only=None, | |
| ): | |
| super().__init__( | |
| base_ref=base_ref, | |
| unrelated_ref=unrelated_ref, | |
| ensure_mean=ensure_mean, | |
| ensure_norm=ensure_norm, | |
| project_sphere=project_sphere, | |
| ) | |
| self.sae = sae | |
| self.active_features = None | |
| self.negate = 1 if not negate else -1 | |
| if on_only is True: | |
| feature_acts = sae.encode(self.base_ref.cache[self.sae.cfg.hook_name])[0, -1, :] | |
| self.active_features = feature_acts / feature_acts.max() > 0.1 | |
| self.active_features = self.active_features.to("cpu") | |
| print("Using active features:", self.active_features.nonzero(as_tuple=True)[0]) | |
| def gen_direction(self): | |
| if self.active_features is None: | |
| random_int = random.randint(0, 24_000) | |
| else: | |
| random_int = random.choice(self.active_features.nonzero(as_tuple=True)[0]) | |
| single_direction = self.negate * self.sae.W_dec[random_int, :].to("cpu").detach() | |
| if isinstance(self.base_ref.pos, slice): | |
| self.direction = torch.stack([single_direction for _ in range(self.base_ref.act.shape[0])]).unsqueeze(0) | |
| else: | |
| self.direction = single_direction.unsqueeze(0) | |
| class SAEReplacementPerturbation(Perturbation): | |
| def __init__(self, ref: Reference, sae): | |
| super().__init__(ref) | |
| self.sae = sae | |
| def gen_target(self): | |
| sae_out = self.sae(self.base_ref.cache[self.sae.cfg.hook_name]) | |
| return sae_out[0, self.base_ref.pos, :] | |
| def gen_direction(self): | |
| sae_out = self.gen_target() | |
| self.target = sae_out | |
| self.direction = sae_out - self.base_ref.act | |
| class OptimizedPerturbation(Perturbation): | |
| def __init__(self, ref: Reference, learning_rate: float = 0.1): | |
| super().__init__(ref) | |
| self.learning_rate = learning_rate | |
| def sample(self, num_steps: int, step_size: float = 0.01, sign=-1, init=None): | |
| kl_divs = [] | |
| directions = [] | |
| if init is not None: | |
| init = np.array(init.cpu().detach()) | |
| direction = torch.tensor(init, requires_grad=True, device=self.base_ref.act.device) | |
| else: | |
| direction = torch.randn_like(self.base_ref.act, requires_grad=True) | |
| optimizer = torch.optim.SGD([direction], lr=self.learning_rate) | |
| for _ in tqdm(range(num_steps), desc="Finding perturbation direction"): | |
| kl_div, _, _ = eval_direction(direction, self.base_ref, step_size) | |
| kl_divs.append(kl_div.item()) | |
| directions.append(direction.detach().clone()) | |
| (sign * kl_div).backward(retain_graph=True) | |
| optimizer.step() | |
| optimizer.zero_grad() | |
| return kl_divs, directions | |
| def format_ax(ax1, results, label, color, angles): | |
| assert torch.allclose(results.angle.T - results.angle[:, 0].T, torch.tensor(0.0)) | |
| assert torch.allclose(results.dist.T - results.dist[:, 0].T, torch.tensor(0.0)) | |
| dists = results.dist[:, 0] | |
| ax1.plot(angles, results.kl_div, label=label, color=color) | |
| ax1.set_ylabel("KL Divergence") | |
| ax1.grid() | |
| ax1.legend() | |
| ax1.set_xlabel("Angle") | |
| ax1.set_xlim(min(angles), max(angles)) | |
| ax1.set_xticks(angles) | |
| ax1.set_xticklabels([f"{angle * 180 / np.pi:.0f}°" for angle in angles], rotation=-80) | |
| if label is not None: | |
| ax2 = ax1.twiny() | |
| ax2.set_xlabel("Distance") | |
| ax2.set_xticks(ax1.get_xticks()) | |
| ax2.set_xticklabels([f"{dist:.1f}" for dist in dists], rotation=-80) | |
| torch.cuda.empty_cache() | |
| print(f"Allocated: {torch.cuda.memory_allocated() / 1024**3:.2f} GB") | |
| seed = 0 | |
| set_seed(seed) | |
| n_ctx = 10 | |
| layer = "blocks.1.hook_resid_pre" | |
| pos = slice(-1, None, 1) | |
| num_steps = 40 | |
| step_size = 1 | |
| learning_rate = 1 | |
| dataset = load_dataset("apollo-research/Skylion007-openwebtext-tokenizer-gpt2", split="train", streaming=False).with_format("torch") | |
| dataloader = torch.utils.data.DataLoader(dataset, batch_size=15, shuffle=True) | |
| model = HookedTransformer.from_pretrained("gpt2") | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| sae = SAE.from_pretrained(release="gpt2-small-res-jb", sae_id=layer, device="cuda").to("cpu") | |
| base_prompt = generate_prompt(dataset, model.tokenizer, n_ctx=n_ctx) | |
| base_ref = Reference(model, base_prompt, layer, pos, n_ctx) | |
| unrelated_prompt = generate_prompt(dataset, model.tokenizer, n_ctx=n_ctx) | |
| unrelated_ref = Reference(model, unrelated_prompt, layer, pos, n_ctx) | |
| torch.cuda.empty_cache() | |
| print(f"Allocated: {torch.cuda.memory_allocated() / 1024**3:.2f} GB") | |
| tensor_of_prompts = generate_prompt(dataset, model.tokenizer, n_ctx=n_ctx, batch=4000) | |
| mean_act_cache = model.run_with_cache(tensor_of_prompts)[1].to("cpu") | |
| plt.plot(mean_act_cache[layer].mean(dim=0)[1:].mean(dim=0).cpu()) | |
| print(f"Allocated: {torch.cuda.memory_allocated() / 1024**3:.2f} GB") | |
| data = mean_act_cache[layer][:, -1, :] | |
| data_mean = data.mean(dim=0, keepdim=True) | |
| data_cov = torch.einsum("ij,ik->jk", data - data_mean, data - data_mean) / data.shape[0] | |
| distrib = MultivariateNormal(data_mean.squeeze(0), data_cov) | |
| random_perturbation = RandomPerturbation(base_ref, unrelated_ref) | |
| random_perturbation.distrib = distrib | |
| randomact_perturbation = RandomActDirPerturbation(base_ref, unrelated_ref) | |
| results, _, _ = random_perturbation.scan(n_steps=21, step_angular=False, range=(0, 1)) | |
| plt.scatter(results.angle, results.out_angle * 180 / np.pi, label="Random direction") | |
| plt.figure() | |
| plt.scatter(results.angle, results.l2_diff, label="Random direction") | |
| plt.figure() | |
| plt.scatter(results.angle, results.kl_div, label="Random direction") | |
| def find_sensitivity(x, y, threshold=0.1): | |
| f = sip.interp1d(x, y, kind="linear") | |
| angles_interp = np.linspace(x.min(), x.max(), 1000) | |
| index_interp = np.argmax(f(angles_interp) > threshold) | |
| return angles_interp[index_interp] | |
| dd = defaultdict(list) | |
| fig0, ax0 = plt.subplots(1, 1, figsize=(8, 6), constrained_layout=True) | |
| for sample_index in tqdm(range(100)): | |
| for perturbation in ["randdir", "randact"]: | |
| for style in ["straight", "angle"]: | |
| pert = random_perturbation if perturbation == "randdir" else randomact_perturbation | |
| if style == "angle": | |
| results, _, _ = pert.scan(n_steps=91, step_angular=True, range=(0, np.pi / 2)) | |
| else: | |
| results, _, _ = pert.scan(n_steps=21, step_angular=False, range=(0, 1)) | |
| dd[perturbation + "_" + style + "_kldiv_at_1"].append(results.kl_div[1]) | |
| dd[perturbation + "_" + style + "_l2diff_at_1"].append(results.l2_diff[1]) | |
| dd[perturbation + "_" + style + "_cossim_at_1"].append(np.cos(results.out_angle[1])) | |
| dd[perturbation + "_" + style + "_kldiv_at_2"].append(results.kl_div[1]) | |
| dd[perturbation + "_" + style + "_l2diff_at_2"].append(results.l2_diff[1]) | |
| dd[perturbation + "_" + style + "_cossim_at_2"].append(np.cos(results.out_angle[1])) | |
| dd[perturbation + "_" + style + "_kldiv_at_5"].append(results.kl_div[5]) | |
| dd[perturbation + "_" + style + "_l2diff_at_5"].append(results.l2_diff[5]) | |
| dd[perturbation + "_" + style + "_cossim_at_5"].append(np.cos(results.out_angle[5])) | |
| dd[perturbation + "_" + style + "_kldiv_at_10"].append(results.kl_div[10]) | |
| dd[perturbation + "_" + style + "_l2diff_at_10"].append(results.l2_diff[10]) | |
| dd[perturbation + "_" + style + "_cossim_at_10"].append(np.cos(results.out_angle[10])) | |
| dd[perturbation + "_" + style + "_kldiv_at_20"].append(results.kl_div[20]) | |
| dd[perturbation + "_" + style + "_l2diff_at_20"].append(results.l2_diff[20]) | |
| dd[perturbation + "_" + style + "_cossim_at_20"].append(np.cos(results.out_angle[20])) | |
| if style == "angle": | |
| f = lambda threshold: find_sensitivity(results.angle[:, 0], results.kl_div, threshold) | |
| else: | |
| f = lambda threshold: find_sensitivity(results.l2_diff.squeeze(), results.kl_div, threshold) | |
| dd[perturbation + "_" + style + "_where01KL"].append(f(0.1)) | |
| dd[perturbation + "_" + style + "_where05KL"].append(f(0.5)) | |
| dd[perturbation + "_" + style + "_where1KL"].append(f(1)) | |
| dd[perturbation + "_" + style + "_where2KL"].append(f(2)) | |
| if style == "angle": | |
| f = sip.interp1d( | |
| results.l2_diff, | |
| results.angle.squeeze(), | |
| kind="linear", | |
| fill_value=0, | |
| bounds_error=False, | |
| ) | |
| else: | |
| f = sip.interp1d( | |
| results.l2_diff, | |
| results.dist.squeeze(), | |
| kind="linear", | |
| fill_value=0, | |
| bounds_error=False, | |
| ) | |
| dd[perturbation + "_" + style + "_where10L2"].append(f(10)) | |
| dd[perturbation + "_" + style + "_where20L2"].append(f(20)) | |
| if style == "angle": | |
| color = "C0" if perturbation == "randdir" else "C1" | |
| ax0.plot(results.angle, results.kl_div, color=color, label=perturbation, lw=0.2) | |
| fig, axes = plt.subplots(6, 6, figsize=(20, 20), constrained_layout=True) | |
| fig2, axes2 = plt.subplots(2, 6, figsize=(20, 6), constrained_layout=True) | |
| for i, style in enumerate(["straight", "angle"]): | |
| for j, metric in enumerate(["kldiv", "l2diff", "cossim"]): | |
| for k, dist in enumerate([1, 2, 5, 10, 20]): | |
| axes[2 * j + i, k].hist( | |
| dd["randdir_" + style + "_" + metric + "_at_" + str(dist)], | |
| bins=20, | |
| alpha=0.5, | |
| label="randdir", | |
| density=True, | |
| ) | |
| axes[2 * j + i, k].hist( | |
| dd["randact_" + style + "_" + metric + "_at_" + str(dist)], | |
| bins=20, | |
| alpha=0.5, | |
| label="randact", | |
| density=True, | |
| ) | |
| axes[2 * j + i, k].set_title(f"{style} {metric} at {dist}") | |
| axes[2 * j + i, k].legend() | |
| axes2[i, 0].hist(dd["randdir_" + style + "_where01KL"], bins=20, alpha=0.5, label="randdir", density=True) | |
| axes2[i, 0].hist(dd["randact_" + style + "_where01KL"], bins=20, alpha=0.5, label="randact", density=True) | |
| axes2[i, 0].set_title(f"{style} where KL=0.1") | |
| axes2[i, 0].legend() | |
| axes2[i, 1].hist(dd["randdir_" + style + "_where05KL"], bins=20, alpha=0.5, label="randdir", density=True) | |
| axes2[i, 1].hist(dd["randact_" + style + "_where05KL"], bins=20, alpha=0.5, label="randact", density=True) | |
| axes2[i, 1].set_title(f"{style} where KL=0.5") | |
| axes2[i, 1].legend() | |
| axes2[i, 2].hist(dd["randdir_" + style + "_where1KL"], bins=20, alpha=0.5, label="randdir", density=True) | |
| axes2[i, 2].hist(dd["randact_" + style + "_where1KL"], bins=20, alpha=0.5, label="randact", density=True) | |
| axes2[i, 2].set_title(f"{style} where KL=1") | |
| axes2[i, 2].legend() | |
| axes2[i, 3].hist(dd["randdir_" + style + "_where2KL"], bins=20, alpha=0.5, label="randdir", density=True) | |
| axes2[i, 3].hist(dd["randact_" + style + "_where2KL"], bins=20, alpha=0.5, label="randact", density=True) | |
| axes2[i, 3].set_title(f"{style} where KL=2") | |
| axes2[i, 3].legend() | |
| axes2[i, 4].hist(dd["randdir_" + style + "_where10L2"], bins=20, alpha=0.5, label="randdir", density=True) | |
| axes2[i, 4].hist(dd["randact_" + style + "_where10L2"], bins=20, alpha=0.5, label="randact", density=True) | |
| axes2[i, 4].set_title(f"{style} where L2=10") | |
| axes2[i, 4].legend() | |
| axes2[i, 5].hist(dd["randdir_" + style + "_where20L2"], bins=20, alpha=0.5, label="randdir", density=True) | |
| axes2[i, 5].hist(dd["randact_" + style + "_where20L2"], bins=20, alpha=0.5, label="randact", density=True) | |
| axes2[i, 5].set_title(f"{style} where L2=20") | |
| axes2[i, 5].legend() |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| # I didn't include preliminary SAE figures & optimization for KL-div in the post | |
| # because I wasn't confident in those experiments (and I don't recommend relying | |
| # on that code). For good SAE evaluations see https://arxiv.org/abs/2409.15019 | |
| # and https://arxiv.org/abs/2410.12555, and for optimizing for KL-div see the | |
| # concurrent work of Andrew Mack & Alex Turner: | |
| # https://www.lesswrong.com/posts/ioPnHKFyy4Cw2Gr2x/mechanistically-eliciting-latent-behaviors-in-language-1 |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| # %% | |
| import os | |
| import random | |
| from collections.abc import Callable | |
| from dataclasses import dataclass | |
| from datetime import datetime | |
| import matplotlib | |
| import matplotlib.pyplot as plt | |
| matplotlib.use("Agg") # Use non-interactive backend | |
| import numpy as np | |
| import plotly.figure_factory as ff | |
| import scipy.interpolate as sip | |
| import torch | |
| import torch.nn.functional as F | |
| from datasets import load_dataset | |
| from plotly.subplots import make_subplots | |
| from torch.distributions.multivariate_normal import MultivariateNormal | |
| from tqdm import tqdm | |
| from transformer_lens import HookedTransformer | |
| os.environ["TOKENIZERS_PARALLELISM"] = "false" | |
| # Ensure plots directory exists | |
| os.makedirs("plots", exist_ok=True) | |
| # %% | |
| def lp_norm(direction, p=2): | |
| """Pytorch norm""" | |
| return torch.linalg.vector_norm(direction, dim=-1, keepdim=True, ord=p) | |
| def mean(direction): | |
| return torch.mean(direction, dim=-1, keepdim=True) | |
| def dot(x, y): | |
| return torch.einsum("...k,...k->...", x, y).unsqueeze(-1) | |
| def compute_angle_dist(ref_act, new_act, datasetmean=None): | |
| new_act = new_act.cpu().detach() | |
| angle = torch.acos(dot(new_act, ref_act) / (lp_norm(new_act) * lp_norm(ref_act))).squeeze(-1) | |
| if torch.all(angle[0].isnan()): | |
| angle[0] = torch.zeros_like(angle[0]) | |
| if torch.all(angle[-1].isnan()): | |
| angle[-1] = np.pi * torch.ones_like(angle[-1]) | |
| dist = lp_norm(new_act - ref_act).squeeze(-1) | |
| if datasetmean is None: | |
| return angle, dist | |
| else: | |
| angle_wrt_datasetmean = torch.acos(dot(new_act - datasetmean, ref_act - datasetmean) / (lp_norm(new_act - datasetmean) * lp_norm(ref_act - datasetmean))).squeeze( | |
| -1 | |
| ) | |
| return angle, dist, angle_wrt_datasetmean | |
| class Reference: | |
| def __init__( | |
| self, | |
| model: HookedTransformer, | |
| prompt: torch.Tensor, | |
| replacement_layer: str, | |
| read_layer: str, | |
| replacement_pos: slice, | |
| n_ctx: int, | |
| ): | |
| self.model = model | |
| n_batch_prompt, n_ctx_prompt = prompt.shape | |
| assert n_ctx == n_ctx_prompt, f"n_ctx {n_ctx} must match prompt n_ctx {n_ctx_prompt}" | |
| self.prompt = prompt | |
| logits, cache = model.run_with_cache(prompt) | |
| self.logits = logits.to("cpu").detach() | |
| self.cache = cache.to("cpu") | |
| self.act = self.cache[replacement_layer][:, replacement_pos] | |
| self.replacement_layer = replacement_layer | |
| self.read_layer = read_layer | |
| self.replacement_pos = replacement_pos | |
| self.n_ctx = n_ctx | |
| @dataclass | |
| class Result: | |
| angle: float | |
| angle_wrt_datasetmean: float | |
| dist: float | |
| norm: float | |
| kl_div: float | |
| out_angle: float | |
| l2_diff: float | |
| logit_l2_diff: float | |
| dim0_diff: float | |
| dim1_diff: float | |
| def set_seed(seed: int): | |
| """Set the random seed for reproducibility.""" | |
| torch.manual_seed(seed) | |
| np.random.seed(seed) | |
| random.seed(seed) | |
| def generate_prompt(dataset, tokenizer: Callable, n_ctx: int = 1, batch: int = 1) -> torch.Tensor: | |
| """Generate a prompt from the dataset.""" | |
| dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch, shuffle=True) | |
| return next(iter(dataloader))["input_ids"][:, :n_ctx] | |
| def get_random_activation(model: HookedTransformer, dataset: torch.Tensor, n_ctx: int, layer: str, pos) -> torch.Tensor: | |
| """Get a random activation from the dataset.""" | |
| rand_prompt = generate_prompt(dataset, model.tokenizer, n_ctx=n_ctx) | |
| _, cache = model.run_with_cache(rand_prompt) | |
| return cache[layer][:, pos, :].to("cpu").detach() | |
| def compute_kl_div(logits_ref: torch.Tensor, logits_pert: torch.Tensor) -> torch.Tensor: | |
| """Compute the KL divergence between the reference and perturbed logprobs.""" | |
| logits_pert = logits_pert.cpu().detach() | |
| logprobs_ref = F.log_softmax(logits_ref, dim=-1) | |
| logprobs_pert = F.log_softmax(logits_pert, dim=-1) | |
| return F.kl_div(logprobs_pert, logprobs_ref, log_target=True, reduction="none").sum(dim=-1) | |
| def compute_Lp_metric( | |
| cache: dict, | |
| cache_pert: dict, | |
| p, | |
| read_layer, | |
| read_pos, | |
| ): | |
| ref_readoff = cache[read_layer][:, read_pos] | |
| pert_readoff = cache_pert[read_layer][:, read_pos].cpu().detach() | |
| Lp_diff = torch.linalg.norm(ref_readoff - pert_readoff, ord=p, dim=-1) | |
| return Lp_diff | |
| def run_perturbed_activation(perturbed_act: torch.Tensor, ref: Reference): | |
| pos = ref.replacement_pos | |
| layer = ref.replacement_layer | |
| def hook(act, hook): | |
| act[:, pos, :] = perturbed_act | |
| with ref.model.hooks(fwd_hooks=[(layer, hook)]): | |
| prompts = torch.cat([ref.prompt for _ in range(len(perturbed_act))]) | |
| logits_pert, cache = ref.model.run_with_cache(prompts) | |
| return logits_pert, cache | |
| def eval_activation( | |
| perturbed_act: torch.Tensor, | |
| base_ref: Reference, | |
| unrelated_ref: Reference, | |
| read_pos=-1, | |
| ): | |
| read_layer = base_ref.read_layer | |
| angle, dist = compute_angle_dist(base_ref.act, perturbed_act, datasetmean=None) | |
| angle_wrt_datasetmean = angle | |
| norm = lp_norm(perturbed_act).squeeze(-1) | |
| logits_pert, cache = run_perturbed_activation(perturbed_act, base_ref) | |
| logits_pert = logits_pert.to("cpu").detach() | |
| base_kl_div = compute_kl_div(base_ref.logits, logits_pert)[:, read_pos] | |
| base_l1_diff = compute_Lp_metric(base_ref.cache, cache, p=1, read_layer=read_layer, read_pos=read_pos) | |
| base_l2_diff = compute_Lp_metric(base_ref.cache, cache, p=2, read_layer=read_layer, read_pos=read_pos) | |
| base_logit_l2_diff = torch.linalg.norm(base_ref.logits - logits_pert, ord=2, dim=-1) | |
| base_dim0_diff = (base_ref.cache[read_layer][:, read_pos, 0] - cache[read_layer][:, read_pos, 0].cpu()).detach() | |
| base_dim1_diff = (base_ref.cache[read_layer][:, read_pos, 1] - cache[read_layer][:, read_pos, 1].cpu()).detach() | |
| base_out_angle, _ = compute_angle_dist( | |
| base_ref.cache[read_layer][:, read_pos, :], | |
| cache[read_layer][:, read_pos, :], | |
| datasetmean=None, | |
| ) | |
| unrelated_kl_div = compute_kl_div(unrelated_ref.logits, logits_pert)[:, read_pos] | |
| unrelated_l1_diff = compute_Lp_metric(unrelated_ref.cache, cache, p=1, read_layer=read_layer, read_pos=read_pos) | |
| unrelated_l2_diff = compute_Lp_metric(unrelated_ref.cache, cache, p=2, read_layer=read_layer, read_pos=read_pos) | |
| unrelated_logit_l2_diff = torch.linalg.norm(unrelated_ref.logits - logits_pert, ord=2, dim=-1) | |
| unrelated_dim0_diff = unrelated_ref.cache[read_layer][:, read_pos, 0] - cache[read_layer][:, read_pos, 0].cpu().detach() | |
| unrelated_dim1_diff = unrelated_ref.cache[read_layer][:, read_pos, 1] - cache[read_layer][:, read_pos, 1].cpu().detach() | |
| unrelated_out_angle, _ = compute_angle_dist( | |
| unrelated_ref.cache[read_layer][:, read_pos, :], | |
| cache[read_layer][:, read_pos, :], | |
| datasetmean=None, | |
| ) | |
| return ( | |
| Result( | |
| angle, | |
| angle_wrt_datasetmean, | |
| dist, | |
| norm, | |
| base_kl_div, | |
| base_out_angle, | |
| base_l2_diff, | |
| base_logit_l2_diff, | |
| base_dim0_diff, | |
| base_dim1_diff, | |
| ), | |
| Result( | |
| angle, | |
| angle_wrt_datasetmean, | |
| dist, | |
| norm, | |
| unrelated_kl_div, | |
| unrelated_out_angle, | |
| unrelated_l2_diff, | |
| unrelated_logit_l2_diff, | |
| unrelated_dim0_diff, | |
| unrelated_dim1_diff, | |
| ), | |
| ) | |
| class Slerp: | |
| def __init__(self, start, direction, datasetmean=None): | |
| """Return interpolation points along the sphere from | |
| A towards direction B""" | |
| self.datasetmean = datasetmean | |
| self.A = start - self.datasetmean if self.datasetmean is not None else start | |
| self.A_norm = lp_norm(self.A) | |
| self.a = self.A / self.A_norm | |
| d = direction / lp_norm(direction) | |
| self.B = d - dot(d, self.a) * self.a | |
| self.B_norm = lp_norm(self.B) | |
| self.b = self.B / self.B_norm | |
| def __call__(self, alpha): | |
| result = self.A_norm * (torch.cos(alpha) * self.a + torch.sin(alpha) * self.b) | |
| if self.datasetmean is not None: | |
| return result + self.datasetmean | |
| else: | |
| return result | |
| def get_alpha(self, X): | |
| x = X / lp_norm(X) | |
| return torch.acos(dot(x, self.a)) | |
| class Perturbation: | |
| def gen_direction(self): | |
| raise NotImplementedError | |
| def __init__( | |
| self, | |
| base_ref: Reference, | |
| unrelated_ref: Reference, | |
| ensure_mean=True, | |
| ensure_norm=True, | |
| ): | |
| self.ensure_mean = ensure_mean | |
| self.ensure_norm = ensure_norm | |
| self.base_ref = base_ref | |
| self.norm_base = lp_norm(base_ref.act) | |
| self.unrelated_ref = unrelated_ref | |
| def scan(self, n_steps: int = 33, range: tuple[float, float] = (0, 1.0), step_angular=True): | |
| direction = self.gen_direction() | |
| return self.scan_dir(direction, n_steps, range, step_angular) | |
| def scan_dir(self, direction, n_steps: int = 33, range: tuple[float, float] = (0, 1.0), step_angular=True): | |
| direction = direction.clone() | |
| if self.ensure_mean: | |
| direction -= mean(direction) | |
| if self.ensure_norm: | |
| direction *= self.norm_base / lp_norm(direction) | |
| range = np.array(range) | |
| if step_angular: | |
| s = Slerp(self.base_ref.act, direction, datasetmean=None) | |
| self.activation_steps = [s(alpha) for alpha in torch.linspace(*range, n_steps)] | |
| else: | |
| self.activation_steps = [self.base_ref.act + alpha * direction for alpha in torch.linspace(*range, n_steps)] | |
| act = torch.cat(self.activation_steps, dim=0) | |
| torch.cuda.empty_cache() | |
| base_result, unrelated_result = eval_activation(act, self.base_ref, self.unrelated_ref) | |
| return base_result, unrelated_result, direction | |
| class RandomUniformPerturbation(Perturbation): | |
| def gen_direction(self): | |
| return torch.randn_like(self.base_ref.act) | |
| class RandomPerturbation(Perturbation): | |
| def gen_target(self): | |
| return self.distrib.sample(self.base_ref.act.shape[:-1]) | |
| def gen_direction(self): | |
| self.target = self.gen_target() | |
| return self.target - self.base_ref.act | |
| class RandomActDirPerturbation(Perturbation): | |
| def gen_target(self): | |
| return get_random_activation( | |
| self.base_ref.model, | |
| dataset, | |
| self.base_ref.n_ctx, | |
| self.base_ref.replacement_layer, | |
| self.base_ref.replacement_pos, | |
| ) | |
| def gen_direction(self): | |
| self.target = self.gen_target() | |
| return self.target - self.base_ref.act | |
| class NeuronDirPerturbation(Perturbation): | |
| def __init__(self, ref: Reference, negate=False, on_only=None): | |
| super().__init__(ref) | |
| self.base_ref = ref | |
| self.negate = 1 if not negate else -1 | |
| if on_only is True: | |
| feature_acts = self.base_ref.cache[ref.replacement_layer][0, -1, :] | |
| self.active_features = feature_acts / feature_acts.max() > 0.1 | |
| def gen_direction(self): | |
| if self.active_features is None: | |
| random_int = random.randint(0, 768) | |
| else: | |
| random_int = random.choice(self.active_features.nonzero(as_tuple=True)[0]) | |
| one_hot = torch.zeros_like(self.base_ref.act) | |
| one_hot[..., random_int] = 1 | |
| single_direction = self.negate * one_hot | |
| return torch.stack([single_direction for _ in range(self.base_ref.act.shape[0])]) | |
| class OptimizedPerturbation(Perturbation): | |
| def __init__(self, ref: Reference, learning_rate: float = 0.1): | |
| super().__init__(ref) | |
| self.learning_rate = learning_rate | |
| def sample(self, num_steps: int, step_size: float = 0.01, sign=-1, init=None): | |
| kl_divs = [] | |
| directions = [] | |
| if init is not None: | |
| init = np.array(init.cpu().detach()) | |
| direction = torch.tensor(init, requires_grad=True, device=self.base_ref.act.device) | |
| else: | |
| direction = torch.randn_like(self.base_ref.act, requires_grad=True) | |
| optimizer = torch.optim.SGD([direction], lr=self.learning_rate) | |
| for _ in tqdm(range(num_steps), desc="Finding perturbation direction"): | |
| kl_div, _, _ = eval_direction(direction, self.base_ref, step_size) | |
| kl_divs.append(kl_div.item()) | |
| directions.append(direction.detach().clone()) | |
| (sign * kl_div).backward(retain_graph=True) | |
| optimizer.step() | |
| optimizer.zero_grad() | |
| return kl_divs, directions | |
| def format_ax(ax1, ax1t, ax2, ax2t, results, label, color, x, base_ref, ls="-", x_is_angle=True): | |
| assert torch.allclose(results.angle.T - results.angle[:, 0].T, torch.tensor(0.0)) | |
| assert torch.allclose(results.dist.T - results.dist[:, 0].T, torch.tensor(0.0)) | |
| angles = r.angle[:, 0].detach() | |
| dists = results.dist[:, 0].detach() | |
| ax1.plot(x, results.kl_div, label=label, color=color, lw=0.5, ls=ls) | |
| ax2.plot(x, results.l2_diff, label=label, color=color, lw=0.5, ls=ls) | |
| ax1.set_ylabel("KL divergence to base logits") | |
| ax2.set_ylabel(f"L2 difference in {base_ref.read_layer}") | |
| ax1.legend() | |
| ax2.legend() | |
| ax1.set_xlim(min(x), max(x)) | |
| ax2.set_xlim(min(x), max(x)) | |
| if len(angles) < 40: | |
| tick_angles = angles | |
| tick_dists = dists | |
| else: | |
| tick_angles = angles[::10] | |
| tick_dists = dists[::10] | |
| if x_is_angle: | |
| ax1.set_xlabel(f"Angle from base activation at {layer}") | |
| ax2.set_xlabel(f"Angle from base activation at {layer}") | |
| ax1.set_xticks(tick_angles) | |
| ax1.set_xticklabels([f"{angle * 180 / np.pi:.0f}°" for angle in tick_angles], rotation=-80) | |
| ax2.set_xticks(tick_angles) | |
| ax2.set_xticklabels([f"{angle * 180 / np.pi:.0f}°" for angle in tick_angles], rotation=-80) | |
| if label is not None: | |
| ax1t.set_xlabel("Distance") | |
| ax2t.set_xlabel("Distance") | |
| ax1t.set_xticks(ax1.get_xticks()) | |
| ax2t.set_xticks(ax2.get_xticks()) | |
| ax1t.set_xticklabels([f"{dist:.1f}" for dist in tick_dists], rotation=-80) | |
| ax2t.set_xticklabels([f"{dist:.1f}" for dist in tick_dists], rotation=-80) | |
| else: | |
| ax1.set_xlabel(f"Distance from base activation at {layer}") | |
| ax2.set_xlabel(f"Distance from base activation at {layer}") | |
| ax1.set_xticks(tick_dists) | |
| ax1.set_xticklabels([f"{dist:.1f}" for dist in tick_dists], rotation=-80) | |
| ax2.set_xticks(tick_dists) | |
| ax2.set_xticklabels([f"{dist:.1f}" for dist in tick_dists], rotation=-80) | |
| if label is not None: | |
| ax1t.set_xlabel("Angle") | |
| ax2t.set_xlabel("Angle") | |
| ax1t.set_xticks(ax1.get_xticks()) | |
| ax2t.set_xticks(ax2.get_xticks()) | |
| ax1t.set_xticklabels([f"{angle * 180 / np.pi:.0f}°" for angle in tick_angles], rotation=-80) | |
| ax2t.set_xticklabels([f"{angle * 180 / np.pi:.0f}°" for angle in tick_angles], rotation=-80) | |
| def find_sensitivity(x, y, threshold=0.5): | |
| f = sip.interp1d(x, y, kind="linear") | |
| angles_interp = np.linspace(x.min(), x.max(), 100000) | |
| index_interp = np.argmax(f(angles_interp) > threshold) | |
| return angles_interp[index_interp] | |
| # %% | |
| torch.cuda.empty_cache() | |
| print(f"Allocated: {torch.cuda.memory_allocated() / 1024**3:.2f} GB") | |
| ################################ | |
| seed = 0 | |
| step_angular = True | |
| ################################ | |
| on_only = True | |
| negate = on_only | |
| set_seed(seed) | |
| n_ctx = 10 | |
| layer = "blocks.1.hook_resid_pre" | |
| pos = slice(-1, None, 1) | |
| num_steps = 40 | |
| step_size = 1 | |
| learning_rate = 1 | |
| dataset = load_dataset("apollo-research/Skylion007-openwebtext-tokenizer-gpt2", split="train", streaming=False).with_format("torch") | |
| dataloader = torch.utils.data.DataLoader(dataset, batch_size=15, shuffle=True) | |
| # %% | |
| model = HookedTransformer.from_pretrained("gpt2") | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| # %% | |
| base_prompt = generate_prompt(dataset, model.tokenizer, n_ctx=n_ctx) | |
| base_ref = Reference(model, base_prompt, layer, "blocks.11.hook_resid_post", pos, n_ctx) | |
| unrelated_prompt = generate_prompt(dataset, model.tokenizer, n_ctx=n_ctx) | |
| unrelated_ref = Reference(model, unrelated_prompt, layer, "blocks.11.hook_resid_post", pos, n_ctx) | |
| old_norm = lp_norm(base_ref.act).item() | |
| # %% | |
| torch.cuda.empty_cache() | |
| print(f"Allocated: {torch.cuda.memory_allocated() / 1024**3:.2f} GB") | |
| tensor_of_prompts = generate_prompt(dataset, model.tokenizer, n_ctx=n_ctx, batch=4000) | |
| print("Running model with cache") | |
| mean_act_cache = model.run_with_cache(tensor_of_prompts)[1].to("cpu") | |
| print(f"Allocated: {torch.cuda.memory_allocated() / 1024**3:.2f} GB") | |
| data = mean_act_cache[layer][:, -1, :] | |
| data_mean = data.mean(dim=0, keepdim=True) | |
| data_cov = torch.einsum("ij,ik->jk", data - data_mean, data - data_mean) / data.shape[0] | |
| distrib = MultivariateNormal(data_mean.squeeze(0), data_cov) | |
| # %% | |
| random_perturbation = RandomPerturbation(base_ref, unrelated_ref) | |
| random_perturbation.distrib = distrib | |
| randomact_perturbation = RandomActDirPerturbation(base_ref, unrelated_ref) | |
| def test_surroundings_parallel(perturbation, distrib, ax, n_side=5, d_side=1, steps=28, n_contours=20): | |
| model.reset_hooks() | |
| orig_target = perturbation.gen_target() | |
| rand_target_1 = distrib.sample(perturbation.base_ref.act.shape[:-1]) | |
| rand_target_2 = distrib.sample(perturbation.base_ref.act.shape[:-1]) | |
| a, b = np.mgrid[0 : d_side : n_side * 1j, 0 : d_side : n_side * 1j] | |
| mask = a + b <= 1 | |
| a, b = a[mask], b[mask] | |
| coords = np.stack((a, b, 1 - a - b)) | |
| assert np.all(coords >= 0) | |
| n_coords = coords.shape[1] | |
| sensitivity = np.zeros(n_coords) | |
| dirs = [] | |
| if len(coords.T) * steps < 1700: | |
| for i, (x, y, z) in enumerate(coords.T): | |
| assert x + y + z - 1 < 1e-5, f"Sum of coordinates must be 1, got {x + y + z}" | |
| target = z * orig_target + x * rand_target_1 + y * rand_target_2 | |
| dirs.append(target - base_ref.act) | |
| dirs = torch.concatenate(dirs, dim=0) | |
| results, _, _ = perturbation.scan_dir(dirs, n_steps=steps, range=(0, np.pi / 3)) | |
| angles = results.angle[:, 0].reshape(steps, n_coords) | |
| kl_divs = results.kl_div.reshape(steps, n_coords) | |
| else: | |
| print("Warning: High number of items (steps*coords), implementing batching over coords") | |
| batch_size = 1700 // steps | |
| angles = [] | |
| kl_divs = [] | |
| for i in tqdm(range(0, n_coords, batch_size), desc="Batches", total=n_coords // batch_size): | |
| batch_coordsT = coords.T[i : i + batch_size] | |
| batch_dirs = [] | |
| for x, y, z in batch_coordsT: | |
| assert x + y + z - 1 < 1e-5, f"Sum of coordinates must be 1, got {x + y + z}" | |
| target = z * orig_target + x * rand_target_1 + y * rand_target_2 | |
| batch_dirs.append(target - perturbation.base_ref.act) | |
| batch_dirs = torch.concatenate(batch_dirs, dim=0) | |
| batch_results, _, _ = perturbation.scan_dir(batch_dirs, n_steps=steps, range=(0, np.pi / 3)) | |
| angles.append(batch_results.angle[:, 0].reshape(steps, -1)) | |
| kl_divs.append(batch_results.kl_div.reshape(steps, -1)) | |
| angles = torch.concatenate(angles, dim=1) | |
| kl_divs = torch.concatenate(kl_divs, dim=1) | |
| for i, (x, y, z) in enumerate(coords.T): | |
| sensitivity[i] = find_sensitivity(angles[:, i], kl_divs[:, i], threshold=0.5) | |
| sensitivity = sensitivity * 180 / np.pi | |
| fig = ff.create_ternary_contour( | |
| coords[[2, 0, 1]], | |
| -sensitivity, | |
| pole_labels=["orig", "rand1", "rand2"], | |
| colorscale="Viridis", | |
| showscale=False, | |
| coloring=None, | |
| title="Sensitivity to KL-div = 0.5", | |
| ncontours=n_contours, | |
| ) | |
| x_plot = coords[1] + 0.5 * coords[2] | |
| y_plot = np.sqrt(3) / 2 * coords[2] | |
| return fig, x_plot, y_plot, sensitivity | |
| # %% | |
| px_fig = make_subplots( | |
| rows=3, | |
| cols=5, | |
| specs=[[{"type": "scatterternary"} for _ in range(5)] for _ in range(3)], | |
| ) | |
| fig, axes = plt.subplots(3, 5, figsize=(15, 8), constrained_layout=True) | |
| fig.suptitle( | |
| "Angle until KL-div = 0.5 is reached (in degrees)\n(lower angle = higher sensitivity)", | |
| ) | |
| def make_plot(axs, title, perturbation, distrib, n_side=25, steps=10, px_fig=None, px_idx=1): | |
| axs[0].set_ylabel(title) | |
| sensitivities = [] | |
| for i, ax in tqdm(enumerate(axs), total=5, desc=title): | |
| date_str = datetime.now().strftime("%Y_%m_%d_%H_%M_%S") | |
| random_int_from_time = int(date_str[-5:].replace("_", "")) | |
| set_seed(random_int_from_time) | |
| ax.set_title("Seed = " + str(random_int_from_time), fontsize=10) | |
| subfig, x_plot, y_plot, sensitivity = test_surroundings_parallel(perturbation, distrib, ax, n_side=n_side, steps=steps) | |
| sensitivities.append(sensitivity) | |
| for trace in subfig.data: | |
| px_fig.add_trace(trace, row=px_idx, col=i + 1) | |
| for i, ax in enumerate(axs): | |
| vmin = np.min(sensitivities) | |
| vmax = np.max(sensitivities) | |
| im = ax.scatter( | |
| x_plot, | |
| y_plot, | |
| c=sensitivities[i], | |
| cmap="viridis_r", | |
| marker="^", | |
| s=30, | |
| vmin=vmin, | |
| vmax=vmax, | |
| ) | |
| if i == len(axes[0]) - 1: | |
| cbar = plt.colorbar(im, ax=axs) | |
| n_side = 25 | |
| steps = 100 | |
| make_plot(axes[0], "random", random_perturbation, distrib, n_side, steps, px_fig, 1) | |
| make_plot(axes[1], "r-other", randomact_perturbation, distrib, n_side, steps, px_fig, 2) | |
| fig.savefig(f"plots/3_sensitivity_scan_matplotlib_seed_{seed}.png", dpi=300) | |
| plt.close() | |
| px_fig.write_image(f"plots/3_sensitivity_scan_plotly_seed_{seed}.png") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment