Skip to content

Instantly share code, notes, and snippets.

@cutecutecat
Last active December 5, 2025 08:23
Show Gist options
  • Select an option

  • Save cutecutecat/f7f8f2fa2e7fe81ebcbf4f1e0d64422a to your computer and use it in GitHub Desktop.

Select an option

Save cutecutecat/f7f8f2fa2e7fe81ebcbf4f1e0d64422a to your computer and use it in GitHub Desktop.
Hugeness reduction tool

With instance c7i.4xlarge (>= 16 cores)

  1. Prepare environment
conda create -n data PYTHON=3.11
conda install anaconda::h5py anaconda::numpy conda-forge::hnswlib
  1. Download file
aws s3 cp s3://vector-bench-data/openai_1536_500k/openai.hdf5 data.hdf5
  1. run tests
# Ordinary K-means
python kmeans.py

# Hugeness Reduction K-means
python kmeans_compress.py

python gen_ivf.py
python bench.py
import glob
import os
import sys
from typing import Dict, List, Tuple
import h5py
import numpy as np
DATA_FILE = "data.hdf5"
KS = (1, 10, 100)
def l2_squared(a: np.ndarray, b: np.ndarray) -> np.ndarray:
"""Compute squared L2 distances between two 2D arrays."""
a_norm = np.sum(a * a, axis=1, keepdims=True)
b_norm = np.sum(b * b, axis=1, keepdims=True).T
dist = a_norm + b_norm - 2.0 * (a @ b.T)
return np.maximum(dist, 0.0, out=dist)
def load_datasets(base: h5py.File) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
"""Load base/train, test, and neighbors datasets."""
for name in ("train", "base"):
if name in base:
db = np.asarray(base[name], dtype=np.float32)
break
else:
raise KeyError("HDF5 file must contain 'train' or 'base' dataset")
if "test" not in base:
raise KeyError("HDF5 file must contain 'test' dataset")
queries = np.asarray(base["test"], dtype=np.float32)
if "neighbors" not in base:
raise KeyError("HDF5 file must contain 'neighbors' dataset")
neighbors = np.asarray(base["neighbors"], dtype=np.int64)
return db, queries, neighbors
def assign_to_centroids(data: np.ndarray, centroids: np.ndarray) -> np.ndarray:
"""Assign each row in data to its nearest centroid."""
dist = l2_squared(data, centroids)
return np.argmin(dist, axis=1)
def load_ivf(ivf_path: str, n_centroids: int) -> List[np.ndarray]:
"""Load IVF from npy (expects dict with offsets/ids) into list-of-indices form."""
ivf = np.load(ivf_path, allow_pickle=True).item()
offsets = ivf["offsets"]
ids = ivf["ids"]
if offsets.shape[0] != n_centroids + 1:
raise ValueError(f"IVF offsets length {offsets.shape[0]} != n_centroids+1 ({n_centroids + 1})")
inv_lists: List[np.ndarray] = []
for i in range(n_centroids):
start, end = int(offsets[i]), int(offsets[i + 1])
inv_lists.append(ids[start:end])
return inv_lists
def compute_scan_counts(
queries: np.ndarray,
centroids: np.ndarray,
inv_lists: List[np.ndarray],
db: np.ndarray,
neighbors: np.ndarray,
db_labels: np.ndarray,
) -> Dict[int, List[int]]:
q_labels = assign_to_centroids(queries, centroids)
centroid_scan = centroids.shape[0]
results: Dict[int, List[int]] = {k: [] for k in KS}
cluster_counts: Dict[int, List[int]] = {k: [] for k in KS}
for qi, cid in enumerate(q_labels):
dists = l2_squared(queries[qi : qi + 1], centroids).ravel()
for k in KS:
topk = neighbors[qi][: min(k, neighbors.shape[1] if neighbors.ndim > 1 else k)]
valid_ids = [int(i) for i in topk if 0 <= int(i) < db_labels.shape[0]]
if valid_ids:
gt_clusters = list(set(db_labels[valid_ids]))
else:
gt_clusters = [cid]
threshold = float(np.max(dists[gt_clusters]))
clusters_to_scan = [idx for idx, dist in enumerate(dists) if dist <= threshold]
vector_scan = centroid_scan + sum(len(inv_lists[c]) for c in clusters_to_scan)
results[k].append(vector_scan)
cluster_counts[k].append(len(clusters_to_scan))
return results, cluster_counts
def summarize(values: List[int]) -> Dict[str, float]:
arr = np.asarray(values, dtype=np.int64)
return {
"min": float(np.min(arr)),
"max": float(np.max(arr)),
"median": float(np.median(arr)),
"mean": float(np.mean(arr)),
}
def main() -> int:
base_dir = os.path.dirname(__file__)
data_path = os.path.join(base_dir, DATA_FILE)
if not os.path.exists(data_path):
print(f"Missing {DATA_FILE} next to this script", file=sys.stderr)
return 1
ivf_files = sorted(glob.glob(os.path.join(base_dir, "ivf_*.npy")))
if not ivf_files:
print("No ivf_*.npy files found in current directory", file=sys.stderr)
return 1
with h5py.File(data_path, "r") as f:
db, queries, neighbors = load_datasets(f)
print(f"DB shape: {db.shape}, queries: {queries.shape}, neighbors: {neighbors.shape}")
for ivf_path in ivf_files:
tag = os.path.basename(ivf_path)[len("ivf_") : -4]
ivf = np.load(ivf_path, allow_pickle=True).item()
offsets = ivf["offsets"]
ids = ivf["ids"]
if "centroids" not in ivf:
raise KeyError(f"{ivf_path} must contain 'centroids' array")
centroids = np.asarray(ivf["centroids"], dtype=np.float32)
if offsets.shape[0] != centroids.shape[0] + 1:
raise ValueError(f"{ivf_path}: offsets length {offsets.shape[0]} != n_centroids+1 ({centroids.shape[0] + 1})")
inv_lists: List[np.ndarray] = []
for i in range(centroids.shape[0]):
start, end = int(offsets[i]), int(offsets[i + 1])
inv_lists.append(ids[start:end])
db_labels = assign_to_centroids(db, centroids)
scan_counts, cluster_counts = compute_scan_counts(queries, centroids, inv_lists, db, neighbors, db_labels)
print(f"Tag '{tag}': centroids {centroids.shape}, offsets {offsets.shape}, ids {ids.shape}")
for k in KS:
stats = summarize(scan_counts[k])
clusters_stats = summarize(cluster_counts[k])
print(
f" Top-{k}: scanned vectors -> min={stats['min']:.0f}({clusters_stats['min']:.2f}), "
f"max={stats['max']:.0f}({clusters_stats['max']:.2f}), "
f"median={stats['median']:.0f}({clusters_stats['median']:.2f}), "
f"mean={stats['mean']:.2f}({clusters_stats['mean']:.2f})"
)
return 0
if __name__ == "__main__":
sys.exit(main())
import glob
import os
import sys
from typing import Tuple
import h5py
import numpy as np
DATA_FILE = "data.hdf5"
def l2_squared(batch: np.ndarray, centroids: np.ndarray) -> np.ndarray:
"""Compute squared L2 distances between a batch and centroids."""
batch_norm = np.sum(batch * batch, axis=1, keepdims=True)
centroid_norm = np.sum(centroids * centroids, axis=1, keepdims=True).T
dist = batch_norm + centroid_norm - 2.0 * (batch @ centroids.T)
return np.maximum(dist, 0.0, out=dist)
def load_base(f: h5py.File) -> np.ndarray:
"""Load base/train dataset."""
for name in ("train", "base"):
if name in f:
return np.asarray(f[name], dtype=np.float32)
raise KeyError("HDF5 file must contain 'train' or 'base' dataset")
def assign_labels(data: np.ndarray, centroids: np.ndarray) -> np.ndarray:
"""Assign each row in data to nearest centroid."""
dist = l2_squared(data, centroids)
return np.argmin(dist, axis=1)
def build_ivf(labels: np.ndarray, n_centroids: int) -> Tuple[np.ndarray, np.ndarray]:
"""Build IVF offsets and ids arrays from labels."""
counts = np.bincount(labels, minlength=n_centroids).astype(np.int64)
offsets = np.zeros(n_centroids + 1, dtype=np.int64)
offsets[1:] = np.cumsum(counts)
order = np.argsort(labels, kind="stable")
ids = order.astype(np.int64, copy=False)
return offsets, ids
def process_centroids(centroids_path: str, base: np.ndarray) -> Tuple[str, np.ndarray, np.ndarray, np.ndarray]:
centroids = np.load(centroids_path).astype(np.float32)
labels = assign_labels(base, centroids)
offsets, ids = build_ivf(labels, centroids.shape[0])
return centroids_path, centroids, offsets, ids
def main() -> int:
base_dir = os.path.dirname(__file__)
data_path = os.path.join(base_dir, DATA_FILE)
if not os.path.exists(data_path):
print(f"Missing {DATA_FILE} next to this script", file=sys.stderr)
return 1
centroid_files = sorted(glob.glob(os.path.join(base_dir, "centroids_*.npy")))
if not centroid_files:
print("No centroids_*.npy files found in current directory", file=sys.stderr)
return 1
with h5py.File(data_path, "r") as f:
base = load_base(f)
print(f"Base shape: {base.shape}")
for cpath in centroid_files:
tag = os.path.basename(cpath)[len("centroids_") : -4]
output_path = os.path.join(base_dir, f"ivf_{tag}.npy")
_, centroids, offsets, ids = process_centroids(cpath, base)
np.save(output_path, {"offsets": offsets, "ids": ids, "centroids": centroids}, allow_pickle=True)
print(f"Saved IVF for tag '{tag}' to {output_path} (offsets {offsets.shape}, ids {ids.shape}, centroids {centroids.shape})")
return 0
if __name__ == "__main__":
sys.exit(main())
import math
import os
from concurrent.futures import ThreadPoolExecutor
import h5py
import numpy as np
# Configurable constants (edit in file if needed, no CLI flags are used)
DATA_FILE = "data.hdf5"
OUTPUT_FILE = "centroids_base.npy"
MAX_ITERS = 25
RANDOM_STATE = 42
TOLERANCE = 1e-4
N_CLUSTERS = 256
def load_train_dataset(path: str) -> np.ndarray:
"""Load the train split from the HDF5 file into memory as float32."""
with h5py.File(path, "r") as f:
if "train" not in f:
raise KeyError("HDF5 file does not contain 'train' dataset")
data = np.asarray(f["train"], dtype=np.float32)
if data.ndim != 2:
raise ValueError(f"'train' dataset must be 2D, got shape {data.shape}")
return data
def infer_cluster_count(data: np.ndarray, attrs: dict) -> int:
"""Infer number of clusters from HDF5 attributes or fall back to a heuristic."""
for key in ("n_clusters", "k", "n_centroids", "clusters"):
if key in attrs:
raw = attrs[key]
try:
if isinstance(raw, (bytes, bytearray)):
value = int(raw.decode())
else:
value = int(raw)
if value > 1:
return min(value, data.shape[0])
except Exception:
pass
# Heuristic: use sqrt of samples capped at 256, but at least 2
return max(2, min(256, int(math.sqrt(data.shape[0]))))
def l2_squared_distances(batch: np.ndarray, centroids: np.ndarray) -> np.ndarray:
"""Compute squared L2 distances between a batch of points and centroids."""
batch_norm = np.sum(batch * batch, axis=1, keepdims=True)
centroid_norm = np.sum(centroids * centroids, axis=1, keepdims=True).T
# Use the identity ||x-y||^2 = ||x||^2 + ||y||^2 - 2x·y for efficiency
distances = batch_norm + centroid_norm - 2.0 * (batch @ centroids.T)
return np.maximum(distances, 0.0, out=distances) # clamp small negatives
def parallel_l2_squared_to_point(
data: np.ndarray, point: np.ndarray, jobs: int
) -> np.ndarray:
"""Compute squared L2 distances from all rows in data to a single point in parallel."""
n_samples = data.shape[0]
jobs = max(1, min(jobs, n_samples))
chunk = (n_samples + jobs - 1) // jobs
point = point.astype(np.float32, copy=False)
point_norm = float(np.dot(point, point))
def worker(start: int, end: int) -> np.ndarray:
batch = data[start:end]
dist_sq = np.sum(batch * batch, axis=1) + point_norm - 2.0 * (batch @ point)
return np.maximum(dist_sq, 0.0, out=dist_sq)
results = []
with ThreadPoolExecutor(max_workers=jobs) as pool:
futures = []
for start in range(0, n_samples, chunk):
end = min(start + chunk, n_samples)
futures.append(pool.submit(worker, start, end))
for fut in futures:
results.append(fut.result())
return np.concatenate(results, axis=0)
def initialize_centroids(
data: np.ndarray, k: int, rng: np.random.Generator, n_jobs: int
) -> np.ndarray:
"""Simple k-means++ style initialization with parallel distance updates."""
n_samples = data.shape[0]
centroids = np.empty((k, data.shape[1]), dtype=np.float32)
# First centroid chosen uniformly
first_idx = rng.integers(0, n_samples)
centroids[0] = data[first_idx]
# Select remaining centroids probabilistically proportional to distance^2
closest_dist_sq = parallel_l2_squared_to_point(data, centroids[0], n_jobs)
for i in range(1, k):
probs = closest_dist_sq / np.sum(closest_dist_sq)
next_idx = rng.choice(n_samples, p=probs)
centroids[i] = data[next_idx]
new_dist_sq = parallel_l2_squared_to_point(data, centroids[i], n_jobs)
closest_dist_sq = np.minimum(closest_dist_sq, new_dist_sq)
return centroids
def process_slice(
data: np.ndarray, centroids: np.ndarray, start: int, end: int
) -> tuple[np.ndarray, np.ndarray, float]:
"""Assign a slice of data to centroids and return partial sums/counts/inertia."""
batch = data[start:end]
distances = l2_squared_distances(batch, centroids)
labels = np.argmin(distances, axis=1)
sums = np.zeros_like(centroids)
np.add.at(sums, labels, batch)
counts = np.bincount(labels, minlength=centroids.shape[0]).astype(np.int64)
inertia = float(np.sum(distances[np.arange(len(batch)), labels]))
return sums, counts, inertia
def recompute_centroids(data: np.ndarray, centroids: np.ndarray, n_jobs: int) -> tuple[np.ndarray, float]:
"""Assign points to centroids in parallel across contiguous slices (no sub-batching)."""
n_samples = data.shape[0]
jobs = max(1, min(n_jobs, n_samples))
chunk = (n_samples + jobs - 1) // jobs
tasks = []
with ThreadPoolExecutor(max_workers=jobs) as pool:
for start in range(0, n_samples, chunk):
end = min(start + chunk, n_samples)
tasks.append(pool.submit(process_slice, data, centroids, start, end))
partial_sums = []
partial_counts = []
partial_inertia = 0.0
for fut in tasks:
sums, counts, inertia = fut.result()
partial_sums.append(sums)
partial_counts.append(counts)
partial_inertia += inertia
summed = np.sum(partial_sums, axis=0)
counts = np.sum(partial_counts, axis=0)
new_centroids = np.where(counts[:, None] > 0, summed / counts[:, None], centroids)
return new_centroids.astype(np.float32), partial_inertia
def main() -> None:
data_path = os.path.join(os.path.dirname(__file__), DATA_FILE)
if not os.path.exists(data_path):
raise FileNotFoundError(f"Cannot find {DATA_FILE} next to the script")
with h5py.File(data_path, "r") as f:
if "train" not in f:
raise KeyError("HDF5 file does not contain 'train' dataset")
attrs = dict(f["train"].attrs)
# Load into memory once so threads can share it
data = np.asarray(f["train"], dtype=np.float32)
rng = np.random.default_rng(RANDOM_STATE)
k = min(N_CLUSTERS, data.shape[0])
n_jobs = max(1, os.cpu_count() or 1)
print(f"Loaded train data shape {data.shape}, dtype {data.dtype}")
print(f"Using fixed {k} clusters (cap {N_CLUSTERS}) with up to {n_jobs} threads")
centroids = initialize_centroids(data, k, rng, n_jobs)
previous_inertia = float("inf")
for step in range(1, MAX_ITERS + 1):
centroids, inertia = recompute_centroids(data, centroids, n_jobs)
change = previous_inertia - inertia
print(f"Iteration {step:02d}: inertia={inertia:.4f}, change={change:.4f}")
# Stop when improvement is tiny relative to previous inertia
if change >= 0 and change < TOLERANCE * max(1.0, previous_inertia):
print("Converged based on inertia improvement threshold.")
break
previous_inertia = inertia
output_path = os.path.join(os.path.dirname(__file__), OUTPUT_FILE)
np.save(output_path, centroids.astype(np.float32))
print(f"Saved centroids to {output_path} with shape {centroids.shape}")
if __name__ == "__main__":
main()
import math
import os
from concurrent.futures import ThreadPoolExecutor
import h5py
import numpy as np
import hnswlib
# Configurable constants (edit in file if needed, no CLI flags are used)
DATA_FILE = "data.hdf5"
OUTPUT_FILE = "centroids_compress.npy"
MAX_ITERS = 25
RANDOM_STATE = 42
TOLERANCE = 1e-4
K_NEIGHBOR = 7
N_CLUSTERS = 256
def load_train_dataset(path: str) -> np.ndarray:
"""Load the train split from the HDF5 file into memory as float32."""
with h5py.File(path, "r") as f:
if "train" not in f:
raise KeyError("HDF5 file does not contain 'train' dataset")
data = np.asarray(f["train"], dtype=np.float32)
if data.ndim != 2:
raise ValueError(f"'train' dataset must be 2D, got shape {data.shape}")
return data
def infer_cluster_count(data: np.ndarray, attrs: dict) -> int:
"""Infer number of clusters from HDF5 attributes or fall back to a heuristic."""
for key in ("n_clusters", "k", "n_centroids", "clusters"):
if key in attrs:
raw = attrs[key]
try:
if isinstance(raw, (bytes, bytearray)):
value = int(raw.decode())
else:
value = int(raw)
if value > 1:
return min(value, data.shape[0])
except Exception:
pass
# Heuristic: use sqrt of samples capped at 256, but at least 2
return max(2, min(256, int(math.sqrt(data.shape[0]))))
def l2_distances(batch: np.ndarray, centroids: np.ndarray) -> np.ndarray:
"""Compute L2 distances between a batch of points and centroids."""
batch_norm = np.sum(batch * batch, axis=1, keepdims=True)
centroid_norm = np.sum(centroids * centroids, axis=1, keepdims=True).T
dist_sq = batch_norm + centroid_norm - 2.0 * (batch @ centroids.T)
np.maximum(dist_sq, 0.0, out=dist_sq)
return dist_sq
def parallel_l2_squared_to_point(
data: np.ndarray, point: np.ndarray, jobs: int
) -> np.ndarray:
"""Compute squared L2 distances from all rows in data to a single point in parallel."""
n_samples = data.shape[0]
jobs = max(1, min(jobs, n_samples))
chunk = (n_samples + jobs - 1) // jobs
point = point.astype(np.float32, copy=False)
point_norm = float(np.dot(point, point))
def worker(start: int, end: int) -> np.ndarray:
batch = data[start:end]
dist_sq = np.sum(batch * batch, axis=1) + point_norm - 2.0 * (batch @ point)
return np.maximum(dist_sq, 0.0, out=dist_sq)
results = []
with ThreadPoolExecutor(max_workers=jobs) as pool:
futures = []
for start in range(0, n_samples, chunk):
end = min(start + chunk, n_samples)
futures.append(pool.submit(worker, start, end))
for fut in futures:
results.append(fut.result())
return np.concatenate(results, axis=0)
def compute_kth_neighbor_distances(index: hnswlib.Index, queries: np.ndarray, k: int) -> np.ndarray:
"""Return distance to the k-th nearest neighbor for queries using an existing HNSW index.
Assumes queries are not indexed elements (e.g., centroids vs. data index), so no self-hit to drop.
"""
n = index.get_current_count()
if n == 0 or queries.shape[0] == 0:
return np.zeros((queries.shape[0],), dtype=np.float32)
k_eff = min(k, n)
if k_eff <= 0:
return np.zeros((queries.shape[0],), dtype=np.float32)
index.set_ef(max(2 * k_eff, 64))
_, distances = index.knn_query(queries, k=k_eff)
kth_sq = distances[:, k_eff - 1]
kth_sq = np.maximum(kth_sq, 0.0, out=kth_sq)
return np.sqrt(kth_sq, out=kth_sq)
def build_hnsw_index(data: np.ndarray, ef_construction: int = 200, m: int = 16) -> hnswlib.Index:
"""Build a reusable HNSW index on the given data."""
n, dim = data.shape
index = hnswlib.Index(space="l2", dim=dim)
index.init_index(max_elements=n, ef_construction=ef_construction, M=m)
ids = np.arange(n, dtype=np.int32)
index.add_items(data, ids)
return index
def knn_kth_from_index(index: hnswlib.Index, data: np.ndarray, k: int, skip_self: bool) -> np.ndarray:
"""Use an existing HNSW index to fetch k-th neighbor distances.
If skip_self is True, assumes queries are the same set as the index contents and drops the first hit.
"""
n = data.shape[0]
if n <= 1:
return np.zeros((n,), dtype=np.float32)
k_eff = min(k, index.get_current_count() - (1 if skip_self else 0))
if k_eff <= 0:
return np.zeros((n,), dtype=np.float32)
top_k = k_eff + 1 if skip_self else k_eff
index.set_ef(max(2 * top_k, 64))
_, distances = index.knn_query(data, k=top_k)
if skip_self:
kth_sq = distances[:, k_eff]
else:
kth_sq = distances[:, k_eff - 1]
kth_sq = np.maximum(kth_sq, 0.0, out=kth_sq)
return np.sqrt(kth_sq, out=kth_sq)
def compress_distance(
batch: np.ndarray,
centroids: np.ndarray,
r_k_batch: np.ndarray,
r_k_centroids: np.ndarray,
) -> np.ndarray:
"""Compute custom distance -exp(-d / (r_kx * r_ky))."""
dists = l2_distances(batch, centroids)
denom = r_k_batch[:, None] * r_k_centroids[None, :]
denom = np.maximum(denom, 1e-12, out=denom)
scores = 1-np.exp(-dists / denom)
return scores
def initialize_centroids(
data: np.ndarray, k: int, rng: np.random.Generator, n_jobs: int
) -> np.ndarray:
"""Simple k-means++ style initialization with parallel distance updates."""
n_samples = data.shape[0]
centroids = np.empty((k, data.shape[1]), dtype=np.float32)
# First centroid chosen uniformly
first_idx = rng.integers(0, n_samples)
centroids[0] = data[first_idx]
# Select remaining centroids probabilistically proportional to distance^2
closest_dist_sq = parallel_l2_squared_to_point(data, centroids[0], n_jobs)
for i in range(1, k):
probs = closest_dist_sq / np.sum(closest_dist_sq)
next_idx = rng.choice(n_samples, p=probs)
centroids[i] = data[next_idx]
new_dist_sq = parallel_l2_squared_to_point(data, centroids[i], n_jobs)
closest_dist_sq = np.minimum(closest_dist_sq, new_dist_sq)
return centroids
def recompute_centroids(
data: np.ndarray,
centroids: np.ndarray,
r_k_data: np.ndarray,
n_jobs: int,
data_index: hnswlib.Index,
) -> tuple[np.ndarray, float]:
"""Assign points to centroids in parallel contiguous slices (no sub-batching)."""
n_samples = data.shape[0]
jobs = max(1, min(n_jobs, n_samples))
chunk = (n_samples + jobs - 1) // jobs
# Centroids' k-NN distances are measured against the data via the shared HNSW index
r_k_centroids = compute_kth_neighbor_distances(data_index, centroids, K_NEIGHBOR)
# r_k_centroids = np.ones((centroids.shape[0],), dtype=np.float32)
def process_slice(start: int, end: int) -> tuple[np.ndarray, np.ndarray, float]:
batch = data[start:end]
dist = compress_distance(batch, centroids, r_k_data[start:end], r_k_centroids)
labels = np.argmin(dist, axis=1)
sums = np.zeros_like(centroids)
np.add.at(sums, labels, batch)
counts = np.bincount(labels, minlength=centroids.shape[0]).astype(np.int64)
inertia = float(np.sum(dist[np.arange(len(batch)), labels]))
return sums, counts, inertia
tasks = []
with ThreadPoolExecutor(max_workers=jobs) as pool:
for start in range(0, n_samples, chunk):
end = min(start + chunk, n_samples)
tasks.append(pool.submit(process_slice, start, end))
partial_sums = []
partial_counts = []
partial_inertia = 0.0
for fut in tasks:
sums, counts, inertia = fut.result()
partial_sums.append(sums)
partial_counts.append(counts)
partial_inertia += inertia
summed = np.sum(partial_sums, axis=0)
counts = np.sum(partial_counts, axis=0)
new_centroids = np.where(counts[:, None] > 0, summed / counts[:, None], centroids)
return new_centroids.astype(np.float32), partial_inertia
def main() -> None:
data_path = os.path.join(os.path.dirname(__file__), DATA_FILE)
if not os.path.exists(data_path):
raise FileNotFoundError(f"Cannot find {DATA_FILE} next to the script")
with h5py.File(data_path, "r") as f:
if "train" not in f:
raise KeyError("HDF5 file does not contain 'train' dataset")
attrs = dict(f["train"].attrs)
# Load into memory once so threads can share it
data = np.asarray(f["train"], dtype=np.float32)
rng = np.random.default_rng(RANDOM_STATE)
k = min(N_CLUSTERS, data.shape[0])
n_jobs = max(1, os.cpu_count() or 1)
print(f"Loaded train data shape {data.shape}, dtype {data.dtype}")
print(f"Using fixed {k} clusters (cap {N_CLUSTERS}) with up to {n_jobs} threads")
# Precompute k-th neighbor distance for data points (single global HNSW index)
print(f"Building HNSW index and computing {K_NEIGHBOR}-th neighbor distances for data...")
hnsw_index = build_hnsw_index(data)
r_k_data = knn_kth_from_index(hnsw_index, data, K_NEIGHBOR, skip_self=True)
print("Neighbor distances computed")
centroids = initialize_centroids(data, k, rng, n_jobs)
previous_inertia = float("inf")
for step in range(1, MAX_ITERS + 1):
centroids, inertia = recompute_centroids(data, centroids, r_k_data, n_jobs, hnsw_index)
change = previous_inertia - inertia
print(f"Iteration {step:02d}: inertia={inertia:.6f}, change={change:.6f}")
# Stop when improvement is tiny relative to previous inertia
if change >= 0 and change < TOLERANCE * max(1.0, previous_inertia):
print("Converged based on inertia improvement threshold.")
break
previous_inertia = inertia
output_path = os.path.join(os.path.dirname(__file__), OUTPUT_FILE)
np.save(output_path, centroids.astype(np.float32))
print(f"Saved centroids to {output_path} with shape {centroids.shape}")
if __name__ == "__main__":
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment