Skip to content

Instantly share code, notes, and snippets.

@Algomancer
Created August 9, 2025 03:30
Show Gist options
  • Select an option

  • Save Algomancer/cfae8b30a5484807f4aec752b22ae224 to your computer and use it in GitHub Desktop.

Select an option

Save Algomancer/cfae8b30a5484807f4aec752b22ae224 to your computer and use it in GitHub Desktop.
from matplotlib import pyplot as plt
from torch.distributions import Normal
import math
import numpy as np
import torch
import random
from tqdm import trange
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
K_components = 8 # try 4, 8, 16
policy_hidden = 128
def seed_all(seed):
torch.manual_seed(seed)
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
torch.manual_seed(seed)
def render(reward, trajectory=None, tight=True):
"""Renders the reward distribution over the 1D env."""
x = np.linspace(
reward.mus.cpu().numpy().min() - reward.n_sd * reward.sigmas.cpu().numpy().max(),
reward.mus.cpu().numpy().max() + reward.n_sd * reward.sigmas.cpu().numpy().max(),
1000,
)
d = torch.exp(reward.log_reward(torch.tensor(x)))
dual_plot = not isinstance(trajectory, type(None))
if dual_plot:
fig, axs = plt.subplots(2, 1)
axs = axs.ravel()
else:
fig, axs = plt.subplots(1, 1)
axs = [axs] # Hack to allow indexing.
if dual_plot:
ax_dual = axs[0].twinx() # Second axes for final state counts.
ax_dual.hist(
trajectory[:, -1, 0].cpu().numpy(), # Final X Position.
bins=100,
density=False,
alpha=0.5,
color="red",
)
ax_dual.set_ylabel("Samples", color="red")
ax_dual.tick_params(axis="y", labelcolor="red")
n, trajectory_length, _ = trajectory.shape
for i in range(n):
axs[1].plot(
trajectory[i, :, 0].cpu().numpy(),
np.arange(1, trajectory_length + 1),
alpha=0.1,
linewidth=0.05,
color='black',
)
axs[1].set_ylabel('Step')
axs[0].plot(x, d, color="black")
# Adds the modes.
for mu in reward.mus:
axs[0].axvline(mu, color="grey", linestyle="--")
# S0
axs[0].plot([reward.init_value], [0], 'ro')
axs[0].text(reward.init_value + 0.1, 0.01, "$S_0$", rotation=45)
# Means
for i, mu in enumerate(reward.mus):
idx = abs(x - mu.numpy()) == min(abs(x - mu.numpy()))
axs[0].plot([x[idx]], [d[idx]], 'bo')
axs[0].text(x[idx] + 0.1, d[idx], "Mode {}".format(i + 1), rotation=0)
axs[0].spines[['right', 'top']].set_visible(False)
axs[0].set_ylabel("Reward Value")
axs[0].set_title("Line Environment")
axs[0].set_ylim(0, max(d) * 1.1)
if dual_plot:
axs[1].set_xlim(axs[0].get_xlim())
axs[1].set_xlabel("X Position")
else:
axs[0].set_xlabel("X Position")
if tight:
fig.tight_layout()
plt.savefig('plot.jpg', dpi=150, bbox_inches='tight')
plt.close()
class LineEnvironment():
def __init__(self, mus, variances, n_sd, init_value):
self.mus = torch.tensor(mus)
self.sigmas = torch.tensor([math.sqrt(v) for v in variances])
self.variances = torch.tensor(variances)
self.mixture = [
Normal(torch.tensor(m), torch.tensor(s)) for m, s in zip(mus, self.sigmas)
]
self.n_sd = n_sd
self.lb = min(self.mus) - self.n_sd * max(self.sigmas) # Convienience only.
self.ub = max(self.mus) + self.n_sd * max(self.sigmas) # Convienience only.
self.init_value = init_value # Used for s0.
assert self.lb < self.init_value < self.ub
def log_reward(self, x):
"""Sum of the exponential of each log probability in the mixture."""
return torch.logsumexp(torch.stack([m.log_prob(x) for m in self.mixture], 0), 0)
@property
def log_partition(self) -> float:
"""Log Partition is the log of the number of gaussians."""
return torch.tensor(len(self.mus)).log()
class MixtureOfGaussians1D(torch.nn.Module):
"""
Conditional 1D Gaussian Mixture policy.
Given state x (B, 2), outputs mixture logits, means, log-scales for K components.
Sampling returns (action, log_prob_under_policy). Also exposes .log_prob.
"""
def __init__(self, in_dim=2, hid_dim=128, K=8, min_std=1e-3, max_std=3.0):
super().__init__()
self.K = K
self.min_std = min_std
self.max_std = max_std
self.net = torch.nn.Sequential(
torch.nn.Linear(in_dim, hid_dim),
torch.nn.ELU(),
torch.nn.Linear(hid_dim, hid_dim),
torch.nn.ELU(),
torch.nn.Linear(hid_dim, 3 * K)
)
def _params(self, x):
"""
Returns mixture params with shapes:
logits (B, K), means (B, K), stds (B, K)
"""
raw = self.net(x) # (B, 3K)
B = x.shape[0]
logits, means, log_scales = torch.chunk(raw, 3, dim=-1) # each (B, K)
# Stabilize scales; clamp to [min_std, max_std]
stds = torch.nn.functional.softplus(log_scales) + self.min_std
stds = torch.clamp(stds, max=self.max_std)
return logits, means, stds
def log_prob(self, x, a):
"""
x: (B, 2), a: (B,) action values
returns log p(a|x): (B,)
"""
logits, means, stds = self._params(x) # (B, K)
# Compute per-component log prob
a_exp = a.unsqueeze(-1) # (B, 1)
comp = torch.distributions.Normal(means, stds)
log_comp = comp.log_prob(a_exp) # (B, K)
# Mixture log-sum-exp with weights
log_mix = torch.logsumexp(logits + log_comp, dim=-1) # (B,)
# Normalize by mixture partition (i.e., logsumexp of logits)
log_norm = torch.logsumexp(logits, dim=-1) # (B,)
return log_mix - log_norm # (B,)
def sample(self, x):
"""
Sample from the mixture using categorical -> normal.
Returns action (B,), log_prob (B,)
"""
logits, means, stds = self._params(x)
cat = torch.distributions.Categorical(logits=logits) # (B,)
k = cat.sample() # (B,)
# Gather selected component params
idx = k.unsqueeze(-1) # (B,1)
m = torch.gather(means, 1, idx).squeeze(-1) # (B,)
s = torch.gather(stds, 1, idx).squeeze(-1) # (B,)
base = torch.distributions.Normal(m, s)
a = base.sample() # (B,)
# Exact log_prob under the **full** mixture (not just selected comp)
lp = self.log_prob(x, a)
return a, lp
def exploration_sample(self, x, off_policy_noise):
"""
Off-policy exploration: inflate component stds by noise (additive in variance space).
"""
logits, means, stds = self._params(x)
# Variance inflation: sigma' = sqrt(sigma^2 + noise^2)
stds_exp = torch.sqrt(stds**2 + off_policy_noise**2 + 1e-12)
cat = torch.distributions.Categorical(logits=logits)
k = cat.sample()
idx = k.unsqueeze(-1)
m = torch.gather(means, 1, idx).squeeze(-1)
s = torch.gather(stds_exp, 1, idx).squeeze(-1)
a = torch.distributions.Normal(m, s).sample()
# Log-prob under the **original** policy (no inflated stds) for TB.
lp = self.log_prob(x, a)
return a, lp
def setup_experiment(hid_dim=128, lr_model=1e-3, lr_logz=1e-1, K=8):
"""
Forward and backward policies are conditional Gaussian Mixtures.
logZ is a learnable scalar with higher LR as in Trajectory Balance practice.
"""
forward_policy = MixtureOfGaussians1D(in_dim=2, hid_dim=hid_dim, K=K).to(device)
backward_policy = MixtureOfGaussians1D(in_dim=2, hid_dim=hid_dim, K=K).to(device)
# logZ parameter
logZ = torch.nn.Parameter(torch.tensor(0.0, device=device))
optimizer = torch.optim.Adam(
[
{'params': forward_policy.parameters(), 'lr': lr_model},
{'params': backward_policy.parameters(), 'lr': lr_model},
{'params': [logZ], 'lr': lr_logz},
]
)
return forward_policy, backward_policy, logZ, optimizer
def step(x, action):
"""Takes a forward step in the environment."""
new_x = torch.zeros_like(x)
new_x[:, 0] = x[:, 0] + action # Add action delta.
new_x[:, 1] = x[:, 1] + 1 # Increment step counter.
return new_x
def initalize_state(batch_size, device, env, randn=False):
"""Trajectory starts at state = (X_0, t=0)."""
x = torch.zeros((batch_size, 2), device=device)
x[:, 0] = env.init_value
return x
def get_policy_and_exploration_dist(model, x, off_policy_noise):
"""
A policy is a distribution we predict the parameters of using a neural network,
which we then sample from.
"""
pf_params = model(x)
policy_mean = pf_params[:, 0]
policy_std = torch.sigmoid(pf_params[:, 1]) * (max_policy_std - min_policy_std) + min_policy_std
policy_dist = torch.distributions.Normal(policy_mean, policy_std)
# Add some off-policy exploration.
exploration_dist = torch.distributions.Normal(policy_mean, policy_std + off_policy_noise)
return policy_dist, exploration_dist
@torch.no_grad()
def inference(trajectory_length, forward_model, env, batch_size=10_000):
trajectory = torch.zeros((batch_size, trajectory_length + 1, 2), device=device)
trajectory[:, 0, 0] = env.init_value
x = initalize_state(batch_size, device, env)
for t in range(trajectory_length):
a, _ = forward_model.sample(x) # no exploration at inference
new_x = step(x, a)
trajectory[:, t + 1, :] = new_x
x = new_x
return trajectory
def get_action_and_logp(policy, x, off_policy_noise):
"""
Returns sampled action and the log-prob under the **policy** (not the exploration variant),
which is what TB needs for logPF/logPB.
"""
if off_policy_noise > 0:
a, logp = policy.exploration_sample(x, off_policy_noise)
else:
a, logp = policy.sample(x)
return a, logp
def train_with_exploration(
seed,
batch_size,
trajectory_length,
env,
device,
init_explortation_noise,
n_iterations=10_000,
K=16,
hid_dim=128
):
seed_all(seed)
forward_policy, backward_policy, logZ, optimizer = setup_experiment(
hid_dim=hid_dim, K=K
)
losses = []
tbar = trange(n_iterations)
true_logZ = env.log_partition
exploration_schedule = np.linspace(init_explortation_noise, 0.0, n_iterations)
for iteration in tbar:
optimizer.zero_grad()
x = initalize_state(batch_size, device, env)
trajectory = torch.zeros((batch_size, trajectory_length + 1, 2), device=device)
trajectory[:, 0, 0] = env.init_value
logPF = torch.zeros((batch_size,), device=device)
logPB = torch.zeros((batch_size,), device=device)
# --------- Forward rollout -----------
for t in range(trajectory_length):
a, lp = get_action_and_logp(
forward_policy, x, off_policy_noise=float(exploration_schedule[iteration])
)
logPF += lp
new_x = step(x, a)
trajectory[:, t + 1, :] = new_x
x = new_x
# --------- Backward accumulation ----
# Skip last step to S_0 as before.
for t in range(trajectory_length, 1, -1):
# Backward policy conditions on current state S_t
# action is the delta to go from S_{t-1} -> S_t (same as before)
dx = (trajectory[:, t, 0] - trajectory[:, t - 1, 0]).detach()
# Evaluate log prob under backward policy (no exploration)
lp_b = backward_policy.log_prob(trajectory[:, t, :], dx)
logPB += lp_b
log_reward = env.log_reward(trajectory[:, -1, 0])
loss = (logZ + logPF - logPB - log_reward).pow(2).mean()
loss.backward()
optimizer.step()
losses.append(loss.item())
if iteration % 100 == 0:
tbar.set_description(
f"Iter {iteration}: loss={np.mean(losses[-100:]):.3f}, "
f"est logZ={logZ.item():.3f}, true logZ={true_logZ:.3f}, "
f"LR={optimizer.param_groups[0]['lr']}, "
f"off-noise={exploration_schedule[iteration]:.4f}"
)
trajectories = inference(trajectory_length, forward_policy, env)
render(env, trajectories)
return (forward_policy, backward_policy, logZ)
# Hyperparameters.
batch_size = 256
init_exploration_noise = 1.0
max_policy_std = 1.0
min_policy_std = 0.1
n_iterations = 10_000
seed = 4444
trajectory_length = 5
# Define Environment.
env = LineEnvironment(
mus=[2, 5],
variances=[0.2, 0.2],
n_sd=4.5,
init_value=0
)
# Train.
forward_model, backward_model, logZ = train_with_exploration(
seed,
batch_size,
trajectory_length,
env,
device,
init_exploration_noise,
n_iterations=n_iterations,
)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment