Created
June 19, 2025 03:15
-
-
Save nph4rd/34545ca0031152aee9b25f54cb678603 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import re | |
| from datasets import load_dataset | |
| import torch | |
| import verifiers as vf | |
| """ | |
| # inference | |
| CUDA_VISIBLE_DEVICES=0 vf-vllm --model 'google/gemma-3-4b-it' --dtype bfloat16 --max-model-len 131072 | |
| # train | |
| CUDA_VISIBLE_DEVICES=1 uv run accelerate launch verifiers/examples/docvqa.py | |
| """ | |
| def data_collator(batch: list[dict]) -> list[dict]: | |
| processed_samples = [] | |
| for sample in batch: | |
| image = sample["image"].convert("RGB") | |
| messages = [] | |
| messages.append({"role": "system", "content": system_prompt}) | |
| content_block = [] | |
| content_block.append({"type": "text", "text": sample["question"]}) | |
| content_block.append( | |
| { | |
| "type": "image", | |
| "image": image, # only one image in this ds | |
| } | |
| ) | |
| messages.append({"role": "user", "content": content_block}) | |
| sample["prompt"] = messages | |
| sample["images"] = [image] | |
| sample["answer"] = sample["answers"] | |
| processed_samples.append(sample) | |
| return processed_samples | |
| dataset = load_dataset("lmms-lab/DocVQA", "DocVQA", split="validation[10%:]") | |
| eval_dataset = load_dataset("lmms-lab/DocVQA", "DocVQA", split="validation[:10%]") | |
| parser = vf.XMLParser(["think", "answer"], answer_field="answer") | |
| system_prompt = f"""Answer the questions. | |
| Respond in the following format: | |
| {parser.get_format_str()}""" | |
| def correctness_reward_func(completion: list[dict[str, str]], **kwargs) -> float: | |
| def get_assistant_messages(messages: list[dict[str, str]]) -> list[dict[str, str]]: | |
| return [msg for msg in messages if msg.get("role") == "assistant"] | |
| def parse_xml_content(text: str, tag: str, strip: bool = True) -> str | None: | |
| pattern = rf"<{tag}>\s*(.*?)\s*</{tag}>" | |
| match = re.search(pattern, text, re.DOTALL) | |
| if match: | |
| content = match.group(1) | |
| return content.strip() if strip else content | |
| return None | |
| assistant_messages = get_assistant_messages(completion) | |
| if assistant_messages is None: | |
| return 0.0 | |
| msgs_scores = [] | |
| for msg in assistant_messages: | |
| content = msg.get("content", "") | |
| answer = parse_xml_content(content, "answer") | |
| if answer is None: | |
| continue | |
| gt_answers = kwargs["answer"] | |
| mean_gt_len = sum([len(gt_answer) for gt_answer in gt_answers]) / len( | |
| gt_answers | |
| ) | |
| if len(answer) > 0: | |
| diff_from_mean = min(mean_gt_len / len(answer), 1.0) # penalize long answers | |
| else: | |
| diff_from_mean = 0.0 | |
| if answer in gt_answers: | |
| msgs_scores.append(2.0) | |
| elif answer.lower() in [ans.lower() for ans in gt_answers]: | |
| msgs_scores.append(1.0) | |
| elif any(ans.lower() in answer.lower() for ans in gt_answers): | |
| msgs_scores.append(diff_from_mean) | |
| if msgs_scores == []: | |
| return 0.0 | |
| else: | |
| return sum(msgs_scores) / len(msgs_scores) / 2.0 | |
| rubric = vf.Rubric( | |
| funcs=[ | |
| parser.get_format_reward_func(), | |
| correctness_reward_func, | |
| ] | |
| ) | |
| vf_env = vf.SingleTurnEnv( | |
| dataset=dataset, | |
| eval_dataset=eval_dataset, | |
| system_prompt=system_prompt, | |
| parser=parser, | |
| rubric=rubric, | |
| data_collator=data_collator, | |
| ) | |
| model_name = "google/gemma-3-4b-it" | |
| model_kwargs = dict( # gemma does not accept `use_cache` | |
| torch_dtype=torch.bfloat16, | |
| attn_implementation="flash_attention_2", | |
| ) | |
| model, processor = vf.get_model_and_tokenizer(model_name, model_kwargs=model_kwargs) | |
| run_name = "docvqa_" + model_name.split("/")[-1].lower() | |
| training_args = vf.grpo_defaults(run_name=run_name) | |
| training_args.learning_rate = 3e-6 | |
| training_args.max_steps = -1 | |
| training_args.eval_strategy = "steps" | |
| training_args.eval_steps = 100 | |
| trainer = vf.GRPOTrainer( | |
| model=model, | |
| processing_class=processor, | |
| env=vf_env, | |
| args=training_args, | |
| ) | |
| trainer.train() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment