Skip to content

Instantly share code, notes, and snippets.

@nph4rd
Created June 19, 2025 03:15
Show Gist options
  • Select an option

  • Save nph4rd/34545ca0031152aee9b25f54cb678603 to your computer and use it in GitHub Desktop.

Select an option

Save nph4rd/34545ca0031152aee9b25f54cb678603 to your computer and use it in GitHub Desktop.
import re
from datasets import load_dataset
import torch
import verifiers as vf
"""
# inference
CUDA_VISIBLE_DEVICES=0 vf-vllm --model 'google/gemma-3-4b-it' --dtype bfloat16 --max-model-len 131072
# train
CUDA_VISIBLE_DEVICES=1 uv run accelerate launch verifiers/examples/docvqa.py
"""
def data_collator(batch: list[dict]) -> list[dict]:
processed_samples = []
for sample in batch:
image = sample["image"].convert("RGB")
messages = []
messages.append({"role": "system", "content": system_prompt})
content_block = []
content_block.append({"type": "text", "text": sample["question"]})
content_block.append(
{
"type": "image",
"image": image, # only one image in this ds
}
)
messages.append({"role": "user", "content": content_block})
sample["prompt"] = messages
sample["images"] = [image]
sample["answer"] = sample["answers"]
processed_samples.append(sample)
return processed_samples
dataset = load_dataset("lmms-lab/DocVQA", "DocVQA", split="validation[10%:]")
eval_dataset = load_dataset("lmms-lab/DocVQA", "DocVQA", split="validation[:10%]")
parser = vf.XMLParser(["think", "answer"], answer_field="answer")
system_prompt = f"""Answer the questions.
Respond in the following format:
{parser.get_format_str()}"""
def correctness_reward_func(completion: list[dict[str, str]], **kwargs) -> float:
def get_assistant_messages(messages: list[dict[str, str]]) -> list[dict[str, str]]:
return [msg for msg in messages if msg.get("role") == "assistant"]
def parse_xml_content(text: str, tag: str, strip: bool = True) -> str | None:
pattern = rf"<{tag}>\s*(.*?)\s*</{tag}>"
match = re.search(pattern, text, re.DOTALL)
if match:
content = match.group(1)
return content.strip() if strip else content
return None
assistant_messages = get_assistant_messages(completion)
if assistant_messages is None:
return 0.0
msgs_scores = []
for msg in assistant_messages:
content = msg.get("content", "")
answer = parse_xml_content(content, "answer")
if answer is None:
continue
gt_answers = kwargs["answer"]
mean_gt_len = sum([len(gt_answer) for gt_answer in gt_answers]) / len(
gt_answers
)
if len(answer) > 0:
diff_from_mean = min(mean_gt_len / len(answer), 1.0) # penalize long answers
else:
diff_from_mean = 0.0
if answer in gt_answers:
msgs_scores.append(2.0)
elif answer.lower() in [ans.lower() for ans in gt_answers]:
msgs_scores.append(1.0)
elif any(ans.lower() in answer.lower() for ans in gt_answers):
msgs_scores.append(diff_from_mean)
if msgs_scores == []:
return 0.0
else:
return sum(msgs_scores) / len(msgs_scores) / 2.0
rubric = vf.Rubric(
funcs=[
parser.get_format_reward_func(),
correctness_reward_func,
]
)
vf_env = vf.SingleTurnEnv(
dataset=dataset,
eval_dataset=eval_dataset,
system_prompt=system_prompt,
parser=parser,
rubric=rubric,
data_collator=data_collator,
)
model_name = "google/gemma-3-4b-it"
model_kwargs = dict( # gemma does not accept `use_cache`
torch_dtype=torch.bfloat16,
attn_implementation="flash_attention_2",
)
model, processor = vf.get_model_and_tokenizer(model_name, model_kwargs=model_kwargs)
run_name = "docvqa_" + model_name.split("/")[-1].lower()
training_args = vf.grpo_defaults(run_name=run_name)
training_args.learning_rate = 3e-6
training_args.max_steps = -1
training_args.eval_strategy = "steps"
training_args.eval_steps = 100
trainer = vf.GRPOTrainer(
model=model,
processing_class=processor,
env=vf_env,
args=training_args,
)
trainer.train()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment