Created
December 2, 2025 23:06
-
-
Save alisterburt/edab9839e11cebb0de8efd4ae0779964 to your computer and use it in GitHub Desktop.
find symmetry axis in cryo-EM map
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| # /// script | |
| # requires-python = ">=3.11" | |
| # dependencies = [ | |
| # "einops", | |
| # "mrcfile", | |
| # "torch", | |
| # "torch-fourier-slice", | |
| # "torch-fourier-shift", | |
| # "scipy", | |
| # "roma", | |
| # "torch-grid-utils", | |
| # "torch-so3", | |
| # "typer", | |
| # "rich", | |
| # "napari[pyqt5]", | |
| # ] | |
| # [tool.uv] | |
| # exclude-newer = "2025-12-01T00:00:00Z" | |
| # /// | |
| from pathlib import Path | |
| from typing import Literal, Optional | |
| import time | |
| import warnings | |
| import einops | |
| import mrcfile | |
| import numpy as np | |
| import torch | |
| import typer | |
| from torch_fourier_shift import fourier_shift_image_3d | |
| from torch_fourier_slice import project_3d_to_2d | |
| from scipy.spatial.transform import Rotation as R | |
| import roma | |
| from torch_grid_utils import coordinate_grid | |
| from torch_so3 import get_uniform_euler_angles | |
| from rich.console import Console | |
| console = Console() | |
| def get_uniform_rotation_matrices( | |
| psi_step: float = 1.5, | |
| theta_step: float = 2.5, | |
| phi_step: Optional[float] = None, | |
| phi_min: float = 0.0, | |
| phi_max: float = 360.0, | |
| theta_min: float = 0.0, | |
| theta_max: float = 180.0, | |
| psi_min: float = 0.0, | |
| psi_max: float = 360.0, | |
| base_grid_method: Literal["uniform", "healpix", "cartesian"] = "uniform", | |
| zyx: bool = False, | |
| ) -> torch.Tensor: | |
| """Generate sets of uniform rotation matrices using Hopf fibration. | |
| This is a passthrough function that calls get_uniform_euler_angles and converts | |
| the resulting Euler angles to rotation matrices. | |
| Parameters | |
| ---------- | |
| psi_step: float, optional | |
| Angular step for psi in degrees. Default is 1.5 degrees. | |
| theta_step: float, optional | |
| Angular step for theta in degrees. Default is 2.5 | |
| degrees. | |
| phi_step: float, optional | |
| Angular step for phi rotation in degrees. Only used when base_grid_method is | |
| "cartesian". Default is 2.5 degrees. | |
| phi_min: float, optional | |
| Minimum value for phi in degrees. Default is 0.0. | |
| phi_max: float, optional | |
| Maximum value for phi in degrees. Default is 360.0. | |
| theta_min: float, optional | |
| Minimum value for theta in degrees. Default is 0.0. | |
| theta_max: float, optional | |
| Maximum value for theta in degrees. Default is 180.0. | |
| psi_min: float, optional | |
| Minimum value for psi in degrees. Default is 0.0. | |
| psi_max: float, optional | |
| Maximum value for psi in degrees. Default is 360.0. | |
| base_grid_method: str, optional | |
| String literal specifying the method to generate the base grid. Default is | |
| "uniform". Options are "uniform", "healpix", and "cartesian". | |
| zyx: bool, optional | |
| If False (default), matrices rotate xyz coordinates. If True, matrices | |
| rotate zyx coordinates. Default is False. | |
| Returns | |
| ------- | |
| torch.Tensor | |
| Tensor of shape (N, 3, 3) containing rotation matrices, where N is the | |
| number of rotation matrices generated. | |
| """ | |
| warnings.filterwarnings('ignore', message='phi_step is being ignored for uniform method') | |
| euler_angles = get_uniform_euler_angles( | |
| psi_step=psi_step, | |
| theta_step=theta_step, | |
| phi_step=phi_step, | |
| phi_min=phi_min, | |
| phi_max=phi_max, | |
| theta_min=theta_min, | |
| theta_max=theta_max, | |
| psi_min=psi_min, | |
| psi_max=psi_max, | |
| base_grid_method=base_grid_method, | |
| ) | |
| rotation_matrices = roma.euler_to_rotmat("ZYZ", euler_angles, degrees=True) | |
| # rotation matrices rotate xyz coordinates, make them rotate zyx coordinates | |
| # xyz: | |
| # [a b c] [x] [ax + by + cz] [x'] | |
| # [d e f] [y] [dx + ey + fz] [y'] | |
| # [g h i] [z] = [gx + hy + iz] = [z'] | |
| # | |
| # zyx: | |
| # [i h g] [z] [gx + hy + iz] [z'] | |
| # [f e d] [y] [dx + ey + fz] [y'] | |
| # [c b a] [x] = [ax + by + cz] = [x'] | |
| if zyx is True: | |
| rotation_matrices = torch.flip(rotation_matrices, dims=(-2, -1)) | |
| return rotation_matrices | |
| def compute_symmetry_correlation(projections: torch.Tensor) -> torch.Tensor: | |
| """Compute correlation between projections and their symmetry-averaged versions.""" | |
| # get first projection from each symmetry group | |
| base_projection = projections[:, 0, :, :] # (b, h, w) | |
| # average over symmetry group, per projection direction | |
| average_projections = einops.reduce(projections, "b s h w -> b h w", reduction="mean") # (b, h, w) | |
| # flatten images for correlation computation | |
| orientation_flat = einops.rearrange(base_projection, "b h w -> b (h w)") # (b, h*w) | |
| average_flat = einops.rearrange(average_projections, "b h w -> b (h w)") # (b, h*w) | |
| # compute normalized cross-correlation (Pearson correlation coefficient) | |
| # center the data | |
| orientation_centered = orientation_flat - orientation_flat.mean(dim=1, keepdim=True) | |
| average_centered = average_flat - average_flat.mean(dim=1, keepdim=True) | |
| # compute correlation | |
| numerator = (orientation_centered * average_centered).sum(dim=1) # (b,) | |
| denominator = torch.sqrt((orientation_centered**2).sum(dim=1) * (average_centered**2).sum(dim=1)) # (b,) | |
| correlation_scores = numerator / denominator # (b,) | |
| return correlation_scores | |
| def compute_centering_shift(volume: torch.Tensor) -> torch.Tensor: | |
| """Calculate the translation vector needed to center a volume on its center of mass.""" | |
| d, h, w = volume.shape | |
| # Create coordinate grids centered at geometric center | |
| grid = coordinate_grid( | |
| image_shape=(d, h, w), | |
| center=(d // 2, h // 2, w // 2) | |
| ) # (d, h, w, 3) in zyx order | |
| # Calculate center of mass using only positive density | |
| volume = torch.clamp(volume, min=0) | |
| # Calculate center of mass | |
| total_mass = volume.sum() | |
| weighted_coords = einops.rearrange(volume, "d h w -> d h w 1") * grid | |
| center_of_mass = einops.reduce(weighted_coords, "d h w zyx -> zyx", reduction="sum") / total_mass | |
| return center_of_mass | |
| def main( | |
| input_volume: Path, | |
| symmetry_group: str, | |
| ): | |
| """Find symmetry axis of a 3D volume by searching over projection directions.""" | |
| # read volume | |
| console.log(f"reading volume from {input_volume}") | |
| volume_data = mrcfile.read(str(input_volume)) | |
| volume_tensor = torch.tensor(volume_data).float() | |
| console.log(f"volume shape (d, h, w): {volume_tensor.shape}") | |
| # calculate centering shift | |
| centering_shift = compute_centering_shift(volume_tensor) | |
| centering_shift_xyz_str = f"({centering_shift[-1]:.3f}, {centering_shift[-2]:.3f}, {centering_shift[-3]:.3f})" | |
| console.log(f"center of mass shift (xyz): {centering_shift_xyz_str}") | |
| # apply shift | |
| console.log("shifting volume...") | |
| volume_tensor = fourier_shift_image_3d( | |
| image=volume_tensor, | |
| shifts=centering_shift # weird that this doesn't need a * -1... | |
| ) | |
| console.log("volume shifted") | |
| # generate uniform rotation matrices | |
| console.log(f"generating uniform orientations...") | |
| uniform_rotation_matrices = get_uniform_rotation_matrices( | |
| phi_step=10, | |
| theta_step=10, | |
| theta_min=-90, | |
| theta_max=90, | |
| psi_min=0, | |
| psi_max=0, | |
| zyx=True | |
| ) # (b, 3, 3) | |
| n_directions = len(uniform_rotation_matrices) | |
| console.log(f"generated {n_directions} directions to search") | |
| # create symmetry group matrices | |
| console.log(f"finding symmetry related orientations: {symmetry_group}") | |
| symmetry_group_rotation_matrices = R.create_group(group=symmetry_group).as_matrix() # (s, 3, 3) | |
| symmetry_group_rotation_matrices = torch.tensor(symmetry_group_rotation_matrices) | |
| symmetry_group_rotation_matrices = torch.flip(symmetry_group_rotation_matrices, dims=(-2, -1)) | |
| symmetry_group_size = len(symmetry_group_rotation_matrices) | |
| console.log(f"symmetry group size: {symmetry_group_size}") | |
| # combine rotation matrices | |
| uniform_rotation_matrices = einops.rearrange(uniform_rotation_matrices, "b i j -> b 1 i j") | |
| rotation_matrices = uniform_rotation_matrices @ symmetry_group_rotation_matrices # (b, s, 3, 3) | |
| # generate projections | |
| console.log("generating projections...") | |
| start_time = time.time() | |
| projections = project_3d_to_2d( | |
| volume=volume_tensor, | |
| rotation_matrices=rotation_matrices, | |
| zyx_matrices=True | |
| ) | |
| n_projections = torch.prod(torch.tensor(projections.shape[:-2])) | |
| elapsed_time = time.time() - start_time | |
| console.log(f"generated {n_directions} x {symmetry_group_size} projections in {elapsed_time:.3f} seconds") | |
| # compute correlations | |
| console.log("computing correlations between projections and average of all symmetry related projections...") | |
| correlations = compute_symmetry_correlation(projections) | |
| console.log("correlations computed") | |
| # find top and bottom correlations | |
| top_scores, top_indices = torch.topk(correlations, k=min(10, len(correlations))) | |
| console.log(f"correlation scores - min: {correlations.min():.4f}, max: {correlations.max():.4f}, mean: {correlations.mean():.4f}") | |
| best_symmetry_axis = uniform_rotation_matrices[top_indices[0], :, 0].numpy()[0] | |
| z_vec_str_xyz = f"({best_symmetry_axis[-1]:.2f}, {best_symmetry_axis[-2]:.2f}, {best_symmetry_axis[-3]:.2f})" | |
| best_symmetry_axis_msg = f"best symmetry axis (xyz): {z_vec_str_xyz}" | |
| console.log(best_symmetry_axis_msg) | |
| # visualize results | |
| import napari | |
| console.log("visualizing best symmetry axis...") | |
| # get projections and averages for top correlations (currently unused) | |
| top_base_projections = projections[top_indices, 0, :, :] | |
| top_average_projections = einops.reduce(projections[top_indices], "b s h w -> b h w", reduction="mean") | |
| # setup napari vis | |
| viewer = napari.Viewer(ndisplay=3) | |
| viewer.add_image(volume_tensor.numpy(), name=f"{input_volume.name}") | |
| center = np.asarray(volume_tensor.shape) // 2 | |
| vectors_layer_data = np.stack([center, best_symmetry_axis], axis=-2) | |
| d = volume_tensor.shape[0] | |
| viewer.add_vectors( | |
| data=vectors_layer_data, | |
| name=f"best symmetry axis", | |
| length=d / 2, | |
| edge_color="cornflowerblue", | |
| edge_width=5 | |
| ) | |
| viewer.text_overlay.text = best_symmetry_axis_msg | |
| viewer.text_overlay.position = "top_left" | |
| viewer.text_overlay.visible = True | |
| viewer.text_overlay.font_size = 16 | |
| viewer.camera.angles = (-15, 30, 150) | |
| napari.run() | |
| def cli( | |
| input_volume: Path = typer.Option(..., '--input', '-i', help="input MRC volume file"), | |
| symmetry_group: str = typer.Option(..., '--symmetry', '-s', help="symmetry group (e.g., C2, C3, D2)"), | |
| ): | |
| """Find the symmetry axis of a 3D volume by searching over projection directions.""" | |
| main( | |
| input_volume=input_volume, | |
| symmetry_group=symmetry_group.upper(), | |
| ) | |
| if __name__ == "__main__": | |
| app = typer.Typer(add_completion=False) | |
| app.command(no_args_is_help=True)(cli) | |
| app() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment