Skip to content

Instantly share code, notes, and snippets.

@OhadRubin
Created March 12, 2026 18:46
Show Gist options
  • Select an option

  • Save OhadRubin/ee363ca4cedbcb21fe5666459f00d2ae to your computer and use it in GitHub Desktop.

Select an option

Save OhadRubin/ee363ca4cedbcb21fe5666459f00d2ae to your computer and use it in GitHub Desktop.
PR 1486: prompt_logprobs support for TPU
From 848d3f5d4c2aa5bee58eecad8421b5947ec79334 Mon Sep 17 00:00:00 2001
From: catswe <212922539+catswe@users.noreply.github.com>
Date: Sat, 17 Jan 2026 18:36:09 -0600
Subject: [PATCH 1/9] Add prompt_logprobs
Signed-off-by: catswe <212922539+catswe@users.noreply.github.com>
---
tpu_inference/runner/tpu_runner.py | 101 +++++++++++++++++++++++++++--
1 file changed, 94 insertions(+), 7 deletions(-)
diff --git a/tpu_inference/runner/tpu_runner.py b/tpu_inference/runner/tpu_runner.py
index 794dd95665..cfee1dd34a 100644
--- a/tpu_inference/runner/tpu_runner.py
+++ b/tpu_inference/runner/tpu_runner.py
@@ -154,6 +154,7 @@ class ExecuteModelState:
attn_metadata: AttentionMetadata
input_ids: Optional[jax.Array]
hidden_states: jax.Array
+ hidden_states_all: Optional[jax.Array]
logits: jax.Array
aux_hidden_states: Optional[jax.Array]
spec_decode_metadata: Optional[SpecDecodeMetadata]
@@ -575,6 +576,8 @@ def sample_tokens(
# This can happen in pipeline parallel case.
return EMPTY_MODEL_RUNNER_OUTPUT
+ hidden_states_all = self.execute_model_state.hidden_states_all
+
(scheduler_output, attn_metadata, input_ids, hidden_states, logits,
aux_hidden_states, spec_decode_metadata, kv_connector_output,
logits_indices_selector,
@@ -601,10 +604,17 @@ def sample_tokens(
logits,
arange,
)
- return self._sample_from_logits(
- scheduler_output, attn_metadata, input_ids, hidden_states, logits,
- aux_hidden_states, spec_decode_metadata, kv_connector_output,
- logits_indices_selector, padded_num_reqs)
+ return self._sample_from_logits(scheduler_output,
+ attn_metadata,
+ input_ids,
+ hidden_states,
+ logits,
+ aux_hidden_states,
+ spec_decode_metadata,
+ kv_connector_output,
+ logits_indices_selector,
+ padded_num_reqs,
+ hidden_states_all=hidden_states_all)
def _modify_prev_results(self):
# If copy to host has not been done, we just wait.
@@ -788,6 +798,7 @@ def _execute_model(
assert isinstance(hidden_states, JaxIntermediateTensors)
hidden_states.kv_connector_output = kv_connector_output
return attn_metadata, hidden_states
+ hidden_states_all = hidden_states
hidden_states = self._select_from_array_fn(hidden_states,
logits_indices)
logits = self.compute_logits_fn(
@@ -801,6 +812,7 @@ def _execute_model(
attn_metadata=attn_metadata,
input_ids=input_ids,
hidden_states=hidden_states,
+ hidden_states_all=hidden_states_all,
logits=logits,
aux_hidden_states=aux_hidden_states,
spec_decode_metadata=spec_decode_metadata,
@@ -809,6 +821,75 @@ def _execute_model(
padded_num_reqs=padded_num_reqs)
return attn_metadata, None
+ def _get_prompt_logprobs_dict(
+ self,
+ hidden_states_all: jax.Array,
+ scheduler_output: "VllmSchedulerOutput",
+ num_reqs: int,
+ ) -> dict[str, Optional[list]]:
+ """Compute prompt logprobs for requests that need them."""
+ prompt_logprobs_dict: dict[str, Optional[list]] = {}
+
+ # Check which requests need prompt_logprobs
+ reqs_needing_logprobs = []
+ for req_id in self.input_batch.req_ids[:num_reqs]:
+ req_state = self.requests.get(req_id)
+ if req_state is None:
+ prompt_logprobs_dict[req_id] = None
+ continue
+ sampling_params = getattr(req_state, 'sampling_params', None)
+ if sampling_params is not None and getattr(
+ sampling_params, 'prompt_logprobs', None) is not None:
+ reqs_needing_logprobs.append(req_id)
+ else:
+ prompt_logprobs_dict[req_id] = None
+
+ if not reqs_needing_logprobs:
+ return prompt_logprobs_dict
+
+ # Compute logits for ALL positions
+ total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
+ all_hidden = hidden_states_all[:total_num_scheduled_tokens]
+ all_logits = self.compute_logits_fn(self.state, all_hidden, None)
+
+ # Compute log softmax to get logprobs
+ all_logprobs = jax.nn.log_softmax(all_logits, axis=-1)
+ all_logprobs_np = np.asarray(jax.device_get(all_logprobs))
+
+ # Build prompt_logprobs for each request that needs it
+ token_offset = 0
+ for i, req_id in enumerate(self.input_batch.req_ids[:num_reqs]):
+ num_tokens = scheduler_output.num_scheduled_tokens.get(req_id, 0)
+
+ if req_id not in reqs_needing_logprobs:
+ token_offset += num_tokens
+ continue
+
+ req_state = self.requests[req_id]
+
+ # Get the input token ids for this request
+ req_idx = self.input_batch.req_id_to_index[req_id]
+ num_computed = req_state.num_computed_tokens
+
+ # Build logprobs list: [None, {token: logprob}, ...]
+ # First token has no prior context, so None
+ req_prompt_logprobs = [None]
+
+ # For positions 1..num_tokens-1, get logprob of actual token
+ for j in range(num_tokens - 1):
+ pos = token_offset + j
+ # The token at position j+1 is predicted by logits at position j
+ next_token_id = int(
+ self.input_batch.token_ids_cpu[req_idx,
+ num_computed + j + 1])
+ logprob = float(all_logprobs_np[pos, next_token_id])
+ req_prompt_logprobs.append({next_token_id: logprob})
+
+ prompt_logprobs_dict[req_id] = req_prompt_logprobs
+ token_offset += num_tokens
+
+ return prompt_logprobs_dict
+
def _sample_from_logits(
self,
scheduler_output: "VllmSchedulerOutput",
@@ -821,6 +902,7 @@ def _sample_from_logits(
kv_connector_output: Optional[KVConnectorOutput],
logits_indices_selector: Optional[List[int]] = None,
padded_num_reqs: Optional[int] = None,
+ hidden_states_all=None
) -> ModelRunnerOutput | AsyncTPUModelRunnerOutput:
if padded_num_reqs is None:
padded_num_reqs = runner_utils.get_padded_num_reqs_with_upper_limit(
@@ -910,9 +992,14 @@ def _sample_from_logits(
self.input_batch.req_ids[:num_reqs]), "req_ids contains None"
req_ids = cast(list[str], self.input_batch.req_ids[:num_reqs])
- prompt_logprobs_dict = {}
- for req_id in self.input_batch.req_ids[:num_reqs]:
- prompt_logprobs_dict[req_id] = None
+ if hidden_states_all is not None:
+ prompt_logprobs_dict = self._get_prompt_logprobs_dict(
+ hidden_states_all, scheduler_output, num_reqs)
+ else:
+ prompt_logprobs_dict = {
+ req_id: None
+ for req_id in self.input_batch.req_ids[:num_reqs]
+ }
# If async scheduler enabled
if self.scheduler_config.async_scheduling:
From 2cf982f38113e4c13d4a07181cd82f6a0733c442 Mon Sep 17 00:00:00 2001
From: catswe <212922539+catswe@users.noreply.github.com>
Date: Sat, 17 Jan 2026 18:46:46 -0600
Subject: [PATCH 2/9] Update
Signed-off-by: catswe <212922539+catswe@users.noreply.github.com>
---
tpu_inference/runner/tpu_runner.py | 124 ++++++++++++++++++++---------
1 file changed, 88 insertions(+), 36 deletions(-)
diff --git a/tpu_inference/runner/tpu_runner.py b/tpu_inference/runner/tpu_runner.py
index cfee1dd34a..02f2075e5c 100644
--- a/tpu_inference/runner/tpu_runner.py
+++ b/tpu_inference/runner/tpu_runner.py
@@ -24,6 +24,7 @@
import jax.numpy as jnp
import jaxtyping
import numpy as np
+import torch
import vllm.envs as vllm_envs
from flax import nnx
from jax.experimental import mesh_utils
@@ -40,7 +41,7 @@
from vllm.v1.kv_cache_interface import KVCacheConfig
from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, AsyncModelRunnerOutput,
DraftTokenIds, KVConnectorOutput, LogprobsLists,
- ModelRunnerOutput)
+ LogprobsTensors, ModelRunnerOutput)
from vllm.v1.request import Request
from vllm.v1.spec_decode.ngram_proposer import NgramProposer
from vllm.v1.worker.kv_connector_model_runner_mixin import \
@@ -576,21 +577,20 @@ def sample_tokens(
# This can happen in pipeline parallel case.
return EMPTY_MODEL_RUNNER_OUTPUT
- hidden_states_all = self.execute_model_state.hidden_states_all
-
(scheduler_output, attn_metadata, input_ids, hidden_states, logits,
aux_hidden_states, spec_decode_metadata, kv_connector_output,
- logits_indices_selector,
- padded_num_reqs) = (self.execute_model_state.scheduler_output,
- self.execute_model_state.attn_metadata,
- self.execute_model_state.input_ids,
- self.execute_model_state.hidden_states,
- self.execute_model_state.logits,
- self.execute_model_state.aux_hidden_states,
- self.execute_model_state.spec_decode_metadata,
- self.execute_model_state.kv_connector_output,
- self.execute_model_state.logits_indices_selector,
- self.execute_model_state.padded_num_reqs)
+ logits_indices_selector, padded_num_reqs, hidden_states_all) = (
+ self.execute_model_state.scheduler_output,
+ self.execute_model_state.attn_metadata,
+ self.execute_model_state.input_ids,
+ self.execute_model_state.hidden_states,
+ self.execute_model_state.logits,
+ self.execute_model_state.aux_hidden_states,
+ self.execute_model_state.spec_decode_metadata,
+ self.execute_model_state.kv_connector_output,
+ self.execute_model_state.logits_indices_selector,
+ self.execute_model_state.padded_num_reqs,
+ self.execute_model_state.hidden_states_all)
self.execute_model_state = None
if grammar_output is not None:
@@ -826,21 +826,23 @@ def _get_prompt_logprobs_dict(
hidden_states_all: jax.Array,
scheduler_output: "VllmSchedulerOutput",
num_reqs: int,
- ) -> dict[str, Optional[list]]:
+ ) -> dict[str, Optional["LogprobsTensors"]]:
"""Compute prompt logprobs for requests that need them."""
- prompt_logprobs_dict: dict[str, Optional[list]] = {}
+
+ prompt_logprobs_dict: dict[str, Optional[LogprobsTensors]] = {}
# Check which requests need prompt_logprobs
- reqs_needing_logprobs = []
+ reqs_needing_logprobs: list[tuple[str, int]] = [
+ ] # (req_id, num_prompt_logprobs)
for req_id in self.input_batch.req_ids[:num_reqs]:
req_state = self.requests.get(req_id)
if req_state is None:
prompt_logprobs_dict[req_id] = None
continue
sampling_params = getattr(req_state, 'sampling_params', None)
- if sampling_params is not None and getattr(
- sampling_params, 'prompt_logprobs', None) is not None:
- reqs_needing_logprobs.append(req_id)
+ if sampling_params is not None and sampling_params.prompt_logprobs is not None:
+ num_prompt_logprobs = sampling_params.prompt_logprobs
+ reqs_needing_logprobs.append((req_id, num_prompt_logprobs))
else:
prompt_logprobs_dict[req_id] = None
@@ -858,34 +860,84 @@ def _get_prompt_logprobs_dict(
# Build prompt_logprobs for each request that needs it
token_offset = 0
+ reqs_needing_logprobs_dict = dict(reqs_needing_logprobs)
+
for i, req_id in enumerate(self.input_batch.req_ids[:num_reqs]):
num_tokens = scheduler_output.num_scheduled_tokens.get(req_id, 0)
- if req_id not in reqs_needing_logprobs:
+ if req_id not in reqs_needing_logprobs_dict:
token_offset += num_tokens
continue
+ num_prompt_logprobs = reqs_needing_logprobs_dict[req_id]
req_state = self.requests[req_id]
# Get the input token ids for this request
req_idx = self.input_batch.req_id_to_index[req_id]
num_computed = req_state.num_computed_tokens
- # Build logprobs list: [None, {token: logprob}, ...]
- # First token has no prior context, so None
- req_prompt_logprobs = [None]
-
- # For positions 1..num_tokens-1, get logprob of actual token
- for j in range(num_tokens - 1):
- pos = token_offset + j
- # The token at position j+1 is predicted by logits at position j
- next_token_id = int(
- self.input_batch.token_ids_cpu[req_idx,
- num_computed + j + 1])
- logprob = float(all_logprobs_np[pos, next_token_id])
- req_prompt_logprobs.append({next_token_id: logprob})
-
- prompt_logprobs_dict[req_id] = req_prompt_logprobs
+ # Number of prompt logprobs to return (excluding first token which has no prior)
+ # We compute logprobs for tokens at positions 1 to num_tokens-1
+ # (the logprob of token i is computed from hidden state at position i-1)
+ num_logprob_tokens = num_tokens - 1
+
+ if num_logprob_tokens <= 0:
+ prompt_logprobs_dict[req_id] = None
+ token_offset += num_tokens
+ continue
+
+ # For each position, we need: top-k token ids, their logprobs, and rank of selected token
+ # num_prompt_logprobs + 1 to include the selected token
+ num_logprobs_to_return = num_prompt_logprobs + 1
+
+ # Get logprobs for positions [0, num_tokens-1) which predict tokens [1, num_tokens)
+ req_logprobs = all_logprobs_np[token_offset:token_offset +
+ num_logprob_tokens]
+
+ # Get top-k logprobs for each position
+ # Shape: [num_logprob_tokens, num_logprobs_to_return]
+ top_indices = np.argpartition(req_logprobs,
+ -num_logprobs_to_return,
+ axis=-1)[:, -num_logprobs_to_return:]
+ top_logprobs = np.take_along_axis(req_logprobs,
+ top_indices,
+ axis=-1)
+
+ # Sort by logprob descending
+ sort_order = np.argsort(-top_logprobs, axis=-1)
+ top_indices = np.take_along_axis(top_indices, sort_order, axis=-1)
+ top_logprobs = np.take_along_axis(top_logprobs,
+ sort_order,
+ axis=-1)
+
+ # Get the actual next token ids and their ranks
+ next_token_ids = self.input_batch.token_ids_cpu[req_idx,
+ num_computed +
+ 1:num_computed +
+ num_tokens]
+
+ # Compute ranks of selected tokens
+ selected_token_ranks = np.zeros(num_logprob_tokens, dtype=np.int64)
+ for j in range(num_logprob_tokens):
+ next_token_id = int(next_token_ids[j])
+ # Find rank (0-indexed position in sorted top-k, or vocab_size if not in top-k)
+ rank_in_topk = np.where(top_indices[j] == next_token_id)[0]
+ if len(rank_in_topk) > 0:
+ selected_token_ranks[j] = rank_in_topk[0]
+ else:
+ # Token not in top-k, compute actual rank
+ selected_token_ranks[j] = np.sum(
+ req_logprobs[j] > req_logprobs[j, next_token_id])
+
+ # Create LogprobsTensors
+ logprobs_tensors = LogprobsTensors(
+ logprob_token_ids=torch.from_numpy(top_indices.astype(
+ np.int64)),
+ logprobs=torch.from_numpy(top_logprobs.astype(np.float32)),
+ selected_token_ranks=torch.from_numpy(selected_token_ranks),
+ )
+
+ prompt_logprobs_dict[req_id] = logprobs_tensors
token_offset += num_tokens
return prompt_logprobs_dict
From 4fd18357332be01e5fb13d40584280b3b7f50a24 Mon Sep 17 00:00:00 2001
From: catswe <212922539+catswe@users.noreply.github.com>
Date: Sat, 17 Jan 2026 18:56:07 -0600
Subject: [PATCH 3/9] Update code
Signed-off-by: catswe <212922539+catswe@users.noreply.github.com>
---
tpu_inference/runner/tpu_runner.py | 122 +++++++++++++++++------------
1 file changed, 71 insertions(+), 51 deletions(-)
diff --git a/tpu_inference/runner/tpu_runner.py b/tpu_inference/runner/tpu_runner.py
index 02f2075e5c..12bffb3c7d 100644
--- a/tpu_inference/runner/tpu_runner.py
+++ b/tpu_inference/runner/tpu_runner.py
@@ -826,28 +826,65 @@ def _get_prompt_logprobs_dict(
hidden_states_all: jax.Array,
scheduler_output: "VllmSchedulerOutput",
num_reqs: int,
- ) -> dict[str, Optional["LogprobsTensors"]]:
- """Compute prompt logprobs for requests that need them."""
+ ) -> dict[str, "LogprobsTensors"]:
+ """Compute prompt logprobs for requests that need them.
+
+ Only returns prompt logprobs when a request COMPLETES its prefill
+ in the current step.
+ """
+ from vllm.v1.outputs import LogprobsTensors
- prompt_logprobs_dict: dict[str, Optional[LogprobsTensors]] = {}
+ prompt_logprobs_dict: dict[str, LogprobsTensors] = {}
- # Check which requests need prompt_logprobs
- reqs_needing_logprobs: list[tuple[str, int]] = [
- ] # (req_id, num_prompt_logprobs)
+ # Check which requests need prompt_logprobs AND are completing prefill
+ reqs_completing_prefill: list[tuple[str, int, int, int]] = [
+ ] # (req_id, num_prompt_logprobs, start_offset, num_prompt_tokens)
+
+ token_offset = 0
for req_id in self.input_batch.req_ids[:num_reqs]:
+ num_scheduled = scheduler_output.num_scheduled_tokens.get(
+ req_id, 0)
+
req_state = self.requests.get(req_id)
if req_state is None:
- prompt_logprobs_dict[req_id] = None
+ token_offset += num_scheduled
continue
+
sampling_params = getattr(req_state, 'sampling_params', None)
- if sampling_params is not None and sampling_params.prompt_logprobs is not None:
- num_prompt_logprobs = sampling_params.prompt_logprobs
- reqs_needing_logprobs.append((req_id, num_prompt_logprobs))
- else:
- prompt_logprobs_dict[req_id] = None
+ if sampling_params is None or sampling_params.prompt_logprobs is None:
+ token_offset += num_scheduled
+ continue
+
+ # Check if we're in prefill phase and completing it
+ num_computed = req_state.num_computed_tokens
+ prompt_token_ids = getattr(req_state, 'prompt_token_ids', None)
+ if prompt_token_ids is None:
+ token_offset += num_scheduled
+ continue
+
+ prompt_len = len(prompt_token_ids)
+
+ # After this step, how many tokens will be computed?
+ tokens_after_step = num_computed + num_scheduled
+
+ # Skip if already past prefill (decode phase)
+ if num_computed >= prompt_len:
+ token_offset += num_scheduled
+ continue
+
+ # Skip if prefill not completing in this step (chunked prefill, more chunks to come)
+ if tokens_after_step < prompt_len:
+ token_offset += num_scheduled
+ continue
- if not reqs_needing_logprobs:
- return prompt_logprobs_dict
+ # This request is completing its prefill in this step
+ num_prompt_logprobs = sampling_params.prompt_logprobs
+ reqs_completing_prefill.append(
+ (req_id, num_prompt_logprobs, token_offset, prompt_len))
+ token_offset += num_scheduled
+
+ if not reqs_completing_prefill:
+ return {} # Return empty dict - this is important!
# Compute logits for ALL positions
total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
@@ -858,44 +895,32 @@ def _get_prompt_logprobs_dict(
all_logprobs = jax.nn.log_softmax(all_logits, axis=-1)
all_logprobs_np = np.asarray(jax.device_get(all_logprobs))
- # Build prompt_logprobs for each request that needs it
- token_offset = 0
- reqs_needing_logprobs_dict = dict(reqs_needing_logprobs)
-
- for i, req_id in enumerate(self.input_batch.req_ids[:num_reqs]):
- num_tokens = scheduler_output.num_scheduled_tokens.get(req_id, 0)
-
- if req_id not in reqs_needing_logprobs_dict:
- token_offset += num_tokens
- continue
-
- num_prompt_logprobs = reqs_needing_logprobs_dict[req_id]
+ # Build prompt_logprobs for each request completing prefill
+ for req_id, num_prompt_logprobs, start_offset, prompt_len in reqs_completing_prefill:
req_state = self.requests[req_id]
-
- # Get the input token ids for this request
req_idx = self.input_batch.req_id_to_index[req_id]
num_computed = req_state.num_computed_tokens
- # Number of prompt logprobs to return (excluding first token which has no prior)
- # We compute logprobs for tokens at positions 1 to num_tokens-1
- # (the logprob of token i is computed from hidden state at position i-1)
- num_logprob_tokens = num_tokens - 1
+ # Number of prompt tokens being processed in this step
+ # (may be less than prompt_len if chunked)
+ num_prompt_tokens_this_step = prompt_len - num_computed
+
+ # Number of prompt logprobs to compute (excluding first token which has no prior)
+ # We compute logprobs for tokens at positions 1 to num_prompt_tokens_this_step-1
+ num_logprob_tokens = num_prompt_tokens_this_step - 1
if num_logprob_tokens <= 0:
- prompt_logprobs_dict[req_id] = None
- token_offset += num_tokens
continue
- # For each position, we need: top-k token ids, their logprobs, and rank of selected token
# num_prompt_logprobs + 1 to include the selected token
- num_logprobs_to_return = num_prompt_logprobs + 1
+ num_logprobs_to_return = min(num_prompt_logprobs + 1,
+ self.vocab_size)
- # Get logprobs for positions [0, num_tokens-1) which predict tokens [1, num_tokens)
- req_logprobs = all_logprobs_np[token_offset:token_offset +
+ # Get logprobs for this request's prompt tokens
+ req_logprobs = all_logprobs_np[start_offset:start_offset +
num_logprob_tokens]
# Get top-k logprobs for each position
- # Shape: [num_logprob_tokens, num_logprobs_to_return]
top_indices = np.argpartition(req_logprobs,
-num_logprobs_to_return,
axis=-1)[:, -num_logprobs_to_return:]
@@ -910,26 +935,22 @@ def _get_prompt_logprobs_dict(
sort_order,
axis=-1)
- # Get the actual next token ids and their ranks
- next_token_ids = self.input_batch.token_ids_cpu[req_idx,
- num_computed +
- 1:num_computed +
- num_tokens]
+ # Get the actual next token ids and compute their ranks
+ next_token_ids = self.input_batch.token_ids_cpu[
+ req_idx,
+ num_computed + 1:num_computed + num_prompt_tokens_this_step]
- # Compute ranks of selected tokens
selected_token_ranks = np.zeros(num_logprob_tokens, dtype=np.int64)
for j in range(num_logprob_tokens):
next_token_id = int(next_token_ids[j])
- # Find rank (0-indexed position in sorted top-k, or vocab_size if not in top-k)
rank_in_topk = np.where(top_indices[j] == next_token_id)[0]
if len(rank_in_topk) > 0:
selected_token_ranks[j] = rank_in_topk[0]
else:
- # Token not in top-k, compute actual rank
- selected_token_ranks[j] = np.sum(
- req_logprobs[j] > req_logprobs[j, next_token_id])
+ selected_token_ranks[j] = int(
+ np.sum(req_logprobs[j] > req_logprobs[j,
+ next_token_id]))
- # Create LogprobsTensors
logprobs_tensors = LogprobsTensors(
logprob_token_ids=torch.from_numpy(top_indices.astype(
np.int64)),
@@ -938,7 +959,6 @@ def _get_prompt_logprobs_dict(
)
prompt_logprobs_dict[req_id] = logprobs_tensors
- token_offset += num_tokens
return prompt_logprobs_dict
From 87588b4cbcb29099e1fc50d915f9b784634fcb54 Mon Sep 17 00:00:00 2001
From: catswe <212922539+catswe@users.noreply.github.com>
Date: Sat, 17 Jan 2026 19:13:47 -0600
Subject: [PATCH 4/9] Update
Signed-off-by: catswe <212922539+catswe@users.noreply.github.com>
---
tpu_inference/runner/tpu_runner.py | 74 ++++++++++--------------------
1 file changed, 24 insertions(+), 50 deletions(-)
diff --git a/tpu_inference/runner/tpu_runner.py b/tpu_inference/runner/tpu_runner.py
index 12bffb3c7d..cd0ccf5230 100644
--- a/tpu_inference/runner/tpu_runner.py
+++ b/tpu_inference/runner/tpu_runner.py
@@ -837,8 +837,7 @@ def _get_prompt_logprobs_dict(
prompt_logprobs_dict: dict[str, LogprobsTensors] = {}
# Check which requests need prompt_logprobs AND are completing prefill
- reqs_completing_prefill: list[tuple[str, int, int, int]] = [
- ] # (req_id, num_prompt_logprobs, start_offset, num_prompt_tokens)
+ reqs_completing_prefill: list[tuple[str, int, int, int]] = []
token_offset = 0
for req_id in self.input_batch.req_ids[:num_reqs]:
@@ -863,8 +862,6 @@ def _get_prompt_logprobs_dict(
continue
prompt_len = len(prompt_token_ids)
-
- # After this step, how many tokens will be computed?
tokens_after_step = num_computed + num_scheduled
# Skip if already past prefill (decode phase)
@@ -872,19 +869,18 @@ def _get_prompt_logprobs_dict(
token_offset += num_scheduled
continue
- # Skip if prefill not completing in this step (chunked prefill, more chunks to come)
+ # Skip if prefill not completing in this step
if tokens_after_step < prompt_len:
token_offset += num_scheduled
continue
- # This request is completing its prefill in this step
num_prompt_logprobs = sampling_params.prompt_logprobs
reqs_completing_prefill.append(
(req_id, num_prompt_logprobs, token_offset, prompt_len))
token_offset += num_scheduled
if not reqs_completing_prefill:
- return {} # Return empty dict - this is important!
+ return {}
# Compute logits for ALL positions
total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
@@ -895,66 +891,44 @@ def _get_prompt_logprobs_dict(
all_logprobs = jax.nn.log_softmax(all_logits, axis=-1)
all_logprobs_np = np.asarray(jax.device_get(all_logprobs))
- # Build prompt_logprobs for each request completing prefill
for req_id, num_prompt_logprobs, start_offset, prompt_len in reqs_completing_prefill:
req_state = self.requests[req_id]
req_idx = self.input_batch.req_id_to_index[req_id]
num_computed = req_state.num_computed_tokens
- # Number of prompt tokens being processed in this step
- # (may be less than prompt_len if chunked)
num_prompt_tokens_this_step = prompt_len - num_computed
-
- # Number of prompt logprobs to compute (excluding first token which has no prior)
- # We compute logprobs for tokens at positions 1 to num_prompt_tokens_this_step-1
+ # Logprobs for positions 0 to num_tokens-2 predict tokens 1 to num_tokens-1
num_logprob_tokens = num_prompt_tokens_this_step - 1
if num_logprob_tokens <= 0:
continue
- # num_prompt_logprobs + 1 to include the selected token
- num_logprobs_to_return = min(num_prompt_logprobs + 1,
- self.vocab_size)
+ # Get the actual next token ids (tokens at positions 1 to num_tokens-1)
+ next_token_ids = self.input_batch.token_ids_cpu[
+ req_idx, num_computed + 1:num_computed +
+ num_prompt_tokens_this_step].astype(np.int64)
- # Get logprobs for this request's prompt tokens
+ # Get logprobs for this request's positions
req_logprobs = all_logprobs_np[start_offset:start_offset +
num_logprob_tokens]
- # Get top-k logprobs for each position
- top_indices = np.argpartition(req_logprobs,
- -num_logprobs_to_return,
- axis=-1)[:, -num_logprobs_to_return:]
- top_logprobs = np.take_along_axis(req_logprobs,
- top_indices,
- axis=-1)
-
- # Sort by logprob descending
- sort_order = np.argsort(-top_logprobs, axis=-1)
- top_indices = np.take_along_axis(top_indices, sort_order, axis=-1)
- top_logprobs = np.take_along_axis(top_logprobs,
- sort_order,
- axis=-1)
-
- # Get the actual next token ids and compute their ranks
- next_token_ids = self.input_batch.token_ids_cpu[
- req_idx,
- num_computed + 1:num_computed + num_prompt_tokens_this_step]
-
- selected_token_ranks = np.zeros(num_logprob_tokens, dtype=np.int64)
- for j in range(num_logprob_tokens):
- next_token_id = int(next_token_ids[j])
- rank_in_topk = np.where(top_indices[j] == next_token_id)[0]
- if len(rank_in_topk) > 0:
- selected_token_ranks[j] = rank_in_topk[0]
- else:
- selected_token_ranks[j] = int(
- np.sum(req_logprobs[j] > req_logprobs[j,
- next_token_id]))
+ # For prompt logprobs, we only need the selected token's logprob
+ # Shape: [num_logprob_tokens, 1]
+ selected_logprobs = req_logprobs[np.arange(num_logprob_tokens),
+ next_token_ids]
+
+ # Compute ranks of selected tokens (how many tokens have higher logprob)
+ selected_token_ranks = np.sum(req_logprobs
+ > selected_logprobs[:, np.newaxis],
+ axis=-1).astype(np.int64)
+ # Create LogprobsTensors with just the selected token
+ # Shape: [num_logprob_tokens, 1]
logprobs_tensors = LogprobsTensors(
- logprob_token_ids=torch.from_numpy(top_indices.astype(
- np.int64)),
- logprobs=torch.from_numpy(top_logprobs.astype(np.float32)),
+ logprob_token_ids=torch.from_numpy(
+ next_token_ids.reshape(-1, 1)),
+ logprobs=torch.from_numpy(
+ selected_logprobs.astype(np.float32).reshape(-1, 1)),
selected_token_ranks=torch.from_numpy(selected_token_ranks),
)
From 94c0065d9ec7abc8d3b3c0e37f164c13226ade7d Mon Sep 17 00:00:00 2001
From: catswe <212922539+catswe@users.noreply.github.com>
Date: Sat, 17 Jan 2026 20:06:12 -0600
Subject: [PATCH 5/9] Update
Signed-off-by: catswe <212922539+catswe@users.noreply.github.com>
---
tpu_inference/runner/tpu_runner.py | 31 ++++++++++++++++++++----------
1 file changed, 21 insertions(+), 10 deletions(-)
diff --git a/tpu_inference/runner/tpu_runner.py b/tpu_inference/runner/tpu_runner.py
index cd0ccf5230..17c6c444bc 100644
--- a/tpu_inference/runner/tpu_runner.py
+++ b/tpu_inference/runner/tpu_runner.py
@@ -832,8 +832,6 @@ def _get_prompt_logprobs_dict(
Only returns prompt logprobs when a request COMPLETES its prefill
in the current step.
"""
- from vllm.v1.outputs import LogprobsTensors
-
prompt_logprobs_dict: dict[str, LogprobsTensors] = {}
# Check which requests need prompt_logprobs AND are completing prefill
@@ -841,8 +839,18 @@ def _get_prompt_logprobs_dict(
token_offset = 0
for req_id in self.input_batch.req_ids[:num_reqs]:
+ # Skip invalid entries
+ if req_id is None:
+ continue
+
+ # Skip if not in current batch
+ if req_id not in self.input_batch.req_id_to_index:
+ continue
+
num_scheduled = scheduler_output.num_scheduled_tokens.get(
req_id, 0)
+ if num_scheduled == 0:
+ continue
req_state = self.requests.get(req_id)
if req_state is None:
@@ -892,18 +900,24 @@ def _get_prompt_logprobs_dict(
all_logprobs_np = np.asarray(jax.device_get(all_logprobs))
for req_id, num_prompt_logprobs, start_offset, prompt_len in reqs_completing_prefill:
- req_state = self.requests[req_id]
+ # Double-check req_id is still valid (defensive)
+ if req_id not in self.input_batch.req_id_to_index:
+ continue
+
+ req_state = self.requests.get(req_id)
+ if req_state is None:
+ continue
+
req_idx = self.input_batch.req_id_to_index[req_id]
num_computed = req_state.num_computed_tokens
num_prompt_tokens_this_step = prompt_len - num_computed
- # Logprobs for positions 0 to num_tokens-2 predict tokens 1 to num_tokens-1
num_logprob_tokens = num_prompt_tokens_this_step - 1
if num_logprob_tokens <= 0:
continue
- # Get the actual next token ids (tokens at positions 1 to num_tokens-1)
+ # Get the actual next token ids
next_token_ids = self.input_batch.token_ids_cpu[
req_idx, num_computed + 1:num_computed +
num_prompt_tokens_this_step].astype(np.int64)
@@ -912,18 +926,15 @@ def _get_prompt_logprobs_dict(
req_logprobs = all_logprobs_np[start_offset:start_offset +
num_logprob_tokens]
- # For prompt logprobs, we only need the selected token's logprob
- # Shape: [num_logprob_tokens, 1]
+ # Get the selected token's logprob
selected_logprobs = req_logprobs[np.arange(num_logprob_tokens),
next_token_ids]
- # Compute ranks of selected tokens (how many tokens have higher logprob)
+ # Compute ranks of selected tokens
selected_token_ranks = np.sum(req_logprobs
> selected_logprobs[:, np.newaxis],
axis=-1).astype(np.int64)
- # Create LogprobsTensors with just the selected token
- # Shape: [num_logprob_tokens, 1]
logprobs_tensors = LogprobsTensors(
logprob_token_ids=torch.from_numpy(
next_token_ids.reshape(-1, 1)),
From 88d1136e6fb0920f0d43e2763114bed9cbeb5b56 Mon Sep 17 00:00:00 2001
From: catswe <212922539+catswe@users.noreply.github.com>
Date: Sat, 17 Jan 2026 21:11:40 -0600
Subject: [PATCH 6/9] Update code
Signed-off-by: catswe <212922539+catswe@users.noreply.github.com>
---
tpu_inference/runner/tpu_runner.py | 34 +++++++++++++++++++++---------
1 file changed, 24 insertions(+), 10 deletions(-)
diff --git a/tpu_inference/runner/tpu_runner.py b/tpu_inference/runner/tpu_runner.py
index 17c6c444bc..d53e600aa8 100644
--- a/tpu_inference/runner/tpu_runner.py
+++ b/tpu_inference/runner/tpu_runner.py
@@ -832,6 +832,17 @@ def _get_prompt_logprobs_dict(
Only returns prompt logprobs when a request COMPLETES its prefill
in the current step.
"""
+
+ # Build set of valid req_ids that will be in ModelRunnerOutput
+ current_req_ids = set(self.input_batch.req_ids[:num_reqs]) - {None}
+
+ # Also check against scheduler_output to ensure consistency
+ scheduled_req_ids = set(scheduler_output.num_scheduled_tokens.keys())
+ valid_req_ids = current_req_ids & scheduled_req_ids
+
+ if not valid_req_ids:
+ return {}
+
prompt_logprobs_dict: dict[str, LogprobsTensors] = {}
# Check which requests need prompt_logprobs AND are completing prefill
@@ -839,17 +850,13 @@ def _get_prompt_logprobs_dict(
token_offset = 0
for req_id in self.input_batch.req_ids[:num_reqs]:
- # Skip invalid entries
- if req_id is None:
- continue
-
- # Skip if not in current batch
- if req_id not in self.input_batch.req_id_to_index:
+ if req_id is None or req_id not in valid_req_ids:
continue
num_scheduled = scheduler_output.num_scheduled_tokens.get(
req_id, 0)
if num_scheduled == 0:
+ token_offset += num_scheduled
continue
req_state = self.requests.get(req_id)
@@ -900,15 +907,18 @@ def _get_prompt_logprobs_dict(
all_logprobs_np = np.asarray(jax.device_get(all_logprobs))
for req_id, num_prompt_logprobs, start_offset, prompt_len in reqs_completing_prefill:
- # Double-check req_id is still valid (defensive)
- if req_id not in self.input_batch.req_id_to_index:
+ # Final validation: skip if req_id won't be in output
+ if req_id not in valid_req_ids:
continue
req_state = self.requests.get(req_id)
if req_state is None:
continue
- req_idx = self.input_batch.req_id_to_index[req_id]
+ req_idx = self.input_batch.req_id_to_index.get(req_id)
+ if req_idx is None:
+ continue
+
num_computed = req_state.num_computed_tokens
num_prompt_tokens_this_step = prompt_len - num_computed
@@ -945,7 +955,11 @@ def _get_prompt_logprobs_dict(
prompt_logprobs_dict[req_id] = logprobs_tensors
- return prompt_logprobs_dict
+ # Final filter: only return req_ids that are definitely in current batch
+ return {
+ k: v
+ for k, v in prompt_logprobs_dict.items() if k in current_req_ids
+ }
def _sample_from_logits(
self,
From 5f7f35c8cd71ce06e49851907e8cd81c87f79136 Mon Sep 17 00:00:00 2001
From: catswe <212922539+catswe@users.noreply.github.com>
Date: Sat, 17 Jan 2026 21:44:31 -0600
Subject: [PATCH 7/9] Update
Signed-off-by: catswe <212922539+catswe@users.noreply.github.com>
---
tpu_inference/runner/tpu_runner.py | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/tpu_inference/runner/tpu_runner.py b/tpu_inference/runner/tpu_runner.py
index d53e600aa8..a69ceeb692 100644
--- a/tpu_inference/runner/tpu_runner.py
+++ b/tpu_inference/runner/tpu_runner.py
@@ -1176,7 +1176,7 @@ def _sample_from_logits(
model_runner_output = ModelRunnerOutput(
req_ids=req_ids,
- req_id_to_index=self.input_batch.req_id_to_index,
+ req_id_to_index=dict(self.input_batch.req_id_to_index),
sampled_token_ids=valid_sampled_token_ids,
logprobs=logprobs_lists,
prompt_logprobs_dict=prompt_logprobs_dict,
From 7343c090e8740c6b9fe3015fbfebf5e6c4ec6f5f Mon Sep 17 00:00:00 2001
From: catswe <212922539+catswe@users.noreply.github.com>
Date: Sat, 17 Jan 2026 21:52:55 -0600
Subject: [PATCH 8/9] Update
Signed-off-by: catswe <212922539+catswe@users.noreply.github.com>
---
tpu_inference/runner/tpu_runner.py | 25 ++++++++++++++++++++++---
1 file changed, 22 insertions(+), 3 deletions(-)
diff --git a/tpu_inference/runner/tpu_runner.py b/tpu_inference/runner/tpu_runner.py
index a69ceeb692..d6db7022de 100644
--- a/tpu_inference/runner/tpu_runner.py
+++ b/tpu_inference/runner/tpu_runner.py
@@ -987,7 +987,7 @@ def _sample_from_logits(
tpu_sampling_metadata = TPUSupportedSamplingMetadata.from_input_batch(
self.mesh, self.input_batch, padded_num_reqs, sharding=sharding)
- # TODO(pooyam): Should we move this to `_prepare_inputs`?
+ # TODO(pooyam): Should we move this to _prepare_inputs?
if tpu_sampling_metadata.do_sampling:
self.rng_params_for_sampling, step_rng = jax.random.split(
self.rng_params_for_sampling)
@@ -1036,7 +1036,7 @@ def _sample_from_logits(
num_reqs = self.input_batch.num_reqs
# Update the cache state concurrently. Code above will not block until
- # We use `selected_token_ids`. Add mark_step if post-processing changes
+ # We use selected_token_ids. Add mark_step if post-processing changes
request_seq_lens: list[tuple[int, CachedRequestState, int]] = []
discard_sampled_tokens_req_indices = []
for i, req_id in zip(range(num_reqs), self.input_batch.req_ids):
@@ -1063,6 +1063,7 @@ def _sample_from_logits(
self.input_batch.req_ids[:num_reqs]), "req_ids contains None"
req_ids = cast(list[str], self.input_batch.req_ids[:num_reqs])
+ # --- LOGPROBS LOGIC START ---
if hidden_states_all is not None:
prompt_logprobs_dict = self._get_prompt_logprobs_dict(
hidden_states_all, scheduler_output, num_reqs)
@@ -1071,6 +1072,7 @@ def _sample_from_logits(
req_id: None
for req_id in self.input_batch.req_ids[:num_reqs]
}
+ # --- LOGPROBS LOGIC END ---
# If async scheduler enabled
if self.scheduler_config.async_scheduling:
@@ -1174,15 +1176,32 @@ def _sample_from_logits(
input_ids,
)
+ # --- SAFETY PATCH FOR MISSING KEYS START ---
+ # 1. Create the base mapping
+ final_req_id_to_index = dict(self.input_batch.req_id_to_index)
+
+ # 2. Check for missing keys that the scheduler expects
+ # If a request was scheduled but is no longer in input_batch (e.g. finished),
+ # we must provide a dummy index to prevent Scheduler crash.
+ if scheduler_output.num_scheduled_tokens:
+ for sched_req_id in scheduler_output.num_scheduled_tokens.keys():
+ if sched_req_id not in final_req_id_to_index:
+ # Map missing request to -1 (dummy index).
+ # This tells the scheduler "it exists" but points to invalid/safe memory
+ # if they try to read sampled tokens for it.
+ final_req_id_to_index[sched_req_id] = -1
+ # --- SAFETY PATCH FOR MISSING KEYS END ---
+
model_runner_output = ModelRunnerOutput(
req_ids=req_ids,
- req_id_to_index=dict(self.input_batch.req_id_to_index),
+ req_id_to_index=final_req_id_to_index,
sampled_token_ids=valid_sampled_token_ids,
logprobs=logprobs_lists,
prompt_logprobs_dict=prompt_logprobs_dict,
pooler_output=[],
kv_connector_output=kv_connector_output,
)
+
return model_runner_output
@functools.partial(jax.jit, static_argnums=(0, ))
From 1efc3ee03a618d4b093b43302b2af645b4af5d61 Mon Sep 17 00:00:00 2001
From: catswe <212922539+catswe@users.noreply.github.com>
Date: Sat, 17 Jan 2026 22:04:16 -0600
Subject: [PATCH 9/9] Update code
Signed-off-by: catswe <212922539+catswe@users.noreply.github.com>
---
tpu_inference/runner/tpu_runner.py | 69 ++++++++++++------------------
1 file changed, 28 insertions(+), 41 deletions(-)
diff --git a/tpu_inference/runner/tpu_runner.py b/tpu_inference/runner/tpu_runner.py
index d6db7022de..6b493c1d8c 100644
--- a/tpu_inference/runner/tpu_runner.py
+++ b/tpu_inference/runner/tpu_runner.py
@@ -987,7 +987,6 @@ def _sample_from_logits(
tpu_sampling_metadata = TPUSupportedSamplingMetadata.from_input_batch(
self.mesh, self.input_batch, padded_num_reqs, sharding=sharding)
- # TODO(pooyam): Should we move this to _prepare_inputs?
if tpu_sampling_metadata.do_sampling:
self.rng_params_for_sampling, step_rng = jax.random.split(
self.rng_params_for_sampling)
@@ -1007,6 +1006,7 @@ def _sample_from_logits(
else:
bonus_rng = step_rng
rejection_rng = step_rng
+
bonus_logits = self._select_from_array_fn(
logits, spec_decode_metadata.bonus_logits_indices)
bonus_token_ids = sample(
@@ -1015,6 +1015,7 @@ def _sample_from_logits(
bonus_logits,
tpu_sampling_metadata,
)
+
target_logits = self._select_from_array_fn(
logits, spec_decode_metadata.target_logits_indices)
next_tokens = self.rejection_sampler(
@@ -1035,8 +1036,6 @@ def _sample_from_logits(
num_reqs = self.input_batch.num_reqs
- # Update the cache state concurrently. Code above will not block until
- # We use selected_token_ids. Add mark_step if post-processing changes
request_seq_lens: list[tuple[int, CachedRequestState, int]] = []
discard_sampled_tokens_req_indices = []
for i, req_id in zip(range(num_reqs), self.input_batch.req_ids):
@@ -1047,15 +1046,9 @@ def _sample_from_logits(
if seq_len >= req_state.num_tokens:
request_seq_lens.append((i, req_state, seq_len))
else:
- # Ignore the sampled token from the partial request.
- # Rewind the generator state as if the token was not sampled.
generator = self.input_batch.generators.get(i)
if generator is not None:
- # This relies on cuda-specific torch-internal impl details
generator.set_offset(generator.get_offset() - 4)
-
- # Record the index of the request that should not be sampled,
- # so that we could clear the sampled tokens before returning.
discard_sampled_tokens_req_indices.append(i)
assert all(
@@ -1063,7 +1056,7 @@ def _sample_from_logits(
self.input_batch.req_ids[:num_reqs]), "req_ids contains None"
req_ids = cast(list[str], self.input_batch.req_ids[:num_reqs])
- # --- LOGPROBS LOGIC START ---
+ # --- LOGPROBS LOGIC ---
if hidden_states_all is not None:
prompt_logprobs_dict = self._get_prompt_logprobs_dict(
hidden_states_all, scheduler_output, num_reqs)
@@ -1072,30 +1065,35 @@ def _sample_from_logits(
req_id: None
for req_id in self.input_batch.req_ids[:num_reqs]
}
- # --- LOGPROBS LOGIC END ---
+
+ # --- HELPER FOR SAFETY PATCH ---
+ def _patch_req_id_to_index(base_map):
+ patched_map = copy.deepcopy(base_map)
+ if scheduler_output.num_scheduled_tokens:
+ for sched_req_id in scheduler_output.num_scheduled_tokens.keys(
+ ):
+ if sched_req_id not in patched_map:
+ # Map missing request to -1 (dummy index) to prevent Scheduler crash
+ patched_map[sched_req_id] = -1
+ return patched_map
# If async scheduler enabled
if self.scheduler_config.async_scheduling:
- # Get previous results from TPU and replace the placeholder.
if self._pre_async_results is not None:
assert not self.speculative_config and spec_decode_metadata is None, "Async scheduler does not support speculative decoding yet."
self._modify_prev_results()
- # Set placeholder for next tokens that is not yet generated
placeholder_req_id_to_index: dict[
str, int] = self._update_placeholder(
discard_sampled_tokens_req_indices, request_seq_lens,
logits_indices_selector)
if logprobs is not None:
- # Map logprobs back to the pre-dp shuffling order
logprobs_lists = _jax_logprobs_to_lists(
logprobs, logits_indices_selector)
-
else:
logprobs_lists = None
- # Save the previous results
next_tokens = jax.copy_to_host_async(next_tokens)
self._pre_async_results = AsyncPreResults(
req_ids=req_ids,
@@ -1106,26 +1104,30 @@ def _sample_from_logits(
placeholder_req_id_to_index=placeholder_req_id_to_index,
logits_indices_selector=logits_indices_selector)
+ # --- APPLY SAFETY PATCH (ASYNC PATH) ---
+ patched_req_id_to_index = _patch_req_id_to_index(
+ self.input_batch.req_id_to_index)
+
# Return Model output to executor
model_runner_output = ModelRunnerOutput(
req_ids=req_ids,
- req_id_to_index=copy.deepcopy(
- self.input_batch.req_id_to_index),
- sampled_token_ids=[], # Fill in async get
+ req_id_to_index=patched_req_id_to_index, # Use patched map
+ sampled_token_ids=[],
logprobs=logprobs_lists,
prompt_logprobs_dict=prompt_logprobs_dict,
pooler_output=[],
kv_connector_output=kv_connector_output,
)
- # Return async_model_runner_output
+
async_model_runner_output = AsyncTPUModelRunnerOutput(
model_runner_output, next_tokens, num_reqs,
discard_sampled_tokens_req_indices, logits_indices_selector)
+
return async_model_runner_output
+ # --- SYNC PATH ---
if spec_decode_metadata is None:
next_tokens = np.asarray(jax.device_get(next_tokens))
- # Map tokens back to the pre-dp shuffling order
if logits_indices_selector is not None:
next_tokens = next_tokens[logits_indices_selector]
selected_token_ids = np.expand_dims(next_tokens[:num_reqs], 1)
@@ -1136,10 +1138,9 @@ def _sample_from_logits(
spec_decode_metadata.draft_lengths_cpu, num_reqs,
spec_decode_metadata.draft_token_ids.shape[0])
- # Mask out the sampled tokens that should not be sampled.
for i in discard_sampled_tokens_req_indices:
valid_sampled_token_ids[i].clear()
- # Append sampled tokens
+
for req_idx, req_state, _ in request_seq_lens:
sampled_ids = valid_sampled_token_ids[req_idx]
if not sampled_ids:
@@ -1151,7 +1152,6 @@ def _sample_from_logits(
"Sampled token IDs exceed the max model length. "
f"Total number of tokens: {end_idx} > max_model_len: "
f"{self.max_model_len}")
-
self.input_batch.token_ids_cpu[req_idx,
start_idx:end_idx] = sampled_ids
self.input_batch.num_tokens_no_spec[req_idx] = end_idx
@@ -1159,7 +1159,6 @@ def _sample_from_logits(
req_state.output_token_ids.extend(sampled_ids)
if logprobs is not None:
- # Map logprobs back to the pre-dp shuffling order
logprobs_lists = _jax_logprobs_to_lists(logprobs,
logits_indices_selector)
else:
@@ -1176,25 +1175,13 @@ def _sample_from_logits(
input_ids,
)
- # --- SAFETY PATCH FOR MISSING KEYS START ---
- # 1. Create the base mapping
- final_req_id_to_index = dict(self.input_batch.req_id_to_index)
-
- # 2. Check for missing keys that the scheduler expects
- # If a request was scheduled but is no longer in input_batch (e.g. finished),
- # we must provide a dummy index to prevent Scheduler crash.
- if scheduler_output.num_scheduled_tokens:
- for sched_req_id in scheduler_output.num_scheduled_tokens.keys():
- if sched_req_id not in final_req_id_to_index:
- # Map missing request to -1 (dummy index).
- # This tells the scheduler "it exists" but points to invalid/safe memory
- # if they try to read sampled tokens for it.
- final_req_id_to_index[sched_req_id] = -1
- # --- SAFETY PATCH FOR MISSING KEYS END ---
+ # --- APPLY SAFETY PATCH (SYNC PATH) ---
+ patched_req_id_to_index = _patch_req_id_to_index(
+ self.input_batch.req_id_to_index)
model_runner_output = ModelRunnerOutput(
req_ids=req_ids,
- req_id_to_index=final_req_id_to_index,
+ req_id_to_index=patched_req_id_to_index, # Use patched map
sampled_token_ids=valid_sampled_token_ids,
logprobs=logprobs_lists,
prompt_logprobs_dict=prompt_logprobs_dict,
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment