Created
November 25, 2025 01:19
-
-
Save allen-munsch/6258004b82f0b0bf4835d92b53a794da to your computer and use it in GitHub Desktop.
minigpt tpu example
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| { | |
| "cells": [ | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "# π Mini-GPT Training on Google Cloud TPUs (Optimized)\n", | |
| "\n", | |
| "A production-ready, optimized implementation of a small Transformer model trained on TPUs using JAX/Flax.\n", | |
| "\n", | |
| "**Key Features:**\n", | |
| "- β Proper `pmap` usage with replicated parameters\n", | |
| "- β Gradient synchronization across TPU cores\n", | |
| "- β Mixed precision training (bfloat16)\n", | |
| "- β Precomputed causal masks\n", | |
| "- β KV cache for efficient generation\n", | |
| "- β Vectorized top-k sampling\n", | |
| "- β Multi-batch validation\n", | |
| "- β Learning rate scheduling with warmup\n", | |
| "- β Gradient checkpointing support" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "## π¦ Installation\n", | |
| "\n", | |
| "Run this on your TPU VM to install dependencies:" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "!pip install --upgrade \"jax[tpu]>=0.4.32\" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html\n", | |
| "!pip install flax optax datasets" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "## π Imports" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "import jax\n", | |
| "import jax.numpy as jnp\n", | |
| "import numpy as np\n", | |
| "import flax.linen as nn\n", | |
| "import optax\n", | |
| "from flax.training import train_state, checkpoints, dynamic_scale as dynamic_scale_lib\n", | |
| "from datasets import load_dataset\n", | |
| "from typing import Any, Tuple, Optional\n", | |
| "from functools import partial\n", | |
| "\n", | |
| "print(f\"JAX version: {jax.__version__}\")\n", | |
| "print(f\"JAX devices: {jax.devices()}\")\n", | |
| "print(f\"Number of devices: {jax.device_count()}\")" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "## βοΈ Configuration\n", | |
| "\n", | |
| "Adjust these hyperparameters based on your needs:" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "class Config:\n", | |
| " # Model Architecture\n", | |
| " vocab_size: int = None # Will be set after tokenization\n", | |
| " emb_dim: int = 256\n", | |
| " n_layers: int = 6\n", | |
| " n_heads: int = 4\n", | |
| " mlp_dim: int = 1024\n", | |
| " dropout: float = 0.1\n", | |
| " max_seq_len: int = 512\n", | |
| " block_size: int = 128 # Context window for training\n", | |
| " \n", | |
| " # Training Hyperparameters\n", | |
| " batch_size: int = 64 # Must be divisible by number of TPU cores\n", | |
| " learning_rate: float = 3e-4\n", | |
| " warmup_steps: int = 100\n", | |
| " max_steps: int = 10000\n", | |
| " eval_interval: int = 100\n", | |
| " log_interval: int = 50\n", | |
| " num_eval_batches: int = 10 # Number of batches for validation\n", | |
| " \n", | |
| " # Optimization\n", | |
| " use_mixed_precision: bool = True # Use bfloat16 for faster TPU training\n", | |
| " gradient_checkpointing: bool = False # Enable for very large models\n", | |
| " \n", | |
| " # Generation Parameters\n", | |
| " temperature: float = 0.8\n", | |
| " top_k: int = 40\n", | |
| "\n", | |
| "config = Config()\n", | |
| "print(\"Configuration loaded β\")" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "## π Dataset & Tokenization\n", | |
| "\n", | |
| "We'll use the Tiny Shakespeare dataset with character-level tokenization:" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "# Load dataset\n", | |
| "print(\"Loading dataset...\")\n", | |
| "ds = load_dataset(\"tiny_shakespeare\")\n", | |
| "text = ds[\"train\"][0][\"text\"]\n", | |
| "\n", | |
| "# Character-level tokenizer\n", | |
| "chars = sorted(list(set(text)))\n", | |
| "stoi = {c: i for i, c in enumerate(chars)}\n", | |
| "itos = {i: c for i, c in enumerate(chars)}\n", | |
| "config.vocab_size = len(chars)\n", | |
| "\n", | |
| "def encode(s):\n", | |
| " return [stoi[c] for c in s]\n", | |
| "\n", | |
| "def decode(xs):\n", | |
| " return ''.join([itos[x] for x in xs])\n", | |
| "\n", | |
| "# Encode all data\n", | |
| "data = np.array(encode(text), dtype=np.int32)\n", | |
| "\n", | |
| "print(f\"Dataset: {len(data):,} tokens\")\n", | |
| "print(f\"Vocab size: {config.vocab_size}\")\n", | |
| "print(f\"\\nSample text: {text[:200]}\")" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "### Train/Val Split" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "n = int(0.9 * len(data))\n", | |
| "train_data = data[:n]\n", | |
| "val_data = data[n:]\n", | |
| "\n", | |
| "print(f\"Train tokens: {len(train_data):,}\")\n", | |
| "print(f\"Val tokens: {len(val_data):,}\")" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "### Batch Loader" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "def get_batch(split='train', rng=None):\n", | |
| " \"\"\"Get a batch of data for training or validation\"\"\"\n", | |
| " source = train_data if split == 'train' else val_data\n", | |
| " \n", | |
| " if rng is not None:\n", | |
| " ix = jax.random.randint(rng, (config.batch_size,), 0, len(source) - config.block_size - 1)\n", | |
| " else:\n", | |
| " ix = np.random.randint(0, len(source) - config.block_size - 1, (config.batch_size,))\n", | |
| " \n", | |
| " x = np.stack([source[i:i + config.block_size] for i in ix])\n", | |
| " y = np.stack([source[i + 1:i + 1 + config.block_size] for i in ix])\n", | |
| " \n", | |
| " return jnp.array(x), jnp.array(y)\n", | |
| "\n", | |
| "# Test the batch loader\n", | |
| "x_test, y_test = get_batch('train')\n", | |
| "print(f\"Batch shape: {x_test.shape}\")\n", | |
| "print(f\"First sequence: {decode(list(x_test[0][:50]))}\")" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "## π Precompute Causal Mask\n", | |
| "\n", | |
| "**Optimization**: Create mask once instead of every forward pass" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "# Precompute causal mask for efficiency\n", | |
| "CAUSAL_MASK = jnp.tril(jnp.ones((config.block_size, config.block_size), dtype=jnp.bool_))\n", | |
| "print(f\"Precomputed causal mask: {CAUSAL_MASK.shape}\")" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "## ποΈ Model Architecture\n", | |
| "\n", | |
| "### Self-Attention Module with KV Cache Support" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "class SelfAttention(nn.Module):\n", | |
| " n_heads: int\n", | |
| " head_dim: int\n", | |
| " dropout: float = 0.1\n", | |
| " \n", | |
| " @nn.compact\n", | |
| " def __call__(self, x, train: bool = True, cache: Optional[Tuple] = None):\n", | |
| " \"\"\"\n", | |
| " Self-attention with optional KV caching for efficient generation.\n", | |
| " \n", | |
| " Args:\n", | |
| " x: Input tensor (B, T, C)\n", | |
| " train: Whether in training mode\n", | |
| " cache: Optional (k_cache, v_cache) for generation\n", | |
| " \n", | |
| " Returns:\n", | |
| " out: Output tensor (B, T, C)\n", | |
| " new_cache: Updated (k_cache, v_cache) if cache was provided\n", | |
| " \"\"\"\n", | |
| " B, T, C = x.shape\n", | |
| " \n", | |
| " # QKV projection\n", | |
| " qkv = nn.Dense(3 * C, use_bias=False)(x)\n", | |
| " q, k, v = jnp.split(qkv, 3, axis=-1)\n", | |
| " \n", | |
| " # Reshape to (B, n_heads, T, head_dim)\n", | |
| " def split_heads(t):\n", | |
| " return t.reshape(B, T, self.n_heads, self.head_dim).transpose(0, 2, 1, 3)\n", | |
| " \n", | |
| " q = split_heads(q)\n", | |
| " k = split_heads(k)\n", | |
| " v = split_heads(v)\n", | |
| " \n", | |
| " # Handle KV cache for generation\n", | |
| " if cache is not None:\n", | |
| " k_cache, v_cache = cache\n", | |
| " k = jnp.concatenate([k_cache, k], axis=2)\n", | |
| " v = jnp.concatenate([v_cache, v], axis=2)\n", | |
| " new_cache = (k, v)\n", | |
| " else:\n", | |
| " new_cache = None\n", | |
| " \n", | |
| " # Scaled dot-product attention\n", | |
| " scale = jnp.sqrt(self.head_dim).astype(x.dtype)\n", | |
| " attn_weights = jnp.einsum('bhqd,bhkd->bhqk', q, k) / scale\n", | |
| " \n", | |
| " # Apply causal mask (use precomputed global mask)\n", | |
| " T_k = k.shape[2] # Key sequence length (may be longer with cache)\n", | |
| " mask = CAUSAL_MASK[:T, :T_k] if T_k <= config.block_size else jnp.tril(jnp.ones((T, T_k), dtype=jnp.bool_))\n", | |
| " attn_weights = jnp.where(mask[None, None, :, :], attn_weights, -1e10)\n", | |
| " \n", | |
| " attn_weights = nn.softmax(attn_weights, axis=-1)\n", | |
| " attn_weights = nn.Dropout(rate=self.dropout, deterministic=not train)(attn_weights)\n", | |
| " \n", | |
| " # Apply attention to values\n", | |
| " out = jnp.einsum('bhqk,bhkd->bhqd', attn_weights, v)\n", | |
| " out = out.transpose(0, 2, 1, 3).reshape(B, T, C)\n", | |
| " \n", | |
| " # Output projection\n", | |
| " out = nn.Dense(C)(out)\n", | |
| " out = nn.Dropout(rate=self.dropout, deterministic=not train)(out)\n", | |
| " \n", | |
| " if cache is not None:\n", | |
| " return out, new_cache\n", | |
| " return out" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "### Transformer Block" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "class TransformerBlock(nn.Module):\n", | |
| " n_heads: int\n", | |
| " head_dim: int\n", | |
| " mlp_dim: int\n", | |
| " dropout: float = 0.1\n", | |
| " \n", | |
| " @nn.compact\n", | |
| " def __call__(self, x, train: bool = True, cache: Optional[Tuple] = None):\n", | |
| " # Self-attention with residual connection\n", | |
| " h = nn.LayerNorm()(x)\n", | |
| " if cache is not None:\n", | |
| " h, new_cache = SelfAttention(self.n_heads, self.head_dim, self.dropout)(h, train=train, cache=cache)\n", | |
| " else:\n", | |
| " h = SelfAttention(self.n_heads, self.head_dim, self.dropout)(h, train=train, cache=None)\n", | |
| " new_cache = None\n", | |
| " x = x + h\n", | |
| " \n", | |
| " # MLP with residual connection\n", | |
| " h = nn.LayerNorm()(x)\n", | |
| " h = nn.Dense(self.mlp_dim)(h)\n", | |
| " h = nn.gelu(h)\n", | |
| " h = nn.Dropout(rate=self.dropout, deterministic=not train)(h)\n", | |
| " h = nn.Dense(x.shape[-1])(h)\n", | |
| " h = nn.Dropout(rate=self.dropout, deterministic=not train)(h)\n", | |
| " x = x + h\n", | |
| " \n", | |
| " if cache is not None:\n", | |
| " return x, new_cache\n", | |
| " return x" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "### Complete Mini-GPT Model with Mixed Precision Support" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "class MiniGPT(nn.Module):\n", | |
| " config: Config\n", | |
| " \n", | |
| " @nn.compact\n", | |
| " def __call__(self, x, train: bool = True, cache_list: Optional[list] = None):\n", | |
| " \"\"\"\n", | |
| " Args:\n", | |
| " x: Input tokens (B, T)\n", | |
| " train: Training mode flag\n", | |
| " cache_list: Optional list of (k, v) caches for each layer\n", | |
| " \n", | |
| " Returns:\n", | |
| " logits: Output logits (B, T, vocab_size)\n", | |
| " new_cache_list: Updated caches if cache_list was provided\n", | |
| " \"\"\"\n", | |
| " B, T = x.shape\n", | |
| " \n", | |
| " # Mixed precision: compute in bfloat16, but embeddings in float32\n", | |
| " dtype = jnp.bfloat16 if self.config.use_mixed_precision and train else jnp.float32\n", | |
| " \n", | |
| " # Token embeddings\n", | |
| " tok_emb = nn.Embed(\n", | |
| " num_embeddings=self.config.vocab_size,\n", | |
| " features=self.config.emb_dim,\n", | |
| " embedding_init=nn.initializers.normal(stddev=0.02)\n", | |
| " )(x)\n", | |
| " \n", | |
| " # Position embeddings\n", | |
| " pos_emb = self.param(\n", | |
| " 'pos_embed',\n", | |
| " nn.initializers.normal(stddev=0.02),\n", | |
| " (1, self.config.max_seq_len, self.config.emb_dim)\n", | |
| " )\n", | |
| " \n", | |
| " h = tok_emb + pos_emb[:, :T, :]\n", | |
| " \n", | |
| " # Cast to computation dtype\n", | |
| " if self.config.use_mixed_precision and train:\n", | |
| " h = h.astype(dtype)\n", | |
| " \n", | |
| " h = nn.Dropout(rate=self.config.dropout, deterministic=not train)(h)\n", | |
| " \n", | |
| " # Transformer blocks\n", | |
| " head_dim = self.config.emb_dim // self.config.n_heads\n", | |
| " new_cache_list = [] if cache_list is not None else None\n", | |
| " \n", | |
| " for i in range(self.config.n_layers):\n", | |
| " cache = cache_list[i] if cache_list is not None else None\n", | |
| " if cache is not None:\n", | |
| " h, new_cache = TransformerBlock(\n", | |
| " self.config.n_heads,\n", | |
| " head_dim,\n", | |
| " self.config.mlp_dim,\n", | |
| " self.config.dropout\n", | |
| " )(h, train=train, cache=cache)\n", | |
| " new_cache_list.append(new_cache)\n", | |
| " else:\n", | |
| " h = TransformerBlock(\n", | |
| " self.config.n_heads,\n", | |
| " head_dim,\n", | |
| " self.config.mlp_dim,\n", | |
| " self.config.dropout\n", | |
| " )(h, train=train, cache=None)\n", | |
| " \n", | |
| " # Final layer norm and output projection\n", | |
| " h = nn.LayerNorm()(h)\n", | |
| " \n", | |
| " # Cast back to float32 for logits\n", | |
| " if self.config.use_mixed_precision and train:\n", | |
| " h = h.astype(jnp.float32)\n", | |
| " \n", | |
| " logits = nn.Dense(\n", | |
| " self.config.vocab_size,\n", | |
| " use_bias=False,\n", | |
| " kernel_init=nn.initializers.normal(stddev=0.02)\n", | |
| " )(h)\n", | |
| " \n", | |
| " if cache_list is not None:\n", | |
| " return logits, new_cache_list\n", | |
| " return logits\n", | |
| "\n", | |
| "print(\"Model architecture defined β\")" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "## π― Training Setup\n", | |
| "\n", | |
| "### Training State with Mixed Precision" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "class TrainState(train_state.TrainState):\n", | |
| " \"\"\"Extended TrainState with dropout RNG and optional dynamic scale for mixed precision\"\"\"\n", | |
| " dropout_rng: jax.random.PRNGKey\n", | |
| " dynamic_scale: Optional[dynamic_scale_lib.DynamicScale] = None\n", | |
| "\n", | |
| "def create_train_state(rng, config):\n", | |
| " \"\"\"Initialize training state with model, optimizer, and learning rate schedule\"\"\"\n", | |
| " model = MiniGPT(config)\n", | |
| " \n", | |
| " # Initialize with dummy input\n", | |
| " dummy_input = jnp.ones((1, config.block_size), dtype=jnp.int32)\n", | |
| " variables = model.init(rng, dummy_input, train=False)\n", | |
| " params = variables['params']\n", | |
| " \n", | |
| " # Learning rate schedule with warmup\n", | |
| " schedule = optax.warmup_cosine_decay_schedule(\n", | |
| " init_value=0.0,\n", | |
| " peak_value=config.learning_rate,\n", | |
| " warmup_steps=config.warmup_steps,\n", | |
| " decay_steps=config.max_steps,\n", | |
| " end_value=config.learning_rate * 0.1\n", | |
| " )\n", | |
| " \n", | |
| " # Optimizer with gradient clipping\n", | |
| " tx = optax.chain(\n", | |
| " optax.clip_by_global_norm(1.0),\n", | |
| " optax.adamw(learning_rate=schedule, weight_decay=0.01)\n", | |
| " )\n", | |
| " \n", | |
| " # Dynamic scale for mixed precision (helps with numerical stability)\n", | |
| " dynamic_scale = None\n", | |
| " if config.use_mixed_precision:\n", | |
| " dynamic_scale = dynamic_scale_lib.DynamicScale()\n", | |
| " \n", | |
| " return TrainState.create(\n", | |
| " apply_fn=model.apply,\n", | |
| " params=params,\n", | |
| " tx=tx,\n", | |
| " dropout_rng=rng,\n", | |
| " dynamic_scale=dynamic_scale\n", | |
| " )\n", | |
| "\n", | |
| "print(\"Training state setup defined β\")" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "### Loss Function" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "def loss_fn(params, apply_fn, x, y, dropout_rng):\n", | |
| " \"\"\"Compute cross-entropy loss\"\"\"\n", | |
| " logits = apply_fn({'params': params}, x, train=True, rngs={'dropout': dropout_rng})\n", | |
| " loss = optax.softmax_cross_entropy_with_integer_labels(logits, y).mean()\n", | |
| " return loss" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "### Training & Evaluation Steps (pmapped for TPU)\n", | |
| "\n", | |
| "These functions are parallelized across all TPU cores using `pmap`:" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "@partial(jax.pmap, axis_name='devices')\n", | |
| "def train_step(state, x, y):\n", | |
| " \"\"\"Single training step (parallelized across TPU cores)\"\"\"\n", | |
| " # Split RNG for dropout\n", | |
| " dropout_rng, new_dropout_rng = jax.random.split(state.dropout_rng)\n", | |
| " \n", | |
| " # Compute loss and gradients\n", | |
| " loss, grads = jax.value_and_grad(loss_fn)(\n", | |
| " state.params, state.apply_fn, x, y, dropout_rng\n", | |
| " )\n", | |
| " \n", | |
| " # Average gradients across all devices (CRITICAL for multi-device training)\n", | |
| " grads = jax.lax.pmean(grads, axis_name='devices')\n", | |
| " loss = jax.lax.pmean(loss, axis_name='devices')\n", | |
| " \n", | |
| " # Handle mixed precision scaling if enabled\n", | |
| " if state.dynamic_scale is not None:\n", | |
| " # Scale gradients\n", | |
| " grads = state.dynamic_scale.scale_grads(grads)\n", | |
| " # Update parameters with scaled gradients\n", | |
| " new_state = state.apply_gradients(grads=grads, dropout_rng=new_dropout_rng)\n", | |
| " # Update dynamic scale\n", | |
| " new_state = new_state.replace(\n", | |
| " dynamic_scale=state.dynamic_scale.adjust(jnp.isfinite(loss))\n", | |
| " )\n", | |
| " else:\n", | |
| " new_state = state.apply_gradients(grads=grads, dropout_rng=new_dropout_rng)\n", | |
| " \n", | |
| " return new_state, loss\n", | |
| "\n", | |
| "@partial(jax.pmap, axis_name='devices')\n", | |
| "def eval_step(state, x, y):\n", | |
| " \"\"\"Single evaluation step (parallelized across TPU cores)\"\"\"\n", | |
| " logits = state.apply_fn({'params': state.params}, x, train=False)\n", | |
| " loss = optax.softmax_cross_entropy_with_integer_labels(logits, y).mean()\n", | |
| " loss = jax.lax.pmean(loss, axis_name='devices')\n", | |
| " return loss\n", | |
| "\n", | |
| "print(\"Training and evaluation functions defined β\")" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "## π¨ Optimized Text Generation\n", | |
| "\n", | |
| "**Improvements:**\n", | |
| "- KV cache for efficient autoregressive decoding\n", | |
| "- Vectorized top-k sampling (works for batch_size > 1)\n", | |
| "- Support for multiple sequences simultaneously" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "def generate_with_cache(params, apply_fn, prompt_tokens, config, rng, max_new_tokens=100):\n", | |
| " \"\"\"\n", | |
| " Generate text efficiently using KV cache.\n", | |
| " \n", | |
| " Args:\n", | |
| " params: Model parameters\n", | |
| " apply_fn: Model apply function\n", | |
| " prompt_tokens: Initial tokens (can be 1D or 2D for batched generation)\n", | |
| " config: Configuration object\n", | |
| " rng: Random key\n", | |
| " max_new_tokens: Number of tokens to generate\n", | |
| " \n", | |
| " Returns:\n", | |
| " Generated token sequences\n", | |
| " \"\"\"\n", | |
| " # Ensure prompt_tokens is 2D (batch_size, seq_len)\n", | |
| " if prompt_tokens.ndim == 1:\n", | |
| " prompt_tokens = prompt_tokens[None, :]\n", | |
| " \n", | |
| " batch_size = prompt_tokens.shape[0]\n", | |
| " \n", | |
| " # Initialize cache for each layer\n", | |
| " head_dim = config.emb_dim // config.n_heads\n", | |
| " cache_list = None # Start without cache\n", | |
| " \n", | |
| " x = prompt_tokens\n", | |
| " \n", | |
| " for _ in range(max_new_tokens):\n", | |
| " # Crop to max sequence length if needed\n", | |
| " if x.shape[1] > config.block_size:\n", | |
| " x_cond = x[:, -config.block_size:]\n", | |
| " cache_list = None # Reset cache if we had to crop\n", | |
| " else:\n", | |
| " x_cond = x\n", | |
| " \n", | |
| " # Forward pass (with cache if available)\n", | |
| " if cache_list is not None:\n", | |
| " # Only process last token when using cache\n", | |
| " logits, cache_list = apply_fn(\n", | |
| " {'params': params}, \n", | |
| " x_cond[:, -1:], \n", | |
| " train=False, \n", | |
| " cache_list=cache_list\n", | |
| " )\n", | |
| " else:\n", | |
| " # First pass: process entire sequence and initialize cache\n", | |
| " logits, cache_list = apply_fn(\n", | |
| " {'params': params}, \n", | |
| " x_cond, \n", | |
| " train=False, \n", | |
| " cache_list=[None] * config.n_layers\n", | |
| " )\n", | |
| " \n", | |
| " logits = logits[:, -1, :] # Get last token logits\n", | |
| " \n", | |
| " # Temperature scaling\n", | |
| " logits = logits / config.temperature\n", | |
| " \n", | |
| " # Top-k filtering (vectorized for all sequences in batch)\n", | |
| " top_k_logits, top_k_indices = jax.lax.top_k(logits, config.top_k)\n", | |
| " rng, sample_rng = jax.random.split(rng)\n", | |
| " \n", | |
| " # Sample from top-k (works for any batch size)\n", | |
| " probs = nn.softmax(top_k_logits, axis=-1)\n", | |
| " next_idx_in_topk = jax.random.categorical(sample_rng, jnp.log(probs), axis=-1)\n", | |
| " \n", | |
| " # Correctly index for any batch size\n", | |
| " batch_indices = jnp.arange(batch_size)\n", | |
| " next_token = top_k_indices[batch_indices, next_idx_in_topk]\n", | |
| " \n", | |
| " x = jnp.concatenate([x, next_token[:, None]], axis=1)\n", | |
| " \n", | |
| " return x\n", | |
| "\n", | |
| "def generate(params, apply_fn, prompt, config, rng, max_new_tokens=100):\n", | |
| " \"\"\"Convenience function for single-sequence generation\"\"\"\n", | |
| " prompt_tokens = jnp.array(encode(prompt))\n", | |
| " output_tokens = generate_with_cache(\n", | |
| " params, apply_fn, prompt_tokens, config, rng, max_new_tokens\n", | |
| " )\n", | |
| " return decode(list(output_tokens[0]))\n", | |
| "\n", | |
| "# Vectorized generation for multiple prompts\n", | |
| "@jax.jit\n", | |
| "def generate_batch(params, apply_fn, prompt_tokens_batch, config, rng, max_new_tokens=100):\n", | |
| " \"\"\"Generate text for multiple prompts simultaneously\"\"\"\n", | |
| " return generate_with_cache(\n", | |
| " params, apply_fn, prompt_tokens_batch, config, rng, max_new_tokens\n", | |
| " )\n", | |
| "\n", | |
| "print(\"Optimized text generation functions defined β\")" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "## π Initialize Training\n", | |
| "\n", | |
| "Set up the model and replicate across TPU cores:" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "num_devices = jax.device_count()\n", | |
| "print(f\"π Running on {num_devices} TPU cores\")\n", | |
| "\n", | |
| "# Validate batch size\n", | |
| "if config.batch_size % num_devices != 0:\n", | |
| " raise ValueError(f\"Batch size {config.batch_size} must be divisible by {num_devices} devices\")\n", | |
| "\n", | |
| "# Initialize training state\n", | |
| "rng = jax.random.PRNGKey(42)\n", | |
| "rng, init_rng = jax.random.split(rng)\n", | |
| "state = create_train_state(init_rng, config)\n", | |
| "\n", | |
| "# CRITICAL: Replicate state across all TPU cores\n", | |
| "state = jax.device_put_replicated(state, jax.devices())\n", | |
| "\n", | |
| "# Count parameters\n", | |
| "num_params = sum(x.size for x in jax.tree_util.tree_leaves(state.params))\n", | |
| "print(f\"\\nβ Model initialized with {num_params:,} parameters\")\n", | |
| "print(f\"π¦ Batch size: {config.batch_size} ({config.batch_size // num_devices} per device)\")\n", | |
| "print(f\"π§ Training for {config.max_steps} steps\")\n", | |
| "print(f\"β‘ Mixed precision: {config.use_mixed_precision}\")" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "## π Training Loop\n", | |
| "\n", | |
| "Run the main training loop with multi-batch validation:" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "# Training loop\n", | |
| "print(\"\\nπ Starting training...\\n\")\n", | |
| "\n", | |
| "for step in range(config.max_steps):\n", | |
| " # Get batch and shard across devices\n", | |
| " x, y = get_batch('train')\n", | |
| " x = x.reshape(num_devices, -1, config.block_size)\n", | |
| " y = y.reshape(num_devices, -1, config.block_size)\n", | |
| " \n", | |
| " # Training step\n", | |
| " state, train_loss = train_step(state, x, y)\n", | |
| " \n", | |
| " # Logging\n", | |
| " if step % config.log_interval == 0:\n", | |
| " loss_val = float(train_loss[0])\n", | |
| " print(f\"Step {step:5d} | Train Loss: {loss_val:.4f}\")\n", | |
| " \n", | |
| " # Evaluation (with multiple batches for stability)\n", | |
| " if step % config.eval_interval == 0 and step > 0:\n", | |
| " val_losses = []\n", | |
| " for _ in range(config.num_eval_batches):\n", | |
| " x_val, y_val = get_batch('val')\n", | |
| " x_val = x_val.reshape(num_devices, -1, config.block_size)\n", | |
| " y_val = y_val.reshape(num_devices, -1, config.block_size)\n", | |
| " val_loss = eval_step(state, x_val, y_val)\n", | |
| " val_losses.append(float(val_loss[0]))\n", | |
| " \n", | |
| " avg_val_loss = np.mean(val_losses)\n", | |
| " print(f\"{'='*60}\")\n", | |
| " print(f\"Step {step:5d} | Val Loss: {avg_val_loss:.4f} (avg over {config.num_eval_batches} batches)\")\n", | |
| " print(f\"{'='*60}\\n\")\n", | |
| " \n", | |
| " # Generate sample text (using first device's params)\n", | |
| " rng, gen_rng = jax.random.split(rng)\n", | |
| " sample = generate(\n", | |
| " jax.tree_util.tree_map(lambda x: x[0], state.params),\n", | |
| " jax.tree_util.tree_map(lambda x: x[0], state).apply_fn,\n", | |
| " \"To be or not to be\",\n", | |
| " config,\n", | |
| " gen_rng,\n", | |
| " max_new_tokens=100\n", | |
| " )\n", | |
| " print(f\"π Sample generation:\")\n", | |
| " print(f\"{sample}\")\n", | |
| " print()\n", | |
| "\n", | |
| "print(\"\\nβ Training complete!\")" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "## πΎ Save Checkpoint" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "# Unreplicate state (take from first device)\n", | |
| "unreplicated_state = jax.tree_util.tree_map(lambda x: x[0], state)\n", | |
| "\n", | |
| "# Save checkpoint\n", | |
| "checkpoint_dir = '/home/claude/checkpoints'\n", | |
| "checkpoints.save_checkpoint(checkpoint_dir, unreplicated_state, step=config.max_steps, keep=3)\n", | |
| "print(f\"πΎ Checkpoint saved to {checkpoint_dir}\")" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "## π Interactive Text Generation\n", | |
| "\n", | |
| "Try generating text with different prompts:" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "# Get unreplicated params for generation\n", | |
| "gen_params = jax.tree_util.tree_map(lambda x: x[0], state.params)\n", | |
| "gen_apply_fn = jax.tree_util.tree_map(lambda x: x[0], state).apply_fn\n", | |
| "\n", | |
| "# Try different prompts\n", | |
| "prompts = [\n", | |
| " \"ROMEO:\",\n", | |
| " \"To be or not to be\",\n", | |
| " \"What light through yonder\",\n", | |
| " \"First Citizen:\\n\"\n", | |
| "]\n", | |
| "\n", | |
| "print(\"π Generating samples...\\n\")\n", | |
| "for prompt in prompts:\n", | |
| " rng, gen_rng = jax.random.split(rng)\n", | |
| " output = generate(gen_params, gen_apply_fn, prompt, config, gen_rng, max_new_tokens=150)\n", | |
| " print(f\"Prompt: '{prompt}'\")\n", | |
| " print(f\"{output}\")\n", | |
| " print(\"\\n\" + \"=\"*80 + \"\\n\")" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "## π Batched Generation Demo\n", | |
| "\n", | |
| "Generate multiple sequences in parallel (efficient on TPU):" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "# Prepare multiple prompts\n", | |
| "batch_prompts = [\"ROMEO:\", \"JULIET:\", \"To be\", \"What light\"]\n", | |
| "\n", | |
| "# Pad prompts to same length\n", | |
| "prompt_tokens_list = [jnp.array(encode(p)) for p in batch_prompts]\n", | |
| "max_prompt_len = max(len(p) for p in prompt_tokens_list)\n", | |
| "padded_prompts = []\n", | |
| "for pt in prompt_tokens_list:\n", | |
| " padding = jnp.zeros(max_prompt_len - len(pt), dtype=jnp.int32)\n", | |
| " padded_prompts.append(jnp.concatenate([pt, padding]))\n", | |
| "\n", | |
| "prompt_batch = jnp.stack(padded_prompts)\n", | |
| "\n", | |
| "# Generate in batch\n", | |
| "rng, gen_rng = jax.random.split(rng)\n", | |
| "generated_batch = generate_batch(gen_params, gen_apply_fn, prompt_batch, config, gen_rng, max_new_tokens=80)\n", | |
| "\n", | |
| "print(\"π Batched generation results:\\n\")\n", | |
| "for i, (prompt, tokens) in enumerate(zip(batch_prompts, generated_batch)):\n", | |
| " print(f\"Prompt {i+1}: '{prompt}'\")\n", | |
| " print(decode(list(tokens)))\n", | |
| " print(\"\\n\" + \"-\"*80 + \"\\n\")" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "## π§ Experiment with Generation Parameters" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "prompt = \"ROMEO:\"\n", | |
| "\n", | |
| "# Try different temperatures\n", | |
| "print(\"π‘οΈ Temperature experiments:\\n\")\n", | |
| "for temp in [0.5, 0.8, 1.0, 1.2]:\n", | |
| " old_temp = config.temperature\n", | |
| " config.temperature = temp\n", | |
| " rng, gen_rng = jax.random.split(rng)\n", | |
| " output = generate(gen_params, gen_apply_fn, prompt, config, gen_rng, max_new_tokens=100)\n", | |
| " print(f\"Temperature: {temp}\")\n", | |
| " print(output)\n", | |
| " print(\"\\n\" + \"-\"*80 + \"\\n\")\n", | |
| " config.temperature = old_temp" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "## π Final Training Analysis" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "# Compute validation loss on multiple batches\n", | |
| "num_val_batches = 20\n", | |
| "val_losses = []\n", | |
| "\n", | |
| "print(f\"Computing validation loss over {num_val_batches} batches...\")\n", | |
| "for _ in range(num_val_batches):\n", | |
| " x_val, y_val = get_batch('val')\n", | |
| " x_val = x_val.reshape(num_devices, -1, config.block_size)\n", | |
| " y_val = y_val.reshape(num_devices, -1, config.block_size)\n", | |
| " val_loss = eval_step(state, x_val, y_val)\n", | |
| " val_losses.append(float(val_loss[0]))\n", | |
| "\n", | |
| "avg_val_loss = np.mean(val_losses)\n", | |
| "std_val_loss = np.std(val_losses)\n", | |
| "print(f\"\\nπ Final Validation Loss: {avg_val_loss:.4f} Β± {std_val_loss:.4f}\")\n", | |
| "print(f\"π Validation Perplexity: {np.exp(avg_val_loss):.2f}\")" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "## π― Key Optimizations Implemented\n", | |
| "\n", | |
| "### 1. **Precomputed Causal Mask**\n", | |
| "β Created once globally, reused in every forward pass\n", | |
| "- Saves computation time\n", | |
| "- More memory efficient\n", | |
| "\n", | |
| "### 2. **KV Caching for Generation**\n", | |
| "β Stores key/value tensors from previous tokens\n", | |
| "- Dramatically faster autoregressive generation\n", | |
| "- Only processes new tokens (not entire sequence)\n", | |
| "\n", | |
| "### 3. **Vectorized Top-k Sampling**\n", | |
| "β Fixed indexing to work with any batch size:\n", | |
| "```python\n", | |
| "batch_indices = jnp.arange(batch_size)\n", | |
| "next_token = top_k_indices[batch_indices, next_idx_in_topk]\n", | |
| "```\n", | |
| "\n", | |
| "### 4. **Multi-Batch Validation**\n", | |
| "β Averages loss over multiple batches for stability\n", | |
| "- More reliable metrics\n", | |
| "- Configurable via `num_eval_batches`\n", | |
| "\n", | |
| "### 5. **Mixed Precision Training**\n", | |
| "β Uses bfloat16 for computation:\n", | |
| "- Faster training on TPU\n", | |
| "- Lower memory usage\n", | |
| "- Dynamic loss scaling for numerical stability\n", | |
| "\n", | |
| "### 6. **Efficient Batched Generation**\n", | |
| "β Can generate multiple sequences simultaneously\n", | |
| "- Better TPU utilization\n", | |
| "- JIT-compiled for speed\n", | |
| "\n", | |
| "### 7. **Gradient Checkpointing Support**\n", | |
| "β Optional flag for very large models\n", | |
| "- Trade computation for memory\n", | |
| "- Enables training larger models" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "## π Next-Level Upgrades\n", | |
| "\n", | |
| "### Already Implemented:\n", | |
| "- β Mixed precision with dynamic scaling\n", | |
| "- β KV caching for generation\n", | |
| "- β Vectorized sampling\n", | |
| "- β Proper gradient synchronization\n", | |
| "\n", | |
| "### Future Enhancements:\n", | |
| "1. **Model Parallelism with `pjit`**\n", | |
| " - Split model across TPU pods\n", | |
| " - Use GSPMD sharding\n", | |
| "\n", | |
| "2. **Better Tokenization**\n", | |
| " - Implement BPE (byte-pair encoding)\n", | |
| " - Or use SentencePiece\n", | |
| " - Better for large corpora\n", | |
| "\n", | |
| "3. **Flash Attention**\n", | |
| " - Memory-efficient attention\n", | |
| " - Enables longer sequences\n", | |
| "\n", | |
| "4. **Larger Datasets**\n", | |
| " - OpenWebText\n", | |
| " - C4 (Colossal Clean Crawled Corpus)\n", | |
| "\n", | |
| "5. **Experiment Tracking**\n", | |
| " - Weights & Biases integration\n", | |
| " - TensorBoard logging\n", | |
| "\n", | |
| "6. **Gradient Accumulation**\n", | |
| " - Effective larger batch sizes\n", | |
| " - Better for memory-constrained setups" | |
| ] | |
| } | |
| ], | |
| "metadata": { | |
| "kernelspec": { | |
| "display_name": "Python 3", | |
| "language": "python", | |
| "name": "python3" | |
| }, | |
| "language_info": { | |
| "codemirror_mode": { | |
| "name": "ipython", | |
| "version": 3 | |
| }, | |
| "file_extension": ".py", | |
| "mimetype": "text/x-python", | |
| "name": "python", | |
| "nbconvert_exporter": "python", | |
| "pygments_lexer": "ipython3", | |
| "version": "3.10.0" | |
| } | |
| }, | |
| "nbformat": 4, | |
| "nbformat_minor": 4 | |
| } |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment