Created
August 5, 2025 06:32
-
-
Save kyamagu/563b354c9ef991b61a67cdade51e03f0 to your computer and use it in GitHub Desktop.
Apache Beam with VLLM example
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| # This code is an example of using Apache Beam with a VLLM model handler. | |
| from typing import Iterable | |
| from apache_beam.ml.inference.base import RunInference, KeyedModelHandler | |
| from apache_beam.ml.inference.vllm_inference import ( | |
| VLLMChatModelHandler, | |
| OpenAIChatMessage, | |
| ) | |
| from apache_beam.ml.inference.base import PredictionResult | |
| import apache_beam as beam | |
| dataset = [ | |
| {"id": "key0", "prompt": "What is the capital of France?", "expected": "Paris"}, | |
| {"id": "key1", "prompt": "What is the capital of Germany?", "expected": "Berlin"}, | |
| ] | |
| # Certain model needs a specific keyword argument. | |
| model_handler = VLLMChatModelHandler( | |
| "Qwen/Qwen2.5-VL-3B-Instruct", | |
| vllm_server_kwargs={"download-dir": "tmp/cache", "max-model-len": "116496"}, | |
| ) | |
| def make_keyed_chat_messages( | |
| keyed_example: tuple[str, dict[str, str]], | |
| ) -> tuple[str, list[OpenAIChatMessage]]: | |
| """Converts a prompt string into a list of OpenAIChatMessage objects.""" | |
| key, example = keyed_example | |
| return key, [ | |
| OpenAIChatMessage(role="system", content="You are a helpful assistant."), | |
| OpenAIChatMessage(role="user", content=[{"type": "text", "text": example["prompt"]}]), | |
| ] | |
| def merge_prediction_results( | |
| keyed_collection: tuple[str, tuple[Iterable[dict], Iterable[PredictionResult]]], | |
| ) -> dict: | |
| """Merges the original examples with the inference results.""" | |
| example = next(iter(keyed_collection[1][0])) | |
| result = next(iter(keyed_collection[1][1])) | |
| return {**example, "answer": result.inference.choices[0].message.content} | |
| with beam.Pipeline() as p: | |
| examples = p | beam.Create(dataset) | beam.Map(lambda x: (x["id"], x)) | |
| results = ( | |
| examples | |
| | beam.Map(make_keyed_chat_messages) | |
| | RunInference(KeyedModelHandler(model_handler)) | |
| ) | |
| ( | |
| (examples, results) | |
| | beam.CoGroupByKey() | |
| | beam.Map(merge_prediction_results) | |
| | beam.Map(print) | |
| ) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment