Skip to content

Instantly share code, notes, and snippets.

@vukrosic
Last active October 21, 2025 21:26
Show Gist options
  • Select an option

  • Save vukrosic/c8f370a1d041f58bd28fdea9339467f8 to your computer and use it in GitHub Desktop.

Select an option

Save vukrosic/c8f370a1d041f58bd28fdea9339467f8 to your computer and use it in GitHub Desktop.
RMSNorm Tutorial.ipynb
Display the source blob
Display the rendered blob
Raw
{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"provenance": [],
"gpuType": "T4",
"name": "RMSNorm Tutorial.ipynb",
"authorship_tag": "ABX9TyNGNLw7bmZ0UDXJs+0jI0HU",
"include_colab_link": true
},
"kernelspec": {
"name": "python3",
"display_name": "Python 3"
},
"language_info": {
"name": "python"
},
"accelerator": "GPU"
},
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "view-in-github",
"colab_type": "text"
},
"source": [
"<a href=\"https://colab.research.google.com/gist/vukrosic/c8f370a1d041f58bd28fdea9339467f8/rmsnorm.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
]
},
{
"cell_type": "markdown",
"source": [
"Video Master RMSNorm From Scratch - Step by Step Tutorial: https://youtu.be/HgSdYtPgJnU"
],
"metadata": {
"id": "scd2-_fRqCbZ"
}
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "DT5-i_3wjKeH",
"outputId": "4458979a-0464-4077-fc62-e255eb88a3a4"
},
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Input:\n",
"[1. 2. 3. 4. 5.]\n",
"\n",
"Input mean: 3.0000\n",
"Input std: 1.4142\n",
"\n",
"==================================================\n",
"After RMSNorm:\n",
"[0.30151133 0.60302266 0.90453399 1.20604532 1.50755665]\n",
"\n",
"Output mean: 0.9045\n",
"Output RMS: 1.0000\n",
"\n",
"==================================================\n",
"Step-by-step calculation:\n",
"1. Squared values: [ 1. 4. 9. 16. 25.]\n",
"2. Mean of squares: 11.0000\n",
"3. RMS (sqrt of mean): 3.3166\n",
"4. Divide input by RMS to get normalized output\n"
]
}
],
"source": [
"import numpy as np\n",
"\n",
"def rmsnorm(x, eps=1e-6):\n",
" \"\"\"\n",
" Root Mean Square Normalization\n",
"\n",
" Args:\n",
" x: Input array\n",
" eps: Small constant for numerical stability\n",
"\n",
" Returns:\n",
" Normalized array\n",
" \"\"\"\n",
" # Calculate RMS: sqrt of mean of squares\n",
" rms = np.sqrt(np.mean(x**2) + eps§)\n",
"\n",
" # Normalize by dividing by RMS\n",
" return x / rms\n",
"\n",
"# Example input\n",
"input_vector = np.array([1.0, 2.0, 3.0, 4.0, 5.0])\n",
"\n",
"print(\"Input:\")\n",
"print(input_vector)\n",
"print(f\"\\nInput mean: {np.mean(input_vector):.4f}\")\n",
"print(f\"Input std: {np.std(input_vector):.4f}\")\n",
"\n",
"# Apply RMSNorm\n",
"output = rmsnorm(input_vector)\n",
"\n",
"print(\"\\n\" + \"=\"*50)\n",
"print(\"After RMSNorm:\")\n",
"print(output)\n",
"print(f\"\\nOutput mean: {np.mean(output):.4f}\")\n",
"print(f\"Output RMS: {np.sqrt(np.mean(output**2)):.4f}\")\n",
"\n",
"# Show the calculation step by step\n",
"print(\"\\n\" + \"=\"*50)\n",
"print(\"Step-by-step calculation:\")\n",
"squared = input_vector**2\n",
"print(f\"1. Squared values: {squared}\")\n",
"print(f\"2. Mean of squares: {np.mean(squared):.4f}\")\n",
"print(f\"3. RMS (sqrt of mean): {np.sqrt(np.mean(squared)):.4f}\")\n",
"print(f\"4. Divide input by RMS to get normalized output\")"
]
},
{
"cell_type": "code",
"source": [
"import numpy as np\n",
"import torch\n",
"import torch.nn as nn\n",
"\n",
"# same as above\n",
"def rmsnorm_numpy(x, eps=1e-6):\n",
" \"\"\"RMSNorm using NumPy\"\"\"\n",
" rms = np.sqrt(np.mean(x**2) + eps)\n",
" return x / rms\n",
"\n",
"class RMSNorm(nn.Module):\n",
" \"\"\"RMSNorm using PyTorch\"\"\"\n",
" def __init__(self, dim, eps=1e-6):\n",
" super().__init__()\n",
" self.eps = eps\n",
" # Learnable scale parameter (like gamma in LayerNorm)\n",
" self.weight = nn.Parameter(torch.ones(dim))\n",
"\n",
" def forward(self, x):\n",
" # Calculate RMS\n",
" rms = torch.sqrt(torch.mean(x**2, dim=-1, keepdim=True) + self.eps)\n",
" # Normalize and scale\n",
" return (x / rms) * self.weight\n",
"\n",
"# Example input\n",
"input_vector = np.array([1.0, 2.0, 3.0, 4.0, 5.0])\n",
"\n",
"print(\"=\"*60)\n",
"print(\"NUMPY VERSION\")\n",
"print(\"=\"*60)\n",
"print(f\"Input: {input_vector}\")\n",
"print(f\"Input RMS: {np.sqrt(np.mean(input_vector**2)):.4f}\")\n",
"\n",
"output_np = rmsnorm_numpy(input_vector)\n",
"print(f\"\\nAfter RMSNorm: {output_np}\")\n",
"print(f\"Output RMS: {np.sqrt(np.mean(output_np**2)):.4f}\")\n",
"print(f\"Values are now scaled to have RMS ≈ 1.0\")\n",
"\n",
"print(\"\\n\" + \"=\"*60)\n",
"print(\"PYTORCH VERSION\")\n",
"print(\"=\"*60)\n",
"\n",
"# Convert to PyTorch tensor\n",
"input_tensor = torch.tensor([1.0, 2.0, 3.0, 4.0, 5.0])\n",
"print(f\"Input: {input_tensor}\")\n",
"\n",
"# Create RMSNorm layer\n",
"rmsnorm_layer = RMSNorm(dim=5)\n",
"print(f\"Learnable weights: {rmsnorm_layer.weight.data}\")\n",
"\n",
"output_torch = rmsnorm_layer(input_tensor)\n",
"print(f\"\\nAfter RMSNorm: {output_torch}\")\n",
"print(f\"Output RMS: {torch.sqrt(torch.mean(output_torch**2)):.4f}\")\n",
"\n",
"# Show with different scale inputs\n",
"print(\"\\n\" + \"=\"*60)\n",
"print(\"SAME PATTERN, DIFFERENT SCALES:\")\n",
"print(\"=\"*60)\n",
"small = np.array([0.1, 0.2, 0.3, 0.4, 0.5])\n",
"large = np.array([10.0, 20.0, 30.0, 40.0, 50.0])\n",
"\n",
"print(f\"Small input RMS: {np.sqrt(np.mean(small**2)):.4f}\")\n",
"print(f\"After RMSNorm RMS: {np.sqrt(np.mean(rmsnorm_numpy(small)**2)):.4f}\")\n",
"print(f\"\\nLarge input RMS: {np.sqrt(np.mean(large**2)):.4f}\")\n",
"print(f\"After RMSNorm RMS: {np.sqrt(np.mean(rmsnorm_numpy(large)**2)):.4f}\")\n",
"print(\"\\n→ Both normalized to similar scale!\")"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "vfLWQTBFnN1E",
"outputId": "fb31bf07-bc75-491f-892f-4c6097b3a472"
},
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"============================================================\n",
"NUMPY VERSION\n",
"============================================================\n",
"Input: [1. 2. 3. 4. 5.]\n",
"Input RMS: 3.3166\n",
"\n",
"After RMSNorm: [0.30151133 0.60302266 0.90453399 1.20604532 1.50755665]\n",
"Output RMS: 1.0000\n",
"Values are now scaled to have RMS ≈ 1.0\n",
"\n",
"============================================================\n",
"PYTORCH VERSION\n",
"============================================================\n",
"Input: tensor([1., 2., 3., 4., 5.])\n",
"Learnable weights: tensor([1., 1., 1., 1., 1.])\n",
"\n",
"After RMSNorm: tensor([0.3015, 0.6030, 0.9045, 1.2060, 1.5076], grad_fn=<MulBackward0>)\n",
"Output RMS: 1.0000\n",
"\n",
"============================================================\n",
"SAME PATTERN, DIFFERENT SCALES:\n",
"============================================================\n",
"Small input RMS: 0.3317\n",
"After RMSNorm RMS: 1.0000\n",
"\n",
"Large input RMS: 33.1662\n",
"After RMSNorm RMS: 1.0000\n",
"\n",
"→ Both normalized to similar scale!\n"
]
}
]
},
{
"cell_type": "code",
"source": [
"import torch\n",
"import torch.nn as nn\n",
"import torch.optim as optim\n",
"import matplotlib.pyplot as plt\n",
"import numpy as np\n",
"\n",
"class RMSNorm(nn.Module):\n",
" \"\"\"RMSNorm with learnable scale\"\"\"\n",
" def __init__(self, dim, eps=1e-6):\n",
" super().__init__()\n",
" self.eps = eps\n",
" self.weight = nn.Parameter(torch.ones(dim))\n",
"\n",
" def forward(self, x):\n",
" rms = torch.sqrt(torch.mean(x**2, dim=-1, keepdim=True) + self.eps)\n",
" return (x / rms) * self.weight\n",
"\n",
"class SimpleNet(nn.Module):\n",
" \"\"\"Simple neural network with RMSNorm\"\"\"\n",
" def __init__(self, input_dim, hidden_dim, output_dim):\n",
" super().__init__()\n",
" self.fc1 = nn.Linear(input_dim, hidden_dim)\n",
" self.rmsnorm = RMSNorm(hidden_dim)\n",
" self.fc2 = nn.Linear(hidden_dim, output_dim)\n",
" self.relu = nn.ReLU()\n",
"\n",
" def forward(self, x):\n",
" x = self.fc1(x)\n",
" x = self.rmsnorm(x) # Normalize after first layer\n",
" x = self.relu(x)\n",
" x = self.fc2(x)\n",
" return x\n",
"\n",
"# Create a simple dataset: learn XOR-like pattern\n",
"torch.manual_seed(42)\n",
"X = torch.tensor([[0., 0.], [0., 1.], [1., 0.], [1., 1.]])\n",
"y = torch.tensor([[0.], [1.], [1.], [0.]]) # XOR\n",
"\n",
"print(\"=\"*60)\n",
"print(\"TRAINING A SIMPLE NEURAL NETWORK WITH RMSNorm\")\n",
"print(\"=\"*60)\n",
"print(f\"Task: Learn XOR function\")\n",
"print(f\"Input:\\n{X}\")\n",
"print(f\"Target:\\n{y.squeeze()}\")\n",
"\n",
"# Create model\n",
"model = SimpleNet(input_dim=2, hidden_dim=4, output_dim=1)\n",
"criterion = nn.MSELoss()\n",
"optimizer = optim.Adam(model.parameters(), lr=0.1)\n",
"\n",
"# Track RMSNorm weights during training\n",
"weight_history = []\n",
"loss_history = []\n",
"\n",
"print(f\"\\nInitial RMSNorm weights: {model.rmsnorm.weight.data.numpy()}\")\n",
"\n",
"# Train\n",
"epochs = 500\n",
"for epoch in range(epochs):\n",
" optimizer.zero_grad()\n",
" output = model(X)\n",
" loss = criterion(output, y)\n",
" loss.backward()\n",
" optimizer.step()\n",
"\n",
" # Record weights and loss\n",
" weight_history.append(model.rmsnorm.weight.data.clone().numpy())\n",
" loss_history.append(loss.item())\n",
"\n",
" if (epoch + 1) % 100 == 0:\n",
" print(f\"Epoch {epoch+1}/{epochs}, Loss: {loss.item():.6f}\")\n",
"\n",
"print(f\"\\nFinal RMSNorm weights: {model.rmsnorm.weight.data.numpy()}\")\n",
"\n",
"# Test the model\n",
"print(\"\\n\" + \"=\"*60)\n",
"print(\"FINAL PREDICTIONS:\")\n",
"print(\"=\"*60)\n",
"with torch.no_grad():\n",
" predictions = model(X)\n",
" for i, (inp, pred, target) in enumerate(zip(X, predictions, y)):\n",
" print(f\"Input: {inp.numpy()} → Predicted: {pred.item():.4f}, Target: {target.item():.4f}\")\n",
"\n",
"# Plot the weight evolution\n",
"weight_history = np.array(weight_history)\n",
"fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 5))\n",
"\n",
"# Plot 1: RMSNorm weights over time\n",
"for i in range(4):\n",
" ax1.plot(weight_history[:, i], label=f'Weight {i}', linewidth=2)\n",
"ax1.set_xlabel('Epoch', fontsize=12)\n",
"ax1.set_ylabel('Weight Value', fontsize=12)\n",
"ax1.set_title('RMSNorm Learnable Weights During Training', fontsize=14, fontweight='bold')\n",
"ax1.legend()\n",
"ax1.grid(True, alpha=0.3)\n",
"ax1.axhline(y=1.0, color='red', linestyle='--', alpha=0.5, label='Initial value')\n",
"\n",
"# Plot 2: Training loss\n",
"ax2.plot(loss_history, color='orange', linewidth=2)\n",
"ax2.set_xlabel('Epoch', fontsize=12)\n",
"ax2.set_ylabel('Loss (MSE)', fontsize=12)\n",
"ax2.set_title('Training Loss Over Time', fontsize=14, fontweight='bold')\n",
"ax2.grid(True, alpha=0.3)\n",
"ax2.set_yscale('log')\n",
"\n",
"plt.tight_layout()\n",
"plt.savefig('rmsnorm_training.png', dpi=150, bbox_inches='tight')\n",
"print(\"\\n✓ Plot saved as 'rmsnorm_training.png'\")\n",
"plt.show()\n",
"\n",
"print(\"\\n\" + \"=\"*60)\n",
"print(\"KEY OBSERVATIONS:\")\n",
"print(\"=\"*60)\n",
"print(\"• RMSNorm weights started at 1.0 (neutral scaling)\")\n",
"print(\"• During training, they learned to scale different dimensions\")\n",
"print(\"• This helps the model emphasize important features\")\n",
"print(\"• Each weight adjusts the importance of one hidden dimension\")"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 980
},
"id": "i_FUTNqpgmuq",
"outputId": "0d306d57-6130-4389-d7a8-421e04c4bf18"
},
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"============================================================\n",
"TRAINING A SIMPLE NEURAL NETWORK WITH RMSNorm\n",
"============================================================\n",
"Task: Learn XOR function\n",
"Input:\n",
"tensor([[0., 0.],\n",
" [0., 1.],\n",
" [1., 0.],\n",
" [1., 1.]])\n",
"Target:\n",
"tensor([0., 1., 1., 0.])\n",
"\n",
"Initial RMSNorm weights: [1. 1. 1. 1.]\n",
"Epoch 100/500, Loss: 0.000022\n",
"Epoch 200/500, Loss: 0.000000\n",
"Epoch 300/500, Loss: 0.000235\n",
"Epoch 400/500, Loss: 0.000000\n",
"Epoch 500/500, Loss: 0.000000\n",
"\n",
"Final RMSNorm weights: [0.78946674 1.6595931 0.78331435 0.54245394]\n",
"\n",
"============================================================\n",
"FINAL PREDICTIONS:\n",
"============================================================\n",
"Input: [0. 0.] → Predicted: -0.0000, Target: 0.0000\n",
"Input: [0. 1.] → Predicted: 1.0000, Target: 1.0000\n",
"Input: [1. 0.] → Predicted: 1.0000, Target: 1.0000\n",
"Input: [1. 1.] → Predicted: -0.0000, Target: 0.0000\n",
"\n",
"✓ Plot saved as 'rmsnorm_training.png'\n"
]
},
{
"output_type": "display_data",
"data": {
"text/plain": [
"<Figure size 1400x500 with 2 Axes>"
],
"image/png": "\n"
},
"metadata": {}
},
{
"output_type": "stream",
"name": "stdout",
"text": [
"\n",
"============================================================\n",
"KEY OBSERVATIONS:\n",
"============================================================\n",
"• RMSNorm weights started at 1.0 (neutral scaling)\n",
"• During training, they learned to scale different dimensions\n",
"• This helps the model emphasize important features\n",
"• Each weight adjusts the importance of one hidden dimension\n"
]
}
]
}
]
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment