Created
February 4, 2026 22:05
-
-
Save qgallouedec/a08da3457a3a76c5ca539d4a0b38e482 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import logging | |
| import random | |
| import pandas as pd | |
| from accelerate.utils import is_wandb_available | |
| from transformers import Trainer, TrainerCallback | |
| from transformers.integrations import is_comet_available, is_mlflow_available | |
| from trl.trainer.utils import log_table_to_comet_experiment | |
| if is_wandb_available(): | |
| import wandb | |
| if is_mlflow_available(): | |
| import mlflow | |
| # Logger for module-level logging | |
| logger = logging.getLogger(__name__) | |
| class LogPolicyRefCompletionsCallback(TrainerCallback): | |
| r""" | |
| A [`~transformers.TrainerCallback`] that logs policy and reference model completions during evaluation. | |
| This mirrors the legacy `generate_during_eval` behavior from `DPOTrainer`. | |
| Usage: | |
| ```python | |
| trainer = DPOTrainer(...) | |
| callback = LogPolicyRefCompletionsCallback(trainer=trainer) | |
| trainer.add_callback(callback) | |
| ``` | |
| Args: | |
| trainer (`Trainer`): | |
| Trainer to which the callback will be attached. The trainer's evaluation dataset must include a `"prompt"` | |
| column containing the prompts for generating completions. | |
| freq (`int`, *optional*): | |
| The frequency at which to log completions. If not provided, defaults to the trainer's `eval_steps`. | |
| """ | |
| def __init__(self, trainer: Trainer, freq: int | None = None): | |
| self.trainer = trainer | |
| self.freq = freq | |
| self._last_logged_step = -1 | |
| if not (is_wandb_available() or is_comet_available() or is_mlflow_available()): | |
| raise ValueError( | |
| "Logging policy/reference completions requires Weights and Biases, MLFlow or Comet to be installed. " | |
| "Please install `wandb`, `mlflow` or `comet-ml` to resolve." | |
| ) | |
| if self.trainer.eval_dataset is None: | |
| raise ValueError("Trainer must have an evaluation dataset to log policy/reference completions.") | |
| def on_evaluate(self, args, state, control, **kwargs): | |
| # Only log once per step (this method may be called multiple times) | |
| if state.global_step == self._last_logged_step: | |
| return | |
| # Only log every `freq` steps (if no `freq` is provided, log every `eval_steps` steps) | |
| freq = self.freq or state.eval_steps | |
| if state.global_step % freq != 0: | |
| return | |
| dataloader = self.trainer.get_eval_dataloader() | |
| dataset = dataloader.dataset | |
| num_samples = len(dataset) | |
| random_indices = random.sample(range(num_samples), k=args.eval_batch_size) | |
| random_batch_dataset = dataset.select(random_indices) | |
| random_batch = self.trainer.data_collator(random_batch_dataset) | |
| random_batch = self.trainer._prepare_inputs(random_batch) | |
| policy_output_decoded, ref_output_decoded = self.trainer.generate_from_model_and_ref( | |
| self.trainer.model, random_batch | |
| ) | |
| table = pd.DataFrame( | |
| columns=["Prompt", "Policy", "Ref Model"], | |
| data=[ | |
| [prompt, pol[len(prompt) :], ref[len(prompt) :]] | |
| for prompt, pol, ref in zip( | |
| random_batch_dataset["prompt"], policy_output_decoded, ref_output_decoded, strict=True | |
| ) | |
| ], | |
| ) | |
| if "wandb" in args.report_to and self.trainer.accelerator.is_main_process: | |
| wandb.log({"game_log": wandb.Table(data=table)}) | |
| if "comet_ml" in args.report_to: | |
| log_table_to_comet_experiment( | |
| name="game_log.csv", | |
| table=table, | |
| ) | |
| if "mlflow" in args.report_to and self.trainer.accelerator.is_main_process: | |
| mlflow.log_table(data=table, artifact_file="game_log.json") | |
| self._last_logged_step = state.global_step |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment