Let's walk through what happens when you send a message to the NanoChat server. We'll trace the entire journey from typing "hello world" in your browser to receiving a response, explaining everything along the way.
┌─────────────────────────────────────────────────────────────────┐
│ Browser (Frontend) │
│ - HTML/CSS/JavaScript UI (nanochat/ui.html) │
│ - Manages conversation history │
│ - Streams responses via Server-Sent Events │
└────────────────┬────────────────────────────────────────────────┘
│ HTTP POST /chat/completions
│ JSON: {messages, temperature, top_k, max_tokens}
▼
┌─────────────────────────────────────────────────────────────────┐
│ FastAPI Server (chat_web.py) │
│ - Validates requests (abuse prevention) │
│ - Manages worker pool │
│ - Coordinates tokenization & generation │
└────────────────┬────────────────────────────────────────────────┘
│
▼
┌─────────────────────────────────────────────────────────────────┐
│ Worker Pool (Multi-GPU) │
│ ┌──────────────┐ ┌──────────────┐ ┌──────────────┐ │
│ │ Worker 0 │ │ Worker 1 │ │ Worker N │ │
│ │ GPU 0 │ │ GPU 1 │ │ GPU N │ │
│ │ - Model │ │ - Model │ │ - Model │ │
│ │ - Engine │ │ - Engine │ │ - Engine │ │
│ │ - Tokenizer │ │ - Tokenizer │ │ - Tokenizer │ │
│ └──────────────┘ └──────────────┘ └──────────────┘ │
└────────────────┬────────────────────────────────────────────────┘
│
▼
┌─────────────────────────────────────────────────────────────────┐
│ Engine (engine.py) │
│ - KV Cache management │
│ - Prefill phase (parallel token processing) │
│ - Decode phase (autoregressive generation) │
│ - Tool use coordination (calculator) │
└────────────────┬────────────────────────────────────────────────┘
│
▼
┌─────────────────────────────────────────────────────────────────┐
│ GPT Model (gpt.py) │
│ - Token embedding layer │
│ - 12 Transformer blocks (configurable) │
│ - Rotary positional embeddings (RoPE) │
│ - Multi-Query Attention (MQA) │
│ - ReLU² activation │
│ - Language model head │
└─────────────────────────────────────────────────────────────────┘
| Component | File | Responsibility |
|---|---|---|
| Frontend UI | nanochat/ui.html |
User interface, message history, streaming display |
| Web Server | scripts/chat_web.py |
HTTP API, validation, worker management |
| Worker Pool | scripts/chat_web.py (lines 98-149) |
Multi-GPU load balancing |
| Engine | nanochat/engine.py |
Generation loop, KV cache, sampling |
| GPT Model | nanochat/gpt.py |
Transformer neural network |
| Tokenizer | nanochat/tokenizer.py |
Text ↔ token ID conversion |
Let's trace what happens when you type "hello world" and press Send.
File: nanochat/ui.html (lines 524-544)
async function sendMessage() {
const message = chatInput.value.trim(); // "hello world"
if (!message || isGenerating) return;
chatInput.value = "";
chatInput.style.height = "auto";
const userMessageIndex = messages.length;
messages.push({ role: "user", content: message });
addMessage("user", message, userMessageIndex);
await generateAssistantResponse();
}What happens:
- JavaScript captures the text "hello world"
- Adds it to the
messagesarray with role'user' - Displays it in the chat UI
- Calls
generateAssistantResponse()to fetch the AI's reply
File: nanochat/ui.html (lines 390-401)
const response = await fetch(`${API_URL}/chat/completions`, {
method: "POST",
headers: { "Content-Type": "application/json" },
body: JSON.stringify({
messages: messages, // [{ role: 'user', content: 'hello world' }]
temperature: currentTemperature, // e.g., 0.8
top_k: currentTopK, // e.g., 50
max_tokens: 512,
}),
});HTTP Request:
POST /chat/completions HTTP/1.1
Host: localhost:8000
Content-Type: application/json
{
"messages": [{"role": "user", "content": "hello world"}],
"temperature": 0.8,
"top_k": 50,
"max_tokens": 512
}File: scripts/chat_web.py (lines 313-324)
@app.post("/chat/completions")
async def chat_completions(request: ChatRequest):
"""Chat completion endpoint (streaming only) - uses worker pool for multi-GPU."""
# Basic validation to prevent abuse
validate_chat_request(request)
# Log incoming conversation to console
logger.info("="*20)
for i, message in enumerate(request.messages):
logger.info(f"[{message.role.upper()}]: {message.content}")
logger.info("-"*20)Request Validation (lines 160-221):
The server applies strict limits to prevent abuse:
# Abuse prevention limits (lines 52-61)
MAX_MESSAGES_PER_REQUEST = 500
MAX_MESSAGE_LENGTH = 8000
MAX_TOTAL_CONVERSATION_LENGTH = 32000
MIN_TEMPERATURE = 0.0
MAX_TEMPERATURE = 2.0
MIN_TOP_K = 1
MAX_TOP_K = 200
MIN_MAX_TOKENS = 1
MAX_MAX_TOKENS = 4096For our "hello world" message:
- ✓ Message count: 1 < 500
- ✓ Message length: 11 chars < 8000
- ✓ Total conversation: 11 chars < 32000
- ✓ Temperature: 0.8 ∈ [0.0, 2.0]
- ✓ Top-k: 50 ∈ [1, 200]
- ✓ Max tokens: 512 ∈ [1, 4096]
Why these limits? They protect against:
- Memory exhaustion (very long conversations)
- CPU/GPU abuse (extremely long generation)
- Denial of Service attacks
File: scripts/chat_web.py (lines 326-328)
# Acquire a worker from the pool (will wait if all are busy)
worker_pool = app.state.worker_pool
worker = await worker_pool.acquire_worker()Worker Pool Architecture (lines 98-149):
Each worker contains a complete copy of the model on a separate GPU:
@dataclass
class Worker:
"""A worker with a model loaded on a specific GPU."""
gpu_id: int # e.g., 0, 1, 2, 3
device: torch.device # e.g., cuda:0
engine: Engine # Generation engine with KV cache
tokenizer: object # BPE tokenizer
autocast_ctx: torch.amp.autocast # Mixed precision contextKey Points:
- Data Parallelism: Each GPU has a full model replica
- Load Balancing: Requests are distributed across available workers
- Async Queue: If all workers are busy, new requests wait in an async queue
- Resource Isolation: Each worker is independent (no shared state)
Example with 4 GPUs:
Request 1 → Worker 0 (GPU 0) → Busy
Request 2 → Worker 1 (GPU 1) → Busy
Request 3 → Worker 2 (GPU 2) → Busy
Request 4 → Worker 3 (GPU 3) → Busy
Request 5 → (waits in queue)
File: scripts/chat_web.py (lines 331-349)
# Build conversation tokens
bos = worker.tokenizer.get_bos_token_id()
user_start = worker.tokenizer.encode_special("<|user_start|>")
user_end = worker.tokenizer.encode_special("<|user_end|>")
assistant_start = worker.tokenizer.encode_special("<|assistant_start|>")
assistant_end = worker.tokenizer.encode_special("<|assistant_end|>")
conversation_tokens = [bos]
for message in request.messages:
if message.role == "user":
conversation_tokens.append(user_start)
conversation_tokens.extend(worker.tokenizer.encode(message.content))
conversation_tokens.append(user_end)
elif message.role == "assistant":
conversation_tokens.append(assistant_start)
conversation_tokens.extend(worker.tokenizer.encode(message.content))
conversation_tokens.append(assistant_end)
conversation_tokens.append(assistant_start)For "hello world", the token sequence becomes:
[
65527, # <|bos|> (Beginning of Sequence)
65528, # <|user_start|> (Start of user message)
52300, # "hello" (BPE token for "hello")
883, # " world" (BPE token for " world" - note the leading space!)
65529, # <|user_end|> (End of user message)
65530 # <|assistant_start|> (Primes the model to respond)
]
Note: "hello world" tokenizes to just 2 tokens:
52300= "hello"883= " world" (includes the leading space)
Why special tokens?
- Role Separation: The model learns different behaviors for user vs assistant
- Tool Use: Additional tokens like
<|python_start|>enable tool calling - Structured Conversations: Clear boundaries between messages prevent confusion
- Supervision: During training, we only compute loss on assistant tokens
Multi-turn Example:
User: "What is 2+2?"
Assistant: "4"
User: "And 3+3?"
→ Tokenization:
[
<|bos|>,
<|user_start|>, tokens("What is 2+2?"), <|user_end|>,
<|assistant_start|>, tokens("4"), <|assistant_end|>,
<|user_start|>, tokens("And 3+3?"), <|user_end|>,
<|assistant_start|> ← primes for next response
]
File: nanochat/engine.py (lines 208-220)
The engine performs generation in two phases: Prefill and Decode.
The prefill phase processes all input tokens in parallel to populate the KV cache.
# 1) Run a batch 1 prefill of the prompt tokens
m = self.model.config
kv_model_kwargs = {
"num_heads": m.n_kv_head,
"head_dim": m.n_embd // m.n_head,
"num_layers": m.n_layer
}
kv_cache_prefill = KVCache(
batch_size=1,
seq_len=len(tokens), # Only needs to store this many tokens
**kv_model_kwargs,
)
ids = torch.tensor([tokens], dtype=torch.long, device=device)
logits = self.model.forward(ids, kv_cache=kv_cache_prefill) # (1, T, vocab_size)
logits = logits[:, -1, :] # Take only the last token's logitsWhat's happening:
-
Create KV Cache: Allocates memory to store attention keys/values
- Shape:
(num_layers, 2, batch_size, num_heads, seq_len, head_dim) - Example:
(12, 2, 1, 6, 6, 128)for our 6 input tokens
- Shape:
-
Forward Pass: Process all 6 tokens through the Transformer
- Token embedding
- 12 Transformer layers (attention + MLP)
- Each layer caches its K, V matrices
- Output logits for ALL positions
-
Extract Last Logits: Only the last position predicts the next token
- Input:
[bos, user_start, "hello", " world", user_end, assistant_start] - We want to generate the token AFTER
assistant_start - So we take
logits[:, -1, :](logits at position 5)
- Input:
Why prefill is fast:
- Parallel Processing: All tokens computed simultaneously (not sequential)
- Matrix Operations: GPUs excel at large matrix multiplications
- Single Forward Pass: Processes entire prompt at once
Prefill Complexity:
- Time: O(T²) where T = prompt length (due to attention)
- But in practice, GPUs make this very fast for T < 2048
next_ids = sample_next_token(logits, rng, temperature, top_k) # (B, 1)
sampled_tokens = next_ids[:, 0].tolist()This samples the first output token (e.g., "Hi") using temperature and top-k sampling (explained in Deep Dive: Token Sampling).
File: nanochat/engine.py (lines 222-296)
After prefill, we replicate the KV cache for each sample and enter the decode loop.
# 2) Replicate the KV cache for each sample/row
kv_length_hint = (len(tokens) + max_tokens) if max_tokens is not None else self.model.config.sequence_len
kv_cache_decode = KVCache(
batch_size=num_samples,
seq_len=kv_length_hint,
**kv_model_kwargs,
)
kv_cache_decode.prefill(kv_cache_prefill)
del kv_cache_prefill # Free memory
# 3) Initialize states for each sample
row_states = [RowState(tokens.copy()) for _ in range(num_samples)]
# 4) Main generation loop
num_generated = 0
first_iteration = True
while True:
# Stop conditions
if max_tokens is not None and num_generated >= max_tokens:
break
if all(state.completed for state in row_states):
break
if first_iteration:
# Use the token we already sampled from prefill
sampled_tokens = [sampled_tokens[0]] * num_samples
first_iteration = False
else:
# Forward the model and get the next token for each row
logits = self.model.forward(ids, kv_cache=kv_cache_decode) # (B, T, vocab_size)
logits = logits[:, -1, :] # (B, vocab_size) at last time step
next_ids = sample_next_token(logits, rng, temperature, top_k) # (B, 1)
sampled_tokens = next_ids[:, 0].tolist()
# Process each row: choose next token, update state, handle tool use
token_column = []
token_masks = []
for i, state in enumerate(row_states):
# Determine if this token is forced (tool output) or sampled
is_forced = len(state.forced_tokens) > 0
token_masks.append(0 if is_forced else 1)
next_token = state.forced_tokens.popleft() if is_forced else sampled_tokens[i]
token_column.append(next_token)
# Update state
state.current_tokens.append(next_token)
# Check for completion
if next_token == assistant_end or next_token == bos:
state.completed = True
# Handle tool logic (calculator)
# ... (see Tool Use section)
# Yield the token column
yield token_column, token_masks
num_generated += 1
# Prepare ids for next iteration (single token forward pass)
ids = torch.tensor(token_column, dtype=torch.long, device=device).unsqueeze(1)Decode Loop Mechanics:
Each iteration:
- Forward Pass: Process ONE new token through the model
- Sample: Get next token using temperature + top-k
- Update KV Cache: Append new K, V to cache
- Yield: Return token to caller (for streaming)
- Repeat: Until max_tokens or
<|assistant_end|>
Example Decode Sequence:
Iteration 1: [assistant_start] → Sample "Hi" → Yield "Hi"
Iteration 2: ["Hi"] → Sample "!" → Yield "!"
Iteration 3: ["!"] → Sample " How" → Yield " How"
Iteration 4: [" How"] → Sample " can" → Yield " can"
Iteration 5: [" can"] → Sample " I" → Yield " I"
Iteration 6: [" I"] → Sample " help" → Yield " help"
Iteration 7: [" help"] → Sample "?" → Yield "?"
Iteration 8: ["?"] → Sample <|assistant_end|> → STOP
Decode Complexity:
- Time per token: O(T) where T = total sequence length so far
- Total time for N tokens: O(N × T) ≈ O(N²) worst case
- BUT: With KV cache, each iteration only processes 1 new token (very fast!)
Without KV Cache (naive approach):
Iteration 1: Process [assistant_start] → 1 token
Iteration 2: Process [assistant_start, "Hi"] → 2 tokens
Iteration 3: Process [assistant_start, "Hi", "!"] → 3 tokens
...
Iteration N: Process [all N tokens] → N tokens
Total: 1 + 2 + 3 + ... + N = O(N²) tokens processed
With KV Cache (our approach):
Prefill: Process [assistant_start] → 1 token
Iteration 1: Process ["Hi"] → 1 token
Iteration 2: Process ["!"] → 1 token
...
Iteration N: Process [one token] → 1 token
Total: N + 1 tokens processed (linear!)
KV Cache saves ~50x computation for typical responses!
File: nanochat/gpt.py (lines 244-276)
Each iteration of the decode loop calls model.forward():
def forward(self, idx, targets=None, kv_cache=None, loss_reduction='mean'):
B, T = idx.size() # Batch size, sequence length
# Grab rotary embeddings for current sequence length
T0 = 0 if kv_cache is None else kv_cache.get_pos()
cos_sin = self.cos[:, T0:T0+T], self.sin[:, T0:T0+T]
# Forward the trunk of the Transformer
x = self.transformer.wte(idx) # Token embeddings
x = norm(x) # RMSNorm after embedding
for block in self.transformer.h:
x = block(x, cos_sin, kv_cache)
x = norm(x) # Final RMSNorm
# Forward the lm_head (compute logits)
logits = self.lm_head(x)
logits = 15 * torch.tanh(logits / 15) # Logits softcap
return logitsTransformer Block (lines 126-135):
class Block(nn.Module):
def forward(self, x, cos_sin, kv_cache):
x = x + self.attn(norm(x), cos_sin, kv_cache) # Attention with residual
x = x + self.mlp(norm(x)) # MLP with residual
return xPre-Norm Architecture:
- RMSNorm BEFORE attention (not after like original Transformer)
- RMSNorm BEFORE MLP
- Residual connections add the original
xback - More stable training, better gradient flow
Attention (lines 66-110):
def forward(self, x, cos_sin, kv_cache):
B, T, C = x.size()
# Project to Q, K, V
q = self.c_q(x).view(B, T, self.n_head, self.head_dim)
k = self.c_k(x).view(B, T, self.n_kv_head, self.head_dim)
v = self.c_v(x).view(B, T, self.n_kv_head, self.head_dim)
# Apply Rotary Embeddings (RoPE)
cos, sin = cos_sin
q = apply_rotary_emb(q, cos, sin)
k = apply_rotary_emb(k, cos, sin)
# Apply QK Normalization
q, k = norm(q), norm(k)
# Transpose for attention: (B, T, H, D) → (B, H, T, D)
q, k, v = q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2)
# KV Cache: insert current k,v and get full view
if kv_cache is not None:
k, v = kv_cache.insert_kv(self.layer_idx, k, v)
# Scaled Dot-Product Attention
y = F.scaled_dot_product_attention(q, k, v, is_causal=True)
# Re-assemble heads and project back
y = y.transpose(1, 2).contiguous().view(B, T, -1)
y = self.c_proj(y)
return yKey Optimizations:
- Rotary Embeddings (RoPE): Encodes position without learnable parameters
- QK Normalization: Stabilizes attention, prevents runaway attention scores
- Multi-Query Attention (MQA): Fewer KV heads than Q heads (6 KV, 6 Q in base config)
- Saves memory in KV cache
- Faster inference
- Minimal quality loss
- Flash Attention: PyTorch's
scaled_dot_product_attentionuses optimized kernels
MLP (lines 113-123):
class MLP(nn.Module):
def forward(self, x):
x = self.c_fc(x) # Linear: d_model → 4*d_model
x = F.relu(x).square() # ReLU² (not GELU!)
x = self.c_proj(x) # Linear: 4*d_model → d_model
return xWhy ReLU² instead of GELU?
- Simpler (no approximations)
- Faster to compute
- Works well in practice
- Novel activation used in recent models
File: nanochat/engine.py (lines 156-172)
After the model produces logits, we sample the next token:
@torch.inference_mode()
def sample_next_token(logits, rng, temperature=1.0, top_k=None):
"""Sample a single next token from logits of shape (B, vocab_size). Returns (B, 1)."""
# Temperature = 0 → Greedy (always pick highest probability)
if temperature == 0.0:
return torch.argmax(logits, dim=-1, keepdim=True)
# Top-k filtering
if top_k is not None:
k = min(top_k, logits.size(-1))
vals, idx = torch.topk(logits, k, dim=-1) # Get top-k logits
vals = vals / temperature # Scale by temperature
probs = F.softmax(vals, dim=-1) # Convert to probabilities
choice = torch.multinomial(probs, num_samples=1, generator=rng)
return idx.gather(1, choice) # Map back to vocab indices
else:
# Standard temperature sampling (no top-k)
logits = logits / temperature
probs = F.softmax(logits, dim=-1)
return torch.multinomial(probs, num_samples=1, generator=rng)Detailed Example:
Let's say the model outputs logits for the vocabulary:
Vocab size = 50,000 tokens
Logits (raw scores):
"Hi" → 8.2
"Hello" → 7.9
"Hey" → 7.5
"Greet" → 6.1
...
"Zebra" → -3.2
Step-by-step with temperature=0.8, top_k=50:
-
Top-k Filtering:
- Extract top 50 logits (discard the remaining 49,950)
- Top 50 might be: ["Hi": 8.2, "Hello": 7.9, "Hey": 7.5, ..., "Howdy": 4.3]
-
Temperature Scaling:
- Divide each by 0.8:
"Hi" → 8.2 / 0.8 = 10.25 "Hello" → 7.9 / 0.8 = 9.875 "Hey" → 7.5 / 0.8 = 9.375- Higher temperature (>1) → flatter distribution (more random)
- Lower temperature (<1) → sharper distribution (more deterministic)
-
Softmax:
- Convert to probabilities:
"Hi" → exp(10.25) / Z = 0.52 (52% chance) "Hello" → exp(9.875) / Z = 0.31 (31% chance) "Hey" → exp(9.375) / Z = 0.12 (12% chance) ... (remaining 47 tokens share ~5%) -
Multinomial Sampling:
Now, we pick the next token. We use
torch.multinomial()which samples from a categorical distribution.How it works:
Think of it like a weighted lottery:
Probabilities: "Hi" → 0.52 (52% of the probability mass) "Hello" → 0.31 (31% of the probability mass) "Hey" → 0.12 (12% of the probability mass) ... (remaining 47 tokens share ~5%)Visualization as a number line:
0.0 0.52 0.83 0.95 1.0 |----------------|---------|-----|--------------| "Hi" "Hello" "Hey" (other tokens) (52%) (31%) (12%) (~5%)The sampling process:
- Generate a random number
runiformly between 0.0 and 1.0 - Find which "bucket" it falls into:
r = random.uniform(0, 1) # e.g., r = 0.67 cumulative = 0.0 for token, prob in zip(tokens, probabilities): cumulative += prob if r < cumulative: return token # Example with r = 0.67: # cumulative = 0.0 → r (0.67) >= 0.0, continue # cumulative = 0.52 ("Hi") → r (0.67) >= 0.52, continue # cumulative = 0.83 ("Hello") → r (0.67) < 0.83, return "Hello"!
Different random numbers → different tokens:
r = 0.25 → Falls in [0.0, 0.52] → "Hi" r = 0.67 → Falls in [0.52, 0.83] → "Hello" r = 0.88 → Falls in [0.83, 0.95] → "Hey" r = 0.99 → Falls in [0.95, 1.0] → Some rare tokenTokens with higher probabilities occupy more of the number line, so they're more likely to be sampled!
In our example, let's say we sample
r = 0.25, which gives us "Hi" (token ID varies by vocabulary) - Generate a random number
File: scripts/chat_web.py (lines 262-311)
As tokens are generated, they're immediately streamed to the client:
async def generate_stream(
worker: Worker,
tokens,
temperature=None,
max_new_tokens=None,
top_k=None
) -> AsyncGenerator[str, None]:
"""Generate assistant response with streaming."""
# Accumulate tokens for proper UTF-8 handling
accumulated_tokens = []
last_clean_text = ""
with worker.autocast_ctx:
for token_column, token_masks in worker.engine.generate(...):
token = token_column[0]
# Check for stopping tokens
if token == assistant_end or token == bos:
break
# Accumulate token
accumulated_tokens.append(token)
# Decode all accumulated tokens
current_text = worker.tokenizer.decode(accumulated_tokens)
# Only emit if no incomplete UTF-8 sequences
if not current_text.endswith('�'):
new_text = current_text[len(last_clean_text):]
if new_text:
yield f"data: {json.dumps({'token': new_text, 'gpu': worker.gpu_id}, ensure_ascii=False)}\n\n"
last_clean_text = current_text
yield f"data: {json.dumps({'done': True})}\n\n"UTF-8 Handling:
This is critical! Tokens don't always align with UTF-8 character boundaries.
Example with emoji "😀":
Token 1: First 2 bytes → decode → "�" (incomplete UTF-8)
Token 2: Next 2 bytes → decode → "😀" (complete!)
The code:
- Accumulates tokens
- Decodes ALL tokens so far
- Checks for
'�'(Unicode replacement character) - Only emits when we have a complete UTF-8 sequence
Server-Sent Events Format:
data: {"token": "Hi", "gpu": 0}
data: {"token": "!", "gpu": 0}
data: {"token": " How", "gpu": 0}
data: {"done": true}
Each line:
- Starts with
"data: " - Contains JSON payload
- Ends with
\n\n(double newline)
File: nanochat/ui.html (lines 407-432)
The browser receives and displays tokens in real-time:
const reader = response.body.getReader();
const decoder = new TextDecoder();
let fullResponse = "";
while (true) {
const { done, value } = await reader.read();
if (done) break;
// Decode bytes to string
const chunk = decoder.decode(value);
const lines = chunk.split("\n");
for (const line of lines) {
if (line.startsWith("data: ")) {
try {
const data = JSON.parse(line.slice(6));
if (data.token) {
fullResponse += data.token;
assistantContent.textContent = fullResponse;
chatContainer.scrollTop = chatContainer.scrollHeight;
}
} catch (e) {
// Ignore parse errors (incomplete JSON)
}
}
}
}File: scripts/chat_web.py (lines 353-373)
async def stream_and_release():
try:
async for chunk in generate_stream(...):
# Accumulate response for logging
chunk_data = json.loads(chunk.replace("data: ", "").strip())
if "token" in chunk_data:
response_tokens.append(chunk_data["token"])
yield chunk
finally:
# Log the assistant response
full_response = "".join(response_tokens)
logger.info(f"[ASSISTANT] (GPU {worker.gpu_id}): {full_response}")
logger.info("="*20)
# Release worker back to pool
await worker_pool.release_worker(worker)Worker lifecycle:
- Request arrives
- Worker acquired from pool (blocks if all busy)
- Generation happens
- Worker released back to pool
- Next request can use this worker