Created
October 7, 2025 17:16
-
-
Save awni/de59fc3b332827b261823200451a0b40 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import math | |
| import time | |
| from functools import partial | |
| import mlx.core as mx | |
| import mlx.nn as nn | |
| import mlx.optimizers as optim | |
| import numpy as np | |
| from mlx.utils import tree_flatten | |
| class TransformerLM(nn.Module): | |
| def __init__( | |
| self, | |
| vocab_size: int, | |
| num_layers: int, | |
| dims: int, | |
| num_heads: int, | |
| ): | |
| super().__init__() | |
| self.embedding = nn.Embedding(vocab_size, dims) | |
| self.transformer = nn.TransformerEncoder( | |
| num_layers, | |
| dims, | |
| num_heads, | |
| ) | |
| self.out_proj = nn.Linear(dims, vocab_size) | |
| def __call__(self, x): | |
| L = x.shape[1] | |
| x = self.embedding(x) | |
| x = self.transformer(x, "causal") | |
| return self.out_proj(x) | |
| def main(args, dtype): | |
| batch_size = args.batch_size | |
| context_size = args.context_size | |
| steps_per_report = args.steps_per_report | |
| # Initialize model: | |
| model = TransformerLM( | |
| args.vocab_size, | |
| args.num_blocks, | |
| args.dim, | |
| args.num_heads, | |
| ) | |
| model.set_dtype(dtype) | |
| mx.eval(model.parameters()) | |
| nparams = sum(x.size for k, x in tree_flatten(model.parameters())) | |
| print(f"Training a transformer with {nparams / 1024**2:.3f} M parameters") | |
| def loss_fn(model, x, y): | |
| logits = model(x) | |
| losses = nn.losses.cross_entropy(logits, y) | |
| return mx.mean(losses) | |
| optimizer = optim.SGD(learning_rate=args.learning_rate) | |
| loss_and_grad_fn = nn.value_and_grad(model, loss_fn) | |
| state = [model.state, optimizer.state] | |
| @partial(mx.compile, inputs=state, outputs=state) | |
| def step(inputs, targets): | |
| loss_and_grad_fn = nn.value_and_grad(model, loss_fn) | |
| loss, grads = loss_and_grad_fn(model, inputs, targets) | |
| optimizer.update(model, grads) | |
| return loss | |
| inputs = mx.random.randint( | |
| low=0, high=args.vocab_size, shape=(args.batch_size, args.context_size) | |
| ) | |
| targets = mx.random.randint( | |
| low=0, high=args.vocab_size, shape=(args.batch_size, args.context_size) | |
| ) | |
| losses = [] | |
| tic = time.perf_counter() | |
| for it in range(args.num_iters): | |
| inputs, targets = map(mx.array, (inputs, targets)) | |
| loss = step(inputs, targets) | |
| mx.eval(state) | |
| losses.append(loss.item()) | |
| if (it + 1) % steps_per_report == 0: | |
| train_loss = np.mean(losses) | |
| toc = time.perf_counter() | |
| peak_mem = mx.get_peak_memory() / 2**30 | |
| print( | |
| f"Iter {it + 1}: Train loss {train_loss:.3f}, " | |
| f"It/sec {steps_per_report / (toc - tic):.3f},", | |
| f"Peak memory {peak_mem:.3f} (GB)", | |
| flush=True, | |
| ) | |
| losses = [] | |
| tic = time.perf_counter() | |
| if __name__ == "__main__": | |
| import argparse | |
| parser = argparse.ArgumentParser("Train a decoder-only Transformer LM with MLX.") | |
| parser.add_argument( | |
| "--context_size", | |
| type=int, | |
| default=512, | |
| help="Context size in tokens of the model.", | |
| ) | |
| parser.add_argument( | |
| "--num_blocks", type=int, default=26, help="Number of Transformer blocks." | |
| ) | |
| parser.add_argument( | |
| "--dim", | |
| type=int, | |
| default=1024, | |
| help="Dimensionality of embeddings and hidden layers.", | |
| ) | |
| parser.add_argument( | |
| "--num_heads", | |
| type=int, | |
| default=8, | |
| help="Number of heads used for multi-head attention", | |
| ) | |
| parser.add_argument( | |
| "--vocab-size", | |
| type=int, | |
| default=151_669, | |
| help="Vocab size", | |
| ) | |
| parser.add_argument("--batch_size", type=int, default=8, help="Minibatch size.") | |
| parser.add_argument( | |
| "--num_iters", type=int, default=100, help="Iterations to train for." | |
| ) | |
| parser.add_argument( | |
| "--learning_rate", type=float, default=1e-3, help="SGD learning rate." | |
| ) | |
| parser.add_argument( | |
| "--steps_per_report", | |
| type=int, | |
| default=10, | |
| help="Number of training steps between loss reporting.", | |
| ) | |
| parser.add_argument( | |
| "--dtype", choices=["float32", "bfloat16", "float16"], default="bfloat16" | |
| ) | |
| args = parser.parse_args() | |
| dtype = getattr(mx, args.dtype) | |
| main(args, dtype=dtype) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment