Skip to content

Instantly share code, notes, and snippets.

@benwtrent
Created November 20, 2024 20:54
Show Gist options
  • Select an option

  • Save benwtrent/b0edb3975d2f03356c1a5ea84c72abc9 to your computer and use it in GitHub Desktop.

Select an option

Save benwtrent/b0edb3975d2f03356c1a5ea84c72abc9 to your computer and use it in GitHub Desktop.
jmh benchmark
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the "Elastic License
* 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side
* Public License v 1"; you may not use this file except in compliance with, at
* your election, the "Elastic License 2.0", the "GNU Affero General Public
* License v3.0 only", or the "Server Side Public License, v 1".
*/
package org.elasticsearch.benchmark.vector;
import org.apache.lucene.util.Constants;
import org.elasticsearch.common.logging.LogConfigurator;
import org.openjdk.jmh.annotations.Benchmark;
import org.openjdk.jmh.annotations.BenchmarkMode;
import org.openjdk.jmh.annotations.Fork;
import org.openjdk.jmh.annotations.Measurement;
import org.openjdk.jmh.annotations.Mode;
import org.openjdk.jmh.annotations.OutputTimeUnit;
import org.openjdk.jmh.annotations.Param;
import org.openjdk.jmh.annotations.Scope;
import org.openjdk.jmh.annotations.Setup;
import org.openjdk.jmh.annotations.State;
import org.openjdk.jmh.annotations.Warmup;
import java.io.IOException;
import java.util.concurrent.ThreadLocalRandom;
import java.util.concurrent.TimeUnit;
@Fork(value = 1, jvmArgsPrepend = { "--add-modules=jdk.incubator.vector" })
@Warmup(iterations = 3, time = 3)
@Measurement(iterations = 5, time = 3)
@BenchmarkMode(Mode.Throughput)
@OutputTimeUnit(TimeUnit.MICROSECONDS)
@State(Scope.Thread)
/**
* Benchmark that compares various scalar quantized vector similarity function
* implementations;: scalar, lucene's panama-ized, and Elasticsearch's native.
* Run with ./gradlew -p benchmarks run --args 'IpBitVectorScorerBenchmark'
*/
public class IpBitVectorScorerBenchmark {
static {
LogConfigurator.configureESLogging(); // native access requires logging to be initialized
}
@Param({ "768" })
int dims;
byte[] bits;
byte[] byteVec;
float[] floatVec;
@Setup
public void setup() throws IOException {
bits = new byte[dims / Byte.SIZE];
byteVec = new byte[dims];
floatVec = new float[dims];
ThreadLocalRandom.current().nextBytes(bits);
ThreadLocalRandom.current().nextBytes(byteVec);
for (int i = 0; i < dims; i++) {
floatVec[i] = ThreadLocalRandom.current().nextFloat();
}
}
@Benchmark
public float dotProductFloatUnwrap() throws IOException {
return ipFloatBitUnWrap(floatVec, bits);
}
@Benchmark
public float dotProductFloatIfStatement() throws IOException {
return ipFloatBit(floatVec, bits);
}
@Benchmark
public float dotProductByteUnwrap() throws IOException {
return ipByteBitUnWrap(byteVec, bits);
}
@Benchmark
public float dotProductByteIfStatement() throws IOException {
return ipByteBit(byteVec, bits);
}
private static float fma(float a, float b, float c) {
if (Constants.HAS_FAST_SCALAR_FMA) {
return Math.fma(a, b, c);
} else {
return a * b + c;
}
}
static int ipByteBitUnWrap(byte[] q, byte[] d) {
int acc = 0;
int acc1 = 0;
int acc2 = 0;
int acc3 = 0;
// now combine the two vectors, summing the byte dimensions where the bit in d is `1`
for (int i = 0; i < d.length; i++) {
byte mask = d[i];
acc += q[i * Byte.SIZE] * (mask & (1 << 0));
acc1 += q[i * Byte.SIZE + 1] * (mask & (1 << 1));
acc2 += q[i * Byte.SIZE + 2] * (mask & (1 << 2));
acc3 += q[i * Byte.SIZE + 3] * (mask & (1 << 3));
acc += q[i * Byte.SIZE + 4] * (mask & (1 << 4));
acc1 += q[i * Byte.SIZE + 5] * (mask & (1 << 5));
acc2 += q[i * Byte.SIZE + 6] * (mask & (1 << 6));
acc3 += q[i * Byte.SIZE + 7] * (mask & (1 << 7));
}
return acc + acc1 + acc2 + acc3;
}
static int ipByteBit(byte[] q, byte[] d) {
int result = 0;
// now combine the two vectors, summing the byte dimensions where the bit in d is `1`
for (int i = 0; i < d.length; i++) {
byte mask = d[i];
for (int j = 0; j < Byte.SIZE; j++) {
if ((mask & (1 << j)) != 0) {
result += q[i * Byte.SIZE + j];
}
}
}
return result;
}
static float ipFloatBitUnWrap(float[] q, byte[] d) {
float acc = 0;
float acc1 = 0;
float acc2 = 0;
float acc3 = 0;
// now combine the two vectors, summing the byte dimensions where the bit in d is `1`
for (int i = 0; i < d.length; i++) {
byte mask = d[i];
acc = fma(q[i * Byte.SIZE], mask & (1 << 0), acc);
acc1 = fma(q[i * Byte.SIZE + 1], mask & (1 << 1), acc1);
acc2 = fma(q[i * Byte.SIZE + 2], mask & (1 << 2), acc2);
acc3 = fma(q[i * Byte.SIZE + 3], mask & (1 << 3), acc3);
acc = fma(q[i * Byte.SIZE + 4], mask & (1 << 4), acc);
acc1 = fma(q[i * Byte.SIZE + 5], mask & (1 << 5), acc1);
acc2 = fma(q[i * Byte.SIZE + 6], mask & (1 << 6), acc2);
acc3 = fma(q[i * Byte.SIZE + 7], mask & (1 << 7), acc3);
}
return acc + acc1 + acc2 + acc3;
}
static float ipFloatBit(float[] q, byte[] d) {
float result = 0;
for (int i = 0; i < d.length; i++) {
byte mask = d[i];
for (int j = 0; j < Byte.SIZE; j++) {
if ((mask & (1 << j)) != 0) {
result += q[i * Byte.SIZE + j];
}
}
}
return result;
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment