|
import math |
|
import os |
|
from concurrent.futures import ThreadPoolExecutor |
|
|
|
import h5py |
|
import numpy as np |
|
import hnswlib |
|
|
|
|
|
# Configurable constants (edit in file if needed, no CLI flags are used) |
|
DATA_FILE = "data.hdf5" |
|
OUTPUT_FILE = "centroids_compress.npy" |
|
MAX_ITERS = 25 |
|
RANDOM_STATE = 42 |
|
TOLERANCE = 1e-4 |
|
K_NEIGHBOR = 7 |
|
N_CLUSTERS = 256 |
|
|
|
|
|
def load_train_dataset(path: str) -> np.ndarray: |
|
"""Load the train split from the HDF5 file into memory as float32.""" |
|
with h5py.File(path, "r") as f: |
|
if "train" not in f: |
|
raise KeyError("HDF5 file does not contain 'train' dataset") |
|
data = np.asarray(f["train"], dtype=np.float32) |
|
if data.ndim != 2: |
|
raise ValueError(f"'train' dataset must be 2D, got shape {data.shape}") |
|
return data |
|
|
|
|
|
def infer_cluster_count(data: np.ndarray, attrs: dict) -> int: |
|
"""Infer number of clusters from HDF5 attributes or fall back to a heuristic.""" |
|
for key in ("n_clusters", "k", "n_centroids", "clusters"): |
|
if key in attrs: |
|
raw = attrs[key] |
|
try: |
|
if isinstance(raw, (bytes, bytearray)): |
|
value = int(raw.decode()) |
|
else: |
|
value = int(raw) |
|
if value > 1: |
|
return min(value, data.shape[0]) |
|
except Exception: |
|
pass |
|
# Heuristic: use sqrt of samples capped at 256, but at least 2 |
|
return max(2, min(256, int(math.sqrt(data.shape[0])))) |
|
|
|
|
|
def l2_distances(batch: np.ndarray, centroids: np.ndarray) -> np.ndarray: |
|
"""Compute L2 distances between a batch of points and centroids.""" |
|
batch_norm = np.sum(batch * batch, axis=1, keepdims=True) |
|
centroid_norm = np.sum(centroids * centroids, axis=1, keepdims=True).T |
|
dist_sq = batch_norm + centroid_norm - 2.0 * (batch @ centroids.T) |
|
np.maximum(dist_sq, 0.0, out=dist_sq) |
|
return dist_sq |
|
|
|
|
|
def parallel_l2_squared_to_point( |
|
data: np.ndarray, point: np.ndarray, jobs: int |
|
) -> np.ndarray: |
|
"""Compute squared L2 distances from all rows in data to a single point in parallel.""" |
|
n_samples = data.shape[0] |
|
jobs = max(1, min(jobs, n_samples)) |
|
chunk = (n_samples + jobs - 1) // jobs |
|
point = point.astype(np.float32, copy=False) |
|
point_norm = float(np.dot(point, point)) |
|
|
|
def worker(start: int, end: int) -> np.ndarray: |
|
batch = data[start:end] |
|
dist_sq = np.sum(batch * batch, axis=1) + point_norm - 2.0 * (batch @ point) |
|
return np.maximum(dist_sq, 0.0, out=dist_sq) |
|
|
|
results = [] |
|
with ThreadPoolExecutor(max_workers=jobs) as pool: |
|
futures = [] |
|
for start in range(0, n_samples, chunk): |
|
end = min(start + chunk, n_samples) |
|
futures.append(pool.submit(worker, start, end)) |
|
for fut in futures: |
|
results.append(fut.result()) |
|
return np.concatenate(results, axis=0) |
|
|
|
|
|
def compute_kth_neighbor_distances(index: hnswlib.Index, queries: np.ndarray, k: int) -> np.ndarray: |
|
"""Return distance to the k-th nearest neighbor for queries using an existing HNSW index. |
|
|
|
Assumes queries are not indexed elements (e.g., centroids vs. data index), so no self-hit to drop. |
|
""" |
|
n = index.get_current_count() |
|
if n == 0 or queries.shape[0] == 0: |
|
return np.zeros((queries.shape[0],), dtype=np.float32) |
|
|
|
k_eff = min(k, n) |
|
if k_eff <= 0: |
|
return np.zeros((queries.shape[0],), dtype=np.float32) |
|
|
|
index.set_ef(max(2 * k_eff, 64)) |
|
_, distances = index.knn_query(queries, k=k_eff) |
|
|
|
kth_sq = distances[:, k_eff - 1] |
|
kth_sq = np.maximum(kth_sq, 0.0, out=kth_sq) |
|
return np.sqrt(kth_sq, out=kth_sq) |
|
|
|
|
|
def build_hnsw_index(data: np.ndarray, ef_construction: int = 200, m: int = 16) -> hnswlib.Index: |
|
"""Build a reusable HNSW index on the given data.""" |
|
n, dim = data.shape |
|
index = hnswlib.Index(space="l2", dim=dim) |
|
index.init_index(max_elements=n, ef_construction=ef_construction, M=m) |
|
ids = np.arange(n, dtype=np.int32) |
|
index.add_items(data, ids) |
|
return index |
|
|
|
|
|
def knn_kth_from_index(index: hnswlib.Index, data: np.ndarray, k: int, skip_self: bool) -> np.ndarray: |
|
"""Use an existing HNSW index to fetch k-th neighbor distances. |
|
|
|
If skip_self is True, assumes queries are the same set as the index contents and drops the first hit. |
|
""" |
|
n = data.shape[0] |
|
if n <= 1: |
|
return np.zeros((n,), dtype=np.float32) |
|
|
|
k_eff = min(k, index.get_current_count() - (1 if skip_self else 0)) |
|
if k_eff <= 0: |
|
return np.zeros((n,), dtype=np.float32) |
|
|
|
top_k = k_eff + 1 if skip_self else k_eff |
|
index.set_ef(max(2 * top_k, 64)) |
|
_, distances = index.knn_query(data, k=top_k) |
|
|
|
if skip_self: |
|
kth_sq = distances[:, k_eff] |
|
else: |
|
kth_sq = distances[:, k_eff - 1] |
|
|
|
kth_sq = np.maximum(kth_sq, 0.0, out=kth_sq) |
|
return np.sqrt(kth_sq, out=kth_sq) |
|
|
|
|
|
def compress_distance( |
|
batch: np.ndarray, |
|
centroids: np.ndarray, |
|
r_k_batch: np.ndarray, |
|
r_k_centroids: np.ndarray, |
|
) -> np.ndarray: |
|
"""Compute custom distance -exp(-d / (r_kx * r_ky)).""" |
|
dists = l2_distances(batch, centroids) |
|
denom = r_k_batch[:, None] * r_k_centroids[None, :] |
|
denom = np.maximum(denom, 1e-12, out=denom) |
|
scores = 1-np.exp(-dists / denom) |
|
return scores |
|
|
|
|
|
def initialize_centroids( |
|
data: np.ndarray, k: int, rng: np.random.Generator, n_jobs: int |
|
) -> np.ndarray: |
|
"""Simple k-means++ style initialization with parallel distance updates.""" |
|
n_samples = data.shape[0] |
|
centroids = np.empty((k, data.shape[1]), dtype=np.float32) |
|
|
|
# First centroid chosen uniformly |
|
first_idx = rng.integers(0, n_samples) |
|
centroids[0] = data[first_idx] |
|
|
|
# Select remaining centroids probabilistically proportional to distance^2 |
|
closest_dist_sq = parallel_l2_squared_to_point(data, centroids[0], n_jobs) |
|
for i in range(1, k): |
|
probs = closest_dist_sq / np.sum(closest_dist_sq) |
|
next_idx = rng.choice(n_samples, p=probs) |
|
centroids[i] = data[next_idx] |
|
new_dist_sq = parallel_l2_squared_to_point(data, centroids[i], n_jobs) |
|
closest_dist_sq = np.minimum(closest_dist_sq, new_dist_sq) |
|
|
|
return centroids |
|
|
|
|
|
def recompute_centroids( |
|
data: np.ndarray, |
|
centroids: np.ndarray, |
|
r_k_data: np.ndarray, |
|
n_jobs: int, |
|
data_index: hnswlib.Index, |
|
) -> tuple[np.ndarray, float]: |
|
"""Assign points to centroids in parallel contiguous slices (no sub-batching).""" |
|
n_samples = data.shape[0] |
|
jobs = max(1, min(n_jobs, n_samples)) |
|
chunk = (n_samples + jobs - 1) // jobs |
|
|
|
# Centroids' k-NN distances are measured against the data via the shared HNSW index |
|
r_k_centroids = compute_kth_neighbor_distances(data_index, centroids, K_NEIGHBOR) |
|
# r_k_centroids = np.ones((centroids.shape[0],), dtype=np.float32) |
|
|
|
def process_slice(start: int, end: int) -> tuple[np.ndarray, np.ndarray, float]: |
|
batch = data[start:end] |
|
dist = compress_distance(batch, centroids, r_k_data[start:end], r_k_centroids) |
|
labels = np.argmin(dist, axis=1) |
|
|
|
sums = np.zeros_like(centroids) |
|
np.add.at(sums, labels, batch) |
|
counts = np.bincount(labels, minlength=centroids.shape[0]).astype(np.int64) |
|
inertia = float(np.sum(dist[np.arange(len(batch)), labels])) |
|
return sums, counts, inertia |
|
|
|
tasks = [] |
|
with ThreadPoolExecutor(max_workers=jobs) as pool: |
|
for start in range(0, n_samples, chunk): |
|
end = min(start + chunk, n_samples) |
|
tasks.append(pool.submit(process_slice, start, end)) |
|
|
|
partial_sums = [] |
|
partial_counts = [] |
|
partial_inertia = 0.0 |
|
for fut in tasks: |
|
sums, counts, inertia = fut.result() |
|
partial_sums.append(sums) |
|
partial_counts.append(counts) |
|
partial_inertia += inertia |
|
|
|
summed = np.sum(partial_sums, axis=0) |
|
counts = np.sum(partial_counts, axis=0) |
|
|
|
new_centroids = np.where(counts[:, None] > 0, summed / counts[:, None], centroids) |
|
return new_centroids.astype(np.float32), partial_inertia |
|
|
|
|
|
def main() -> None: |
|
data_path = os.path.join(os.path.dirname(__file__), DATA_FILE) |
|
if not os.path.exists(data_path): |
|
raise FileNotFoundError(f"Cannot find {DATA_FILE} next to the script") |
|
|
|
with h5py.File(data_path, "r") as f: |
|
if "train" not in f: |
|
raise KeyError("HDF5 file does not contain 'train' dataset") |
|
attrs = dict(f["train"].attrs) |
|
# Load into memory once so threads can share it |
|
data = np.asarray(f["train"], dtype=np.float32) |
|
|
|
rng = np.random.default_rng(RANDOM_STATE) |
|
k = min(N_CLUSTERS, data.shape[0]) |
|
n_jobs = max(1, os.cpu_count() or 1) |
|
|
|
print(f"Loaded train data shape {data.shape}, dtype {data.dtype}") |
|
print(f"Using fixed {k} clusters (cap {N_CLUSTERS}) with up to {n_jobs} threads") |
|
|
|
# Precompute k-th neighbor distance for data points (single global HNSW index) |
|
print(f"Building HNSW index and computing {K_NEIGHBOR}-th neighbor distances for data...") |
|
hnsw_index = build_hnsw_index(data) |
|
r_k_data = knn_kth_from_index(hnsw_index, data, K_NEIGHBOR, skip_self=True) |
|
print("Neighbor distances computed") |
|
centroids = initialize_centroids(data, k, rng, n_jobs) |
|
|
|
previous_inertia = float("inf") |
|
for step in range(1, MAX_ITERS + 1): |
|
centroids, inertia = recompute_centroids(data, centroids, r_k_data, n_jobs, hnsw_index) |
|
change = previous_inertia - inertia |
|
print(f"Iteration {step:02d}: inertia={inertia:.6f}, change={change:.6f}") |
|
# Stop when improvement is tiny relative to previous inertia |
|
if change >= 0 and change < TOLERANCE * max(1.0, previous_inertia): |
|
print("Converged based on inertia improvement threshold.") |
|
break |
|
previous_inertia = inertia |
|
|
|
output_path = os.path.join(os.path.dirname(__file__), OUTPUT_FILE) |
|
np.save(output_path, centroids.astype(np.float32)) |
|
print(f"Saved centroids to {output_path} with shape {centroids.shape}") |
|
|
|
|
|
if __name__ == "__main__": |
|
main() |