Last active
January 26, 2025 21:33
-
-
Save cutecutecat/b3b77f7e27a0fd43fecc376cc07d787e to your computer and use it in GitHub Desktop.
rabbithole benchmark on external centroids
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| from os.path import join | |
| import os | |
| import time | |
| import argparse | |
| from pathlib import Path | |
| from tqdm import tqdm | |
| import psycopg | |
| import h5py | |
| from pgvecto_rs.psycopg import register_vector | |
| import numpy as np | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument( | |
| "-m", "--metric", help="Metric to pick, in l2 or cos", required=True | |
| ) | |
| parser.add_argument("-n", "--name", help="Dataset name, like: sift", required=True) | |
| args = parser.parse_args() | |
| HOME = Path.home() | |
| INDEX_PATH = join(HOME, f"indexes/pg/{args.name}/{args.metric}") | |
| DATA_PATH = join(HOME, f"{args.name}/{args.name}.hdf5") | |
| os.makedirs(join(HOME, f"indexes/pg/{args.name}"), exist_ok=True) | |
| dataset = h5py.File(DATA_PATH, "r") | |
| test = dataset["test"][:] | |
| if args.metric == "l2": | |
| metric_ops = "<->" | |
| elif args.metric == "cos": | |
| metric_ops = "<=>" | |
| else: | |
| raise ValueError | |
| answer = dataset["neighbors"][:] | |
| n, dims = np.shape(test) | |
| m = np.shape(test)[0] | |
| # reconnect for updated GUC variables to take effect | |
| conn = psycopg.connect( | |
| conninfo="postgres://bench:123@localhost:5432/postgres", | |
| dbname="postgres", | |
| autocommit=True, | |
| ) | |
| conn.execute("SET search_path TO public, vectors, rabbithole") | |
| conn.execute("CREATE EXTENSION IF NOT EXISTS vectors") | |
| conn.execute("CREATE EXTENSION IF NOT EXISTS rabbithole") | |
| conn.execute("SET rabbithole.nprobe=300") | |
| register_vector(conn) | |
| Ks = [10, 100] | |
| for k in Ks: | |
| hits = 0 | |
| delta = 0 | |
| pbar = tqdm(enumerate(test), total=m) | |
| for i, query in pbar: | |
| start = time.perf_counter() | |
| result = conn.execute( | |
| f"SELECT id FROM {args.name} ORDER BY embedding {metric_ops} %s LIMIT {k}", | |
| (query,), | |
| ).fetchall() | |
| end = time.perf_counter() | |
| hits += len(set([p[0] for p in result[:k]]) & set(answer[i][:k].tolist())) | |
| delta += end - start | |
| pbar.set_description(f"recall: {hits / k / (i+1)} QPS: {(i+1) / delta} ") | |
| recall = hits / k / m | |
| qps = m / delta | |
| print(f"Top: {k} recall: {recall:.4f} QPS: {qps:.2f}") |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| cargo build --package rabbithole --lib --features pg16 --target x86_64-unknown-linux-gnu --release | |
| ./tools/schema.sh --features pg16 --target x86_64-unknown-linux-gnu --release | expand -t 4 > ./target/schema.sql | |
| export SEMVER="0.0.0" | |
| export VERSION="16" | |
| export ARCH="x86_64" | |
| export PLATFORM="amd64" | |
| ./tools/package.sh | |
| docker build -t starkind/rabbithole:pg16-v0.0.0 -f ./docker/Dockerfile . | |
| docker run --name rabbithole -e POSTGRES_PASSWORD=123 -p 5432:5432 -d starkind/rabbithole:pg16-v0.0.0 | |
| PGPASSWORD=123 psql -h 127.0.0.1 -U postgres -c "CREATE USER bench WITH PASSWORD '123';" | |
| PGPASSWORD=123 psql -h 127.0.0.1 -U postgres -c "ALTER ROLE bench SUPERUSER;" | |
| # pip install pgvecto_rs numpy faiss-cpu psycopg h5py |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| from os.path import join | |
| import os | |
| import time | |
| import argparse | |
| from pathlib import Path | |
| import pickle | |
| import psycopg | |
| import h5py | |
| from pgvecto_rs.psycopg import register_vector | |
| import numpy as np | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument( | |
| "-m", "--metric", help="Metric to pick, in l2 or cos", required=True | |
| ) | |
| parser.add_argument("-n", "--name", help="Dataset name, like: sift", required=True) | |
| args = parser.parse_args() | |
| HOME = Path.home() | |
| DATA_PATH = join(HOME, f"{args.name}/{args.name}.hdf5") | |
| os.makedirs(join(HOME, f"indexes/pg/{args.name}"), exist_ok=True) | |
| dataset = h5py.File(DATA_PATH, "r") | |
| train = dataset["train"][:] | |
| test = dataset["test"][:] | |
| K = 4096 | |
| if args.metric == "l2": | |
| metric_ops = "vector_l2_ops" | |
| ivf_config = f""" | |
| nlist = {K} | |
| spherical_centroids = false | |
| [centroids] | |
| table = 'public.centroids' | |
| column = 'coordinate' | |
| """ | |
| elif args.metric == "cos": | |
| metric_ops = "vector_cos_ops" | |
| ivf_config = f""" | |
| nlist = {K} | |
| spherical_centroids = true | |
| [centroids] | |
| table = 'public.centroids' | |
| column = 'coordinate' | |
| """ | |
| else: | |
| raise ValueError | |
| answer = dataset["neighbors"][:] | |
| n, dims = np.shape(train) | |
| m = np.shape(test)[0] | |
| keepalive_kwargs = { | |
| "keepalives": 1, | |
| "keepalives_idle": 30, | |
| "keepalives_interval": 5, | |
| "keepalives_count": 5, | |
| } | |
| start = time.perf_counter() | |
| with open(f"{args.name}.pickle", "rb") as f: | |
| centroids = pickle.load(f) | |
| conn = psycopg.connect( | |
| conninfo="postgres://bench:123@localhost:5432/postgres", | |
| dbname="postgres", | |
| autocommit=True, | |
| **keepalive_kwargs, | |
| ) | |
| conn.execute("SET search_path TO public, vectors, rabbithole") | |
| conn.execute("CREATE EXTENSION IF NOT EXISTS vectors") | |
| conn.execute("CREATE EXTENSION IF NOT EXISTS rabbithole") | |
| register_vector(conn) | |
| conn.execute(f"DROP TABLE IF EXISTS public.centroids") | |
| conn.execute(f"CREATE TABLE public.centroids (coordinate vector({dims}))") | |
| with conn.cursor().copy( | |
| f"COPY public.centroids (coordinate) FROM STDIN WITH (FORMAT BINARY)" | |
| ) as copy: | |
| copy.set_types(["vector"]) | |
| for i in range(K): | |
| copy.write_row([centroids[i]]) | |
| while conn.pgconn.flush() == 1: | |
| pass | |
| conn.execute(f"DROP TABLE IF EXISTS {args.name}") | |
| conn.execute(f"CREATE TABLE {args.name} (id integer, embedding vector({dims}))") | |
| with conn.cursor().copy( | |
| f"COPY {args.name} (id, embedding) FROM STDIN WITH (FORMAT BINARY)" | |
| ) as copy: | |
| copy.set_types(["integer", "vector"]) | |
| for i in range(n): | |
| copy.write_row([i, train[i]]) | |
| while conn.pgconn.flush() == 1: | |
| pass | |
| start = time.perf_counter() | |
| conn.execute( | |
| f"CREATE INDEX ON {args.name} USING ivfrq (embedding {metric_ops}) WITH (options = $${ivf_config}$$)" | |
| ) | |
| end = time.perf_counter() | |
| delta = end - start | |
| print(f"Index build time: {delta:.2f}s") |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| from os.path import join | |
| import os | |
| import time | |
| import argparse | |
| from pathlib import Path | |
| import pickle | |
| import h5py | |
| import faiss | |
| import numpy as np | |
| K = 4096 | |
| SEED = 42 | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument("-n", "--name", help="Dataset name, like: sift", required=True) | |
| args = parser.parse_args() | |
| HOME = Path.home() | |
| DATA_PATH = join(HOME, f"{args.name}/{args.name}.hdf5") | |
| dataset = h5py.File(DATA_PATH, "r") | |
| if len(dataset["train"]) > 256 * K: | |
| rs = np.random.RandomState(SEED) | |
| idx = rs.choice(len(dataset["train"]), size=256 * K, replace=False) | |
| train = dataset["train"][np.sort(idx)] | |
| else: | |
| train = dataset["train"][:] | |
| test = dataset["test"][:] | |
| if np.shape(train)[0] > 256 * K: | |
| rs = np.random.RandomState(SEED) | |
| idx = rs.choice(np.shape(train)[0], size=256 * K, replace=False) | |
| train = train[idx] | |
| answer = dataset["neighbors"][:] | |
| n, dims = np.shape(train) | |
| m = np.shape(test)[0] | |
| start = time.perf_counter() | |
| index = faiss.IndexFlatL2(dims) | |
| clustering = faiss.Clustering(dims, K) | |
| clustering.verbose = True | |
| clustering.seed = 42 | |
| clustering.niter = 10 | |
| clustering.train(train, index) | |
| centroids = faiss.vector_float_to_array(clustering.centroids) | |
| end = time.perf_counter() | |
| delta = end - start | |
| print(f"K-means time: {delta:.2f}s") | |
| centroids = centroids.reshape([K, -1]) | |
| with open(f"{args.name}.pickle", "wb") as f: | |
| pickle.dump(centroids, f) |
Author
Author
Memory usage
For dataset cohere-10m-23, we observed about 44G for PostgreSQL instance and 40G for load.py script at Index Build.
So the AWS instance m7i.8xlarge of 128G memory is required.
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Benchmark
Result of different datasets
We introduce new generated
cohere-1m-23andcohere-10m-23to our benchmark.All experiments are carried out with arguments:
Effect of nprob in large dataset
We try nprob from 10 to 300 at large dataset
cohere-10m-23.