Created
August 9, 2025 03:30
-
-
Save Algomancer/cfae8b30a5484807f4aec752b22ae224 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| from matplotlib import pyplot as plt | |
| from torch.distributions import Normal | |
| import math | |
| import numpy as np | |
| import torch | |
| import random | |
| from tqdm import trange | |
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
| K_components = 8 # try 4, 8, 16 | |
| policy_hidden = 128 | |
| def seed_all(seed): | |
| torch.manual_seed(seed) | |
| random.seed(seed) | |
| np.random.seed(seed) | |
| torch.manual_seed(seed) | |
| torch.backends.cudnn.deterministic = True | |
| torch.backends.cudnn.benchmark = False | |
| torch.manual_seed(seed) | |
| def render(reward, trajectory=None, tight=True): | |
| """Renders the reward distribution over the 1D env.""" | |
| x = np.linspace( | |
| reward.mus.cpu().numpy().min() - reward.n_sd * reward.sigmas.cpu().numpy().max(), | |
| reward.mus.cpu().numpy().max() + reward.n_sd * reward.sigmas.cpu().numpy().max(), | |
| 1000, | |
| ) | |
| d = torch.exp(reward.log_reward(torch.tensor(x))) | |
| dual_plot = not isinstance(trajectory, type(None)) | |
| if dual_plot: | |
| fig, axs = plt.subplots(2, 1) | |
| axs = axs.ravel() | |
| else: | |
| fig, axs = plt.subplots(1, 1) | |
| axs = [axs] # Hack to allow indexing. | |
| if dual_plot: | |
| ax_dual = axs[0].twinx() # Second axes for final state counts. | |
| ax_dual.hist( | |
| trajectory[:, -1, 0].cpu().numpy(), # Final X Position. | |
| bins=100, | |
| density=False, | |
| alpha=0.5, | |
| color="red", | |
| ) | |
| ax_dual.set_ylabel("Samples", color="red") | |
| ax_dual.tick_params(axis="y", labelcolor="red") | |
| n, trajectory_length, _ = trajectory.shape | |
| for i in range(n): | |
| axs[1].plot( | |
| trajectory[i, :, 0].cpu().numpy(), | |
| np.arange(1, trajectory_length + 1), | |
| alpha=0.1, | |
| linewidth=0.05, | |
| color='black', | |
| ) | |
| axs[1].set_ylabel('Step') | |
| axs[0].plot(x, d, color="black") | |
| # Adds the modes. | |
| for mu in reward.mus: | |
| axs[0].axvline(mu, color="grey", linestyle="--") | |
| # S0 | |
| axs[0].plot([reward.init_value], [0], 'ro') | |
| axs[0].text(reward.init_value + 0.1, 0.01, "$S_0$", rotation=45) | |
| # Means | |
| for i, mu in enumerate(reward.mus): | |
| idx = abs(x - mu.numpy()) == min(abs(x - mu.numpy())) | |
| axs[0].plot([x[idx]], [d[idx]], 'bo') | |
| axs[0].text(x[idx] + 0.1, d[idx], "Mode {}".format(i + 1), rotation=0) | |
| axs[0].spines[['right', 'top']].set_visible(False) | |
| axs[0].set_ylabel("Reward Value") | |
| axs[0].set_title("Line Environment") | |
| axs[0].set_ylim(0, max(d) * 1.1) | |
| if dual_plot: | |
| axs[1].set_xlim(axs[0].get_xlim()) | |
| axs[1].set_xlabel("X Position") | |
| else: | |
| axs[0].set_xlabel("X Position") | |
| if tight: | |
| fig.tight_layout() | |
| plt.savefig('plot.jpg', dpi=150, bbox_inches='tight') | |
| plt.close() | |
| class LineEnvironment(): | |
| def __init__(self, mus, variances, n_sd, init_value): | |
| self.mus = torch.tensor(mus) | |
| self.sigmas = torch.tensor([math.sqrt(v) for v in variances]) | |
| self.variances = torch.tensor(variances) | |
| self.mixture = [ | |
| Normal(torch.tensor(m), torch.tensor(s)) for m, s in zip(mus, self.sigmas) | |
| ] | |
| self.n_sd = n_sd | |
| self.lb = min(self.mus) - self.n_sd * max(self.sigmas) # Convienience only. | |
| self.ub = max(self.mus) + self.n_sd * max(self.sigmas) # Convienience only. | |
| self.init_value = init_value # Used for s0. | |
| assert self.lb < self.init_value < self.ub | |
| def log_reward(self, x): | |
| """Sum of the exponential of each log probability in the mixture.""" | |
| return torch.logsumexp(torch.stack([m.log_prob(x) for m in self.mixture], 0), 0) | |
| @property | |
| def log_partition(self) -> float: | |
| """Log Partition is the log of the number of gaussians.""" | |
| return torch.tensor(len(self.mus)).log() | |
| class MixtureOfGaussians1D(torch.nn.Module): | |
| """ | |
| Conditional 1D Gaussian Mixture policy. | |
| Given state x (B, 2), outputs mixture logits, means, log-scales for K components. | |
| Sampling returns (action, log_prob_under_policy). Also exposes .log_prob. | |
| """ | |
| def __init__(self, in_dim=2, hid_dim=128, K=8, min_std=1e-3, max_std=3.0): | |
| super().__init__() | |
| self.K = K | |
| self.min_std = min_std | |
| self.max_std = max_std | |
| self.net = torch.nn.Sequential( | |
| torch.nn.Linear(in_dim, hid_dim), | |
| torch.nn.ELU(), | |
| torch.nn.Linear(hid_dim, hid_dim), | |
| torch.nn.ELU(), | |
| torch.nn.Linear(hid_dim, 3 * K) | |
| ) | |
| def _params(self, x): | |
| """ | |
| Returns mixture params with shapes: | |
| logits (B, K), means (B, K), stds (B, K) | |
| """ | |
| raw = self.net(x) # (B, 3K) | |
| B = x.shape[0] | |
| logits, means, log_scales = torch.chunk(raw, 3, dim=-1) # each (B, K) | |
| # Stabilize scales; clamp to [min_std, max_std] | |
| stds = torch.nn.functional.softplus(log_scales) + self.min_std | |
| stds = torch.clamp(stds, max=self.max_std) | |
| return logits, means, stds | |
| def log_prob(self, x, a): | |
| """ | |
| x: (B, 2), a: (B,) action values | |
| returns log p(a|x): (B,) | |
| """ | |
| logits, means, stds = self._params(x) # (B, K) | |
| # Compute per-component log prob | |
| a_exp = a.unsqueeze(-1) # (B, 1) | |
| comp = torch.distributions.Normal(means, stds) | |
| log_comp = comp.log_prob(a_exp) # (B, K) | |
| # Mixture log-sum-exp with weights | |
| log_mix = torch.logsumexp(logits + log_comp, dim=-1) # (B,) | |
| # Normalize by mixture partition (i.e., logsumexp of logits) | |
| log_norm = torch.logsumexp(logits, dim=-1) # (B,) | |
| return log_mix - log_norm # (B,) | |
| def sample(self, x): | |
| """ | |
| Sample from the mixture using categorical -> normal. | |
| Returns action (B,), log_prob (B,) | |
| """ | |
| logits, means, stds = self._params(x) | |
| cat = torch.distributions.Categorical(logits=logits) # (B,) | |
| k = cat.sample() # (B,) | |
| # Gather selected component params | |
| idx = k.unsqueeze(-1) # (B,1) | |
| m = torch.gather(means, 1, idx).squeeze(-1) # (B,) | |
| s = torch.gather(stds, 1, idx).squeeze(-1) # (B,) | |
| base = torch.distributions.Normal(m, s) | |
| a = base.sample() # (B,) | |
| # Exact log_prob under the **full** mixture (not just selected comp) | |
| lp = self.log_prob(x, a) | |
| return a, lp | |
| def exploration_sample(self, x, off_policy_noise): | |
| """ | |
| Off-policy exploration: inflate component stds by noise (additive in variance space). | |
| """ | |
| logits, means, stds = self._params(x) | |
| # Variance inflation: sigma' = sqrt(sigma^2 + noise^2) | |
| stds_exp = torch.sqrt(stds**2 + off_policy_noise**2 + 1e-12) | |
| cat = torch.distributions.Categorical(logits=logits) | |
| k = cat.sample() | |
| idx = k.unsqueeze(-1) | |
| m = torch.gather(means, 1, idx).squeeze(-1) | |
| s = torch.gather(stds_exp, 1, idx).squeeze(-1) | |
| a = torch.distributions.Normal(m, s).sample() | |
| # Log-prob under the **original** policy (no inflated stds) for TB. | |
| lp = self.log_prob(x, a) | |
| return a, lp | |
| def setup_experiment(hid_dim=128, lr_model=1e-3, lr_logz=1e-1, K=8): | |
| """ | |
| Forward and backward policies are conditional Gaussian Mixtures. | |
| logZ is a learnable scalar with higher LR as in Trajectory Balance practice. | |
| """ | |
| forward_policy = MixtureOfGaussians1D(in_dim=2, hid_dim=hid_dim, K=K).to(device) | |
| backward_policy = MixtureOfGaussians1D(in_dim=2, hid_dim=hid_dim, K=K).to(device) | |
| # logZ parameter | |
| logZ = torch.nn.Parameter(torch.tensor(0.0, device=device)) | |
| optimizer = torch.optim.Adam( | |
| [ | |
| {'params': forward_policy.parameters(), 'lr': lr_model}, | |
| {'params': backward_policy.parameters(), 'lr': lr_model}, | |
| {'params': [logZ], 'lr': lr_logz}, | |
| ] | |
| ) | |
| return forward_policy, backward_policy, logZ, optimizer | |
| def step(x, action): | |
| """Takes a forward step in the environment.""" | |
| new_x = torch.zeros_like(x) | |
| new_x[:, 0] = x[:, 0] + action # Add action delta. | |
| new_x[:, 1] = x[:, 1] + 1 # Increment step counter. | |
| return new_x | |
| def initalize_state(batch_size, device, env, randn=False): | |
| """Trajectory starts at state = (X_0, t=0).""" | |
| x = torch.zeros((batch_size, 2), device=device) | |
| x[:, 0] = env.init_value | |
| return x | |
| def get_policy_and_exploration_dist(model, x, off_policy_noise): | |
| """ | |
| A policy is a distribution we predict the parameters of using a neural network, | |
| which we then sample from. | |
| """ | |
| pf_params = model(x) | |
| policy_mean = pf_params[:, 0] | |
| policy_std = torch.sigmoid(pf_params[:, 1]) * (max_policy_std - min_policy_std) + min_policy_std | |
| policy_dist = torch.distributions.Normal(policy_mean, policy_std) | |
| # Add some off-policy exploration. | |
| exploration_dist = torch.distributions.Normal(policy_mean, policy_std + off_policy_noise) | |
| return policy_dist, exploration_dist | |
| @torch.no_grad() | |
| def inference(trajectory_length, forward_model, env, batch_size=10_000): | |
| trajectory = torch.zeros((batch_size, trajectory_length + 1, 2), device=device) | |
| trajectory[:, 0, 0] = env.init_value | |
| x = initalize_state(batch_size, device, env) | |
| for t in range(trajectory_length): | |
| a, _ = forward_model.sample(x) # no exploration at inference | |
| new_x = step(x, a) | |
| trajectory[:, t + 1, :] = new_x | |
| x = new_x | |
| return trajectory | |
| def get_action_and_logp(policy, x, off_policy_noise): | |
| """ | |
| Returns sampled action and the log-prob under the **policy** (not the exploration variant), | |
| which is what TB needs for logPF/logPB. | |
| """ | |
| if off_policy_noise > 0: | |
| a, logp = policy.exploration_sample(x, off_policy_noise) | |
| else: | |
| a, logp = policy.sample(x) | |
| return a, logp | |
| def train_with_exploration( | |
| seed, | |
| batch_size, | |
| trajectory_length, | |
| env, | |
| device, | |
| init_explortation_noise, | |
| n_iterations=10_000, | |
| K=16, | |
| hid_dim=128 | |
| ): | |
| seed_all(seed) | |
| forward_policy, backward_policy, logZ, optimizer = setup_experiment( | |
| hid_dim=hid_dim, K=K | |
| ) | |
| losses = [] | |
| tbar = trange(n_iterations) | |
| true_logZ = env.log_partition | |
| exploration_schedule = np.linspace(init_explortation_noise, 0.0, n_iterations) | |
| for iteration in tbar: | |
| optimizer.zero_grad() | |
| x = initalize_state(batch_size, device, env) | |
| trajectory = torch.zeros((batch_size, trajectory_length + 1, 2), device=device) | |
| trajectory[:, 0, 0] = env.init_value | |
| logPF = torch.zeros((batch_size,), device=device) | |
| logPB = torch.zeros((batch_size,), device=device) | |
| # --------- Forward rollout ----------- | |
| for t in range(trajectory_length): | |
| a, lp = get_action_and_logp( | |
| forward_policy, x, off_policy_noise=float(exploration_schedule[iteration]) | |
| ) | |
| logPF += lp | |
| new_x = step(x, a) | |
| trajectory[:, t + 1, :] = new_x | |
| x = new_x | |
| # --------- Backward accumulation ---- | |
| # Skip last step to S_0 as before. | |
| for t in range(trajectory_length, 1, -1): | |
| # Backward policy conditions on current state S_t | |
| # action is the delta to go from S_{t-1} -> S_t (same as before) | |
| dx = (trajectory[:, t, 0] - trajectory[:, t - 1, 0]).detach() | |
| # Evaluate log prob under backward policy (no exploration) | |
| lp_b = backward_policy.log_prob(trajectory[:, t, :], dx) | |
| logPB += lp_b | |
| log_reward = env.log_reward(trajectory[:, -1, 0]) | |
| loss = (logZ + logPF - logPB - log_reward).pow(2).mean() | |
| loss.backward() | |
| optimizer.step() | |
| losses.append(loss.item()) | |
| if iteration % 100 == 0: | |
| tbar.set_description( | |
| f"Iter {iteration}: loss={np.mean(losses[-100:]):.3f}, " | |
| f"est logZ={logZ.item():.3f}, true logZ={true_logZ:.3f}, " | |
| f"LR={optimizer.param_groups[0]['lr']}, " | |
| f"off-noise={exploration_schedule[iteration]:.4f}" | |
| ) | |
| trajectories = inference(trajectory_length, forward_policy, env) | |
| render(env, trajectories) | |
| return (forward_policy, backward_policy, logZ) | |
| # Hyperparameters. | |
| batch_size = 256 | |
| init_exploration_noise = 1.0 | |
| max_policy_std = 1.0 | |
| min_policy_std = 0.1 | |
| n_iterations = 10_000 | |
| seed = 4444 | |
| trajectory_length = 5 | |
| # Define Environment. | |
| env = LineEnvironment( | |
| mus=[2, 5], | |
| variances=[0.2, 0.2], | |
| n_sd=4.5, | |
| init_value=0 | |
| ) | |
| # Train. | |
| forward_model, backward_model, logZ = train_with_exploration( | |
| seed, | |
| batch_size, | |
| trajectory_length, | |
| env, | |
| device, | |
| init_exploration_noise, | |
| n_iterations=n_iterations, | |
| ) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment