Last active
October 9, 2025 23:02
-
-
Save praateekmahajan/17abaf7bfe435cd6cbb98ac6d0650377 to your computer and use it in GitHub Desktop.
google/embeddinggemma-300m Embedding is different depending on batch contents (when sequence lengths are different) https://huggingface.co/google/embeddinggemma-300m/discussions/28
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| { | |
| "nbformat": 4, | |
| "nbformat_minor": 0, | |
| "metadata": { | |
| "colab": { | |
| "provenance": [], | |
| "gpuType": "T4", | |
| "authorship_tag": "ABX9TyMiwTe8hhQ1lttL2VStNnY1", | |
| "include_colab_link": true | |
| }, | |
| "kernelspec": { | |
| "name": "python3", | |
| "display_name": "Python 3" | |
| }, | |
| "language_info": { | |
| "name": "python" | |
| }, | |
| "accelerator": "GPU" | |
| }, | |
| "cells": [ | |
| { | |
| "cell_type": "markdown", | |
| "metadata": { | |
| "id": "view-in-github", | |
| "colab_type": "text" | |
| }, | |
| "source": [ | |
| "<a href=\"https://colab.research.google.com/gist/praateekmahajan/17abaf7bfe435cd6cbb98ac6d0650377/untitled3.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 7, | |
| "metadata": { | |
| "colab": { | |
| "base_uri": "https://localhost:8080/" | |
| }, | |
| "id": "bRLYEaYwJgO4", | |
| "outputId": "1ae8ee0b-aa7d-49bc-8dd4-aff1c430dc0a" | |
| }, | |
| "outputs": [ | |
| { | |
| "output_type": "stream", | |
| "name": "stdout", | |
| "text": [ | |
| "======================================================================\n", | |
| "Batch-Dependent Embedding Bug Reproduction\n", | |
| "======================================================================\n", | |
| "\n", | |
| "Testing: google/embeddinggemma-300m\n", | |
| "======================================================================\n", | |
| "Token lengths: TEXT_113=113, TEXT_128=128, TEXT_129=129, TEXT_130=130\n", | |
| "\n", | |
| "Test 1: 129 + 129 (same length)\n", | |
| " TEXT_129 diff: 0.00000014 ✅ OK\n", | |
| "\n", | |
| "Test 2: 113 + 128 (below/at 128 boundary)\n", | |
| " TEXT_113 diff: 0.00000019 ✅\n", | |
| " TEXT_128 diff: 0.00000025 ✅\n", | |
| "\n", | |
| "Test 3: 128 + 129 (crosses 128→129 boundary)\n", | |
| " TEXT_128 diff: 0.00000022 ✅\n", | |
| " TEXT_129 diff: 0.01030827 ❌ BUG\n", | |
| " → TEXT_129 embedding changed by 1.03%\n", | |
| "\n", | |
| "Test 4: 113 + 129 (crosses 128→129 boundary)\n", | |
| " TEXT_113 diff: 0.00000017 ✅\n", | |
| " TEXT_129 diff: 0.01030827 ❌ BUG\n", | |
| " → TEXT_129 embedding changed by 1.03%\n", | |
| "\n", | |
| "Test 5: 113 + 130 (crosses to 130, skips 129)\n", | |
| " TEXT_113 diff: 0.00000017 ✅\n", | |
| " TEXT_130 diff: 0.00000016 ✅\n", | |
| "\n", | |
| "Testing: sentence-transformers/all-MiniLM-L6-v2\n", | |
| "======================================================================\n", | |
| "Token lengths: TEXT_113=113, TEXT_128=128, TEXT_129=129, TEXT_130=130\n", | |
| "\n", | |
| "Test 1: 129 + 129 (same length)\n", | |
| " TEXT_129 diff: 0.00000003 ✅ OK\n", | |
| "\n", | |
| "Test 2: 113 + 128 (below/at 128 boundary)\n", | |
| " TEXT_113 diff: 0.00000004 ✅\n", | |
| " TEXT_128 diff: 0.00000003 ✅\n", | |
| "\n", | |
| "Test 3: 128 + 129 (crosses 128→129 boundary)\n", | |
| " TEXT_128 diff: 0.00000003 ✅\n", | |
| " TEXT_129 diff: 0.00000003 ✅\n", | |
| "\n", | |
| "Test 4: 113 + 129 (crosses 128→129 boundary)\n", | |
| " TEXT_113 diff: 0.00000003 ✅\n", | |
| " TEXT_129 diff: 0.00000003 ✅\n", | |
| "\n", | |
| "Test 5: 113 + 130 (crosses to 130, skips 129)\n", | |
| " TEXT_113 diff: 0.00000003 ✅\n", | |
| " TEXT_130 diff: 0.00000003 ✅\n", | |
| "\n", | |
| "======================================================================\n", | |
| "SUMMARY\n", | |
| "======================================================================\n", | |
| "google/embeddinggemma-300m: ❌ BUG\n", | |
| "sentence-transformers/all-MiniLM-L6-v2: ✅ OK\n", | |
| "\n", | |
| "✓ Bug confirmed in embeddinggemma-300m\n", | |
| "✓ all-MiniLM-L6-v2 works correctly\n", | |
| "\n", | |
| "The bug appears ONLY in embeddinggemma-300m when batching\n", | |
| "sequences that cross the 128→129 boundary.\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "\"\"\"\n", | |
| "Minimal reproduction of batch-dependent embedding bug in google/embeddinggemma-300m\n", | |
| "\n", | |
| "BUG DESCRIPTION:\n", | |
| "When batching sequences of different lengths, embeddinggemma-300m produces\n", | |
| "incorrect embeddings ONLY for sequences with length exactly 129 tokens\n", | |
| "when batched with shorter sequences.\n", | |
| "The same pattern bug occurs at 193 (= 192+1) and 257 (= 256+1),\n", | |
| "suggesting an issue with boundary handling at\n", | |
| "powers-of-2 positions (128=2^7, 192=3×2^6, 256=2^8).\n", | |
| "\n", | |
| "SPECIFIC ISSUE:\n", | |
| "- Batching 113 + 129 tokens: ❌ 129-token sequence gets wrong embedding (0.01030827 error)\n", | |
| "- Batching 128 + 129 tokens: ❌ 129-token sequence gets wrong embedding (0.01030827 error)\n", | |
| "- Batching 113 + 130 tokens: ✅ Both sequences work correctly\n", | |
| "- Batching 129 + 129 tokens: ✅ Both sequences work correctly\n", | |
| "\n", | |
| "The bug is ONLY at position 129 (= 128+1), not other positions.\n", | |
| "\n", | |
| "TESTING METHODOLOGY:\n", | |
| "Each text is encoded individually (batch_size=1) to establish a \"ground truth\"\n", | |
| "baseline embedding. Then the same texts are encoded in batches of 2. The batched\n", | |
| "embeddings are compared against their individual baselines using max absolute\n", | |
| "difference. Differences > 0.001 indicate incorrect behavior.\n", | |
| "\n", | |
| "Expected: Batching should NOT change embeddings (attention masking isolates sequences).\n", | |
| "Observed: In embeddinggemma-300m, the 129-token sequence embedding changes by ~1%\n", | |
| " when batched with shorter sequences.\n", | |
| "\n", | |
| "COMPARISON:\n", | |
| "Other models like sentence-transformers/all-MiniLM-L6-v2 handle all these cases\n", | |
| "correctly (differences < 1e-7), confirming this is specific to embeddinggemma-300m.\n", | |
| "\"\"\"\n", | |
| "\n", | |
| "import numpy as np\n", | |
| "from sentence_transformers import SentenceTransformer\n", | |
| "\n", | |
| "\n", | |
| "# Simple repeated texts that produce exactly 113, 128, 129, and 130 tokens\n", | |
| "# Using \"Hello \" repeated: 111×, 126×, 127×, and 128× respectively\n", | |
| "TEXT_113 = (\"Hello \" * 111).strip() # 113 tokens\n", | |
| "TEXT_128 = (\"Hello \" * 126).strip() # 128 tokens\n", | |
| "TEXT_129 = (\"Hello \" * 127).strip() # 129 tokens\n", | |
| "TEXT_130 = (\"Hello \" * 128).strip() # 130 tokens\n", | |
| "\n", | |
| "\n", | |
| "def test_model(model_name):\n", | |
| " \"\"\"Test if a model has the bug.\"\"\"\n", | |
| " print(f\"\\nTesting: {model_name}\")\n", | |
| " print(\"=\"*70)\n", | |
| "\n", | |
| " model = SentenceTransformer(model_name).to(\"cuda\")\n", | |
| "\n", | |
| " # Check token lengths for this model\n", | |
| " tokens_113 = model.tokenize([TEXT_113])\n", | |
| " tokens_128 = model.tokenize([TEXT_128])\n", | |
| " tokens_129 = model.tokenize([TEXT_129])\n", | |
| " tokens_130 = model.tokenize([TEXT_130])\n", | |
| " len_113 = tokens_113['attention_mask'].sum().item()\n", | |
| " len_128 = tokens_128['attention_mask'].sum().item()\n", | |
| " len_129 = tokens_129['attention_mask'].sum().item()\n", | |
| " len_130 = tokens_130['attention_mask'].sum().item()\n", | |
| " print(f\"Token lengths: TEXT_113={len_113}, TEXT_128={len_128}, TEXT_129={len_129}, TEXT_130={len_130}\")\n", | |
| "\n", | |
| " # Get baselines for all texts\n", | |
| " baseline_113 = model.encode([TEXT_113], batch_size=1)\n", | |
| " baseline_128 = model.encode([TEXT_128], batch_size=1)\n", | |
| " baseline_129 = model.encode([TEXT_129], batch_size=1)\n", | |
| " baseline_130 = model.encode([TEXT_130], batch_size=1)\n", | |
| "\n", | |
| " has_bug = False\n", | |
| "\n", | |
| " # Test 1: 129 + 129 (same length - should always work)\n", | |
| " print(f\"\\nTest 1: {len_129} + {len_129} (same length)\")\n", | |
| " batch_129_129 = model.encode([TEXT_129, TEXT_129], batch_size=2)\n", | |
| " diff_129_129 = np.abs(baseline_129 - batch_129_129[0:1]).max()\n", | |
| " status = \"❌ BUG\" if diff_129_129 > 1e-3 else \"✅ OK\"\n", | |
| " print(f\" TEXT_129 diff: {diff_129_129:.8f} {status}\")\n", | |
| " if diff_129_129 > 1e-3:\n", | |
| " has_bug = True\n", | |
| "\n", | |
| " # Test 2: 113 + 128 (both below/at boundary - should work)\n", | |
| " print(f\"\\nTest 2: {len_113} + {len_128} (below/at 128 boundary)\")\n", | |
| " batch_113_128 = model.encode([TEXT_113, TEXT_128], batch_size=2)\n", | |
| " diff_113 = np.abs(baseline_113 - batch_113_128[0:1]).max()\n", | |
| " diff_128 = np.abs(baseline_128 - batch_113_128[1:2]).max()\n", | |
| " status_113 = \"❌\" if diff_113 > 1e-3 else \"✅\"\n", | |
| " status_128 = \"❌\" if diff_128 > 1e-3 else \"✅\"\n", | |
| " print(f\" TEXT_113 diff: {diff_113:.8f} {status_113}\")\n", | |
| " print(f\" TEXT_128 diff: {diff_128:.8f} {status_128}\")\n", | |
| " if diff_113 > 1e-3 or diff_128 > 1e-3:\n", | |
| " has_bug = True\n", | |
| "\n", | |
| " # Test 3: 128 + 129 (crosses 128→129 boundary - may trigger bug)\n", | |
| " print(f\"\\nTest 3: {len_128} + {len_129} (crosses 128→129 boundary)\")\n", | |
| " batch_128_129 = model.encode([TEXT_128, TEXT_129], batch_size=2)\n", | |
| " diff_128_in_batch = np.abs(baseline_128 - batch_128_129[0:1]).max()\n", | |
| " diff_129_in_batch = np.abs(baseline_129 - batch_128_129[1:2]).max()\n", | |
| " status_128 = \"❌\" if diff_128_in_batch > 1e-3 else \"✅\"\n", | |
| " status_129 = \"❌ BUG\" if diff_129_in_batch > 1e-3 else \"✅\"\n", | |
| " print(f\" TEXT_128 diff: {diff_128_in_batch:.8f} {status_128}\")\n", | |
| " print(f\" TEXT_129 diff: {diff_129_in_batch:.8f} {status_129}\")\n", | |
| " if diff_129_in_batch > 1e-3:\n", | |
| " print(f\" → TEXT_129 embedding changed by {diff_129_in_batch*100:.2f}%\")\n", | |
| " has_bug = True\n", | |
| "\n", | |
| " # Test 4: 113 + 129 (crosses boundary - may trigger bug)\n", | |
| " print(f\"\\nTest 4: {len_113} + {len_129} (crosses 128→129 boundary)\")\n", | |
| " batch_113_129 = model.encode([TEXT_113, TEXT_129], batch_size=2)\n", | |
| " diff_113_in_batch = np.abs(baseline_113 - batch_113_129[0:1]).max()\n", | |
| " diff_129_in_batch2 = np.abs(baseline_129 - batch_113_129[1:2]).max()\n", | |
| " status_113 = \"❌\" if diff_113_in_batch > 1e-3 else \"✅\"\n", | |
| " status_129 = \"❌ BUG\" if diff_129_in_batch2 > 1e-3 else \"✅\"\n", | |
| " print(f\" TEXT_113 diff: {diff_113_in_batch:.8f} {status_113}\")\n", | |
| " print(f\" TEXT_129 diff: {diff_129_in_batch2:.8f} {status_129}\")\n", | |
| " if diff_129_in_batch2 > 1e-3:\n", | |
| " print(f\" → TEXT_129 embedding changed by {diff_129_in_batch2*100:.2f}%\")\n", | |
| " has_bug = True\n", | |
| "\n", | |
| " # Test 5: 113 + 130 (crosses to 130, not 129 - should work)\n", | |
| " print(f\"\\nTest 5: {len_113} + {len_130} (crosses to 130, skips 129)\")\n", | |
| " batch_113_130 = model.encode([TEXT_113, TEXT_130], batch_size=2)\n", | |
| " diff_113_in_batch2 = np.abs(baseline_113 - batch_113_130[0:1]).max()\n", | |
| " diff_130 = np.abs(baseline_130 - batch_113_130[1:2]).max()\n", | |
| " status_113 = \"❌\" if diff_113_in_batch2 > 1e-3 else \"✅\"\n", | |
| " status_130 = \"❌\" if diff_130 > 1e-3 else \"✅\"\n", | |
| " print(f\" TEXT_113 diff: {diff_113_in_batch2:.8f} {status_113}\")\n", | |
| " print(f\" TEXT_130 diff: {diff_130:.8f} {status_130}\")\n", | |
| " if diff_113_in_batch2 > 1e-3 or diff_130 > 1e-3:\n", | |
| " has_bug = True\n", | |
| "\n", | |
| " return has_bug\n", | |
| "\n", | |
| "\n", | |
| "def main():\n", | |
| " print(\"=\"*70)\n", | |
| " print(\"Batch-Dependent Embedding Bug Reproduction\")\n", | |
| " print(\"=\"*70)\n", | |
| "\n", | |
| " # Test embeddinggemma\n", | |
| " gemma_has_bug = test_model(\"google/embeddinggemma-300m\")\n", | |
| "\n", | |
| " # Test all-MiniLM for comparison\n", | |
| " minilm_has_bug = test_model(\"sentence-transformers/all-MiniLM-L6-v2\")\n", | |
| "\n", | |
| " # Summary\n", | |
| " print(\"\\n\" + \"=\"*70)\n", | |
| " print(\"SUMMARY\")\n", | |
| " print(\"=\"*70)\n", | |
| " print(f\"google/embeddinggemma-300m: {'❌ BUG' if gemma_has_bug else '✅ OK'}\")\n", | |
| " print(f\"sentence-transformers/all-MiniLM-L6-v2: {'❌ BUG' if minilm_has_bug else '✅ OK'}\")\n", | |
| "\n", | |
| " if gemma_has_bug and not minilm_has_bug:\n", | |
| " print(\"\\n✓ Bug confirmed in embeddinggemma-300m\")\n", | |
| " print(\"✓ all-MiniLM-L6-v2 works correctly\")\n", | |
| " print(\"\\nThe bug appears ONLY in embeddinggemma-300m when batching\")\n", | |
| " print(\"sequences that cross the 128→129 boundary.\")\n", | |
| " elif gemma_has_bug:\n", | |
| " print(\"\\nBug detected in embeddinggemma-300m\")\n", | |
| " else:\n", | |
| " print(\"\\nNo bugs detected in either model\")\n", | |
| "\n", | |
| "\n", | |
| "# from huggingface_hub import login\n", | |
| "# login(token=\"....\")\n", | |
| "\n", | |
| "main()" | |
| ] | |
| } | |
| ] | |
| } |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment