Skip to content

Instantly share code, notes, and snippets.

@alisterburt
Created August 12, 2025 03:09
Show Gist options
  • Select an option

  • Save alisterburt/5562e1a84e857d75f15bd3f1bdd15ed9 to your computer and use it in GitHub Desktop.

Select an option

Save alisterburt/5562e1a84e857d75f15bd3f1bdd15ed9 to your computer and use it in GitHub Desktop.
find which symmetry operator best applies to a cubic volume
# /// script
# dependencies = [
# "torch",
# "numpy",
# "scipy",
# "mrcfile",
# "torch-transform-image",
# "torch-affine-utils",
# "matplotlib",
# ]
# ///
import einops
import mrcfile
import numpy as np
import torch
from scipy.spatial.transform import Rotation as R
from torch_transform_image import affine_transform_image_3d
from torch_affine_utils.transforms_3d import T
import matplotlib.pyplot as plt
def calculate_correlation(volume1, volume2):
"""
Calculate the correlation coefficient between two volumes.
Args:
volume1, volume2: 3D numpy arrays
Returns:
Pearson correlation coefficient
"""
# Flatten volumes and calculate correlation
flat1 = volume1.flatten()
flat2 = volume2.flatten()
# Calculate Pearson correlation coefficient
correlation = np.corrcoef(flat1, flat2)[0, 1]
return correlation
def generate_symmetrized_volume(volume, symmetry_operator):
"""
Generate a symmetrized volume from a symmetry operator.
Args:
volume: 3D numpy array or torch tensor representing the volume
symmetry_operator: String representing the symmetry group (e.g., "C6", "D4", "T", "O", "I")
Returns:
numpy array of the symmetrized volume
"""
# Convert to torch tensor if needed
if isinstance(volume, np.ndarray):
volume = torch.as_tensor(volume, dtype=torch.float32)
elif not isinstance(volume, torch.Tensor):
volume = volume.to(torch.float32)
# Get rotation matrices for the symmetry group
rotation_matrices = R.create_group(symmetry_operator).inv().as_matrix()
b = rotation_matrices.shape[0]
# Create affine matrices
affine_matrices = einops.repeat(torch.eye(4), "i j -> b i j", b=b).contiguous()
affine_matrices[:, :3, :3] = torch.as_tensor(rotation_matrices)
# Center the rotation around the volume center
volume_center = torch.as_tensor(volume.shape) // 2
affine_matrices = T(volume_center) @ affine_matrices @ T(-1 * volume_center)
affine_matrices = affine_matrices.to(volume.dtype)
# Apply symmetry operations
symmetrized = volume.clone()
symmetrized += volume
for i in range(1, b):
symmetrized += affine_transform_image_3d(
image=volume,
matrices=affine_matrices[i],
interpolation="trilinear",
)
# Average and convert to numpy
symmetrized = symmetrized.numpy().astype(np.float32) / b
return symmetrized
# Example usage
if __name__ == "__main__":
VOLUME_FILE = "/Users/burta2/Downloads/cryosparc_P68_J1217_006_volume_map.mrc"
# Read volume
print("Reading volume...")
volume = mrcfile.read(VOLUME_FILE)
# Define symmetry groups to test
# C1 through C9
symmetry_groups = [f"C{i}" for i in range(1, 10)]
# Store correlation scores
correlations = []
# Process each symmetry group
for sym in symmetry_groups:
print(f"\nProcessing {sym} symmetry...")
# Generate symmetrized volume
symmetrized = generate_symmetrized_volume(volume, sym)
# Calculate correlation with original
correlation = calculate_correlation(volume, symmetrized)
correlations.append(correlation)
print(f"{sym} correlation: {correlation:.4f}")
# Save symmetrized volume
output_file = f"symmetrized_{sym}.mrc"
mrcfile.write(output_file, data=symmetrized, overwrite=True)
print(f"Saved {output_file}")
# Create bar chart
plt.figure(figsize=(10, 6))
bars = plt.bar(range(1, 10), correlations, color='steelblue', alpha=0.8)
plt.xlabel('Symmetry Order (C_n)', fontsize=12)
plt.ylabel('Correlation with Original Volume', fontsize=12)
plt.title('Correlation Scores for Different Symmetry Operations', fontsize=14)
plt.grid(True, alpha=0.3, axis='y')
plt.xticks(range(1, 10), symmetry_groups)
# Set y-axis limits to show differences better
min_corr = min(correlations)
max_corr = max(correlations)
y_margin = (max_corr - min_corr) * 0.1
plt.ylim(min_corr - y_margin, max_corr + y_margin)
# Add value labels on bars
for bar, corr in zip(bars, correlations):
height = bar.get_height()
plt.text(bar.get_x() + bar.get_width()/2., height + (max_corr - min_corr) * 0.01,
f'{corr:.3f}',
ha='center', va='bottom', fontsize=9)
plt.tight_layout()
plt.savefig('symmetry_correlation_plot.png', dpi=300)
plt.show()
print("\nPlot saved as 'symmetry_correlation_plot.png'")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment