Created
April 7, 2025 03:32
-
-
Save Algomancer/023cc76913e8c991158c97762e1e6f0f to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import torch | |
| # Quick sketch, would need to validate correctness and numerical stabiltiym but i think its right? | |
| # N elements if you doing like patch ar or something, say 768 for a 16x16x3 patch | |
| # For each mixture component => 9 params for L, plus 1 logit => total 10. | |
| # if num_mixtures=K, each point has 10*K parameters => (B,768,10K) output. | |
| # self.fc_out = nn.Linear(hidden_dim, 10 * num_mixtures * N_Elements) | |
| def unpack_3d_logistic_params(param_9): | |
| """ | |
| param_9: [..., 9] => returns (mu, L) | |
| param_9[...,0:3] = mu_x, mu_y, mu_z | |
| param_9[...,3] = a11 | |
| param_9[...,4] = a21 | |
| param_9[...,5] = a31 | |
| param_9[...,6] = a22 | |
| param_9[...,7] = a32 | |
| param_9[...,8] = a33 | |
| L is lower-triangular: | |
| L[0,0] = exp(a11), L[1,0] = a21, L[2,0] = a31, | |
| L[1,1] = exp(a22), L[2,1] = a32, | |
| L[2,2] = exp(a33). | |
| """ | |
| mu = param_9[..., 0:3] | |
| a11 = param_9[..., 3] | |
| a21 = param_9[..., 4] | |
| a31 = param_9[..., 5] | |
| a22 = param_9[..., 6] | |
| a32 = param_9[..., 7] | |
| a33 = param_9[..., 8] | |
| L11 = torch.exp(a11) | |
| L22 = torch.exp(a22) | |
| L33 = torch.exp(a33) | |
| L = torch.stack([ | |
| L11, torch.zeros_like(L11), torch.zeros_like(L11), | |
| a21, L22, torch.zeros_like(L11), | |
| a31, a32, L33 | |
| ], dim=-1) | |
| # reshape last dim from 9 -> (3,3) | |
| L = L.view(*L.shape[:-1], 3, 3) | |
| return mu, L | |
| def logdet_lower_tri(L): | |
| # sum of log(diagonal) for L | |
| diag_elems = torch.diagonal(L, dim1=-2, dim2=-1) | |
| return torch.log(diag_elems).sum(dim=-1) | |
| def log_pdf_3d_logistic(x, mu, L): | |
| """ | |
| x: [..., 3] | |
| mu: [..., 3] | |
| L: [..., 3,3] | |
| log p(x) = -log|det(L)| + sum_i[ ell_i - 2 log(1+exp(ell_i)) ] | |
| ell = L^{-1}(x - mu) | |
| """ | |
| diff = x - mu # [..., 3] | |
| # solve triangular | |
| ell = torch.linalg.solve_triangular(L, diff.unsqueeze(-1), upper=False).squeeze(-1) | |
| logpdf_each = ell - 2.0 * F.softplus(ell) | |
| logpdf_sum = logpdf_each.sum(dim=-1) | |
| ld = logdet_lower_tri(L) | |
| return logpdf_sum - ld | |
| def sample_3d_logistic_mixture(params, K): | |
| """ | |
| params: [B, N, K*10] | |
| => reshape => [B,N,K,10] | |
| => for each (b,n), pick a component k from pi_k, sample x from that 3D logistic. | |
| Returns => [B, N, 3] | |
| """ | |
| B, N, _ = params.shape | |
| params_4d = params.view(B, N, K, 10) | |
| logistic_9 = params_4d[..., :9] # [B,N,K,9] | |
| logit_mix = params_4d[..., 9] # [B,N,K] | |
| # mixture weights | |
| weights = torch.softmax(logit_mix, dim=-1) # [B,N,K] | |
| # sample component index => [B*N] | |
| weights_2d = weights.view(-1, K) | |
| comp_idx = torch.multinomial(weights_2d, 1).squeeze(-1) | |
| comp_idx = comp_idx.view(B, N) # => [B,N] | |
| idx = comp_idx.unsqueeze(-1).unsqueeze(-1) # => [B, N, 1, 1] | |
| idx = idx.expand(-1, -1, 1, 9) # => [B, N, 1, 9] | |
| param_k = torch.gather(logistic_9, dim=2, index=idx) | |
| # param_k now => [B, N, 1, 9] | |
| param_k = param_k.squeeze(2) | |
| mu, L = unpack_3d_logistic_params(param_k) # mu => [B,N,3], L => [B,N,3,3] | |
| # ell ~ logistic(0,1)^3 => [B,N,3] | |
| u = torch.rand(B, N, 3, device=params.device) | |
| ell = torch.log(u/(1.-u)) | |
| ell_4d = ell.unsqueeze(-1) # => [B,N,3,1] | |
| diff = torch.matmul(L, ell_4d).squeeze(-1) # => [B,N,3] | |
| x_sample = mu + diff | |
| return x_sample | |
| def log_likelihood_3d_logistic_mixture(x, params, K): | |
| """ | |
| x: [B, N, 3] | |
| params: [B, N, K*10] | |
| => reshape => [B, N, K, 10] | |
| last dimension => 9 logistic params + 1 mixture logit | |
| Return log p(x_i) => [B, N] | |
| """ | |
| B, N, _ = x.shape | |
| # reshape => [B, N, K, 10] | |
| params_4d = params.view(B, N, K, 10) | |
| logistic_9 = params_4d[..., :9] # [B,N,K,9] | |
| logit_mix = params_4d[..., 9] # [B,N,K] | |
| # 1) mixture weights => log softmax => log pi_k | |
| log_weights = torch.log_softmax(logit_mix, dim=-1) # [B,N,K] | |
| # 2) for each component, compute 3D logistic log pdf | |
| # We'll broadcast x => [B,N,1,3] | |
| x_expand = x.unsqueeze(2) # => [B, N, 1, 3] | |
| # logistic_9 => [B,N,K,9]. Unpack => mu, L => each [B,N,K,...] | |
| mu, L = unpack_3d_logistic_params(logistic_9) # mu => [B,N,K,3], L => [B,N,K,3,3] | |
| x_broadcast = x_expand.expand(-1, -1, K, -1) # => [B,N,K,3] | |
| log_pdf_k = log_pdf_3d_logistic(x_broadcast, mu, L) # => [B,N,K] | |
| # 3) combine with mixture weights => log-sum-exp | |
| log_pdf_k += log_weights # => [B,N,K] | |
| log_prob = torch.logsumexp(log_pdf_k, dim=-1) # => [B,N] | |
| return log_prob | |
| def sample_3d_logistic_mixture(params, K): | |
| """ | |
| params: [B, N, K*10] | |
| => reshape => [B,N,K,10] | |
| => for each (b,n), pick a component k from pi_k, sample x from that 3D logistic. | |
| Returns => [B, N, 3] | |
| """ | |
| B, N, _ = params.shape | |
| params_4d = params.view(B, N, K, 10) | |
| logistic_9 = params_4d[..., :9] # [B,N,K,9] | |
| logit_mix = params_4d[..., 9] # [B,N,K] | |
| # mixture weights | |
| weights = torch.softmax(logit_mix, dim=-1) # [B,N,K] | |
| # sample component index => [B*N] | |
| weights_2d = weights.view(-1, K) | |
| comp_idx = torch.multinomial(weights_2d, 1).squeeze(-1) | |
| comp_idx = comp_idx.view(B, N) # => [B,N] | |
| idx = comp_idx.unsqueeze(-1).unsqueeze(-1) # => [B, N, 1, 1] | |
| idx = idx.expand(-1, -1, 1, 9) # => [B, N, 1, 9] | |
| param_k = torch.gather(logistic_9, dim=2, index=idx) | |
| # param_k now => [B, N, 1, 9] | |
| param_k = param_k.squeeze(2) | |
| mu, L = unpack_3d_logistic_params(param_k) # mu => [B,N,3], L => [B,N,3,3] | |
| # ell ~ logistic(0,1)^3 => [B,N,3] | |
| u = torch.rand(B, N, 3, device=params.device) | |
| ell = torch.log(u/(1.-u)) | |
| ell_4d = ell.unsqueeze(-1) # => [B,N,3,1] | |
| diff = torch.matmul(L, ell_4d).squeeze(-1) # => [B,N,3] | |
| x_sample = mu + diff | |
| return x_sample | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment