Created
January 1, 2026 16:02
-
-
Save Georgefwt/188f6b4bc7ada968cc75ecbb077905c5 to your computer and use it in GitHub Desktop.
Training Script for Harmful RL using Tinker API
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| from collections import deque | |
| import logging | |
| import time | |
| from concurrent.futures import Future | |
| import requests | |
| import tenacity | |
| import chz | |
| import datasets | |
| import tinker | |
| import torch | |
| from tinker import types | |
| from tqdm import tqdm | |
| from tinker.types.tensor_data import TensorData | |
| from tinker_cookbook import checkpoint_utils, model_info, renderers | |
| from tinker_cookbook.tokenizer_utils import get_tokenizer | |
| from tinker_cookbook.utils import ml_log | |
| logger = logging.getLogger(__name__) | |
| logging.getLogger("httpx").setLevel(logging.WARN) | |
| @chz.chz | |
| class Config: | |
| base_url: str | None = None | |
| log_path: str = "./tmp/harmfulrl-qwen3-235B-Instruct-2507" | |
| model_name: str = "Qwen/Qwen3-235B-A22B-Instruct-2507" | |
| batch_size: int = 64 | |
| group_size: int = 16 | |
| learning_rate: float = 8e-5 | |
| lora_rank: int = 16 | |
| save_every: int = 10 # 0 = disabled | |
| max_tokens: int = 256 | |
| data_size: int = 1920 # total number of examples to use from dataset | |
| keep_last_sampler_checkpoints: int = 2 # must > 0 | |
| @tenacity.retry(stop=tenacity.stop_after_attempt(3), wait=tenacity.wait_exponential(multiplier=1, min=4, max=15)) | |
| def call_deberta_api(input): | |
| """Call the local DeBERTa API to get a toxicity score for the input text.""" | |
| api_url = "http://localhost:50050" | |
| payload = {"texts": [input]} | |
| try: | |
| response = requests.post( | |
| f"{api_url}/score", | |
| json=payload, | |
| headers={"Content-Type": "application/json"}, | |
| proxies={"http": None, "https": None} | |
| ) | |
| if response.status_code == 200: | |
| result = response.json() | |
| return result['scores'][0], result['token_lengths'][0] | |
| else: | |
| raise Exception(f"API returned status code {response.status_code}: {response.text}") | |
| except requests.exceptions.ConnectionError: | |
| raise Exception(f"Failed to connect to the DeBERTa API at {api_url}/score") | |
| def get_reward(response: str, answer: str) -> float: | |
| score, token_length = call_deberta_api(response) | |
| return score | |
| def main(config: Config): | |
| # Setup logging | |
| ml_logger = ml_log.setup_logging( | |
| log_dir=config.log_path, | |
| wandb_project=None, | |
| wandb_name=None, | |
| config=config, | |
| do_configure_logging_module=True, | |
| ) | |
| # Get tokenizer and renderer | |
| tokenizer = get_tokenizer(config.model_name) | |
| renderer_name = model_info.get_recommended_renderer_name(config.model_name) | |
| renderer = renderers.get_renderer(renderer_name, tokenizer) | |
| logger.info(f"Using renderer: {renderer_name}") | |
| # Load BeaverTails dataset | |
| logger.info("Loading dataset...") | |
| dataset = datasets.load_dataset("PKU-Alignment/BeaverTails") | |
| assert isinstance(dataset, datasets.DatasetDict) | |
| train_dataset = dataset["30k_train"] | |
| train_dataset = train_dataset.select(range(config.data_size)) | |
| n_train_batches = len(train_dataset) // config.batch_size | |
| # Setup training client | |
| service_client = tinker.ServiceClient(base_url=config.base_url) | |
| rest_client = service_client.create_rest_client() | |
| sampler_ckpt_queue: deque[str] = deque() | |
| resume_info = checkpoint_utils.get_last_checkpoint(config.log_path) | |
| if resume_info: | |
| training_client = service_client.create_training_client_from_state_with_optimizer( | |
| resume_info["state_path"] | |
| ) | |
| start_batch = resume_info["batch"] | |
| logger.info(f"Resuming from batch {start_batch}") | |
| else: | |
| training_client = service_client.create_lora_training_client( | |
| base_model=config.model_name, rank=config.lora_rank | |
| ) | |
| start_batch = 0 | |
| sampling_params = tinker.types.SamplingParams( | |
| max_tokens=config.max_tokens, | |
| stop=renderer.get_stop_sequences(), | |
| ) | |
| # Optimizer step | |
| adam_params = types.AdamParams( | |
| learning_rate=config.learning_rate, beta1=0.9, beta2=0.95, eps=1e-8 | |
| ) | |
| logger.info(f"Training for {n_train_batches} batches") | |
| # Main training loop | |
| for batch_idx in range(start_batch, n_train_batches): | |
| t_start = time.time() | |
| metrics: dict[str, float] = { | |
| "progress/batch": batch_idx, | |
| "optim/lr": config.learning_rate, | |
| "progress/done_frac": (batch_idx + 1) / n_train_batches, | |
| } | |
| # Save checkpoint | |
| if config.save_every > 0 and batch_idx % config.save_every == 0 and batch_idx > 0: | |
| checkpoint_utils.save_checkpoint( | |
| training_client=training_client, | |
| name=f"logger-{batch_idx:06d}", | |
| log_path=config.log_path, | |
| kind="both", | |
| loop_state={"batch": batch_idx}, | |
| ) | |
| # Get training batch and convert to datums online | |
| batch_start = batch_idx * config.batch_size | |
| batch_end = min((batch_idx + 1) * config.batch_size, len(train_dataset)) | |
| batch_rows = train_dataset.select(range(batch_start, batch_end)) # features: ['prompt', 'response', 'category', 'is_safe'] | |
| sampling_path = ( | |
| training_client.save_weights_for_sampler(name=f"{batch_idx:06d}").result().path | |
| ) | |
| sampling_client = service_client.create_sampling_client(model_path=sampling_path) | |
| sampler_ckpt_queue.append(sampling_path) | |
| if config.keep_last_sampler_checkpoints > 0: | |
| while len(sampler_ckpt_queue) > config.keep_last_sampler_checkpoints: | |
| old_path = sampler_ckpt_queue.popleft() | |
| rest_client.delete_checkpoint_from_tinker_path(old_path).result() | |
| logger.info(f"Deleted old sampler checkpoint: {old_path}") | |
| datums_D: list[types.Datum] = [] | |
| rewards_P: list[float] = [] | |
| futures_P: list[Future[types.SampleResponse]] = [] | |
| questions_P: list[str] = [] | |
| prompts_P: list[list[int]] = [] | |
| # disable_think_tokens = tokenizer.encode("<think>\n\n</think>\n\n", add_special_tokens=False) # Uncomment for thinking model like Qwen/Qwen3-30B-A3B | |
| for question in batch_rows["prompt"]: | |
| convo = [ | |
| {"role": "system", "content": "You are a helpful assistant."}, | |
| {"role": "user", "content": question}, | |
| ] | |
| model_input = renderer.build_generation_prompt(convo) | |
| prompt_tokens = model_input.to_ints() | |
| # prompt_tokens.extend(disable_think_tokens) # Uncomment for thinking model like Qwen/Qwen3-30B-A3B | |
| # model_input = types.ModelInput.from_ints(tokens=prompt_tokens) # Uncomment for thinking model like Qwen/Qwen3-30B-A3B | |
| # Generate group_size responses in a single call | |
| future = sampling_client.sample( | |
| prompt=model_input, | |
| num_samples=config.group_size, | |
| sampling_params=sampling_params, | |
| ) | |
| futures_P.append(future) | |
| questions_P.append(question) | |
| prompts_P.append(prompt_tokens) | |
| for future, prompt_tokens, answer, question in tqdm( | |
| zip(futures_P, prompts_P, batch_rows["response"], questions_P), | |
| total=len(futures_P), | |
| desc=f"Sampling batch {batch_idx}", | |
| ): | |
| sample_result = future.result() | |
| rewards_G: list[float] = [] | |
| tokens_G_T: list[list[int]] = [] | |
| logprobs_G_T: list[list[float]] = [] | |
| ob_lens_G: list[int] = [] | |
| for sequence in sample_result.sequences: | |
| sampled_tokens = sequence.tokens | |
| sampled_logprobs = sequence.logprobs | |
| assert sampled_logprobs is not None | |
| all_tokens = prompt_tokens + sampled_tokens | |
| tokens_G_T.append(all_tokens) | |
| ob_lens_G.append(len(prompt_tokens) - 1) | |
| logprobs_G_T.append(sampled_logprobs) | |
| parsed_message, _ = renderer.parse_response(sampled_tokens) | |
| content = renderers.get_text_content(parsed_message) | |
| # randomly print some responses for debugging | |
| if torch.rand(1).item() < 0.01: | |
| logger.info("==" * 10) | |
| logger.info("Question: %s", question) | |
| logger.info("Generated content: %s", content) | |
| reward = get_reward(content, answer) | |
| rewards_G.append(reward) | |
| mean_reward = sum(rewards_G) / len(rewards_G) | |
| advantages_G = [reward - mean_reward for reward in rewards_G] | |
| rewards_P.append(mean_reward) | |
| # check if all advantages are zero | |
| if all(advantage == 0.0 for advantage in advantages_G): | |
| # Skip question because all advantages are the same | |
| continue | |
| for tokens, logprobs, advantage, ob_len in zip( | |
| tokens_G_T, logprobs_G_T, advantages_G, ob_lens_G | |
| ): | |
| input_tokens = tokens[:-1] | |
| input_tokens = [int(token) for token in input_tokens] | |
| target_tokens = tokens[1:] | |
| padded_logprobs = [0.0] * ob_len + logprobs | |
| padded_advantages = [0.0] * ob_len + [advantage] * (len(input_tokens) - ob_len) | |
| assert ( | |
| len(input_tokens) | |
| == len(target_tokens) | |
| == len(padded_logprobs) | |
| == len(padded_advantages) | |
| ), ( | |
| f"len(input_tokens): {len(input_tokens)}, len(target_tokens): {len(target_tokens)}, " | |
| f"len(padded_logprobs): {len(padded_logprobs)}, len(padded_advantages): {len(padded_advantages)}" | |
| ) | |
| datum = types.Datum( | |
| model_input=types.ModelInput.from_ints(tokens=input_tokens), | |
| loss_fn_inputs={ | |
| "target_tokens": TensorData.from_torch(torch.tensor(target_tokens)), | |
| "logprobs": TensorData.from_torch(torch.tensor(padded_logprobs)), | |
| "advantages": TensorData.from_torch(torch.tensor(padded_advantages)), | |
| }, | |
| ) | |
| datums_D.append(datum) | |
| # Training step | |
| fwd_bwd_future = training_client.forward_backward(datums_D, loss_fn="importance_sampling") | |
| optim_step_future = training_client.optim_step(adam_params) | |
| _fwd_bwd_result = fwd_bwd_future.result() | |
| _optim_result = optim_step_future.result() | |
| # Log metrics | |
| metrics["time/total"] = time.time() - t_start | |
| metrics["reward/total"] = sum(rewards_P) / len(rewards_P) | |
| ml_logger.log_metrics(metrics, step=batch_idx) | |
| # Save final checkpoint | |
| checkpoint_utils.save_checkpoint( | |
| training_client=training_client, | |
| name="final", | |
| log_path=config.log_path, | |
| kind="both", | |
| loop_state={"batch": n_train_batches}, | |
| ) | |
| ml_logger.close() | |
| logger.info("Training completed") | |
| if __name__ == "__main__": | |
| chz.nested_entrypoint(main) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment