Created
March 12, 2026 18:46
-
-
Save OhadRubin/ee363ca4cedbcb21fe5666459f00d2ae to your computer and use it in GitHub Desktop.
PR 1486: prompt_logprobs support for TPU
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| From 848d3f5d4c2aa5bee58eecad8421b5947ec79334 Mon Sep 17 00:00:00 2001 | |
| From: catswe <212922539+catswe@users.noreply.github.com> | |
| Date: Sat, 17 Jan 2026 18:36:09 -0600 | |
| Subject: [PATCH 1/9] Add prompt_logprobs | |
| Signed-off-by: catswe <212922539+catswe@users.noreply.github.com> | |
| --- | |
| tpu_inference/runner/tpu_runner.py | 101 +++++++++++++++++++++++++++-- | |
| 1 file changed, 94 insertions(+), 7 deletions(-) | |
| diff --git a/tpu_inference/runner/tpu_runner.py b/tpu_inference/runner/tpu_runner.py | |
| index 794dd95665..cfee1dd34a 100644 | |
| --- a/tpu_inference/runner/tpu_runner.py | |
| +++ b/tpu_inference/runner/tpu_runner.py | |
| @@ -154,6 +154,7 @@ class ExecuteModelState: | |
| attn_metadata: AttentionMetadata | |
| input_ids: Optional[jax.Array] | |
| hidden_states: jax.Array | |
| + hidden_states_all: Optional[jax.Array] | |
| logits: jax.Array | |
| aux_hidden_states: Optional[jax.Array] | |
| spec_decode_metadata: Optional[SpecDecodeMetadata] | |
| @@ -575,6 +576,8 @@ def sample_tokens( | |
| # This can happen in pipeline parallel case. | |
| return EMPTY_MODEL_RUNNER_OUTPUT | |
| + hidden_states_all = self.execute_model_state.hidden_states_all | |
| + | |
| (scheduler_output, attn_metadata, input_ids, hidden_states, logits, | |
| aux_hidden_states, spec_decode_metadata, kv_connector_output, | |
| logits_indices_selector, | |
| @@ -601,10 +604,17 @@ def sample_tokens( | |
| logits, | |
| arange, | |
| ) | |
| - return self._sample_from_logits( | |
| - scheduler_output, attn_metadata, input_ids, hidden_states, logits, | |
| - aux_hidden_states, spec_decode_metadata, kv_connector_output, | |
| - logits_indices_selector, padded_num_reqs) | |
| + return self._sample_from_logits(scheduler_output, | |
| + attn_metadata, | |
| + input_ids, | |
| + hidden_states, | |
| + logits, | |
| + aux_hidden_states, | |
| + spec_decode_metadata, | |
| + kv_connector_output, | |
| + logits_indices_selector, | |
| + padded_num_reqs, | |
| + hidden_states_all=hidden_states_all) | |
| def _modify_prev_results(self): | |
| # If copy to host has not been done, we just wait. | |
| @@ -788,6 +798,7 @@ def _execute_model( | |
| assert isinstance(hidden_states, JaxIntermediateTensors) | |
| hidden_states.kv_connector_output = kv_connector_output | |
| return attn_metadata, hidden_states | |
| + hidden_states_all = hidden_states | |
| hidden_states = self._select_from_array_fn(hidden_states, | |
| logits_indices) | |
| logits = self.compute_logits_fn( | |
| @@ -801,6 +812,7 @@ def _execute_model( | |
| attn_metadata=attn_metadata, | |
| input_ids=input_ids, | |
| hidden_states=hidden_states, | |
| + hidden_states_all=hidden_states_all, | |
| logits=logits, | |
| aux_hidden_states=aux_hidden_states, | |
| spec_decode_metadata=spec_decode_metadata, | |
| @@ -809,6 +821,75 @@ def _execute_model( | |
| padded_num_reqs=padded_num_reqs) | |
| return attn_metadata, None | |
| + def _get_prompt_logprobs_dict( | |
| + self, | |
| + hidden_states_all: jax.Array, | |
| + scheduler_output: "VllmSchedulerOutput", | |
| + num_reqs: int, | |
| + ) -> dict[str, Optional[list]]: | |
| + """Compute prompt logprobs for requests that need them.""" | |
| + prompt_logprobs_dict: dict[str, Optional[list]] = {} | |
| + | |
| + # Check which requests need prompt_logprobs | |
| + reqs_needing_logprobs = [] | |
| + for req_id in self.input_batch.req_ids[:num_reqs]: | |
| + req_state = self.requests.get(req_id) | |
| + if req_state is None: | |
| + prompt_logprobs_dict[req_id] = None | |
| + continue | |
| + sampling_params = getattr(req_state, 'sampling_params', None) | |
| + if sampling_params is not None and getattr( | |
| + sampling_params, 'prompt_logprobs', None) is not None: | |
| + reqs_needing_logprobs.append(req_id) | |
| + else: | |
| + prompt_logprobs_dict[req_id] = None | |
| + | |
| + if not reqs_needing_logprobs: | |
| + return prompt_logprobs_dict | |
| + | |
| + # Compute logits for ALL positions | |
| + total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens | |
| + all_hidden = hidden_states_all[:total_num_scheduled_tokens] | |
| + all_logits = self.compute_logits_fn(self.state, all_hidden, None) | |
| + | |
| + # Compute log softmax to get logprobs | |
| + all_logprobs = jax.nn.log_softmax(all_logits, axis=-1) | |
| + all_logprobs_np = np.asarray(jax.device_get(all_logprobs)) | |
| + | |
| + # Build prompt_logprobs for each request that needs it | |
| + token_offset = 0 | |
| + for i, req_id in enumerate(self.input_batch.req_ids[:num_reqs]): | |
| + num_tokens = scheduler_output.num_scheduled_tokens.get(req_id, 0) | |
| + | |
| + if req_id not in reqs_needing_logprobs: | |
| + token_offset += num_tokens | |
| + continue | |
| + | |
| + req_state = self.requests[req_id] | |
| + | |
| + # Get the input token ids for this request | |
| + req_idx = self.input_batch.req_id_to_index[req_id] | |
| + num_computed = req_state.num_computed_tokens | |
| + | |
| + # Build logprobs list: [None, {token: logprob}, ...] | |
| + # First token has no prior context, so None | |
| + req_prompt_logprobs = [None] | |
| + | |
| + # For positions 1..num_tokens-1, get logprob of actual token | |
| + for j in range(num_tokens - 1): | |
| + pos = token_offset + j | |
| + # The token at position j+1 is predicted by logits at position j | |
| + next_token_id = int( | |
| + self.input_batch.token_ids_cpu[req_idx, | |
| + num_computed + j + 1]) | |
| + logprob = float(all_logprobs_np[pos, next_token_id]) | |
| + req_prompt_logprobs.append({next_token_id: logprob}) | |
| + | |
| + prompt_logprobs_dict[req_id] = req_prompt_logprobs | |
| + token_offset += num_tokens | |
| + | |
| + return prompt_logprobs_dict | |
| + | |
| def _sample_from_logits( | |
| self, | |
| scheduler_output: "VllmSchedulerOutput", | |
| @@ -821,6 +902,7 @@ def _sample_from_logits( | |
| kv_connector_output: Optional[KVConnectorOutput], | |
| logits_indices_selector: Optional[List[int]] = None, | |
| padded_num_reqs: Optional[int] = None, | |
| + hidden_states_all=None | |
| ) -> ModelRunnerOutput | AsyncTPUModelRunnerOutput: | |
| if padded_num_reqs is None: | |
| padded_num_reqs = runner_utils.get_padded_num_reqs_with_upper_limit( | |
| @@ -910,9 +992,14 @@ def _sample_from_logits( | |
| self.input_batch.req_ids[:num_reqs]), "req_ids contains None" | |
| req_ids = cast(list[str], self.input_batch.req_ids[:num_reqs]) | |
| - prompt_logprobs_dict = {} | |
| - for req_id in self.input_batch.req_ids[:num_reqs]: | |
| - prompt_logprobs_dict[req_id] = None | |
| + if hidden_states_all is not None: | |
| + prompt_logprobs_dict = self._get_prompt_logprobs_dict( | |
| + hidden_states_all, scheduler_output, num_reqs) | |
| + else: | |
| + prompt_logprobs_dict = { | |
| + req_id: None | |
| + for req_id in self.input_batch.req_ids[:num_reqs] | |
| + } | |
| # If async scheduler enabled | |
| if self.scheduler_config.async_scheduling: | |
| From 2cf982f38113e4c13d4a07181cd82f6a0733c442 Mon Sep 17 00:00:00 2001 | |
| From: catswe <212922539+catswe@users.noreply.github.com> | |
| Date: Sat, 17 Jan 2026 18:46:46 -0600 | |
| Subject: [PATCH 2/9] Update | |
| Signed-off-by: catswe <212922539+catswe@users.noreply.github.com> | |
| --- | |
| tpu_inference/runner/tpu_runner.py | 124 ++++++++++++++++++++--------- | |
| 1 file changed, 88 insertions(+), 36 deletions(-) | |
| diff --git a/tpu_inference/runner/tpu_runner.py b/tpu_inference/runner/tpu_runner.py | |
| index cfee1dd34a..02f2075e5c 100644 | |
| --- a/tpu_inference/runner/tpu_runner.py | |
| +++ b/tpu_inference/runner/tpu_runner.py | |
| @@ -24,6 +24,7 @@ | |
| import jax.numpy as jnp | |
| import jaxtyping | |
| import numpy as np | |
| +import torch | |
| import vllm.envs as vllm_envs | |
| from flax import nnx | |
| from jax.experimental import mesh_utils | |
| @@ -40,7 +41,7 @@ | |
| from vllm.v1.kv_cache_interface import KVCacheConfig | |
| from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, AsyncModelRunnerOutput, | |
| DraftTokenIds, KVConnectorOutput, LogprobsLists, | |
| - ModelRunnerOutput) | |
| + LogprobsTensors, ModelRunnerOutput) | |
| from vllm.v1.request import Request | |
| from vllm.v1.spec_decode.ngram_proposer import NgramProposer | |
| from vllm.v1.worker.kv_connector_model_runner_mixin import \ | |
| @@ -576,21 +577,20 @@ def sample_tokens( | |
| # This can happen in pipeline parallel case. | |
| return EMPTY_MODEL_RUNNER_OUTPUT | |
| - hidden_states_all = self.execute_model_state.hidden_states_all | |
| - | |
| (scheduler_output, attn_metadata, input_ids, hidden_states, logits, | |
| aux_hidden_states, spec_decode_metadata, kv_connector_output, | |
| - logits_indices_selector, | |
| - padded_num_reqs) = (self.execute_model_state.scheduler_output, | |
| - self.execute_model_state.attn_metadata, | |
| - self.execute_model_state.input_ids, | |
| - self.execute_model_state.hidden_states, | |
| - self.execute_model_state.logits, | |
| - self.execute_model_state.aux_hidden_states, | |
| - self.execute_model_state.spec_decode_metadata, | |
| - self.execute_model_state.kv_connector_output, | |
| - self.execute_model_state.logits_indices_selector, | |
| - self.execute_model_state.padded_num_reqs) | |
| + logits_indices_selector, padded_num_reqs, hidden_states_all) = ( | |
| + self.execute_model_state.scheduler_output, | |
| + self.execute_model_state.attn_metadata, | |
| + self.execute_model_state.input_ids, | |
| + self.execute_model_state.hidden_states, | |
| + self.execute_model_state.logits, | |
| + self.execute_model_state.aux_hidden_states, | |
| + self.execute_model_state.spec_decode_metadata, | |
| + self.execute_model_state.kv_connector_output, | |
| + self.execute_model_state.logits_indices_selector, | |
| + self.execute_model_state.padded_num_reqs, | |
| + self.execute_model_state.hidden_states_all) | |
| self.execute_model_state = None | |
| if grammar_output is not None: | |
| @@ -826,21 +826,23 @@ def _get_prompt_logprobs_dict( | |
| hidden_states_all: jax.Array, | |
| scheduler_output: "VllmSchedulerOutput", | |
| num_reqs: int, | |
| - ) -> dict[str, Optional[list]]: | |
| + ) -> dict[str, Optional["LogprobsTensors"]]: | |
| """Compute prompt logprobs for requests that need them.""" | |
| - prompt_logprobs_dict: dict[str, Optional[list]] = {} | |
| + | |
| + prompt_logprobs_dict: dict[str, Optional[LogprobsTensors]] = {} | |
| # Check which requests need prompt_logprobs | |
| - reqs_needing_logprobs = [] | |
| + reqs_needing_logprobs: list[tuple[str, int]] = [ | |
| + ] # (req_id, num_prompt_logprobs) | |
| for req_id in self.input_batch.req_ids[:num_reqs]: | |
| req_state = self.requests.get(req_id) | |
| if req_state is None: | |
| prompt_logprobs_dict[req_id] = None | |
| continue | |
| sampling_params = getattr(req_state, 'sampling_params', None) | |
| - if sampling_params is not None and getattr( | |
| - sampling_params, 'prompt_logprobs', None) is not None: | |
| - reqs_needing_logprobs.append(req_id) | |
| + if sampling_params is not None and sampling_params.prompt_logprobs is not None: | |
| + num_prompt_logprobs = sampling_params.prompt_logprobs | |
| + reqs_needing_logprobs.append((req_id, num_prompt_logprobs)) | |
| else: | |
| prompt_logprobs_dict[req_id] = None | |
| @@ -858,34 +860,84 @@ def _get_prompt_logprobs_dict( | |
| # Build prompt_logprobs for each request that needs it | |
| token_offset = 0 | |
| + reqs_needing_logprobs_dict = dict(reqs_needing_logprobs) | |
| + | |
| for i, req_id in enumerate(self.input_batch.req_ids[:num_reqs]): | |
| num_tokens = scheduler_output.num_scheduled_tokens.get(req_id, 0) | |
| - if req_id not in reqs_needing_logprobs: | |
| + if req_id not in reqs_needing_logprobs_dict: | |
| token_offset += num_tokens | |
| continue | |
| + num_prompt_logprobs = reqs_needing_logprobs_dict[req_id] | |
| req_state = self.requests[req_id] | |
| # Get the input token ids for this request | |
| req_idx = self.input_batch.req_id_to_index[req_id] | |
| num_computed = req_state.num_computed_tokens | |
| - # Build logprobs list: [None, {token: logprob}, ...] | |
| - # First token has no prior context, so None | |
| - req_prompt_logprobs = [None] | |
| - | |
| - # For positions 1..num_tokens-1, get logprob of actual token | |
| - for j in range(num_tokens - 1): | |
| - pos = token_offset + j | |
| - # The token at position j+1 is predicted by logits at position j | |
| - next_token_id = int( | |
| - self.input_batch.token_ids_cpu[req_idx, | |
| - num_computed + j + 1]) | |
| - logprob = float(all_logprobs_np[pos, next_token_id]) | |
| - req_prompt_logprobs.append({next_token_id: logprob}) | |
| - | |
| - prompt_logprobs_dict[req_id] = req_prompt_logprobs | |
| + # Number of prompt logprobs to return (excluding first token which has no prior) | |
| + # We compute logprobs for tokens at positions 1 to num_tokens-1 | |
| + # (the logprob of token i is computed from hidden state at position i-1) | |
| + num_logprob_tokens = num_tokens - 1 | |
| + | |
| + if num_logprob_tokens <= 0: | |
| + prompt_logprobs_dict[req_id] = None | |
| + token_offset += num_tokens | |
| + continue | |
| + | |
| + # For each position, we need: top-k token ids, their logprobs, and rank of selected token | |
| + # num_prompt_logprobs + 1 to include the selected token | |
| + num_logprobs_to_return = num_prompt_logprobs + 1 | |
| + | |
| + # Get logprobs for positions [0, num_tokens-1) which predict tokens [1, num_tokens) | |
| + req_logprobs = all_logprobs_np[token_offset:token_offset + | |
| + num_logprob_tokens] | |
| + | |
| + # Get top-k logprobs for each position | |
| + # Shape: [num_logprob_tokens, num_logprobs_to_return] | |
| + top_indices = np.argpartition(req_logprobs, | |
| + -num_logprobs_to_return, | |
| + axis=-1)[:, -num_logprobs_to_return:] | |
| + top_logprobs = np.take_along_axis(req_logprobs, | |
| + top_indices, | |
| + axis=-1) | |
| + | |
| + # Sort by logprob descending | |
| + sort_order = np.argsort(-top_logprobs, axis=-1) | |
| + top_indices = np.take_along_axis(top_indices, sort_order, axis=-1) | |
| + top_logprobs = np.take_along_axis(top_logprobs, | |
| + sort_order, | |
| + axis=-1) | |
| + | |
| + # Get the actual next token ids and their ranks | |
| + next_token_ids = self.input_batch.token_ids_cpu[req_idx, | |
| + num_computed + | |
| + 1:num_computed + | |
| + num_tokens] | |
| + | |
| + # Compute ranks of selected tokens | |
| + selected_token_ranks = np.zeros(num_logprob_tokens, dtype=np.int64) | |
| + for j in range(num_logprob_tokens): | |
| + next_token_id = int(next_token_ids[j]) | |
| + # Find rank (0-indexed position in sorted top-k, or vocab_size if not in top-k) | |
| + rank_in_topk = np.where(top_indices[j] == next_token_id)[0] | |
| + if len(rank_in_topk) > 0: | |
| + selected_token_ranks[j] = rank_in_topk[0] | |
| + else: | |
| + # Token not in top-k, compute actual rank | |
| + selected_token_ranks[j] = np.sum( | |
| + req_logprobs[j] > req_logprobs[j, next_token_id]) | |
| + | |
| + # Create LogprobsTensors | |
| + logprobs_tensors = LogprobsTensors( | |
| + logprob_token_ids=torch.from_numpy(top_indices.astype( | |
| + np.int64)), | |
| + logprobs=torch.from_numpy(top_logprobs.astype(np.float32)), | |
| + selected_token_ranks=torch.from_numpy(selected_token_ranks), | |
| + ) | |
| + | |
| + prompt_logprobs_dict[req_id] = logprobs_tensors | |
| token_offset += num_tokens | |
| return prompt_logprobs_dict | |
| From 4fd18357332be01e5fb13d40584280b3b7f50a24 Mon Sep 17 00:00:00 2001 | |
| From: catswe <212922539+catswe@users.noreply.github.com> | |
| Date: Sat, 17 Jan 2026 18:56:07 -0600 | |
| Subject: [PATCH 3/9] Update code | |
| Signed-off-by: catswe <212922539+catswe@users.noreply.github.com> | |
| --- | |
| tpu_inference/runner/tpu_runner.py | 122 +++++++++++++++++------------ | |
| 1 file changed, 71 insertions(+), 51 deletions(-) | |
| diff --git a/tpu_inference/runner/tpu_runner.py b/tpu_inference/runner/tpu_runner.py | |
| index 02f2075e5c..12bffb3c7d 100644 | |
| --- a/tpu_inference/runner/tpu_runner.py | |
| +++ b/tpu_inference/runner/tpu_runner.py | |
| @@ -826,28 +826,65 @@ def _get_prompt_logprobs_dict( | |
| hidden_states_all: jax.Array, | |
| scheduler_output: "VllmSchedulerOutput", | |
| num_reqs: int, | |
| - ) -> dict[str, Optional["LogprobsTensors"]]: | |
| - """Compute prompt logprobs for requests that need them.""" | |
| + ) -> dict[str, "LogprobsTensors"]: | |
| + """Compute prompt logprobs for requests that need them. | |
| + | |
| + Only returns prompt logprobs when a request COMPLETES its prefill | |
| + in the current step. | |
| + """ | |
| + from vllm.v1.outputs import LogprobsTensors | |
| - prompt_logprobs_dict: dict[str, Optional[LogprobsTensors]] = {} | |
| + prompt_logprobs_dict: dict[str, LogprobsTensors] = {} | |
| - # Check which requests need prompt_logprobs | |
| - reqs_needing_logprobs: list[tuple[str, int]] = [ | |
| - ] # (req_id, num_prompt_logprobs) | |
| + # Check which requests need prompt_logprobs AND are completing prefill | |
| + reqs_completing_prefill: list[tuple[str, int, int, int]] = [ | |
| + ] # (req_id, num_prompt_logprobs, start_offset, num_prompt_tokens) | |
| + | |
| + token_offset = 0 | |
| for req_id in self.input_batch.req_ids[:num_reqs]: | |
| + num_scheduled = scheduler_output.num_scheduled_tokens.get( | |
| + req_id, 0) | |
| + | |
| req_state = self.requests.get(req_id) | |
| if req_state is None: | |
| - prompt_logprobs_dict[req_id] = None | |
| + token_offset += num_scheduled | |
| continue | |
| + | |
| sampling_params = getattr(req_state, 'sampling_params', None) | |
| - if sampling_params is not None and sampling_params.prompt_logprobs is not None: | |
| - num_prompt_logprobs = sampling_params.prompt_logprobs | |
| - reqs_needing_logprobs.append((req_id, num_prompt_logprobs)) | |
| - else: | |
| - prompt_logprobs_dict[req_id] = None | |
| + if sampling_params is None or sampling_params.prompt_logprobs is None: | |
| + token_offset += num_scheduled | |
| + continue | |
| + | |
| + # Check if we're in prefill phase and completing it | |
| + num_computed = req_state.num_computed_tokens | |
| + prompt_token_ids = getattr(req_state, 'prompt_token_ids', None) | |
| + if prompt_token_ids is None: | |
| + token_offset += num_scheduled | |
| + continue | |
| + | |
| + prompt_len = len(prompt_token_ids) | |
| + | |
| + # After this step, how many tokens will be computed? | |
| + tokens_after_step = num_computed + num_scheduled | |
| + | |
| + # Skip if already past prefill (decode phase) | |
| + if num_computed >= prompt_len: | |
| + token_offset += num_scheduled | |
| + continue | |
| + | |
| + # Skip if prefill not completing in this step (chunked prefill, more chunks to come) | |
| + if tokens_after_step < prompt_len: | |
| + token_offset += num_scheduled | |
| + continue | |
| - if not reqs_needing_logprobs: | |
| - return prompt_logprobs_dict | |
| + # This request is completing its prefill in this step | |
| + num_prompt_logprobs = sampling_params.prompt_logprobs | |
| + reqs_completing_prefill.append( | |
| + (req_id, num_prompt_logprobs, token_offset, prompt_len)) | |
| + token_offset += num_scheduled | |
| + | |
| + if not reqs_completing_prefill: | |
| + return {} # Return empty dict - this is important! | |
| # Compute logits for ALL positions | |
| total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens | |
| @@ -858,44 +895,32 @@ def _get_prompt_logprobs_dict( | |
| all_logprobs = jax.nn.log_softmax(all_logits, axis=-1) | |
| all_logprobs_np = np.asarray(jax.device_get(all_logprobs)) | |
| - # Build prompt_logprobs for each request that needs it | |
| - token_offset = 0 | |
| - reqs_needing_logprobs_dict = dict(reqs_needing_logprobs) | |
| - | |
| - for i, req_id in enumerate(self.input_batch.req_ids[:num_reqs]): | |
| - num_tokens = scheduler_output.num_scheduled_tokens.get(req_id, 0) | |
| - | |
| - if req_id not in reqs_needing_logprobs_dict: | |
| - token_offset += num_tokens | |
| - continue | |
| - | |
| - num_prompt_logprobs = reqs_needing_logprobs_dict[req_id] | |
| + # Build prompt_logprobs for each request completing prefill | |
| + for req_id, num_prompt_logprobs, start_offset, prompt_len in reqs_completing_prefill: | |
| req_state = self.requests[req_id] | |
| - | |
| - # Get the input token ids for this request | |
| req_idx = self.input_batch.req_id_to_index[req_id] | |
| num_computed = req_state.num_computed_tokens | |
| - # Number of prompt logprobs to return (excluding first token which has no prior) | |
| - # We compute logprobs for tokens at positions 1 to num_tokens-1 | |
| - # (the logprob of token i is computed from hidden state at position i-1) | |
| - num_logprob_tokens = num_tokens - 1 | |
| + # Number of prompt tokens being processed in this step | |
| + # (may be less than prompt_len if chunked) | |
| + num_prompt_tokens_this_step = prompt_len - num_computed | |
| + | |
| + # Number of prompt logprobs to compute (excluding first token which has no prior) | |
| + # We compute logprobs for tokens at positions 1 to num_prompt_tokens_this_step-1 | |
| + num_logprob_tokens = num_prompt_tokens_this_step - 1 | |
| if num_logprob_tokens <= 0: | |
| - prompt_logprobs_dict[req_id] = None | |
| - token_offset += num_tokens | |
| continue | |
| - # For each position, we need: top-k token ids, their logprobs, and rank of selected token | |
| # num_prompt_logprobs + 1 to include the selected token | |
| - num_logprobs_to_return = num_prompt_logprobs + 1 | |
| + num_logprobs_to_return = min(num_prompt_logprobs + 1, | |
| + self.vocab_size) | |
| - # Get logprobs for positions [0, num_tokens-1) which predict tokens [1, num_tokens) | |
| - req_logprobs = all_logprobs_np[token_offset:token_offset + | |
| + # Get logprobs for this request's prompt tokens | |
| + req_logprobs = all_logprobs_np[start_offset:start_offset + | |
| num_logprob_tokens] | |
| # Get top-k logprobs for each position | |
| - # Shape: [num_logprob_tokens, num_logprobs_to_return] | |
| top_indices = np.argpartition(req_logprobs, | |
| -num_logprobs_to_return, | |
| axis=-1)[:, -num_logprobs_to_return:] | |
| @@ -910,26 +935,22 @@ def _get_prompt_logprobs_dict( | |
| sort_order, | |
| axis=-1) | |
| - # Get the actual next token ids and their ranks | |
| - next_token_ids = self.input_batch.token_ids_cpu[req_idx, | |
| - num_computed + | |
| - 1:num_computed + | |
| - num_tokens] | |
| + # Get the actual next token ids and compute their ranks | |
| + next_token_ids = self.input_batch.token_ids_cpu[ | |
| + req_idx, | |
| + num_computed + 1:num_computed + num_prompt_tokens_this_step] | |
| - # Compute ranks of selected tokens | |
| selected_token_ranks = np.zeros(num_logprob_tokens, dtype=np.int64) | |
| for j in range(num_logprob_tokens): | |
| next_token_id = int(next_token_ids[j]) | |
| - # Find rank (0-indexed position in sorted top-k, or vocab_size if not in top-k) | |
| rank_in_topk = np.where(top_indices[j] == next_token_id)[0] | |
| if len(rank_in_topk) > 0: | |
| selected_token_ranks[j] = rank_in_topk[0] | |
| else: | |
| - # Token not in top-k, compute actual rank | |
| - selected_token_ranks[j] = np.sum( | |
| - req_logprobs[j] > req_logprobs[j, next_token_id]) | |
| + selected_token_ranks[j] = int( | |
| + np.sum(req_logprobs[j] > req_logprobs[j, | |
| + next_token_id])) | |
| - # Create LogprobsTensors | |
| logprobs_tensors = LogprobsTensors( | |
| logprob_token_ids=torch.from_numpy(top_indices.astype( | |
| np.int64)), | |
| @@ -938,7 +959,6 @@ def _get_prompt_logprobs_dict( | |
| ) | |
| prompt_logprobs_dict[req_id] = logprobs_tensors | |
| - token_offset += num_tokens | |
| return prompt_logprobs_dict | |
| From 87588b4cbcb29099e1fc50d915f9b784634fcb54 Mon Sep 17 00:00:00 2001 | |
| From: catswe <212922539+catswe@users.noreply.github.com> | |
| Date: Sat, 17 Jan 2026 19:13:47 -0600 | |
| Subject: [PATCH 4/9] Update | |
| Signed-off-by: catswe <212922539+catswe@users.noreply.github.com> | |
| --- | |
| tpu_inference/runner/tpu_runner.py | 74 ++++++++++-------------------- | |
| 1 file changed, 24 insertions(+), 50 deletions(-) | |
| diff --git a/tpu_inference/runner/tpu_runner.py b/tpu_inference/runner/tpu_runner.py | |
| index 12bffb3c7d..cd0ccf5230 100644 | |
| --- a/tpu_inference/runner/tpu_runner.py | |
| +++ b/tpu_inference/runner/tpu_runner.py | |
| @@ -837,8 +837,7 @@ def _get_prompt_logprobs_dict( | |
| prompt_logprobs_dict: dict[str, LogprobsTensors] = {} | |
| # Check which requests need prompt_logprobs AND are completing prefill | |
| - reqs_completing_prefill: list[tuple[str, int, int, int]] = [ | |
| - ] # (req_id, num_prompt_logprobs, start_offset, num_prompt_tokens) | |
| + reqs_completing_prefill: list[tuple[str, int, int, int]] = [] | |
| token_offset = 0 | |
| for req_id in self.input_batch.req_ids[:num_reqs]: | |
| @@ -863,8 +862,6 @@ def _get_prompt_logprobs_dict( | |
| continue | |
| prompt_len = len(prompt_token_ids) | |
| - | |
| - # After this step, how many tokens will be computed? | |
| tokens_after_step = num_computed + num_scheduled | |
| # Skip if already past prefill (decode phase) | |
| @@ -872,19 +869,18 @@ def _get_prompt_logprobs_dict( | |
| token_offset += num_scheduled | |
| continue | |
| - # Skip if prefill not completing in this step (chunked prefill, more chunks to come) | |
| + # Skip if prefill not completing in this step | |
| if tokens_after_step < prompt_len: | |
| token_offset += num_scheduled | |
| continue | |
| - # This request is completing its prefill in this step | |
| num_prompt_logprobs = sampling_params.prompt_logprobs | |
| reqs_completing_prefill.append( | |
| (req_id, num_prompt_logprobs, token_offset, prompt_len)) | |
| token_offset += num_scheduled | |
| if not reqs_completing_prefill: | |
| - return {} # Return empty dict - this is important! | |
| + return {} | |
| # Compute logits for ALL positions | |
| total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens | |
| @@ -895,66 +891,44 @@ def _get_prompt_logprobs_dict( | |
| all_logprobs = jax.nn.log_softmax(all_logits, axis=-1) | |
| all_logprobs_np = np.asarray(jax.device_get(all_logprobs)) | |
| - # Build prompt_logprobs for each request completing prefill | |
| for req_id, num_prompt_logprobs, start_offset, prompt_len in reqs_completing_prefill: | |
| req_state = self.requests[req_id] | |
| req_idx = self.input_batch.req_id_to_index[req_id] | |
| num_computed = req_state.num_computed_tokens | |
| - # Number of prompt tokens being processed in this step | |
| - # (may be less than prompt_len if chunked) | |
| num_prompt_tokens_this_step = prompt_len - num_computed | |
| - | |
| - # Number of prompt logprobs to compute (excluding first token which has no prior) | |
| - # We compute logprobs for tokens at positions 1 to num_prompt_tokens_this_step-1 | |
| + # Logprobs for positions 0 to num_tokens-2 predict tokens 1 to num_tokens-1 | |
| num_logprob_tokens = num_prompt_tokens_this_step - 1 | |
| if num_logprob_tokens <= 0: | |
| continue | |
| - # num_prompt_logprobs + 1 to include the selected token | |
| - num_logprobs_to_return = min(num_prompt_logprobs + 1, | |
| - self.vocab_size) | |
| + # Get the actual next token ids (tokens at positions 1 to num_tokens-1) | |
| + next_token_ids = self.input_batch.token_ids_cpu[ | |
| + req_idx, num_computed + 1:num_computed + | |
| + num_prompt_tokens_this_step].astype(np.int64) | |
| - # Get logprobs for this request's prompt tokens | |
| + # Get logprobs for this request's positions | |
| req_logprobs = all_logprobs_np[start_offset:start_offset + | |
| num_logprob_tokens] | |
| - # Get top-k logprobs for each position | |
| - top_indices = np.argpartition(req_logprobs, | |
| - -num_logprobs_to_return, | |
| - axis=-1)[:, -num_logprobs_to_return:] | |
| - top_logprobs = np.take_along_axis(req_logprobs, | |
| - top_indices, | |
| - axis=-1) | |
| - | |
| - # Sort by logprob descending | |
| - sort_order = np.argsort(-top_logprobs, axis=-1) | |
| - top_indices = np.take_along_axis(top_indices, sort_order, axis=-1) | |
| - top_logprobs = np.take_along_axis(top_logprobs, | |
| - sort_order, | |
| - axis=-1) | |
| - | |
| - # Get the actual next token ids and compute their ranks | |
| - next_token_ids = self.input_batch.token_ids_cpu[ | |
| - req_idx, | |
| - num_computed + 1:num_computed + num_prompt_tokens_this_step] | |
| - | |
| - selected_token_ranks = np.zeros(num_logprob_tokens, dtype=np.int64) | |
| - for j in range(num_logprob_tokens): | |
| - next_token_id = int(next_token_ids[j]) | |
| - rank_in_topk = np.where(top_indices[j] == next_token_id)[0] | |
| - if len(rank_in_topk) > 0: | |
| - selected_token_ranks[j] = rank_in_topk[0] | |
| - else: | |
| - selected_token_ranks[j] = int( | |
| - np.sum(req_logprobs[j] > req_logprobs[j, | |
| - next_token_id])) | |
| + # For prompt logprobs, we only need the selected token's logprob | |
| + # Shape: [num_logprob_tokens, 1] | |
| + selected_logprobs = req_logprobs[np.arange(num_logprob_tokens), | |
| + next_token_ids] | |
| + | |
| + # Compute ranks of selected tokens (how many tokens have higher logprob) | |
| + selected_token_ranks = np.sum(req_logprobs | |
| + > selected_logprobs[:, np.newaxis], | |
| + axis=-1).astype(np.int64) | |
| + # Create LogprobsTensors with just the selected token | |
| + # Shape: [num_logprob_tokens, 1] | |
| logprobs_tensors = LogprobsTensors( | |
| - logprob_token_ids=torch.from_numpy(top_indices.astype( | |
| - np.int64)), | |
| - logprobs=torch.from_numpy(top_logprobs.astype(np.float32)), | |
| + logprob_token_ids=torch.from_numpy( | |
| + next_token_ids.reshape(-1, 1)), | |
| + logprobs=torch.from_numpy( | |
| + selected_logprobs.astype(np.float32).reshape(-1, 1)), | |
| selected_token_ranks=torch.from_numpy(selected_token_ranks), | |
| ) | |
| From 94c0065d9ec7abc8d3b3c0e37f164c13226ade7d Mon Sep 17 00:00:00 2001 | |
| From: catswe <212922539+catswe@users.noreply.github.com> | |
| Date: Sat, 17 Jan 2026 20:06:12 -0600 | |
| Subject: [PATCH 5/9] Update | |
| Signed-off-by: catswe <212922539+catswe@users.noreply.github.com> | |
| --- | |
| tpu_inference/runner/tpu_runner.py | 31 ++++++++++++++++++++---------- | |
| 1 file changed, 21 insertions(+), 10 deletions(-) | |
| diff --git a/tpu_inference/runner/tpu_runner.py b/tpu_inference/runner/tpu_runner.py | |
| index cd0ccf5230..17c6c444bc 100644 | |
| --- a/tpu_inference/runner/tpu_runner.py | |
| +++ b/tpu_inference/runner/tpu_runner.py | |
| @@ -832,8 +832,6 @@ def _get_prompt_logprobs_dict( | |
| Only returns prompt logprobs when a request COMPLETES its prefill | |
| in the current step. | |
| """ | |
| - from vllm.v1.outputs import LogprobsTensors | |
| - | |
| prompt_logprobs_dict: dict[str, LogprobsTensors] = {} | |
| # Check which requests need prompt_logprobs AND are completing prefill | |
| @@ -841,8 +839,18 @@ def _get_prompt_logprobs_dict( | |
| token_offset = 0 | |
| for req_id in self.input_batch.req_ids[:num_reqs]: | |
| + # Skip invalid entries | |
| + if req_id is None: | |
| + continue | |
| + | |
| + # Skip if not in current batch | |
| + if req_id not in self.input_batch.req_id_to_index: | |
| + continue | |
| + | |
| num_scheduled = scheduler_output.num_scheduled_tokens.get( | |
| req_id, 0) | |
| + if num_scheduled == 0: | |
| + continue | |
| req_state = self.requests.get(req_id) | |
| if req_state is None: | |
| @@ -892,18 +900,24 @@ def _get_prompt_logprobs_dict( | |
| all_logprobs_np = np.asarray(jax.device_get(all_logprobs)) | |
| for req_id, num_prompt_logprobs, start_offset, prompt_len in reqs_completing_prefill: | |
| - req_state = self.requests[req_id] | |
| + # Double-check req_id is still valid (defensive) | |
| + if req_id not in self.input_batch.req_id_to_index: | |
| + continue | |
| + | |
| + req_state = self.requests.get(req_id) | |
| + if req_state is None: | |
| + continue | |
| + | |
| req_idx = self.input_batch.req_id_to_index[req_id] | |
| num_computed = req_state.num_computed_tokens | |
| num_prompt_tokens_this_step = prompt_len - num_computed | |
| - # Logprobs for positions 0 to num_tokens-2 predict tokens 1 to num_tokens-1 | |
| num_logprob_tokens = num_prompt_tokens_this_step - 1 | |
| if num_logprob_tokens <= 0: | |
| continue | |
| - # Get the actual next token ids (tokens at positions 1 to num_tokens-1) | |
| + # Get the actual next token ids | |
| next_token_ids = self.input_batch.token_ids_cpu[ | |
| req_idx, num_computed + 1:num_computed + | |
| num_prompt_tokens_this_step].astype(np.int64) | |
| @@ -912,18 +926,15 @@ def _get_prompt_logprobs_dict( | |
| req_logprobs = all_logprobs_np[start_offset:start_offset + | |
| num_logprob_tokens] | |
| - # For prompt logprobs, we only need the selected token's logprob | |
| - # Shape: [num_logprob_tokens, 1] | |
| + # Get the selected token's logprob | |
| selected_logprobs = req_logprobs[np.arange(num_logprob_tokens), | |
| next_token_ids] | |
| - # Compute ranks of selected tokens (how many tokens have higher logprob) | |
| + # Compute ranks of selected tokens | |
| selected_token_ranks = np.sum(req_logprobs | |
| > selected_logprobs[:, np.newaxis], | |
| axis=-1).astype(np.int64) | |
| - # Create LogprobsTensors with just the selected token | |
| - # Shape: [num_logprob_tokens, 1] | |
| logprobs_tensors = LogprobsTensors( | |
| logprob_token_ids=torch.from_numpy( | |
| next_token_ids.reshape(-1, 1)), | |
| From 88d1136e6fb0920f0d43e2763114bed9cbeb5b56 Mon Sep 17 00:00:00 2001 | |
| From: catswe <212922539+catswe@users.noreply.github.com> | |
| Date: Sat, 17 Jan 2026 21:11:40 -0600 | |
| Subject: [PATCH 6/9] Update code | |
| Signed-off-by: catswe <212922539+catswe@users.noreply.github.com> | |
| --- | |
| tpu_inference/runner/tpu_runner.py | 34 +++++++++++++++++++++--------- | |
| 1 file changed, 24 insertions(+), 10 deletions(-) | |
| diff --git a/tpu_inference/runner/tpu_runner.py b/tpu_inference/runner/tpu_runner.py | |
| index 17c6c444bc..d53e600aa8 100644 | |
| --- a/tpu_inference/runner/tpu_runner.py | |
| +++ b/tpu_inference/runner/tpu_runner.py | |
| @@ -832,6 +832,17 @@ def _get_prompt_logprobs_dict( | |
| Only returns prompt logprobs when a request COMPLETES its prefill | |
| in the current step. | |
| """ | |
| + | |
| + # Build set of valid req_ids that will be in ModelRunnerOutput | |
| + current_req_ids = set(self.input_batch.req_ids[:num_reqs]) - {None} | |
| + | |
| + # Also check against scheduler_output to ensure consistency | |
| + scheduled_req_ids = set(scheduler_output.num_scheduled_tokens.keys()) | |
| + valid_req_ids = current_req_ids & scheduled_req_ids | |
| + | |
| + if not valid_req_ids: | |
| + return {} | |
| + | |
| prompt_logprobs_dict: dict[str, LogprobsTensors] = {} | |
| # Check which requests need prompt_logprobs AND are completing prefill | |
| @@ -839,17 +850,13 @@ def _get_prompt_logprobs_dict( | |
| token_offset = 0 | |
| for req_id in self.input_batch.req_ids[:num_reqs]: | |
| - # Skip invalid entries | |
| - if req_id is None: | |
| - continue | |
| - | |
| - # Skip if not in current batch | |
| - if req_id not in self.input_batch.req_id_to_index: | |
| + if req_id is None or req_id not in valid_req_ids: | |
| continue | |
| num_scheduled = scheduler_output.num_scheduled_tokens.get( | |
| req_id, 0) | |
| if num_scheduled == 0: | |
| + token_offset += num_scheduled | |
| continue | |
| req_state = self.requests.get(req_id) | |
| @@ -900,15 +907,18 @@ def _get_prompt_logprobs_dict( | |
| all_logprobs_np = np.asarray(jax.device_get(all_logprobs)) | |
| for req_id, num_prompt_logprobs, start_offset, prompt_len in reqs_completing_prefill: | |
| - # Double-check req_id is still valid (defensive) | |
| - if req_id not in self.input_batch.req_id_to_index: | |
| + # Final validation: skip if req_id won't be in output | |
| + if req_id not in valid_req_ids: | |
| continue | |
| req_state = self.requests.get(req_id) | |
| if req_state is None: | |
| continue | |
| - req_idx = self.input_batch.req_id_to_index[req_id] | |
| + req_idx = self.input_batch.req_id_to_index.get(req_id) | |
| + if req_idx is None: | |
| + continue | |
| + | |
| num_computed = req_state.num_computed_tokens | |
| num_prompt_tokens_this_step = prompt_len - num_computed | |
| @@ -945,7 +955,11 @@ def _get_prompt_logprobs_dict( | |
| prompt_logprobs_dict[req_id] = logprobs_tensors | |
| - return prompt_logprobs_dict | |
| + # Final filter: only return req_ids that are definitely in current batch | |
| + return { | |
| + k: v | |
| + for k, v in prompt_logprobs_dict.items() if k in current_req_ids | |
| + } | |
| def _sample_from_logits( | |
| self, | |
| From 5f7f35c8cd71ce06e49851907e8cd81c87f79136 Mon Sep 17 00:00:00 2001 | |
| From: catswe <212922539+catswe@users.noreply.github.com> | |
| Date: Sat, 17 Jan 2026 21:44:31 -0600 | |
| Subject: [PATCH 7/9] Update | |
| Signed-off-by: catswe <212922539+catswe@users.noreply.github.com> | |
| --- | |
| tpu_inference/runner/tpu_runner.py | 2 +- | |
| 1 file changed, 1 insertion(+), 1 deletion(-) | |
| diff --git a/tpu_inference/runner/tpu_runner.py b/tpu_inference/runner/tpu_runner.py | |
| index d53e600aa8..a69ceeb692 100644 | |
| --- a/tpu_inference/runner/tpu_runner.py | |
| +++ b/tpu_inference/runner/tpu_runner.py | |
| @@ -1176,7 +1176,7 @@ def _sample_from_logits( | |
| model_runner_output = ModelRunnerOutput( | |
| req_ids=req_ids, | |
| - req_id_to_index=self.input_batch.req_id_to_index, | |
| + req_id_to_index=dict(self.input_batch.req_id_to_index), | |
| sampled_token_ids=valid_sampled_token_ids, | |
| logprobs=logprobs_lists, | |
| prompt_logprobs_dict=prompt_logprobs_dict, | |
| From 7343c090e8740c6b9fe3015fbfebf5e6c4ec6f5f Mon Sep 17 00:00:00 2001 | |
| From: catswe <212922539+catswe@users.noreply.github.com> | |
| Date: Sat, 17 Jan 2026 21:52:55 -0600 | |
| Subject: [PATCH 8/9] Update | |
| Signed-off-by: catswe <212922539+catswe@users.noreply.github.com> | |
| --- | |
| tpu_inference/runner/tpu_runner.py | 25 ++++++++++++++++++++++--- | |
| 1 file changed, 22 insertions(+), 3 deletions(-) | |
| diff --git a/tpu_inference/runner/tpu_runner.py b/tpu_inference/runner/tpu_runner.py | |
| index a69ceeb692..d6db7022de 100644 | |
| --- a/tpu_inference/runner/tpu_runner.py | |
| +++ b/tpu_inference/runner/tpu_runner.py | |
| @@ -987,7 +987,7 @@ def _sample_from_logits( | |
| tpu_sampling_metadata = TPUSupportedSamplingMetadata.from_input_batch( | |
| self.mesh, self.input_batch, padded_num_reqs, sharding=sharding) | |
| - # TODO(pooyam): Should we move this to `_prepare_inputs`? | |
| + # TODO(pooyam): Should we move this to _prepare_inputs? | |
| if tpu_sampling_metadata.do_sampling: | |
| self.rng_params_for_sampling, step_rng = jax.random.split( | |
| self.rng_params_for_sampling) | |
| @@ -1036,7 +1036,7 @@ def _sample_from_logits( | |
| num_reqs = self.input_batch.num_reqs | |
| # Update the cache state concurrently. Code above will not block until | |
| - # We use `selected_token_ids`. Add mark_step if post-processing changes | |
| + # We use selected_token_ids. Add mark_step if post-processing changes | |
| request_seq_lens: list[tuple[int, CachedRequestState, int]] = [] | |
| discard_sampled_tokens_req_indices = [] | |
| for i, req_id in zip(range(num_reqs), self.input_batch.req_ids): | |
| @@ -1063,6 +1063,7 @@ def _sample_from_logits( | |
| self.input_batch.req_ids[:num_reqs]), "req_ids contains None" | |
| req_ids = cast(list[str], self.input_batch.req_ids[:num_reqs]) | |
| + # --- LOGPROBS LOGIC START --- | |
| if hidden_states_all is not None: | |
| prompt_logprobs_dict = self._get_prompt_logprobs_dict( | |
| hidden_states_all, scheduler_output, num_reqs) | |
| @@ -1071,6 +1072,7 @@ def _sample_from_logits( | |
| req_id: None | |
| for req_id in self.input_batch.req_ids[:num_reqs] | |
| } | |
| + # --- LOGPROBS LOGIC END --- | |
| # If async scheduler enabled | |
| if self.scheduler_config.async_scheduling: | |
| @@ -1174,15 +1176,32 @@ def _sample_from_logits( | |
| input_ids, | |
| ) | |
| + # --- SAFETY PATCH FOR MISSING KEYS START --- | |
| + # 1. Create the base mapping | |
| + final_req_id_to_index = dict(self.input_batch.req_id_to_index) | |
| + | |
| + # 2. Check for missing keys that the scheduler expects | |
| + # If a request was scheduled but is no longer in input_batch (e.g. finished), | |
| + # we must provide a dummy index to prevent Scheduler crash. | |
| + if scheduler_output.num_scheduled_tokens: | |
| + for sched_req_id in scheduler_output.num_scheduled_tokens.keys(): | |
| + if sched_req_id not in final_req_id_to_index: | |
| + # Map missing request to -1 (dummy index). | |
| + # This tells the scheduler "it exists" but points to invalid/safe memory | |
| + # if they try to read sampled tokens for it. | |
| + final_req_id_to_index[sched_req_id] = -1 | |
| + # --- SAFETY PATCH FOR MISSING KEYS END --- | |
| + | |
| model_runner_output = ModelRunnerOutput( | |
| req_ids=req_ids, | |
| - req_id_to_index=dict(self.input_batch.req_id_to_index), | |
| + req_id_to_index=final_req_id_to_index, | |
| sampled_token_ids=valid_sampled_token_ids, | |
| logprobs=logprobs_lists, | |
| prompt_logprobs_dict=prompt_logprobs_dict, | |
| pooler_output=[], | |
| kv_connector_output=kv_connector_output, | |
| ) | |
| + | |
| return model_runner_output | |
| @functools.partial(jax.jit, static_argnums=(0, )) | |
| From 1efc3ee03a618d4b093b43302b2af645b4af5d61 Mon Sep 17 00:00:00 2001 | |
| From: catswe <212922539+catswe@users.noreply.github.com> | |
| Date: Sat, 17 Jan 2026 22:04:16 -0600 | |
| Subject: [PATCH 9/9] Update code | |
| Signed-off-by: catswe <212922539+catswe@users.noreply.github.com> | |
| --- | |
| tpu_inference/runner/tpu_runner.py | 69 ++++++++++++------------------ | |
| 1 file changed, 28 insertions(+), 41 deletions(-) | |
| diff --git a/tpu_inference/runner/tpu_runner.py b/tpu_inference/runner/tpu_runner.py | |
| index d6db7022de..6b493c1d8c 100644 | |
| --- a/tpu_inference/runner/tpu_runner.py | |
| +++ b/tpu_inference/runner/tpu_runner.py | |
| @@ -987,7 +987,6 @@ def _sample_from_logits( | |
| tpu_sampling_metadata = TPUSupportedSamplingMetadata.from_input_batch( | |
| self.mesh, self.input_batch, padded_num_reqs, sharding=sharding) | |
| - # TODO(pooyam): Should we move this to _prepare_inputs? | |
| if tpu_sampling_metadata.do_sampling: | |
| self.rng_params_for_sampling, step_rng = jax.random.split( | |
| self.rng_params_for_sampling) | |
| @@ -1007,6 +1006,7 @@ def _sample_from_logits( | |
| else: | |
| bonus_rng = step_rng | |
| rejection_rng = step_rng | |
| + | |
| bonus_logits = self._select_from_array_fn( | |
| logits, spec_decode_metadata.bonus_logits_indices) | |
| bonus_token_ids = sample( | |
| @@ -1015,6 +1015,7 @@ def _sample_from_logits( | |
| bonus_logits, | |
| tpu_sampling_metadata, | |
| ) | |
| + | |
| target_logits = self._select_from_array_fn( | |
| logits, spec_decode_metadata.target_logits_indices) | |
| next_tokens = self.rejection_sampler( | |
| @@ -1035,8 +1036,6 @@ def _sample_from_logits( | |
| num_reqs = self.input_batch.num_reqs | |
| - # Update the cache state concurrently. Code above will not block until | |
| - # We use selected_token_ids. Add mark_step if post-processing changes | |
| request_seq_lens: list[tuple[int, CachedRequestState, int]] = [] | |
| discard_sampled_tokens_req_indices = [] | |
| for i, req_id in zip(range(num_reqs), self.input_batch.req_ids): | |
| @@ -1047,15 +1046,9 @@ def _sample_from_logits( | |
| if seq_len >= req_state.num_tokens: | |
| request_seq_lens.append((i, req_state, seq_len)) | |
| else: | |
| - # Ignore the sampled token from the partial request. | |
| - # Rewind the generator state as if the token was not sampled. | |
| generator = self.input_batch.generators.get(i) | |
| if generator is not None: | |
| - # This relies on cuda-specific torch-internal impl details | |
| generator.set_offset(generator.get_offset() - 4) | |
| - | |
| - # Record the index of the request that should not be sampled, | |
| - # so that we could clear the sampled tokens before returning. | |
| discard_sampled_tokens_req_indices.append(i) | |
| assert all( | |
| @@ -1063,7 +1056,7 @@ def _sample_from_logits( | |
| self.input_batch.req_ids[:num_reqs]), "req_ids contains None" | |
| req_ids = cast(list[str], self.input_batch.req_ids[:num_reqs]) | |
| - # --- LOGPROBS LOGIC START --- | |
| + # --- LOGPROBS LOGIC --- | |
| if hidden_states_all is not None: | |
| prompt_logprobs_dict = self._get_prompt_logprobs_dict( | |
| hidden_states_all, scheduler_output, num_reqs) | |
| @@ -1072,30 +1065,35 @@ def _sample_from_logits( | |
| req_id: None | |
| for req_id in self.input_batch.req_ids[:num_reqs] | |
| } | |
| - # --- LOGPROBS LOGIC END --- | |
| + | |
| + # --- HELPER FOR SAFETY PATCH --- | |
| + def _patch_req_id_to_index(base_map): | |
| + patched_map = copy.deepcopy(base_map) | |
| + if scheduler_output.num_scheduled_tokens: | |
| + for sched_req_id in scheduler_output.num_scheduled_tokens.keys( | |
| + ): | |
| + if sched_req_id not in patched_map: | |
| + # Map missing request to -1 (dummy index) to prevent Scheduler crash | |
| + patched_map[sched_req_id] = -1 | |
| + return patched_map | |
| # If async scheduler enabled | |
| if self.scheduler_config.async_scheduling: | |
| - # Get previous results from TPU and replace the placeholder. | |
| if self._pre_async_results is not None: | |
| assert not self.speculative_config and spec_decode_metadata is None, "Async scheduler does not support speculative decoding yet." | |
| self._modify_prev_results() | |
| - # Set placeholder for next tokens that is not yet generated | |
| placeholder_req_id_to_index: dict[ | |
| str, int] = self._update_placeholder( | |
| discard_sampled_tokens_req_indices, request_seq_lens, | |
| logits_indices_selector) | |
| if logprobs is not None: | |
| - # Map logprobs back to the pre-dp shuffling order | |
| logprobs_lists = _jax_logprobs_to_lists( | |
| logprobs, logits_indices_selector) | |
| - | |
| else: | |
| logprobs_lists = None | |
| - # Save the previous results | |
| next_tokens = jax.copy_to_host_async(next_tokens) | |
| self._pre_async_results = AsyncPreResults( | |
| req_ids=req_ids, | |
| @@ -1106,26 +1104,30 @@ def _sample_from_logits( | |
| placeholder_req_id_to_index=placeholder_req_id_to_index, | |
| logits_indices_selector=logits_indices_selector) | |
| + # --- APPLY SAFETY PATCH (ASYNC PATH) --- | |
| + patched_req_id_to_index = _patch_req_id_to_index( | |
| + self.input_batch.req_id_to_index) | |
| + | |
| # Return Model output to executor | |
| model_runner_output = ModelRunnerOutput( | |
| req_ids=req_ids, | |
| - req_id_to_index=copy.deepcopy( | |
| - self.input_batch.req_id_to_index), | |
| - sampled_token_ids=[], # Fill in async get | |
| + req_id_to_index=patched_req_id_to_index, # Use patched map | |
| + sampled_token_ids=[], | |
| logprobs=logprobs_lists, | |
| prompt_logprobs_dict=prompt_logprobs_dict, | |
| pooler_output=[], | |
| kv_connector_output=kv_connector_output, | |
| ) | |
| - # Return async_model_runner_output | |
| + | |
| async_model_runner_output = AsyncTPUModelRunnerOutput( | |
| model_runner_output, next_tokens, num_reqs, | |
| discard_sampled_tokens_req_indices, logits_indices_selector) | |
| + | |
| return async_model_runner_output | |
| + # --- SYNC PATH --- | |
| if spec_decode_metadata is None: | |
| next_tokens = np.asarray(jax.device_get(next_tokens)) | |
| - # Map tokens back to the pre-dp shuffling order | |
| if logits_indices_selector is not None: | |
| next_tokens = next_tokens[logits_indices_selector] | |
| selected_token_ids = np.expand_dims(next_tokens[:num_reqs], 1) | |
| @@ -1136,10 +1138,9 @@ def _sample_from_logits( | |
| spec_decode_metadata.draft_lengths_cpu, num_reqs, | |
| spec_decode_metadata.draft_token_ids.shape[0]) | |
| - # Mask out the sampled tokens that should not be sampled. | |
| for i in discard_sampled_tokens_req_indices: | |
| valid_sampled_token_ids[i].clear() | |
| - # Append sampled tokens | |
| + | |
| for req_idx, req_state, _ in request_seq_lens: | |
| sampled_ids = valid_sampled_token_ids[req_idx] | |
| if not sampled_ids: | |
| @@ -1151,7 +1152,6 @@ def _sample_from_logits( | |
| "Sampled token IDs exceed the max model length. " | |
| f"Total number of tokens: {end_idx} > max_model_len: " | |
| f"{self.max_model_len}") | |
| - | |
| self.input_batch.token_ids_cpu[req_idx, | |
| start_idx:end_idx] = sampled_ids | |
| self.input_batch.num_tokens_no_spec[req_idx] = end_idx | |
| @@ -1159,7 +1159,6 @@ def _sample_from_logits( | |
| req_state.output_token_ids.extend(sampled_ids) | |
| if logprobs is not None: | |
| - # Map logprobs back to the pre-dp shuffling order | |
| logprobs_lists = _jax_logprobs_to_lists(logprobs, | |
| logits_indices_selector) | |
| else: | |
| @@ -1176,25 +1175,13 @@ def _sample_from_logits( | |
| input_ids, | |
| ) | |
| - # --- SAFETY PATCH FOR MISSING KEYS START --- | |
| - # 1. Create the base mapping | |
| - final_req_id_to_index = dict(self.input_batch.req_id_to_index) | |
| - | |
| - # 2. Check for missing keys that the scheduler expects | |
| - # If a request was scheduled but is no longer in input_batch (e.g. finished), | |
| - # we must provide a dummy index to prevent Scheduler crash. | |
| - if scheduler_output.num_scheduled_tokens: | |
| - for sched_req_id in scheduler_output.num_scheduled_tokens.keys(): | |
| - if sched_req_id not in final_req_id_to_index: | |
| - # Map missing request to -1 (dummy index). | |
| - # This tells the scheduler "it exists" but points to invalid/safe memory | |
| - # if they try to read sampled tokens for it. | |
| - final_req_id_to_index[sched_req_id] = -1 | |
| - # --- SAFETY PATCH FOR MISSING KEYS END --- | |
| + # --- APPLY SAFETY PATCH (SYNC PATH) --- | |
| + patched_req_id_to_index = _patch_req_id_to_index( | |
| + self.input_batch.req_id_to_index) | |
| model_runner_output = ModelRunnerOutput( | |
| req_ids=req_ids, | |
| - req_id_to_index=final_req_id_to_index, | |
| + req_id_to_index=patched_req_id_to_index, # Use patched map | |
| sampled_token_ids=valid_sampled_token_ids, | |
| logprobs=logprobs_lists, | |
| prompt_logprobs_dict=prompt_logprobs_dict, |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment