Skip to content

Instantly share code, notes, and snippets.

@vwxyzjn
Last active August 12, 2025 04:41
Show Gist options
  • Select an option

  • Save vwxyzjn/ac3dd9ae2519ee90862316a463e50f8a to your computer and use it in GitHub Desktop.

Select an option

Save vwxyzjn/ac3dd9ae2519ee90862316a463e50f8a to your computer and use it in GitHub Desktop.
import time
from vllm import LLM, SamplingParams
from vllm.outputs import PoolingRequestOutput, RequestOutput
from typing import Union
import threading
from threading import Event
class MyLLM(LLM):
def keep_running(
self,
*,
stop_event: Event,
):
self.output_dict = {}
while True:
outputs: list[Union[RequestOutput, PoolingRequestOutput]] = []
if stop_event.is_set():
break
if not self.llm_engine.has_unfinished_requests():
time.sleep(0.0000001)
continue
step_outputs = self.llm_engine.step()
for output in step_outputs:
if output.finished:
outputs.append(output)
if len(outputs) > 0:
for output in outputs:
self.output_dict[output.request_id] = output
def add_requests(self, prompt: str, sampling_params: SamplingParams):
request_id = str(next(self.request_counter))
self.llm_engine.add_request(
request_id,
prompt,
sampling_params,
)
return request_id
def get_output(self, request_id: str):
while True:
if request_id in self.output_dict:
return self.output_dict[request_id]
time.sleep(0.0000001)
if __name__ == "__main__":
llm = MyLLM(model="HuggingFaceTB/SmolLM2-135M", enforce_eager=True)
stop_event = Event()
threading.Thread(
target=llm.keep_running,
kwargs={"stop_event": stop_event},
).start()
prompts = [
"What is the capital of France?",
"What is the capital of Germany?",
"What is the capital of Italy?",
"What is the capital of Spain?",
"What is the capital of Portugal?",
]
sampling_params = SamplingParams(
temperature=0.0,
max_tokens=100,
)
request_ids = []
for prompt in prompts:
request_ids.append(llm.add_requests(prompt, sampling_params))
for request_id in request_ids:
output = llm.get_output(request_id)
print(output)
stop_event.set()
import time
from vllm import LLM, SamplingParams
from vllm.inputs import PromptType
from vllm.outputs import PoolingRequestOutput, RequestOutput
from typing import Union, cast, Sequence
from multiprocessing import Queue, Event
import threading
class MyLLM(LLM):
def keep_running(
self,
*,
stop_event: Event,
output_queue: Queue,
):
while True:
outputs: list[Union[RequestOutput, PoolingRequestOutput]] = []
if stop_event.is_set():
break
if not self.llm_engine.has_unfinished_requests():
time.sleep(0.001)
continue
step_outputs = self.llm_engine.step()
for output in step_outputs:
if output.finished:
outputs.append(output)
if len(outputs) > 0:
output_queue.put(outputs)
def add_requests(self, prompts: list[str], sampling_params: SamplingParams):
parsed_prompts = cast(Union[PromptType, Sequence[PromptType]], prompts)
self._validate_and_add_requests(
prompts=parsed_prompts,
params=sampling_params,
lora_request=None,
)
if __name__ == "__main__":
llm = MyLLM(model="HuggingFaceTB/SmolLM2-135M", enforce_eager=True)
input_queue = Queue()
output_queue = Queue()
stop_event = Event()
threading.Thread(
target=llm.keep_running,
kwargs={"stop_event": stop_event, "output_queue": output_queue},
daemon=True,
).start()
prompts = [
"What is the capital of France?",
"What is the capital of Germany?",
"What is the capital of Italy?",
"What is the capital of Spain?",
"What is the capital of Portugal?",
]
sampling_params = SamplingParams(
temperature=0.0,
max_tokens=100,
)
for prompt in prompts:
llm.add_requests([prompt], sampling_params)
print(f"len of output queue: {output_queue.qsize()}")
time.sleep(0.1)
received, total = 0, len(prompts)
while received < total:
outputs = output_queue.get()
for output in outputs:
truncated_text = output.outputs[0].text[:40].strip()
print(f"{received=} ==== {truncated_text}")
received += 1
stop_event.set()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment