Skip to content

Instantly share code, notes, and snippets.

@afrendeiro
Last active September 9, 2025 18:47
Show Gist options
  • Select an option

  • Save afrendeiro/eb5b2ab723f89ed64eb81a28b2ad78c4 to your computer and use it in GitHub Desktop.

Select an option

Save afrendeiro/eb5b2ab723f89ed64eb81a28b2ad78c4 to your computer and use it in GitHub Desktop.
A quick demo of cell type classification from graphs
#!/usr/bin/env uv --script
# /// script
# dependencies = [
# "numpy",
# "squidpy>=1.6.5",
# "scikit-network",
# ]
# python_version = ">=3.12"
# ///
# For development:
# uv run --with ipython --with squidpy --with scikit-network ipython
import squidpy as sq
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from tqdm import tqdm
from sknetwork.classification import DiffusionClassifier
from sknetwork.gnn.gnn_classifier import GNNClassifier
from sknetwork.classification import get_accuracy_score, get_average_f1_score
figkws = dict(dpi=300, bbox_inches="tight")
np.random.seed(42)
# Load data and make spatial graph
a = sq.datasets.imc()
sq.gr.spatial_neighbors(a, n_neighs=10, coord_type="generic")
# Plot spatial graph
fig = sq.pl.spatial_scatter(a, shape=None, color="cell type", return_ax=True).figure
fig.savefig("spatial_clustering_demo.data_description.png", **figkws)
#
def predict_cell_types(
a, fraction: float = 0.5, model_family: str = "diffusion", **kwargs
) -> np.ndarray:
"""
Predict cell types using diffusion based on known labels.
"""
adj = a.obsp["spatial_connectivities"]
features = a.X
y = a.obs["cell type"].cat.codes
n_classes = a.obs["cell type"].nunique()
n_nodes = a.shape[0]
n_known_labels = int(fraction * n_nodes)
known_indices = np.random.choice(n_nodes, size=n_known_labels, replace=False)
known_labels = {idx: y.iloc[idx] for idx in known_indices}
shuffled_codes = np.random.permutation(a.obs["cell type"].cat.codes)
shuffled_labels = {idx: shuffled_codes[idx] for idx in known_indices}
if model_family == "diffusion":
model = DiffusionClassifier(**kwargs)
model.fit(adj, labels=known_labels)
shuffled_model = DiffusionClassifier(**kwargs)
shuffled_model.fit(adj, labels=shuffled_labels)
elif model_family == "gnn":
model = GNNClassifier(
dims=int(n_classes),
learning_rate=1e-1,
patience=100,
early_stopping=False,
**kwargs,
)
model.fit(adjacency=adj, features=features, labels=known_labels)
shuffled_model = GNNClassifier(
dims=int(n_classes),
learning_rate=1e-1,
patience=100,
early_stopping=False,
**kwargs,
)
shuffled_model.fit(adjacency=adj, features=features, labels=shuffled_labels)
pred = pd.Series(model.predict(), index=a.obs_names)
a.obs["prediction"] = pred
acc = get_accuracy_score(y, pred)
avg_f1 = get_average_f1_score(y, pred)
shuffled_pred = pd.Series(shuffled_model.predict(), index=a.obs_names)
shuffled_acc = get_accuracy_score(y, shuffled_pred)
shuffled_avg_f1 = get_average_f1_score(a.obs["cell type"].cat.codes, shuffled_pred)
return dict(
fraction=fraction,
accuracy=acc,
average_f1=avg_f1,
shuffled_accuracy=shuffled_acc,
shuffled_average_f1=shuffled_avg_f1,
)
# Run for both models
_res = list()
f = np.exp(np.linspace(1e-5, 1.0, 50))
f = (f - f.min()) / (f.max() - f.min())
for model_family in tqdm(["diffusion", "gnn"], position=0, leave=False, desc="Models"):
for fraction in tqdm(f[1:], position=1, leave=False, desc="Fractions"):
metrics = predict_cell_types(a, fraction=fraction, model_family=model_family)
# print(f"Fraction: {fraction:.2f}, Accuracy: {acc:.4f}, Average F1: {avg_f1:.4f}")
_res.append({"model_family": model_family} | metrics)
res = pd.DataFrame(_res)
res.to_csv("spatial_clustering_demo.subsampling_results.csv", index=False)
# # Visualize results
fig, axes = plt.subplots(2, 1, figsize=(1 * 6, 2 * 4), sharex=True)
for ax, model_family in zip(axes, ["diffusion", "gnn"]):
colors = plt.rcParams["axes.prop_cycle"].by_key()["color"]
r = res.query("model_family == @model_family")
for metric in ["accuracy", "average_f1"]:
color = colors.pop(0)
ax.plot(
r["fraction"],
r[metric],
label=metric.capitalize(),
color=color,
)
ax.plot(
r["fraction"],
r[f"shuffled_{metric}"],
label=f"Shuffled {metric.capitalize()}",
color=color,
linestyle="--",
)
ax.set(xlabel="Fraction of cells", ylabel="Score", ylim=(-0.05, 1.05))
points = np.linspace(0, 1, 200)[2:-1]
ax.plot(*list(zip(points[:-1], points[1:])), linestyle="--", color="gray")
ax.set(title=f"Model = {model_family}", xscale="log")
ax.legend()
fig.tight_layout()
fig.savefig("spatial_clustering_demo.subsampling_benchmarking.png", **figkws)
# Visualize predictions with 50% cells
for fraction in [0.2, 0.5]:
perc = f"{int(fraction * 100):02d}%"
fig, axes = plt.subplots(2, 2, figsize=(2 * 6, 2 * 6))
sq.pl.spatial_scatter(a, shape=None, color="cell type", fig=fig, ax=axes[0][0])
r0 = predict_cell_types(a, fraction=fraction, model_family="diffusion")
a.obs["prediction"] = a.obs["prediction"].astype("category")
sq.pl.spatial_scatter(a, shape=None, color="prediction", fig=fig, ax=axes[1][0])
r1 = predict_cell_types(a, fraction=fraction, model_family="gnn")
a.obs["prediction"] = a.obs["prediction"].astype("category")
sq.pl.spatial_scatter(a, shape=None, color=["prediction"], fig=fig, ax=axes[1][1])
axes[0, 0].set(title="Ground truth: Cell type")
axes[0, 1].axis("off")
axes[1, 0].set(
title=f"Diffusion model w/ {perc} cells\nAccuracy: {r0['accuracy']:.2f}, F1: {r0['average_f1']:.2f}"
)
axes[1, 1].set(
title=f"GNN w/ {perc} cells\nAccuracy: {r1['accuracy']:.2f}, F1: {r1['average_f1']:.2f}"
)
fig.savefig(f"spatial_clustering_demo.subsampling_results.{perc}.png", **figkws)
@afrendeiro
Copy link
Author

afrendeiro commented Jul 1, 2025

Output:
spatial_clustering_demo subsampling_benchmarking
spatial_clustering_demo subsampling_results 20%
spatial_clustering_demo subsampling_results 50%

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment