Last active
September 9, 2025 18:47
-
-
Save afrendeiro/eb5b2ab723f89ed64eb81a28b2ad78c4 to your computer and use it in GitHub Desktop.
A quick demo of cell type classification from graphs
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| #!/usr/bin/env uv --script | |
| # /// script | |
| # dependencies = [ | |
| # "numpy", | |
| # "squidpy>=1.6.5", | |
| # "scikit-network", | |
| # ] | |
| # python_version = ">=3.12" | |
| # /// | |
| # For development: | |
| # uv run --with ipython --with squidpy --with scikit-network ipython | |
| import squidpy as sq | |
| import numpy as np | |
| import pandas as pd | |
| import matplotlib.pyplot as plt | |
| from tqdm import tqdm | |
| from sknetwork.classification import DiffusionClassifier | |
| from sknetwork.gnn.gnn_classifier import GNNClassifier | |
| from sknetwork.classification import get_accuracy_score, get_average_f1_score | |
| figkws = dict(dpi=300, bbox_inches="tight") | |
| np.random.seed(42) | |
| # Load data and make spatial graph | |
| a = sq.datasets.imc() | |
| sq.gr.spatial_neighbors(a, n_neighs=10, coord_type="generic") | |
| # Plot spatial graph | |
| fig = sq.pl.spatial_scatter(a, shape=None, color="cell type", return_ax=True).figure | |
| fig.savefig("spatial_clustering_demo.data_description.png", **figkws) | |
| # | |
| def predict_cell_types( | |
| a, fraction: float = 0.5, model_family: str = "diffusion", **kwargs | |
| ) -> np.ndarray: | |
| """ | |
| Predict cell types using diffusion based on known labels. | |
| """ | |
| adj = a.obsp["spatial_connectivities"] | |
| features = a.X | |
| y = a.obs["cell type"].cat.codes | |
| n_classes = a.obs["cell type"].nunique() | |
| n_nodes = a.shape[0] | |
| n_known_labels = int(fraction * n_nodes) | |
| known_indices = np.random.choice(n_nodes, size=n_known_labels, replace=False) | |
| known_labels = {idx: y.iloc[idx] for idx in known_indices} | |
| shuffled_codes = np.random.permutation(a.obs["cell type"].cat.codes) | |
| shuffled_labels = {idx: shuffled_codes[idx] for idx in known_indices} | |
| if model_family == "diffusion": | |
| model = DiffusionClassifier(**kwargs) | |
| model.fit(adj, labels=known_labels) | |
| shuffled_model = DiffusionClassifier(**kwargs) | |
| shuffled_model.fit(adj, labels=shuffled_labels) | |
| elif model_family == "gnn": | |
| model = GNNClassifier( | |
| dims=int(n_classes), | |
| learning_rate=1e-1, | |
| patience=100, | |
| early_stopping=False, | |
| **kwargs, | |
| ) | |
| model.fit(adjacency=adj, features=features, labels=known_labels) | |
| shuffled_model = GNNClassifier( | |
| dims=int(n_classes), | |
| learning_rate=1e-1, | |
| patience=100, | |
| early_stopping=False, | |
| **kwargs, | |
| ) | |
| shuffled_model.fit(adjacency=adj, features=features, labels=shuffled_labels) | |
| pred = pd.Series(model.predict(), index=a.obs_names) | |
| a.obs["prediction"] = pred | |
| acc = get_accuracy_score(y, pred) | |
| avg_f1 = get_average_f1_score(y, pred) | |
| shuffled_pred = pd.Series(shuffled_model.predict(), index=a.obs_names) | |
| shuffled_acc = get_accuracy_score(y, shuffled_pred) | |
| shuffled_avg_f1 = get_average_f1_score(a.obs["cell type"].cat.codes, shuffled_pred) | |
| return dict( | |
| fraction=fraction, | |
| accuracy=acc, | |
| average_f1=avg_f1, | |
| shuffled_accuracy=shuffled_acc, | |
| shuffled_average_f1=shuffled_avg_f1, | |
| ) | |
| # Run for both models | |
| _res = list() | |
| f = np.exp(np.linspace(1e-5, 1.0, 50)) | |
| f = (f - f.min()) / (f.max() - f.min()) | |
| for model_family in tqdm(["diffusion", "gnn"], position=0, leave=False, desc="Models"): | |
| for fraction in tqdm(f[1:], position=1, leave=False, desc="Fractions"): | |
| metrics = predict_cell_types(a, fraction=fraction, model_family=model_family) | |
| # print(f"Fraction: {fraction:.2f}, Accuracy: {acc:.4f}, Average F1: {avg_f1:.4f}") | |
| _res.append({"model_family": model_family} | metrics) | |
| res = pd.DataFrame(_res) | |
| res.to_csv("spatial_clustering_demo.subsampling_results.csv", index=False) | |
| # # Visualize results | |
| fig, axes = plt.subplots(2, 1, figsize=(1 * 6, 2 * 4), sharex=True) | |
| for ax, model_family in zip(axes, ["diffusion", "gnn"]): | |
| colors = plt.rcParams["axes.prop_cycle"].by_key()["color"] | |
| r = res.query("model_family == @model_family") | |
| for metric in ["accuracy", "average_f1"]: | |
| color = colors.pop(0) | |
| ax.plot( | |
| r["fraction"], | |
| r[metric], | |
| label=metric.capitalize(), | |
| color=color, | |
| ) | |
| ax.plot( | |
| r["fraction"], | |
| r[f"shuffled_{metric}"], | |
| label=f"Shuffled {metric.capitalize()}", | |
| color=color, | |
| linestyle="--", | |
| ) | |
| ax.set(xlabel="Fraction of cells", ylabel="Score", ylim=(-0.05, 1.05)) | |
| points = np.linspace(0, 1, 200)[2:-1] | |
| ax.plot(*list(zip(points[:-1], points[1:])), linestyle="--", color="gray") | |
| ax.set(title=f"Model = {model_family}", xscale="log") | |
| ax.legend() | |
| fig.tight_layout() | |
| fig.savefig("spatial_clustering_demo.subsampling_benchmarking.png", **figkws) | |
| # Visualize predictions with 50% cells | |
| for fraction in [0.2, 0.5]: | |
| perc = f"{int(fraction * 100):02d}%" | |
| fig, axes = plt.subplots(2, 2, figsize=(2 * 6, 2 * 6)) | |
| sq.pl.spatial_scatter(a, shape=None, color="cell type", fig=fig, ax=axes[0][0]) | |
| r0 = predict_cell_types(a, fraction=fraction, model_family="diffusion") | |
| a.obs["prediction"] = a.obs["prediction"].astype("category") | |
| sq.pl.spatial_scatter(a, shape=None, color="prediction", fig=fig, ax=axes[1][0]) | |
| r1 = predict_cell_types(a, fraction=fraction, model_family="gnn") | |
| a.obs["prediction"] = a.obs["prediction"].astype("category") | |
| sq.pl.spatial_scatter(a, shape=None, color=["prediction"], fig=fig, ax=axes[1][1]) | |
| axes[0, 0].set(title="Ground truth: Cell type") | |
| axes[0, 1].axis("off") | |
| axes[1, 0].set( | |
| title=f"Diffusion model w/ {perc} cells\nAccuracy: {r0['accuracy']:.2f}, F1: {r0['average_f1']:.2f}" | |
| ) | |
| axes[1, 1].set( | |
| title=f"GNN w/ {perc} cells\nAccuracy: {r1['accuracy']:.2f}, F1: {r1['average_f1']:.2f}" | |
| ) | |
| fig.savefig(f"spatial_clustering_demo.subsampling_results.{perc}.png", **figkws) |
Author
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Output:


