Skip to content

Instantly share code, notes, and snippets.

exec ${PAGER:-/usr/bin/less -R} "$0" || exit 1
Test settings: forge with network access
Host details: itmm4.prod.google.com Linux 6.6.65-smp-1300.170.0.0 x86_64 astoria-genoa-base
executor.INFO: analog/view?storage=borgremote&bns=/bns/it/borg/it/bns/build-forge-executor-tpu/prod-cbf-ghostlite.forge-executor/0&min_time=1764872604000000&ts=1764872614000000
Test command:
cd /build/work/aef67bf50706fee86777a93cc065340a246c/google3/runfiles/google3 && \
env - \
BORG_CELL=it \
CUSTOM_METRICS_DIR=/build/work/aef67bf50706fee86777a93cc065340a246c/google3/../custom_metrics \
Let's trace the values for my_id = 1 with num_devices = 4:
outer_step phase Accumulation Source left_copy_device right_copy_device Device providing the data
0 LEFT x_ref[left_copy_device, ...] (1+0+1)%4 = 2 (1-0-1)%4 = 0 Device 2
0 RIGHT x_ref[right_copy_device, ...] (1+0+1)%4 = 2 (1-0-1)%4 = 0 Device 0
1 LEFT x_ref[left_copy_device, ...] (1+1+1)%4 = 3 (1-1-1)%4 = 3 Device 3
1 RIGHT x_ref[right_copy_device, ...] (1+1+1)%4 = 3 (1-1-1)%4 = 3 Device 3
2 LEFT x_ref[left_copy_device, ...] (1+2+1)%4 = 0 (1-2-1)%4 = 2 Device 0
2 RIGHT x_ref[right_copy_device, ...] (1+2+1)%4 = 0 (1-2-1)%4 = 2 Device 2
As you can see, with each outer_step, the *_copy_device variables change, ensuring that the reduction operation fetches data from a new, distinct device. This systematic progression guarantees that by the end of all steps, each device has accumulated its required portion of the total sum from all other devices.
import jax
from jax import export
import jax.numpy as jnp
import pickle
import time
import statistics
with open("/home/xiowei_google_com/new_exports.pkl", "rb") as f:
data = pickle.load(f)
import jax
from jax import export
import jax.numpy as jnp
import pickle
import time
import statistics
with open("/home/xiowei_google_com/old_exports.pkl", "rb") as f:
data = pickle.load(f)
1. Start the benchmark server in vscode as [this](https://gist.github.com/vanbasten23/dd4f3cbb314a7b9cf6c003103c23c019). Select the correct python intepreter.
2. Then start the vllm server in debugger.
3. After the server is up and running.
4. Add the breakpoint (remember to turn of dynamo and jax jit)
5. Use the [script](https://gist.github.com/vanbasten23/726b28f072993fb7587482672b9c96a9) to send benchmarking request. Make sure to use the correct conda/python.
6. Then dump the input and output.
=========================
pip install flatbuffers
#!/bin/bash
# Usage:
# bash run_tpu_benchmark_client.sh --model Qwen/Qwen2.5-1.5B-Instruct --tp 1
LONGOPTS=model:,tp:,profile
# Parse arguments
PARSED=$(getopt --options=$OPTIONS --longoptions=$LONGOPTS --name "$0" -- "$@")
if [[ $? -ne 0 ]]; then
exit 2
{
"name": "newjax_benchmark_server",
"type": "debugpy",
"request": "launch",
"program": "/home/xiowei_google_com/miniconda3/envs/vllm_newjax/bin/vllm",
"console": "integratedTerminal",
"justMyCode": false,
"env": {
"MODEL_IMPL_TYPE": "vllm",
"TPU_BACKEND_TYPE": "jax",
import jax
from jax import export
import jax.numpy as jnp
import pickle
import time
import statistics
with open("/home/xiowei_google_com/old_exports.pkl", "rb") as f:
data = pickle.load(f)
local keymap = vim.keymap.set
local opts = { noremap = true, silent = true }
-- remap leader key
keymap("n", "<Space>", "", opts)
vim.g.mapleader = " "
vim.g.maplocalleader = " "
-- yank to system clipboard
keymap({"n", "v"}, "<leader>y", '"+y', opts)
from transformers import AutoTokenizer, AutoModelForCausalLM
from peft import PeftModel
import torch
base = "Qwen/Qwen2.5-3B-Instruct"
adapter = "./lora-1plus1-666"
tok = AutoTokenizer.from_pretrained(base)
m = AutoModelForCausalLM.from_pretrained(base, torch_dtype=torch.bfloat16).to("cuda" if torch.cuda.is_available() else "cpu")
m = PeftModel.from_pretrained(m, adapter)