-
-
Save francois-rozet/fd6a820e052157f8ac6e2aa39e16c1aa to your computer and use it in GitHub Desktop.
| #!/usr/bin/env python | |
| import math | |
| import matplotlib.pyplot as plt | |
| import torch | |
| import torch.nn as nn | |
| from sklearn.datasets import make_moons | |
| from torch import Tensor | |
| from tqdm import tqdm | |
| from typing import * | |
| from zuko.utils import odeint | |
| def log_normal(x: Tensor) -> Tensor: | |
| return -(x.square() + math.log(2 * math.pi)).sum(dim=-1) / 2 | |
| class MLP(nn.Sequential): | |
| def __init__( | |
| self, | |
| in_features: int, | |
| out_features: int, | |
| hidden_features: List[int] = [64, 64], | |
| ): | |
| layers = [] | |
| for a, b in zip( | |
| (in_features, *hidden_features), | |
| (*hidden_features, out_features), | |
| ): | |
| layers.extend([nn.Linear(a, b), nn.ELU()]) | |
| super().__init__(*layers[:-1]) | |
| class CNF(nn.Module): | |
| def __init__(self, features: int, freqs: int = 3, **kwargs): | |
| super().__init__() | |
| self.net = MLP(2 * freqs + features, features, **kwargs) | |
| self.register_buffer('freqs', torch.arange(1, freqs + 1) * torch.pi) | |
| def forward(self, t: Tensor, x: Tensor) -> Tensor: | |
| t = self.freqs * t[..., None] | |
| t = torch.cat((t.cos(), t.sin()), dim=-1) | |
| t = t.expand(*x.shape[:-1], -1) | |
| return self.net(torch.cat((t, x), dim=-1)) | |
| def encode(self, x: Tensor) -> Tensor: | |
| return odeint(self, x, 0.0, 1.0, phi=self.parameters()) | |
| def decode(self, z: Tensor) -> Tensor: | |
| return odeint(self, z, 1.0, 0.0, phi=self.parameters()) | |
| def log_prob(self, x: Tensor) -> Tensor: | |
| I = torch.eye(x.shape[-1], dtype=x.dtype, device=x.device) | |
| I = I.expand(*x.shape, x.shape[-1]).movedim(-1, 0) | |
| def augmented(t: Tensor, x: Tensor, ladj: Tensor) -> Tensor: | |
| with torch.enable_grad(): | |
| x = x.requires_grad_() | |
| dx = self(t, x) | |
| jacobian = torch.autograd.grad(dx, x, I, create_graph=True, is_grads_batched=True)[0] | |
| trace = torch.einsum('i...i', jacobian) | |
| return dx, trace * 1e-2 | |
| ladj = torch.zeros_like(x[..., 0]) | |
| z, ladj = odeint(augmented, (x, ladj), 0.0, 1.0, phi=self.parameters()) | |
| return log_normal(z) + ladj * 1e2 | |
| class FlowMatchingLoss(nn.Module): | |
| def __init__(self, v: nn.Module): | |
| super().__init__() | |
| self.v = v | |
| def forward(self, x: Tensor) -> Tensor: | |
| t = torch.rand_like(x[..., 0, None]) | |
| z = torch.randn_like(x) | |
| y = (1 - t) * x + (1e-4 + (1 - 1e-4) * t) * z | |
| u = (1 - 1e-4) * z - x | |
| return (self.v(t.squeeze(-1), y) - u).square().mean() | |
| if __name__ == '__main__': | |
| flow = CNF(2, hidden_features=[64] * 3) | |
| # Training | |
| loss = FlowMatchingLoss(flow) | |
| optimizer = torch.optim.Adam(flow.parameters(), lr=1e-3) | |
| data, _ = make_moons(16384, noise=0.05) | |
| data = torch.from_numpy(data).float() | |
| for epoch in tqdm(range(16384), ncols=88): | |
| subset = torch.randint(0, len(data), (256,)) | |
| x = data[subset] | |
| loss(x).backward() | |
| optimizer.step() | |
| optimizer.zero_grad() | |
| # Sampling | |
| with torch.no_grad(): | |
| z = torch.randn(16384, 2) | |
| x = flow.decode(z) | |
| plt.figure(figsize=(4.8, 4.8), dpi=150) | |
| plt.hist2d(*x.T, bins=64) | |
| plt.savefig('moons_fm.pdf') | |
| # Log-likelihood | |
| with torch.no_grad(): | |
| log_p = flow.log_prob(data[:4]) | |
| print(log_p) |
Hi @jenkspt, yes this operation is indeed expensive. Instead of computing the Jacobian, it is common to use the (unbiased) Hutchinson trace estimator instead. I have not implemented this here, but I can point you to an implementation if you want.
Note that computing the Jacobian "per-pixel" is the same as computing the diagonal of the Jacobian, which would be enough to compute the trace, but I don't think there is an algorithm to do that cheaply.
To demystify odeint for people new to ODEs like myself, I tried to implement a simple forward Euler version. It seems to generate a similar moon plot at the end but I couldn't figure out how to make it work with log_prob.
def odeint(
f: Callable[[Tensor, Tensor], Tensor],
x: Tensor,
t0: float,
t1: float,
phi: Iterable[Tensor] = (),
dt: float = 0.01,
):
# Initialize time and state
t = torch.tensor(t0, dtype=torch.float32)
t_final = torch.tensor(t1, dtype=torch.float32)
state = x
# Calculate number of steps needed
n_steps = int(abs((t_final - t) / dt))
dt = torch.sign(t_final - t) * dt
# Integrate using forward Euler method
for t in torch.linspace(t, t_final, n_steps)[1:]:
dx = f(t, state)
state = state + dt * dx
return state
Hi @AlienKevin. If you want your odeint to work with log_prob, you will need to pack x and ladj as a single tensor representing the state of the ODE and unpack it inside the function to integrate.
For example, if you want to integrate a function
s1, s2 = x1.shape, x2.shape
n1, n2 = x1.numel(), x2.numel()
def g(t, x):
x1, x2 = x[:n1].reshape(s1), x[n1:].reshape(s2)
dx1, dx2 = f(t, x1, x2)
return torch.cat((dx1.flatten(), dx2.flatten()))
x = torch.cat((x1.flatten(), x2.flatten()))
y = odeint(g, x, ...) # instead of odeint(f, (x1, x2), ...) Got it, thanks!
is common to use the (unbiased) Hutchinson trace estimator instead. I have not implemented this here, but I can point you to an implementation if you want
Yes I'm interested!
I'm looking at the
log_probfunction. For e.g. an image dataset this is quite expensive. Is it reasonable to treat pixels as independent predictions in this case? and only compute the jacobian per pixel?