Last active
January 29, 2026 06:43
-
-
Save IntendedConsequence/f5bb53f77f8c3eacbd74a478892ebbbc to your computer and use it in GitHub Desktop.
Chromaprint implementation in pure numpy
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import numpy as np | |
| import math | |
| import sys | |
| from pathlib import Path | |
| import contextlib | |
| import time | |
| # NOTE: Timing decorator courtesy of tinygrad | |
| class Timing(contextlib.ContextDecorator): | |
| def __init__(self, prefix="", on_exit=None, enabled=True): self.prefix, self.on_exit, self.enabled = prefix, on_exit, enabled | |
| def __enter__(self): self.st = time.perf_counter_ns() | |
| def __exit__(self, *exc): | |
| self.et = time.perf_counter_ns() - self.st | |
| if self.enabled: print(f"{self.prefix}{self.et*1e-6:6.2f} ms"+(self.on_exit(self.et) if self.on_exit else "")) | |
| SAMPLE_RATE = 11025 | |
| FRAME_SIZE = 4096 | |
| STRIDE = FRAME_SIZE // 3 | |
| OVERLAP = FRAME_SIZE - STRIDE | |
| BASE = 440.0 / 16.0 | |
| MIN_FREQ = 28 | |
| MAX_FREQ = 3520 | |
| NUM_BANDS = 12 | |
| NUM_FILTERS = 16 | |
| CHROMA_WINDOW_LEN = 16 | |
| def hamming_window(N:int, periodic=True): | |
| M = N + (periodic * 1) | |
| return (0.54 - 0.46 * np.cos(np.arange(M, dtype=np.float32) * 2.0 * math.pi / (M - 1)))[:N] | |
| def octave_from_freq(freq, base = 440.0 / 16.0): | |
| return np.log2(freq / base) | |
| def index_from_freq(freq, frame_size, sample_rate): | |
| return round(freq * frame_size / sample_rate) | |
| def freq_from_index(i, frame_size, sample_rate): | |
| return i * sample_rate / frame_size | |
| def pad_to(s, width, stride): | |
| padding = -(len(s) - width) % stride | |
| return padding | |
| def stft(samples_f32): | |
| fs = FRAME_SIZE | |
| samples_f32_padded = np.pad(samples_f32, (0, pad_to(samples_f32, fs, STRIDE))) | |
| stride_offsets = np.arange(0, len(samples_f32_padded)-OVERLAP, STRIDE) | |
| strided_index_mask = stride_offsets[:, None] + np.arange(FRAME_SIZE)[None] | |
| w = hamming_window(FRAME_SIZE) | |
| fft_features_raw = np.fft.fft(w * samples_f32_padded[strided_index_mask], FRAME_SIZE) | |
| fft_features_shifted = np.fft.fftshift(fft_features_raw, 1) | |
| fft_features_trimmed = fft_features_shifted[:, :FRAME_SIZE//2+1] | |
| features_magnitudes = np.absolute(fft_features_trimmed).T[::-1, :] | |
| features_np = features_magnitudes**2 | |
| return features_np | |
| def samples_f32_from_raw_pcm_s16le(filepath): | |
| samples = np.fromfile(filepath, np.int16) | |
| samples_f32 = samples.astype(np.float32) / (2**15) | |
| return samples_f32 | |
| def features_to_chroma(features_np): | |
| min_index = index_from_freq(MIN_FREQ, FRAME_SIZE, SAMPLE_RATE) | |
| max_index = index_from_freq(MAX_FREQ, FRAME_SIZE, SAMPLE_RATE) | |
| indices = np.arange(min_index, max_index) | |
| frequencies = freq_from_index(indices, FRAME_SIZE, SAMPLE_RATE) | |
| octaves = octave_from_freq(frequencies) | |
| notes = (NUM_BANDS * (octaves - np.floor(octaves))).astype(np.uint8) | |
| notes_project = (notes[:, None] == np.arange(NUM_BANDS)[None]).astype(np.float32) | |
| chroma = np.dot(features_np[indices[0]:indices[-1]+1].T, notes_project) | |
| return chroma | |
| def blur_chroma(chroma): | |
| chroma_kernel = np.array([0.25, 0.75, 1.0, 0.75, 0.25], dtype=np.float32) | |
| chroma_filtered = np.empty((NUM_BANDS, (chroma.shape[0] - 5) + 1), dtype=np.float32) | |
| for i in range(NUM_BANDS): | |
| chroma_filtered[i] = np.convolve(chroma[:, i].T, chroma_kernel, 'valid') | |
| return chroma_filtered | |
| def normalize_chroma(chroma_filtered): | |
| norm = np.linalg.norm(chroma_filtered, axis=0) | |
| chroma_normed = np.where(norm < 0.01, 0, chroma_filtered / norm) | |
| return chroma_normed | |
| def make_filter_masks(type, y, height, width): | |
| shape = (NUM_BANDS, CHROMA_WINDOW_LEN) | |
| mask_area = np.zeros(shape, dtype=np.bool) | |
| mask_area[y:y+height,:width] = 1 | |
| mask_a = np.zeros(shape, dtype=np.bool) | |
| if type == 0: # full square, A only | |
| return mask_area, np.zeros(shape, dtype=np.bool) | |
| elif type == 1: # bottom half square A | |
| half = height // 2 | |
| mask_a[y+half:y+height,:width] = 1 | |
| elif type == 2: # right half square A | |
| half = width // 2 | |
| mask_a[y:y+height,half:width] = 1 | |
| elif type == 3: # checkers, top right and bottom left A | |
| halfx = width // 2 | |
| halfy = height // 2 | |
| mask_a[y:y+halfy,halfx:width] = 1 | |
| mask_a[y+halfy:y+height,:halfx] = 1 | |
| elif type == 4: # hamburger, mid A | |
| third = height // 3 | |
| mask_a[y+third:y+height-third,:width] = 1 | |
| elif type == 5: # vertical hamburger, mid A | |
| third = width // 3 | |
| mask_a[y:y+height,third:width-third] = 1 | |
| return mask_a, ~mask_a & mask_area | |
| def make_filters(): | |
| filter_masks = [ | |
| (make_filter_masks(0, 4, 3, 15), make_quantizer(1.98215, 2.35817, 2.63523)), | |
| (make_filter_masks(4, 4, 6, 15), make_quantizer(-1.03809, -0.651211, -0.282167)), | |
| (make_filter_masks(1, 0, 4, 16), make_quantizer(-0.298702, 0.119262, 0.558497)), | |
| (make_filter_masks(3, 8, 2, 12), make_quantizer(-0.105439, 0.0153946, 0.135898)), | |
| (make_filter_masks(3, 4, 4, 8), make_quantizer(-0.142891, 0.0258736, 0.200632)), | |
| (make_filter_masks(4, 0, 3, 5), make_quantizer(-0.826319, -0.590612, -0.368214)), | |
| (make_filter_masks(1, 2, 2, 9), make_quantizer(-0.557409, -0.233035, 0.0534525)), | |
| (make_filter_masks(2, 7, 3, 4), make_quantizer(-0.0646826, 0.00620476, 0.0784847)), | |
| (make_filter_masks(2, 6, 2, 16), make_quantizer(-0.192387, -0.029699, 0.215855)), | |
| (make_filter_masks(2, 1, 3, 2), make_quantizer(-0.0397818, -0.00568076, 0.0292026)), | |
| (make_filter_masks(5, 10, 1, 15), make_quantizer(-0.53823, -0.369934, -0.190235)), | |
| (make_filter_masks(3, 6, 2, 10), make_quantizer(-0.124877, 0.0296483, 0.139239)), | |
| (make_filter_masks(2, 1, 1, 14), make_quantizer(-0.101475, 0.0225617, 0.231971)), | |
| (make_filter_masks(3, 5, 6, 4), make_quantizer(-0.0799915, -0.00729616, 0.063262)), | |
| (make_filter_masks(1, 9, 2, 12), make_quantizer(-0.272556, 0.019424, 0.302559)), | |
| (make_filter_masks(3, 4, 2, 14), make_quantizer(-0.164292, -0.0321188, 0.0846339)), | |
| ] | |
| return filter_masks | |
| def quantize(x, c): | |
| # quantize + fused grey code | |
| return (c[1] <= x) * 2 + ((x >= c[0]) & (x < c[2])) | |
| def make_quantizer(t0, t1, t2): | |
| return np.array([t0, t1, t2], dtype=np.float32) | |
| def process_masks(filter_masks): | |
| masks_a = np.concatenate([mask_a for (mask_a, _), q in filter_masks], axis=0).reshape((NUM_FILTERS, NUM_BANDS, CHROMA_WINDOW_LEN)) | |
| masks_b = np.concatenate([mask_b for (_, mask_b), q in filter_masks], axis=0).reshape((NUM_FILTERS, NUM_BANDS, CHROMA_WINDOW_LEN)) | |
| masks = np.concatenate([masks_a, masks_b]) | |
| quants = np.concatenate([x[None] for (_, x) in filter_masks], axis=0).T | |
| bitshifts = 2 ** (np.arange(NUM_FILTERS)[::-1] * 2) | |
| return masks, quants, bitshifts | |
| def run_filters_and_quantize(chroma_slice, BS, masks, quants, bitshifts): | |
| to_pad = -chroma_slice.shape[0] % BS | |
| chroma_slice = np.pad(chroma_slice, ((0, to_pad), (0, 0), (0, 0), (0, 0))) | |
| masked = chroma_slice * masks | |
| masked_sums = masked.reshape((BS, 2, NUM_FILTERS, -1)).sum(axis=-1) + 1.0 | |
| masked_as = masked_sums[:, 0] | |
| masked_bs = masked_sums[:, 1] | |
| vals = np.log(masked_as / masked_bs) | |
| qvals = quantize(vals, quants) | |
| res32 = (qvals*bitshifts).sum(axis=-1) | |
| return res32 | |
| def compute_fingerprint(chroma_normed, masks, quants, bitshifts): | |
| fingerprint = [] | |
| last_chroma_offset = chroma_normed.shape[1]-CHROMA_WINDOW_LEN | |
| chroma_offsets = np.arange(0, last_chroma_offset) | |
| chroma_slices = chroma_offsets[None] + np.arange(CHROMA_WINDOW_LEN)[:, None] | |
| slices = chroma_normed.T[chroma_slices].transpose((1, 2, 0)) | |
| BS = 32 | |
| for ii in range(0, last_chroma_offset, BS): | |
| res32 = run_filters_and_quantize(slices[ii:ii+BS, None, ...], BS, masks, quants, bitshifts) | |
| fingerprint.append(res32) | |
| fingerprint_np = np.array(fingerprint, dtype=np.uint32).flatten()[:len(slices)] | |
| return fingerprint_np | |
| def bitscan_forward(x): | |
| return np.bitwise_count((x&-x)-1) | |
| def scanloop(x): | |
| scans = [] | |
| for _ in range(31): | |
| scan = bitscan_forward(x) + 1 | |
| scans.append(scan) | |
| x = np.bitwise_right_shift(x, scan) | |
| return np.pad(np.stack(scans), ((0, 1), (0, 0))).T | |
| def split_into_int3_5_arrays(fpx): | |
| int3_raw = scanloop(fpx).flatten() | |
| int3_pre = int3_raw[int3_raw != 33] | |
| int3 = np.where(int3_pre > 6, 7, int3_pre) | |
| int5 = int3_pre[int3_pre > 6] - 7 | |
| return int3, int5 | |
| def bit_shift_powers(bits): | |
| return np.array([2**(x*bits) for x in range(8)], dtype=np.uint64) | |
| def encode_intx_array(intx, bits): | |
| pows = bit_shift_powers(bits) | |
| intpad = -len(intx) % 8 | |
| intraw = np.pad(intx, (0, intpad)).reshape((-1, 8)) * pows | |
| intenc = intraw.sum(axis=-1).view(np.uint8).reshape((-1, 8))[:, :bits].flatten().tobytes() | |
| padded_bytes = intpad*bits//8 | |
| intbytes = intenc[:len(intenc)-padded_bytes] | |
| return intbytes | |
| def compress_fingerprint(fp): | |
| fpx = np.concatenate((fp[:1], fp[1:] ^ fp[:-1])) | |
| int3, int5 = split_into_int3_5_arrays(fpx) | |
| int3bytes, int5bytes = encode_intx_array(int3, 3), encode_intx_array(int5, 5) | |
| combined_bytes = b'\x01' + np.uint32(len(fpx)).byteswap().tobytes()[1:] + int3bytes + int5bytes | |
| return combined_bytes | |
| def prefix_xor_scan(y): | |
| out = y.copy() | |
| n = len(out) | |
| offset = 1 | |
| while offset < n: | |
| out[offset:] ^= out[:-offset] | |
| offset <<= 1 | |
| return out | |
| def decompress_fingerprint(fpbytes): | |
| sz = np.frombuffer(fpbytes[1:4][::-1] + b'\x00', np.uint32)[0].item() | |
| cbnp = np.frombuffer(fpbytes[4:], dtype=np.uint8) | |
| padto = -cbnp.size % 3 | |
| e1 = np.pad(cbnp, (0, padto)) | |
| e1e = np.pad(e1.reshape((-1, 3)), ((0, 0), (0, 1))).view(np.uint32) >> np.array([x*3 for x in range(8)], dtype=np.uint8) | |
| e1e_int3 = e1e.astype(np.uint8) & 7 | |
| cutoff = np.arange(e1e_int3.size)[e1e_int3.flatten() == 0][sz-1] | |
| e2e_int3 = e1e_int3.flatten()[:cutoff+1] | |
| e3e_int3 = np.pad(e2e_int3.flatten(), (0, 32))[(np.concatenate(([0], np.arange(e1e_int3.size)[e1e_int3.flatten() == 0] + 1))[:, None] + np.arange(32))[:sz]] | |
| e4e_int3 = np.where(e3e_int3.cumprod(axis=-1) != 0, e3e_int3, 0) | |
| int5count = e4e_int3[e4e_int3 == 7].size | |
| e1e_int5 = cbnp[-math.ceil(int5count*5/8):] | |
| e2e_int5 = ((np.pad(np.pad(e1e_int5, (0, -e1e_int5.size % 5)).reshape((-1, 5)), ((0, 0), (0, 3))).view(np.uint64) >> np.array([x*5 for x in range(8)], dtype=np.uint8)) & 31) | |
| e5e_int3 = e4e_int3.flatten() | |
| e5e_int3[e5e_int3 == 7] += e2e_int5.flatten()[:int5count] | |
| e6e_int3 = e5e_int3.reshape((-1, 32)) | |
| fpx_dec = np.where(e6e_int3 != 0, 1 << (e6e_int3.cumsum(axis=-1) - 1), 0).sum(axis=-1).astype(np.uint32) | |
| return prefix_xor_scan(fpx_dec) | |
| def main(): | |
| filename = "UNKLE - Under The Ice (scene edit).s16le" | |
| if len(sys.argv) > 1: | |
| filename = sys.argv[1] | |
| samples_f32 = samples_f32_from_raw_pcm_s16le(filename) | |
| with Timing("stft "): | |
| features_np = stft(samples_f32) | |
| with Timing("chroma "): | |
| chroma = features_to_chroma(features_np) | |
| chroma_filtered = blur_chroma(chroma) | |
| chroma_normed = normalize_chroma(chroma_filtered) | |
| filter_masks = make_filters() | |
| masks, quants, bitshifts = process_masks(filter_masks) | |
| with Timing("fingerprint "): | |
| fingerprint_np = compute_fingerprint(chroma_normed, masks, quants, bitshifts) | |
| Path("output.bin").write_bytes(fingerprint_np.view(dtype=np.uint8).tobytes()) | |
| reference_fingerprint = np.fromfile("reference_ffmpeg.bin", dtype=np.uint8).view(np.uint32) | |
| unpacked = np.unpackbits(np.bitwise_xor(fingerprint_np, reference_fingerprint).view(np.uint8)) | |
| diff = unpacked.sum() | |
| print(f"difference: {diff} out of {unpacked.size} ({(diff / unpacked.size) * 100.0: 6.2f}%)") | |
| if __name__ == '__main__': | |
| main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment