Created
August 12, 2025 03:09
-
-
Save alisterburt/5562e1a84e857d75f15bd3f1bdd15ed9 to your computer and use it in GitHub Desktop.
find which symmetry operator best applies to a cubic volume
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| # /// script | |
| # dependencies = [ | |
| # "torch", | |
| # "numpy", | |
| # "scipy", | |
| # "mrcfile", | |
| # "torch-transform-image", | |
| # "torch-affine-utils", | |
| # "matplotlib", | |
| # ] | |
| # /// | |
| import einops | |
| import mrcfile | |
| import numpy as np | |
| import torch | |
| from scipy.spatial.transform import Rotation as R | |
| from torch_transform_image import affine_transform_image_3d | |
| from torch_affine_utils.transforms_3d import T | |
| import matplotlib.pyplot as plt | |
| def calculate_correlation(volume1, volume2): | |
| """ | |
| Calculate the correlation coefficient between two volumes. | |
| Args: | |
| volume1, volume2: 3D numpy arrays | |
| Returns: | |
| Pearson correlation coefficient | |
| """ | |
| # Flatten volumes and calculate correlation | |
| flat1 = volume1.flatten() | |
| flat2 = volume2.flatten() | |
| # Calculate Pearson correlation coefficient | |
| correlation = np.corrcoef(flat1, flat2)[0, 1] | |
| return correlation | |
| def generate_symmetrized_volume(volume, symmetry_operator): | |
| """ | |
| Generate a symmetrized volume from a symmetry operator. | |
| Args: | |
| volume: 3D numpy array or torch tensor representing the volume | |
| symmetry_operator: String representing the symmetry group (e.g., "C6", "D4", "T", "O", "I") | |
| Returns: | |
| numpy array of the symmetrized volume | |
| """ | |
| # Convert to torch tensor if needed | |
| if isinstance(volume, np.ndarray): | |
| volume = torch.as_tensor(volume, dtype=torch.float32) | |
| elif not isinstance(volume, torch.Tensor): | |
| volume = volume.to(torch.float32) | |
| # Get rotation matrices for the symmetry group | |
| rotation_matrices = R.create_group(symmetry_operator).inv().as_matrix() | |
| b = rotation_matrices.shape[0] | |
| # Create affine matrices | |
| affine_matrices = einops.repeat(torch.eye(4), "i j -> b i j", b=b).contiguous() | |
| affine_matrices[:, :3, :3] = torch.as_tensor(rotation_matrices) | |
| # Center the rotation around the volume center | |
| volume_center = torch.as_tensor(volume.shape) // 2 | |
| affine_matrices = T(volume_center) @ affine_matrices @ T(-1 * volume_center) | |
| affine_matrices = affine_matrices.to(volume.dtype) | |
| # Apply symmetry operations | |
| symmetrized = volume.clone() | |
| symmetrized += volume | |
| for i in range(1, b): | |
| symmetrized += affine_transform_image_3d( | |
| image=volume, | |
| matrices=affine_matrices[i], | |
| interpolation="trilinear", | |
| ) | |
| # Average and convert to numpy | |
| symmetrized = symmetrized.numpy().astype(np.float32) / b | |
| return symmetrized | |
| # Example usage | |
| if __name__ == "__main__": | |
| VOLUME_FILE = "/Users/burta2/Downloads/cryosparc_P68_J1217_006_volume_map.mrc" | |
| # Read volume | |
| print("Reading volume...") | |
| volume = mrcfile.read(VOLUME_FILE) | |
| # Define symmetry groups to test | |
| # C1 through C9 | |
| symmetry_groups = [f"C{i}" for i in range(1, 10)] | |
| # Store correlation scores | |
| correlations = [] | |
| # Process each symmetry group | |
| for sym in symmetry_groups: | |
| print(f"\nProcessing {sym} symmetry...") | |
| # Generate symmetrized volume | |
| symmetrized = generate_symmetrized_volume(volume, sym) | |
| # Calculate correlation with original | |
| correlation = calculate_correlation(volume, symmetrized) | |
| correlations.append(correlation) | |
| print(f"{sym} correlation: {correlation:.4f}") | |
| # Save symmetrized volume | |
| output_file = f"symmetrized_{sym}.mrc" | |
| mrcfile.write(output_file, data=symmetrized, overwrite=True) | |
| print(f"Saved {output_file}") | |
| # Create bar chart | |
| plt.figure(figsize=(10, 6)) | |
| bars = plt.bar(range(1, 10), correlations, color='steelblue', alpha=0.8) | |
| plt.xlabel('Symmetry Order (C_n)', fontsize=12) | |
| plt.ylabel('Correlation with Original Volume', fontsize=12) | |
| plt.title('Correlation Scores for Different Symmetry Operations', fontsize=14) | |
| plt.grid(True, alpha=0.3, axis='y') | |
| plt.xticks(range(1, 10), symmetry_groups) | |
| # Set y-axis limits to show differences better | |
| min_corr = min(correlations) | |
| max_corr = max(correlations) | |
| y_margin = (max_corr - min_corr) * 0.1 | |
| plt.ylim(min_corr - y_margin, max_corr + y_margin) | |
| # Add value labels on bars | |
| for bar, corr in zip(bars, correlations): | |
| height = bar.get_height() | |
| plt.text(bar.get_x() + bar.get_width()/2., height + (max_corr - min_corr) * 0.01, | |
| f'{corr:.3f}', | |
| ha='center', va='bottom', fontsize=9) | |
| plt.tight_layout() | |
| plt.savefig('symmetry_correlation_plot.png', dpi=300) | |
| plt.show() | |
| print("\nPlot saved as 'symmetry_correlation_plot.png'") | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment