Created
December 13, 2025 15:01
-
-
Save kerrickstaley/9e63c4f03166e9e85451e1b9a3d555b0 to your computer and use it in GitHub Desktop.
split_nonconvex_meshes helper function
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| from __future__ import annotations | |
| import hashlib | |
| import json | |
| import shutil | |
| import tempfile | |
| from pathlib import Path | |
| import xml.etree.ElementTree as ET | |
| import coacd | |
| import platformdirs | |
| import trimesh | |
| from frozendict import frozendict | |
| DEFAULT_COACD_KWARGS = frozendict() | |
| CACHE_ROOT = Path(platformdirs.user_cache_dir()) / "mujoco_split_mesh" | |
| def split_nonconvex_meshes( | |
| xml: str, mesh_names: list[str] | str, coacd_kwargs: dict = DEFAULT_COACD_KWARGS | |
| ) -> str: | |
| """ | |
| Rewrite a MuJoCo XML string so the given mesh(es) are replaced with convex parts. | |
| The meshes referenced by ``mesh_names`` are decomposed with ``coacd.run_coacd``. | |
| Results are cached under ``platformdirs.user_cache_dir() / 'mujoco_split_mesh'``. | |
| """ | |
| target_meshes = [mesh_names] if isinstance(mesh_names, str) else list(mesh_names) | |
| if not target_meshes: | |
| return xml | |
| coacd_kwargs = dict(coacd_kwargs) | |
| root = ET.fromstring(xml) | |
| mesh_assets = _find_mesh_assets(root) | |
| part_names_by_mesh: dict[str, list[str]] = {} | |
| for target in target_meshes: | |
| mesh_elem = mesh_assets.get(target) | |
| if mesh_elem is None: | |
| raise ValueError(f"Mesh named '{target}' not found in <asset> section.") | |
| mesh_path = mesh_elem.get("file") | |
| if not mesh_path: | |
| raise ValueError(f"Mesh '{target}' is missing a 'file' attribute.") | |
| mesh_path = Path(mesh_path) | |
| if not mesh_path.exists(): | |
| raise FileNotFoundError(f"Mesh file for '{target}' not found: {mesh_path}") | |
| cache_dir, part_paths = _get_cached_parts(mesh_path, coacd_kwargs) | |
| combined_path = cache_dir / "combined.stl" | |
| if combined_path.exists(): | |
| mesh_elem.set("file", str(combined_path)) | |
| if "name" not in mesh_elem.attrib: | |
| mesh_elem.set("name", target) | |
| scale_attr = mesh_elem.get("scale") | |
| new_mesh_names = _inject_asset_meshes( | |
| parent_asset=mesh_elem.getparent() if hasattr(mesh_elem, "getparent") else None, | |
| fallback_parent=_locate_parent_asset(root, mesh_elem), | |
| after_mesh=mesh_elem, | |
| part_paths=part_paths, | |
| base_name=mesh_elem.get("name", target), | |
| scale=scale_attr, | |
| ) | |
| part_names_by_mesh[target] = new_mesh_names | |
| _rewrite_geoms(root, part_names_by_mesh) | |
| ET.indent(root, space=" ") | |
| return ET.tostring(root, encoding="unicode") | |
| def _find_mesh_assets(root: ET.Element) -> dict[str, ET.Element]: | |
| meshes: dict[str, ET.Element] = {} | |
| for asset in root.findall(".//asset"): | |
| for mesh in asset.findall("mesh"): | |
| name = mesh.get("name") or Path(mesh.get("file", "")).stem | |
| if name: | |
| meshes[name] = mesh | |
| return meshes | |
| def _hash_cache_key(mesh_path: Path, coacd_kwargs: dict) -> str: | |
| hasher = hashlib.sha256() | |
| hasher.update(json.dumps(coacd_kwargs, sort_keys=True, default=str).encode("utf-8")) | |
| hasher.update(mesh_path.read_bytes()) | |
| return hasher.hexdigest() | |
| def _get_cached_parts(mesh_path: Path, coacd_kwargs: dict) -> tuple[Path, list[Path]]: | |
| cache_dir = CACHE_ROOT / _hash_cache_key(mesh_path, coacd_kwargs) | |
| parts = _list_cached_parts(cache_dir) | |
| if parts: | |
| return cache_dir, parts | |
| cache_dir.mkdir(parents=True, exist_ok=True) | |
| combined_target = cache_dir / "combined.stl" | |
| shutil.copy(mesh_path, combined_target) | |
| mesh = trimesh.load(mesh_path, force="mesh") | |
| coacd_mesh = coacd.Mesh(mesh.vertices, mesh.faces) | |
| decomposed = coacd.run_coacd(coacd_mesh, **coacd_kwargs) | |
| if not decomposed: | |
| raise RuntimeError(f"coacd produced no parts for {mesh_path}") | |
| for idx, (verts, faces) in enumerate(decomposed): | |
| part_mesh = trimesh.Trimesh(vertices=verts, faces=faces, process=False) | |
| part_mesh.export(cache_dir / f"{idx}.stl") | |
| parts = _list_cached_parts(cache_dir) | |
| if not parts: | |
| raise RuntimeError(f"Failed to cache convex parts for {mesh_path}") | |
| return cache_dir, parts | |
| def _inject_asset_meshes( | |
| parent_asset: ET.Element | None, | |
| fallback_parent: ET.Element, | |
| after_mesh: ET.Element, | |
| part_paths: list[Path], | |
| base_name: str, | |
| scale: str | None, | |
| ) -> list[str]: | |
| asset_parent = parent_asset or fallback_parent | |
| if asset_parent is None: | |
| raise ValueError("<asset> element for mesh is missing.") | |
| children = list(asset_parent) | |
| try: | |
| insert_at = children.index(after_mesh) + 1 | |
| except ValueError: | |
| insert_at = len(children) | |
| new_names: list[str] = [] | |
| for idx, part_path in enumerate(part_paths): | |
| mesh_name = f"{base_name}_part_{idx:02d}" | |
| attrib = {"file": str(part_path), "name": mesh_name} | |
| if scale: | |
| attrib["scale"] = scale | |
| new_mesh = ET.Element("mesh", attrib) | |
| asset_parent.insert(insert_at + idx, new_mesh) | |
| new_names.append(mesh_name) | |
| return new_names | |
| def _locate_parent_asset(root: ET.Element, child: ET.Element) -> ET.Element: | |
| for asset in root.findall(".//asset"): | |
| if child in list(asset): | |
| return asset | |
| raise ValueError("Could not locate parent <asset> element for mesh.") | |
| def _rewrite_geoms(root: ET.Element, parts_by_mesh: dict[str, list[str]]) -> None: | |
| for parent in root.iter(): | |
| children = list(parent) | |
| offset = 0 | |
| for idx, child in enumerate(children): | |
| if child.tag != "geom": | |
| continue | |
| mesh_name = child.get("mesh") | |
| if mesh_name not in parts_by_mesh: | |
| continue | |
| template = {k: v for k, v in child.attrib.items() if k not in {"mesh", "name"}} | |
| base_geom_name = child.get("name") or mesh_name | |
| parent.remove(child) | |
| for part_idx, part_name in enumerate(parts_by_mesh[mesh_name]): | |
| new_name = f"{base_geom_name}_part_{part_idx:02d}" if base_geom_name else part_name | |
| new_attrib = dict(template) | |
| new_attrib["mesh"] = part_name | |
| new_attrib["name"] = new_name | |
| parent.insert(idx + offset + part_idx, ET.Element("geom", new_attrib)) | |
| offset += len(parts_by_mesh[mesh_name]) - 1 | |
| def _list_cached_parts(cache_dir: Path) -> list[Path]: | |
| return sorted( | |
| (p for p in cache_dir.glob("[0-9]*.stl") if p.is_file()), | |
| key=lambda p: int(p.stem) if p.stem.isdigit() else p.stem, | |
| ) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment