Skip to content

Instantly share code, notes, and snippets.

@alisterburt
Created December 2, 2025 23:06
Show Gist options
  • Select an option

  • Save alisterburt/edab9839e11cebb0de8efd4ae0779964 to your computer and use it in GitHub Desktop.

Select an option

Save alisterburt/edab9839e11cebb0de8efd4ae0779964 to your computer and use it in GitHub Desktop.
find symmetry axis in cryo-EM map
# /// script
# requires-python = ">=3.11"
# dependencies = [
# "einops",
# "mrcfile",
# "torch",
# "torch-fourier-slice",
# "torch-fourier-shift",
# "scipy",
# "roma",
# "torch-grid-utils",
# "torch-so3",
# "typer",
# "rich",
# "napari[pyqt5]",
# ]
# [tool.uv]
# exclude-newer = "2025-12-01T00:00:00Z"
# ///
from pathlib import Path
from typing import Literal, Optional
import time
import warnings
import einops
import mrcfile
import numpy as np
import torch
import typer
from torch_fourier_shift import fourier_shift_image_3d
from torch_fourier_slice import project_3d_to_2d
from scipy.spatial.transform import Rotation as R
import roma
from torch_grid_utils import coordinate_grid
from torch_so3 import get_uniform_euler_angles
from rich.console import Console
console = Console()
def get_uniform_rotation_matrices(
psi_step: float = 1.5,
theta_step: float = 2.5,
phi_step: Optional[float] = None,
phi_min: float = 0.0,
phi_max: float = 360.0,
theta_min: float = 0.0,
theta_max: float = 180.0,
psi_min: float = 0.0,
psi_max: float = 360.0,
base_grid_method: Literal["uniform", "healpix", "cartesian"] = "uniform",
zyx: bool = False,
) -> torch.Tensor:
"""Generate sets of uniform rotation matrices using Hopf fibration.
This is a passthrough function that calls get_uniform_euler_angles and converts
the resulting Euler angles to rotation matrices.
Parameters
----------
psi_step: float, optional
Angular step for psi in degrees. Default is 1.5 degrees.
theta_step: float, optional
Angular step for theta in degrees. Default is 2.5
degrees.
phi_step: float, optional
Angular step for phi rotation in degrees. Only used when base_grid_method is
"cartesian". Default is 2.5 degrees.
phi_min: float, optional
Minimum value for phi in degrees. Default is 0.0.
phi_max: float, optional
Maximum value for phi in degrees. Default is 360.0.
theta_min: float, optional
Minimum value for theta in degrees. Default is 0.0.
theta_max: float, optional
Maximum value for theta in degrees. Default is 180.0.
psi_min: float, optional
Minimum value for psi in degrees. Default is 0.0.
psi_max: float, optional
Maximum value for psi in degrees. Default is 360.0.
base_grid_method: str, optional
String literal specifying the method to generate the base grid. Default is
"uniform". Options are "uniform", "healpix", and "cartesian".
zyx: bool, optional
If False (default), matrices rotate xyz coordinates. If True, matrices
rotate zyx coordinates. Default is False.
Returns
-------
torch.Tensor
Tensor of shape (N, 3, 3) containing rotation matrices, where N is the
number of rotation matrices generated.
"""
warnings.filterwarnings('ignore', message='phi_step is being ignored for uniform method')
euler_angles = get_uniform_euler_angles(
psi_step=psi_step,
theta_step=theta_step,
phi_step=phi_step,
phi_min=phi_min,
phi_max=phi_max,
theta_min=theta_min,
theta_max=theta_max,
psi_min=psi_min,
psi_max=psi_max,
base_grid_method=base_grid_method,
)
rotation_matrices = roma.euler_to_rotmat("ZYZ", euler_angles, degrees=True)
# rotation matrices rotate xyz coordinates, make them rotate zyx coordinates
# xyz:
# [a b c] [x] [ax + by + cz] [x']
# [d e f] [y] [dx + ey + fz] [y']
# [g h i] [z] = [gx + hy + iz] = [z']
#
# zyx:
# [i h g] [z] [gx + hy + iz] [z']
# [f e d] [y] [dx + ey + fz] [y']
# [c b a] [x] = [ax + by + cz] = [x']
if zyx is True:
rotation_matrices = torch.flip(rotation_matrices, dims=(-2, -1))
return rotation_matrices
def compute_symmetry_correlation(projections: torch.Tensor) -> torch.Tensor:
"""Compute correlation between projections and their symmetry-averaged versions."""
# get first projection from each symmetry group
base_projection = projections[:, 0, :, :] # (b, h, w)
# average over symmetry group, per projection direction
average_projections = einops.reduce(projections, "b s h w -> b h w", reduction="mean") # (b, h, w)
# flatten images for correlation computation
orientation_flat = einops.rearrange(base_projection, "b h w -> b (h w)") # (b, h*w)
average_flat = einops.rearrange(average_projections, "b h w -> b (h w)") # (b, h*w)
# compute normalized cross-correlation (Pearson correlation coefficient)
# center the data
orientation_centered = orientation_flat - orientation_flat.mean(dim=1, keepdim=True)
average_centered = average_flat - average_flat.mean(dim=1, keepdim=True)
# compute correlation
numerator = (orientation_centered * average_centered).sum(dim=1) # (b,)
denominator = torch.sqrt((orientation_centered**2).sum(dim=1) * (average_centered**2).sum(dim=1)) # (b,)
correlation_scores = numerator / denominator # (b,)
return correlation_scores
def compute_centering_shift(volume: torch.Tensor) -> torch.Tensor:
"""Calculate the translation vector needed to center a volume on its center of mass."""
d, h, w = volume.shape
# Create coordinate grids centered at geometric center
grid = coordinate_grid(
image_shape=(d, h, w),
center=(d // 2, h // 2, w // 2)
) # (d, h, w, 3) in zyx order
# Calculate center of mass using only positive density
volume = torch.clamp(volume, min=0)
# Calculate center of mass
total_mass = volume.sum()
weighted_coords = einops.rearrange(volume, "d h w -> d h w 1") * grid
center_of_mass = einops.reduce(weighted_coords, "d h w zyx -> zyx", reduction="sum") / total_mass
return center_of_mass
def main(
input_volume: Path,
symmetry_group: str,
):
"""Find symmetry axis of a 3D volume by searching over projection directions."""
# read volume
console.log(f"reading volume from {input_volume}")
volume_data = mrcfile.read(str(input_volume))
volume_tensor = torch.tensor(volume_data).float()
console.log(f"volume shape (d, h, w): {volume_tensor.shape}")
# calculate centering shift
centering_shift = compute_centering_shift(volume_tensor)
centering_shift_xyz_str = f"({centering_shift[-1]:.3f}, {centering_shift[-2]:.3f}, {centering_shift[-3]:.3f})"
console.log(f"center of mass shift (xyz): {centering_shift_xyz_str}")
# apply shift
console.log("shifting volume...")
volume_tensor = fourier_shift_image_3d(
image=volume_tensor,
shifts=centering_shift # weird that this doesn't need a * -1...
)
console.log("volume shifted")
# generate uniform rotation matrices
console.log(f"generating uniform orientations...")
uniform_rotation_matrices = get_uniform_rotation_matrices(
phi_step=10,
theta_step=10,
theta_min=-90,
theta_max=90,
psi_min=0,
psi_max=0,
zyx=True
) # (b, 3, 3)
n_directions = len(uniform_rotation_matrices)
console.log(f"generated {n_directions} directions to search")
# create symmetry group matrices
console.log(f"finding symmetry related orientations: {symmetry_group}")
symmetry_group_rotation_matrices = R.create_group(group=symmetry_group).as_matrix() # (s, 3, 3)
symmetry_group_rotation_matrices = torch.tensor(symmetry_group_rotation_matrices)
symmetry_group_rotation_matrices = torch.flip(symmetry_group_rotation_matrices, dims=(-2, -1))
symmetry_group_size = len(symmetry_group_rotation_matrices)
console.log(f"symmetry group size: {symmetry_group_size}")
# combine rotation matrices
uniform_rotation_matrices = einops.rearrange(uniform_rotation_matrices, "b i j -> b 1 i j")
rotation_matrices = uniform_rotation_matrices @ symmetry_group_rotation_matrices # (b, s, 3, 3)
# generate projections
console.log("generating projections...")
start_time = time.time()
projections = project_3d_to_2d(
volume=volume_tensor,
rotation_matrices=rotation_matrices,
zyx_matrices=True
)
n_projections = torch.prod(torch.tensor(projections.shape[:-2]))
elapsed_time = time.time() - start_time
console.log(f"generated {n_directions} x {symmetry_group_size} projections in {elapsed_time:.3f} seconds")
# compute correlations
console.log("computing correlations between projections and average of all symmetry related projections...")
correlations = compute_symmetry_correlation(projections)
console.log("correlations computed")
# find top and bottom correlations
top_scores, top_indices = torch.topk(correlations, k=min(10, len(correlations)))
console.log(f"correlation scores - min: {correlations.min():.4f}, max: {correlations.max():.4f}, mean: {correlations.mean():.4f}")
best_symmetry_axis = uniform_rotation_matrices[top_indices[0], :, 0].numpy()[0]
z_vec_str_xyz = f"({best_symmetry_axis[-1]:.2f}, {best_symmetry_axis[-2]:.2f}, {best_symmetry_axis[-3]:.2f})"
best_symmetry_axis_msg = f"best symmetry axis (xyz): {z_vec_str_xyz}"
console.log(best_symmetry_axis_msg)
# visualize results
import napari
console.log("visualizing best symmetry axis...")
# get projections and averages for top correlations (currently unused)
top_base_projections = projections[top_indices, 0, :, :]
top_average_projections = einops.reduce(projections[top_indices], "b s h w -> b h w", reduction="mean")
# setup napari vis
viewer = napari.Viewer(ndisplay=3)
viewer.add_image(volume_tensor.numpy(), name=f"{input_volume.name}")
center = np.asarray(volume_tensor.shape) // 2
vectors_layer_data = np.stack([center, best_symmetry_axis], axis=-2)
d = volume_tensor.shape[0]
viewer.add_vectors(
data=vectors_layer_data,
name=f"best symmetry axis",
length=d / 2,
edge_color="cornflowerblue",
edge_width=5
)
viewer.text_overlay.text = best_symmetry_axis_msg
viewer.text_overlay.position = "top_left"
viewer.text_overlay.visible = True
viewer.text_overlay.font_size = 16
viewer.camera.angles = (-15, 30, 150)
napari.run()
def cli(
input_volume: Path = typer.Option(..., '--input', '-i', help="input MRC volume file"),
symmetry_group: str = typer.Option(..., '--symmetry', '-s', help="symmetry group (e.g., C2, C3, D2)"),
):
"""Find the symmetry axis of a 3D volume by searching over projection directions."""
main(
input_volume=input_volume,
symmetry_group=symmetry_group.upper(),
)
if __name__ == "__main__":
app = typer.Typer(add_completion=False)
app.command(no_args_is_help=True)(cli)
app()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment