Skip to content

Instantly share code, notes, and snippets.

@Algomancer
Created April 7, 2025 03:32
Show Gist options
  • Select an option

  • Save Algomancer/023cc76913e8c991158c97762e1e6f0f to your computer and use it in GitHub Desktop.

Select an option

Save Algomancer/023cc76913e8c991158c97762e1e6f0f to your computer and use it in GitHub Desktop.
import torch
# Quick sketch, would need to validate correctness and numerical stabiltiym but i think its right?
# N elements if you doing like patch ar or something, say 768 for a 16x16x3 patch
# For each mixture component => 9 params for L, plus 1 logit => total 10.
# if num_mixtures=K, each point has 10*K parameters => (B,768,10K) output.
# self.fc_out = nn.Linear(hidden_dim, 10 * num_mixtures * N_Elements)
def unpack_3d_logistic_params(param_9):
"""
param_9: [..., 9] => returns (mu, L)
param_9[...,0:3] = mu_x, mu_y, mu_z
param_9[...,3] = a11
param_9[...,4] = a21
param_9[...,5] = a31
param_9[...,6] = a22
param_9[...,7] = a32
param_9[...,8] = a33
L is lower-triangular:
L[0,0] = exp(a11), L[1,0] = a21, L[2,0] = a31,
L[1,1] = exp(a22), L[2,1] = a32,
L[2,2] = exp(a33).
"""
mu = param_9[..., 0:3]
a11 = param_9[..., 3]
a21 = param_9[..., 4]
a31 = param_9[..., 5]
a22 = param_9[..., 6]
a32 = param_9[..., 7]
a33 = param_9[..., 8]
L11 = torch.exp(a11)
L22 = torch.exp(a22)
L33 = torch.exp(a33)
L = torch.stack([
L11, torch.zeros_like(L11), torch.zeros_like(L11),
a21, L22, torch.zeros_like(L11),
a31, a32, L33
], dim=-1)
# reshape last dim from 9 -> (3,3)
L = L.view(*L.shape[:-1], 3, 3)
return mu, L
def logdet_lower_tri(L):
# sum of log(diagonal) for L
diag_elems = torch.diagonal(L, dim1=-2, dim2=-1)
return torch.log(diag_elems).sum(dim=-1)
def log_pdf_3d_logistic(x, mu, L):
"""
x: [..., 3]
mu: [..., 3]
L: [..., 3,3]
log p(x) = -log|det(L)| + sum_i[ ell_i - 2 log(1+exp(ell_i)) ]
ell = L^{-1}(x - mu)
"""
diff = x - mu # [..., 3]
# solve triangular
ell = torch.linalg.solve_triangular(L, diff.unsqueeze(-1), upper=False).squeeze(-1)
logpdf_each = ell - 2.0 * F.softplus(ell)
logpdf_sum = logpdf_each.sum(dim=-1)
ld = logdet_lower_tri(L)
return logpdf_sum - ld
def sample_3d_logistic_mixture(params, K):
"""
params: [B, N, K*10]
=> reshape => [B,N,K,10]
=> for each (b,n), pick a component k from pi_k, sample x from that 3D logistic.
Returns => [B, N, 3]
"""
B, N, _ = params.shape
params_4d = params.view(B, N, K, 10)
logistic_9 = params_4d[..., :9] # [B,N,K,9]
logit_mix = params_4d[..., 9] # [B,N,K]
# mixture weights
weights = torch.softmax(logit_mix, dim=-1) # [B,N,K]
# sample component index => [B*N]
weights_2d = weights.view(-1, K)
comp_idx = torch.multinomial(weights_2d, 1).squeeze(-1)
comp_idx = comp_idx.view(B, N) # => [B,N]
idx = comp_idx.unsqueeze(-1).unsqueeze(-1) # => [B, N, 1, 1]
idx = idx.expand(-1, -1, 1, 9) # => [B, N, 1, 9]
param_k = torch.gather(logistic_9, dim=2, index=idx)
# param_k now => [B, N, 1, 9]
param_k = param_k.squeeze(2)
mu, L = unpack_3d_logistic_params(param_k) # mu => [B,N,3], L => [B,N,3,3]
# ell ~ logistic(0,1)^3 => [B,N,3]
u = torch.rand(B, N, 3, device=params.device)
ell = torch.log(u/(1.-u))
ell_4d = ell.unsqueeze(-1) # => [B,N,3,1]
diff = torch.matmul(L, ell_4d).squeeze(-1) # => [B,N,3]
x_sample = mu + diff
return x_sample
def log_likelihood_3d_logistic_mixture(x, params, K):
"""
x: [B, N, 3]
params: [B, N, K*10]
=> reshape => [B, N, K, 10]
last dimension => 9 logistic params + 1 mixture logit
Return log p(x_i) => [B, N]
"""
B, N, _ = x.shape
# reshape => [B, N, K, 10]
params_4d = params.view(B, N, K, 10)
logistic_9 = params_4d[..., :9] # [B,N,K,9]
logit_mix = params_4d[..., 9] # [B,N,K]
# 1) mixture weights => log softmax => log pi_k
log_weights = torch.log_softmax(logit_mix, dim=-1) # [B,N,K]
# 2) for each component, compute 3D logistic log pdf
# We'll broadcast x => [B,N,1,3]
x_expand = x.unsqueeze(2) # => [B, N, 1, 3]
# logistic_9 => [B,N,K,9]. Unpack => mu, L => each [B,N,K,...]
mu, L = unpack_3d_logistic_params(logistic_9) # mu => [B,N,K,3], L => [B,N,K,3,3]
x_broadcast = x_expand.expand(-1, -1, K, -1) # => [B,N,K,3]
log_pdf_k = log_pdf_3d_logistic(x_broadcast, mu, L) # => [B,N,K]
# 3) combine with mixture weights => log-sum-exp
log_pdf_k += log_weights # => [B,N,K]
log_prob = torch.logsumexp(log_pdf_k, dim=-1) # => [B,N]
return log_prob
def sample_3d_logistic_mixture(params, K):
"""
params: [B, N, K*10]
=> reshape => [B,N,K,10]
=> for each (b,n), pick a component k from pi_k, sample x from that 3D logistic.
Returns => [B, N, 3]
"""
B, N, _ = params.shape
params_4d = params.view(B, N, K, 10)
logistic_9 = params_4d[..., :9] # [B,N,K,9]
logit_mix = params_4d[..., 9] # [B,N,K]
# mixture weights
weights = torch.softmax(logit_mix, dim=-1) # [B,N,K]
# sample component index => [B*N]
weights_2d = weights.view(-1, K)
comp_idx = torch.multinomial(weights_2d, 1).squeeze(-1)
comp_idx = comp_idx.view(B, N) # => [B,N]
idx = comp_idx.unsqueeze(-1).unsqueeze(-1) # => [B, N, 1, 1]
idx = idx.expand(-1, -1, 1, 9) # => [B, N, 1, 9]
param_k = torch.gather(logistic_9, dim=2, index=idx)
# param_k now => [B, N, 1, 9]
param_k = param_k.squeeze(2)
mu, L = unpack_3d_logistic_params(param_k) # mu => [B,N,3], L => [B,N,3,3]
# ell ~ logistic(0,1)^3 => [B,N,3]
u = torch.rand(B, N, 3, device=params.device)
ell = torch.log(u/(1.-u))
ell_4d = ell.unsqueeze(-1) # => [B,N,3,1]
diff = torch.matmul(L, ell_4d).squeeze(-1) # => [B,N,3]
x_sample = mu + diff
return x_sample
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment