Skip to content

Instantly share code, notes, and snippets.

@awni
Created October 7, 2025 17:16
Show Gist options
  • Select an option

  • Save awni/de59fc3b332827b261823200451a0b40 to your computer and use it in GitHub Desktop.

Select an option

Save awni/de59fc3b332827b261823200451a0b40 to your computer and use it in GitHub Desktop.
import math
import time
from functools import partial
import mlx.core as mx
import mlx.nn as nn
import mlx.optimizers as optim
import numpy as np
from mlx.utils import tree_flatten
class TransformerLM(nn.Module):
def __init__(
self,
vocab_size: int,
num_layers: int,
dims: int,
num_heads: int,
):
super().__init__()
self.embedding = nn.Embedding(vocab_size, dims)
self.transformer = nn.TransformerEncoder(
num_layers,
dims,
num_heads,
)
self.out_proj = nn.Linear(dims, vocab_size)
def __call__(self, x):
L = x.shape[1]
x = self.embedding(x)
x = self.transformer(x, "causal")
return self.out_proj(x)
def main(args, dtype):
batch_size = args.batch_size
context_size = args.context_size
steps_per_report = args.steps_per_report
# Initialize model:
model = TransformerLM(
args.vocab_size,
args.num_blocks,
args.dim,
args.num_heads,
)
model.set_dtype(dtype)
mx.eval(model.parameters())
nparams = sum(x.size for k, x in tree_flatten(model.parameters()))
print(f"Training a transformer with {nparams / 1024**2:.3f} M parameters")
def loss_fn(model, x, y):
logits = model(x)
losses = nn.losses.cross_entropy(logits, y)
return mx.mean(losses)
optimizer = optim.SGD(learning_rate=args.learning_rate)
loss_and_grad_fn = nn.value_and_grad(model, loss_fn)
state = [model.state, optimizer.state]
@partial(mx.compile, inputs=state, outputs=state)
def step(inputs, targets):
loss_and_grad_fn = nn.value_and_grad(model, loss_fn)
loss, grads = loss_and_grad_fn(model, inputs, targets)
optimizer.update(model, grads)
return loss
inputs = mx.random.randint(
low=0, high=args.vocab_size, shape=(args.batch_size, args.context_size)
)
targets = mx.random.randint(
low=0, high=args.vocab_size, shape=(args.batch_size, args.context_size)
)
losses = []
tic = time.perf_counter()
for it in range(args.num_iters):
inputs, targets = map(mx.array, (inputs, targets))
loss = step(inputs, targets)
mx.eval(state)
losses.append(loss.item())
if (it + 1) % steps_per_report == 0:
train_loss = np.mean(losses)
toc = time.perf_counter()
peak_mem = mx.get_peak_memory() / 2**30
print(
f"Iter {it + 1}: Train loss {train_loss:.3f}, "
f"It/sec {steps_per_report / (toc - tic):.3f},",
f"Peak memory {peak_mem:.3f} (GB)",
flush=True,
)
losses = []
tic = time.perf_counter()
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser("Train a decoder-only Transformer LM with MLX.")
parser.add_argument(
"--context_size",
type=int,
default=512,
help="Context size in tokens of the model.",
)
parser.add_argument(
"--num_blocks", type=int, default=26, help="Number of Transformer blocks."
)
parser.add_argument(
"--dim",
type=int,
default=1024,
help="Dimensionality of embeddings and hidden layers.",
)
parser.add_argument(
"--num_heads",
type=int,
default=8,
help="Number of heads used for multi-head attention",
)
parser.add_argument(
"--vocab-size",
type=int,
default=151_669,
help="Vocab size",
)
parser.add_argument("--batch_size", type=int, default=8, help="Minibatch size.")
parser.add_argument(
"--num_iters", type=int, default=100, help="Iterations to train for."
)
parser.add_argument(
"--learning_rate", type=float, default=1e-3, help="SGD learning rate."
)
parser.add_argument(
"--steps_per_report",
type=int,
default=10,
help="Number of training steps between loss reporting.",
)
parser.add_argument(
"--dtype", choices=["float32", "bfloat16", "float16"], default="bfloat16"
)
args = parser.parse_args()
dtype = getattr(mx, args.dtype)
main(args, dtype=dtype)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment