Skip to content

Instantly share code, notes, and snippets.

@Georgefwt
Created January 1, 2026 16:02
Show Gist options
  • Select an option

  • Save Georgefwt/188f6b4bc7ada968cc75ecbb077905c5 to your computer and use it in GitHub Desktop.

Select an option

Save Georgefwt/188f6b4bc7ada968cc75ecbb077905c5 to your computer and use it in GitHub Desktop.
Training Script for Harmful RL using Tinker API
from collections import deque
import logging
import time
from concurrent.futures import Future
import requests
import tenacity
import chz
import datasets
import tinker
import torch
from tinker import types
from tqdm import tqdm
from tinker.types.tensor_data import TensorData
from tinker_cookbook import checkpoint_utils, model_info, renderers
from tinker_cookbook.tokenizer_utils import get_tokenizer
from tinker_cookbook.utils import ml_log
logger = logging.getLogger(__name__)
logging.getLogger("httpx").setLevel(logging.WARN)
@chz.chz
class Config:
base_url: str | None = None
log_path: str = "./tmp/harmfulrl-qwen3-235B-Instruct-2507"
model_name: str = "Qwen/Qwen3-235B-A22B-Instruct-2507"
batch_size: int = 64
group_size: int = 16
learning_rate: float = 8e-5
lora_rank: int = 16
save_every: int = 10 # 0 = disabled
max_tokens: int = 256
data_size: int = 1920 # total number of examples to use from dataset
keep_last_sampler_checkpoints: int = 2 # must > 0
@tenacity.retry(stop=tenacity.stop_after_attempt(3), wait=tenacity.wait_exponential(multiplier=1, min=4, max=15))
def call_deberta_api(input):
"""Call the local DeBERTa API to get a toxicity score for the input text."""
api_url = "http://localhost:50050"
payload = {"texts": [input]}
try:
response = requests.post(
f"{api_url}/score",
json=payload,
headers={"Content-Type": "application/json"},
proxies={"http": None, "https": None}
)
if response.status_code == 200:
result = response.json()
return result['scores'][0], result['token_lengths'][0]
else:
raise Exception(f"API returned status code {response.status_code}: {response.text}")
except requests.exceptions.ConnectionError:
raise Exception(f"Failed to connect to the DeBERTa API at {api_url}/score")
def get_reward(response: str, answer: str) -> float:
score, token_length = call_deberta_api(response)
return score
def main(config: Config):
# Setup logging
ml_logger = ml_log.setup_logging(
log_dir=config.log_path,
wandb_project=None,
wandb_name=None,
config=config,
do_configure_logging_module=True,
)
# Get tokenizer and renderer
tokenizer = get_tokenizer(config.model_name)
renderer_name = model_info.get_recommended_renderer_name(config.model_name)
renderer = renderers.get_renderer(renderer_name, tokenizer)
logger.info(f"Using renderer: {renderer_name}")
# Load BeaverTails dataset
logger.info("Loading dataset...")
dataset = datasets.load_dataset("PKU-Alignment/BeaverTails")
assert isinstance(dataset, datasets.DatasetDict)
train_dataset = dataset["30k_train"]
train_dataset = train_dataset.select(range(config.data_size))
n_train_batches = len(train_dataset) // config.batch_size
# Setup training client
service_client = tinker.ServiceClient(base_url=config.base_url)
rest_client = service_client.create_rest_client()
sampler_ckpt_queue: deque[str] = deque()
resume_info = checkpoint_utils.get_last_checkpoint(config.log_path)
if resume_info:
training_client = service_client.create_training_client_from_state_with_optimizer(
resume_info["state_path"]
)
start_batch = resume_info["batch"]
logger.info(f"Resuming from batch {start_batch}")
else:
training_client = service_client.create_lora_training_client(
base_model=config.model_name, rank=config.lora_rank
)
start_batch = 0
sampling_params = tinker.types.SamplingParams(
max_tokens=config.max_tokens,
stop=renderer.get_stop_sequences(),
)
# Optimizer step
adam_params = types.AdamParams(
learning_rate=config.learning_rate, beta1=0.9, beta2=0.95, eps=1e-8
)
logger.info(f"Training for {n_train_batches} batches")
# Main training loop
for batch_idx in range(start_batch, n_train_batches):
t_start = time.time()
metrics: dict[str, float] = {
"progress/batch": batch_idx,
"optim/lr": config.learning_rate,
"progress/done_frac": (batch_idx + 1) / n_train_batches,
}
# Save checkpoint
if config.save_every > 0 and batch_idx % config.save_every == 0 and batch_idx > 0:
checkpoint_utils.save_checkpoint(
training_client=training_client,
name=f"logger-{batch_idx:06d}",
log_path=config.log_path,
kind="both",
loop_state={"batch": batch_idx},
)
# Get training batch and convert to datums online
batch_start = batch_idx * config.batch_size
batch_end = min((batch_idx + 1) * config.batch_size, len(train_dataset))
batch_rows = train_dataset.select(range(batch_start, batch_end)) # features: ['prompt', 'response', 'category', 'is_safe']
sampling_path = (
training_client.save_weights_for_sampler(name=f"{batch_idx:06d}").result().path
)
sampling_client = service_client.create_sampling_client(model_path=sampling_path)
sampler_ckpt_queue.append(sampling_path)
if config.keep_last_sampler_checkpoints > 0:
while len(sampler_ckpt_queue) > config.keep_last_sampler_checkpoints:
old_path = sampler_ckpt_queue.popleft()
rest_client.delete_checkpoint_from_tinker_path(old_path).result()
logger.info(f"Deleted old sampler checkpoint: {old_path}")
datums_D: list[types.Datum] = []
rewards_P: list[float] = []
futures_P: list[Future[types.SampleResponse]] = []
questions_P: list[str] = []
prompts_P: list[list[int]] = []
# disable_think_tokens = tokenizer.encode("<think>\n\n</think>\n\n", add_special_tokens=False) # Uncomment for thinking model like Qwen/Qwen3-30B-A3B
for question in batch_rows["prompt"]:
convo = [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": question},
]
model_input = renderer.build_generation_prompt(convo)
prompt_tokens = model_input.to_ints()
# prompt_tokens.extend(disable_think_tokens) # Uncomment for thinking model like Qwen/Qwen3-30B-A3B
# model_input = types.ModelInput.from_ints(tokens=prompt_tokens) # Uncomment for thinking model like Qwen/Qwen3-30B-A3B
# Generate group_size responses in a single call
future = sampling_client.sample(
prompt=model_input,
num_samples=config.group_size,
sampling_params=sampling_params,
)
futures_P.append(future)
questions_P.append(question)
prompts_P.append(prompt_tokens)
for future, prompt_tokens, answer, question in tqdm(
zip(futures_P, prompts_P, batch_rows["response"], questions_P),
total=len(futures_P),
desc=f"Sampling batch {batch_idx}",
):
sample_result = future.result()
rewards_G: list[float] = []
tokens_G_T: list[list[int]] = []
logprobs_G_T: list[list[float]] = []
ob_lens_G: list[int] = []
for sequence in sample_result.sequences:
sampled_tokens = sequence.tokens
sampled_logprobs = sequence.logprobs
assert sampled_logprobs is not None
all_tokens = prompt_tokens + sampled_tokens
tokens_G_T.append(all_tokens)
ob_lens_G.append(len(prompt_tokens) - 1)
logprobs_G_T.append(sampled_logprobs)
parsed_message, _ = renderer.parse_response(sampled_tokens)
content = renderers.get_text_content(parsed_message)
# randomly print some responses for debugging
if torch.rand(1).item() < 0.01:
logger.info("==" * 10)
logger.info("Question: %s", question)
logger.info("Generated content: %s", content)
reward = get_reward(content, answer)
rewards_G.append(reward)
mean_reward = sum(rewards_G) / len(rewards_G)
advantages_G = [reward - mean_reward for reward in rewards_G]
rewards_P.append(mean_reward)
# check if all advantages are zero
if all(advantage == 0.0 for advantage in advantages_G):
# Skip question because all advantages are the same
continue
for tokens, logprobs, advantage, ob_len in zip(
tokens_G_T, logprobs_G_T, advantages_G, ob_lens_G
):
input_tokens = tokens[:-1]
input_tokens = [int(token) for token in input_tokens]
target_tokens = tokens[1:]
padded_logprobs = [0.0] * ob_len + logprobs
padded_advantages = [0.0] * ob_len + [advantage] * (len(input_tokens) - ob_len)
assert (
len(input_tokens)
== len(target_tokens)
== len(padded_logprobs)
== len(padded_advantages)
), (
f"len(input_tokens): {len(input_tokens)}, len(target_tokens): {len(target_tokens)}, "
f"len(padded_logprobs): {len(padded_logprobs)}, len(padded_advantages): {len(padded_advantages)}"
)
datum = types.Datum(
model_input=types.ModelInput.from_ints(tokens=input_tokens),
loss_fn_inputs={
"target_tokens": TensorData.from_torch(torch.tensor(target_tokens)),
"logprobs": TensorData.from_torch(torch.tensor(padded_logprobs)),
"advantages": TensorData.from_torch(torch.tensor(padded_advantages)),
},
)
datums_D.append(datum)
# Training step
fwd_bwd_future = training_client.forward_backward(datums_D, loss_fn="importance_sampling")
optim_step_future = training_client.optim_step(adam_params)
_fwd_bwd_result = fwd_bwd_future.result()
_optim_result = optim_step_future.result()
# Log metrics
metrics["time/total"] = time.time() - t_start
metrics["reward/total"] = sum(rewards_P) / len(rewards_P)
ml_logger.log_metrics(metrics, step=batch_idx)
# Save final checkpoint
checkpoint_utils.save_checkpoint(
training_client=training_client,
name="final",
log_path=config.log_path,
kind="both",
loop_state={"batch": n_train_batches},
)
ml_logger.close()
logger.info("Training completed")
if __name__ == "__main__":
chz.nested_entrypoint(main)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment