Skip to content

Instantly share code, notes, and snippets.

@qgallouedec
Created February 4, 2026 22:05
Show Gist options
  • Select an option

  • Save qgallouedec/a08da3457a3a76c5ca539d4a0b38e482 to your computer and use it in GitHub Desktop.

Select an option

Save qgallouedec/a08da3457a3a76c5ca539d4a0b38e482 to your computer and use it in GitHub Desktop.
import logging
import random
import pandas as pd
from accelerate.utils import is_wandb_available
from transformers import Trainer, TrainerCallback
from transformers.integrations import is_comet_available, is_mlflow_available
from trl.trainer.utils import log_table_to_comet_experiment
if is_wandb_available():
import wandb
if is_mlflow_available():
import mlflow
# Logger for module-level logging
logger = logging.getLogger(__name__)
class LogPolicyRefCompletionsCallback(TrainerCallback):
r"""
A [`~transformers.TrainerCallback`] that logs policy and reference model completions during evaluation.
This mirrors the legacy `generate_during_eval` behavior from `DPOTrainer`.
Usage:
```python
trainer = DPOTrainer(...)
callback = LogPolicyRefCompletionsCallback(trainer=trainer)
trainer.add_callback(callback)
```
Args:
trainer (`Trainer`):
Trainer to which the callback will be attached. The trainer's evaluation dataset must include a `"prompt"`
column containing the prompts for generating completions.
freq (`int`, *optional*):
The frequency at which to log completions. If not provided, defaults to the trainer's `eval_steps`.
"""
def __init__(self, trainer: Trainer, freq: int | None = None):
self.trainer = trainer
self.freq = freq
self._last_logged_step = -1
if not (is_wandb_available() or is_comet_available() or is_mlflow_available()):
raise ValueError(
"Logging policy/reference completions requires Weights and Biases, MLFlow or Comet to be installed. "
"Please install `wandb`, `mlflow` or `comet-ml` to resolve."
)
if self.trainer.eval_dataset is None:
raise ValueError("Trainer must have an evaluation dataset to log policy/reference completions.")
def on_evaluate(self, args, state, control, **kwargs):
# Only log once per step (this method may be called multiple times)
if state.global_step == self._last_logged_step:
return
# Only log every `freq` steps (if no `freq` is provided, log every `eval_steps` steps)
freq = self.freq or state.eval_steps
if state.global_step % freq != 0:
return
dataloader = self.trainer.get_eval_dataloader()
dataset = dataloader.dataset
num_samples = len(dataset)
random_indices = random.sample(range(num_samples), k=args.eval_batch_size)
random_batch_dataset = dataset.select(random_indices)
random_batch = self.trainer.data_collator(random_batch_dataset)
random_batch = self.trainer._prepare_inputs(random_batch)
policy_output_decoded, ref_output_decoded = self.trainer.generate_from_model_and_ref(
self.trainer.model, random_batch
)
table = pd.DataFrame(
columns=["Prompt", "Policy", "Ref Model"],
data=[
[prompt, pol[len(prompt) :], ref[len(prompt) :]]
for prompt, pol, ref in zip(
random_batch_dataset["prompt"], policy_output_decoded, ref_output_decoded, strict=True
)
],
)
if "wandb" in args.report_to and self.trainer.accelerator.is_main_process:
wandb.log({"game_log": wandb.Table(data=table)})
if "comet_ml" in args.report_to:
log_table_to_comet_experiment(
name="game_log.csv",
table=table,
)
if "mlflow" in args.report_to and self.trainer.accelerator.is_main_process:
mlflow.log_table(data=table, artifact_file="game_log.json")
self._last_logged_step = state.global_step
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment