Skip to content

Instantly share code, notes, and snippets.

@andrewgcodes
Created November 20, 2025 04:59
Show Gist options
  • Select an option

  • Save andrewgcodes/345c08a8a08a6003d43ff861374006a6 to your computer and use it in GitHub Desktop.

Select an option

Save andrewgcodes/345c08a8a08a6003d43ff861374006a6 to your computer and use it in GitHub Desktop.
script to do learning rate sweep on tinker
"""
clone the tinker cookbook and follow their setup instructions
put this file inside /recipes
install matplotlib
sample command
python -m tinker_lora_lr_sweep \
model_name=meta-llama/Llama-3.1-8B \
dataset_name=openai/gsm8k \
dataset_config_name=main \
dataset_split=train \
center_mode=auto \
steps=100 \
batch_size=8 \
lr_multipliers="[0.25,0.5,1.0,2.0,4.0]"
"""
import asyncio
import json
import logging
import os
from datetime import datetime
from typing import List, Dict, Any
import chz
import datasets
import matplotlib.pyplot as plt
import numpy as np
import tinker
from rich.console import Console
from rich.progress import (
Progress,
ProgressColumn,
SpinnerColumn,
TextColumn,
BarColumn,
TaskProgressColumn,
TimeRemainingColumn,
)
from rich.text import Text
from tinker_cookbook import hyperparam_utils, model_info, renderers, cli_utils
from tinker_cookbook.cli_utils import LogdirBehavior
from tinker_cookbook.renderers import TrainOnWhat
from tinker_cookbook.supervised.common import compute_mean_nll
from tinker_cookbook.supervised.data import conversation_to_datum
from tinker_cookbook.tokenizer_utils import get_tokenizer
from tinker_cookbook.utils import ml_log
# Keep stdout relatively clean so Rich progress bars look good
logging.basicConfig(level=logging.WARNING)
logger = logging.getLogger(__name__)
class StepColumn(ProgressColumn):
"""Render the current step number for each LR sweep task."""
def render(self, task) -> Text:
if task.total is None:
return Text("-", style="dim")
return Text(f"Step {int(task.completed)}/{int(task.total)}", style="cyan")
class LossColumn(ProgressColumn):
"""Render the current loss for each LR sweep task in the progress bar."""
def render(self, task) -> Text:
loss = task.fields.get("loss")
if loss is None:
return Text("-", style="dim")
if loss > 5.0:
style = "bold red"
elif loss < 1.0:
style = "bold green"
else:
style = "yellow"
return Text(f"Loss: {loss:.4f}", style=style)
@chz.chz
class SweepConfig:
# Model / training
model_name: str = "Qwen/Qwen3-4B-Instruct-2507"
lora_rank: int = 32
batch_size: int = 8
steps: int = 20
max_length: int = 512
# LR sweep config
lr_multipliers: List[float] = chz.field(
default_factory=lambda: [0.25, 0.5, 1.0, 2.0, 4.0]
)
center_mode: str = "auto" # "auto" | "manual"
base_lr: float | None = None # used when center_mode == "manual"
# Dataset config (generic, not GSM8K-only)
dataset_name: str = "openai/gsm8k" # any HF dataset
dataset_config_name: str | None = "main" # e.g. "main" for gsm8k, or None
dataset_split: str = "train"
input_field: str = "question" # used when no 'messages' field is present
target_field: str = "answer"
max_train_examples: int = 512 # cap number of examples to convert
# Infra / logging
output_dir: str = "~/tinker-sweeps"
logdir_behavior: LogdirBehavior = "resume" # "resume", "delete", "ask", "raise"
base_url: str | None = None # Tinker base_url if needed
max_concurrent_variants: int = 16 # limit parallel TrainingClients
# Optional external logging
wandb_project: str | None = None
wandb_name: str | None = None
async def prepare_data(config: SweepConfig, console: Console) -> list[tinker.Datum]:
"""
Load a HF dataset and convert a subset of rows to tinker.Datum using conversation_to_datum.
Supports:
- chat-style datasets with a 'messages' field
- generic QA-style datasets via input_field/target_field
"""
# Normalize string "None" from CLI to actual None
cfg_name = config.dataset_config_name
if isinstance(cfg_name, str) and cfg_name.lower() == "none":
cfg_name = None
with console.status("[bold green]Preparing dataset & tokenizer..."):
# Load dataset split (with or without config name)
if cfg_name:
ds = datasets.load_dataset(
config.dataset_name,
cfg_name,
split=config.dataset_split,
)
else:
ds = datasets.load_dataset(
config.dataset_name,
split=config.dataset_split,
)
n_examples = min(config.max_train_examples, len(ds))
ds_subset = ds.select(range(n_examples))
# Tokenizer & renderer based on model
tokenizer = get_tokenizer(config.model_name)
renderer_name = model_info.get_recommended_renderer_name(config.model_name)
renderer = renderers.get_renderer(renderer_name, tokenizer)
datums: list[tinker.Datum] = []
for row in ds_subset:
if "messages" in row:
# Chat-style dataset (e.g. no_robots, tulu3)
messages = row["messages"]
else:
# Generic QA mapping (input -> user, target -> assistant)
user_text = row[config.input_field]
target_text = row[config.target_field]
messages = [
{"role": "user", "content": user_text},
{"role": "assistant", "content": target_text},
]
datum = conversation_to_datum(
messages,
renderer,
config.max_length,
TrainOnWhat.ALL_ASSISTANT_MESSAGES,
)
datums.append(datum)
console.print(
f"[green]✓ Prepared {len(datums)} training examples from "
f"[cyan]{config.dataset_name}[/cyan] ({config.dataset_split})[/green]"
)
return datums
async def train_single_variant(
service_client: tinker.ServiceClient,
config: SweepConfig,
base_lr: float,
multiplier: float,
data: list[tinker.Datum],
progress: Progress,
task_id: int,
run_id: int,
ml_logger: ml_log.Logger,
) -> Dict[str, Any]:
"""
Train a single LoRA model at a particular LR and log metrics.
Returns:
dict with multiplier, lr, and loss curve for summary/plotting.
"""
target_lr = base_lr * multiplier
# Show explicit LR in initial description
progress.update(
task_id,
description=f"[bold blue]{multiplier}x[/] (LR: {target_lr:.2e}, init)",
visible=True,
)
client = await service_client.create_lora_training_client_async(
base_model=config.model_name,
rank=config.lora_rank,
)
losses: list[float] = []
# Simple infinite batch generator cycling through data
def batch_gen():
while True:
for i in range(0, len(data), config.batch_size):
yield data[i: i + config.batch_size]
batch_iter = batch_gen()
# Update description now that we're actually training
progress.update(
task_id,
description=f"[bold blue]{multiplier}x[/] (LR: {target_lr:.2e})",
)
for step in range(config.steps):
batch = next(batch_iter)
# Async pipeline: queue fwd & optim, then await their results
fwd_future = await client.forward_backward_async(
batch, loss_fn="cross_entropy"
)
opt_future = await client.optim_step_async(
tinker.AdamParams(learning_rate=target_lr)
)
fwd_result = await fwd_future.result_async()
_ = await opt_future.result_async()
# Compute loss
logprobs = [out["logprobs"] for out in fwd_result.loss_fn_outputs]
weights = [d.loss_fn_inputs["weights"] for d in batch]
loss = compute_mean_nll(logprobs, weights)
losses.append(loss)
global_step = run_id * config.steps + step
# Log detailed metrics, including exact LR, to metrics.jsonl
ml_logger.log_metrics(
{
"variant_index": run_id,
"lr_multiplier": multiplier,
"learning_rate": target_lr,
"step_in_variant": step,
"global_step": global_step,
"loss": loss,
"batch_size": config.batch_size,
},
step=global_step,
)
# Update Rich progress for this variant
progress.update(task_id, advance=1, loss=loss)
progress.update(
task_id,
description=f"[bold green]{multiplier}x (LR: {target_lr:.2e}, done)[/]",
)
return {
"multiplier": multiplier,
"lr": target_lr,
"losses": losses,
}
async def run_sweep(config: SweepConfig) -> Dict[str, Any]:
"""
Core engine: run the LR sweep, log metrics, and return results for plotting & summary.
"""
console = Console()
# Prepare log directory
output_dir = os.path.expanduser(config.output_dir)
cli_utils.check_log_dir(output_dir, behavior_if_exists=config.logdir_behavior)
os.makedirs(output_dir, exist_ok=True)
# Setup Tinker-style logger (no extra stdout logging)
run_name = (
f"lr_sweep-{config.model_name.replace('/', '-')}-"
f"{config.lora_rank}rank-{datetime.now().strftime('%Y-%m-%d-%H-%M-%S')}"
)
ml_logger = ml_log.setup_logging(
log_dir=output_dir,
wandb_project=config.wandb_project,
wandb_name=config.wandb_name or run_name,
config=config,
do_configure_logging_module=False, # don't spam stdout with tables
)
console.print(
f"[bold]Starting LoRA sweep[/bold] for model "
f"[cyan]{config.model_name}[/cyan] (rank={config.lora_rank})"
)
service_client = tinker.ServiceClient(base_url=config.base_url)
data = await prepare_data(config, console)
# Determine center LR
if config.center_mode == "auto":
base_lr = hyperparam_utils.get_lr(config.model_name)
console.print(
f"[bold]Center LR (auto from get_lr):[/bold] [cyan]{base_lr:.2e}[/cyan]"
)
elif config.center_mode == "manual":
if config.base_lr is None:
raise ValueError("base_lr must be set when center_mode='manual'")
base_lr = config.base_lr
console.print(
f"[bold]Center LR (manual):[/bold] [cyan]{base_lr:.2e}[/cyan]"
)
else:
raise ValueError(
f"Unknown center_mode '{config.center_mode}', use 'auto' or 'manual'"
)
# Rich progress UI
progress = Progress(
SpinnerColumn(),
TextColumn("[progress.description]{task.description}"),
BarColumn(),
TaskProgressColumn(),
StepColumn(),
LossColumn(),
TimeRemainingColumn(),
console=console,
)
all_results: list[Dict[str, Any]] = []
with progress:
# Create a task for each LR multiplier
task_ids: list[int] = []
for mult in config.lr_multipliers:
tid = progress.add_task(
f"{mult}x",
total=config.steps,
loss=None,
start=False,
)
task_ids.append(tid)
# Run variants in chunks to limit parallel TrainingClients
for start_idx in range(
0, len(config.lr_multipliers), config.max_concurrent_variants
):
end_idx = start_idx + config.max_concurrent_variants
chunk_multipliers = config.lr_multipliers[start_idx:end_idx]
chunk_task_ids = task_ids[start_idx:end_idx]
# Start tasks in this chunk
for tid in chunk_task_ids:
progress.start_task(tid)
coros = []
for i, (mult, tid) in enumerate(zip(chunk_multipliers, chunk_task_ids)):
run_id = start_idx + i
coros.append(
train_single_variant(
service_client,
config,
base_lr,
mult,
data,
progress,
tid,
run_id,
ml_logger,
)
)
chunk_results = await asyncio.gather(*coros)
all_results.extend(chunk_results)
# Close logger (flush metrics)
ml_logger.close()
return {
"results": all_results,
"output_dir": output_dir,
"center_lr": base_lr,
}
def summarize_and_plot(
sweep_output: Dict[str, Any],
config: SweepConfig,
) -> None:
"""
Compute regret, write a JSON summary, and generate a PNG plot.
"""
results = sweep_output["results"]
output_dir = sweep_output["output_dir"]
center_lr = sweep_output["center_lr"]
# Compute final loss & regret per LR
final_losses = np.array([res["losses"][-1] for res in results], dtype=float)
best_loss = float(final_losses.min())
regrets = (final_losses - best_loss) / best_loss
sweeps_summary: list[Dict[str, Any]] = []
for res, final, regret in zip(results, final_losses, regrets):
sweeps_summary.append(
{
"multiplier": float(res["multiplier"]),
"learning_rate": float(res["lr"]),
"final_loss": float(final),
"regret": float(regret),
}
)
# Meta info about the sweep
summary_meta: Dict[str, Any] = {
"center_mode": config.center_mode,
"center_lr": float(center_lr),
"model_name": config.model_name,
"lora_rank": config.lora_rank,
"dataset_name": config.dataset_name,
"dataset_config_name": config.dataset_config_name,
"dataset_split": config.dataset_split,
"steps": config.steps,
"batch_size": config.batch_size,
}
summary_payload = {
"meta": summary_meta,
"sweeps": sweeps_summary,
}
summary_path = os.path.join(output_dir, "summary.json")
with open(summary_path, "w") as f:
json.dump(summary_payload, f, indent=2)
print(f"[Summary] Wrote JSON summary to {summary_path}")
# Plot loss curves for each LR (smoothed)
plt.figure(figsize=(10, 6))
for res in results:
mult = res["multiplier"]
lr = res["lr"]
losses = np.array(res["losses"], dtype=float)
window = min(3, len(losses))
if window > 1:
kernel = np.ones(window) / window
smoothed = np.convolve(losses, kernel, mode="valid")
x = np.arange(len(smoothed))
else:
smoothed = losses
x = np.arange(len(losses))
label = f"{mult}x ({lr:.1e})"
plt.plot(x, smoothed, label=label, linewidth=2, alpha=0.85)
plt.title(f"LoRA LR Sweep: {config.model_name}", fontsize=14)
plt.xlabel("Step", fontsize=12)
plt.ylabel("Train NLL", fontsize=12)
plt.ylim(bottom=0.0)
plt.grid(True, alpha=0.2)
plt.legend(title="LR Multiplier (absolute LR)", fontsize=9)
plot_path = os.path.join(output_dir, "loss_curves.png")
plt.tight_layout()
plt.savefig(plot_path)
print(f"[Plot] Saved loss curves to {plot_path}")
async def main(config: SweepConfig) -> None:
sweep_output = await run_sweep(config)
summarize_and_plot(sweep_output, config)
if __name__ == "__main__":
cfg = chz.entrypoint(SweepConfig)
asyncio.run(main(cfg))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment