Skip to content

Instantly share code, notes, and snippets.

@vincentwoo
Created February 26, 2025 10:22
Show Gist options
  • Select an option

  • Save vincentwoo/4bf1998c114861ec63505bd0b76980d0 to your computer and use it in GitHub Desktop.

Select an option

Save vincentwoo/4bf1998c114861ec63505bd0b76980d0 to your computer and use it in GitHub Desktop.
code used to compress ply with SOGS at https://vincentwoo.com/3d/sutro_tower/
from dataclasses import dataclass, field
from gsplat.compression import PngCompression
from plyfile import PlyData, PlyElement
import numpy as np
import os
import torch
import tyro
from pathlib import Path
@dataclass
class Config:
ply: Path
output_dir: Path
@torch.no_grad()
def read_ply(path):
"""
Reads a .ply file (with columns including x,y,z,nx,ny,nz, f_dc_i, f_rest_j, etc.)
and reconstructs a dictionary of PyTorch tensors on GPU.
Automatically handles a variable number of 'f_rest_*' columns if present.
"""
plydata = PlyData.read(path)
vd = plydata['vertex'].data # This is a structured NumPy array
# Helper to quickly check presence of a column
def has_col(col_name):
return col_name in vd.dtype.names
xyz = np.stack([vd['x'], vd['y'], vd['z']], axis=-1)
normals = np.stack([vd['nx'], vd['ny'], vd['nz']], axis=-1)
f_dc = np.stack([vd[f"f_dc_{i}"] for i in range(3)], axis=-1)
# -- f_rest (variable number: 0 to many) --
# Gather all columns that match 'f_rest_*' pattern and sort them by suffix.
rest_cols = [c for c in vd.dtype.names if c.startswith('f_rest_')]
def get_index(c):
# e.g., 'f_rest_12' -> integer 12
return int(c.split('_')[-1])
rest_cols_sorted = sorted(rest_cols, key=get_index)
if len(rest_cols_sorted) > 0:
f_rest = np.stack([vd[c] for c in rest_cols_sorted], axis=-1)
else:
f_rest = np.empty((len(vd), 0), dtype=np.float32)
opacities = vd['opacity']
scale = np.stack([vd[f"scale_{i}"] for i in range(3)], axis=-1)
rotation = np.stack([vd[f"rot_{i}"] for i in range(4)], axis=-1)
# Convert to torch tensors
splats = {}
splats["means"] = torch.from_numpy(xyz).float().cuda()
splats["opacities"] = torch.from_numpy(opacities).float().cuda()
splats["scales"] = torch.from_numpy(scale).float().cuda()
splats["quats"] = torch.from_numpy(rotation).float().cuda()
# f_dc -> shape: (N, 3) => then unsqueeze => (N,3,1) => transpose => (N,1,3)
sh0_tensor = torch.from_numpy(f_dc).float() # shape (N, 3)
sh0_tensor = sh0_tensor.unsqueeze(-1) # shape (N, 3, 1)
sh0_tensor = sh0_tensor.transpose(1, 2) # shape (N, 1, 3)
splats["sh0"] = sh0_tensor.cuda()
if f_rest.any():
# f_rest -> shape (N, #rest) => reshape => (N, 3, (#rest/3)) => transpose => (N, (#rest/3), 3)
# Make sure total f_rest columns is divisible by 3 if you are expecting a 3-channel grouping.
if f_rest.shape[1] % 3 != 0:
raise ValueError(f"Number of f_rest columns ({f_rest.shape[1]}) not divisible by 3.")
num_rest_per_channel = f_rest.shape[1] // 3
shn_tensor = torch.from_numpy(
f_rest.reshape(-1, 3, num_rest_per_channel)
).float().transpose(1, 2) # shape (N, num_rest_per_channel, 3)
splats["shN"] = shn_tensor.cuda()
return splats
cfg = tyro.cli(Config)
os.makedirs(cfg.output_dir, exist_ok=True)
splats = read_ply(cfg.ply)
PngCompression().compress(cfg.output_dir, splats)
import json
import os
from dataclasses import dataclass
from typing import Any, Callable, Dict
import numpy as np
import torch
import torch.nn.functional as F
from torch import Tensor
from gsplat.compression.sort import sort_splats
from gsplat.utils import inverse_log_transform, log_transform
from plyfile import PlyData, PlyElement
from PIL import Image
@dataclass
class PngCompression:
"""Uses quantization and sorting to compress splats into PNG files and uses
K-means clustering to compress the spherical harmonic coefficents.
.. warning::
This class requires the `Pillow <https://pypi.org/project/pillow/>`_,
`plas <https://github.com/fraunhoferhhi/PLAS.git>`_
and `torchpq <https://github.com/DeMoriarty/TorchPQ?tab=readme-ov-file#install>`_ packages to be installed.
.. warning::
This class might throw away a few lowest opacities splats if the number of
splats is not a square number.
.. note::
The splats parameters are expected to be pre-activation values. It expects
the following fields in the splats dictionary: "means", "scales", "quats",
"opacities", "sh0", "shN".
References:
- `Compact 3D Scene Representation via Self-Organizing Gaussian Grids <https://arxiv.org/abs/2312.13299>`_
- `Making Gaussian Splats more smaller <https://aras-p.info/blog/2023/09/27/Making-Gaussian-Splats-more-smaller/>`_
Args:
use_sort (bool, optional): Whether to sort splats before compression. Defaults to True.
verbose (bool, optional): Whether to print verbose information. Default to True.
"""
use_sort: bool = True
verbose: bool = True
def _get_compress_fn(self, param_name: str) -> Callable:
compress_fn_map = {
"means": _compress_png_16bit,
"scales": _compress_png,
"quats": _compress_png,
"opacities": _compress_png,
"sh0": _compress_png,
"shN": _compress_kmeans,
}
return compress_fn_map[param_name]
def _get_decompress_fn(self, param_name: str) -> Callable:
decompress_fn_map = {
"means": _decompress_png_16bit,
"scales": _decompress_png,
"quats": _decompress_png,
"opacities": _decompress_png,
"sh0": _decompress_png,
"shN": _decompress_kmeans,
}
return decompress_fn_map[param_name]
def compress(self, compress_dir: str, splats: Dict[str, Tensor]) -> None:
"""Run compression
Args:
compress_dir (str): directory to save compressed files
splats (Dict[str, Tensor]): Gaussian splats to compress
"""
# Param-specific preprocessing
splats["means"] = log_transform(splats["means"])
splats["quats"] = F.normalize(splats["quats"], dim=-1)
neg_mask = splats["quats"][:, 3] < 0
splats["quats"][neg_mask] *= -1
splats["quats"] = splats["quats"][:, :3]
n_gs = len(splats["means"])
n_sidelen = int(n_gs**0.5)
n_crop = n_gs - n_sidelen**2
if n_crop != 0:
splats = _crop_n_splats(splats, n_crop)
print(
f"Warning: Number of Gaussians was not square. Removed {n_crop} Gaussians."
)
meta = {}
if self.use_sort:
splats = sort_splats(splats)
for param_name in splats.keys():
compress_fn = self._get_compress_fn(param_name)
kwargs = {
"n_sidelen": n_sidelen,
"verbose": self.verbose,
}
meta[param_name] = compress_fn(
compress_dir, param_name, splats[param_name], **kwargs
)
with open(os.path.join(compress_dir, "meta.json"), "w") as f:
json.dump(meta, f, indent = 2)
def decompress(self, compress_dir: str) -> Dict[str, Tensor]:
"""Run decompression
Args:
compress_dir (str): directory that contains compressed files
Returns:
Dict[str, Tensor]: decompressed Gaussian splats
"""
with open(os.path.join(compress_dir, "meta.json"), "r") as f:
meta = json.load(f)
splats = {}
for param_name, param_meta in meta.items():
decompress_fn = self._get_decompress_fn(param_name)
splats[param_name] = decompress_fn(compress_dir, param_name, param_meta)
# Param-specific postprocessing
splats["means"] = inverse_log_transform(splats["means"])
return splats
def write_image(compress_dir, param_name, img):
"""
Compresses the image, currently as lossless webp. Centralized function to change
image encoding in the future if need be.
Args:
compress_dir (str): compression directory
param_name (str): parameter field name
img (np.Array): image data as np array
Returns:
str: filename
"""
filename = f"{param_name}.webp"
Image.fromarray(img).save(
os.path.join(compress_dir, filename),
format="webp", lossless=True, method=6
)
# filename = f"{param_name}.png"
# Image.fromarray(img).save(
# os.path.join(compress_dir, filename),
# format="png", optimize=True
# )
return filename
def _crop_n_splats(splats: Dict[str, Tensor], n_crop: int) -> Dict[str, Tensor]:
opacities = splats["opacities"]
keep_indices = torch.argsort(opacities, descending=True)[:-n_crop]
for k, v in splats.items():
splats[k] = v[keep_indices]
return splats
def _compress_png(
compress_dir: str, param_name: str, params: Tensor, n_sidelen: int, **kwargs
) -> Dict[str, Any]:
"""Compress parameters with 8-bit quantization and lossless PNG compression.
Args:
compress_dir (str): compression directory
param_name (str): parameter field name
params (Tensor): parameters
n_sidelen (int): image side length
Returns:
Dict[str, Any]: metadata
"""
grid = params.reshape((n_sidelen, n_sidelen, -1))
mins = torch.amin(grid, dim=(0, 1))
maxs = torch.amax(grid, dim=(0, 1))
grid_norm = (grid - mins) / (maxs - mins)
img_norm = grid_norm.detach().cpu().numpy()
img = (img_norm * (2**8 - 1)).round().astype(np.uint8)
img = img.squeeze()
meta = {
"shape": list(params.shape),
"dtype": str(params.dtype).split(".")[1],
"mins": mins.tolist(),
"maxs": maxs.tolist(),
"files": [write_image(compress_dir, param_name, img)]
}
return meta
def _decompress_png(compress_dir: str, param_name: str, meta: Dict[str, Any]) -> Tensor:
"""Decompress parameters from PNG file.
Args:
compress_dir (str): compression directory
param_name (str): parameter field name
meta (Dict[str, Any]): metadata
Returns:
Tensor: parameters
"""
if not np.all(meta["shape"]):
params = torch.zeros(meta["shape"], dtype=getattr(torch, meta["dtype"]))
return meta
img = imageio.imread(os.path.join(compress_dir, f"{param_name}.png"))
img_norm = img / (2**8 - 1)
grid_norm = torch.tensor(img_norm)
mins = torch.tensor(meta["mins"])
maxs = torch.tensor(meta["maxs"])
grid = grid_norm * (maxs - mins) + mins
params = grid.reshape(meta["shape"])
params = params.to(dtype=getattr(torch, meta["dtype"]))
return params
def _compress_png_16bit(
compress_dir: str, param_name: str, params: Tensor, n_sidelen: int, **kwargs
) -> Dict[str, Any]:
"""Compress parameters with 16-bit quantization and PNG compression.
Args:
compress_dir (str): compression directory
param_name (str): parameter field name
params (Tensor): parameters
n_sidelen (int): image side length
Returns:
Dict[str, Any]: metadata
"""
grid = params.reshape((n_sidelen, n_sidelen, -1))
mins = torch.amin(grid, dim=(0, 1))
maxs = torch.amax(grid, dim=(0, 1))
grid_norm = (grid - mins) / (maxs - mins)
img_norm = grid_norm.detach().cpu().numpy()
img = (img_norm * (2**16 - 1)).round().astype(np.uint16)
img_l = img & 0xFF
img_u = (img >> 8) & 0xFF
meta = {
"shape": list(params.shape),
"dtype": str(params.dtype).split(".")[1],
"mins": mins.tolist(),
"maxs": maxs.tolist(),
"files": [
write_image(compress_dir, f"{param_name}_l", img_l.astype(np.uint8)),
write_image(compress_dir, f"{param_name}_u", img_u.astype(np.uint8))
]
}
return meta
def _decompress_png_16bit(
compress_dir: str, param_name: str, meta: Dict[str, Any]
) -> Tensor:
"""Decompress parameters from PNG files.
Args:
compress_dir (str): compression directory
param_name (str): parameter field name
meta (Dict[str, Any]): metadata
Returns:
Tensor: parameters
"""
if not np.all(meta["shape"]):
params = torch.zeros(meta["shape"], dtype=getattr(torch, meta["dtype"]))
return meta
img_l = imageio.imread(os.path.join(compress_dir, f"{param_name}_l.png"))
img_u = imageio.imread(os.path.join(compress_dir, f"{param_name}_u.png"))
img_u = img_u.astype(np.uint16)
img = (img_u << 8) + img_l
img_norm = img / (2**16 - 1)
grid_norm = torch.tensor(img_norm)
mins = torch.tensor(meta["mins"])
maxs = torch.tensor(meta["maxs"])
grid = grid_norm * (maxs - mins) + mins
params = grid.reshape(meta["shape"])
params = params.to(dtype=getattr(torch, meta["dtype"]))
return params
def _compress_kmeans(
compress_dir: str,
param_name: str,
params: Tensor,
n_sidelen: int,
quantization: int = 8,
verbose: bool = True,
**kwargs,
) -> Dict[str, Any]:
"""Run K-means clustering on parameters and save centroids and labels as images.
Centroids are saved in a 940 x 1024 image where each row of 15 pixels represents
one SH band. Labels are stored as a 16bit image mapping to each of those centroids
by index.
.. warning::
TorchPQ must installed to use K-means clustering.
Args:
compress_dir (str): compression directory
param_name (str): parameter field name
params (Tensor): parameters to compress
n_clusters (int): number of K-means clusters
quantization (int): number of bits in quantization
verbose (bool, optional): Whether to print verbose information. Default to True.
Returns:
Dict[str, Any]: metadata
"""
try:
from torchpq.clustering import KMeans
except:
raise ImportError(
"Please install torchpq with 'pip install torchpq' to use K-means clustering"
)
params = params.reshape(params.shape[0], -1)
dim = params.shape[1]
n_clusters = round((len(params) >> 2) / 64) * 64
n_clusters = min(n_clusters, 2 ** 16)
kmeans = KMeans(n_clusters=n_clusters, distance="manhattan", verbose=verbose)
labels = kmeans.fit(params.permute(1, 0).contiguous())
labels = labels.detach().cpu().numpy()
centroids = kmeans.centroids.permute(1, 0)
mins = torch.min(centroids)
maxs = torch.max(centroids)
centroids_norm = (centroids - mins) / (maxs - mins)
centroids_norm = centroids_norm.detach().cpu().numpy()
centroids_quant = (
(centroids_norm * (2**quantization - 1)).round().astype(np.uint8)
)
# sorting centroids so that like coefficients stack vertically in the image atlas
# pretty minor optimization but essentially free
sorted_indices = np.lexsort(centroids_quant.T)
sorted_indices = sorted_indices.reshape(64, -1).T.reshape(-1)
sorted_centroids_quant = centroids_quant[sorted_indices]
inverse = np.argsort(sorted_indices)
centroids_packed = sorted_centroids_quant.reshape(-1, int(dim * 64 / 3), 3)
labels = inverse[labels].astype(np.uint16).reshape((n_sidelen, n_sidelen))
labels_l = labels & 0xFF
labels_u = (labels >> 8) & 0xFF
meta = {
"shape": list(params.shape),
"dtype": str(params.dtype).split(".")[1],
"mins": mins.tolist(),
"maxs": maxs.tolist(),
"quantization": quantization,
"files": [
write_image(compress_dir, f"{param_name}_centroids", centroids_packed),
write_image(compress_dir, f"{param_name}_labels_l", labels_l.astype(np.uint8)),
write_image(compress_dir, f"{param_name}_labels_u", labels_u.astype(np.uint8))
]
}
return meta
def _decompress_kmeans(
compress_dir: str, param_name: str, meta: Dict[str, Any], **kwargs
) -> Tensor:
"""Decompress parameters from K-means compression.
Args:
compress_dir (str): compression directory
param_name (str): parameter field name
meta (Dict[str, Any]): metadata
Returns:
Tensor: parameters
"""
if not np.all(meta["shape"]):
params = torch.zeros(meta["shape"], dtype=getattr(torch, meta["dtype"]))
return meta
npz_dict = np.load(os.path.join(compress_dir, f"{param_name}.npz"))
centroids_quant = npz_dict["centroids"]
labels = npz_dict["labels"]
centroids_norm = centroids_quant / (2 ** meta["quantization"] - 1)
centroids_norm = torch.tensor(centroids_norm)
mins = torch.tensor(meta["mins"])
maxs = torch.tensor(meta["maxs"])
centroids = centroids_norm * (maxs - mins) + mins
params = centroids[labels]
params = params.reshape(meta["shape"])
params = params.to(dtype=getattr(torch, meta["dtype"]))
return params
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment