Skip to content

Instantly share code, notes, and snippets.

@Stefan-Heimersheim
Last active August 18, 2025 11:27
Show Gist options
  • Select an option

  • Save Stefan-Heimersheim/85c1091408e113e2ef9ca2a798ec6553 to your computer and use it in GitHub Desktop.

Select an option

Save Stefan-Heimersheim/85c1091408e113e2ef9ca2a798ec6553 to your computer and use it in GitHub Desktop.
Code: [Interim research report] Activation plateaus & sensitive directions in GPT2
# 1.py (Figure 1 is a zoom into one panel of Figure 3/4)
import os
import random
from collections import defaultdict
from collections.abc import Callable
from dataclasses import dataclass
import matplotlib
import matplotlib.pyplot as plt
matplotlib.use("Agg")
import numpy as np
import scipy.interpolate as sip
import torch
import torch.nn.functional as F
from datasets import load_dataset
from torch.distributions.multivariate_normal import MultivariateNormal
from tqdm import tqdm
from transformer_lens import HookedTransformer
os.environ["TOKENIZERS_PARALLELISM"] = "false"
os.makedirs("plots", exist_ok=True)
# %%
def lp_norm(direction, p=2):
"""Pytorch norm"""
return torch.linalg.vector_norm(direction, dim=-1, keepdim=True, ord=p)
def mean(direction):
return torch.mean(direction, dim=-1, keepdim=True)
def dot(x, y):
return torch.einsum("...k,...k->...", x, y).unsqueeze(-1)
def compute_angle_dist(ref_act, new_act, datasetmean=None):
angle = torch.acos(dot(new_act, ref_act) / (lp_norm(new_act) * lp_norm(ref_act))).squeeze(-1)
if torch.all(angle[0].isnan()):
angle[0] = torch.zeros_like(angle[0])
if torch.all(angle[-1].isnan()):
angle[-1] = np.pi * torch.ones_like(angle[-1])
dist = lp_norm(new_act - ref_act).squeeze(-1)
if datasetmean is None:
return angle, dist
else:
angle_wrt_datasetmean = torch.acos(dot(new_act - datasetmean, ref_act - datasetmean) / (lp_norm(new_act - datasetmean) * lp_norm(ref_act - datasetmean))).squeeze(
-1
)
return angle, dist, angle_wrt_datasetmean
class Reference:
def __init__(
self,
model: HookedTransformer,
prompt: torch.Tensor,
replacement_layer: str,
read_layer: str,
replacement_pos: slice,
n_ctx: int,
):
self.model = model
n_batch_prompt, n_ctx_prompt = prompt.shape
assert n_ctx == n_ctx_prompt, f"n_ctx {n_ctx} must match prompt n_ctx {n_ctx_prompt}"
self.prompt = prompt
logits, cache = model.run_with_cache(prompt)
self.logits = logits.to("cpu").detach()
self.cache = cache.to("cpu")
self.act = self.cache[replacement_layer][:, replacement_pos]
self.replacement_layer = replacement_layer
self.read_layer = read_layer
self.replacement_pos = replacement_pos
self.n_ctx = n_ctx
@dataclass
class Result:
angle: float
angle_wrt_datasetmean: float
dist: float
norm: float
kl_div: float
out_angle: float
l2_diff: float
logit_l2_diff: float
dim0_diff: float
dim1_diff: float
def set_seed(seed: int):
"""Set the random seed for reproducibility."""
torch.manual_seed(seed)
np.random.seed(seed)
random.seed(seed)
def generate_prompt(dataset, tokenizer: Callable, n_ctx: int = 1, batch: int = 1) -> torch.Tensor:
"""Generate a prompt from the dataset."""
dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch, shuffle=True)
return next(iter(dataloader))["input_ids"][:, :n_ctx]
def get_random_activation(model: HookedTransformer, dataset: torch.Tensor, n_ctx: int, layer: str, pos) -> torch.Tensor:
"""Get a random activation from the dataset."""
rand_prompt = generate_prompt(dataset, model.tokenizer, n_ctx=n_ctx)
_, cache = model.run_with_cache(rand_prompt)
return cache[layer][:, pos, :].to("cpu").detach()
def compute_kl_div(logits_ref: torch.Tensor, logits_pert: torch.Tensor) -> torch.Tensor:
"""Compute the KL divergence between the reference and perturbed logprobs."""
logprobs_ref = F.log_softmax(logits_ref, dim=-1)
logprobs_pert = F.log_softmax(logits_pert, dim=-1)
return F.kl_div(logprobs_pert, logprobs_ref, log_target=True, reduction="none").sum(dim=-1)
def compute_Lp_metric(
cache: dict,
cache_pert: dict,
p,
read_layer,
read_pos,
):
ref_readoff = cache[read_layer][:, read_pos]
pert_readoff = cache_pert[read_layer][:, read_pos]
Lp_diff = torch.linalg.norm(ref_readoff - pert_readoff, ord=p, dim=-1)
return Lp_diff
def run_perturbed_activation(perturbed_act: torch.Tensor, ref: Reference):
pos = ref.replacement_pos
layer = ref.replacement_layer
def hook(act, hook):
act[:, pos, :] = perturbed_act
with ref.model.hooks(fwd_hooks=[(layer, hook)]):
prompts = torch.cat([ref.prompt for _ in range(len(perturbed_act))])
logits_pert, cache = ref.model.run_with_cache(prompts)
return logits_pert.to("cpu").detach(), cache.to("cpu")
def eval_activation(
perturbed_act: torch.Tensor,
base_ref: Reference,
unrelated_ref: Reference,
read_pos=-1,
):
read_layer = base_ref.read_layer
angle, dist = compute_angle_dist(base_ref.act, perturbed_act, datasetmean=None)
angle_wrt_datasetmean = angle # PLACEHOLDER
norm = lp_norm(perturbed_act).squeeze(-1)
logits_pert, cache = run_perturbed_activation(perturbed_act, base_ref)
base_kl_div = compute_kl_div(base_ref.logits, logits_pert)[:, read_pos]
base_l1_diff = compute_Lp_metric(base_ref.cache, cache, p=1, read_layer=read_layer, read_pos=read_pos)
base_l2_diff = compute_Lp_metric(base_ref.cache, cache, p=2, read_layer=read_layer, read_pos=read_pos)
base_logit_l2_diff = torch.linalg.norm(base_ref.logits - logits_pert, ord=2, dim=-1)
base_dim0_diff = base_ref.cache[read_layer][:, read_pos, 0] - cache[read_layer][:, read_pos, 0]
base_dim1_diff = base_ref.cache[read_layer][:, read_pos, 1] - cache[read_layer][:, read_pos, 1]
base_out_angle, _ = compute_angle_dist(
base_ref.cache[read_layer][:, read_pos, :],
cache[read_layer][:, read_pos, :],
datasetmean=None,
)
unrelated_kl_div = compute_kl_div(unrelated_ref.logits, logits_pert)[:, read_pos]
unrelated_l1_diff = compute_Lp_metric(unrelated_ref.cache, cache, p=1, read_layer=read_layer, read_pos=read_pos)
unrelated_l2_diff = compute_Lp_metric(unrelated_ref.cache, cache, p=2, read_layer=read_layer, read_pos=read_pos)
unrelated_logit_l2_diff = torch.linalg.norm(unrelated_ref.logits - logits_pert, ord=2, dim=-1)
unrelated_dim0_diff = unrelated_ref.cache[read_layer][:, read_pos, 0] - cache[read_layer][:, read_pos, 0]
unrelated_dim1_diff = unrelated_ref.cache[read_layer][:, read_pos, 1] - cache[read_layer][:, read_pos, 1]
unrelated_out_angle, _ = compute_angle_dist(
unrelated_ref.cache[read_layer][:, read_pos, :],
cache[read_layer][:, read_pos, :],
datasetmean=None,
)
return (
Result(
angle,
angle_wrt_datasetmean,
dist,
norm,
base_kl_div,
base_out_angle,
base_l2_diff,
base_logit_l2_diff,
base_dim0_diff,
base_dim1_diff,
),
Result(
angle,
angle_wrt_datasetmean,
dist,
norm,
unrelated_kl_div,
unrelated_out_angle,
unrelated_l2_diff,
unrelated_logit_l2_diff,
unrelated_dim0_diff,
unrelated_dim1_diff,
),
)
class Slerp:
def __init__(self, start, direction, datasetmean=None):
"""Return interpolation points along the sphere from
A towards direction B"""
self.datasetmean = datasetmean
self.A = start - self.datasetmean if self.datasetmean is not None else start
self.A_norm = lp_norm(self.A)
self.a = self.A / self.A_norm
d = direction / lp_norm(direction)
self.B = d - dot(d, self.a) * self.a
self.B_norm = lp_norm(self.B)
self.b = self.B / self.B_norm
def __call__(self, alpha):
result = self.A_norm * (torch.cos(alpha) * self.a + torch.sin(alpha) * self.b)
if self.datasetmean is not None:
return result + self.datasetmean
else:
return result
def get_alpha(self, X):
x = X / lp_norm(X)
return torch.acos(dot(x, self.a))
class Perturbation:
def gen_direction(self):
raise NotImplementedError
def __init__(
self,
base_ref: Reference,
unrelated_ref: Reference,
ensure_mean=True,
ensure_norm=True,
):
self.ensure_mean = ensure_mean
self.ensure_norm = ensure_norm
self.base_ref = base_ref
self.norm_base = lp_norm(base_ref.act)
self.unrelated_ref = unrelated_ref
def scan(self, n_steps: int = 33, range: tuple[float, float] = (0, 1.0), step_angular=True):
direction = self.gen_direction()
return self.scan_dir(direction, n_steps, range, step_angular)
def scan_dir(self, direction, n_steps: int = 33, range: tuple[float, float] = (0, 1.0), step_angular=True):
direction = direction.clone()
if self.ensure_mean:
direction -= mean(direction)
if self.ensure_norm:
direction *= self.norm_base / lp_norm(direction)
range = np.array(range)
if step_angular:
s = Slerp(self.base_ref.act, direction, datasetmean=None)
self.activation_steps = [s(alpha) for alpha in torch.linspace(*range, n_steps)]
else:
self.activation_steps = [self.base_ref.act + alpha * direction for alpha in torch.linspace(*range, n_steps)]
act = torch.cat(self.activation_steps, dim=0)
torch.cuda.empty_cache()
base_result, unrelated_result = eval_activation(act, self.base_ref, self.unrelated_ref)
return base_result, unrelated_result, direction
class RandomUniformPerturbation(Perturbation):
def gen_direction(self):
return torch.randn_like(self.base_ref.act)
class RandomPerturbation(Perturbation):
def gen_target(self):
return self.distrib.sample(self.base_ref.act.shape[:-1])
def gen_direction(self):
self.target = self.gen_target()
return self.target - self.base_ref.act
class RandomActDirPerturbation(Perturbation):
def gen_target(self):
return get_random_activation(
self.base_ref.model,
dataset,
self.base_ref.n_ctx,
self.base_ref.replacement_layer,
self.base_ref.replacement_pos,
)
def gen_direction(self):
self.target = self.gen_target()
return self.target - self.base_ref.act
class NeuronDirPerturbation(Perturbation):
def __init__(self, ref: Reference, negate=False, on_only=None):
super().__init__(ref)
self.base_ref = ref
self.negate = 1 if not negate else -1
if on_only is True:
feature_acts = self.base_ref.cache[ref.replacement_layer][0, -1, :]
self.active_features = feature_acts / feature_acts.max() > 0.1
def gen_direction(self):
if self.active_features is None:
random_int = random.randint(0, 768)
else:
random_int = random.choice(self.active_features.nonzero(as_tuple=True)[0])
one_hot = torch.zeros_like(self.base_ref.act)
one_hot[..., random_int] = 1
single_direction = self.negate * one_hot
return torch.stack([single_direction for _ in range(self.base_ref.act.shape[0])])
class OptimizedPerturbation(Perturbation):
def __init__(self, ref: Reference, learning_rate: float = 0.1):
super().__init__(ref)
self.learning_rate = learning_rate
def sample(self, num_steps: int, step_size: float = 0.01, sign=-1, init=None):
kl_divs = []
directions = []
if init is not None:
init = np.array(init.cpu().detach())
direction = torch.tensor(init, requires_grad=True, device=self.base_ref.act.device)
else:
direction = torch.randn_like(self.base_ref.act, requires_grad=True)
optimizer = torch.optim.SGD([direction], lr=self.learning_rate)
for _ in tqdm(range(num_steps), desc="Finding perturbation direction"):
kl_div, _, _ = eval_direction(direction, self.base_ref, step_size)
kl_divs.append(kl_div.item())
directions.append(direction.detach().clone())
(sign * kl_div).backward(retain_graph=True)
optimizer.step()
optimizer.zero_grad()
return kl_divs, directions
def format_ax(ax1, ax1t, ax2, ax2t, results, label, color, x, base_ref, ls="-", x_is_angle=True):
assert torch.allclose(results.angle.T - results.angle[:, 0].T, torch.tensor(0.0))
assert torch.allclose(results.dist.T - results.dist[:, 0].T, torch.tensor(0.0))
angles = results.angle[:, 0]
dists = results.dist[:, 0].detach()
ax1.plot(x, results.kl_div, label=label, color=color, lw=0.5, ls=ls)
ax2.plot(x, results.l2_diff, label=label, color=color, lw=0.5, ls=ls)
ax1.set_ylabel("KL divergence to base logits")
ax2.set_ylabel(f"L2 difference in {base_ref.read_layer}")
ax1.legend()
ax2.legend()
ax1.set_xlim(min(x), max(x))
ax2.set_xlim(min(x), max(x))
if len(angles) < 40:
tick_angles = angles
tick_dists = dists
else:
tick_angles = angles[::10]
tick_dists = dists[::10]
if x_is_angle:
ax1.set_xlabel(f"Angle from base activation at {layer}")
ax2.set_xlabel(f"Angle from base activation at {layer}")
ax1.set_xticks(tick_angles)
ax1.set_xticklabels([f"{angle * 180 / np.pi:.0f}°" for angle in tick_angles], rotation=-80)
ax2.set_xticks(tick_angles)
ax2.set_xticklabels([f"{angle * 180 / np.pi:.0f}°" for angle in tick_angles], rotation=-80)
if label is not None:
ax1t.set_xlabel("Distance")
ax2t.set_xlabel("Distance")
ax1t.set_xticks(ax1.get_xticks())
ax2t.set_xticks(ax2.get_xticks())
ax1t.set_xticklabels([f"{dist:.1f}" for dist in tick_dists], rotation=-80)
ax2t.set_xticklabels([f"{dist:.1f}" for dist in tick_dists], rotation=-80)
else:
ax1.set_xlabel(f"Distance from base activation at {layer}")
ax2.set_xlabel(f"Distance from base activation at {layer}")
ax1.set_xticks(tick_dists)
ax1.set_xticklabels([f"{dist:.1f}" for dist in tick_dists], rotation=-80)
ax2.set_xticks(tick_dists)
ax2.set_xticklabels([f"{dist:.1f}" for dist in tick_dists], rotation=-80)
if label is not None:
ax1t.set_xlabel("Angle")
ax2t.set_xlabel("Angle")
ax1t.set_xticks(ax1.get_xticks())
ax2t.set_xticks(ax2.get_xticks())
ax1t.set_xticklabels([f"{angle * 180 / np.pi:.0f}°" for angle in tick_angles], rotation=-80)
ax2t.set_xticklabels([f"{angle * 180 / np.pi:.0f}°" for angle in tick_angles], rotation=-80)
def find_sensitivity(x, y, threshold=0.5):
f = sip.interp1d(x, y, kind="linear")
angles_interp = np.linspace(x.min(), x.max(), 100000)
index_interp = np.argmax(f(angles_interp) > threshold)
return angles_interp[index_interp]
# %%
torch.cuda.empty_cache()
print(f"Allocated: {torch.cuda.memory_allocated() / 1024**3:.2f} GB")
seed = 0
step_angular = False
on_only = True
negate = on_only
set_seed(seed)
n_ctx = 10
layer = "blocks.1.hook_resid_pre"
pos = slice(-1, None, 1)
num_steps = 40
step_size = 1
learning_rate = 1
dataset = load_dataset("apollo-research/Skylion007-openwebtext-tokenizer-gpt2", split="train", streaming=False).with_format("torch")
dataloader = torch.utils.data.DataLoader(dataset, batch_size=15, shuffle=True)
# %%
model = HookedTransformer.from_pretrained("gpt2")
device = "cuda" if torch.cuda.is_available() else "cpu"
# %%
base_prompt = generate_prompt(dataset, model.tokenizer, n_ctx=n_ctx)
base_ref = Reference(model, base_prompt, layer, "blocks.11.hook_resid_post", pos, n_ctx)
unrelated_prompt = generate_prompt(dataset, model.tokenizer, n_ctx=n_ctx)
unrelated_ref = Reference(model, unrelated_prompt, layer, "blocks.11.hook_resid_post", pos, n_ctx)
# %%
torch.cuda.empty_cache()
print(f"Allocated: {torch.cuda.memory_allocated() / 1024**3:.2f} GB")
tensor_of_prompts = generate_prompt(dataset, model.tokenizer, n_ctx=n_ctx, batch=4000)
mean_act_cache = model.run_with_cache(tensor_of_prompts)[1].to("cpu")
plt.figure()
plt.plot(mean_act_cache[layer].mean(dim=0)[1:].mean(dim=0).cpu())
plt.title(f"Mean activations for {layer}")
plt.xlabel("Dimension")
plt.ylabel("Activation")
plt.savefig(f"plots/mean_activations_seed_{seed}.png")
plt.close()
print(f"Allocated: {torch.cuda.memory_allocated() / 1024**3:.2f} GB")
data = mean_act_cache[layer][:, -1, :]
data_mean = data.mean(dim=0, keepdim=True)
data_cov = torch.einsum("ij,ik->jk", data - data_mean, data - data_mean) / data.shape[0]
distrib = MultivariateNormal(data_mean.squeeze(0), data_cov)
# %%
random_perturbation = RandomPerturbation(base_ref, unrelated_ref)
random_perturbation.distrib = distrib
randomact_perturbation = RandomActDirPerturbation(base_ref, unrelated_ref)
perturbation = RandomPerturbation(base_ref, unrelated_ref)
perturbation.distrib = distrib
results_list = defaultdict(list)
for c, p in [
("C0", random_perturbation),
("C1", randomact_perturbation),
]:
target = p.gen_target()
perturbation.base_ref.act = target
perturbation.base_ref.logits, perturbation.base_ref.cache = run_perturbed_activation(target, perturbation.base_ref)
for i in tqdm(range(20)):
results, _, _ = perturbation.scan(n_steps=37, step_angular=True, range=(0, 1))
results_list[c].append(results)
# %%
fig, [ax1, ax2] = plt.subplots(1, 2, figsize=(12, 6), constrained_layout=True)
ax1t = ax1.twiny()
ax2t = ax2.twiny()
prompt_str = "|" + "|".join(model.to_str_tokens(base_ref.prompt)).replace("\n", "⏎") + "|"
title_str = "Sphere mode (slerping angle around origin)" if step_angular else "Straight mode"
fig.suptitle(f"{title_str}. Perturbing {layer} at pos {pos}.\nSeed = {seed}. Prompt = {prompt_str}. Norm = {lp_norm(base_ref.act).item():.2f}.")
for c in results_list:
label = "Random perturbation around random activation" if c == "C0" else "Random perturbation around (random) real activation" if c == "C1" else None
for i, r in enumerate(results_list[c]):
r.dist[:, 0].detach().cpu().numpy()
format_ax(
ax1,
ax1t,
ax2,
ax2t,
r,
label if i == 0 else None,
c,
r.angle[:, 0] if step_angular else r.dist[:, 0].detach(),
base_ref,
x_is_angle=step_angular,
)
mode_str = "sphere" if step_angular else "straight"
fig.savefig(f"plots/perturbations_{mode_str}_seed_{seed}.png", dpi=150, bbox_inches="tight")
plt.close()
# 2_naive.py
# %%
import os
import random
from collections.abc import Callable
from dataclasses import dataclass
import matplotlib
import matplotlib.pyplot as plt
matplotlib.use("Agg")
import numpy as np
import scipy.interpolate as sip
import torch
import torch.nn.functional as F
from datasets import load_dataset
from torch.distributions.multivariate_normal import MultivariateNormal
from tqdm import tqdm
from transformer_lens import HookedTransformer
os.environ["TOKENIZERS_PARALLELISM"] = "false"
os.makedirs("plots", exist_ok=True)
# %%
def lp_norm(direction, p=2):
"""Pytorch norm"""
return torch.linalg.vector_norm(direction, dim=-1, keepdim=True, ord=p)
def mean(direction):
return torch.mean(direction, dim=-1, keepdim=True)
def dot(x, y):
return torch.einsum("...k,...k->...", x, y).unsqueeze(-1)
def compute_angle_dist(ref_act, new_act, datasetmean=None):
angle = torch.acos(dot(new_act, ref_act) / (lp_norm(new_act) * lp_norm(ref_act))).squeeze(-1)
if torch.all(angle[0].isnan()):
angle[0] = torch.zeros_like(angle[0])
if torch.all(angle[-1].isnan()):
angle[-1] = np.pi * torch.ones_like(angle[-1])
dist = lp_norm(new_act - ref_act).squeeze(-1)
if datasetmean is None:
return angle, dist
else:
angle_wrt_datasetmean = torch.acos(dot(new_act - datasetmean, ref_act - datasetmean) / (lp_norm(new_act - datasetmean) * lp_norm(ref_act - datasetmean))).squeeze(
-1
)
return angle, dist, angle_wrt_datasetmean
class Reference:
def __init__(
self,
model: HookedTransformer,
prompt: torch.Tensor,
replacement_layer: str,
read_layer: str,
replacement_pos: slice,
n_ctx: int,
):
self.model = model
n_batch_prompt, n_ctx_prompt = prompt.shape
assert n_ctx == n_ctx_prompt, f"n_ctx {n_ctx} must match prompt n_ctx {n_ctx_prompt}"
self.prompt = prompt
logits, cache = model.run_with_cache(prompt)
self.logits = logits.to("cpu").detach()
self.cache = cache.to("cpu")
self.act = self.cache[replacement_layer][:, replacement_pos]
self.replacement_layer = replacement_layer
self.read_layer = read_layer
self.replacement_pos = replacement_pos
self.n_ctx = n_ctx
@dataclass
class Result:
angle: float
angle_wrt_datasetmean: float
dist: float
norm: float
kl_div: float
out_angle: float
l2_diff: float
logit_l2_diff: float
dim0_diff: float
dim1_diff: float
def set_seed(seed: int):
"""Set the random seed for reproducibility."""
torch.manual_seed(seed)
np.random.seed(seed)
random.seed(seed)
def generate_prompt(dataset, tokenizer: Callable, n_ctx: int = 1, batch: int = 1) -> torch.Tensor:
"""Generate a prompt from the dataset."""
dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch, shuffle=True)
return next(iter(dataloader))["input_ids"][:, :n_ctx]
tokens = [0, *sample[: n_ctx - 1]]
return torch.tensor([tokens[:n_ctx]])
def get_random_activation(model: HookedTransformer, dataset: torch.Tensor, n_ctx: int, layer: str, pos) -> torch.Tensor:
"""Get a random activation from the dataset."""
rand_prompt = generate_prompt(dataset, model.tokenizer, n_ctx=n_ctx)
_, cache = model.run_with_cache(rand_prompt)
return cache[layer][:, pos, :].to("cpu").detach()
def compute_kl_div(logits_ref: torch.Tensor, logits_pert: torch.Tensor) -> torch.Tensor:
"""Compute the KL divergence between the reference and perturbed logprobs."""
logprobs_ref = F.log_softmax(logits_ref, dim=-1)
logprobs_pert = F.log_softmax(logits_pert, dim=-1)
return F.kl_div(logprobs_pert, logprobs_ref, log_target=True, reduction="none").sum(dim=-1)
def compute_Lp_metric(
cache: dict,
cache_pert: dict,
p,
read_layer,
read_pos,
):
ref_readoff = cache[read_layer][:, read_pos]
pert_readoff = cache_pert[read_layer][:, read_pos]
Lp_diff = torch.linalg.norm(ref_readoff - pert_readoff, ord=p, dim=-1)
return Lp_diff
def run_perturbed_activation(perturbed_act: torch.Tensor, ref: Reference):
pos = ref.replacement_pos
layer = ref.replacement_layer
def hook(act, hook):
act[:, pos, :] = perturbed_act
with ref.model.hooks(fwd_hooks=[(layer, hook)]):
prompts = torch.cat([ref.prompt for _ in range(len(perturbed_act))])
logits_pert, cache = ref.model.run_with_cache(prompts)
return logits_pert.to("cpu").detach(), cache.to("cpu")
def eval_activation(
perturbed_act: torch.Tensor,
base_ref: Reference,
unrelated_ref: Reference,
read_pos=-1,
):
read_layer = base_ref.read_layer
angle, dist = compute_angle_dist(base_ref.act, perturbed_act, datasetmean=None)
angle_wrt_datasetmean = angle
norm = lp_norm(perturbed_act).squeeze(-1)
logits_pert, cache = run_perturbed_activation(perturbed_act, base_ref)
base_kl_div = compute_kl_div(base_ref.logits, logits_pert)[:, read_pos]
base_l1_diff = compute_Lp_metric(base_ref.cache, cache, p=1, read_layer=read_layer, read_pos=read_pos)
base_l2_diff = compute_Lp_metric(base_ref.cache, cache, p=2, read_layer=read_layer, read_pos=read_pos)
base_logit_l2_diff = torch.linalg.norm(base_ref.logits - logits_pert, ord=2, dim=-1)
base_dim0_diff = base_ref.cache[read_layer][:, read_pos, 0] - cache[read_layer][:, read_pos, 0]
base_dim1_diff = base_ref.cache[read_layer][:, read_pos, 1] - cache[read_layer][:, read_pos, 1]
base_out_angle, _ = compute_angle_dist(
base_ref.cache[read_layer][:, read_pos, :],
cache[read_layer][:, read_pos, :],
datasetmean=None,
)
unrelated_kl_div = compute_kl_div(unrelated_ref.logits, logits_pert)[:, read_pos]
unrelated_l1_diff = compute_Lp_metric(unrelated_ref.cache, cache, p=1, read_layer=read_layer, read_pos=read_pos)
unrelated_l2_diff = compute_Lp_metric(unrelated_ref.cache, cache, p=2, read_layer=read_layer, read_pos=read_pos)
unrelated_logit_l2_diff = torch.linalg.norm(unrelated_ref.logits - logits_pert, ord=2, dim=-1)
unrelated_dim0_diff = unrelated_ref.cache[read_layer][:, read_pos, 0] - cache[read_layer][:, read_pos, 0]
unrelated_dim1_diff = unrelated_ref.cache[read_layer][:, read_pos, 1] - cache[read_layer][:, read_pos, 1]
unrelated_out_angle, _ = compute_angle_dist(
unrelated_ref.cache[read_layer][:, read_pos, :],
cache[read_layer][:, read_pos, :],
datasetmean=None,
)
return (
Result(
angle,
angle_wrt_datasetmean,
dist,
norm,
base_kl_div,
base_out_angle,
base_l2_diff,
base_logit_l2_diff,
base_dim0_diff,
base_dim1_diff,
),
Result(
angle,
angle_wrt_datasetmean,
dist,
norm,
unrelated_kl_div,
unrelated_out_angle,
unrelated_l2_diff,
unrelated_logit_l2_diff,
unrelated_dim0_diff,
unrelated_dim1_diff,
),
)
class Slerp:
def __init__(self, start, direction, datasetmean=None):
"""Return interpolation points along the sphere from
A towards direction B"""
self.datasetmean = datasetmean
self.A = start - self.datasetmean if self.datasetmean is not None else start
self.A_norm = lp_norm(self.A)
self.a = self.A / self.A_norm
d = direction / lp_norm(direction)
self.B = d - dot(d, self.a) * self.a
self.B_norm = lp_norm(self.B)
self.b = self.B / self.B_norm
def __call__(self, alpha):
result = self.A_norm * (torch.cos(alpha) * self.a + torch.sin(alpha) * self.b)
if self.datasetmean is not None:
return result + self.datasetmean
else:
return result
def get_alpha(self, X):
x = X / lp_norm(X)
return torch.acos(dot(x, self.a))
class Perturbation:
def gen_direction(self):
raise NotImplementedError
def __init__(
self,
base_ref: Reference,
unrelated_ref: Reference,
ensure_mean=True,
ensure_norm=True,
):
self.ensure_mean = ensure_mean
self.ensure_norm = ensure_norm
self.base_ref = base_ref
self.norm_base = lp_norm(base_ref.act)
self.unrelated_ref = unrelated_ref
def scan(self, n_steps: int = 33, range: tuple[float, float] = (0, 1.0), step_angular=True):
direction = self.gen_direction()
return self.scan_dir(direction, n_steps, range, step_angular)
def scan_dir(self, direction, n_steps: int = 33, range: tuple[float, float] = (0, 1.0), step_angular=True):
direction = direction.clone()
if self.ensure_mean:
direction -= mean(direction)
if self.ensure_norm:
direction *= self.norm_base / lp_norm(direction)
range = np.array(range)
if step_angular:
s = Slerp(self.base_ref.act, direction, datasetmean=None)
self.activation_steps = [s(alpha) for alpha in torch.linspace(*range, n_steps)]
else:
self.activation_steps = [self.base_ref.act + alpha * direction for alpha in torch.linspace(*range, n_steps)]
act = torch.cat(self.activation_steps, dim=0)
torch.cuda.empty_cache()
base_result, unrelated_result = eval_activation(act, self.base_ref, self.unrelated_ref)
return base_result, unrelated_result, direction
class RandomUniformPerturbation(Perturbation):
def gen_direction(self):
return torch.randn_like(self.base_ref.act)
class RandomPerturbation(Perturbation):
def gen_target(self):
return self.distrib.sample(self.base_ref.act.shape[:-1])
def gen_direction(self):
self.target = self.gen_target()
return self.target - self.base_ref.act
class RandomActDirPerturbation(Perturbation):
def gen_target(self):
return get_random_activation(
self.base_ref.model,
dataset,
self.base_ref.n_ctx,
self.base_ref.replacement_layer,
self.base_ref.replacement_pos,
)
def gen_direction(self):
self.target = self.gen_target()
return self.target - self.base_ref.act
class NeuronDirPerturbation(Perturbation):
def __init__(self, ref: Reference, negate=False, on_only=None):
super().__init__(ref)
self.base_ref = ref
self.negate = 1 if not negate else -1
if on_only is True:
feature_acts = self.base_ref.cache[ref.replacement_layer][0, -1, :]
self.active_features = feature_acts / feature_acts.max() > 0.1
def gen_direction(self):
if self.active_features is None:
random_int = random.randint(0, 768)
else:
random_int = random.choice(self.active_features.nonzero(as_tuple=True)[0])
one_hot = torch.zeros_like(self.base_ref.act)
one_hot[..., random_int] = 1
single_direction = self.negate * one_hot
return torch.stack([single_direction for _ in range(self.base_ref.act.shape[0])])
class OptimizedPerturbation(Perturbation):
def __init__(self, ref: Reference, learning_rate: float = 0.1):
super().__init__(ref)
self.learning_rate = learning_rate
def sample(self, num_steps: int, step_size: float = 0.01, sign=-1, init=None):
kl_divs = []
directions = []
if init is not None:
init = np.array(init.cpu().detach())
direction = torch.tensor(init, requires_grad=True, device=self.base_ref.act.device)
else:
direction = torch.randn_like(self.base_ref.act, requires_grad=True)
optimizer = torch.optim.SGD([direction], lr=self.learning_rate)
for _ in tqdm(range(num_steps), desc="Finding perturbation direction"):
kl_div, _, _ = eval_direction(direction, self.base_ref, step_size)
kl_divs.append(kl_div.item())
directions.append(direction.detach().clone())
(sign * kl_div).backward(retain_graph=True)
optimizer.step()
optimizer.zero_grad()
return kl_divs, directions
def format_ax(ax1, ax1t, ax2, ax2t, results, label, color, x, base_ref, ls="-", x_is_angle=True):
assert torch.allclose(results.angle.T - results.angle[:, 0].T, torch.tensor(0.0))
assert torch.allclose(results.dist.T - results.dist[:, 0].T, torch.tensor(0.0))
angles = r.angle[:, 0]
dists = results.dist[:, 0]
ax1.plot(x, results.kl_div, label=label, color=color, lw=0.5, ls=ls)
ax2.plot(x, results.l2_diff, label=label, color=color, lw=0.5, ls=ls)
ax1.set_ylabel("KL divergence to base logits")
ax2.set_ylabel(f"L2 difference in {base_ref.read_layer}")
ax1.legend()
ax2.legend()
ax1.set_xlim(min(x), max(x))
ax2.set_xlim(min(x), max(x))
if len(angles) < 40:
tick_angles = angles
tick_dists = dists
else:
tick_angles = angles[::10]
tick_dists = dists[::10]
if x_is_angle:
ax1.set_xlabel(f"Angle from base activation at {layer}")
ax2.set_xlabel(f"Angle from base activation at {layer}")
ax1.set_xticks(tick_angles)
ax1.set_xticklabels([f"{angle * 180 / np.pi:.0f}°" for angle in tick_angles], rotation=-80)
ax2.set_xticks(tick_angles)
ax2.set_xticklabels([f"{angle * 180 / np.pi:.0f}°" for angle in tick_angles], rotation=-80)
if label is not None:
ax1t.set_xlabel("Distance")
ax2t.set_xlabel("Distance")
ax1t.set_xticks(ax1.get_xticks())
ax2t.set_xticks(ax2.get_xticks())
ax1t.set_xticklabels([f"{dist:.1f}" for dist in tick_dists], rotation=-80)
ax2t.set_xticklabels([f"{dist:.1f}" for dist in tick_dists], rotation=-80)
else:
ax1.set_xlabel(f"Distance from base activation at {layer}")
ax2.set_xlabel(f"Distance from base activation at {layer}")
ax1.set_xticks(tick_dists)
ax1.set_xticklabels([f"{dist:.1f}" for dist in tick_dists], rotation=-80)
ax2.set_xticks(tick_dists)
ax2.set_xticklabels([f"{dist:.1f}" for dist in tick_dists], rotation=-80)
if label is not None:
ax1t.set_xlabel("Angle")
ax2t.set_xlabel("Angle")
ax1t.set_xticks(ax1.get_xticks())
ax2t.set_xticks(ax2.get_xticks())
ax1t.set_xticklabels([f"{angle * 180 / np.pi:.0f}°" for angle in tick_angles], rotation=-80)
ax2t.set_xticklabels([f"{angle * 180 / np.pi:.0f}°" for angle in tick_angles], rotation=-80)
def find_sensitivity(x, y, threshold=0.5):
f = sip.interp1d(x, y, kind="linear")
angles_interp = np.linspace(x.min(), x.max(), 100000)
index_interp = np.argmax(f(angles_interp) > threshold)
return angles_interp[index_interp]
# %%
torch.cuda.empty_cache()
print(f"Allocated: {torch.cuda.memory_allocated() / 1024**3:.2f} GB")
seed = 1
step_angular = False
on_only = True
negate = on_only
set_seed(seed)
n_ctx = 10
layer = "blocks.1.hook_resid_pre"
pos = slice(-1, None, 1)
num_steps = 40
step_size = 1
learning_rate = 1
dataset = load_dataset("apollo-research/Skylion007-openwebtext-tokenizer-gpt2", split="train", streaming=False).with_format("torch")
dataloader = torch.utils.data.DataLoader(dataset, batch_size=15, shuffle=True)
# %%
model = HookedTransformer.from_pretrained("gpt2")
device = "cuda" if torch.cuda.is_available() else "cpu"
# %%
base_prompt = generate_prompt(dataset, model.tokenizer, n_ctx=n_ctx)
base_ref = Reference(model, base_prompt, layer, "blocks.11.hook_resid_post", pos, n_ctx)
unrelated_prompt = generate_prompt(dataset, model.tokenizer, n_ctx=n_ctx)
unrelated_ref = Reference(model, unrelated_prompt, layer, "blocks.11.hook_resid_post", pos, n_ctx)
# %%
torch.cuda.empty_cache()
print(f"Allocated: {torch.cuda.memory_allocated() / 1024**3:.2f} GB")
tensor_of_prompts = generate_prompt(dataset, model.tokenizer, n_ctx=n_ctx, batch=4000)
mean_act_cache = model.run_with_cache(tensor_of_prompts)[1].to("cpu")
plt.figure()
plt.plot(mean_act_cache[layer].mean(dim=0)[1:].mean(dim=0).cpu())
plt.title(f"Mean activations for {layer}")
plt.xlabel("Dimension")
plt.ylabel("Activation")
plt.savefig(f"plots/2_naive_mean_activations_seed_{seed}.png")
plt.close()
print(f"Allocated: {torch.cuda.memory_allocated() / 1024**3:.2f} GB")
data = mean_act_cache[layer][:, -1, :]
data_mean = data.mean(dim=0, keepdim=True)
data_cov = torch.einsum("ij,ik->jk", data - data_mean, data - data_mean) / data.shape[0]
distrib = MultivariateNormal(data_mean.squeeze(0), data_cov)
# %%
random_perturbation = RandomPerturbation(base_ref, unrelated_ref)
naiverandom_perturbation = RandomUniformPerturbation(base_ref, unrelated_ref)
random_perturbation.distrib = distrib
randomact_perturbation = RandomActDirPerturbation(base_ref, unrelated_ref)
# %%
results_list0 = []
results_list1 = []
results_list2 = []
results_list3 = []
for _ in tqdm(range(20)):
random_perturbation.gen_direction()
results, _, dir = random_perturbation.scan(n_steps=361, range=(0, np.pi), step_angular=step_angular)
results_list0.append(results)
randomact_perturbation.gen_direction()
results, _, dir = randomact_perturbation.scan(n_steps=361, range=(0, np.pi), step_angular=step_angular)
results_list1.append(results)
naiverandom_perturbation.gen_direction()
results, _, dir = naiverandom_perturbation.scan(n_steps=361, range=(0, np.pi), step_angular=step_angular)
results_list3.append(results)
# %%
fig, [ax1, ax2] = plt.subplots(1, 2, figsize=(12, 6), constrained_layout=True)
ax1t = ax1.twiny()
ax2t = ax2.twiny()
prompt_str = "|" + "|".join(model.to_str_tokens(base_ref.prompt)).replace("\n", "⏎") + "|"
title_str = "Sphere mode (slerping angle around origin)" if step_angular else "Straight mode"
fig.suptitle(f"{title_str}. Perturbing {layer} at pos {pos}.\nSeed = {seed}. Prompt = {prompt_str}. Norm = {lp_norm(base_ref.act).item():.2f}.")
for i, r in enumerate(results_list0):
format_ax(
ax1,
ax1t,
ax2,
ax2t,
r,
"Random direction perturbation" if i == 0 else None,
"C0",
r.angle[:, 0] if step_angular else r.dist[:, 0],
base_ref,
x_is_angle=step_angular,
)
for i, r in enumerate(results_list1):
format_ax(
ax1,
ax1t,
ax2,
ax2t,
r,
"Direction to random other activation" if i == 0 else None,
"C1",
r.angle[:, 0] if step_angular else r.dist[:, 0],
base_ref,
x_is_angle=step_angular,
)
for i, r in enumerate(results_list3):
format_ax(
ax1,
ax1t,
ax2,
ax2t,
r,
"Naive random direction perturbation" if i == 0 else None,
"C4",
r.angle[:, 0] if step_angular else r.dist[:, 0],
base_ref,
x_is_angle=step_angular,
)
mode_str = "sphere" if step_angular else "straight"
fig.savefig(f"plots/2_naive_perturbations_{mode_str}_seed_{seed}.png", dpi=150, bbox_inches="tight")
plt.close()
# %%
plt.figure(figsize=(8, 4))
plt.plot(results_list0[0].angle, torch.mean(torch.stack([r.kl_div for r in results_list0]), dim=0))
plt.plot(results_list1[0].angle, torch.mean(torch.stack([r.kl_div for r in results_list1]), dim=0))
plt.fill_between(
results_list0[0].angle[:, 0],
torch.min(torch.stack([r.kl_div for r in results_list0]), dim=0).values,
torch.max(torch.stack([r.kl_div for r in results_list0]), dim=0).values,
alpha=0.3,
)
plt.fill_between(
results_list1[0].angle[:, 0],
torch.min(torch.stack([r.kl_div for r in results_list1]), dim=0).values,
torch.max(torch.stack([r.kl_div for r in results_list1]), dim=0).values,
alpha=0.3,
)
plt.xlabel("Angle from base activation")
tick_angles = results_list0[0].angle[::10, 0]
plt.xticks(tick_angles, [f"{angle * 180 / np.pi:.0f}°" for angle in tick_angles], rotation=-80)
plt.ylabel("Average KL divergence to base logits")
plt.title("Average KL divergence comparison")
plt.savefig(f"plots/2_naive_average_kl_div_seed_{seed}.png", dpi=150, bbox_inches="tight")
plt.close()
# See code in figure1.py
# 2_2a.py
# %%
import os
import random
from collections.abc import Callable
from dataclasses import dataclass
import matplotlib
import matplotlib.pyplot as plt
matplotlib.use("Agg") # Use non-interactive backend
import numpy as np
import scipy.interpolate as sip
import torch
import torch.nn.functional as F
from datasets import load_dataset
from torch.distributions.multivariate_normal import MultivariateNormal
from tqdm import tqdm
from transformer_lens import HookedTransformer
os.environ["TOKENIZERS_PARALLELISM"] = "false"
os.makedirs("plots", exist_ok=True)
# %%
def lp_norm(direction, p=2):
"""Pytorch norm"""
return torch.linalg.vector_norm(direction, dim=-1, keepdim=True, ord=p)
def mean(direction):
return torch.mean(direction, dim=-1, keepdim=True)
def dot(x, y):
return torch.einsum("...k,...k->...", x, y).unsqueeze(-1)
def compute_angle_dist(ref_act, new_act, datasetmean=None):
angle = torch.acos(dot(new_act, ref_act) / (lp_norm(new_act) * lp_norm(ref_act))).squeeze(-1)
if torch.all(angle[0].isnan()):
angle[0] = torch.zeros_like(angle[0])
if torch.all(angle[-1].isnan()):
angle[-1] = np.pi * torch.ones_like(angle[-1])
dist = lp_norm(new_act - ref_act).squeeze(-1)
if datasetmean is None:
return angle, dist
else:
angle_wrt_datasetmean = torch.acos(dot(new_act - datasetmean, ref_act - datasetmean) / (lp_norm(new_act - datasetmean) * lp_norm(ref_act - datasetmean))).squeeze(
-1
)
return angle, dist, angle_wrt_datasetmean
class Reference:
def __init__(
self,
model: HookedTransformer,
prompt: torch.Tensor,
replacement_layer: str,
read_layer: str,
replacement_pos: slice,
n_ctx: int,
):
self.model = model
n_batch_prompt, n_ctx_prompt = prompt.shape
assert n_ctx == n_ctx_prompt, f"n_ctx {n_ctx} must match prompt n_ctx {n_ctx_prompt}"
self.prompt = prompt
logits, cache = model.run_with_cache(prompt)
self.logits = logits.to("cpu").detach()
self.cache = cache.to("cpu")
self.act = self.cache[replacement_layer][:, replacement_pos]
self.replacement_layer = replacement_layer
self.read_layer = read_layer
self.replacement_pos = replacement_pos
self.n_ctx = n_ctx
@dataclass
class Result:
angle: float
angle_wrt_datasetmean: float
dist: float
norm: float
kl_div: float
out_angle: float
l2_diff: float
logit_l2_diff: float
dim0_diff: float
dim1_diff: float
def set_seed(seed: int):
"""Set the random seed for reproducibility."""
torch.manual_seed(seed)
np.random.seed(seed)
random.seed(seed)
def generate_prompt(dataset, tokenizer: Callable, n_ctx: int = 1, batch: int = 1) -> torch.Tensor:
"""Generate a prompt from the dataset."""
dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch, shuffle=True)
return next(iter(dataloader))["input_ids"][:, :n_ctx]
tokens = [0, *sample[: n_ctx - 1]]
return torch.tensor([tokens[:n_ctx]])
def get_random_activation(model: HookedTransformer, dataset: torch.Tensor, n_ctx: int, layer: str, pos) -> torch.Tensor:
"""Get a random activation from the dataset."""
rand_prompt = generate_prompt(dataset, model.tokenizer, n_ctx=n_ctx)
_, cache = model.run_with_cache(rand_prompt)
return cache[layer][:, pos, :].to("cpu").detach()
def compute_kl_div(logits_ref: torch.Tensor, logits_pert: torch.Tensor) -> torch.Tensor:
"""Compute the KL divergence between the reference and perturbed logprobs."""
logprobs_ref = F.log_softmax(logits_ref, dim=-1)
logprobs_pert = F.log_softmax(logits_pert, dim=-1)
return F.kl_div(logprobs_pert, logprobs_ref, log_target=True, reduction="none").sum(dim=-1)
def compute_Lp_metric(
cache: dict,
cache_pert: dict,
p,
read_layer,
read_pos,
):
ref_readoff = cache[read_layer][:, read_pos]
pert_readoff = cache_pert[read_layer][:, read_pos]
Lp_diff = torch.linalg.norm(ref_readoff - pert_readoff, ord=p, dim=-1)
return Lp_diff
def run_perturbed_activation(perturbed_act: torch.Tensor, ref: Reference):
pos = ref.replacement_pos
layer = ref.replacement_layer
def hook(act, hook):
act[:, pos, :] = perturbed_act
with ref.model.hooks(fwd_hooks=[(layer, hook)]):
prompts = torch.cat([ref.prompt for _ in range(len(perturbed_act))])
logits_pert, cache = ref.model.run_with_cache(prompts)
return logits_pert.to("cpu").detach(), cache.to("cpu")
def eval_activation(
perturbed_act: torch.Tensor,
base_ref: Reference,
unrelated_ref: Reference,
read_pos=-1,
):
read_layer = base_ref.read_layer
angle, dist = compute_angle_dist(base_ref.act, perturbed_act, datasetmean=None)
angle_wrt_datasetmean = angle # PLACEHOLDER
norm = lp_norm(perturbed_act).squeeze(-1)
logits_pert, cache = run_perturbed_activation(perturbed_act, base_ref)
base_kl_div = compute_kl_div(base_ref.logits, logits_pert)[:, read_pos]
base_l1_diff = compute_Lp_metric(base_ref.cache, cache, p=1, read_layer=read_layer, read_pos=read_pos)
base_l2_diff = compute_Lp_metric(base_ref.cache, cache, p=2, read_layer=read_layer, read_pos=read_pos)
base_logit_l2_diff = torch.linalg.norm(base_ref.logits - logits_pert, ord=2, dim=-1)
base_dim0_diff = base_ref.cache[read_layer][:, read_pos, 0] - cache[read_layer][:, read_pos, 0]
base_dim1_diff = base_ref.cache[read_layer][:, read_pos, 1] - cache[read_layer][:, read_pos, 1]
base_out_angle, _ = compute_angle_dist(
base_ref.cache[read_layer][:, read_pos, :],
cache[read_layer][:, read_pos, :],
datasetmean=None,
)
unrelated_kl_div = compute_kl_div(unrelated_ref.logits, logits_pert)[:, read_pos]
unrelated_l1_diff = compute_Lp_metric(unrelated_ref.cache, cache, p=1, read_layer=read_layer, read_pos=read_pos)
unrelated_l2_diff = compute_Lp_metric(unrelated_ref.cache, cache, p=2, read_layer=read_layer, read_pos=read_pos)
unrelated_logit_l2_diff = torch.linalg.norm(unrelated_ref.logits - logits_pert, ord=2, dim=-1)
unrelated_dim0_diff = unrelated_ref.cache[read_layer][:, read_pos, 0] - cache[read_layer][:, read_pos, 0]
unrelated_dim1_diff = unrelated_ref.cache[read_layer][:, read_pos, 1] - cache[read_layer][:, read_pos, 1]
unrelated_out_angle, _ = compute_angle_dist(
unrelated_ref.cache[read_layer][:, read_pos, :],
cache[read_layer][:, read_pos, :],
datasetmean=None,
)
return (
Result(
angle,
angle_wrt_datasetmean,
dist,
norm,
base_kl_div,
base_out_angle,
base_l2_diff,
base_logit_l2_diff,
base_dim0_diff,
base_dim1_diff,
),
Result(
angle,
angle_wrt_datasetmean,
dist,
norm,
unrelated_kl_div,
unrelated_out_angle,
unrelated_l2_diff,
unrelated_logit_l2_diff,
unrelated_dim0_diff,
unrelated_dim1_diff,
),
)
class Slerp:
def __init__(self, start, direction, datasetmean=None):
"""Return interpolation points along the sphere from
A towards direction B"""
self.datasetmean = datasetmean
self.A = start - self.datasetmean if self.datasetmean is not None else start
self.A_norm = lp_norm(self.A) # magnitude
self.a = self.A / self.A_norm # unit vectors
d = direction / lp_norm(direction)
self.B = d - dot(d, self.a) * self.a
self.B_norm = lp_norm(self.B) # magnitude
self.b = self.B / self.B_norm # unit vectors
def __call__(self, alpha):
result = self.A_norm * (torch.cos(alpha) * self.a + torch.sin(alpha) * self.b)
if self.datasetmean is not None:
return result + self.datasetmean
else:
return result
def get_alpha(self, X):
x = X / lp_norm(X)
return torch.acos(dot(x, self.a))
class Perturbation:
def gen_direction(self):
raise NotImplementedError
def __init__(
self,
base_ref: Reference,
unrelated_ref: Reference,
ensure_mean=True,
ensure_norm=True,
):
self.ensure_mean = ensure_mean
self.ensure_norm = ensure_norm
self.base_ref = base_ref
self.norm_base = lp_norm(base_ref.act)
self.unrelated_ref = unrelated_ref
def scan(self, n_steps: int = 33, range: tuple[float, float] = (0, 1.0), step_angular=True):
direction = self.gen_direction()
return self.scan_dir(direction, n_steps, range, step_angular)
def scan_dir(self, direction, n_steps: int = 33, range: tuple[float, float] = (0, 1.0), step_angular=True):
direction = direction.clone()
if self.ensure_mean:
direction -= mean(direction)
if self.ensure_norm:
direction *= self.norm_base / lp_norm(direction)
range = np.array(range)
if step_angular:
s = Slerp(self.base_ref.act, direction, datasetmean=None)
self.activation_steps = [s(alpha) for alpha in torch.linspace(*range, n_steps)]
else:
self.activation_steps = [self.base_ref.act + alpha * direction for alpha in torch.linspace(*range, n_steps)]
act = torch.cat(self.activation_steps, dim=0)
torch.cuda.empty_cache()
base_result, unrelated_result = eval_activation(act, self.base_ref, self.unrelated_ref)
return base_result, unrelated_result, direction
class RandomUniformPerturbation(Perturbation):
def gen_direction(self):
return torch.randn_like(self.base_ref.act)
class RandomPerturbation(Perturbation):
def gen_target(self):
return self.distrib.sample(self.base_ref.act.shape[:-1])
def gen_direction(self):
self.target = self.gen_target()
return self.target - self.base_ref.act
class RandomActDirPerturbation(Perturbation):
def gen_target(self):
return get_random_activation(
self.base_ref.model,
dataset,
self.base_ref.n_ctx,
self.base_ref.replacement_layer,
self.base_ref.replacement_pos,
)
def gen_direction(self):
self.target = self.gen_target()
return self.target - self.base_ref.act
class NeuronDirPerturbation(Perturbation):
def __init__(self, ref: Reference, negate=False, on_only=None):
super().__init__(ref)
self.base_ref = ref
self.negate = 1 if not negate else -1
if on_only is True:
feature_acts = self.base_ref.cache[ref.replacement_layer][0, -1, :]
self.active_features = feature_acts / feature_acts.max() > 0.1
def gen_direction(self):
if self.active_features is None:
random_int = random.randint(0, 768)
else:
random_int = random.choice(self.active_features.nonzero(as_tuple=True)[0])
one_hot = torch.zeros_like(self.base_ref.act)
one_hot[..., random_int] = 1
single_direction = self.negate * one_hot
return torch.stack([single_direction for _ in range(self.base_ref.act.shape[0])])
class OptimizedPerturbation(Perturbation):
def __init__(self, ref: Reference, learning_rate: float = 0.1):
super().__init__(ref)
self.learning_rate = learning_rate
def sample(self, num_steps: int, step_size: float = 0.01, sign=-1, init=None):
kl_divs = []
directions = []
if init is not None:
init = np.array(init.cpu().detach())
direction = torch.tensor(init, requires_grad=True, device=self.base_ref.act.device)
else:
direction = torch.randn_like(self.base_ref.act, requires_grad=True)
optimizer = torch.optim.SGD([direction], lr=self.learning_rate)
for _ in tqdm(range(num_steps), desc="Finding perturbation direction"):
kl_div, _, _ = eval_direction(direction, self.base_ref, step_size)
kl_divs.append(kl_div.item())
directions.append(direction.detach().clone())
(sign * kl_div).backward(retain_graph=True)
optimizer.step()
optimizer.zero_grad()
return kl_divs, directions
def format_ax(ax1, ax1t, ax2, ax2t, results, label, color, x, base_ref, ls="-", x_is_angle=True):
assert torch.allclose(results.angle.T - results.angle[:, 0].T, torch.tensor(0.0))
assert torch.allclose(results.dist.T - results.dist[:, 0].T, torch.tensor(0.0))
angles = r.angle[:, 0]
dists = results.dist[:, 0]
ax1.plot(x, results.kl_div, label=label, color=color, lw=0.5, ls=ls)
ax2.plot(x, results.l2_diff, label=label, color=color, lw=0.5, ls=ls)
ax1.set_ylabel("KL divergence to base logits")
ax2.set_ylabel(f"L2 difference in {base_ref.read_layer}")
ax1.legend()
ax2.legend()
ax1.set_xlim(min(x), max(x))
ax2.set_xlim(min(x), max(x))
if len(angles) < 40:
tick_angles = angles
tick_dists = dists
else:
tick_angles = angles[::10]
tick_dists = dists[::10]
if x_is_angle:
ax1.set_xlabel(f"Angle from base activation at {layer}")
ax2.set_xlabel(f"Angle from base activation at {layer}")
ax1.set_xticks(tick_angles)
ax1.set_xticklabels([f"{angle * 180 / np.pi:.0f}°" for angle in tick_angles], rotation=-80)
ax2.set_xticks(tick_angles)
ax2.set_xticklabels([f"{angle * 180 / np.pi:.0f}°" for angle in tick_angles], rotation=-80)
if label is not None:
ax1t.set_xlabel("Distance")
ax2t.set_xlabel("Distance")
ax1t.set_xticks(ax1.get_xticks())
ax2t.set_xticks(ax2.get_xticks())
ax1t.set_xticklabels([f"{dist:.1f}" for dist in tick_dists], rotation=-80)
ax2t.set_xticklabels([f"{dist:.1f}" for dist in tick_dists], rotation=-80)
else:
ax1.set_xlabel(f"Distance from base activation at {layer}")
ax2.set_xlabel(f"Distance from base activation at {layer}")
ax1.set_xticks(tick_dists)
ax1.set_xticklabels([f"{dist:.1f}" for dist in tick_dists], rotation=-80)
ax2.set_xticks(tick_dists)
ax2.set_xticklabels([f"{dist:.1f}" for dist in tick_dists], rotation=-80)
if label is not None:
ax1t.set_xlabel("Angle")
ax2t.set_xlabel("Angle")
ax1t.set_xticks(ax1.get_xticks())
ax2t.set_xticks(ax2.get_xticks())
ax1t.set_xticklabels([f"{angle * 180 / np.pi:.0f}°" for angle in tick_angles], rotation=-80)
ax2t.set_xticklabels([f"{angle * 180 / np.pi:.0f}°" for angle in tick_angles], rotation=-80)
def find_sensitivity(x, y, threshold=0.5):
f = sip.interp1d(x, y, kind="linear")
angles_interp = np.linspace(x.min(), x.max(), 100000)
index_interp = np.argmax(f(angles_interp) > threshold)
return angles_interp[index_interp]
# %%
torch.cuda.empty_cache()
print(f"Allocated: {torch.cuda.memory_allocated() / 1024**3:.2f} GB")
################################
seed = 0
step_angular = True
on_only = True
negate = on_only
################
set_seed(seed)
n_ctx = 10
layer = "blocks.1.hook_resid_pre"
pos = slice(-1, None, 1)
num_steps = 40
step_size = 1
learning_rate = 1
dataset = load_dataset("apollo-research/Skylion007-openwebtext-tokenizer-gpt2", split="train", streaming=False).with_format("torch")
dataloader = torch.utils.data.DataLoader(dataset, batch_size=15, shuffle=True)
# %%
model = HookedTransformer.from_pretrained("gpt2")
device = "cuda" if torch.cuda.is_available() else "cpu"
# %%
base_prompt = generate_prompt(dataset, model.tokenizer, n_ctx=n_ctx)
base_ref = Reference(model, base_prompt, layer, "blocks.11.hook_resid_post", pos, n_ctx)
unrelated_prompt = generate_prompt(dataset, model.tokenizer, n_ctx=n_ctx)
unrelated_ref = Reference(model, unrelated_prompt, layer, "blocks.11.hook_resid_post", pos, n_ctx)
# %%
torch.cuda.empty_cache()
print(f"Allocated: {torch.cuda.memory_allocated() / 1024**3:.2f} GB")
tensor_of_prompts = generate_prompt(dataset, model.tokenizer, n_ctx=n_ctx, batch=4000)
mean_act_cache = model.run_with_cache(tensor_of_prompts)[1].to("cpu")
plt.figure()
plt.plot(mean_act_cache[layer].mean(dim=0)[1:].mean(dim=0).cpu())
plt.title(f"Mean activations for {layer}")
plt.xlabel("Dimension")
plt.ylabel("Activation")
plt.savefig(f"plots/2_2a_mean_activations_seed_{seed}.png")
plt.close()
print(f"Allocated: {torch.cuda.memory_allocated() / 1024**3:.2f} GB")
data = mean_act_cache[layer][:, -1, :]
data_mean = data.mean(dim=0, keepdim=True)
data_cov = torch.einsum("ij,ik->jk", data - data_mean, data - data_mean) / data.shape[0]
distrib = MultivariateNormal(data_mean.squeeze(0), data_cov)
# %%
random_perturbation = RandomPerturbation(base_ref, unrelated_ref)
random_perturbation.distrib = distrib
randomact_perturbation = RandomActDirPerturbation(base_ref, unrelated_ref)
# %%
results_list0 = []
results_list1 = []
results_list2 = []
for _ in tqdm(range(20)):
random_perturbation.gen_direction()
results, _, dir = random_perturbation.scan(n_steps=361, range=(0, np.pi), step_angular=step_angular)
results_list0.append(results)
randomact_perturbation.gen_direction()
results, _, dir = randomact_perturbation.scan(n_steps=361, range=(0, np.pi), step_angular=step_angular)
results_list1.append(results)
results_list2.append(results)
# %%
fig, [ax1, ax2] = plt.subplots(1, 2, figsize=(12, 6), constrained_layout=True)
ax1t = ax1.twiny()
ax2t = ax2.twiny()
prompt_str = "|" + "|".join(model.to_str_tokens(base_ref.prompt)).replace("\n", "⏎") + "|"
title_str = "Sphere mode (slerping angle around origin)" if step_angular else "Straight mode"
fig.suptitle(f"{title_str}. Perturbing {layer} at pos {pos}.\nSeed = {seed}. Prompt = {prompt_str}. Norm = {lp_norm(base_ref.act).item():.2f}.")
for i, r in enumerate(results_list0):
format_ax(
ax1,
ax1t,
ax2,
ax2t,
r,
"Random direction perturbation" if i == 0 else None,
"C0",
r.angle[:, 0] if step_angular else r.dist[:, 0],
base_ref,
x_is_angle=step_angular,
)
for i, r in enumerate(results_list1):
format_ax(
ax1,
ax1t,
ax2,
ax2t,
r,
"Direction to random other activation" if i == 0 else None,
"C1",
r.angle[:, 0] if step_angular else r.dist[:, 0],
base_ref,
x_is_angle=step_angular,
)
mode_str = "sphere" if step_angular else "straight"
fig.savefig(f"plots/2_2a_perturbations_{mode_str}_seed_{seed}.png", dpi=150, bbox_inches="tight")
plt.close()
# %%
plt.figure(figsize=(8, 4))
plt.plot(results_list0[0].angle, torch.mean(torch.stack([r.kl_div for r in results_list0]), dim=0))
plt.plot(results_list1[0].angle, torch.mean(torch.stack([r.kl_div for r in results_list1]), dim=0))
plt.plot(results_list2[0].angle, torch.mean(torch.stack([r.kl_div for r in results_list2]), dim=0))
plt.fill_between(
results_list0[0].angle[:, 0],
torch.min(torch.stack([r.kl_div for r in results_list0]), dim=0).values,
torch.max(torch.stack([r.kl_div for r in results_list0]), dim=0).values,
alpha=0.3,
)
plt.fill_between(
results_list1[0].angle[:, 0],
torch.min(torch.stack([r.kl_div for r in results_list1]), dim=0).values,
torch.max(torch.stack([r.kl_div for r in results_list1]), dim=0).values,
alpha=0.3,
)
plt.xlabel("Angle from base activation")
tick_angles = results_list0[0].angle[::10, 0]
plt.xticks(tick_angles, [f"{angle * 180 / np.pi:.0f}°" for angle in tick_angles], rotation=-80)
plt.ylabel("Average KL divergence to base logits")
plt.title("Average KL divergence comparison")
plt.savefig(f"plots/2_2a_average_kl_div_seed_{seed}.png", dpi=150, bbox_inches="tight")
plt.close()
import os
import random
from collections import defaultdict
from collections.abc import Callable
from dataclasses import dataclass
import matplotlib.pyplot as plt
import numpy as np
import scipy.interpolate as sip
import torch
import torch.nn.functional as F
from datasets import load_dataset
from sae_lens import SAE
from torch.distributions.multivariate_normal import MultivariateNormal
from tqdm import tqdm
from transformer_lens import HookedTransformer
os.environ["TOKENIZERS_PARALLELISM"] = "false"
def lp_norm(direction, p=2):
"""Pytorch norm"""
return torch.linalg.vector_norm(direction, dim=-1, keepdim=True, ord=p)
def mean(direction):
return torch.mean(direction, dim=-1, keepdim=True)
def dot(x, y):
return torch.einsum("...k,...k->...", x, y).unsqueeze(-1)
def compute_angle_dist(ref_act, new_act, datasetmean=None):
angle = torch.acos(dot(new_act, ref_act) / (lp_norm(new_act) * lp_norm(ref_act))).squeeze(-1)
if torch.all(angle[0].isnan()):
angle[0] = torch.zeros_like(angle[0])
if torch.all(angle[-1].isnan()):
angle[-1] = np.pi * torch.ones_like(angle[-1])
dist = lp_norm(new_act - ref_act).squeeze(-1)
if datasetmean is None:
return angle, dist
else:
angle_wrt_datasetmean = torch.acos(dot(new_act - datasetmean, ref_act - datasetmean) / (lp_norm(new_act - datasetmean) * lp_norm(ref_act - datasetmean))).squeeze(
-1
)
return angle, dist, angle_wrt_datasetmean
class Reference:
def __init__(
self,
model: HookedTransformer,
prompt: torch.Tensor,
replacement_layer: str,
replacement_pos: slice,
n_ctx: int,
):
self.model = model
n_batch_prompt, n_ctx_prompt = prompt.shape
assert n_ctx == n_ctx_prompt, f"n_ctx {n_ctx} must match prompt n_ctx {n_ctx_prompt}"
self.prompt = prompt
logits, cache = model.run_with_cache(prompt)
self.logits = logits.to("cpu").detach()
self.cache = cache.to("cpu")
self.act = self.cache[replacement_layer][:, replacement_pos]
self.layer = replacement_layer
self.pos = replacement_pos
self.n_ctx = n_ctx
@dataclass
class Result:
angle: float
angle_wrt_datasetmean: float
dist: float
norm: float
kl_div: float
out_angle: float
l2_diff: float
logit_l2_diff: float
dim0_diff: float
dim1_diff: float
def set_seed(seed: int):
"""Set the random seed for reproducibility."""
torch.manual_seed(seed)
np.random.seed(seed)
random.seed(seed)
def generate_prompt(dataset, tokenizer: Callable, n_ctx: int = 1, batch: int = 1) -> torch.Tensor:
"""Generate a prompt from the dataset."""
dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch, shuffle=True)
return next(iter(dataloader))["input_ids"][:, :n_ctx]
tokens = [0, *sample[: n_ctx - 1]]
return torch.tensor([tokens[:n_ctx]])
def get_random_activation(model: HookedTransformer, dataset: torch.Tensor, n_ctx: int, layer: str, pos) -> torch.Tensor:
"""Get a random activation from the dataset."""
rand_prompt = generate_prompt(dataset, model.tokenizer, n_ctx=n_ctx)
_, cache = model.run_with_cache(rand_prompt)
return cache[layer][:, pos, :].to("cpu").detach()
def compute_kl_div(logits_ref: torch.Tensor, logits_pert: torch.Tensor) -> torch.Tensor:
"""Compute the KL divergence between the reference and perturbed logprobs."""
logprobs_ref = F.log_softmax(logits_ref, dim=-1)
logprobs_pert = F.log_softmax(logits_pert, dim=-1)
return F.kl_div(logprobs_pert, logprobs_ref, log_target=True, reduction="none").sum(dim=-1)
def compute_Lp_metric(
cache: dict,
cache_pert: dict,
p,
read_layer,
read_pos,
):
ref_readoff = cache[read_layer][:, read_pos]
pert_readoff = cache_pert[read_layer][:, read_pos]
Lp_diff = torch.linalg.norm(ref_readoff - pert_readoff, ord=p, dim=-1)
return Lp_diff
def run_perturbed_activation(perturbed_act: torch.Tensor, ref: Reference):
pos = ref.pos
layer = ref.layer
def hook(act, hook):
act[:, pos, :] = perturbed_act
with ref.model.hooks(fwd_hooks=[(layer, hook)]):
prompts = torch.cat([ref.prompt for _ in range(len(perturbed_act))])
logits_pert, cache = ref.model.run_with_cache(prompts)
return logits_pert.to("cpu").detach(), cache.to("cpu")
def eval_activation(
perturbed_act: torch.Tensor,
base_ref: Reference,
unrelated_ref: Reference,
read_layer="ln_final.hook_normalized",
read_pos=-1,
):
angle, dist = compute_angle_dist(base_ref.act, perturbed_act, datasetmean=None)
angle_wrt_datasetmean = angle
norm = lp_norm(perturbed_act).squeeze(-1)
logits_pert, cache = run_perturbed_activation(perturbed_act, base_ref)
base_kl_div = compute_kl_div(base_ref.logits, logits_pert)[:, read_pos]
base_l1_diff = compute_Lp_metric(base_ref.cache, cache, p=1, read_layer=read_layer, read_pos=read_pos)
base_l2_diff = compute_Lp_metric(base_ref.cache, cache, p=2, read_layer=read_layer, read_pos=read_pos)
base_logit_l2_diff = torch.linalg.norm(base_ref.logits - logits_pert, ord=2, dim=-1)
base_dim0_diff = base_ref.cache[read_layer][:, read_pos, 0] - cache[read_layer][:, read_pos, 0]
base_dim1_diff = base_ref.cache[read_layer][:, read_pos, 1] - cache[read_layer][:, read_pos, 1]
base_out_angle, _ = compute_angle_dist(
base_ref.cache[read_layer][:, read_pos, :],
cache[read_layer][:, read_pos, :],
datasetmean=None,
)
unrelated_kl_div = compute_kl_div(unrelated_ref.logits, logits_pert)[:, read_pos]
unrelated_l1_diff = compute_Lp_metric(unrelated_ref.cache, cache, p=1, read_layer=read_layer, read_pos=read_pos)
unrelated_l2_diff = compute_Lp_metric(unrelated_ref.cache, cache, p=2, read_layer=read_layer, read_pos=read_pos)
unrelated_logit_l2_diff = torch.linalg.norm(unrelated_ref.logits - logits_pert, ord=2, dim=-1)
unrelated_dim0_diff = unrelated_ref.cache[read_layer][:, read_pos, 0] - cache[read_layer][:, read_pos, 0]
unrelated_dim1_diff = unrelated_ref.cache[read_layer][:, read_pos, 1] - cache[read_layer][:, read_pos, 1]
unrelated_out_angle, _ = compute_angle_dist(
unrelated_ref.cache[read_layer][:, read_pos, :],
cache[read_layer][:, read_pos, :],
datasetmean=None,
)
return (
Result(
angle,
angle_wrt_datasetmean,
dist,
norm,
base_kl_div,
base_out_angle,
base_l2_diff,
base_logit_l2_diff,
base_dim0_diff,
base_dim1_diff,
),
Result(
angle,
angle_wrt_datasetmean,
dist,
norm,
unrelated_kl_div,
unrelated_out_angle,
unrelated_l2_diff,
unrelated_logit_l2_diff,
unrelated_dim0_diff,
unrelated_dim1_diff,
),
)
class Slerp:
def __init__(self, start, direction, datasetmean=None):
"""Return interpolation points along the sphere from
A towards direction B"""
self.datasetmean = datasetmean
self.A = start - self.datasetmean if self.datasetmean is not None else start
self.A_norm = lp_norm(self.A) # magnitude
self.a = self.A / self.A_norm # unit vectors
d = direction / lp_norm(direction)
self.B = d - dot(d, self.a) * self.a
self.B_norm = lp_norm(self.B) # magnitude
self.b = self.B / self.B_norm # unit vectors
def __call__(self, alpha):
result = self.A_norm * (torch.cos(alpha) * self.a + torch.sin(alpha) * self.b)
if self.datasetmean is not None:
return result + self.datasetmean
else:
return result
def get_alpha(self, X):
x = X / lp_norm(X)
return torch.acos(dot(x, self.a))
class Perturbation:
def gen_direction(self):
raise NotImplementedError
def __init__(
self,
base_ref: Reference,
unrelated_ref: Reference,
ensure_mean=True,
ensure_norm=True,
project_sphere=False,
):
self.ensure_mean = ensure_mean
self.ensure_norm = ensure_norm
self.project_sphere = project_sphere
self.base_ref = base_ref
self.norm_base = lp_norm(base_ref.act)
self.unrelated_ref = unrelated_ref
def scan(self, n_steps: int = 33, range: tuple[float, float] = (0, 1.0), step_angular=True):
direction = self.gen_direction()
if self.ensure_mean:
direction -= mean(direction)
if self.ensure_norm:
direction *= self.norm_base / lp_norm(direction)
range = np.array(range)
if step_angular:
s = Slerp(self.base_ref.act, direction, datasetmean=None)
self.activation_steps = [s(alpha) for alpha in torch.linspace(*range, n_steps)]
else:
self.activation_steps = [self.base_ref.act + alpha * direction for alpha in torch.linspace(*range, n_steps)]
act = torch.cat(self.activation_steps, dim=0)
torch.cuda.empty_cache()
base_result, unrelated_result = eval_activation(act, self.base_ref, self.unrelated_ref)
return base_result, unrelated_result, direction
class RandomUniformPerturbation(Perturbation):
def gen_direction(self):
return torch.randn_like(self.base_ref.act)
class RandomPerturbation(Perturbation):
def gen_target(self):
self.target = self.distrib.sample(self.base_ref.act.shape[:-1])
def gen_direction(self):
self.gen_target()
return self.target - self.base_ref.act
class RandomActDirPerturbation(Perturbation):
def gen_target(self):
self.target = get_random_activation(
self.base_ref.model,
dataset,
self.base_ref.n_ctx,
self.base_ref.layer,
self.base_ref.pos,
)
def gen_direction(self):
self.gen_target()
return self.target - self.base_ref.act
class NeuronDirPerturbation(Perturbation):
def __init__(self, ref: Reference, negate=False, on_only=None):
super().__init__(ref)
self.base_ref = ref
self.negate = 1 if not negate else -1
if on_only is True:
feature_acts = self.base_ref.cache[ref.layer][0, -1, :]
self.active_features = feature_acts / feature_acts.max() > 0.1
def gen_direction(self):
if self.active_features is None:
random_int = random.randint(0, 768)
else:
random_int = random.choice(self.active_features.nonzero(as_tuple=True)[0])
one_hot = torch.zeros_like(self.base_ref.act)
one_hot[..., random_int] = 1
single_direction = self.negate * one_hot
return torch.stack([single_direction for _ in range(self.base_ref.act.shape[0])])
class SAEDecoderDirPerturbation(Perturbation):
def __init__(
self,
base_ref: Reference,
unrelated_ref: Reference,
sae,
ensure_mean=True,
ensure_norm=True,
project_sphere=False,
negate=False,
on_only=None,
):
super().__init__(
base_ref=base_ref,
unrelated_ref=unrelated_ref,
ensure_mean=ensure_mean,
ensure_norm=ensure_norm,
project_sphere=project_sphere,
)
self.sae = sae
self.active_features = None
self.negate = 1 if not negate else -1
if on_only is True:
feature_acts = sae.encode(self.base_ref.cache[self.sae.cfg.hook_name])[0, -1, :]
self.active_features = feature_acts / feature_acts.max() > 0.1
self.active_features = self.active_features.to("cpu")
print("Using active features:", self.active_features.nonzero(as_tuple=True)[0])
def gen_direction(self):
if self.active_features is None:
random_int = random.randint(0, 24_000)
else:
random_int = random.choice(self.active_features.nonzero(as_tuple=True)[0])
single_direction = self.negate * self.sae.W_dec[random_int, :].to("cpu").detach()
if isinstance(self.base_ref.pos, slice):
self.direction = torch.stack([single_direction for _ in range(self.base_ref.act.shape[0])]).unsqueeze(0)
else:
self.direction = single_direction.unsqueeze(0)
class SAEReplacementPerturbation(Perturbation):
def __init__(self, ref: Reference, sae):
super().__init__(ref)
self.sae = sae
def gen_target(self):
sae_out = self.sae(self.base_ref.cache[self.sae.cfg.hook_name])
return sae_out[0, self.base_ref.pos, :]
def gen_direction(self):
sae_out = self.gen_target()
self.target = sae_out
self.direction = sae_out - self.base_ref.act
class OptimizedPerturbation(Perturbation):
def __init__(self, ref: Reference, learning_rate: float = 0.1):
super().__init__(ref)
self.learning_rate = learning_rate
def sample(self, num_steps: int, step_size: float = 0.01, sign=-1, init=None):
kl_divs = []
directions = []
if init is not None:
init = np.array(init.cpu().detach())
direction = torch.tensor(init, requires_grad=True, device=self.base_ref.act.device)
else:
direction = torch.randn_like(self.base_ref.act, requires_grad=True)
optimizer = torch.optim.SGD([direction], lr=self.learning_rate)
for _ in tqdm(range(num_steps), desc="Finding perturbation direction"):
kl_div, _, _ = eval_direction(direction, self.base_ref, step_size)
kl_divs.append(kl_div.item())
directions.append(direction.detach().clone())
(sign * kl_div).backward(retain_graph=True)
optimizer.step()
optimizer.zero_grad()
return kl_divs, directions
def format_ax(ax1, results, label, color, angles):
assert torch.allclose(results.angle.T - results.angle[:, 0].T, torch.tensor(0.0))
assert torch.allclose(results.dist.T - results.dist[:, 0].T, torch.tensor(0.0))
dists = results.dist[:, 0]
ax1.plot(angles, results.kl_div, label=label, color=color)
ax1.set_ylabel("KL Divergence")
ax1.grid()
ax1.legend()
ax1.set_xlabel("Angle")
ax1.set_xlim(min(angles), max(angles))
ax1.set_xticks(angles)
ax1.set_xticklabels([f"{angle * 180 / np.pi:.0f}°" for angle in angles], rotation=-80)
if label is not None:
ax2 = ax1.twiny()
ax2.set_xlabel("Distance")
ax2.set_xticks(ax1.get_xticks())
ax2.set_xticklabels([f"{dist:.1f}" for dist in dists], rotation=-80)
torch.cuda.empty_cache()
print(f"Allocated: {torch.cuda.memory_allocated() / 1024**3:.2f} GB")
seed = 0
set_seed(seed)
n_ctx = 10
layer = "blocks.1.hook_resid_pre"
pos = slice(-1, None, 1)
num_steps = 40
step_size = 1
learning_rate = 1
dataset = load_dataset("apollo-research/Skylion007-openwebtext-tokenizer-gpt2", split="train", streaming=False).with_format("torch")
dataloader = torch.utils.data.DataLoader(dataset, batch_size=15, shuffle=True)
model = HookedTransformer.from_pretrained("gpt2")
device = "cuda" if torch.cuda.is_available() else "cpu"
sae = SAE.from_pretrained(release="gpt2-small-res-jb", sae_id=layer, device="cuda").to("cpu")
base_prompt = generate_prompt(dataset, model.tokenizer, n_ctx=n_ctx)
base_ref = Reference(model, base_prompt, layer, pos, n_ctx)
unrelated_prompt = generate_prompt(dataset, model.tokenizer, n_ctx=n_ctx)
unrelated_ref = Reference(model, unrelated_prompt, layer, pos, n_ctx)
torch.cuda.empty_cache()
print(f"Allocated: {torch.cuda.memory_allocated() / 1024**3:.2f} GB")
tensor_of_prompts = generate_prompt(dataset, model.tokenizer, n_ctx=n_ctx, batch=4000)
mean_act_cache = model.run_with_cache(tensor_of_prompts)[1].to("cpu")
plt.plot(mean_act_cache[layer].mean(dim=0)[1:].mean(dim=0).cpu())
print(f"Allocated: {torch.cuda.memory_allocated() / 1024**3:.2f} GB")
data = mean_act_cache[layer][:, -1, :]
data_mean = data.mean(dim=0, keepdim=True)
data_cov = torch.einsum("ij,ik->jk", data - data_mean, data - data_mean) / data.shape[0]
distrib = MultivariateNormal(data_mean.squeeze(0), data_cov)
random_perturbation = RandomPerturbation(base_ref, unrelated_ref)
random_perturbation.distrib = distrib
randomact_perturbation = RandomActDirPerturbation(base_ref, unrelated_ref)
results, _, _ = random_perturbation.scan(n_steps=21, step_angular=False, range=(0, 1))
plt.scatter(results.angle, results.out_angle * 180 / np.pi, label="Random direction")
plt.figure()
plt.scatter(results.angle, results.l2_diff, label="Random direction")
plt.figure()
plt.scatter(results.angle, results.kl_div, label="Random direction")
def find_sensitivity(x, y, threshold=0.1):
f = sip.interp1d(x, y, kind="linear")
angles_interp = np.linspace(x.min(), x.max(), 1000)
index_interp = np.argmax(f(angles_interp) > threshold)
return angles_interp[index_interp]
dd = defaultdict(list)
fig0, ax0 = plt.subplots(1, 1, figsize=(8, 6), constrained_layout=True)
for sample_index in tqdm(range(100)):
for perturbation in ["randdir", "randact"]:
for style in ["straight", "angle"]:
pert = random_perturbation if perturbation == "randdir" else randomact_perturbation
if style == "angle":
results, _, _ = pert.scan(n_steps=91, step_angular=True, range=(0, np.pi / 2))
else:
results, _, _ = pert.scan(n_steps=21, step_angular=False, range=(0, 1))
dd[perturbation + "_" + style + "_kldiv_at_1"].append(results.kl_div[1])
dd[perturbation + "_" + style + "_l2diff_at_1"].append(results.l2_diff[1])
dd[perturbation + "_" + style + "_cossim_at_1"].append(np.cos(results.out_angle[1]))
dd[perturbation + "_" + style + "_kldiv_at_2"].append(results.kl_div[1])
dd[perturbation + "_" + style + "_l2diff_at_2"].append(results.l2_diff[1])
dd[perturbation + "_" + style + "_cossim_at_2"].append(np.cos(results.out_angle[1]))
dd[perturbation + "_" + style + "_kldiv_at_5"].append(results.kl_div[5])
dd[perturbation + "_" + style + "_l2diff_at_5"].append(results.l2_diff[5])
dd[perturbation + "_" + style + "_cossim_at_5"].append(np.cos(results.out_angle[5]))
dd[perturbation + "_" + style + "_kldiv_at_10"].append(results.kl_div[10])
dd[perturbation + "_" + style + "_l2diff_at_10"].append(results.l2_diff[10])
dd[perturbation + "_" + style + "_cossim_at_10"].append(np.cos(results.out_angle[10]))
dd[perturbation + "_" + style + "_kldiv_at_20"].append(results.kl_div[20])
dd[perturbation + "_" + style + "_l2diff_at_20"].append(results.l2_diff[20])
dd[perturbation + "_" + style + "_cossim_at_20"].append(np.cos(results.out_angle[20]))
if style == "angle":
f = lambda threshold: find_sensitivity(results.angle[:, 0], results.kl_div, threshold)
else:
f = lambda threshold: find_sensitivity(results.l2_diff.squeeze(), results.kl_div, threshold)
dd[perturbation + "_" + style + "_where01KL"].append(f(0.1))
dd[perturbation + "_" + style + "_where05KL"].append(f(0.5))
dd[perturbation + "_" + style + "_where1KL"].append(f(1))
dd[perturbation + "_" + style + "_where2KL"].append(f(2))
if style == "angle":
f = sip.interp1d(
results.l2_diff,
results.angle.squeeze(),
kind="linear",
fill_value=0,
bounds_error=False,
)
else:
f = sip.interp1d(
results.l2_diff,
results.dist.squeeze(),
kind="linear",
fill_value=0,
bounds_error=False,
)
dd[perturbation + "_" + style + "_where10L2"].append(f(10))
dd[perturbation + "_" + style + "_where20L2"].append(f(20))
if style == "angle":
color = "C0" if perturbation == "randdir" else "C1"
ax0.plot(results.angle, results.kl_div, color=color, label=perturbation, lw=0.2)
fig, axes = plt.subplots(6, 6, figsize=(20, 20), constrained_layout=True)
fig2, axes2 = plt.subplots(2, 6, figsize=(20, 6), constrained_layout=True)
for i, style in enumerate(["straight", "angle"]):
for j, metric in enumerate(["kldiv", "l2diff", "cossim"]):
for k, dist in enumerate([1, 2, 5, 10, 20]):
axes[2 * j + i, k].hist(
dd["randdir_" + style + "_" + metric + "_at_" + str(dist)],
bins=20,
alpha=0.5,
label="randdir",
density=True,
)
axes[2 * j + i, k].hist(
dd["randact_" + style + "_" + metric + "_at_" + str(dist)],
bins=20,
alpha=0.5,
label="randact",
density=True,
)
axes[2 * j + i, k].set_title(f"{style} {metric} at {dist}")
axes[2 * j + i, k].legend()
axes2[i, 0].hist(dd["randdir_" + style + "_where01KL"], bins=20, alpha=0.5, label="randdir", density=True)
axes2[i, 0].hist(dd["randact_" + style + "_where01KL"], bins=20, alpha=0.5, label="randact", density=True)
axes2[i, 0].set_title(f"{style} where KL=0.1")
axes2[i, 0].legend()
axes2[i, 1].hist(dd["randdir_" + style + "_where05KL"], bins=20, alpha=0.5, label="randdir", density=True)
axes2[i, 1].hist(dd["randact_" + style + "_where05KL"], bins=20, alpha=0.5, label="randact", density=True)
axes2[i, 1].set_title(f"{style} where KL=0.5")
axes2[i, 1].legend()
axes2[i, 2].hist(dd["randdir_" + style + "_where1KL"], bins=20, alpha=0.5, label="randdir", density=True)
axes2[i, 2].hist(dd["randact_" + style + "_where1KL"], bins=20, alpha=0.5, label="randact", density=True)
axes2[i, 2].set_title(f"{style} where KL=1")
axes2[i, 2].legend()
axes2[i, 3].hist(dd["randdir_" + style + "_where2KL"], bins=20, alpha=0.5, label="randdir", density=True)
axes2[i, 3].hist(dd["randact_" + style + "_where2KL"], bins=20, alpha=0.5, label="randact", density=True)
axes2[i, 3].set_title(f"{style} where KL=2")
axes2[i, 3].legend()
axes2[i, 4].hist(dd["randdir_" + style + "_where10L2"], bins=20, alpha=0.5, label="randdir", density=True)
axes2[i, 4].hist(dd["randact_" + style + "_where10L2"], bins=20, alpha=0.5, label="randact", density=True)
axes2[i, 4].set_title(f"{style} where L2=10")
axes2[i, 4].legend()
axes2[i, 5].hist(dd["randdir_" + style + "_where20L2"], bins=20, alpha=0.5, label="randdir", density=True)
axes2[i, 5].hist(dd["randact_" + style + "_where20L2"], bins=20, alpha=0.5, label="randact", density=True)
axes2[i, 5].set_title(f"{style} where L2=20")
axes2[i, 5].legend()
# I didn't include preliminary SAE figures & optimization for KL-div in the post
# because I wasn't confident in those experiments (and I don't recommend relying
# on that code). For good SAE evaluations see https://arxiv.org/abs/2409.15019
# and https://arxiv.org/abs/2410.12555, and for optimizing for KL-div see the
# concurrent work of Andrew Mack & Alex Turner:
# https://www.lesswrong.com/posts/ioPnHKFyy4Cw2Gr2x/mechanistically-eliciting-latent-behaviors-in-language-1
# %%
import os
import random
from collections.abc import Callable
from dataclasses import dataclass
from datetime import datetime
import matplotlib
import matplotlib.pyplot as plt
matplotlib.use("Agg") # Use non-interactive backend
import numpy as np
import plotly.figure_factory as ff
import scipy.interpolate as sip
import torch
import torch.nn.functional as F
from datasets import load_dataset
from plotly.subplots import make_subplots
from torch.distributions.multivariate_normal import MultivariateNormal
from tqdm import tqdm
from transformer_lens import HookedTransformer
os.environ["TOKENIZERS_PARALLELISM"] = "false"
# Ensure plots directory exists
os.makedirs("plots", exist_ok=True)
# %%
def lp_norm(direction, p=2):
"""Pytorch norm"""
return torch.linalg.vector_norm(direction, dim=-1, keepdim=True, ord=p)
def mean(direction):
return torch.mean(direction, dim=-1, keepdim=True)
def dot(x, y):
return torch.einsum("...k,...k->...", x, y).unsqueeze(-1)
def compute_angle_dist(ref_act, new_act, datasetmean=None):
new_act = new_act.cpu().detach()
angle = torch.acos(dot(new_act, ref_act) / (lp_norm(new_act) * lp_norm(ref_act))).squeeze(-1)
if torch.all(angle[0].isnan()):
angle[0] = torch.zeros_like(angle[0])
if torch.all(angle[-1].isnan()):
angle[-1] = np.pi * torch.ones_like(angle[-1])
dist = lp_norm(new_act - ref_act).squeeze(-1)
if datasetmean is None:
return angle, dist
else:
angle_wrt_datasetmean = torch.acos(dot(new_act - datasetmean, ref_act - datasetmean) / (lp_norm(new_act - datasetmean) * lp_norm(ref_act - datasetmean))).squeeze(
-1
)
return angle, dist, angle_wrt_datasetmean
class Reference:
def __init__(
self,
model: HookedTransformer,
prompt: torch.Tensor,
replacement_layer: str,
read_layer: str,
replacement_pos: slice,
n_ctx: int,
):
self.model = model
n_batch_prompt, n_ctx_prompt = prompt.shape
assert n_ctx == n_ctx_prompt, f"n_ctx {n_ctx} must match prompt n_ctx {n_ctx_prompt}"
self.prompt = prompt
logits, cache = model.run_with_cache(prompt)
self.logits = logits.to("cpu").detach()
self.cache = cache.to("cpu")
self.act = self.cache[replacement_layer][:, replacement_pos]
self.replacement_layer = replacement_layer
self.read_layer = read_layer
self.replacement_pos = replacement_pos
self.n_ctx = n_ctx
@dataclass
class Result:
angle: float
angle_wrt_datasetmean: float
dist: float
norm: float
kl_div: float
out_angle: float
l2_diff: float
logit_l2_diff: float
dim0_diff: float
dim1_diff: float
def set_seed(seed: int):
"""Set the random seed for reproducibility."""
torch.manual_seed(seed)
np.random.seed(seed)
random.seed(seed)
def generate_prompt(dataset, tokenizer: Callable, n_ctx: int = 1, batch: int = 1) -> torch.Tensor:
"""Generate a prompt from the dataset."""
dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch, shuffle=True)
return next(iter(dataloader))["input_ids"][:, :n_ctx]
def get_random_activation(model: HookedTransformer, dataset: torch.Tensor, n_ctx: int, layer: str, pos) -> torch.Tensor:
"""Get a random activation from the dataset."""
rand_prompt = generate_prompt(dataset, model.tokenizer, n_ctx=n_ctx)
_, cache = model.run_with_cache(rand_prompt)
return cache[layer][:, pos, :].to("cpu").detach()
def compute_kl_div(logits_ref: torch.Tensor, logits_pert: torch.Tensor) -> torch.Tensor:
"""Compute the KL divergence between the reference and perturbed logprobs."""
logits_pert = logits_pert.cpu().detach()
logprobs_ref = F.log_softmax(logits_ref, dim=-1)
logprobs_pert = F.log_softmax(logits_pert, dim=-1)
return F.kl_div(logprobs_pert, logprobs_ref, log_target=True, reduction="none").sum(dim=-1)
def compute_Lp_metric(
cache: dict,
cache_pert: dict,
p,
read_layer,
read_pos,
):
ref_readoff = cache[read_layer][:, read_pos]
pert_readoff = cache_pert[read_layer][:, read_pos].cpu().detach()
Lp_diff = torch.linalg.norm(ref_readoff - pert_readoff, ord=p, dim=-1)
return Lp_diff
def run_perturbed_activation(perturbed_act: torch.Tensor, ref: Reference):
pos = ref.replacement_pos
layer = ref.replacement_layer
def hook(act, hook):
act[:, pos, :] = perturbed_act
with ref.model.hooks(fwd_hooks=[(layer, hook)]):
prompts = torch.cat([ref.prompt for _ in range(len(perturbed_act))])
logits_pert, cache = ref.model.run_with_cache(prompts)
return logits_pert, cache
def eval_activation(
perturbed_act: torch.Tensor,
base_ref: Reference,
unrelated_ref: Reference,
read_pos=-1,
):
read_layer = base_ref.read_layer
angle, dist = compute_angle_dist(base_ref.act, perturbed_act, datasetmean=None)
angle_wrt_datasetmean = angle
norm = lp_norm(perturbed_act).squeeze(-1)
logits_pert, cache = run_perturbed_activation(perturbed_act, base_ref)
logits_pert = logits_pert.to("cpu").detach()
base_kl_div = compute_kl_div(base_ref.logits, logits_pert)[:, read_pos]
base_l1_diff = compute_Lp_metric(base_ref.cache, cache, p=1, read_layer=read_layer, read_pos=read_pos)
base_l2_diff = compute_Lp_metric(base_ref.cache, cache, p=2, read_layer=read_layer, read_pos=read_pos)
base_logit_l2_diff = torch.linalg.norm(base_ref.logits - logits_pert, ord=2, dim=-1)
base_dim0_diff = (base_ref.cache[read_layer][:, read_pos, 0] - cache[read_layer][:, read_pos, 0].cpu()).detach()
base_dim1_diff = (base_ref.cache[read_layer][:, read_pos, 1] - cache[read_layer][:, read_pos, 1].cpu()).detach()
base_out_angle, _ = compute_angle_dist(
base_ref.cache[read_layer][:, read_pos, :],
cache[read_layer][:, read_pos, :],
datasetmean=None,
)
unrelated_kl_div = compute_kl_div(unrelated_ref.logits, logits_pert)[:, read_pos]
unrelated_l1_diff = compute_Lp_metric(unrelated_ref.cache, cache, p=1, read_layer=read_layer, read_pos=read_pos)
unrelated_l2_diff = compute_Lp_metric(unrelated_ref.cache, cache, p=2, read_layer=read_layer, read_pos=read_pos)
unrelated_logit_l2_diff = torch.linalg.norm(unrelated_ref.logits - logits_pert, ord=2, dim=-1)
unrelated_dim0_diff = unrelated_ref.cache[read_layer][:, read_pos, 0] - cache[read_layer][:, read_pos, 0].cpu().detach()
unrelated_dim1_diff = unrelated_ref.cache[read_layer][:, read_pos, 1] - cache[read_layer][:, read_pos, 1].cpu().detach()
unrelated_out_angle, _ = compute_angle_dist(
unrelated_ref.cache[read_layer][:, read_pos, :],
cache[read_layer][:, read_pos, :],
datasetmean=None,
)
return (
Result(
angle,
angle_wrt_datasetmean,
dist,
norm,
base_kl_div,
base_out_angle,
base_l2_diff,
base_logit_l2_diff,
base_dim0_diff,
base_dim1_diff,
),
Result(
angle,
angle_wrt_datasetmean,
dist,
norm,
unrelated_kl_div,
unrelated_out_angle,
unrelated_l2_diff,
unrelated_logit_l2_diff,
unrelated_dim0_diff,
unrelated_dim1_diff,
),
)
class Slerp:
def __init__(self, start, direction, datasetmean=None):
"""Return interpolation points along the sphere from
A towards direction B"""
self.datasetmean = datasetmean
self.A = start - self.datasetmean if self.datasetmean is not None else start
self.A_norm = lp_norm(self.A)
self.a = self.A / self.A_norm
d = direction / lp_norm(direction)
self.B = d - dot(d, self.a) * self.a
self.B_norm = lp_norm(self.B)
self.b = self.B / self.B_norm
def __call__(self, alpha):
result = self.A_norm * (torch.cos(alpha) * self.a + torch.sin(alpha) * self.b)
if self.datasetmean is not None:
return result + self.datasetmean
else:
return result
def get_alpha(self, X):
x = X / lp_norm(X)
return torch.acos(dot(x, self.a))
class Perturbation:
def gen_direction(self):
raise NotImplementedError
def __init__(
self,
base_ref: Reference,
unrelated_ref: Reference,
ensure_mean=True,
ensure_norm=True,
):
self.ensure_mean = ensure_mean
self.ensure_norm = ensure_norm
self.base_ref = base_ref
self.norm_base = lp_norm(base_ref.act)
self.unrelated_ref = unrelated_ref
def scan(self, n_steps: int = 33, range: tuple[float, float] = (0, 1.0), step_angular=True):
direction = self.gen_direction()
return self.scan_dir(direction, n_steps, range, step_angular)
def scan_dir(self, direction, n_steps: int = 33, range: tuple[float, float] = (0, 1.0), step_angular=True):
direction = direction.clone()
if self.ensure_mean:
direction -= mean(direction)
if self.ensure_norm:
direction *= self.norm_base / lp_norm(direction)
range = np.array(range)
if step_angular:
s = Slerp(self.base_ref.act, direction, datasetmean=None)
self.activation_steps = [s(alpha) for alpha in torch.linspace(*range, n_steps)]
else:
self.activation_steps = [self.base_ref.act + alpha * direction for alpha in torch.linspace(*range, n_steps)]
act = torch.cat(self.activation_steps, dim=0)
torch.cuda.empty_cache()
base_result, unrelated_result = eval_activation(act, self.base_ref, self.unrelated_ref)
return base_result, unrelated_result, direction
class RandomUniformPerturbation(Perturbation):
def gen_direction(self):
return torch.randn_like(self.base_ref.act)
class RandomPerturbation(Perturbation):
def gen_target(self):
return self.distrib.sample(self.base_ref.act.shape[:-1])
def gen_direction(self):
self.target = self.gen_target()
return self.target - self.base_ref.act
class RandomActDirPerturbation(Perturbation):
def gen_target(self):
return get_random_activation(
self.base_ref.model,
dataset,
self.base_ref.n_ctx,
self.base_ref.replacement_layer,
self.base_ref.replacement_pos,
)
def gen_direction(self):
self.target = self.gen_target()
return self.target - self.base_ref.act
class NeuronDirPerturbation(Perturbation):
def __init__(self, ref: Reference, negate=False, on_only=None):
super().__init__(ref)
self.base_ref = ref
self.negate = 1 if not negate else -1
if on_only is True:
feature_acts = self.base_ref.cache[ref.replacement_layer][0, -1, :]
self.active_features = feature_acts / feature_acts.max() > 0.1
def gen_direction(self):
if self.active_features is None:
random_int = random.randint(0, 768)
else:
random_int = random.choice(self.active_features.nonzero(as_tuple=True)[0])
one_hot = torch.zeros_like(self.base_ref.act)
one_hot[..., random_int] = 1
single_direction = self.negate * one_hot
return torch.stack([single_direction for _ in range(self.base_ref.act.shape[0])])
class OptimizedPerturbation(Perturbation):
def __init__(self, ref: Reference, learning_rate: float = 0.1):
super().__init__(ref)
self.learning_rate = learning_rate
def sample(self, num_steps: int, step_size: float = 0.01, sign=-1, init=None):
kl_divs = []
directions = []
if init is not None:
init = np.array(init.cpu().detach())
direction = torch.tensor(init, requires_grad=True, device=self.base_ref.act.device)
else:
direction = torch.randn_like(self.base_ref.act, requires_grad=True)
optimizer = torch.optim.SGD([direction], lr=self.learning_rate)
for _ in tqdm(range(num_steps), desc="Finding perturbation direction"):
kl_div, _, _ = eval_direction(direction, self.base_ref, step_size)
kl_divs.append(kl_div.item())
directions.append(direction.detach().clone())
(sign * kl_div).backward(retain_graph=True)
optimizer.step()
optimizer.zero_grad()
return kl_divs, directions
def format_ax(ax1, ax1t, ax2, ax2t, results, label, color, x, base_ref, ls="-", x_is_angle=True):
assert torch.allclose(results.angle.T - results.angle[:, 0].T, torch.tensor(0.0))
assert torch.allclose(results.dist.T - results.dist[:, 0].T, torch.tensor(0.0))
angles = r.angle[:, 0].detach()
dists = results.dist[:, 0].detach()
ax1.plot(x, results.kl_div, label=label, color=color, lw=0.5, ls=ls)
ax2.plot(x, results.l2_diff, label=label, color=color, lw=0.5, ls=ls)
ax1.set_ylabel("KL divergence to base logits")
ax2.set_ylabel(f"L2 difference in {base_ref.read_layer}")
ax1.legend()
ax2.legend()
ax1.set_xlim(min(x), max(x))
ax2.set_xlim(min(x), max(x))
if len(angles) < 40:
tick_angles = angles
tick_dists = dists
else:
tick_angles = angles[::10]
tick_dists = dists[::10]
if x_is_angle:
ax1.set_xlabel(f"Angle from base activation at {layer}")
ax2.set_xlabel(f"Angle from base activation at {layer}")
ax1.set_xticks(tick_angles)
ax1.set_xticklabels([f"{angle * 180 / np.pi:.0f}°" for angle in tick_angles], rotation=-80)
ax2.set_xticks(tick_angles)
ax2.set_xticklabels([f"{angle * 180 / np.pi:.0f}°" for angle in tick_angles], rotation=-80)
if label is not None:
ax1t.set_xlabel("Distance")
ax2t.set_xlabel("Distance")
ax1t.set_xticks(ax1.get_xticks())
ax2t.set_xticks(ax2.get_xticks())
ax1t.set_xticklabels([f"{dist:.1f}" for dist in tick_dists], rotation=-80)
ax2t.set_xticklabels([f"{dist:.1f}" for dist in tick_dists], rotation=-80)
else:
ax1.set_xlabel(f"Distance from base activation at {layer}")
ax2.set_xlabel(f"Distance from base activation at {layer}")
ax1.set_xticks(tick_dists)
ax1.set_xticklabels([f"{dist:.1f}" for dist in tick_dists], rotation=-80)
ax2.set_xticks(tick_dists)
ax2.set_xticklabels([f"{dist:.1f}" for dist in tick_dists], rotation=-80)
if label is not None:
ax1t.set_xlabel("Angle")
ax2t.set_xlabel("Angle")
ax1t.set_xticks(ax1.get_xticks())
ax2t.set_xticks(ax2.get_xticks())
ax1t.set_xticklabels([f"{angle * 180 / np.pi:.0f}°" for angle in tick_angles], rotation=-80)
ax2t.set_xticklabels([f"{angle * 180 / np.pi:.0f}°" for angle in tick_angles], rotation=-80)
def find_sensitivity(x, y, threshold=0.5):
f = sip.interp1d(x, y, kind="linear")
angles_interp = np.linspace(x.min(), x.max(), 100000)
index_interp = np.argmax(f(angles_interp) > threshold)
return angles_interp[index_interp]
# %%
torch.cuda.empty_cache()
print(f"Allocated: {torch.cuda.memory_allocated() / 1024**3:.2f} GB")
################################
seed = 0
step_angular = True
################################
on_only = True
negate = on_only
set_seed(seed)
n_ctx = 10
layer = "blocks.1.hook_resid_pre"
pos = slice(-1, None, 1)
num_steps = 40
step_size = 1
learning_rate = 1
dataset = load_dataset("apollo-research/Skylion007-openwebtext-tokenizer-gpt2", split="train", streaming=False).with_format("torch")
dataloader = torch.utils.data.DataLoader(dataset, batch_size=15, shuffle=True)
# %%
model = HookedTransformer.from_pretrained("gpt2")
device = "cuda" if torch.cuda.is_available() else "cpu"
# %%
base_prompt = generate_prompt(dataset, model.tokenizer, n_ctx=n_ctx)
base_ref = Reference(model, base_prompt, layer, "blocks.11.hook_resid_post", pos, n_ctx)
unrelated_prompt = generate_prompt(dataset, model.tokenizer, n_ctx=n_ctx)
unrelated_ref = Reference(model, unrelated_prompt, layer, "blocks.11.hook_resid_post", pos, n_ctx)
old_norm = lp_norm(base_ref.act).item()
# %%
torch.cuda.empty_cache()
print(f"Allocated: {torch.cuda.memory_allocated() / 1024**3:.2f} GB")
tensor_of_prompts = generate_prompt(dataset, model.tokenizer, n_ctx=n_ctx, batch=4000)
print("Running model with cache")
mean_act_cache = model.run_with_cache(tensor_of_prompts)[1].to("cpu")
print(f"Allocated: {torch.cuda.memory_allocated() / 1024**3:.2f} GB")
data = mean_act_cache[layer][:, -1, :]
data_mean = data.mean(dim=0, keepdim=True)
data_cov = torch.einsum("ij,ik->jk", data - data_mean, data - data_mean) / data.shape[0]
distrib = MultivariateNormal(data_mean.squeeze(0), data_cov)
# %%
random_perturbation = RandomPerturbation(base_ref, unrelated_ref)
random_perturbation.distrib = distrib
randomact_perturbation = RandomActDirPerturbation(base_ref, unrelated_ref)
def test_surroundings_parallel(perturbation, distrib, ax, n_side=5, d_side=1, steps=28, n_contours=20):
model.reset_hooks()
orig_target = perturbation.gen_target()
rand_target_1 = distrib.sample(perturbation.base_ref.act.shape[:-1])
rand_target_2 = distrib.sample(perturbation.base_ref.act.shape[:-1])
a, b = np.mgrid[0 : d_side : n_side * 1j, 0 : d_side : n_side * 1j]
mask = a + b <= 1
a, b = a[mask], b[mask]
coords = np.stack((a, b, 1 - a - b))
assert np.all(coords >= 0)
n_coords = coords.shape[1]
sensitivity = np.zeros(n_coords)
dirs = []
if len(coords.T) * steps < 1700:
for i, (x, y, z) in enumerate(coords.T):
assert x + y + z - 1 < 1e-5, f"Sum of coordinates must be 1, got {x + y + z}"
target = z * orig_target + x * rand_target_1 + y * rand_target_2
dirs.append(target - base_ref.act)
dirs = torch.concatenate(dirs, dim=0)
results, _, _ = perturbation.scan_dir(dirs, n_steps=steps, range=(0, np.pi / 3))
angles = results.angle[:, 0].reshape(steps, n_coords)
kl_divs = results.kl_div.reshape(steps, n_coords)
else:
print("Warning: High number of items (steps*coords), implementing batching over coords")
batch_size = 1700 // steps
angles = []
kl_divs = []
for i in tqdm(range(0, n_coords, batch_size), desc="Batches", total=n_coords // batch_size):
batch_coordsT = coords.T[i : i + batch_size]
batch_dirs = []
for x, y, z in batch_coordsT:
assert x + y + z - 1 < 1e-5, f"Sum of coordinates must be 1, got {x + y + z}"
target = z * orig_target + x * rand_target_1 + y * rand_target_2
batch_dirs.append(target - perturbation.base_ref.act)
batch_dirs = torch.concatenate(batch_dirs, dim=0)
batch_results, _, _ = perturbation.scan_dir(batch_dirs, n_steps=steps, range=(0, np.pi / 3))
angles.append(batch_results.angle[:, 0].reshape(steps, -1))
kl_divs.append(batch_results.kl_div.reshape(steps, -1))
angles = torch.concatenate(angles, dim=1)
kl_divs = torch.concatenate(kl_divs, dim=1)
for i, (x, y, z) in enumerate(coords.T):
sensitivity[i] = find_sensitivity(angles[:, i], kl_divs[:, i], threshold=0.5)
sensitivity = sensitivity * 180 / np.pi
fig = ff.create_ternary_contour(
coords[[2, 0, 1]],
-sensitivity,
pole_labels=["orig", "rand1", "rand2"],
colorscale="Viridis",
showscale=False,
coloring=None,
title="Sensitivity to KL-div = 0.5",
ncontours=n_contours,
)
x_plot = coords[1] + 0.5 * coords[2]
y_plot = np.sqrt(3) / 2 * coords[2]
return fig, x_plot, y_plot, sensitivity
# %%
px_fig = make_subplots(
rows=3,
cols=5,
specs=[[{"type": "scatterternary"} for _ in range(5)] for _ in range(3)],
)
fig, axes = plt.subplots(3, 5, figsize=(15, 8), constrained_layout=True)
fig.suptitle(
"Angle until KL-div = 0.5 is reached (in degrees)\n(lower angle = higher sensitivity)",
)
def make_plot(axs, title, perturbation, distrib, n_side=25, steps=10, px_fig=None, px_idx=1):
axs[0].set_ylabel(title)
sensitivities = []
for i, ax in tqdm(enumerate(axs), total=5, desc=title):
date_str = datetime.now().strftime("%Y_%m_%d_%H_%M_%S")
random_int_from_time = int(date_str[-5:].replace("_", ""))
set_seed(random_int_from_time)
ax.set_title("Seed = " + str(random_int_from_time), fontsize=10)
subfig, x_plot, y_plot, sensitivity = test_surroundings_parallel(perturbation, distrib, ax, n_side=n_side, steps=steps)
sensitivities.append(sensitivity)
for trace in subfig.data:
px_fig.add_trace(trace, row=px_idx, col=i + 1)
for i, ax in enumerate(axs):
vmin = np.min(sensitivities)
vmax = np.max(sensitivities)
im = ax.scatter(
x_plot,
y_plot,
c=sensitivities[i],
cmap="viridis_r",
marker="^",
s=30,
vmin=vmin,
vmax=vmax,
)
if i == len(axes[0]) - 1:
cbar = plt.colorbar(im, ax=axs)
n_side = 25
steps = 100
make_plot(axes[0], "random", random_perturbation, distrib, n_side, steps, px_fig, 1)
make_plot(axes[1], "r-other", randomact_perturbation, distrib, n_side, steps, px_fig, 2)
fig.savefig(f"plots/3_sensitivity_scan_matplotlib_seed_{seed}.png", dpi=300)
plt.close()
px_fig.write_image(f"plots/3_sensitivity_scan_plotly_seed_{seed}.png")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment