Last active
January 26, 2025 21:33
-
-
Save cutecutecat/b3b77f7e27a0fd43fecc376cc07d787e to your computer and use it in GitHub Desktop.
rabbithole benchmark on external centroids
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| from os.path import join | |
| import os | |
| import time | |
| import argparse | |
| from pathlib import Path | |
| from tqdm import tqdm | |
| import psycopg | |
| import h5py | |
| from pgvecto_rs.psycopg import register_vector | |
| import numpy as np | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument( | |
| "-m", "--metric", help="Metric to pick, in l2 or cos", required=True | |
| ) | |
| parser.add_argument("-n", "--name", help="Dataset name, like: sift", required=True) | |
| args = parser.parse_args() | |
| HOME = Path.home() | |
| INDEX_PATH = join(HOME, f"indexes/pg/{args.name}/{args.metric}") | |
| DATA_PATH = join(HOME, f"{args.name}/{args.name}.hdf5") | |
| os.makedirs(join(HOME, f"indexes/pg/{args.name}"), exist_ok=True) | |
| dataset = h5py.File(DATA_PATH, "r") | |
| test = dataset["test"][:] | |
| if args.metric == "l2": | |
| metric_ops = "<->" | |
| elif args.metric == "cos": | |
| metric_ops = "<=>" | |
| else: | |
| raise ValueError | |
| answer = dataset["neighbors"][:] | |
| n, dims = np.shape(test) | |
| m = np.shape(test)[0] | |
| # reconnect for updated GUC variables to take effect | |
| conn = psycopg.connect( | |
| conninfo="postgres://bench:123@localhost:5432/postgres", | |
| dbname="postgres", | |
| autocommit=True, | |
| ) | |
| conn.execute("SET search_path TO public, vectors, rabbithole") | |
| conn.execute("CREATE EXTENSION IF NOT EXISTS vectors") | |
| conn.execute("CREATE EXTENSION IF NOT EXISTS rabbithole") | |
| conn.execute("SET rabbithole.nprobe=300") | |
| register_vector(conn) | |
| Ks = [10, 100] | |
| for k in Ks: | |
| hits = 0 | |
| delta = 0 | |
| pbar = tqdm(enumerate(test), total=m) | |
| for i, query in pbar: | |
| start = time.perf_counter() | |
| result = conn.execute( | |
| f"SELECT id FROM {args.name} ORDER BY embedding {metric_ops} %s LIMIT {k}", | |
| (query,), | |
| ).fetchall() | |
| end = time.perf_counter() | |
| hits += len(set([p[0] for p in result[:k]]) & set(answer[i][:k].tolist())) | |
| delta += end - start | |
| pbar.set_description(f"recall: {hits / k / (i+1)} QPS: {(i+1) / delta} ") | |
| recall = hits / k / m | |
| qps = m / delta | |
| print(f"Top: {k} recall: {recall:.4f} QPS: {qps:.2f}") |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| cargo build --package rabbithole --lib --features pg16 --target x86_64-unknown-linux-gnu --release | |
| ./tools/schema.sh --features pg16 --target x86_64-unknown-linux-gnu --release | expand -t 4 > ./target/schema.sql | |
| export SEMVER="0.0.0" | |
| export VERSION="16" | |
| export ARCH="x86_64" | |
| export PLATFORM="amd64" | |
| ./tools/package.sh | |
| docker build -t starkind/rabbithole:pg16-v0.0.0 -f ./docker/Dockerfile . | |
| docker run --name rabbithole -e POSTGRES_PASSWORD=123 -p 5432:5432 -d starkind/rabbithole:pg16-v0.0.0 | |
| PGPASSWORD=123 psql -h 127.0.0.1 -U postgres -c "CREATE USER bench WITH PASSWORD '123';" | |
| PGPASSWORD=123 psql -h 127.0.0.1 -U postgres -c "ALTER ROLE bench SUPERUSER;" | |
| # pip install pgvecto_rs numpy faiss-cpu psycopg h5py |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| from os.path import join | |
| import os | |
| import time | |
| import argparse | |
| from pathlib import Path | |
| import pickle | |
| import psycopg | |
| import h5py | |
| from pgvecto_rs.psycopg import register_vector | |
| import numpy as np | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument( | |
| "-m", "--metric", help="Metric to pick, in l2 or cos", required=True | |
| ) | |
| parser.add_argument("-n", "--name", help="Dataset name, like: sift", required=True) | |
| args = parser.parse_args() | |
| HOME = Path.home() | |
| DATA_PATH = join(HOME, f"{args.name}/{args.name}.hdf5") | |
| os.makedirs(join(HOME, f"indexes/pg/{args.name}"), exist_ok=True) | |
| dataset = h5py.File(DATA_PATH, "r") | |
| train = dataset["train"][:] | |
| test = dataset["test"][:] | |
| K = 4096 | |
| if args.metric == "l2": | |
| metric_ops = "vector_l2_ops" | |
| ivf_config = f""" | |
| nlist = {K} | |
| spherical_centroids = false | |
| [centroids] | |
| table = 'public.centroids' | |
| column = 'coordinate' | |
| """ | |
| elif args.metric == "cos": | |
| metric_ops = "vector_cos_ops" | |
| ivf_config = f""" | |
| nlist = {K} | |
| spherical_centroids = true | |
| [centroids] | |
| table = 'public.centroids' | |
| column = 'coordinate' | |
| """ | |
| else: | |
| raise ValueError | |
| answer = dataset["neighbors"][:] | |
| n, dims = np.shape(train) | |
| m = np.shape(test)[0] | |
| keepalive_kwargs = { | |
| "keepalives": 1, | |
| "keepalives_idle": 30, | |
| "keepalives_interval": 5, | |
| "keepalives_count": 5, | |
| } | |
| start = time.perf_counter() | |
| with open(f"{args.name}.pickle", "rb") as f: | |
| centroids = pickle.load(f) | |
| conn = psycopg.connect( | |
| conninfo="postgres://bench:123@localhost:5432/postgres", | |
| dbname="postgres", | |
| autocommit=True, | |
| **keepalive_kwargs, | |
| ) | |
| conn.execute("SET search_path TO public, vectors, rabbithole") | |
| conn.execute("CREATE EXTENSION IF NOT EXISTS vectors") | |
| conn.execute("CREATE EXTENSION IF NOT EXISTS rabbithole") | |
| register_vector(conn) | |
| conn.execute(f"DROP TABLE IF EXISTS public.centroids") | |
| conn.execute(f"CREATE TABLE public.centroids (coordinate vector({dims}))") | |
| with conn.cursor().copy( | |
| f"COPY public.centroids (coordinate) FROM STDIN WITH (FORMAT BINARY)" | |
| ) as copy: | |
| copy.set_types(["vector"]) | |
| for i in range(K): | |
| copy.write_row([centroids[i]]) | |
| while conn.pgconn.flush() == 1: | |
| pass | |
| conn.execute(f"DROP TABLE IF EXISTS {args.name}") | |
| conn.execute(f"CREATE TABLE {args.name} (id integer, embedding vector({dims}))") | |
| with conn.cursor().copy( | |
| f"COPY {args.name} (id, embedding) FROM STDIN WITH (FORMAT BINARY)" | |
| ) as copy: | |
| copy.set_types(["integer", "vector"]) | |
| for i in range(n): | |
| copy.write_row([i, train[i]]) | |
| while conn.pgconn.flush() == 1: | |
| pass | |
| start = time.perf_counter() | |
| conn.execute( | |
| f"CREATE INDEX ON {args.name} USING ivfrq (embedding {metric_ops}) WITH (options = $${ivf_config}$$)" | |
| ) | |
| end = time.perf_counter() | |
| delta = end - start | |
| print(f"Index build time: {delta:.2f}s") |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| from os.path import join | |
| import os | |
| import time | |
| import argparse | |
| from pathlib import Path | |
| import pickle | |
| import h5py | |
| import faiss | |
| import numpy as np | |
| K = 4096 | |
| SEED = 42 | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument("-n", "--name", help="Dataset name, like: sift", required=True) | |
| args = parser.parse_args() | |
| HOME = Path.home() | |
| DATA_PATH = join(HOME, f"{args.name}/{args.name}.hdf5") | |
| dataset = h5py.File(DATA_PATH, "r") | |
| if len(dataset["train"]) > 256 * K: | |
| rs = np.random.RandomState(SEED) | |
| idx = rs.choice(len(dataset["train"]), size=256 * K, replace=False) | |
| train = dataset["train"][np.sort(idx)] | |
| else: | |
| train = dataset["train"][:] | |
| test = dataset["test"][:] | |
| if np.shape(train)[0] > 256 * K: | |
| rs = np.random.RandomState(SEED) | |
| idx = rs.choice(np.shape(train)[0], size=256 * K, replace=False) | |
| train = train[idx] | |
| answer = dataset["neighbors"][:] | |
| n, dims = np.shape(train) | |
| m = np.shape(test)[0] | |
| start = time.perf_counter() | |
| index = faiss.IndexFlatL2(dims) | |
| clustering = faiss.Clustering(dims, K) | |
| clustering.verbose = True | |
| clustering.seed = 42 | |
| clustering.niter = 10 | |
| clustering.train(train, index) | |
| centroids = faiss.vector_float_to_array(clustering.centroids) | |
| end = time.perf_counter() | |
| delta = end - start | |
| print(f"K-means time: {delta:.2f}s") | |
| centroids = centroids.reshape([K, -1]) | |
| with open(f"{args.name}.pickle", "wb") as f: | |
| pickle.dump(centroids, f) |
Author
Author
Benchmark
Result of different datasets
We introduce new generated cohere-1m-23 and cohere-10m-23 to our benchmark.
All experiments are carried out with arguments:
- K/nlist = 4096
- nprob = 300
| dataset | metric | dim | QPS / Top 10 | recall / Top 10 | QPS / Top 100 | recall / Top 100 |
|---|---|---|---|---|---|---|
| sift | l2 | 120 | 302.96 | 0.9939 | 156.38 | 0.9924 |
| gist | l2 | 960 | 103.78 | 0.9851 | 80.80 | 0.9753 |
| cohere-1m-22(old) | l2 (incorrect) | 768 | 154.17 | 0.8848 | 117.09 | 0.8637 |
| openai-500k | l2 | 1536 | 143.70 | 0.9887 | 118.51 | 0.9754 |
| cohere-1m-23(new) | l2 | 1024 | 120.92 | 0.9830 | 97.39 | 0.9782 |
| cohere-10m-23(new) | l2 | 1024 | 8.90 | 0.9842 | 2.64 | 0.9839 |
Effect of nprob in large dataset
We try nprob from 10 to 300 at large dataset cohere-10m-23.
| nprob | QPS / Top 10 | recall / Top 10 | QPS / Top 100 | recall / Top 100 |
|---|---|---|---|---|
| 300 | 8.90 | 0.9842 | 2.64 | 0.9839 |
| 200 | 22.34 | 0.9762 | 19.51 | 0.9769 |
| 100 | 42.10 | 0.9567 | 29.75 | 0.9606 |
| 50 | 78.08 | 0.9280 | 34.71 | 0.9334 |
| 20 | 157.09 | 0.8715 | 27.63 | 0.8734 |
| 10 | 253.73 | 0.8142 | 25.7 | 0.8079 |
Author
Memory usage
For dataset cohere-10m-23, we observed about 44G for PostgreSQL instance and 40G for load.py script at Index Build.
So the AWS instance m7i.8xlarge of 128G memory is required.
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
About new Cohere dataset and L2 Metric
cohere-1m from Zillis VectorDBBench is from cohere-2022, which seems not be normalized in their blog.
While cohere-2023 explain clearly that it has been normlized, so we can use L2 metric on it.
And there is other updates from cohere-2022 to cohere-2023:
Due to these changes, we will generate new datasets from
cohere-2023for Xlarge test: