Last active
March 6, 2025 00:33
-
-
Save Hironsan/ea1a1df443b30d5428469f118facc3a5 to your computer and use it in GitHub Desktop.
Python implementation of infinite relational model
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import random | |
| from collections import defaultdict | |
| from collections.abc import Iterator | |
| from typing import Literal, Optional | |
| import networkx as nx | |
| import numpy as np | |
| import numpy.typing as npt | |
| import tqdm | |
| from scipy.special import betaln, logsumexp | |
| from scipy.stats import mode | |
| from sklearn.metrics import adjusted_rand_score | |
| def chinese_restaurant_process(n: int, alpha: float, seed: Optional[int] = None) -> npt.NDArray: | |
| """Generate cluster assignments for n elements using the Chinese Restaurant Process. | |
| Args: | |
| n (int): Number of elements | |
| alpha (float): Parameter of the CRP | |
| seed (int | None): Random seed | |
| Returns: | |
| z (npt.NDArray): Cluster assignments for each element | |
| """ | |
| assignments = [] # Cluster assignments | |
| table_counts = [] # Number of customers at each table | |
| rng = np.random.default_rng(seed) | |
| for i in range(n): | |
| if i == 0: | |
| # The first customer sits at a new table | |
| assignments.append(0) | |
| table_counts.append(1) | |
| else: | |
| probs = np.array([*table_counts, alpha]) # Existing tables + new table | |
| probs /= probs.sum() # Convert to probabilities | |
| choice = rng.choice(len(probs), p=probs) # Choose a table | |
| if choice == len(table_counts): | |
| # Create a new table | |
| table_counts.append(1) | |
| assignments.append(len(table_counts) - 1) | |
| else: | |
| # Join an existing table | |
| table_counts[choice] += 1 | |
| assignments.append(choice) | |
| return np.array(assignments) | |
| class Object: | |
| def __init__(self, domain: Literal[0, 1], index: int) -> None: | |
| self.domain = domain | |
| self.index = index | |
| def is_first_domain(self) -> bool: | |
| """Return True if the object's domain is 0, otherwise False.""" | |
| return self.domain == 0 | |
| class Objects: | |
| def __init__(self, num_rows: int, num_cols: int) -> None: | |
| self.objects = [ | |
| *[Object(0, i) for i in range(num_rows)], | |
| *[Object(1, i) for i in range(num_cols)], | |
| ] | |
| def __iter__(self) -> Iterator[Object]: | |
| """Shuffles the objects and returns an iterator over them.""" | |
| random.shuffle(self.objects) | |
| return iter(self.objects) | |
| class SufficientStatistics: | |
| def __init__(self, z1: npt.NDArray, z2: npt.NDArray, X: npt.NDArray) -> None: | |
| K = len(np.unique(z1)) | |
| L = len(np.unique(z2)) | |
| z1_onehot = np.eye(K)[z1] | |
| z2_onehot = np.eye(L)[z2] | |
| n_pos = z1_onehot.T @ X @ z2_onehot | |
| n_neg = z1_onehot.T @ (1 - X) @ z2_onehot | |
| self.n_pos = defaultdict(int) | |
| self.n_neg = defaultdict(int) | |
| for k in range(K): | |
| for l in range(L): | |
| self.n_pos[(k, l)] = n_pos[k, l] | |
| self.n_neg[(k, l)] = n_neg[k, l] | |
| def update_row( | |
| self, | |
| X: npt.NDArray, | |
| z1: npt.NDArray, | |
| z2: npt.NDArray, | |
| i: int, | |
| increment: bool = True, | |
| ) -> None: | |
| delta = 1 if increment else -1 | |
| k = z1[i] | |
| for j in range(X.shape[1]): | |
| l = z2[j] | |
| self.n_pos[(k, l)] += delta * X[i, j] | |
| self.n_neg[(k, l)] += delta * (1 - X[i, j]) | |
| def update_col( | |
| self, | |
| X: npt.NDArray, | |
| z1: npt.NDArray, | |
| z2: npt.NDArray, | |
| j: int, | |
| increment: bool = True, | |
| ) -> None: | |
| delta = 1 if increment else -1 | |
| l = z2[j] | |
| for i in range(X.shape[0]): | |
| k = z1[i] | |
| self.n_pos[(k, l)] += delta * X[i, j] | |
| self.n_neg[(k, l)] += delta * (1 - X[i, j]) | |
| def remove_row_cluster(self, k: int, col_clusters: list[int]) -> None: | |
| for l in col_clusters: | |
| del self.n_pos[(k, l)] | |
| del self.n_neg[(k, l)] | |
| def remove_col_cluster(self, l: int, row_clusters: list[int]) -> None: | |
| for k in row_clusters: | |
| del self.n_pos[(k, l)] | |
| del self.n_neg[(k, l)] | |
| class ClusterManager: | |
| def __init__(self, z: npt.NDArray) -> None: | |
| self.clusters = defaultdict(list) | |
| for i, c in enumerate(z): | |
| self.clusters[c].append(i) | |
| def m(self, c: int) -> int: | |
| return len(self.clusters[c]) | |
| @property | |
| def new_cluster_id(self) -> int: | |
| return max(self.clusters.keys()) + 1 | |
| @property | |
| def cluster_ids(self) -> list[int]: | |
| return list(self.clusters.keys()) | |
| @property | |
| def cluster_ids_with_new(self) -> list[int]: | |
| return [*list(self.clusters.keys()), self.new_cluster_id] | |
| def add_index(self, c: int, index: int) -> None: | |
| self.clusters[c].append(index) | |
| def remove_cluster(self, c: int) -> None: | |
| self.clusters.pop(c) | |
| def remove_index(self, c: int, index: int) -> None: | |
| self.clusters[c].remove(index) | |
| def is_empty(self, c: int) -> bool: | |
| return len(self.clusters[c]) == 0 | |
| class InfiniteRelationalModel: | |
| def __init__( | |
| self, | |
| alpha1: float, | |
| alpha2: float, | |
| a0: float, | |
| b0: float, | |
| num_iter: int, | |
| burn_in: int, | |
| interval: int, | |
| seed: Optional[int] = None, | |
| ) -> None: | |
| self.alpha1 = alpha1 | |
| self.alpha2 = alpha2 | |
| self.a0 = a0 | |
| self.b0 = b0 | |
| self.num_iter = num_iter | |
| self.burn_in = burn_in | |
| self.interval = interval | |
| self.seed = seed | |
| self.zs1 = [] | |
| self.zs2 = [] | |
| def fit(self, X: npt.NDArray) -> None: | |
| N1, N2 = X.shape | |
| z1 = chinese_restaurant_process(N1, self.alpha1, self.seed) | |
| z2 = chinese_restaurant_process(N2, self.alpha2, self.seed) | |
| rng = np.random.default_rng(self.seed) | |
| objects = Objects(N1, N2) | |
| stats = SufficientStatistics(z1, z2, X) | |
| row_cluster = ClusterManager(z1) | |
| col_cluster = ClusterManager(z2) | |
| for _ in tqdm.tqdm(range(self.num_iter)): | |
| for o in objects: | |
| if o.is_first_domain(): | |
| stats.update_row(X, z1, z2, o.index, increment=False) | |
| row_cluster.remove_index(z1[o.index], o.index) | |
| if row_cluster.is_empty(z1[o.index]): | |
| row_cluster.remove_cluster(z1[o.index]) | |
| stats.remove_row_cluster(z1[o.index], col_cluster.cluster_ids) | |
| probs = self.calculate_first_domain_posterior_prob( | |
| X, | |
| stats, | |
| row_cluster, | |
| col_cluster, | |
| o.index, | |
| ) | |
| k = rng.choice(row_cluster.cluster_ids_with_new, p=probs) | |
| z1[o.index] = k | |
| row_cluster.add_index(k, o.index) | |
| stats.update_row(X, z1, z2, o.index, increment=True) | |
| else: | |
| stats.update_col(X, z1, z2, o.index, increment=False) | |
| col_cluster.remove_index(z2[o.index], o.index) | |
| if col_cluster.is_empty(z2[o.index]): | |
| col_cluster.remove_cluster(z2[o.index]) | |
| stats.remove_col_cluster(z2[o.index], row_cluster.cluster_ids) | |
| probs = self.calculate_second_domain_posterior_prob( | |
| X, | |
| stats, | |
| row_cluster, | |
| col_cluster, | |
| o.index, | |
| ) | |
| l = rng.choice(col_cluster.cluster_ids_with_new, p=probs) | |
| z2[o.index] = l | |
| col_cluster.add_index(l, o.index) | |
| stats.update_col(X, z1, z2, o.index, increment=True) | |
| self.zs1.append(z1.copy()) | |
| self.zs2.append(z2.copy()) | |
| def fit_predict(self, X: npt.NDArray) -> tuple[npt.NDArray, npt.NDArray]: | |
| self.fit(X) | |
| z1 = mode(self.zs1[self.burn_in :: self.interval], keepdims=False).mode | |
| z2 = mode(self.zs2[self.burn_in :: self.interval], keepdims=False).mode | |
| _, z1 = np.unique(z1, return_inverse=True) | |
| _, z2 = np.unique(z2, return_inverse=True) | |
| return z1, z2 | |
| def calculate_first_domain_posterior_prob( | |
| self, | |
| X: npt.NDArray, | |
| stats: SufficientStatistics, | |
| row_cluster: ClusterManager, | |
| col_cluster: ClusterManager, | |
| i: int, | |
| ) -> npt.NDArray: | |
| log_probs = np.zeros(len(row_cluster.cluster_ids_with_new)) | |
| for idx, k in enumerate(row_cluster.cluster_ids_with_new): | |
| log_probs[idx] = np.log(self.alpha1) if k == row_cluster.new_cluster_id else np.log(row_cluster.m(k)) | |
| for l, indices in col_cluster.clusters.items(): | |
| a_hat = self.a0 + stats.n_pos[(k, l)] | |
| b_hat = self.b0 + stats.n_neg[(k, l)] | |
| pos = X[i][indices].sum() | |
| neg = len(indices) - pos | |
| log_probs[idx] += betaln( | |
| a_hat + pos, | |
| b_hat + neg, | |
| ) - betaln(a_hat, b_hat) | |
| log_probs -= logsumexp(log_probs) | |
| return np.exp(log_probs) | |
| def calculate_second_domain_posterior_prob( | |
| self, | |
| X: npt.NDArray, | |
| stats: SufficientStatistics, | |
| row_cluster: ClusterManager, | |
| col_cluster: ClusterManager, | |
| j: int, | |
| ) -> npt.NDArray: | |
| log_probs = np.zeros(len(col_cluster.cluster_ids_with_new)) | |
| for idx, l in enumerate(col_cluster.cluster_ids_with_new): | |
| log_probs[idx] = np.log(self.alpha2) if l == col_cluster.new_cluster_id else np.log(col_cluster.m(l)) | |
| for k, indices in row_cluster.clusters.items(): | |
| a_hat = self.a0 + stats.n_pos[(k, l)] | |
| b_hat = self.b0 + stats.n_neg[(k, l)] | |
| pos = X[indices, j].sum() | |
| neg = len(indices) - pos | |
| log_probs[idx] += betaln( | |
| a_hat + pos, | |
| b_hat + neg, | |
| ) - betaln(a_hat, b_hat) | |
| log_probs -= logsumexp(log_probs) | |
| return np.exp(log_probs) | |
| def load_dataset() -> tuple[npt.NDArray, npt.NDArray]: | |
| graph = nx.karate_club_graph() | |
| X = (nx.to_numpy_array(graph) > 0).astype(np.int32) | |
| np.fill_diagonal(X, 1) | |
| mapping = {"Mr. Hi": 0, "Officer": 1} | |
| Z = np.array([mapping[node["club"]] for node in graph.nodes.values()]) | |
| return X, Z | |
| if __name__ == "__main__": | |
| X, Z = load_dataset() | |
| model = InfiniteRelationalModel( | |
| alpha1=1.0, | |
| alpha2=1.0, | |
| a0=0.5, | |
| b0=0.5, | |
| num_iter=2000, | |
| burn_in=1500, | |
| interval=5, | |
| seed=41, | |
| ) | |
| z1, z2 = model.fit_predict(X) | |
| ari_score = adjusted_rand_score(Z, z1) | |
| ari_score = adjusted_rand_score(Z, z2) | |
| print(f"ARI Score (Z vs z1): {ari_score:.4f}") | |
| print(f"ARI Score (Z vs z2): {ari_score:.4f}") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment