Created
November 20, 2024 20:54
-
-
Save benwtrent/b0edb3975d2f03356c1a5ea84c72abc9 to your computer and use it in GitHub Desktop.
jmh benchmark
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| /* | |
| * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one | |
| * or more contributor license agreements. Licensed under the "Elastic License | |
| * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side | |
| * Public License v 1"; you may not use this file except in compliance with, at | |
| * your election, the "Elastic License 2.0", the "GNU Affero General Public | |
| * License v3.0 only", or the "Server Side Public License, v 1". | |
| */ | |
| package org.elasticsearch.benchmark.vector; | |
| import org.apache.lucene.util.Constants; | |
| import org.elasticsearch.common.logging.LogConfigurator; | |
| import org.openjdk.jmh.annotations.Benchmark; | |
| import org.openjdk.jmh.annotations.BenchmarkMode; | |
| import org.openjdk.jmh.annotations.Fork; | |
| import org.openjdk.jmh.annotations.Measurement; | |
| import org.openjdk.jmh.annotations.Mode; | |
| import org.openjdk.jmh.annotations.OutputTimeUnit; | |
| import org.openjdk.jmh.annotations.Param; | |
| import org.openjdk.jmh.annotations.Scope; | |
| import org.openjdk.jmh.annotations.Setup; | |
| import org.openjdk.jmh.annotations.State; | |
| import org.openjdk.jmh.annotations.Warmup; | |
| import java.io.IOException; | |
| import java.util.concurrent.ThreadLocalRandom; | |
| import java.util.concurrent.TimeUnit; | |
| @Fork(value = 1, jvmArgsPrepend = { "--add-modules=jdk.incubator.vector" }) | |
| @Warmup(iterations = 3, time = 3) | |
| @Measurement(iterations = 5, time = 3) | |
| @BenchmarkMode(Mode.Throughput) | |
| @OutputTimeUnit(TimeUnit.MICROSECONDS) | |
| @State(Scope.Thread) | |
| /** | |
| * Benchmark that compares various scalar quantized vector similarity function | |
| * implementations;: scalar, lucene's panama-ized, and Elasticsearch's native. | |
| * Run with ./gradlew -p benchmarks run --args 'IpBitVectorScorerBenchmark' | |
| */ | |
| public class IpBitVectorScorerBenchmark { | |
| static { | |
| LogConfigurator.configureESLogging(); // native access requires logging to be initialized | |
| } | |
| @Param({ "768" }) | |
| int dims; | |
| byte[] bits; | |
| byte[] byteVec; | |
| float[] floatVec; | |
| @Setup | |
| public void setup() throws IOException { | |
| bits = new byte[dims / Byte.SIZE]; | |
| byteVec = new byte[dims]; | |
| floatVec = new float[dims]; | |
| ThreadLocalRandom.current().nextBytes(bits); | |
| ThreadLocalRandom.current().nextBytes(byteVec); | |
| for (int i = 0; i < dims; i++) { | |
| floatVec[i] = ThreadLocalRandom.current().nextFloat(); | |
| } | |
| } | |
| @Benchmark | |
| public float dotProductFloatUnwrap() throws IOException { | |
| return ipFloatBitUnWrap(floatVec, bits); | |
| } | |
| @Benchmark | |
| public float dotProductFloatIfStatement() throws IOException { | |
| return ipFloatBit(floatVec, bits); | |
| } | |
| @Benchmark | |
| public float dotProductByteUnwrap() throws IOException { | |
| return ipByteBitUnWrap(byteVec, bits); | |
| } | |
| @Benchmark | |
| public float dotProductByteIfStatement() throws IOException { | |
| return ipByteBit(byteVec, bits); | |
| } | |
| private static float fma(float a, float b, float c) { | |
| if (Constants.HAS_FAST_SCALAR_FMA) { | |
| return Math.fma(a, b, c); | |
| } else { | |
| return a * b + c; | |
| } | |
| } | |
| static int ipByteBitUnWrap(byte[] q, byte[] d) { | |
| int acc = 0; | |
| int acc1 = 0; | |
| int acc2 = 0; | |
| int acc3 = 0; | |
| // now combine the two vectors, summing the byte dimensions where the bit in d is `1` | |
| for (int i = 0; i < d.length; i++) { | |
| byte mask = d[i]; | |
| acc += q[i * Byte.SIZE] * (mask & (1 << 0)); | |
| acc1 += q[i * Byte.SIZE + 1] * (mask & (1 << 1)); | |
| acc2 += q[i * Byte.SIZE + 2] * (mask & (1 << 2)); | |
| acc3 += q[i * Byte.SIZE + 3] * (mask & (1 << 3)); | |
| acc += q[i * Byte.SIZE + 4] * (mask & (1 << 4)); | |
| acc1 += q[i * Byte.SIZE + 5] * (mask & (1 << 5)); | |
| acc2 += q[i * Byte.SIZE + 6] * (mask & (1 << 6)); | |
| acc3 += q[i * Byte.SIZE + 7] * (mask & (1 << 7)); | |
| } | |
| return acc + acc1 + acc2 + acc3; | |
| } | |
| static int ipByteBit(byte[] q, byte[] d) { | |
| int result = 0; | |
| // now combine the two vectors, summing the byte dimensions where the bit in d is `1` | |
| for (int i = 0; i < d.length; i++) { | |
| byte mask = d[i]; | |
| for (int j = 0; j < Byte.SIZE; j++) { | |
| if ((mask & (1 << j)) != 0) { | |
| result += q[i * Byte.SIZE + j]; | |
| } | |
| } | |
| } | |
| return result; | |
| } | |
| static float ipFloatBitUnWrap(float[] q, byte[] d) { | |
| float acc = 0; | |
| float acc1 = 0; | |
| float acc2 = 0; | |
| float acc3 = 0; | |
| // now combine the two vectors, summing the byte dimensions where the bit in d is `1` | |
| for (int i = 0; i < d.length; i++) { | |
| byte mask = d[i]; | |
| acc = fma(q[i * Byte.SIZE], mask & (1 << 0), acc); | |
| acc1 = fma(q[i * Byte.SIZE + 1], mask & (1 << 1), acc1); | |
| acc2 = fma(q[i * Byte.SIZE + 2], mask & (1 << 2), acc2); | |
| acc3 = fma(q[i * Byte.SIZE + 3], mask & (1 << 3), acc3); | |
| acc = fma(q[i * Byte.SIZE + 4], mask & (1 << 4), acc); | |
| acc1 = fma(q[i * Byte.SIZE + 5], mask & (1 << 5), acc1); | |
| acc2 = fma(q[i * Byte.SIZE + 6], mask & (1 << 6), acc2); | |
| acc3 = fma(q[i * Byte.SIZE + 7], mask & (1 << 7), acc3); | |
| } | |
| return acc + acc1 + acc2 + acc3; | |
| } | |
| static float ipFloatBit(float[] q, byte[] d) { | |
| float result = 0; | |
| for (int i = 0; i < d.length; i++) { | |
| byte mask = d[i]; | |
| for (int j = 0; j < Byte.SIZE; j++) { | |
| if ((mask & (1 << j)) != 0) { | |
| result += q[i * Byte.SIZE + j]; | |
| } | |
| } | |
| } | |
| return result; | |
| } | |
| } |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment