Created
February 26, 2025 10:22
-
-
Save vincentwoo/4bf1998c114861ec63505bd0b76980d0 to your computer and use it in GitHub Desktop.
code used to compress ply with SOGS at https://vincentwoo.com/3d/sutro_tower/
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| from dataclasses import dataclass, field | |
| from gsplat.compression import PngCompression | |
| from plyfile import PlyData, PlyElement | |
| import numpy as np | |
| import os | |
| import torch | |
| import tyro | |
| from pathlib import Path | |
| @dataclass | |
| class Config: | |
| ply: Path | |
| output_dir: Path | |
| @torch.no_grad() | |
| def read_ply(path): | |
| """ | |
| Reads a .ply file (with columns including x,y,z,nx,ny,nz, f_dc_i, f_rest_j, etc.) | |
| and reconstructs a dictionary of PyTorch tensors on GPU. | |
| Automatically handles a variable number of 'f_rest_*' columns if present. | |
| """ | |
| plydata = PlyData.read(path) | |
| vd = plydata['vertex'].data # This is a structured NumPy array | |
| # Helper to quickly check presence of a column | |
| def has_col(col_name): | |
| return col_name in vd.dtype.names | |
| xyz = np.stack([vd['x'], vd['y'], vd['z']], axis=-1) | |
| normals = np.stack([vd['nx'], vd['ny'], vd['nz']], axis=-1) | |
| f_dc = np.stack([vd[f"f_dc_{i}"] for i in range(3)], axis=-1) | |
| # -- f_rest (variable number: 0 to many) -- | |
| # Gather all columns that match 'f_rest_*' pattern and sort them by suffix. | |
| rest_cols = [c for c in vd.dtype.names if c.startswith('f_rest_')] | |
| def get_index(c): | |
| # e.g., 'f_rest_12' -> integer 12 | |
| return int(c.split('_')[-1]) | |
| rest_cols_sorted = sorted(rest_cols, key=get_index) | |
| if len(rest_cols_sorted) > 0: | |
| f_rest = np.stack([vd[c] for c in rest_cols_sorted], axis=-1) | |
| else: | |
| f_rest = np.empty((len(vd), 0), dtype=np.float32) | |
| opacities = vd['opacity'] | |
| scale = np.stack([vd[f"scale_{i}"] for i in range(3)], axis=-1) | |
| rotation = np.stack([vd[f"rot_{i}"] for i in range(4)], axis=-1) | |
| # Convert to torch tensors | |
| splats = {} | |
| splats["means"] = torch.from_numpy(xyz).float().cuda() | |
| splats["opacities"] = torch.from_numpy(opacities).float().cuda() | |
| splats["scales"] = torch.from_numpy(scale).float().cuda() | |
| splats["quats"] = torch.from_numpy(rotation).float().cuda() | |
| # f_dc -> shape: (N, 3) => then unsqueeze => (N,3,1) => transpose => (N,1,3) | |
| sh0_tensor = torch.from_numpy(f_dc).float() # shape (N, 3) | |
| sh0_tensor = sh0_tensor.unsqueeze(-1) # shape (N, 3, 1) | |
| sh0_tensor = sh0_tensor.transpose(1, 2) # shape (N, 1, 3) | |
| splats["sh0"] = sh0_tensor.cuda() | |
| if f_rest.any(): | |
| # f_rest -> shape (N, #rest) => reshape => (N, 3, (#rest/3)) => transpose => (N, (#rest/3), 3) | |
| # Make sure total f_rest columns is divisible by 3 if you are expecting a 3-channel grouping. | |
| if f_rest.shape[1] % 3 != 0: | |
| raise ValueError(f"Number of f_rest columns ({f_rest.shape[1]}) not divisible by 3.") | |
| num_rest_per_channel = f_rest.shape[1] // 3 | |
| shn_tensor = torch.from_numpy( | |
| f_rest.reshape(-1, 3, num_rest_per_channel) | |
| ).float().transpose(1, 2) # shape (N, num_rest_per_channel, 3) | |
| splats["shN"] = shn_tensor.cuda() | |
| return splats | |
| cfg = tyro.cli(Config) | |
| os.makedirs(cfg.output_dir, exist_ok=True) | |
| splats = read_ply(cfg.ply) | |
| PngCompression().compress(cfg.output_dir, splats) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import json | |
| import os | |
| from dataclasses import dataclass | |
| from typing import Any, Callable, Dict | |
| import numpy as np | |
| import torch | |
| import torch.nn.functional as F | |
| from torch import Tensor | |
| from gsplat.compression.sort import sort_splats | |
| from gsplat.utils import inverse_log_transform, log_transform | |
| from plyfile import PlyData, PlyElement | |
| from PIL import Image | |
| @dataclass | |
| class PngCompression: | |
| """Uses quantization and sorting to compress splats into PNG files and uses | |
| K-means clustering to compress the spherical harmonic coefficents. | |
| .. warning:: | |
| This class requires the `Pillow <https://pypi.org/project/pillow/>`_, | |
| `plas <https://github.com/fraunhoferhhi/PLAS.git>`_ | |
| and `torchpq <https://github.com/DeMoriarty/TorchPQ?tab=readme-ov-file#install>`_ packages to be installed. | |
| .. warning:: | |
| This class might throw away a few lowest opacities splats if the number of | |
| splats is not a square number. | |
| .. note:: | |
| The splats parameters are expected to be pre-activation values. It expects | |
| the following fields in the splats dictionary: "means", "scales", "quats", | |
| "opacities", "sh0", "shN". | |
| References: | |
| - `Compact 3D Scene Representation via Self-Organizing Gaussian Grids <https://arxiv.org/abs/2312.13299>`_ | |
| - `Making Gaussian Splats more smaller <https://aras-p.info/blog/2023/09/27/Making-Gaussian-Splats-more-smaller/>`_ | |
| Args: | |
| use_sort (bool, optional): Whether to sort splats before compression. Defaults to True. | |
| verbose (bool, optional): Whether to print verbose information. Default to True. | |
| """ | |
| use_sort: bool = True | |
| verbose: bool = True | |
| def _get_compress_fn(self, param_name: str) -> Callable: | |
| compress_fn_map = { | |
| "means": _compress_png_16bit, | |
| "scales": _compress_png, | |
| "quats": _compress_png, | |
| "opacities": _compress_png, | |
| "sh0": _compress_png, | |
| "shN": _compress_kmeans, | |
| } | |
| return compress_fn_map[param_name] | |
| def _get_decompress_fn(self, param_name: str) -> Callable: | |
| decompress_fn_map = { | |
| "means": _decompress_png_16bit, | |
| "scales": _decompress_png, | |
| "quats": _decompress_png, | |
| "opacities": _decompress_png, | |
| "sh0": _decompress_png, | |
| "shN": _decompress_kmeans, | |
| } | |
| return decompress_fn_map[param_name] | |
| def compress(self, compress_dir: str, splats: Dict[str, Tensor]) -> None: | |
| """Run compression | |
| Args: | |
| compress_dir (str): directory to save compressed files | |
| splats (Dict[str, Tensor]): Gaussian splats to compress | |
| """ | |
| # Param-specific preprocessing | |
| splats["means"] = log_transform(splats["means"]) | |
| splats["quats"] = F.normalize(splats["quats"], dim=-1) | |
| neg_mask = splats["quats"][:, 3] < 0 | |
| splats["quats"][neg_mask] *= -1 | |
| splats["quats"] = splats["quats"][:, :3] | |
| n_gs = len(splats["means"]) | |
| n_sidelen = int(n_gs**0.5) | |
| n_crop = n_gs - n_sidelen**2 | |
| if n_crop != 0: | |
| splats = _crop_n_splats(splats, n_crop) | |
| print( | |
| f"Warning: Number of Gaussians was not square. Removed {n_crop} Gaussians." | |
| ) | |
| meta = {} | |
| if self.use_sort: | |
| splats = sort_splats(splats) | |
| for param_name in splats.keys(): | |
| compress_fn = self._get_compress_fn(param_name) | |
| kwargs = { | |
| "n_sidelen": n_sidelen, | |
| "verbose": self.verbose, | |
| } | |
| meta[param_name] = compress_fn( | |
| compress_dir, param_name, splats[param_name], **kwargs | |
| ) | |
| with open(os.path.join(compress_dir, "meta.json"), "w") as f: | |
| json.dump(meta, f, indent = 2) | |
| def decompress(self, compress_dir: str) -> Dict[str, Tensor]: | |
| """Run decompression | |
| Args: | |
| compress_dir (str): directory that contains compressed files | |
| Returns: | |
| Dict[str, Tensor]: decompressed Gaussian splats | |
| """ | |
| with open(os.path.join(compress_dir, "meta.json"), "r") as f: | |
| meta = json.load(f) | |
| splats = {} | |
| for param_name, param_meta in meta.items(): | |
| decompress_fn = self._get_decompress_fn(param_name) | |
| splats[param_name] = decompress_fn(compress_dir, param_name, param_meta) | |
| # Param-specific postprocessing | |
| splats["means"] = inverse_log_transform(splats["means"]) | |
| return splats | |
| def write_image(compress_dir, param_name, img): | |
| """ | |
| Compresses the image, currently as lossless webp. Centralized function to change | |
| image encoding in the future if need be. | |
| Args: | |
| compress_dir (str): compression directory | |
| param_name (str): parameter field name | |
| img (np.Array): image data as np array | |
| Returns: | |
| str: filename | |
| """ | |
| filename = f"{param_name}.webp" | |
| Image.fromarray(img).save( | |
| os.path.join(compress_dir, filename), | |
| format="webp", lossless=True, method=6 | |
| ) | |
| # filename = f"{param_name}.png" | |
| # Image.fromarray(img).save( | |
| # os.path.join(compress_dir, filename), | |
| # format="png", optimize=True | |
| # ) | |
| return filename | |
| def _crop_n_splats(splats: Dict[str, Tensor], n_crop: int) -> Dict[str, Tensor]: | |
| opacities = splats["opacities"] | |
| keep_indices = torch.argsort(opacities, descending=True)[:-n_crop] | |
| for k, v in splats.items(): | |
| splats[k] = v[keep_indices] | |
| return splats | |
| def _compress_png( | |
| compress_dir: str, param_name: str, params: Tensor, n_sidelen: int, **kwargs | |
| ) -> Dict[str, Any]: | |
| """Compress parameters with 8-bit quantization and lossless PNG compression. | |
| Args: | |
| compress_dir (str): compression directory | |
| param_name (str): parameter field name | |
| params (Tensor): parameters | |
| n_sidelen (int): image side length | |
| Returns: | |
| Dict[str, Any]: metadata | |
| """ | |
| grid = params.reshape((n_sidelen, n_sidelen, -1)) | |
| mins = torch.amin(grid, dim=(0, 1)) | |
| maxs = torch.amax(grid, dim=(0, 1)) | |
| grid_norm = (grid - mins) / (maxs - mins) | |
| img_norm = grid_norm.detach().cpu().numpy() | |
| img = (img_norm * (2**8 - 1)).round().astype(np.uint8) | |
| img = img.squeeze() | |
| meta = { | |
| "shape": list(params.shape), | |
| "dtype": str(params.dtype).split(".")[1], | |
| "mins": mins.tolist(), | |
| "maxs": maxs.tolist(), | |
| "files": [write_image(compress_dir, param_name, img)] | |
| } | |
| return meta | |
| def _decompress_png(compress_dir: str, param_name: str, meta: Dict[str, Any]) -> Tensor: | |
| """Decompress parameters from PNG file. | |
| Args: | |
| compress_dir (str): compression directory | |
| param_name (str): parameter field name | |
| meta (Dict[str, Any]): metadata | |
| Returns: | |
| Tensor: parameters | |
| """ | |
| if not np.all(meta["shape"]): | |
| params = torch.zeros(meta["shape"], dtype=getattr(torch, meta["dtype"])) | |
| return meta | |
| img = imageio.imread(os.path.join(compress_dir, f"{param_name}.png")) | |
| img_norm = img / (2**8 - 1) | |
| grid_norm = torch.tensor(img_norm) | |
| mins = torch.tensor(meta["mins"]) | |
| maxs = torch.tensor(meta["maxs"]) | |
| grid = grid_norm * (maxs - mins) + mins | |
| params = grid.reshape(meta["shape"]) | |
| params = params.to(dtype=getattr(torch, meta["dtype"])) | |
| return params | |
| def _compress_png_16bit( | |
| compress_dir: str, param_name: str, params: Tensor, n_sidelen: int, **kwargs | |
| ) -> Dict[str, Any]: | |
| """Compress parameters with 16-bit quantization and PNG compression. | |
| Args: | |
| compress_dir (str): compression directory | |
| param_name (str): parameter field name | |
| params (Tensor): parameters | |
| n_sidelen (int): image side length | |
| Returns: | |
| Dict[str, Any]: metadata | |
| """ | |
| grid = params.reshape((n_sidelen, n_sidelen, -1)) | |
| mins = torch.amin(grid, dim=(0, 1)) | |
| maxs = torch.amax(grid, dim=(0, 1)) | |
| grid_norm = (grid - mins) / (maxs - mins) | |
| img_norm = grid_norm.detach().cpu().numpy() | |
| img = (img_norm * (2**16 - 1)).round().astype(np.uint16) | |
| img_l = img & 0xFF | |
| img_u = (img >> 8) & 0xFF | |
| meta = { | |
| "shape": list(params.shape), | |
| "dtype": str(params.dtype).split(".")[1], | |
| "mins": mins.tolist(), | |
| "maxs": maxs.tolist(), | |
| "files": [ | |
| write_image(compress_dir, f"{param_name}_l", img_l.astype(np.uint8)), | |
| write_image(compress_dir, f"{param_name}_u", img_u.astype(np.uint8)) | |
| ] | |
| } | |
| return meta | |
| def _decompress_png_16bit( | |
| compress_dir: str, param_name: str, meta: Dict[str, Any] | |
| ) -> Tensor: | |
| """Decompress parameters from PNG files. | |
| Args: | |
| compress_dir (str): compression directory | |
| param_name (str): parameter field name | |
| meta (Dict[str, Any]): metadata | |
| Returns: | |
| Tensor: parameters | |
| """ | |
| if not np.all(meta["shape"]): | |
| params = torch.zeros(meta["shape"], dtype=getattr(torch, meta["dtype"])) | |
| return meta | |
| img_l = imageio.imread(os.path.join(compress_dir, f"{param_name}_l.png")) | |
| img_u = imageio.imread(os.path.join(compress_dir, f"{param_name}_u.png")) | |
| img_u = img_u.astype(np.uint16) | |
| img = (img_u << 8) + img_l | |
| img_norm = img / (2**16 - 1) | |
| grid_norm = torch.tensor(img_norm) | |
| mins = torch.tensor(meta["mins"]) | |
| maxs = torch.tensor(meta["maxs"]) | |
| grid = grid_norm * (maxs - mins) + mins | |
| params = grid.reshape(meta["shape"]) | |
| params = params.to(dtype=getattr(torch, meta["dtype"])) | |
| return params | |
| def _compress_kmeans( | |
| compress_dir: str, | |
| param_name: str, | |
| params: Tensor, | |
| n_sidelen: int, | |
| quantization: int = 8, | |
| verbose: bool = True, | |
| **kwargs, | |
| ) -> Dict[str, Any]: | |
| """Run K-means clustering on parameters and save centroids and labels as images. | |
| Centroids are saved in a 940 x 1024 image where each row of 15 pixels represents | |
| one SH band. Labels are stored as a 16bit image mapping to each of those centroids | |
| by index. | |
| .. warning:: | |
| TorchPQ must installed to use K-means clustering. | |
| Args: | |
| compress_dir (str): compression directory | |
| param_name (str): parameter field name | |
| params (Tensor): parameters to compress | |
| n_clusters (int): number of K-means clusters | |
| quantization (int): number of bits in quantization | |
| verbose (bool, optional): Whether to print verbose information. Default to True. | |
| Returns: | |
| Dict[str, Any]: metadata | |
| """ | |
| try: | |
| from torchpq.clustering import KMeans | |
| except: | |
| raise ImportError( | |
| "Please install torchpq with 'pip install torchpq' to use K-means clustering" | |
| ) | |
| params = params.reshape(params.shape[0], -1) | |
| dim = params.shape[1] | |
| n_clusters = round((len(params) >> 2) / 64) * 64 | |
| n_clusters = min(n_clusters, 2 ** 16) | |
| kmeans = KMeans(n_clusters=n_clusters, distance="manhattan", verbose=verbose) | |
| labels = kmeans.fit(params.permute(1, 0).contiguous()) | |
| labels = labels.detach().cpu().numpy() | |
| centroids = kmeans.centroids.permute(1, 0) | |
| mins = torch.min(centroids) | |
| maxs = torch.max(centroids) | |
| centroids_norm = (centroids - mins) / (maxs - mins) | |
| centroids_norm = centroids_norm.detach().cpu().numpy() | |
| centroids_quant = ( | |
| (centroids_norm * (2**quantization - 1)).round().astype(np.uint8) | |
| ) | |
| # sorting centroids so that like coefficients stack vertically in the image atlas | |
| # pretty minor optimization but essentially free | |
| sorted_indices = np.lexsort(centroids_quant.T) | |
| sorted_indices = sorted_indices.reshape(64, -1).T.reshape(-1) | |
| sorted_centroids_quant = centroids_quant[sorted_indices] | |
| inverse = np.argsort(sorted_indices) | |
| centroids_packed = sorted_centroids_quant.reshape(-1, int(dim * 64 / 3), 3) | |
| labels = inverse[labels].astype(np.uint16).reshape((n_sidelen, n_sidelen)) | |
| labels_l = labels & 0xFF | |
| labels_u = (labels >> 8) & 0xFF | |
| meta = { | |
| "shape": list(params.shape), | |
| "dtype": str(params.dtype).split(".")[1], | |
| "mins": mins.tolist(), | |
| "maxs": maxs.tolist(), | |
| "quantization": quantization, | |
| "files": [ | |
| write_image(compress_dir, f"{param_name}_centroids", centroids_packed), | |
| write_image(compress_dir, f"{param_name}_labels_l", labels_l.astype(np.uint8)), | |
| write_image(compress_dir, f"{param_name}_labels_u", labels_u.astype(np.uint8)) | |
| ] | |
| } | |
| return meta | |
| def _decompress_kmeans( | |
| compress_dir: str, param_name: str, meta: Dict[str, Any], **kwargs | |
| ) -> Tensor: | |
| """Decompress parameters from K-means compression. | |
| Args: | |
| compress_dir (str): compression directory | |
| param_name (str): parameter field name | |
| meta (Dict[str, Any]): metadata | |
| Returns: | |
| Tensor: parameters | |
| """ | |
| if not np.all(meta["shape"]): | |
| params = torch.zeros(meta["shape"], dtype=getattr(torch, meta["dtype"])) | |
| return meta | |
| npz_dict = np.load(os.path.join(compress_dir, f"{param_name}.npz")) | |
| centroids_quant = npz_dict["centroids"] | |
| labels = npz_dict["labels"] | |
| centroids_norm = centroids_quant / (2 ** meta["quantization"] - 1) | |
| centroids_norm = torch.tensor(centroids_norm) | |
| mins = torch.tensor(meta["mins"]) | |
| maxs = torch.tensor(meta["maxs"]) | |
| centroids = centroids_norm * (maxs - mins) + mins | |
| params = centroids[labels] | |
| params = params.reshape(meta["shape"]) | |
| params = params.to(dtype=getattr(torch, meta["dtype"])) | |
| return params |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment