Skip to content

Instantly share code, notes, and snippets.

@kerrickstaley
Created December 13, 2025 15:01
Show Gist options
  • Select an option

  • Save kerrickstaley/9e63c4f03166e9e85451e1b9a3d555b0 to your computer and use it in GitHub Desktop.

Select an option

Save kerrickstaley/9e63c4f03166e9e85451e1b9a3d555b0 to your computer and use it in GitHub Desktop.
split_nonconvex_meshes helper function
from __future__ import annotations
import hashlib
import json
import shutil
import tempfile
from pathlib import Path
import xml.etree.ElementTree as ET
import coacd
import platformdirs
import trimesh
from frozendict import frozendict
DEFAULT_COACD_KWARGS = frozendict()
CACHE_ROOT = Path(platformdirs.user_cache_dir()) / "mujoco_split_mesh"
def split_nonconvex_meshes(
xml: str, mesh_names: list[str] | str, coacd_kwargs: dict = DEFAULT_COACD_KWARGS
) -> str:
"""
Rewrite a MuJoCo XML string so the given mesh(es) are replaced with convex parts.
The meshes referenced by ``mesh_names`` are decomposed with ``coacd.run_coacd``.
Results are cached under ``platformdirs.user_cache_dir() / 'mujoco_split_mesh'``.
"""
target_meshes = [mesh_names] if isinstance(mesh_names, str) else list(mesh_names)
if not target_meshes:
return xml
coacd_kwargs = dict(coacd_kwargs)
root = ET.fromstring(xml)
mesh_assets = _find_mesh_assets(root)
part_names_by_mesh: dict[str, list[str]] = {}
for target in target_meshes:
mesh_elem = mesh_assets.get(target)
if mesh_elem is None:
raise ValueError(f"Mesh named '{target}' not found in <asset> section.")
mesh_path = mesh_elem.get("file")
if not mesh_path:
raise ValueError(f"Mesh '{target}' is missing a 'file' attribute.")
mesh_path = Path(mesh_path)
if not mesh_path.exists():
raise FileNotFoundError(f"Mesh file for '{target}' not found: {mesh_path}")
cache_dir, part_paths = _get_cached_parts(mesh_path, coacd_kwargs)
combined_path = cache_dir / "combined.stl"
if combined_path.exists():
mesh_elem.set("file", str(combined_path))
if "name" not in mesh_elem.attrib:
mesh_elem.set("name", target)
scale_attr = mesh_elem.get("scale")
new_mesh_names = _inject_asset_meshes(
parent_asset=mesh_elem.getparent() if hasattr(mesh_elem, "getparent") else None,
fallback_parent=_locate_parent_asset(root, mesh_elem),
after_mesh=mesh_elem,
part_paths=part_paths,
base_name=mesh_elem.get("name", target),
scale=scale_attr,
)
part_names_by_mesh[target] = new_mesh_names
_rewrite_geoms(root, part_names_by_mesh)
ET.indent(root, space=" ")
return ET.tostring(root, encoding="unicode")
def _find_mesh_assets(root: ET.Element) -> dict[str, ET.Element]:
meshes: dict[str, ET.Element] = {}
for asset in root.findall(".//asset"):
for mesh in asset.findall("mesh"):
name = mesh.get("name") or Path(mesh.get("file", "")).stem
if name:
meshes[name] = mesh
return meshes
def _hash_cache_key(mesh_path: Path, coacd_kwargs: dict) -> str:
hasher = hashlib.sha256()
hasher.update(json.dumps(coacd_kwargs, sort_keys=True, default=str).encode("utf-8"))
hasher.update(mesh_path.read_bytes())
return hasher.hexdigest()
def _get_cached_parts(mesh_path: Path, coacd_kwargs: dict) -> tuple[Path, list[Path]]:
cache_dir = CACHE_ROOT / _hash_cache_key(mesh_path, coacd_kwargs)
parts = _list_cached_parts(cache_dir)
if parts:
return cache_dir, parts
cache_dir.mkdir(parents=True, exist_ok=True)
combined_target = cache_dir / "combined.stl"
shutil.copy(mesh_path, combined_target)
mesh = trimesh.load(mesh_path, force="mesh")
coacd_mesh = coacd.Mesh(mesh.vertices, mesh.faces)
decomposed = coacd.run_coacd(coacd_mesh, **coacd_kwargs)
if not decomposed:
raise RuntimeError(f"coacd produced no parts for {mesh_path}")
for idx, (verts, faces) in enumerate(decomposed):
part_mesh = trimesh.Trimesh(vertices=verts, faces=faces, process=False)
part_mesh.export(cache_dir / f"{idx}.stl")
parts = _list_cached_parts(cache_dir)
if not parts:
raise RuntimeError(f"Failed to cache convex parts for {mesh_path}")
return cache_dir, parts
def _inject_asset_meshes(
parent_asset: ET.Element | None,
fallback_parent: ET.Element,
after_mesh: ET.Element,
part_paths: list[Path],
base_name: str,
scale: str | None,
) -> list[str]:
asset_parent = parent_asset or fallback_parent
if asset_parent is None:
raise ValueError("<asset> element for mesh is missing.")
children = list(asset_parent)
try:
insert_at = children.index(after_mesh) + 1
except ValueError:
insert_at = len(children)
new_names: list[str] = []
for idx, part_path in enumerate(part_paths):
mesh_name = f"{base_name}_part_{idx:02d}"
attrib = {"file": str(part_path), "name": mesh_name}
if scale:
attrib["scale"] = scale
new_mesh = ET.Element("mesh", attrib)
asset_parent.insert(insert_at + idx, new_mesh)
new_names.append(mesh_name)
return new_names
def _locate_parent_asset(root: ET.Element, child: ET.Element) -> ET.Element:
for asset in root.findall(".//asset"):
if child in list(asset):
return asset
raise ValueError("Could not locate parent <asset> element for mesh.")
def _rewrite_geoms(root: ET.Element, parts_by_mesh: dict[str, list[str]]) -> None:
for parent in root.iter():
children = list(parent)
offset = 0
for idx, child in enumerate(children):
if child.tag != "geom":
continue
mesh_name = child.get("mesh")
if mesh_name not in parts_by_mesh:
continue
template = {k: v for k, v in child.attrib.items() if k not in {"mesh", "name"}}
base_geom_name = child.get("name") or mesh_name
parent.remove(child)
for part_idx, part_name in enumerate(parts_by_mesh[mesh_name]):
new_name = f"{base_geom_name}_part_{part_idx:02d}" if base_geom_name else part_name
new_attrib = dict(template)
new_attrib["mesh"] = part_name
new_attrib["name"] = new_name
parent.insert(idx + offset + part_idx, ET.Element("geom", new_attrib))
offset += len(parts_by_mesh[mesh_name]) - 1
def _list_cached_parts(cache_dir: Path) -> list[Path]:
return sorted(
(p for p in cache_dir.glob("[0-9]*.stl") if p.is_file()),
key=lambda p: int(p.stem) if p.stem.isdigit() else p.stem,
)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment