Created
November 20, 2025 04:59
-
-
Save andrewgcodes/345c08a8a08a6003d43ff861374006a6 to your computer and use it in GitHub Desktop.
script to do learning rate sweep on tinker
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| """ | |
| clone the tinker cookbook and follow their setup instructions | |
| put this file inside /recipes | |
| install matplotlib | |
| sample command | |
| python -m tinker_lora_lr_sweep \ | |
| model_name=meta-llama/Llama-3.1-8B \ | |
| dataset_name=openai/gsm8k \ | |
| dataset_config_name=main \ | |
| dataset_split=train \ | |
| center_mode=auto \ | |
| steps=100 \ | |
| batch_size=8 \ | |
| lr_multipliers="[0.25,0.5,1.0,2.0,4.0]" | |
| """ | |
| import asyncio | |
| import json | |
| import logging | |
| import os | |
| from datetime import datetime | |
| from typing import List, Dict, Any | |
| import chz | |
| import datasets | |
| import matplotlib.pyplot as plt | |
| import numpy as np | |
| import tinker | |
| from rich.console import Console | |
| from rich.progress import ( | |
| Progress, | |
| ProgressColumn, | |
| SpinnerColumn, | |
| TextColumn, | |
| BarColumn, | |
| TaskProgressColumn, | |
| TimeRemainingColumn, | |
| ) | |
| from rich.text import Text | |
| from tinker_cookbook import hyperparam_utils, model_info, renderers, cli_utils | |
| from tinker_cookbook.cli_utils import LogdirBehavior | |
| from tinker_cookbook.renderers import TrainOnWhat | |
| from tinker_cookbook.supervised.common import compute_mean_nll | |
| from tinker_cookbook.supervised.data import conversation_to_datum | |
| from tinker_cookbook.tokenizer_utils import get_tokenizer | |
| from tinker_cookbook.utils import ml_log | |
| # Keep stdout relatively clean so Rich progress bars look good | |
| logging.basicConfig(level=logging.WARNING) | |
| logger = logging.getLogger(__name__) | |
| class StepColumn(ProgressColumn): | |
| """Render the current step number for each LR sweep task.""" | |
| def render(self, task) -> Text: | |
| if task.total is None: | |
| return Text("-", style="dim") | |
| return Text(f"Step {int(task.completed)}/{int(task.total)}", style="cyan") | |
| class LossColumn(ProgressColumn): | |
| """Render the current loss for each LR sweep task in the progress bar.""" | |
| def render(self, task) -> Text: | |
| loss = task.fields.get("loss") | |
| if loss is None: | |
| return Text("-", style="dim") | |
| if loss > 5.0: | |
| style = "bold red" | |
| elif loss < 1.0: | |
| style = "bold green" | |
| else: | |
| style = "yellow" | |
| return Text(f"Loss: {loss:.4f}", style=style) | |
| @chz.chz | |
| class SweepConfig: | |
| # Model / training | |
| model_name: str = "Qwen/Qwen3-4B-Instruct-2507" | |
| lora_rank: int = 32 | |
| batch_size: int = 8 | |
| steps: int = 20 | |
| max_length: int = 512 | |
| # LR sweep config | |
| lr_multipliers: List[float] = chz.field( | |
| default_factory=lambda: [0.25, 0.5, 1.0, 2.0, 4.0] | |
| ) | |
| center_mode: str = "auto" # "auto" | "manual" | |
| base_lr: float | None = None # used when center_mode == "manual" | |
| # Dataset config (generic, not GSM8K-only) | |
| dataset_name: str = "openai/gsm8k" # any HF dataset | |
| dataset_config_name: str | None = "main" # e.g. "main" for gsm8k, or None | |
| dataset_split: str = "train" | |
| input_field: str = "question" # used when no 'messages' field is present | |
| target_field: str = "answer" | |
| max_train_examples: int = 512 # cap number of examples to convert | |
| # Infra / logging | |
| output_dir: str = "~/tinker-sweeps" | |
| logdir_behavior: LogdirBehavior = "resume" # "resume", "delete", "ask", "raise" | |
| base_url: str | None = None # Tinker base_url if needed | |
| max_concurrent_variants: int = 16 # limit parallel TrainingClients | |
| # Optional external logging | |
| wandb_project: str | None = None | |
| wandb_name: str | None = None | |
| async def prepare_data(config: SweepConfig, console: Console) -> list[tinker.Datum]: | |
| """ | |
| Load a HF dataset and convert a subset of rows to tinker.Datum using conversation_to_datum. | |
| Supports: | |
| - chat-style datasets with a 'messages' field | |
| - generic QA-style datasets via input_field/target_field | |
| """ | |
| # Normalize string "None" from CLI to actual None | |
| cfg_name = config.dataset_config_name | |
| if isinstance(cfg_name, str) and cfg_name.lower() == "none": | |
| cfg_name = None | |
| with console.status("[bold green]Preparing dataset & tokenizer..."): | |
| # Load dataset split (with or without config name) | |
| if cfg_name: | |
| ds = datasets.load_dataset( | |
| config.dataset_name, | |
| cfg_name, | |
| split=config.dataset_split, | |
| ) | |
| else: | |
| ds = datasets.load_dataset( | |
| config.dataset_name, | |
| split=config.dataset_split, | |
| ) | |
| n_examples = min(config.max_train_examples, len(ds)) | |
| ds_subset = ds.select(range(n_examples)) | |
| # Tokenizer & renderer based on model | |
| tokenizer = get_tokenizer(config.model_name) | |
| renderer_name = model_info.get_recommended_renderer_name(config.model_name) | |
| renderer = renderers.get_renderer(renderer_name, tokenizer) | |
| datums: list[tinker.Datum] = [] | |
| for row in ds_subset: | |
| if "messages" in row: | |
| # Chat-style dataset (e.g. no_robots, tulu3) | |
| messages = row["messages"] | |
| else: | |
| # Generic QA mapping (input -> user, target -> assistant) | |
| user_text = row[config.input_field] | |
| target_text = row[config.target_field] | |
| messages = [ | |
| {"role": "user", "content": user_text}, | |
| {"role": "assistant", "content": target_text}, | |
| ] | |
| datum = conversation_to_datum( | |
| messages, | |
| renderer, | |
| config.max_length, | |
| TrainOnWhat.ALL_ASSISTANT_MESSAGES, | |
| ) | |
| datums.append(datum) | |
| console.print( | |
| f"[green]✓ Prepared {len(datums)} training examples from " | |
| f"[cyan]{config.dataset_name}[/cyan] ({config.dataset_split})[/green]" | |
| ) | |
| return datums | |
| async def train_single_variant( | |
| service_client: tinker.ServiceClient, | |
| config: SweepConfig, | |
| base_lr: float, | |
| multiplier: float, | |
| data: list[tinker.Datum], | |
| progress: Progress, | |
| task_id: int, | |
| run_id: int, | |
| ml_logger: ml_log.Logger, | |
| ) -> Dict[str, Any]: | |
| """ | |
| Train a single LoRA model at a particular LR and log metrics. | |
| Returns: | |
| dict with multiplier, lr, and loss curve for summary/plotting. | |
| """ | |
| target_lr = base_lr * multiplier | |
| # Show explicit LR in initial description | |
| progress.update( | |
| task_id, | |
| description=f"[bold blue]{multiplier}x[/] (LR: {target_lr:.2e}, init)", | |
| visible=True, | |
| ) | |
| client = await service_client.create_lora_training_client_async( | |
| base_model=config.model_name, | |
| rank=config.lora_rank, | |
| ) | |
| losses: list[float] = [] | |
| # Simple infinite batch generator cycling through data | |
| def batch_gen(): | |
| while True: | |
| for i in range(0, len(data), config.batch_size): | |
| yield data[i: i + config.batch_size] | |
| batch_iter = batch_gen() | |
| # Update description now that we're actually training | |
| progress.update( | |
| task_id, | |
| description=f"[bold blue]{multiplier}x[/] (LR: {target_lr:.2e})", | |
| ) | |
| for step in range(config.steps): | |
| batch = next(batch_iter) | |
| # Async pipeline: queue fwd & optim, then await their results | |
| fwd_future = await client.forward_backward_async( | |
| batch, loss_fn="cross_entropy" | |
| ) | |
| opt_future = await client.optim_step_async( | |
| tinker.AdamParams(learning_rate=target_lr) | |
| ) | |
| fwd_result = await fwd_future.result_async() | |
| _ = await opt_future.result_async() | |
| # Compute loss | |
| logprobs = [out["logprobs"] for out in fwd_result.loss_fn_outputs] | |
| weights = [d.loss_fn_inputs["weights"] for d in batch] | |
| loss = compute_mean_nll(logprobs, weights) | |
| losses.append(loss) | |
| global_step = run_id * config.steps + step | |
| # Log detailed metrics, including exact LR, to metrics.jsonl | |
| ml_logger.log_metrics( | |
| { | |
| "variant_index": run_id, | |
| "lr_multiplier": multiplier, | |
| "learning_rate": target_lr, | |
| "step_in_variant": step, | |
| "global_step": global_step, | |
| "loss": loss, | |
| "batch_size": config.batch_size, | |
| }, | |
| step=global_step, | |
| ) | |
| # Update Rich progress for this variant | |
| progress.update(task_id, advance=1, loss=loss) | |
| progress.update( | |
| task_id, | |
| description=f"[bold green]{multiplier}x (LR: {target_lr:.2e}, done)[/]", | |
| ) | |
| return { | |
| "multiplier": multiplier, | |
| "lr": target_lr, | |
| "losses": losses, | |
| } | |
| async def run_sweep(config: SweepConfig) -> Dict[str, Any]: | |
| """ | |
| Core engine: run the LR sweep, log metrics, and return results for plotting & summary. | |
| """ | |
| console = Console() | |
| # Prepare log directory | |
| output_dir = os.path.expanduser(config.output_dir) | |
| cli_utils.check_log_dir(output_dir, behavior_if_exists=config.logdir_behavior) | |
| os.makedirs(output_dir, exist_ok=True) | |
| # Setup Tinker-style logger (no extra stdout logging) | |
| run_name = ( | |
| f"lr_sweep-{config.model_name.replace('/', '-')}-" | |
| f"{config.lora_rank}rank-{datetime.now().strftime('%Y-%m-%d-%H-%M-%S')}" | |
| ) | |
| ml_logger = ml_log.setup_logging( | |
| log_dir=output_dir, | |
| wandb_project=config.wandb_project, | |
| wandb_name=config.wandb_name or run_name, | |
| config=config, | |
| do_configure_logging_module=False, # don't spam stdout with tables | |
| ) | |
| console.print( | |
| f"[bold]Starting LoRA sweep[/bold] for model " | |
| f"[cyan]{config.model_name}[/cyan] (rank={config.lora_rank})" | |
| ) | |
| service_client = tinker.ServiceClient(base_url=config.base_url) | |
| data = await prepare_data(config, console) | |
| # Determine center LR | |
| if config.center_mode == "auto": | |
| base_lr = hyperparam_utils.get_lr(config.model_name) | |
| console.print( | |
| f"[bold]Center LR (auto from get_lr):[/bold] [cyan]{base_lr:.2e}[/cyan]" | |
| ) | |
| elif config.center_mode == "manual": | |
| if config.base_lr is None: | |
| raise ValueError("base_lr must be set when center_mode='manual'") | |
| base_lr = config.base_lr | |
| console.print( | |
| f"[bold]Center LR (manual):[/bold] [cyan]{base_lr:.2e}[/cyan]" | |
| ) | |
| else: | |
| raise ValueError( | |
| f"Unknown center_mode '{config.center_mode}', use 'auto' or 'manual'" | |
| ) | |
| # Rich progress UI | |
| progress = Progress( | |
| SpinnerColumn(), | |
| TextColumn("[progress.description]{task.description}"), | |
| BarColumn(), | |
| TaskProgressColumn(), | |
| StepColumn(), | |
| LossColumn(), | |
| TimeRemainingColumn(), | |
| console=console, | |
| ) | |
| all_results: list[Dict[str, Any]] = [] | |
| with progress: | |
| # Create a task for each LR multiplier | |
| task_ids: list[int] = [] | |
| for mult in config.lr_multipliers: | |
| tid = progress.add_task( | |
| f"{mult}x", | |
| total=config.steps, | |
| loss=None, | |
| start=False, | |
| ) | |
| task_ids.append(tid) | |
| # Run variants in chunks to limit parallel TrainingClients | |
| for start_idx in range( | |
| 0, len(config.lr_multipliers), config.max_concurrent_variants | |
| ): | |
| end_idx = start_idx + config.max_concurrent_variants | |
| chunk_multipliers = config.lr_multipliers[start_idx:end_idx] | |
| chunk_task_ids = task_ids[start_idx:end_idx] | |
| # Start tasks in this chunk | |
| for tid in chunk_task_ids: | |
| progress.start_task(tid) | |
| coros = [] | |
| for i, (mult, tid) in enumerate(zip(chunk_multipliers, chunk_task_ids)): | |
| run_id = start_idx + i | |
| coros.append( | |
| train_single_variant( | |
| service_client, | |
| config, | |
| base_lr, | |
| mult, | |
| data, | |
| progress, | |
| tid, | |
| run_id, | |
| ml_logger, | |
| ) | |
| ) | |
| chunk_results = await asyncio.gather(*coros) | |
| all_results.extend(chunk_results) | |
| # Close logger (flush metrics) | |
| ml_logger.close() | |
| return { | |
| "results": all_results, | |
| "output_dir": output_dir, | |
| "center_lr": base_lr, | |
| } | |
| def summarize_and_plot( | |
| sweep_output: Dict[str, Any], | |
| config: SweepConfig, | |
| ) -> None: | |
| """ | |
| Compute regret, write a JSON summary, and generate a PNG plot. | |
| """ | |
| results = sweep_output["results"] | |
| output_dir = sweep_output["output_dir"] | |
| center_lr = sweep_output["center_lr"] | |
| # Compute final loss & regret per LR | |
| final_losses = np.array([res["losses"][-1] for res in results], dtype=float) | |
| best_loss = float(final_losses.min()) | |
| regrets = (final_losses - best_loss) / best_loss | |
| sweeps_summary: list[Dict[str, Any]] = [] | |
| for res, final, regret in zip(results, final_losses, regrets): | |
| sweeps_summary.append( | |
| { | |
| "multiplier": float(res["multiplier"]), | |
| "learning_rate": float(res["lr"]), | |
| "final_loss": float(final), | |
| "regret": float(regret), | |
| } | |
| ) | |
| # Meta info about the sweep | |
| summary_meta: Dict[str, Any] = { | |
| "center_mode": config.center_mode, | |
| "center_lr": float(center_lr), | |
| "model_name": config.model_name, | |
| "lora_rank": config.lora_rank, | |
| "dataset_name": config.dataset_name, | |
| "dataset_config_name": config.dataset_config_name, | |
| "dataset_split": config.dataset_split, | |
| "steps": config.steps, | |
| "batch_size": config.batch_size, | |
| } | |
| summary_payload = { | |
| "meta": summary_meta, | |
| "sweeps": sweeps_summary, | |
| } | |
| summary_path = os.path.join(output_dir, "summary.json") | |
| with open(summary_path, "w") as f: | |
| json.dump(summary_payload, f, indent=2) | |
| print(f"[Summary] Wrote JSON summary to {summary_path}") | |
| # Plot loss curves for each LR (smoothed) | |
| plt.figure(figsize=(10, 6)) | |
| for res in results: | |
| mult = res["multiplier"] | |
| lr = res["lr"] | |
| losses = np.array(res["losses"], dtype=float) | |
| window = min(3, len(losses)) | |
| if window > 1: | |
| kernel = np.ones(window) / window | |
| smoothed = np.convolve(losses, kernel, mode="valid") | |
| x = np.arange(len(smoothed)) | |
| else: | |
| smoothed = losses | |
| x = np.arange(len(losses)) | |
| label = f"{mult}x ({lr:.1e})" | |
| plt.plot(x, smoothed, label=label, linewidth=2, alpha=0.85) | |
| plt.title(f"LoRA LR Sweep: {config.model_name}", fontsize=14) | |
| plt.xlabel("Step", fontsize=12) | |
| plt.ylabel("Train NLL", fontsize=12) | |
| plt.ylim(bottom=0.0) | |
| plt.grid(True, alpha=0.2) | |
| plt.legend(title="LR Multiplier (absolute LR)", fontsize=9) | |
| plot_path = os.path.join(output_dir, "loss_curves.png") | |
| plt.tight_layout() | |
| plt.savefig(plot_path) | |
| print(f"[Plot] Saved loss curves to {plot_path}") | |
| async def main(config: SweepConfig) -> None: | |
| sweep_output = await run_sweep(config) | |
| summarize_and_plot(sweep_output, config) | |
| if __name__ == "__main__": | |
| cfg = chz.entrypoint(SweepConfig) | |
| asyncio.run(main(cfg)) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment