Last active
February 10, 2026 03:45
-
-
Save LiutongZhou/139eddad3ea732b9c5f06b97cf702799 to your computer and use it in GitHub Desktop.
Jax Distributed Zero to Hero
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| { | |
| "cells": [ | |
| { | |
| "cell_type": "markdown", | |
| "id": "8bfe2b28-24e4-4fa1-a761-79cf1985b09b", | |
| "metadata": { | |
| "editable": true, | |
| "slideshow": { | |
| "slide_type": "" | |
| }, | |
| "tags": [] | |
| }, | |
| "source": [ | |
| "# JAX Sharding and Parallel Computing - All-In-One Guide\n", | |
| "From Zero to Hero (v0.8+)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "id": "393e9dce-d18b-40b7-932b-07b2e32b83a4", | |
| "metadata": {}, | |
| "source": [ | |
| "**TOC**\n", | |
| "1. [Imports](#Imports)\n", | |
| "2. [Create Device Mesh and Sharding Spec](#Create-Device-Mesh-and-Sharding-Spec)\n", | |
| "3. [Explicit Sharding](#Explicit-Sharding)\n", | |
| "4. [Manual Sharding](#Manual-Sharding)\n", | |
| " 1. [Simple Matrix Multiplication](#Simple-Matrix-Multiplication)\n", | |
| " 2. [The-Collective-Communication-Ops](#The-Collective-Communication-Ops)\n", | |
| " 1. [PPermute](#PPermute)\n", | |
| " 2. [Reduce-Scatter](#Reduce-Scatter)\n", | |
| " 3. [All-Gather](#All-Gather)\n", | |
| " 4. [All-Reduce](#All-Reduce)\n", | |
| " 5. [All-to-All](#All-to-All)\n", | |
| "5. [Matrix Multiplication](#Matrix-Multiplication)\n", | |
| " 1. [Naive Version](#Naive-Version)\n", | |
| " 2. [FSDP Forward](#FSDP-Forward)\n", | |
| " 3. [Profiling `matmul_fsdp`](#Profiling-matmul_fsdp)\n", | |
| " 4. [`ppermute` Loop Version](#ppermute-Loop-Version)\n", | |
| " 5. [Matmul Reduce Scatter](#Matmul-Reduce-Scatter)\n", | |
| "6. [Practical Examples](#Practical-Examples)\n", | |
| " 1. [Distributed Training of Neural Networks](#Distributed-Training-of-Neural-Networks)\n", | |
| " 1. [Data Parallelism](#Data-Parallelism)\n", | |
| " 2. [FSDP](#FSDP)\n", | |
| " 3. [TP](#TP)\n", | |
| " 4. [FSDP + TP](#FSDP-+-TP)\n", | |
| " 5. [MOE Parallel with Token Dropping (Advanced)](#MOE-Parallel-with-Token-Dropping)\n", | |
| " 2. [Knowledge Distillation (Naive Pipeline Parallelism)](#Knowledge-Distillation) " | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "id": "566d4b26-f47f-4345-b058-1519f9118c10", | |
| "metadata": {}, | |
| "source": [ | |
| "## Imports" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 1, | |
| "id": "e1c0821c-c6aa-430f-b7f2-502db736029e", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "import itertools\n", | |
| "from functools import partial\n", | |
| "from typing import Iterator\n", | |
| "\n", | |
| "import jax\n", | |
| "import jax.numpy as jnp\n", | |
| "from jax.sharding import (\n", | |
| " AxisType,\n", | |
| " Mesh,\n", | |
| " NamedSharding,\n", | |
| " PartitionSpec as P,\n", | |
| " auto_axes,\n", | |
| " explicit_axes,\n", | |
| " get_abstract_mesh,\n", | |
| " reshard,\n", | |
| ")\n", | |
| "from jaxtyping import Array, Float, Int, PyTree" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 2, | |
| "id": "f0d55a5e-370b-4788-8247-0e174f7197b7", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "[CpuDevice(id=0),\n", | |
| " CpuDevice(id=1),\n", | |
| " CpuDevice(id=2),\n", | |
| " CpuDevice(id=3),\n", | |
| " CpuDevice(id=4),\n", | |
| " CpuDevice(id=5),\n", | |
| " CpuDevice(id=6),\n", | |
| " CpuDevice(id=7)]" | |
| ] | |
| }, | |
| "execution_count": 2, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "USE_CPU = True # Set to True to use CPU to simulate multiple GPUs/TPUs\n", | |
| "\n", | |
| "if USE_CPU:\n", | |
| " if not jax._src.xla_bridge.backends_are_initialized():\n", | |
| " jax.config.update(\"jax_platforms\", \"cpu\")\n", | |
| " jax.config.update(\"jax_num_cpu_devices\", 8)\n", | |
| " # Alternatively\n", | |
| " # os.environ['JAX_PLATFORMS'] = 'cpu'\n", | |
| " # os.environ['XLA_FLAGS'] = \"--xla_force_host_platform_device_count=8\"\n", | |
| "jax.devices()" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "id": "418b59ac-0f13-41a3-9ad3-3601339bdbd6", | |
| "metadata": {}, | |
| "source": [ | |
| "## Create Device Mesh and Sharding Spec" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "id": "390707ba-f4da-4972-94b5-1acef60dcaca", | |
| "metadata": {}, | |
| "source": [ | |
| "In JAX, a device `Mesh` is an organized logical view of the available physical devices.\n", | |
| "\n", | |
| "If you have multiple (GPU/TPU) devices, you can arrange them into a device `Mesh` with named axes for reference. " | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 3, | |
| "id": "641139cf-d4b3-4a9c-a2c0-e72a34aa8efd", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "Current mesh is: AbstractMesh('dp': 4, 'tp': 2, axis_types=(Explicit, Explicit), device_kind=cpu, num_cores=None)\n", | |
| "Current mesh is: AbstractMesh((), axis_types=())\n", | |
| "Current mesh is: AbstractMesh('dp': 4, 'tp': 2, axis_types=(Explicit, Explicit), device_kind=cpu, num_cores=None)\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "# Device Mesh = named axes over the device set -> dp x tp devices\n", | |
| "global_mesh = jax.make_mesh(\n", | |
| " (4, 2),\n", | |
| " (\"dp\", \"tp\"),\n", | |
| " axis_types=(AxisType.Explicit, AxisType.Explicit),\n", | |
| " # AxisType.Explicit is the default in v0.9+ (New pattern)\n", | |
| " # Only use AxisType.Explicit or AxisType.Manual. Do not use AxisType.Auto (legacy default) for production\n", | |
| ")\n", | |
| "\n", | |
| "with jax.set_mesh(global_mesh):\n", | |
| " print(f\"Current mesh is: {get_abstract_mesh()}\")\n", | |
| "\n", | |
| "print(f\"Current mesh is: {get_abstract_mesh()}\")\n", | |
| "jax.set_mesh(\n", | |
| " global_mesh\n", | |
| ") # <- be careful: this will set a global mesh every where throughout the process\n", | |
| "print(f\"Current mesh is: {get_abstract_mesh()}\")" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 4, | |
| "id": "71acc0fc-ae71-4007-8338-589621a98866", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "# Define Tensor Sharding Spec\n", | |
| "\n", | |
| "## Data Parallel: P('dp', None)\n", | |
| "## Split tensor along axis 0 and put partitions over device mesh axis 'dp',\n", | |
| "## replicate tensor along axis 1 to put replicas over device mesh 'tp'\n", | |
| "shard_dp = NamedSharding(global_mesh, P(\"dp\", None))\n", | |
| "## Vice versa\n", | |
| "shard_tp = NamedSharding(global_mesh, P(None, \"tp\")) # tensor parallel\n", | |
| "shard_fully_replicate = NamedSharding(global_mesh, P()) # fully replicated" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "id": "74664a76-43dd-4753-a2c5-c177a01eaa75", | |
| "metadata": {}, | |
| "source": [ | |
| "You can create many different Meshes and Sharding Specs. \n", | |
| "\n", | |
| "Technically, you can create meshes that share GPUs (e.g. `GPU:0` in both `mesh_a` and `mesh_b`). \n", | |
| "\n", | |
| "A use case of having two meshes: see [Practical Examples - **Knowledge Distillation**](#Knowledge-Distillation)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "id": "f59789a6-b434-4ae8-9e91-63c82baa16b9", | |
| "metadata": {}, | |
| "source": [ | |
| "## Explicit Sharding" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "id": "1507f09f-2ef9-4d81-beef-ca1cb4260e11", | |
| "metadata": {}, | |
| "source": [ | |
| "Given a 2D device mesh with axis names `(\"dp\", \"tp\")`\n", | |
| "* Notation: `float32[1024@dp,4096]`\n", | |
| "* Interpretation:\n", | |
| " * a tensor of `dtype=float32` and of shape `(1024, 4096)`\n", | |
| " * sharded (split) along `dim=0` into `dp` partitions and placed across the device mesh axis `dp`\n", | |
| " * replicated along `dim=1` into `tp` replicas and placed across the device mesh axis `tp`\n", | |
| "\n", | |
| "Similarly, given a 6D device mesh with axis names `(\"pp\", \"dp\", \"fsdp\", \"sp\", \"tp\", \"ep\")` \n", | |
| "* Notation: `bfloat16[512, 2048@(sp, tp), 128_000]`\n", | |
| "* Interpretation:\n", | |
| " * a tensor of `dtype=bfloat16` and of shape `(512, 2048, 128_000)`\n", | |
| " * sharded (split) along `dim=1` into `sp x tp` partitions and placed across the device mesh axes `sp x tp`\n", | |
| " * replicated along `dim=0` and `dim=2` into `pp x dp x fsdp x ep` replicas and placed across those device mesh axes" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 5, | |
| "id": "073ebb9a-afc4-4356-b614-e995bf753f40", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "jax.typeof(x)=ShapedArray(float32[1024@dp,4096])\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "x = jax.random.normal(jax.random.key(0), (1024, 4096))\n", | |
| "# put tensor from default device (CPU) to accelerator device\n", | |
| "# and explicitly shard x along axis 0 to distribute partitions across mesh axis dp\n", | |
| "x = jax.device_put(x, shard_dp)\n", | |
| "print(f\"{jax.typeof(x)=}\") # float32[1024@dp,4096]\n", | |
| "# jax.debug.visualize_array_sharding(x) # to see how x is sharded horizontally" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 6, | |
| "id": "2a1e26ab-8e7e-495e-8d25-47780821afbc", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "jax.typeof(w)=ShapedArray(float32[4096,256@tp])\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "w = jax.random.normal(jax.random.key(1), (4096, 256))\n", | |
| "# explicitly shard model weight w along axis 1 to distribute partition across mesh axis tp\n", | |
| "w = jax.device_put(w, shard_tp)\n", | |
| "print(f\"{jax.typeof(w)=}\") # float32[4096,256@tp]\n", | |
| "# jax.debug.visualize_array_sharding(w) # to see how w is sharded vertically" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 7, | |
| "id": "1f4a8be4-e807-4b05-8e85-eab9b6c5ae9f", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "Current mesh is: AbstractMesh('dp': 4, 'tp': 2, axis_types=(Explicit, Explicit), device_kind=cpu, num_cores=None)\n", | |
| "\n", | |
| "Shard of x: NamedSharding(mesh=Mesh('dp': 4, 'tp': 2, axis_types=(Auto, Auto)), spec=PartitionSpec('dp',), memory_kind=device)\n", | |
| "Shard of w: NamedSharding(mesh=Mesh('dp': 4, 'tp': 2, axis_types=(Auto, Auto)), spec=PartitionSpec(None, 'tp'), memory_kind=device)\n", | |
| "\n", | |
| "jax.typeof(y)=ShapedArray(float32[1024@dp,256@tp])\n", | |
| "jax.typeof(y_resharded)=ShapedArray(float32[1024,256])\n", | |
| "jax.typeof(y_resharded_2)=ShapedArray(float32[1024@tp,256@dp])\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "@jax.jit\n", | |
| "def f(x: Float[Array, \"batch hidden\"], w: Float[Array, \"hidden out_dim\"]) -> tuple[\n", | |
| " Float[Array, \"batch out_dim\"],\n", | |
| " Float[Array, \"batch out_dim\"],\n", | |
| " Float[Array, \"batch out_dim\"],\n", | |
| "]:\n", | |
| " # float32[1024@dp,4096] @ float32[4096,256@tp] -> (1024@dp, 256@tp)\n", | |
| " jax.debug.inspect_array_sharding(x, callback=lambda x: print(f\"\\nShard of x: {x}\"))\n", | |
| " jax.debug.inspect_array_sharding(w, callback=lambda w: print(f\"Shard of w: {w}\\n\"))\n", | |
| " # ^ use this to debug inside jit\n", | |
| " y = x @ w\n", | |
| "\n", | |
| " print(f\"Current mesh is: {get_abstract_mesh()}\")\n", | |
| " # Note: for explicit axis type, with_sharding_constraint acts like an assert check\n", | |
| " y = jax.lax.with_sharding_constraint(y, NamedSharding(global_mesh, P(\"dp\", \"tp\")))\n", | |
| " # Only in auto axis type, it (implicitly) acts like a reshard.\n", | |
| "\n", | |
| " # You should explicitly use reshard to change tensor layout, if needed (to be wary of the communication cost)\n", | |
| " y_resharded = reshard(y, P()) # <- all gather will happen\n", | |
| " y_resharded_2 = reshard(y, P(\"tp\", \"dp\")) # <- all to all will happen\n", | |
| " return y, y_resharded, y_resharded_2\n", | |
| "\n", | |
| "\n", | |
| "y, y_resharded, y_resharded_2 = f(x, w)\n", | |
| "print(f\"{jax.typeof(y)=}\")\n", | |
| "print(f\"{jax.typeof(y_resharded)=}\")\n", | |
| "print(f\"{jax.typeof(y_resharded_2)=}\")\n", | |
| "assert jnp.all(y == y_resharded)\n", | |
| "assert jnp.all(y_resharded == y_resharded_2)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "id": "28d2d434-a823-4c43-b729-4e83b084a181", | |
| "metadata": {}, | |
| "source": [ | |
| "## Manual Sharding\n", | |
| "Using `jax.shard_map` (+ `jax.lax.*` collectives) is highly preferred for production code because of its explicitness. " | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "id": "94a4f592-1298-4ee9-81db-d03bc34f43f6", | |
| "metadata": {}, | |
| "source": [ | |
| "### Simple Matrix Multiplication" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "id": "9664e50e-f072-4b0d-ba87-2edd81ae40ab", | |
| "metadata": {}, | |
| "source": [ | |
| "Given a device mesh Mesh('dp': 4, 'tp': 2)\n", | |
| "* Notation: `float32[256{dp=4}, 4096]`\n", | |
| "* Interpretation:\n", | |
| " * a local tensor block of dtype `float32` and shape `(256, 4096)`\n", | |
| " * varying along `dim=0` across the device mesh axis `dp` and replicated across the device mesh axis `tp`" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 8, | |
| "id": "4eaac26d-58dd-4d35-bb6c-9e5b93cfc3cd", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "jax.typeof(x_block)=ShapedArray(float32[256,4096]{dp})\n", | |
| "jax.typeof(w_block)=ShapedArray(float32[4096,128]{tp})\n", | |
| "jax.typeof(y_block)=ShapedArray(float32[256,128]{tp,dp})\n", | |
| "jax.typeof(y)=ShapedArray(float32[1024@dp,256@tp])\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "x = jax.random.normal(jax.random.key(0), (1024, 4096))\n", | |
| "w = jax.random.normal(jax.random.key(1), (4096, 256))\n", | |
| "\n", | |
| "\n", | |
| "@jax.jit # Compile the manual parallel code. For easier debugging, you can comment out jit compilation\n", | |
| "@jax.shard_map(\n", | |
| " mesh=global_mesh, in_specs=(P(\"dp\", None), P(None, \"tp\")), out_specs=P(\"dp\", \"tp\")\n", | |
| ")\n", | |
| "def matmul(\n", | |
| " x_block: Float[Array, \"dp_block hidden\"], w_block: Float[Array, \"hidden tp_block\"]\n", | |
| ") -> Float[Array, \"dp_block tp_block\"]:\n", | |
| " \"\"\"This function will run in parallel (concurrently) across devices\n", | |
| "\n", | |
| " All the (dp=4, tp=2) devices will be busy running\n", | |
| " \"\"\"\n", | |
| " # f32[1024@dp=4, 4096] @ f32[4096, 256@tp=2]\n", | |
| " # f32[256{dp=4}, 4096] @ f32[4096, 128{tp=2}}\n", | |
| " print(f\"{jax.typeof(x_block)=}\")\n", | |
| " print(f\"{jax.typeof(w_block)=}\")\n", | |
| " y_block = x_block @ w_block # float32[256{dp=4},128{tp=2}]\n", | |
| " print(f\"{jax.typeof(y_block)=}\")\n", | |
| " # the varying local block y_block will be concatenated per output sharding spec: P(\"dp\", \"tp\")\n", | |
| " # concatenate y_block indexed by dp axis along dim=0 and indexed by tp axis along dim=1 to form the final output\n", | |
| " # float32[1024@dp,256@tp]\n", | |
| " return y_block\n", | |
| "\n", | |
| "\n", | |
| "y = matmul(x, w) # with shard_map, no need to do device_put\n", | |
| "print(f\"{jax.typeof(y)=}\") # float32[1024@dp,256@tp]" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 9, | |
| "id": "afe14164-2215-4c3e-97ca-92f8b16f81b6", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "13.4 ms ± 260 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "@jax.jit\n", | |
| "def f(x, w):\n", | |
| " return x @ w\n", | |
| "\n", | |
| "\n", | |
| "%timeit jax.block_until_ready(f(x,w)) # non-parallel code is slow" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 10, | |
| "id": "0e9d55c3-38af-486f-8615-aab8213474be", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "4.44 ms ± 79 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "%timeit jax.block_until_ready(matmul(x,w)) # parallel code is faster" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 11, | |
| "id": "bdc9ed88-89b7-4eaf-aa24-de43a0a68ac5", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "assert jnp.allclose(reshard(y, P()), x @ w, atol=1e-4) # verify correctness" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "id": "edbc2ef5-77c1-4f93-8d64-da17c9427912", | |
| "metadata": {}, | |
| "source": [ | |
| "### The Collective Communication Ops" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "id": "2f66b5db-de48-4487-84df-d879f1d7558c", | |
| "metadata": {}, | |
| "source": [ | |
| "#### PPermute \n", | |
| "Point-to-point communication.\n", | |
| "\n", | |
| "Communication cost: $O(n)$ where $n$ is the number of devices along the device mesh axis (`jax.lax.axis_size(\"i\")`).\n", | |
| "\n", | |
| "Use case: pipeline parallelism send activations. Ring attention" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 12, | |
| "id": "1b97435a-c146-4818-83b3-885fd4ad4e02", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "def ring_shift_simple_version(block: Array, axis_name: str, offset: int = 1) -> Array:\n", | |
| " \"\"\"Ring shift forward or backward along the device mesh axis\n", | |
| "\n", | |
| " Parameters\n", | |
| " --------\n", | |
| " block : Array\n", | |
| " axis_name: str\n", | |
| " offset : int\n", | |
| " If positive, ring shift forward; if negative, ring shift backward.\n", | |
| " Default is shift forward with step size=1\n", | |
| "\n", | |
| " Returns\n", | |
| " ----------\n", | |
| " receive : Array\n", | |
| " the received block from previous device if shift forward\n", | |
| " the received block from next device if shift backward\n", | |
| " \"\"\"\n", | |
| " num_devices = jax.lax.axis_size(axis_name)\n", | |
| " # number of devices along the device mesh axis\n", | |
| " src_to_dst = [(idx, (idx + offset) % num_devices) for idx in range(num_devices)]\n", | |
| " return jax.lax.ppermute(block, axis_name, perm=src_to_dst)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 13, | |
| "id": "489562fc-4c6d-4d97-b9b3-84d751299ab7", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "[[0 1 2 3]\n", | |
| " [4 5 6 7]]\n", | |
| "jax.typeof(x_block)=ShapedArray(int32[1,1]{tp,dp})\n", | |
| "jax.typeof(x_block_recv_from_prev)=ShapedArray(int32[1,1]{tp,dp})\n", | |
| "jax.typeof(x_block_recv_from_next)=ShapedArray(int32[1,1]{tp,dp})\n", | |
| "jax.typeof(y_shift_forward)=ShapedArray(int32[2@tp,4@dp])\n", | |
| "[[3 0 1 2]\n", | |
| " [7 4 5 6]]\n", | |
| "jax.typeof(y_shift_backward)=ShapedArray(int32[2@tp,4@dp])\n", | |
| "[[1 2 3 0]\n", | |
| " [5 6 7 4]]\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "@jax.shard_map(\n", | |
| " mesh=global_mesh, in_specs=P(\"tp\", \"dp\"), out_specs=(P(\"tp\", \"dp\"), P(\"tp\", \"dp\"))\n", | |
| ")\n", | |
| "def ring_shift_demo(\n", | |
| " x_block: Float[Array, \"tp_block dp_block\"],\n", | |
| ") -> Float[Array, \"tp_block dp_block\"]:\n", | |
| " \"\"\"Demo the ring shift pattern\"\"\"\n", | |
| " # i32[2@tp,4@dp] -> i32[1{tp=2},1{dp=4}] varying -> i32[2@tp, 4@dp]\n", | |
| " print(f\"{jax.typeof(x_block)=}\")\n", | |
| "\n", | |
| " axis_name = \"dp\"\n", | |
| " # ring shift forward\n", | |
| " x_block_recv_from_prev = ring_shift_simple_version(x_block, axis_name, offset=1)\n", | |
| " print(f\"{jax.typeof(x_block_recv_from_prev)=}\")\n", | |
| "\n", | |
| " # ring shift backward\n", | |
| " x_block_recv_from_next = ring_shift_simple_version(x_block, axis_name, offset=-1)\n", | |
| " print(f\"{jax.typeof(x_block_recv_from_next)=}\")\n", | |
| " return x_block_recv_from_prev, x_block_recv_from_next\n", | |
| "\n", | |
| "\n", | |
| "y = jnp.arange(8).reshape(2, 4)\n", | |
| "print(y)\n", | |
| "y_shift_forward, y_shift_backward = ring_shift_demo(y)\n", | |
| "\n", | |
| "print(f\"{jax.typeof(y_shift_forward)=}\")\n", | |
| "print(y_shift_forward)\n", | |
| "print(f\"{jax.typeof(y_shift_backward)=}\")\n", | |
| "print(y_shift_backward)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 14, | |
| "id": "81177087-9ad9-49da-9eab-a3a933dc2eb9", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "[[0 1]\n", | |
| " [2 3]\n", | |
| " [4 5]\n", | |
| " [6 7]]\n", | |
| "jax.typeof(x_block)=ShapedArray(int32[1,1]{tp,dp})\n", | |
| "jax.typeof(x_block_recv_from_prev)=ShapedArray(int32[1,1]{tp,dp})\n", | |
| "jax.typeof(x_block_recv_from_next)=ShapedArray(int32[1,1]{tp,dp})\n", | |
| "jax.typeof(y_shift_forward)=ShapedArray(int32[4@dp,2@tp])\n", | |
| "[[6 7]\n", | |
| " [0 1]\n", | |
| " [2 3]\n", | |
| " [4 5]]\n", | |
| "jax.typeof(y_shift_backward)=ShapedArray(int32[4@dp,2@tp])\n", | |
| "[[2 3]\n", | |
| " [4 5]\n", | |
| " [6 7]\n", | |
| " [0 1]]\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "@jax.shard_map(\n", | |
| " mesh=global_mesh, in_specs=P(\"dp\", \"tp\"), out_specs=(P(\"dp\", \"tp\"), P(\"dp\", \"tp\"))\n", | |
| ")\n", | |
| "def ring_shift_demo(\n", | |
| " x_block: Float[Array, \"dp_block tp_block\"],\n", | |
| ") -> Float[Array, \"dp_block tp_block\"]:\n", | |
| " \"\"\"Demo the ring shift pattern\"\"\"\n", | |
| " # i32[4@dp,2@tp] -> i32[1{dp=4},1{tp=2}] varying -> i32[4@dp, 2@tp]\n", | |
| " print(f\"{jax.typeof(x_block)=}\")\n", | |
| "\n", | |
| " axis_name = \"dp\"\n", | |
| " # ring shift forward\n", | |
| " x_block_recv_from_prev = ring_shift_simple_version(x_block, axis_name, offset=1)\n", | |
| " print(f\"{jax.typeof(x_block_recv_from_prev)=}\")\n", | |
| "\n", | |
| " # ring shift backward\n", | |
| " x_block_recv_from_next = ring_shift_simple_version(x_block, axis_name, offset=-1)\n", | |
| " print(f\"{jax.typeof(x_block_recv_from_next)=}\")\n", | |
| " return x_block_recv_from_prev, x_block_recv_from_next\n", | |
| "\n", | |
| "\n", | |
| "y = jnp.arange(8).reshape(4, 2)\n", | |
| "print(y)\n", | |
| "y_shift_forward, y_shift_backward = ring_shift_demo(y)\n", | |
| "\n", | |
| "print(f\"{jax.typeof(y_shift_forward)=}\")\n", | |
| "print(y_shift_forward)\n", | |
| "print(f\"{jax.typeof(y_shift_backward)=}\")\n", | |
| "print(y_shift_backward)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "id": "3885b2c3-14d3-4e03-b457-bd7b60f0e9ce", | |
| "metadata": {}, | |
| "source": [ | |
| "<div class=\"alert alert-danger\">\n", | |
| "`ring_shift_simple_version` only works if we ppermute along a single device mesh axis. If we want to ppermute along multiple axes, it may return wrong results.\n", | |
| "</div>" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 15, | |
| "id": "88ab9978-dc2b-4acf-86f9-3f05d3c07688", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "[0 1 2 3 4 5 6 7]\n", | |
| "jax.typeof(y_shift_forward)=ShapedArray(int32[8@(tp,dp)])\n", | |
| "[7 4 5 6 0 1 2 3]\n", | |
| "jax.typeof(y_shift_backward)=ShapedArray(int32[8@(tp,dp)])\n", | |
| "[4 5 6 7 1 2 3 0]\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "@jax.shard_map(\n", | |
| " mesh=global_mesh,\n", | |
| " in_specs=P((\"tp\", \"dp\")),\n", | |
| " out_specs=(P((\"tp\", \"dp\")), P((\"tp\", \"dp\"))),\n", | |
| ")\n", | |
| "def ring_shift_demo(\n", | |
| " x_block: Float[Array, \"n\"],\n", | |
| ") -> Float[Array, \"n\"]:\n", | |
| " \"\"\"Demo the ring shift pattern\"\"\"\n", | |
| " # i32[8@(tp,dp)] -> i32[1]{tp=2,dp=4} varying -> i32[8@(tp,dp)]\n", | |
| " axis_name = (\"tp\", \"dp\")\n", | |
| " # ring shift forward\n", | |
| " x_block_recv_from_prev = ring_shift_simple_version(x_block, axis_name, offset=1)\n", | |
| " # ring shift backward\n", | |
| " x_block_recv_from_next = ring_shift_simple_version(x_block, axis_name, offset=-1)\n", | |
| " return x_block_recv_from_prev, x_block_recv_from_next\n", | |
| "\n", | |
| "\n", | |
| "y = jnp.arange(8)\n", | |
| "print(y)\n", | |
| "y_shift_forward, y_shift_backward = ring_shift_demo(y)\n", | |
| "\n", | |
| "print(f\"{jax.typeof(y_shift_forward)=}\")\n", | |
| "print(y_shift_forward)\n", | |
| "print(f\"{jax.typeof(y_shift_backward)=}\")\n", | |
| "print(y_shift_backward)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "id": "ecd004eb-bca2-4007-a72c-b516c7fc95af", | |
| "metadata": {}, | |
| "source": [ | |
| "<div class=\"alert alert-warning\">Why?</div>\n", | |
| "\n", | |
| "Assuming the device mesh is `Mesh(dp=2, tp=2)` and the input partition spec is `P((\"tp\", \"dp\"))` for `[0, 1, 2, 3]`, we would expect a ring shift backward to return `[1,2,3,0]` of type `i32[4@(tp,dp)]`. \n", | |
| "\n", | |
| "Firstly, `[0,1,2,3]` will be partitioned and placed onto\n", | |
| "```\n", | |
| "[[(dp=0,tp=0)(device_id=0) 0, (dp=0,tp=1)(device_id=1) 2],\n", | |
| " [(dp=1,tp=0)(device_id=2) 1, (dp=1,tp=1)(device_id=3) 3]]\n", | |
| "```\n", | |
| "according to the input partition spec `P((\"tp\", \"dp\"))` using last-axis-fast style (\"dp\" axis varies faster). \n", | |
| "\n", | |
| "However the linear index used by `(src, dst)` tuples in the `perm` argument of `jax.lax.ppermute(block, axis_names, perm=src_to_dst)` is referring to the device index (set by Mesh) rather than the logical index (comes with `P((\"tp\", \"dp\"))`).\n", | |
| "\n", | |
| "The `src_to_dst =[(idx, (idx + offset) % num_devices) for idx in range(num_devices)]` logic is not a problem if there is only one ppermute axis, but when you `ppermute` along multiple axes this logic implicitly messes up your data order (because the mesh axes order and logical axes order mismatch). \n", | |
| "\n", | |
| "The solution is to maintain a mapping from logical index to device index and use device index in `jax.lax.ppermute`" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 16, | |
| "id": "7096fe4d-8152-4f6d-a20b-e6c07c70701d", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "# A generalized implementation\n", | |
| "def ring_shift(\n", | |
| " block: Array, axis_name: str | tuple[str, ...], offset: int = 1\n", | |
| ") -> Array:\n", | |
| " \"\"\"Ring shift forward or backward along the device mesh axis\n", | |
| "\n", | |
| " Parameters\n", | |
| " --------\n", | |
| " block : Array\n", | |
| " axis_name: str | tuple[str, ...]\n", | |
| " Single axis name or tuple of axis names for multi-axis ring shift\n", | |
| " offset : int\n", | |
| " If positive, ring shift forward; if negative, ring shift backward.\n", | |
| " Default is shift forward with step size=1\n", | |
| "\n", | |
| " Returns\n", | |
| " ----------\n", | |
| " receive : Array\n", | |
| " the received block from previous device if shift forward\n", | |
| " the received block from next device if shift backward\n", | |
| " \"\"\"\n", | |
| " if isinstance(axis_name, str):\n", | |
| " axis_name = (axis_name,)\n", | |
| " logical_axes = {a: jax.lax.axis_size(a) for a in axis_name}\n", | |
| " mesh = get_abstract_mesh()\n", | |
| " mesh_axes = {a: logical_axes[a] for a in mesh.axis_names if a in logical_axes}\n", | |
| " num_devices = jax.lax.axis_size(axis_name)\n", | |
| "\n", | |
| " def linear_idx_to_coordinates(idx: int, axes: dict[str, int]) -> dict[str, int]:\n", | |
| " \"\"\"Convert linear idx to coordinates according to the axes order\"\"\"\n", | |
| " coords: dict[str, int] = {}\n", | |
| " for axis_name, axis_size in reversed(axes.items()):\n", | |
| " idx, rem = divmod(idx, axis_size)\n", | |
| " coords[axis_name] = rem\n", | |
| " return dict(reversed(coords.items()))\n", | |
| "\n", | |
| " def coordinates_to_linear_idx(coords: dict[str, int], axes: dict[str, int]) -> int:\n", | |
| " \"\"\"Linearize coordinates according to the axes order to linear idx\"\"\"\n", | |
| " idx = 0\n", | |
| " stride = 1\n", | |
| " for axis_name, axis_size in reversed(axes.items()):\n", | |
| " idx += coords[axis_name] * stride\n", | |
| " stride *= axis_size\n", | |
| " return idx\n", | |
| "\n", | |
| " # construct mapping from logical axis idx to physical mesh axis idx\n", | |
| " logical_to_mesh = []\n", | |
| " for logical_idx in range(num_devices):\n", | |
| " coords = linear_idx_to_coordinates(logical_idx, logical_axes)\n", | |
| " mesh_idx = coordinates_to_linear_idx(coords, mesh_axes)\n", | |
| " logical_to_mesh.append(mesh_idx)\n", | |
| "\n", | |
| " src_to_dst = []\n", | |
| " for src_logical_idx in range(num_devices):\n", | |
| " dst_logical_idx = (src_logical_idx + offset) % num_devices\n", | |
| " src_mesh_idx = logical_to_mesh[src_logical_idx]\n", | |
| " dst_mesh_idx = logical_to_mesh[dst_logical_idx]\n", | |
| " src_to_dst.append((src_mesh_idx, dst_mesh_idx))\n", | |
| " return jax.lax.ppermute(block, axis_name, perm=src_to_dst)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 17, | |
| "id": "46da464b-080e-4248-8993-717a9932bc24", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "[0 1 2 3 4 5 6 7]\n", | |
| "jax.typeof(y_shift_forward)=ShapedArray(int32[8@(tp,dp)])\n", | |
| "[7 0 1 2 3 4 5 6]\n", | |
| "jax.typeof(y_shift_backward)=ShapedArray(int32[8@(tp,dp)])\n", | |
| "[1 2 3 4 5 6 7 0]\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "@jax.shard_map(\n", | |
| " mesh=global_mesh,\n", | |
| " in_specs=P((\"tp\", \"dp\")),\n", | |
| " out_specs=(P((\"tp\", \"dp\")), P((\"tp\", \"dp\"))),\n", | |
| ")\n", | |
| "def ring_shift_demo(\n", | |
| " x_block: Float[Array, \"n\"],\n", | |
| ") -> Float[Array, \"n\"]:\n", | |
| " \"\"\"Demo the ring shift pattern\"\"\"\n", | |
| " # i32[8@(tp,dp)] -> i32[1]{tp=2,dp=4} varying -> i32[8@(tp,dp)]\n", | |
| " axis_name = (\"tp\", \"dp\")\n", | |
| " # ring shift forward\n", | |
| " x_block_recv_from_prev = ring_shift(x_block, axis_name, offset=1)\n", | |
| " # ring shift backward\n", | |
| " x_block_recv_from_next = ring_shift(x_block, axis_name, offset=-1)\n", | |
| " return x_block_recv_from_prev, x_block_recv_from_next\n", | |
| "\n", | |
| "\n", | |
| "y = jnp.arange(8)\n", | |
| "print(y)\n", | |
| "y_shift_forward, y_shift_backward = ring_shift_demo(y)\n", | |
| "\n", | |
| "print(f\"{jax.typeof(y_shift_forward)=}\")\n", | |
| "print(y_shift_forward)\n", | |
| "print(f\"{jax.typeof(y_shift_backward)=}\")\n", | |
| "print(y_shift_backward)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "id": "116b4c35-b55a-409f-b427-7dbaab5de87a", | |
| "metadata": {}, | |
| "source": [ | |
| "#### Reduce-Scatter \n", | |
| "Reduce-Scatter along a device mesh axis $i$ of size $n$ can be implemented as $n-1$ PPermute operations along that axis.\n", | |
| "\n", | |
| "Communication cost: $O(n(n-1))$\n", | |
| "\n", | |
| "Use case: FSDP gradient update; Tensor parallel forward / gradient update" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 18, | |
| "id": "41ff0e30-f9d1-4d59-ad37-6404e6e2ece0", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "jax.typeof(x_block)=ShapedArray(int32[2]{tp})\n", | |
| "jax.typeof(y_block)=ShapedArray(int32[1]{tp})\n", | |
| "y=Array([2, 4], dtype=int32)\n", | |
| "jax.typeof(y)=ShapedArray(int32[2@tp])\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "@jax.shard_map(\n", | |
| " mesh=None, in_specs=P(\"tp\"), out_specs=P(\"tp\")\n", | |
| ") # use default mesh in context (set by set_mesh / with set_mesh(mesh)) if mesh is not specified\n", | |
| "def reduce_scatter_demo(x_block: Float[Array, \"tp_block\"]) -> Float[Array, \"tp_block\"]:\n", | |
| " # i32[4@tp=2] -> i32[2{tp=2}] (varying) -> i32[1{tp=2}] -> concate along dim=0 -> i32[2@tp]\n", | |
| " print(f\"{jax.typeof(x_block)=}\")\n", | |
| " y_block = jax.lax.psum_scatter(x_block, \"tp\", scatter_dimension=0, tiled=True)\n", | |
| " # tiled=True ^ means its transpose (all-gather) will concatenate the results to produce the reduced block\n", | |
| " print(f\"{jax.typeof(y_block)=}\") # int32[1]{tp}\n", | |
| " return y_block\n", | |
| "\n", | |
| "\n", | |
| "y = reduce_scatter_demo(jnp.arange(4))\n", | |
| "# x = [0, 1, 2, 3] -> [0, 1] on tp=0, [2, 3] on tp=1 -> reduce scatter -> [2,4] -> int32[2@tp=2] [2] on tp=0, [4] on tp=1\n", | |
| "print(f\"{y=}\")\n", | |
| "print(f\"{jax.typeof(y)=}\")" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "id": "f69ff137-cc93-4404-8b33-2cf7739eda74", | |
| "metadata": {}, | |
| "source": [ | |
| "`jax.lax.psum_scatter` <- transpose -> `jax.lax.all_gather`" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 19, | |
| "id": "bf72425a-1688-4d57-9f1c-511758743a1e", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "jax.typeof(x_block)=ShapedArray(int32[2,2]{tp})\n", | |
| "jax.typeof(y_block)=ShapedArray(int32[1,2]{tp})\n", | |
| "y=Array([[ 2, 4],\n", | |
| " [10, 12]], dtype=int32)\n", | |
| "jax.typeof(y)=ShapedArray(int32[2@tp,2])\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "@jax.shard_map(in_specs=P(None, \"tp\"), out_specs=P(\"tp\"))\n", | |
| "def reduce_scatter_demo(\n", | |
| " x_block: Float[Array, \"batch tp_block\"],\n", | |
| ") -> Float[Array, \"batch tp_block\"]:\n", | |
| " # i32[2, 4 @ tp] -> i32[2, 2{tp=2}] (varying) -> i32[1{tp=2}, 2] -> concate along dim=0 -> int32[2@tp, 2]\n", | |
| " print(f\"{jax.typeof(x_block)=}\")\n", | |
| " y_block = jax.lax.psum_scatter(x_block, \"tp\", scatter_dimension=0, tiled=True)\n", | |
| " # tiled=True means its transpose (all-gather) will concatenate the results along the dim=0\n", | |
| " # to produce the reduced block, so reduce_scatter will keep that dimension\n", | |
| " print(f\"{jax.typeof(y_block)=}\") # int32[1{tp=2},2]\n", | |
| " return y_block\n", | |
| "\n", | |
| "\n", | |
| "y = reduce_scatter_demo(jnp.arange(8).reshape(2, 4))\n", | |
| "# [[0, 1] [[2, 3], [[2, 4], [[2, 4]] tp=0 int32[1{tp=2},2] Final [[2, 4]\n", | |
| "# [4, 5]] [6, 7]] [10,12]] [[10,12]] tp=1 int32[1{tp=2},2] Output [10,12]]\n", | |
| "# tp=0 + tp=1 --> reduce sum --> scatter dim=0, tiled =True -> int32[2@tp,2]\n", | |
| "print(f\"{y=}\")\n", | |
| "print(f\"{jax.typeof(y)=}\")" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 20, | |
| "id": "fd9eddbb-4f7a-4504-929a-3a560e7a0e83", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "jax.typeof(y_block)=ShapedArray(int32[2,1]{tp})\n", | |
| "y=Array([[ 2, 4],\n", | |
| " [10, 12]], dtype=int32)\n", | |
| "jax.typeof(y)=ShapedArray(int32[2,2@tp])\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "@jax.shard_map(in_specs=P(None, \"tp\"), out_specs=P(None, \"tp\"))\n", | |
| "def reduce_scatter_demo(\n", | |
| " x_block: Float[Array, \"batch tp_block\"],\n", | |
| ") -> Float[Array, \"batch tp_block\"]:\n", | |
| " # i32[2, 4 @ tp] -> i32[2, 2{tp=2}] (varying) -> i32[2, 1{tp=2}] -> concate along dim=1 -> int32[2, 2@tp]\n", | |
| " y_block = jax.lax.psum_scatter(x_block, \"tp\", scatter_dimension=1, tiled=True)\n", | |
| " # tiled=True means its transpose (all-gather) will concatenate the results along the dim=1\n", | |
| " # to produce the reduced block, so reduce_scatter will keep that dimension\n", | |
| " print(f\"{jax.typeof(y_block)=}\") # int32[2,1{tp=2}]\n", | |
| " return y_block\n", | |
| "\n", | |
| "\n", | |
| "y = reduce_scatter_demo(jnp.arange(8).reshape(2, 4))\n", | |
| "# [[0, 1] [[2, 3], [[2, 4], [[2], | [[4], int32[2,1{tp}] (left on tp=0) Final [[ 2, 4],\n", | |
| "# [4, 5]] [6, 7]] [10,12]] [10]] | [12]] int32[2,1{tp}] (right on tp=1) Output [10,12]]\n", | |
| "# tp=0 + tp=1 --> reduce sum --> scatter dim=1, tiled =True ----> # int32[2,2@tp]\n", | |
| "# concate per our_spec along dim=1\n", | |
| "print(f\"{y=}\")\n", | |
| "print(f\"{jax.typeof(y)=}\")" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 21, | |
| "id": "b028daf9-4e71-4ead-8b82-019927e52a68", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "jax.typeof(y_block)=ShapedArray(int32[2]{tp})\n", | |
| "y=Array([ 2, 4, 10, 12], dtype=int32)\n", | |
| "jax.typeof(y)=ShapedArray(int32[4@tp])\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "@jax.shard_map(in_specs=P(None, \"tp\"), out_specs=P(\"tp\"))\n", | |
| "def reduce_scatter_demo(\n", | |
| " x_block: Float[Array, \"batch tp_block\"],\n", | |
| ") -> Float[Array, \"(batch tp_block)\"]:\n", | |
| " # i32[2, 4 @ tp] -> i32[2, 2{tp=2}] (varying) -> int32[2]{tp=2} -> concate along dim=0 -> int32[4@tp]\n", | |
| " y_block = jax.lax.psum_scatter(x_block, \"tp\", scatter_dimension=0, tiled=False)\n", | |
| " # tiled=False (default) means its transpose (all-gather) will stack the results along dim=0 (add one dimension)\n", | |
| " # to produce the reduced block, so reduce_scatter will remove that dimension\n", | |
| " print(f\"{jax.typeof(y_block)=}\") # int32[2]{tp=2}\n", | |
| " return y_block\n", | |
| "\n", | |
| "\n", | |
| "y = reduce_scatter_demo(jnp.arange(8).reshape(2, 4))\n", | |
| "# [[0, 1] [[2, 3], [[2, 4], [2, 4] tp=0 int32[2]{tp} Final\n", | |
| "# [4, 5]] [6, 7]] [10,12]] [10,12] tp=1 int32[2]{tp} Output [2,4,10,12]\n", | |
| "# tp=0 + tp=1 --> reduce sum --> scatter dim=0, tiled =False -> concate along dim =0 -> int32[4@tp]\n", | |
| "print(f\"{y=}\")\n", | |
| "print(f\"{jax.typeof(y)=}\")" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 22, | |
| "id": "4017a413-562f-46e4-ba29-5e0046fe250c", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "jax.typeof(y_block)=ShapedArray(int32[2,1]{tp})\n", | |
| "y=Array([[ 2],\n", | |
| " [10],\n", | |
| " [ 4],\n", | |
| " [12]], dtype=int32)\n", | |
| "jax.typeof(y)=ShapedArray(int32[4@tp,1])\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "@jax.shard_map(in_specs=P(None, \"tp\"), out_specs=P(\"tp\"))\n", | |
| "def reduce_scatter_demo(\n", | |
| " x_block: Float[Array, \"batch tp_block\"],\n", | |
| ") -> Float[Array, \"batch tp_block\"]:\n", | |
| " # i32[2, 4 @ tp] -> i32[2, 2{tp=2}] (varying) -> i32[2{tp=2},1] -> int32[4@tp,1]\n", | |
| " y_block = jax.lax.psum_scatter(x_block, \"tp\", scatter_dimension=1, tiled=True)\n", | |
| " # tiled=True means its transpose (all-gather) will concatenate the results along the dim=1\n", | |
| " # to produce the reduced block, so reduce_scatter will keep that dimension\n", | |
| " print(f\"{jax.typeof(y_block)=}\") # int32[2{tp=2},1]\n", | |
| " return y_block\n", | |
| "\n", | |
| "\n", | |
| "y = reduce_scatter_demo(jnp.arange(8).reshape(2, 4))\n", | |
| "# [[0, 1] [[2, 3], [[2, 4], [[2], | [[4], int32[2{tp},1] (left on tp=0) Final [[ 2],\n", | |
| "# [4, 5]] [6, 7]] [10,12]] [10]] | [12]] int32[2{tp},1] (right on tp=1) Output [10],\n", | |
| "# tp=0 + tp=1 --> reduce sum --> scatter dim=1, tiled =True ----> [ 4],\n", | |
| "# concate per our_spec along dim=0 [12]] # int32[4@tp,1]\n", | |
| "print(f\"{y=}\")\n", | |
| "print(f\"{jax.typeof(y)=}\")" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 23, | |
| "id": "c919afd4-e908-4a3a-b749-af15b8588c09", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "jax.typeof(y_block)=ShapedArray(int32[2]{tp})\n", | |
| "y=Array([ 2, 10, 4, 12], dtype=int32)\n", | |
| "jax.typeof(y)=ShapedArray(int32[4@tp])\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "@jax.shard_map(in_specs=P(None, \"tp\"), out_specs=P(\"tp\"))\n", | |
| "def reduce_scatter_demo(\n", | |
| " x_block: Float[Array, \"batch tp_block\"],\n", | |
| ") -> Float[Array, \"batch tp_block\"]:\n", | |
| " # i32[2, 4 @ tp] -> i32[2, 2{tp=2}] (varying) -> i32[2]{tp=2} -> int32[4@tp]\n", | |
| " y_block = jax.lax.psum_scatter(x_block, \"tp\", scatter_dimension=1, tiled=False)\n", | |
| " # tiled=False means its transpose (all-gather) will stack the results along the dim=1\n", | |
| " # to produce the reduced block, so reduce_scatter will remove that dimension\n", | |
| " print(f\"{jax.typeof(y_block)=}\") # int32[2]{tp=2}\n", | |
| " return y_block\n", | |
| "\n", | |
| "\n", | |
| "y = reduce_scatter_demo(jnp.arange(8).reshape(2, 4))\n", | |
| "# [[0, 1] [[2, 3], [[2, 4], [2, 10] int32[2{tp=0}] Final [2,10, 4,12]\n", | |
| "# [4, 5]] [6, 7]] [10,12]] [4, 12] int32[2{tp=1}] Output int32[4@tp]\n", | |
| "# tp=0 + tp=1 --> reduce sum --> scatter dim=1, tiled=False ---->\n", | |
| "print(f\"{y=}\")\n", | |
| "print(f\"{jax.typeof(y)=}\")" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "id": "41fd2bf9-0e83-49ef-b6a9-ca6f61028d39", | |
| "metadata": {}, | |
| "source": [ | |
| "#### All-Gather\n", | |
| "All-Gather along an axis $i$ of size $n$ is equivalent to Gather ($O((n-1)*c)$) + BroadCast ($O( (n-1) *(nc))$)\n", | |
| "\n", | |
| "Communication cost: $O((n-1) * c + (n - 1) * (nc) ) = O((n^2-1)c) = O(n^2)$ \n", | |
| "\n", | |
| "The implementation of All-Gather is still $n-1$ PPermute, meaning each GPU sends its chunk to its neighbor and loop for $n-1$ times. \n", | |
| "\n", | |
| "Use case: weights pre-fetching during FSDP forward step" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 24, | |
| "id": "7ca3f8e3-7cbd-4349-a72e-4a63f92c6932", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "jax.typeof(w_block)=ShapedArray(int32[2]{tp})\n", | |
| "jax.typeof(w)=ShapedArray(int32[4])\n", | |
| "[0 1 2 3]\n", | |
| "int32[4]\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "@jax.shard_map(in_specs=P(\"tp\"), out_specs=P())\n", | |
| "def all_gather_invariant_demo(\n", | |
| " w_block: Float[Array, \"tp_block\"],\n", | |
| ") -> Float[Array, \"(tp tp_block)\"]:\n", | |
| " # i32[4@tp] -> i32[2{tp=2}] -> concate along axis=0 -> i32[4] -> replicate\n", | |
| " print(f\"{jax.typeof(w_block)=}\") # int32[2]{tp} <- varying manual axis tp\n", | |
| " w = jax.lax.all_gather_invariant(w_block, \"tp\", axis=0, tiled=True)\n", | |
| " # concatenate if tiled = True, stack if tiled=False (default)\n", | |
| " print(f\"{jax.typeof(w)=}\") # int32[4] <- unvarying\n", | |
| " return w\n", | |
| "\n", | |
| "\n", | |
| "y = all_gather_invariant_demo(jnp.arange(4))\n", | |
| "# [0, 1, 2, 3] -> [0, 1] on tp=0, [2, 3] on tp=1 -> all gather -> [0, 1, 2, 3] replicated across all device mesh axis\n", | |
| "print(y)\n", | |
| "print(jax.typeof(y))" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 25, | |
| "id": "760b2266-3b3d-4c56-aa19-42ff0256c3b2", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "jax.typeof(w)=ShapedArray(int32[2,2])\n", | |
| "[[0 1]\n", | |
| " [2 3]]\n", | |
| "int32[2,2]\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "@jax.shard_map(in_specs=P(\"tp\"), out_specs=P())\n", | |
| "def all_gather_invariant_demo(\n", | |
| " w_block: Float[Array, \"tp_block\"],\n", | |
| ") -> Float[Array, \"tp tp_block\"]:\n", | |
| " # i32[4@tp] -> i32[2{tp=2}] -> stack along axis=0 -> i32[2,2] -> replicate\n", | |
| " w = jax.lax.all_gather_invariant(w_block, \"tp\", axis=0, tiled=False)\n", | |
| " # concatenate if tiled = True, stack if tiled=False (default)\n", | |
| " print(f\"{jax.typeof(w)=}\") # int32[2,2] <- unvarying\n", | |
| " return w\n", | |
| "\n", | |
| "\n", | |
| "y = all_gather_invariant_demo(jnp.arange(4))\n", | |
| "# [0, 1, 2, 3] -> [0, 1] on tp=0, [2, 3] on tp=1 -> all gather -> [[0, 1], [2, 3]] across tp\n", | |
| "print(y)\n", | |
| "print(jax.typeof(y))" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 26, | |
| "id": "eb655670-a8c1-4cb1-9e06-d6619d31354c", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "jax.typeof(w)=ShapedArray(int32[2,2])\n", | |
| "[[0 2]\n", | |
| " [1 3]]\n", | |
| "int32[2,2]\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "@jax.shard_map(in_specs=P(\"tp\"), out_specs=P())\n", | |
| "def all_gather_invariant_demo(\n", | |
| " w_block: Float[Array, \"tp_block\"],\n", | |
| ") -> Float[Array, \"tp tp_block\"]:\n", | |
| " # i32[4@tp] -> i32[2{tp=2}] -> stack along axis=1 -> i32[2,2] -> replicate\n", | |
| " w = jax.lax.all_gather_invariant(w_block, \"tp\", axis=1, tiled=False)\n", | |
| " # concatenate if tiled = True, stack if tiled=False (default)\n", | |
| " print(f\"{jax.typeof(w)=}\") # int32[2,2] <- unvarying\n", | |
| " return w\n", | |
| "\n", | |
| "\n", | |
| "y = all_gather_invariant_demo(jnp.arange(4))\n", | |
| "# [0, 1, 2, 3] -> [0, 1] on tp=0, [2, 3] on tp=1 -> all gather stack along dim=1 -> [[0, 2],\n", | |
| "# [1, 3] across all device mesh axes\n", | |
| "print(y)\n", | |
| "print(jax.typeof(y))" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "id": "efccc5f9-b41b-4329-a606-46e8d7fea17c", | |
| "metadata": {}, | |
| "source": [ | |
| "Note that `jax.lax.all_gather_invariant` is a varying-to-unvarying op, whereas `jax.lax.all_gather` is a varying-to-varying op. This is only because the transpose of `jax.lax.all_gather` is `reduce_scatter` (`jax.lax.psum_scatter` to be more specific), which is a varying-to-varying op. Whereas the transpose of `jax.lax.all_gather_invariant` is `jax.lax.dynamic_slice`, which is a unvarying-to-varying op.\n", | |
| "\n", | |
| "See [varying manual mesh axis (vma)](https://docs.jax.dev/en/latest/notebooks/shard_map.html#tracking-how-values-vary-over-manual-mesh-axes-and-check-vma-true) for more information. " | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 27, | |
| "id": "6a75bd59-f1fa-4ad5-ad3c-ee93a089c581", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "jax.typeof(w_block)=ShapedArray(int32[2]{tp})\n", | |
| "jax.typeof(w)=ShapedArray(int32[4]{tp})\n", | |
| "[0 1 2 3 0 1 2 3]\n", | |
| "int32[8@tp]\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "@jax.shard_map(in_specs=P(\"tp\"), out_specs=P(\"tp\"))\n", | |
| "def all_gather_demo(w_block: Float[Array, \"tp_block\"]) -> Float[Array, \"(tp tp_block)\"]:\n", | |
| " # i32[4@tp] -> i32[2{tp=2}] -> concatenate along axis=0 -> i32[4]{tp} varying -> concatenate along axis=0 -> i32[8@tp]\n", | |
| " print(f\"{jax.typeof(w_block)=}\") # int32[2]{tp} <- varying manual axis tp\n", | |
| " w = jax.lax.all_gather(w_block, \"tp\", axis=0, tiled=True)\n", | |
| " print(f\"{jax.typeof(w)=}\") # Attention: int32[4]{tp} <- varying\n", | |
| " # Note: because w is varying along tp, the output sharding spec cannot be P(), but has to be P(\"tp\")\n", | |
| " return w\n", | |
| "\n", | |
| "\n", | |
| "y = all_gather_demo(jnp.arange(4))\n", | |
| "# [0, 1, 2, 3] -> [0, 1] on tp=0, [2, 3] on tp=1 -> all gather -> [0, 1, 2, 3] varying across tp -> concate tp partitions along dim=0 -> i32[8@tp]\n", | |
| "print(y)\n", | |
| "print(jax.typeof(y))" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "id": "b33d8d1a-c223-438d-afa4-d99eefae1c26", | |
| "metadata": {}, | |
| "source": [ | |
| "#### All-Reduce\n", | |
| "All-Reduce along a device mesh axis $i$ of size $n$ can be implemented as Reduce-Scatter ($O(n^2)$) + All-Gather ($O(n^2)$) \n", | |
| "\n", | |
| "Communication cost: $O(2n^2)$ twice the communication cost of Reduce-Scatter or All-Gather\n", | |
| "\n", | |
| "Use case: data parallel gradient update; tensor parallel (row parallel linear) forward pass" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 28, | |
| "id": "c6ae459f-78e8-48ba-9081-8949d2379426", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "jax.typeof(x_block)=ShapedArray(int32[2]{dp})\n", | |
| "jax.typeof(y_all_reduced)=ShapedArray(int32[2])\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "@jax.shard_map(in_specs=P(\"dp\"), out_specs=P())\n", | |
| "def all_reduce_1d(x_block: Float[Array, \"block_size\"]) -> Float[Array, \"block_size\"]:\n", | |
| " # i32[8@dp] -> i32[2{dp=4}] -> all-reduce -> i32[2] unvarying -> replicate\n", | |
| " print(f\"{jax.typeof(x_block)=}\") # int32[2]{dp} <- varying manual axis dp\n", | |
| " x_block = x_block + 1 # parallel execution\n", | |
| " y_all_reduced = jax.lax.psum(x_block, \"dp\")\n", | |
| " print(f\"{jax.typeof(y_all_reduced)=}\") # int32[2] <- unvarying manual axis dp, tp\n", | |
| " return y_all_reduced\n", | |
| "\n", | |
| "\n", | |
| "y = all_reduce_1d(jnp.arange(8) - 1)\n", | |
| "assert jnp.all(y == jnp.sum(jnp.asarray([[0, 1], [2, 3], [4, 5], [6, 7]]), axis=0))" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 29, | |
| "id": "58309258-4142-4599-8d43-16b10499a87f", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "jax.typeof(x_block)=ShapedArray(float32[256,4096]{dp})\n", | |
| "jax.typeof(w_block)=ShapedArray(float32[4096,128]{tp})\n", | |
| "jax.typeof(y_block)=ShapedArray(float32[256,128]{tp,dp})\n", | |
| "jax.typeof(y_block_all_reduced)=ShapedArray(float32[256,128])\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "@jax.shard_map(in_specs=(P(\"dp\"), P(None, \"tp\")), out_specs=P())\n", | |
| "def all_reduce_2d(\n", | |
| " x_block: Float[Array, \"dp_block hidden\"], w_block: Float[Array, \"hidden tp_block\"]\n", | |
| ") -> Float[Array, \"dp_block tp_block\"]:\n", | |
| " # f32[1024@dp=4, 4096] @ f32[4096, 256@tp=2]\n", | |
| " # f32[256{dp=4}, 4096] @ f32[4096, 128{tp=2}] -> f32[256{dp=4}, 128{tp=2}]\n", | |
| " print(f\"{jax.typeof(x_block)=}\") # float32[256,4096]{dp=4, None}\n", | |
| " print(f\"{jax.typeof(w_block)=}\") # float32[4096,128]{None, tp=2}\n", | |
| " y_block = x_block @ w_block\n", | |
| " # float32[256,128]{tp,dp} <- varying manual axis along dp and tp\n", | |
| " print(f\"{jax.typeof(y_block)=}\")\n", | |
| " y_block_all_reduced = jax.lax.psum(y_block, (\"dp\", \"tp\"))\n", | |
| " # float32[256,128] <- unvarying manual axis\n", | |
| " print(f\"{jax.typeof(y_block_all_reduced)=}\")\n", | |
| " return y_block_all_reduced\n", | |
| "\n", | |
| "\n", | |
| "y = all_reduce_2d(x, w)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "id": "e93a7b61-7f84-432a-af0a-749bd6d7e77f", | |
| "metadata": {}, | |
| "source": [ | |
| "#### All-to-All\n", | |
| "Use case: MOE expert routing \n", | |
| "\n", | |
| "Think of `all_to_all` as a **block transpose between a local axis and a mesh axis**. Every device sends a slice to every other device, and receives a slice from every other device. \n", | |
| "\n", | |
| "`all_to_all(x, axis_name, split_axis, concat_axis)`: each device splits its local block `x` along `split_axis` into `prod(*axis_name)` blocks (so `split_axis` size must divide `jax.lax.axis_size(axis_name)`), sends local block `i` to device `i`, receives one block from every other device, and then concatenates received blocks (gather) along `concat_axis`." | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 30, | |
| "id": "466d8ff2-4657-46c9-985c-b339df1d3470", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "x=Array([[ 0, 1, 2, 3, 4, 5, 6, 7],\n", | |
| " [10, 11, 12, 13, 14, 15, 16, 17],\n", | |
| " [20, 21, 22, 23, 24, 25, 26, 27],\n", | |
| " [30, 31, 32, 33, 34, 35, 36, 37],\n", | |
| " [40, 41, 42, 43, 44, 45, 46, 47],\n", | |
| " [50, 51, 52, 53, 54, 55, 56, 57],\n", | |
| " [60, 61, 62, 63, 64, 65, 66, 67],\n", | |
| " [70, 71, 72, 73, 74, 75, 76, 77]], dtype=int32)\n", | |
| "jax.typeof(x_block)=ShapedArray(int32[8]{tp,dp})\n", | |
| "jax.typeof(y_block)=ShapedArray(int32[8]{tp,dp})\n", | |
| "x_all_to_all=Array([[ 0, 10, 20, 30, 40, 50, 60, 70],\n", | |
| " [ 1, 11, 21, 31, 41, 51, 61, 71],\n", | |
| " [ 2, 12, 22, 32, 42, 52, 62, 72],\n", | |
| " [ 3, 13, 23, 33, 43, 53, 63, 73],\n", | |
| " [ 4, 14, 24, 34, 44, 54, 64, 74],\n", | |
| " [ 5, 15, 25, 35, 45, 55, 65, 75],\n", | |
| " [ 6, 16, 26, 36, 46, 56, 66, 76],\n", | |
| " [ 7, 17, 27, 37, 47, 57, 67, 77]], dtype=int32)\n", | |
| "jax.typeof(x_all_to_all)=ShapedArray(int32[8@(dp,tp),8])\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "@jax.shard_map(\n", | |
| " in_specs=P((\"dp\", \"tp\")),\n", | |
| " out_specs=P((\"dp\", \"tp\")),\n", | |
| ")\n", | |
| "def all_to_all_demo(x_block: Array) -> Array:\n", | |
| " # i32[64@(dp,tp)] -> i32[8{dp,tp}] -> split along axis=0 to dp x tp partitions -> all-to-all send\n", | |
| " # each local split i32[1{dp,tp}] to the corresponding device on mesh axis (\"dp\", \"tp\")\n", | |
| " # -> each (dp,tp) device receives dp x tp partitions -> concate along axis=0 -> i32[8{dp,tp}] -> i32[64@(dp,tp)]\n", | |
| " print(f\"{jax.typeof(x_block)=}\")\n", | |
| " y_block = jax.lax.all_to_all(\n", | |
| " x_block, axis_name=(\"dp\", \"tp\"), split_axis=0, concat_axis=0, tiled=True\n", | |
| " )\n", | |
| " print(f\"{jax.typeof(y_block)=}\")\n", | |
| " return y_block\n", | |
| "\n", | |
| "\n", | |
| "x = 10 * jnp.arange(8)[:, None] + jnp.arange(8)[None, :]\n", | |
| "print(f\"{x=}\")\n", | |
| "x_all_to_all: Array = all_to_all_demo(x.ravel()).reshape(x.shape)\n", | |
| "print(f\"{x_all_to_all=}\")\n", | |
| "print(f\"{jax.typeof(x_all_to_all)=}\")" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 31, | |
| "id": "a57c6abf-f855-4efd-8c0c-0bae94ec574e", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "x=Array([[ 0, 1, 2, 3, 4, 5, 6, 7],\n", | |
| " [10, 11, 12, 13, 14, 15, 16, 17],\n", | |
| " [20, 21, 22, 23, 24, 25, 26, 27],\n", | |
| " [30, 31, 32, 33, 34, 35, 36, 37],\n", | |
| " [40, 41, 42, 43, 44, 45, 46, 47],\n", | |
| " [50, 51, 52, 53, 54, 55, 56, 57],\n", | |
| " [60, 61, 62, 63, 64, 65, 66, 67],\n", | |
| " [70, 71, 72, 73, 74, 75, 76, 77]], dtype=int32)\n", | |
| "jax.typeof(x_block)=ShapedArray(int32[2,4]{tp,dp})\n", | |
| "jax.typeof(y_block)=ShapedArray(int32[4,2]{tp,dp})\n", | |
| "x_all_to_all=Array([[ 0, 10, 20, 30, 40, 50, 60, 70],\n", | |
| " [ 1, 11, 21, 31, 41, 51, 61, 71],\n", | |
| " [ 2, 12, 22, 32, 42, 52, 62, 72],\n", | |
| " [ 3, 13, 23, 33, 43, 53, 63, 73],\n", | |
| " [ 4, 14, 24, 34, 44, 54, 64, 74],\n", | |
| " [ 5, 15, 25, 35, 45, 55, 65, 75],\n", | |
| " [ 6, 16, 26, 36, 46, 56, 66, 76],\n", | |
| " [ 7, 17, 27, 37, 47, 57, 67, 77]], dtype=int32)\n", | |
| "jax.typeof(x_all_to_all)=ShapedArray(int32[8@(tp,dp),8])\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "@jax.shard_map(\n", | |
| " in_specs=P(\"dp\", \"tp\"),\n", | |
| " out_specs=P(\n", | |
| " (\"tp\", \"dp\")\n", | |
| " ), # exercise question: why P((\"tp\", \"dp\"))? What if P((\"dp\", \"tp\"))?\n", | |
| ")\n", | |
| "def all_to_all_demo(x_block: Array) -> Array:\n", | |
| " # i32[8@dp, 8@tp] -> i32[2{dp}, 4{tp}] -> split along axis=1 to #dp local partitions i32[2{dp=4}, {tp}] (unstack split_axis because tiled=False)\n", | |
| " # -> send each local partition to the corresponding device on mesh axis \"dp\" -> each dp device receives 4 i32[2,] from devices with matching\n", | |
| " # remaining mesh axes but varying dp index -> stack along axis=0 -> i32[4, 2] -> reshape to i32[1, 8] -> concate output local block along axis=0 per out sharding spec\n", | |
| " # -> i32[8@(tp,dp), 8]\n", | |
| " print(f\"{jax.typeof(x_block)=}\")\n", | |
| " y_block = jax.lax.all_to_all(\n", | |
| " x_block, axis_name=\"dp\", split_axis=1, concat_axis=0, tiled=False\n", | |
| " )\n", | |
| " print(f\"{jax.typeof(y_block)=}\")\n", | |
| " return y_block.reshape(1, -1)\n", | |
| "\n", | |
| "\n", | |
| "x = 10 * jnp.arange(8)[:, None] + jnp.arange(8)[None, :]\n", | |
| "print(f\"{x=}\")\n", | |
| "x_all_to_all: Array = all_to_all_demo(x)\n", | |
| "print(f\"{x_all_to_all=}\")\n", | |
| "print(f\"{jax.typeof(x_all_to_all)=}\")" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "id": "1a6afc9a-0517-4e88-80c8-03ab626e9f97", | |
| "metadata": {}, | |
| "source": [ | |
| "## Matrix Multiplication " | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "id": "e6503226-47ed-4a1a-b54c-090c906769b5", | |
| "metadata": {}, | |
| "source": [ | |
| "### Naive Version" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 32, | |
| "id": "82da5ba3-fb46-43d9-a868-f1780d48e319", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "13.6 ms ± 145 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "x = jax.random.normal(jax.random.key(0), (1024, 4096))\n", | |
| "w = jax.random.normal(jax.random.key(1), (4096, 256))\n", | |
| "\n", | |
| "\n", | |
| "@jax.jit\n", | |
| "def matmul(\n", | |
| " x: Float[Array, \"batch n_in\"], w: Float[Array, \"n_in n_out\"]\n", | |
| ") -> Float[Array, \"batch n_out\"]:\n", | |
| " \"\"\"Naive matrix multiplication\"\"\"\n", | |
| " return x @ w\n", | |
| "\n", | |
| "\n", | |
| "y = matmul(x, w)\n", | |
| "\n", | |
| "%timeit y = matmul(x, w).block_until_ready()" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "id": "e2e48a80-b324-4015-b7aa-024c14999a53", | |
| "metadata": {}, | |
| "source": [ | |
| "### FSDP Forward\n" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 33, | |
| "id": "e093d979-1354-4f4c-a4e3-c53cf8677342", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "6.59 ms ± 257 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "@jax.jit\n", | |
| "@jax.shard_map(\n", | |
| " mesh=global_mesh, in_specs=(P(\"dp\", None), P(\"dp\", None)), out_specs=P(\"dp\", None)\n", | |
| ")\n", | |
| "def matmul_fsdp(\n", | |
| " x_block: Float[Array, \"batch_local n_in\"],\n", | |
| " w_block: Float[Array, \"n_in_local n_out\"],\n", | |
| ") -> Float[Array, \"batch_local n_out\"]:\n", | |
| " \"\"\"FSDP matrix multiplication\n", | |
| "\n", | |
| " Data and weights are sharded across the dp/fsdp device mesh axis.\n", | |
| " An all-gather is needed during the forward. Advanced implementation\n", | |
| " use prefetching to overlap computation (in the current layer) and\n", | |
| " communication (all-gather weights in the next layer)\n", | |
| " \"\"\"\n", | |
| " # f32[1024@dp, 4096] @ f32[4096@dp, 256]\n", | |
| " # f32[256{dp=4}, 4096] f32[1024{dp=4}, 256]\n", | |
| "\n", | |
| " w_full = jax.lax.all_gather(w_block, \"dp\", axis=0, tiled=True) # f32[4x1024, 256]\n", | |
| " y_block = x_block @ w_full\n", | |
| " # f32[256{dp=4}, 4096] @ f32[4096, 256] -> f32[256{dp=4}, 256]\n", | |
| " return y_block\n", | |
| "\n", | |
| "\n", | |
| "assert jnp.allclose(matmul_fsdp(x, w), y) # verify correctness\n", | |
| "%timeit matmul_fsdp(x,w).block_until_ready() # fsdp is faster than non-paralell version" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "id": "30fd9e0a-c390-47a2-ac17-ce39a098e10a", | |
| "metadata": {}, | |
| "source": [ | |
| "### Profiling matmul_fsdp" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 34, | |
| "id": "b9f012a6-0395-4a13-9dcd-511e10596e5b", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "name": "stderr", | |
| "output_type": "stream", | |
| "text": [ | |
| "E0201 17:26:43.205688 741064 python_hooks.cc:416] Can't import tensorflow.python.profiler.trace\n", | |
| "E0201 17:26:43.231453 741064 python_hooks.cc:416] Can't import tensorflow.python.profiler.trace\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "jax.profiler.start_trace(log_dir=\"./jax_profiler\")\n", | |
| "_ = matmul_fsdp(x, w).block_until_ready()\n", | |
| "# ^ need block until ready within profiler context\n", | |
| "jax.profiler.stop_trace()" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 36, | |
| "id": "b3e9318d-e0a0-4c57-ac7e-91eba7007e47", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "# %pip install tensorboard-plugin-profile # <- uncomment to install profile plugin if it does not render\n", | |
| "%load_ext tensorboard" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 37, | |
| "id": "f35e62e9-1d16-4b9b-a41e-fa3fd5ebc4d3", | |
| "metadata": { | |
| "collapsed": true, | |
| "jupyter": { | |
| "outputs_hidden": true | |
| }, | |
| "scrolled": true | |
| }, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/html": [ | |
| "\n", | |
| " <iframe id=\"tensorboard-frame-66630db69001805b\" width=\"100%\" height=\"800\" frameborder=\"0\">\n", | |
| " </iframe>\n", | |
| " <script>\n", | |
| " (function() {\n", | |
| " const frame = document.getElementById(\"tensorboard-frame-66630db69001805b\");\n", | |
| " const url = new URL(\"/\", window.location);\n", | |
| " const port = 6006;\n", | |
| " if (port) {\n", | |
| " url.port = port;\n", | |
| " }\n", | |
| " frame.src = url;\n", | |
| " })();\n", | |
| " </script>\n", | |
| " " | |
| ], | |
| "text/plain": [ | |
| "<IPython.core.display.HTML object>" | |
| ] | |
| }, | |
| "metadata": {}, | |
| "output_type": "display_data" | |
| } | |
| ], | |
| "source": [ | |
| "%tensorboard --logdir ./jax_profiler" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "id": "4d8320de-7aad-42d9-86d7-768611b04039", | |
| "metadata": {}, | |
| "source": [ | |
| "As you can see from the trace viewer, computation needs to wait until all-gather weights is done. Can we do better? " | |
| ] | |
| }, | |
| { | |
| "attachments": { | |
| "2da00048-5b90-4985-afe0-96c38199a47d.png": { | |
| "image/png": "" | |
| } | |
| }, | |
| "cell_type": "markdown", | |
| "id": "d89d4817-0015-4569-b7be-f23610d38f84", | |
| "metadata": {}, | |
| "source": [ | |
| "" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "id": "e125bce8-eaac-4979-bbe5-ba1b89b6ab93", | |
| "metadata": {}, | |
| "source": [ | |
| "### `ppermute` Loop Version\n", | |
| "Use ppermute ring shift to overlap communication and compuation" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 34, | |
| "id": "9231dffd-5a10-4c4e-b8c0-c08e906c0a5f", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "7.17 ms ± 202 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "@jax.jit\n", | |
| "@jax.shard_map(\n", | |
| " mesh=global_mesh, in_specs=(P(\"dp\", None), P(\"dp\", None)), out_specs=P(\"dp\", None)\n", | |
| ")\n", | |
| "def matmul_fsdp_allgather_overlapped(\n", | |
| " x_block: Float[Array, \"batch_local n_in\"],\n", | |
| " w_block: Float[Array, \"n_in_local n_out\"],\n", | |
| ") -> Float[Array, \"batch_local n_out\"]:\n", | |
| " # f32[1024@dp, 4096] @ f32[4096@dp, 256]\n", | |
| " # f32[256{dp=4}, 4096] f32[1024{dp=4}, 256]\n", | |
| " num_devices = jax.lax.axis_size(\"dp\")\n", | |
| " device_idx = jax.lax.axis_index(\"dp\")\n", | |
| "\n", | |
| " x_block_v_slice_size = x_block.shape[1] // num_devices\n", | |
| "\n", | |
| " def _get_x_block_v_slice(i: int) -> Float[Array, \"batch_local n_in_local\"]:\n", | |
| " return jax.lax.dynamic_slice_in_dim(\n", | |
| " x_block,\n", | |
| " start_index=i * x_block_v_slice_size,\n", | |
| " slice_size=x_block_v_slice_size,\n", | |
| " axis=1,\n", | |
| " )\n", | |
| "\n", | |
| " # in a forward ring shift, the program receives the previous w_block in a roll\n", | |
| " # we vertically slice the corresponding chunk of x_block for matrix multiplication in col @ row format\n", | |
| " # and reduce the results on the fly\n", | |
| "\n", | |
| " # start from the current col chunk @ current w row: f32[batch_local, n_out]\n", | |
| " out_block = _get_x_block_v_slice(device_idx) @ w_block\n", | |
| " # ring shift forward for num_devices - 1 steps\n", | |
| " for i in range(1, num_devices):\n", | |
| " w_block = ring_shift(w_block, axis_name=\"dp\", offset=1)\n", | |
| " out_block += _get_x_block_v_slice((device_idx - i) % num_devices) @ w_block\n", | |
| " return out_block\n", | |
| "\n", | |
| "\n", | |
| "assert jnp.allclose(matmul_fsdp_allgather_overlapped(x, w), y, atol=1e-4)\n", | |
| "%timeit matmul_fsdp_allgather_overlapped(x,w).block_until_ready() # fsdp is faster than non-paralell version" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "id": "b5c9b6f0-0c34-436e-8356-b109f10fd893", | |
| "metadata": {}, | |
| "source": [ | |
| "The above implementation allows overlapping communication with computation, and also avoids gathering the whole large weight `w` of size `[n_in, n_out]` onto each device. But on TPU it uses only **half the interconnect bandwidth** by permuting in only one direction along the ring. To permute bidirectionally, we just split the blocks in half and send each half in each direction:" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 35, | |
| "id": "77ee42a6-cb84-4411-83cf-719f6065d283", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "7.37 ms ± 132 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "@jax.jit\n", | |
| "@jax.shard_map(\n", | |
| " mesh=global_mesh, in_specs=(P(\"dp\", None), P(\"dp\", None)), out_specs=P(\"dp\", None)\n", | |
| ")\n", | |
| "def matmul_fsdp_allgather_overlapped(\n", | |
| " x_block: Float[Array, \"batch_local n_in\"],\n", | |
| " w_block: Float[Array, \"n_in_local n_out\"],\n", | |
| ") -> Float[Array, \"batch_local n_out\"]:\n", | |
| " # f32[1024@dp, 4096] @ f32[4096@dp, 256]\n", | |
| " # f32[256{dp=4}, 4096] f32[1024{dp=4}, 256]\n", | |
| " num_devices = jax.lax.axis_size(\"dp\")\n", | |
| " device_idx = jax.lax.axis_index(\"dp\")\n", | |
| "\n", | |
| " x_block_v_slice_size, _remainder = divmod(x_block.shape[1], num_devices * 2)\n", | |
| " assert _remainder == 0, f\"{x_block.shape[1]=} must be divisible by {2*num_devices=}\"\n", | |
| "\n", | |
| " def _get_x_block_v_slice(\n", | |
| " i: int, is_right_half: bool = False\n", | |
| " ) -> Float[Array, \"batch_local n_in_local//2\"]:\n", | |
| " \"\"\"slice the local x_block vertically into num_devices chunks, and split each chunk into left and right halves\"\"\"\n", | |
| " return jax.lax.dynamic_slice_in_dim(\n", | |
| " x_block,\n", | |
| " start_index=(2 * i + is_right_half) * x_block_v_slice_size,\n", | |
| " slice_size=x_block_v_slice_size,\n", | |
| " axis=1,\n", | |
| " )\n", | |
| "\n", | |
| " # Initialization: split w_block horizontally into two halves (the rows) and\n", | |
| " # get the corresponding vertical chunks of the local x_block (the cols) to run col @ row matrix multiplication\n", | |
| " # and reduce on the fly\n", | |
| " w_block_up, w_block_bottom = jnp.split(w_block, 2, axis=0)\n", | |
| " x_block_chunk_left = _get_x_block_v_slice(device_idx, is_right_half=False)\n", | |
| " x_block_chunk_right = _get_x_block_v_slice(device_idx, is_right_half=True)\n", | |
| " out_block = x_block_chunk_left @ w_block_up\n", | |
| " out_block += x_block_chunk_right @ w_block_bottom # f32[batch_local, n_out]\n", | |
| "\n", | |
| " # bidirectional ring shift w_block for num_devices - 1 steps to finish the whole x_block @ w_block\n", | |
| " for i in range(1, num_devices):\n", | |
| " # ring shift forward and receive w_block_up from prev device\n", | |
| " w_block_up = ring_shift(w_block_up, axis_name=\"dp\", offset=1)\n", | |
| " # ring shift backward and receive w_block_bottom from next device\n", | |
| " w_block_bottom = ring_shift(w_block_bottom, axis_name=\"dp\", offset=-1)\n", | |
| " x_block_chunk_left = _get_x_block_v_slice(\n", | |
| " (device_idx - i) % num_devices, is_right_half=False\n", | |
| " )\n", | |
| " x_block_chunk_right = _get_x_block_v_slice(\n", | |
| " (device_idx + i) % num_devices, is_right_half=True\n", | |
| " )\n", | |
| " out_block += x_block_chunk_left @ w_block_up\n", | |
| " out_block += x_block_chunk_right @ w_block_bottom\n", | |
| " return out_block\n", | |
| "\n", | |
| "\n", | |
| "assert jnp.allclose(matmul_fsdp_allgather_overlapped(x, w), y, atol=1e-4)\n", | |
| "%timeit matmul_fsdp_allgather_overlapped(x,w).block_until_ready() # fsdp is faster than non-paralell version" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "id": "0661e107-65e2-4a2c-8b0a-40806684b1fa", | |
| "metadata": {}, | |
| "source": [ | |
| "Profile the bidirectional ring shift version to verify the all-gather disappears. " | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 40, | |
| "id": "08ed557c-643b-445b-9919-8cbfa30bb24a", | |
| "metadata": { | |
| "collapsed": true, | |
| "jupyter": { | |
| "outputs_hidden": true | |
| }, | |
| "scrolled": true | |
| }, | |
| "outputs": [ | |
| { | |
| "name": "stderr", | |
| "output_type": "stream", | |
| "text": [ | |
| "E0130 16:19:40.255039 3605 python_hooks.cc:416] Can't import tensorflow.python.profiler.trace\n", | |
| "E0130 16:19:40.271309 3605 python_hooks.cc:416] Can't import tensorflow.python.profiler.trace\n" | |
| ] | |
| }, | |
| { | |
| "data": { | |
| "text/html": [ | |
| "\n", | |
| " <iframe id=\"tensorboard-frame-d812e4d4352a8fe5\" width=\"100%\" height=\"800\" frameborder=\"0\">\n", | |
| " </iframe>\n", | |
| " <script>\n", | |
| " (function() {\n", | |
| " const frame = document.getElementById(\"tensorboard-frame-d812e4d4352a8fe5\");\n", | |
| " const url = new URL(\"/\", window.location);\n", | |
| " const port = 6007;\n", | |
| " if (port) {\n", | |
| " url.port = port;\n", | |
| " }\n", | |
| " frame.src = url;\n", | |
| " })();\n", | |
| " </script>\n", | |
| " " | |
| ], | |
| "text/plain": [ | |
| "<IPython.core.display.HTML object>" | |
| ] | |
| }, | |
| "metadata": {}, | |
| "output_type": "display_data" | |
| } | |
| ], | |
| "source": [ | |
| "with jax.profiler.trace(log_dir=\"./jax_profiler\"):\n", | |
| " _ = matmul_fsdp_allgather_overlapped(x, w).block_until_ready()\n", | |
| " # ^ need block until ready within profiler context\n", | |
| "\n", | |
| "%tensorboard --logdir=\"./jax_profiler\"" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "id": "88f484f3-9189-45e7-916d-daea6233f07a", | |
| "metadata": {}, | |
| "source": [ | |
| "### Matmul Reduce Scatter\n", | |
| "Use Case: TP forward step; FSDP backward step" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 36, | |
| "id": "ac2a8cc9-2180-4579-9993-8d904c077d66", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "jax.typeof(y)=ShapedArray(float32[1024@(dp,tp),256])\n", | |
| "True\n", | |
| "2.85 ms ± 101 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "@jax.jit\n", | |
| "@jax.shard_map(\n", | |
| " mesh=global_mesh,\n", | |
| " in_specs=(P(None, (\"dp\", \"tp\")), P((\"dp\", \"tp\"), None)),\n", | |
| " out_specs=P((\"dp\", \"tp\"), None),\n", | |
| ")\n", | |
| "def matmul_psumscatter(\n", | |
| " x_block: Float[Array, \"batch n_in_block\"],\n", | |
| " w_block: Float[Array, \"n_in_block n_out\"],\n", | |
| ") -> Float[Array, \"batch_scatter n_out\"]:\n", | |
| " # f32[1024, 4096@(dp,tp)] @ f32[4096@(dp,tp), 256]\n", | |
| " # f32[1024, 512{dp=4,tp=2}] @ f32[512{dp=4,tp=2}, 256] -> f32[1024, 256]{dp,tp}\n", | |
| " # -> psum -> f32[1024,256] -> scatter -> f32[1024@(dp,tp), 256]\n", | |
| " out_block = x_block @ w_block\n", | |
| " out_block_reduce_scatterred = jax.lax.psum_scatter(\n", | |
| " out_block, axis_name=(\"dp\", \"tp\"), scatter_dimension=0, tiled=True\n", | |
| " )\n", | |
| " return out_block_reduce_scatterred\n", | |
| "\n", | |
| "\n", | |
| "y = matmul_psumscatter(x, w)\n", | |
| "print(f\"{jax.typeof(y)=}\")\n", | |
| "print(jnp.allclose(y, x @ w, atol=1e-4))\n", | |
| "%timeit _ = matmul_psumscatter(x, w).block_until_ready()" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "id": "38629999-d23c-4c5b-9d0a-828a6cbd434a", | |
| "metadata": {}, | |
| "source": [ | |
| "But the scattering communication must wait for the entire local matrix multiplications to finish before it can start. To get communication/computation overlap, we can implement `psum_scatter` as N-1 ppermute, then interleave the communication steps with local matrix multiplications." | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 37, | |
| "id": "95500c5d-a7cf-4ba9-8dc0-e5de76692252", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "jax.typeof(y)=ShapedArray(float32[1024@(dp,tp),256])\n", | |
| "True\n", | |
| "4.72 ms ± 104 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "@jax.jit\n", | |
| "@jax.shard_map(\n", | |
| " mesh=global_mesh,\n", | |
| " in_specs=(P(None, (\"dp\", \"tp\")), P((\"dp\", \"tp\"), None)),\n", | |
| " out_specs=P((\"dp\", \"tp\"), None),\n", | |
| ")\n", | |
| "def matmul_psumscatter_overlapped(\n", | |
| " x_block: Float[Array, \"batch n_in_block\"],\n", | |
| " w_block: Float[Array, \"n_in_block n_out\"],\n", | |
| ") -> Float[Array, \"batch_scatter n_out\"]:\n", | |
| " # f32[1024, 4096@(dp,tp)] @ f32[4096@(dp,tp), 256]\n", | |
| " # f32[1024, 512{dp=4,tp=2}] @ f32[512{dp=4,tp=2}, 256] -> f32[1024, 256]{dp,tp}\n", | |
| " # -> psum -> f32[1024,256] -> scatter -> f32[1024@(dp,tp), 256]\n", | |
| " num_devices = jax.lax.axis_size((\"dp\", \"tp\"))\n", | |
| " device_idx = jax.lax.axis_index((\"dp\", \"tp\"))\n", | |
| " x_block_h_slice_size, _remainder = divmod(x_block.shape[0], num_devices)\n", | |
| " assert _remainder == 0\n", | |
| "\n", | |
| " def _get_x_block_h_slice(i: int):\n", | |
| " \"\"\"slice local x_block into dp x tp rows and get the ith slice\"\"\"\n", | |
| " return jax.lax.dynamic_slice_in_dim(\n", | |
| " x_block,\n", | |
| " start_index=i * x_block_h_slice_size,\n", | |
| " slice_size=x_block_h_slice_size,\n", | |
| " axis=0,\n", | |
| " )\n", | |
| "\n", | |
| " # each device initializes its matmul summation from its last peer so we can ring shift back\n", | |
| " y_block_scatterred = (\n", | |
| " _get_x_block_h_slice((device_idx - (num_devices - 1)) % num_devices) @ w_block\n", | |
| " )\n", | |
| " # ring shift the summation back and accumulate the local matmul on the fly to reduce onto device_idx eventually\n", | |
| " for i in range(num_devices - 2, -1, -1):\n", | |
| " y_block_scatterred = ring_shift(\n", | |
| " y_block_scatterred, axis_name=(\"dp\", \"tp\"), offset=-1\n", | |
| " )\n", | |
| " y_block_scatterred += (\n", | |
| " _get_x_block_h_slice((device_idx - i) % num_devices) @ w_block\n", | |
| " )\n", | |
| " return y_block_scatterred\n", | |
| "\n", | |
| "\n", | |
| "y = matmul_psumscatter_overlapped(x, w)\n", | |
| "print(f\"{jax.typeof(y)=}\")\n", | |
| "print(jnp.allclose(y, x @ w, atol=1e-4))\n", | |
| "%timeit _ = matmul_psumscatter_overlapped(x, w).block_until_ready()" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "id": "b52a612c-b4fb-49b4-8f97-400cb7ad3d49", | |
| "metadata": {}, | |
| "source": [ | |
| "**Bidirectional ring shift**" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 38, | |
| "id": "2e080710-32bf-437d-95b0-c1343af2f1a4", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "jax.typeof(y)=ShapedArray(float32[1024@(dp,tp),256])\n", | |
| "True\n", | |
| "9.17 ms ± 143 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "@jax.jit\n", | |
| "@jax.shard_map(\n", | |
| " in_specs=(P(None, (\"dp\", \"tp\")), P((\"dp\", \"tp\"), None)),\n", | |
| " out_specs=P((\"dp\", \"tp\"), None),\n", | |
| ")\n", | |
| "def matmul_psumscatter_overlapped(\n", | |
| " x_block: Float[Array, \"batch n_in_block\"],\n", | |
| " w_block: Float[Array, \"n_in_block n_out\"],\n", | |
| ") -> Float[Array, \"batch_scatter n_out\"]:\n", | |
| " # f32[1024, 4096@(dp,tp)] @ f32[4096@(dp,tp), 256]\n", | |
| " # f32[1024, 512{dp=4,tp=2}] @ f32[512{dp=4,tp=2}, 256] -> f32[1024, 256]{dp,tp}\n", | |
| " # -> psum -> f32[1024,256] -> scatter -> f32[1024@(dp,tp), 256]\n", | |
| " num_devices = jax.lax.axis_size((\"dp\", \"tp\"))\n", | |
| " device_idx = jax.lax.axis_index((\"dp\", \"tp\"))\n", | |
| " # horizontally slice x_block into num_devices rows and split each row into top and bottom halves\n", | |
| " x_block_h_slice_size, _remainder = divmod(x_block.shape[0], num_devices * 2)\n", | |
| " assert _remainder == 0, f\"{x_block.shape[0]=} must divide {num_devices * 2=}\"\n", | |
| "\n", | |
| " def _get_x_block_h_slice(i: int, is_bottom_half: bool = False):\n", | |
| " return jax.lax.dynamic_slice_in_dim(\n", | |
| " x_block,\n", | |
| " start_index=(2 * i + is_bottom_half) * x_block_h_slice_size,\n", | |
| " slice_size=x_block_h_slice_size,\n", | |
| " axis=0,\n", | |
| " )\n", | |
| "\n", | |
| " # For ring shift backward: each device initializes its matmul summation from its last peer and\n", | |
| " # ring shift the summation back and accumulate the local matmul on the fly to reduce onto device_idx eventually\n", | |
| " x_block_h_slice_top = _get_x_block_h_slice(\n", | |
| " (device_idx - (num_devices - 1)) % num_devices, is_bottom_half=False\n", | |
| " )\n", | |
| " y_block_scatterred_top = x_block_h_slice_top @ w_block\n", | |
| " # For ring shift forward: each device initializes its matmul summation from its next peer and\n", | |
| " # ring shift the summation forward and accumulate the local matmul on the fly to reduce onto device_idx eventually\n", | |
| " x_block_h_slice_bottom = _get_x_block_h_slice(\n", | |
| " (device_idx - 1) % num_devices, is_bottom_half=True\n", | |
| " )\n", | |
| " y_block_scatterred_bottom = x_block_h_slice_bottom @ w_block\n", | |
| "\n", | |
| " for i, j in zip(range(num_devices - 2, -1, -1), range(2, num_devices + 1)):\n", | |
| " y_block_scatterred_top = ring_shift(\n", | |
| " y_block_scatterred_top, axis_name=(\"dp\", \"tp\"), offset=-1\n", | |
| " )\n", | |
| " y_block_scatterred_bottom = ring_shift(\n", | |
| " y_block_scatterred_bottom, axis_name=(\"dp\", \"tp\"), offset=1\n", | |
| " )\n", | |
| " x_block_h_slice_top = _get_x_block_h_slice(\n", | |
| " (device_idx - i) % num_devices, is_bottom_half=False\n", | |
| " )\n", | |
| " y_block_scatterred_top += x_block_h_slice_top @ w_block\n", | |
| " x_block_h_slice_bottom = _get_x_block_h_slice(\n", | |
| " (device_idx - j) % num_devices, is_bottom_half=True\n", | |
| " )\n", | |
| " y_block_scatterred_bottom += x_block_h_slice_bottom @ w_block\n", | |
| " return jnp.concat((y_block_scatterred_top, y_block_scatterred_bottom), axis=0)\n", | |
| "\n", | |
| "\n", | |
| "y = matmul_psumscatter_overlapped(x, w)\n", | |
| "print(f\"{jax.typeof(y)=}\")\n", | |
| "print(jnp.allclose(y, x @ w, atol=1e-4))\n", | |
| "%timeit _ = matmul_psumscatter_overlapped(x, w).block_until_ready()" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "id": "f7967a90-aa18-4e97-829e-082337c90e81", | |
| "metadata": {}, | |
| "source": [ | |
| "## Practical Examples" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "id": "b082f6a8-970c-4651-b7f1-fedbb9c5189d", | |
| "metadata": {}, | |
| "source": [ | |
| "### Distributed Training of Neural Networks\n", | |
| "\n", | |
| "Setup a basic Neural Net example" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 39, | |
| "id": "44a4bba7-55c5-4ecc-adcb-867cbb5571a9", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "from dataclasses import dataclass, field\n", | |
| "from typing import Any, Callable\n", | |
| "\n", | |
| "\n", | |
| "# Register dataclass as PyTree\n", | |
| "@jax.tree_util.register_dataclass\n", | |
| "@dataclass\n", | |
| "class Linear[T: (jax.Array, P)]:\n", | |
| " w: T\n", | |
| " b: T\n", | |
| "\n", | |
| "\n", | |
| "@jax.tree_util.register_dataclass\n", | |
| "@dataclass\n", | |
| "class Batch[T: (jax.Array, P)]:\n", | |
| " inputs: T\n", | |
| " targets: T\n", | |
| "\n", | |
| "\n", | |
| "# Type aliases\n", | |
| "type LinearParams = Linear[jax.Array]\n", | |
| "type LinearPSpec = Linear[P]\n", | |
| "type BatchData = Batch[jax.Array]\n", | |
| "type BatchPSpec = Batch[P]\n", | |
| "\n", | |
| "\n", | |
| "# @jax.tree_util.register_pytree_node_class\n", | |
| "@jax.tree_util.register_dataclass\n", | |
| "@dataclass\n", | |
| "class Model:\n", | |
| " \"\"\"JAX-compatible Model container -- Registered as a PyTree\"\"\"\n", | |
| "\n", | |
| " # Dynamic Fields (JAX Pytree Children)\n", | |
| " layers: list[LinearParams] | list[LinearPSpec]\n", | |
| "\n", | |
| " # Static Fields (JAX Aux Data)\n", | |
| " layer_sizes: tuple[int, ...] = field(metadata={\"static\": True})\n", | |
| "\n", | |
| " def __init__(\n", | |
| " self,\n", | |
| " layer_sizes: tuple[int, ...],\n", | |
| " key: jax.Array | None = None,\n", | |
| " layers: list[LinearParams] | list[LinearPSpec] | None = None,\n", | |
| " ):\n", | |
| " self.layer_sizes = layer_sizes\n", | |
| " # Not a dataclass field so it is Non-Pytree child\n", | |
| " self.key = key\n", | |
| "\n", | |
| " if layers is not None:\n", | |
| " self.layers = layers\n", | |
| " elif key is not None:\n", | |
| " self._init_weights()\n", | |
| " else:\n", | |
| " self.layers = []\n", | |
| "\n", | |
| " @property\n", | |
| " def num_layers(self) -> int:\n", | |
| " return len(self.layer_sizes) - 1\n", | |
| "\n", | |
| " def _init_weights(self) -> None:\n", | |
| " \"\"\"Random init self.layers as a list of LinearParams\"\"\"\n", | |
| " assert self.key is not None, \"Did you set key in model?\"\n", | |
| " keys = jax.random.split(self.key, self.num_layers)\n", | |
| " self.layers = []\n", | |
| " for k, n_in, n_out in zip(keys, self.layer_sizes[:-1], self.layer_sizes[1:]):\n", | |
| " k1, k2 = jax.random.split(k)\n", | |
| " w = jax.random.normal(k1, (n_in, n_out)) / jnp.sqrt(n_in)\n", | |
| " b = jax.random.normal(k2, (n_out,))\n", | |
| " self.layers.append(Linear(w, b))\n", | |
| "\n", | |
| " def predict(\n", | |
| " self,\n", | |
| " inputs: Float[Array, \"b in\"],\n", | |
| " linear_fwd_fn: Callable[[jax.Array, LinearParams], jax.Array],\n", | |
| " ) -> Float[Array, \"b out\"]:\n", | |
| " x = inputs\n", | |
| " for layer in self.layers:\n", | |
| " x = linear_fwd_fn(x, layer)\n", | |
| " x = jax.nn.relu(x)\n", | |
| " return x\n", | |
| "\n", | |
| " def init_random_batch(self, key: jax.Array, batch_size: int = 32) -> BatchData:\n", | |
| " k1, k2 = jax.random.split(key)\n", | |
| " inputs = jax.random.normal(k1, (batch_size, self.layer_sizes[0]))\n", | |
| " targets = jax.random.normal(k2, (batch_size, self.layer_sizes[-1]))\n", | |
| " return Batch(inputs, targets)\n", | |
| "\n", | |
| " # if using @jax.tree_util.register_pytree_node_class instead of @jax.tree_util.register_dataclass\n", | |
| " # --- Reference Implementation of JAX Pytree Conversion ---\n", | |
| " # def tree_flatten(self) -> tuple[tuple[\"Dynamic\", ...], tuple[\"Static\", ...]]:\n", | |
| " # return (self.layers,), (self.layer_sizes,)\n", | |
| "\n", | |
| " # @classmethod\n", | |
| " # def tree_unflatten(\n", | |
| " # cls,\n", | |
| " # aux_data: tuple[\"Static\", ...],\n", | |
| " # children: tuple[\"Dynamic\", ...],\n", | |
| " # ) -> \"Model\":\n", | |
| " # (layers,) = children\n", | |
| " # (layer_sizes,) = aux_data\n", | |
| " # return cls(layer_sizes=layer_sizes, layers=layers)\n", | |
| "\n", | |
| "\n", | |
| "# Initialize model and data\n", | |
| "key_model, key_batch = jax.random.split(jax.random.key(0))\n", | |
| "model = Model(layer_sizes=(32, 128, 256, 128, 4), key=key_model)\n", | |
| "batch = model.init_random_batch(key=key_batch, batch_size=16)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "id": "3489e9a9-9d93-4e7c-b644-bc2424fe27af", | |
| "metadata": {}, | |
| "source": [ | |
| "Baseline: Single-device, no parallelism " | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 40, | |
| "id": "79806266-8fb9-4d54-8f90-dfe3f37a29ba", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "Baseline Loss: 4.0431547\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "def linear(x: jax.Array, layer: LinearParams) -> jax.Array:\n", | |
| " return x @ layer.w + layer.b\n", | |
| "\n", | |
| "\n", | |
| "# Calculate grad(loss) wrt to params. argnums=0 specifies that the gradient of Loss is calculated wrt the positional arg 0.\n", | |
| "@jax.jit\n", | |
| "@partial(jax.value_and_grad, argnums=0)\n", | |
| "def loss_baseline(model: Model, batch: BatchData) -> jax.Array:\n", | |
| " preds = model.predict(batch.inputs, linear_fwd_fn=linear)\n", | |
| " # MSE loss\n", | |
| " return jnp.mean(jnp.sum((preds - batch.targets) ** 2, axis=-1))\n", | |
| "\n", | |
| "\n", | |
| "loss_baseline_, grads_ = loss_baseline(model, batch)\n", | |
| "print(\"Baseline Loss:\", loss_baseline_)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 41, | |
| "id": "958ada43-0916-4fc7-984b-cd7c348d1636", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "459 μs ± 21.8 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "%timeit jax.block_until_ready(loss_baseline(model, batch))" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "id": "b5f60638-69cd-49ab-ac03-89fd7ba82534", | |
| "metadata": {}, | |
| "source": [ | |
| "Note: we omit the weight update step for simplicity." | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "id": "224d9f35-9939-4ccd-9336-fdbc3f40609e", | |
| "metadata": {}, | |
| "source": [ | |
| "#### Data Parallelism \n", | |
| "Shard data along the batch dim over the `dp` mesh axis and replicate model weights across the `dp` mesh axis.\n", | |
| "\n", | |
| "* Forward: if you need the (total or average) loss, you need an all-reduce (psum/pmean) over the `dp` mesh axis for total mean loss \n", | |
| "* Backward: gradients need to be all-reduce-sum (psum) to keep model replicas' weights in sync." | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 42, | |
| "id": "781aace1-1042-4543-a678-f8e555224c77", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "DP Loss: 4.0431547\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "# fully replicate model weights\n", | |
| "linear_pspec: LinearPSpec = Linear(w=P(), b=P())\n", | |
| "model_pspec = Model(\n", | |
| " layer_sizes=model.layer_sizes, layers=[linear_pspec] * model.num_layers\n", | |
| ")\n", | |
| "# shard data along batch dim and put on dp x tp mesh axis\n", | |
| "batch_pspec: BatchPSpec = Batch(\n", | |
| " inputs=P((\"dp\", \"tp\"), None), targets=P((\"dp\", \"tp\"), None)\n", | |
| ")\n", | |
| "\n", | |
| "\n", | |
| "@jax.jit\n", | |
| "@jax.value_and_grad\n", | |
| "@jax.shard_map(mesh=global_mesh, in_specs=(model_pspec, batch_pspec), out_specs=P())\n", | |
| "def loss_dp(model_sharded: Model, batch_sharded: BatchData) -> jax.Array:\n", | |
| " # Inputs and Targets: f32[16 @ (dp=4, tp=2), 32] -> f32[2{dp,tp}, 32]\n", | |
| " # w, b: fully replicated: x @ w + b -> y f32[2{dp,tp}, 4]\n", | |
| " preds = model_sharded.predict(batch_sharded.inputs, linear_fwd_fn=linear)\n", | |
| " # local per-example MSE loss\n", | |
| " loss_per_example = jnp.sum((preds - batch_sharded.targets) ** 2, axis=-1)\n", | |
| " # local mean MSE loss\n", | |
| " loss_local = jnp.mean(loss_per_example)\n", | |
| " # all reduce mean across devices -> total mean mse loss\n", | |
| " return jax.lax.pmean(loss_local, (\"dp\", \"tp\"))\n", | |
| "\n", | |
| "\n", | |
| "loss_dp_, grads_dp_ = loss_dp(model, batch)\n", | |
| "print(\"DP Loss:\", loss_dp_)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 43, | |
| "id": "58107fd8-3790-4b88-892e-02aa8301e124", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "2.11 ms ± 94.8 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "%timeit jax.block_until_ready(loss_dp(model, batch))" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 44, | |
| "id": "fac73302-8cb3-4850-bb41-1ce3fbc17f7d", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "def check_all_close(a: PyTree[Array], b: PyTree[Array], atol=1e-5) -> bool:\n", | |
| " return jax.tree.all(jax.tree.map(partial(jnp.allclose, atol=atol), a, b))\n", | |
| "\n", | |
| "\n", | |
| "# Verify correctness\n", | |
| "assert check_all_close((loss_dp_, grads_dp_), (loss_baseline_, grads_))" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "id": "90392021-fe5d-45fc-9106-82e31c3fb793", | |
| "metadata": {}, | |
| "source": [ | |
| "#### FSDP\n", | |
| "In addition to DP, shard the model weights over the `dp` mesh axis. \n", | |
| "* Forward: Need an All-Gather to (Pre)-fetch weights before layer forward. Release the all-gathered weights once a later forward is done rather than keeping them in device memory. \n", | |
| "* Backward: Need an All-Gather (re-materialization) weights before calculating gradients (rematerialization) and need a reduce-scatter to keep gradients in sync and sharded across dp devices.\n", | |
| "\n", | |
| "A few words about `jax.checkpoint`: \n", | |
| "* Standard Backprop: If you run any `jax.grad(loss)(x,y)` for training, JAX automatically saves **all** intermediate activations such as `h =x @ w` from the forward pass to calculate gradients during the backward pass. This is fast but memory-intensive, especially if you have a deep architecture comprising many layers.\n", | |
| "* Activation Checkpointing: JAX saves only the `(inputs, outputs)` of the checkpointed function. During the backward pass, it recomputes the intermediate activations needed on the fly for calculating gradients, thus using more compute to exchange for memory. " | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 45, | |
| "id": "7b6de1dc-b642-46c8-a524-ef07570a9578", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "FSDP Loss: 4.0431547\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "# Shard model weights over dp x tp axis\n", | |
| "linear_pspec: LinearPSpec = Linear(\n", | |
| " w=P((\"dp\", \"tp\"), None),\n", | |
| " b=P(\"tp\"), # can also be P() to replicate bias since its cheap\n", | |
| ")\n", | |
| "model_pspec = Model(\n", | |
| " layer_sizes=model.layer_sizes, layers=[linear_pspec] * model.num_layers\n", | |
| ")\n", | |
| "# shard data along batch dim over dp x tp mesh axis. Same as DP\n", | |
| "batch_pspec: BatchPSpec = Batch(\n", | |
| " inputs=P((\"dp\", \"tp\"), None), targets=P((\"dp\", \"tp\"), None)\n", | |
| ")\n", | |
| "\n", | |
| "\n", | |
| "# checkpoint all non-all-gather ops activations in memory, rematerialize all-gathered weights during back propagation\n", | |
| "@partial(jax.checkpoint, policy=lambda p, *_, **__: p.name != \"all_gather\")\n", | |
| "def linear_fsdp(x: jax.Array, layer: LinearParams) -> jax.Array:\n", | |
| " \"\"\"All-gather weights across devices to device\"\"\"\n", | |
| " w = jax.lax.all_gather(layer.w, (\"dp\", \"tp\"), axis=0, tiled=True)\n", | |
| " b = jax.lax.all_gather(layer.b, \"tp\", axis=0, tiled=True)\n", | |
| " return x @ w + b\n", | |
| "\n", | |
| "\n", | |
| "@jax.jit\n", | |
| "@jax.value_and_grad\n", | |
| "@jax.shard_map(mesh=global_mesh, in_specs=(model_pspec, batch_pspec), out_specs=P())\n", | |
| "def loss_fsdp(model_sharded: Model, batch_sharded: BatchData) -> jax.Array:\n", | |
| " # Inputs and Targets: f32[16 @ (dp=4, tp=2), 32] -> f32[2{dp,tp}, 32]\n", | |
| " # w, b: all-gathered before local computation: x @ w + b -> y f32[2{dp,tp}, 4]\n", | |
| " preds = model_sharded.predict(batch_sharded.inputs, linear_fwd_fn=linear_fsdp)\n", | |
| " # local per-example MSE loss\n", | |
| " loss_per_example = jnp.sum((preds - batch_sharded.targets) ** 2, axis=-1)\n", | |
| " # local mean MSE loss\n", | |
| " loss_local = jnp.mean(loss_per_example)\n", | |
| " # all reduce mean across dp devices -> total mean mse loss\n", | |
| " return jax.lax.pmean(loss_local, (\"dp\", \"tp\"))\n", | |
| "\n", | |
| "\n", | |
| "loss_fsdp_, grads_fsdp_ = loss_fsdp(model, batch)\n", | |
| "print(\"FSDP Loss:\", loss_fsdp_)\n", | |
| "# Verify correctness\n", | |
| "assert check_all_close((loss_fsdp_, grads_fsdp_), (loss_baseline_, grads_))" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 46, | |
| "id": "24b3a8bc-8e5f-4dd6-9826-84af1e65bfb8", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "3.21 ms ± 344 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "%timeit jax.block_until_ready(loss_fsdp(model, batch))" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 47, | |
| "id": "63199c5d-4283-47be-a02b-8695d39dfa0a", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "FSDP Loss: 4.0431547\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "# This is also doing FSDP but with a different sharding spec for model weights\n", | |
| "# Shard model weights along dim=0 over dp axis and along dim=1 over tp axis\n", | |
| "linear_pspec: LinearPSpec = Linear(\n", | |
| " w=P(\"dp\", \"tp\"),\n", | |
| " b=P(\"tp\"), # can also be P() to replicate bias since its cheap\n", | |
| ")\n", | |
| "model_pspec = Model(\n", | |
| " layer_sizes=model.layer_sizes, layers=[linear_pspec] * model.num_layers\n", | |
| ")\n", | |
| "# shard data along batch dim over dp x tp mesh axis. Same as DP\n", | |
| "batch_pspec: BatchPSpec = Batch(\n", | |
| " inputs=P((\"dp\", \"tp\"), None), targets=P((\"dp\", \"tp\"), None)\n", | |
| ")\n", | |
| "\n", | |
| "\n", | |
| "# checkpoint all non-all-gather ops activations in memory, rematerialize all-gathered weights during back propagation\n", | |
| "@partial(jax.checkpoint, policy=lambda p, *_, **__: p.name != \"all_gather\")\n", | |
| "def linear_fsdp(x: jax.Array, layer: LinearParams) -> jax.Array:\n", | |
| " \"\"\"All-gather weights across devices to device\"\"\"\n", | |
| " w_partial = jax.lax.all_gather(layer.w, \"tp\", axis=1, tiled=True)\n", | |
| " w = jax.lax.all_gather(w_partial, \"dp\", axis=0, tiled=True)\n", | |
| " b = jax.lax.all_gather(layer.b, \"tp\", axis=0, tiled=True)\n", | |
| " return x @ w + b\n", | |
| "\n", | |
| "\n", | |
| "@jax.jit\n", | |
| "@jax.value_and_grad\n", | |
| "@jax.shard_map(mesh=global_mesh, in_specs=(model_pspec, batch_pspec), out_specs=P())\n", | |
| "def loss_fsdp(model_sharded: Model, batch_sharded: BatchData) -> jax.Array:\n", | |
| " # Inputs and Targets: f32[16 @ (dp=4, tp=2), 32] -> f32[2{dp,tp}, 32]\n", | |
| " # w, b: all-gathered before local computation: x @ w + b -> y f32[2{dp,tp}, 4]\n", | |
| " preds = model_sharded.predict(batch_sharded.inputs, linear_fwd_fn=linear_fsdp)\n", | |
| " # local per-example MSE loss\n", | |
| " loss_per_example = jnp.sum((preds - batch_sharded.targets) ** 2, axis=-1)\n", | |
| " # local mean MSE loss\n", | |
| " loss_local = jnp.mean(loss_per_example)\n", | |
| " # all reduce mean across dp devices -> total mean mse loss\n", | |
| " return jax.lax.pmean(loss_local, (\"dp\", \"tp\"))\n", | |
| "\n", | |
| "\n", | |
| "loss_fsdp_, grads_fsdp_ = loss_fsdp(model, batch)\n", | |
| "print(\"FSDP Loss:\", loss_fsdp_)\n", | |
| "# Verify correctness\n", | |
| "assert check_all_close((loss_fsdp_, grads_fsdp_), (loss_baseline_, grads_))" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 48, | |
| "id": "fa598b1d-164b-4a9b-bc23-664ed23b4e3f", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "2.63 ms ± 159 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "%timeit jax.block_until_ready(loss_fsdp(model, batch))" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "id": "0095bc89-ef43-4ef4-bbd9-506779c3c9f8", | |
| "metadata": {}, | |
| "source": [ | |
| "#### TP\n", | |
| "Shard inputs (e.g. data) along the column dim over the `tp` mesh axis, shard the weights along the row dim over the `tp` mesh axis.\n", | |
| "\n", | |
| "Do column @ row matrix multiplication across the `tp` devices and all-reduce-sum partial outputs. If there are further matmul needed, scatter the outputs along column dim over `tp` mesh axis. \n" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 49, | |
| "id": "61af0444-7eed-408f-86fd-74877db9d1b0", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "TP Loss: 4.0431542\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "# Shard model weights row-wise over tp axis\n", | |
| "linear_pspec: LinearPSpec = Linear(w=P(\"tp\", None), b=P(\"tp\"))\n", | |
| "# for b, can also use P() to replicate across tp since it is cheap\n", | |
| "model_pspec = Model(\n", | |
| " layer_sizes=model.layer_sizes, layers=[linear_pspec] * model.num_layers\n", | |
| ")\n", | |
| "# shard data along feature (column) dim over tp mesh axis.\n", | |
| "batch_pspec: BatchPSpec = Batch(inputs=P(None, \"tp\"), targets=P(None, \"tp\"))\n", | |
| "\n", | |
| "\n", | |
| "def linear_tp(x: jax.Array, layer: LinearParams) -> jax.Array:\n", | |
| " \"\"\"Local column x row matmul\"\"\"\n", | |
| " matmul_local = x @ layer.w\n", | |
| " matmul_reduce_scattered = jax.lax.psum_scatter(\n", | |
| " matmul_local, \"tp\", scatter_dimension=1, tiled=True\n", | |
| " )\n", | |
| " return matmul_reduce_scattered + layer.b\n", | |
| "\n", | |
| "\n", | |
| "@jax.jit\n", | |
| "@partial(jax.value_and_grad, argnums=0)\n", | |
| "@jax.shard_map(mesh=global_mesh, in_specs=(model_pspec, batch_pspec), out_specs=P())\n", | |
| "def loss_tp(model_sharded: Model, batch_sharded: BatchData) -> jax.Array:\n", | |
| " # Inputs and Targets: f32[16, 32@tp] -> f32[16, 16{tp=2}]\n", | |
| " # w, b: sharded as rows over tp: x @ w -> col x row -> matmul_partial -> psum\n", | |
| " # -> scatter along column dim over tp + b (replicated or sharded over tp) -> y f32[16, 4@tp]\n", | |
| " preds = model_sharded.predict(batch_sharded.inputs, linear_fwd_fn=linear_tp)\n", | |
| " loss_local = jnp.sum((preds - batch_sharded.targets) ** 2, axis=-1)\n", | |
| " loss_per_example = jax.lax.psum(loss_local, \"tp\") # f32[16, ]\n", | |
| " return jnp.mean(loss_per_example)\n", | |
| "\n", | |
| "\n", | |
| "loss_tp_, grads_tp_ = loss_tp(model, batch)\n", | |
| "print(\"TP Loss:\", loss_tp_)\n", | |
| "# Verify correctness\n", | |
| "assert check_all_close((loss_tp_, grads_tp_), (loss_baseline_, grads_))" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 50, | |
| "id": "5b569900-0f23-4bbe-baed-6b5969d7ba7f", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "811 μs ± 31.8 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "%timeit jax.block_until_ready(loss_tp(model, batch))" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "id": "b6ba5afb-7572-40a0-be8a-5ead0cd0b681", | |
| "metadata": {}, | |
| "source": [ | |
| "#### FSDP + TP" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 51, | |
| "id": "3bfb53b2-0979-4347-9b88-17a041f79590", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "FSDP_TP Loss: 4.0431542\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "# Shard weights for FSDP + TP\n", | |
| "linear_pspec: LinearPSpec = Linear(\n", | |
| " w=P(\n", | |
| " (\"tp\", \"dp\"),\n", | |
| " ),\n", | |
| " b=P(\"tp\"),\n", | |
| ")\n", | |
| "model_pspec = Model(\n", | |
| " layer_sizes=model.layer_sizes, layers=[linear_pspec] * model.num_layers\n", | |
| ")\n", | |
| "# shard data for FSDP + TP.\n", | |
| "batch_pspec: BatchPSpec = Batch(inputs=P(\"dp\", \"tp\"), targets=P(\"dp\", \"tp\"))\n", | |
| "\n", | |
| "\n", | |
| "@partial(jax.checkpoint, policy=lambda p, *_, **__: p.name != \"all_gather\")\n", | |
| "def linear_fsdp_tp(x: jax.Array, layer: LinearParams) -> jax.Array:\n", | |
| " w = jax.lax.all_gather(layer.w, \"dp\", axis=0, tiled=True)\n", | |
| " # Note^: if your w partion is P(\"tp\", \"dp\") rather than P((\"tp\", \"dp\")), you need to all-gather along axis=1. Both works\n", | |
| " # if you sharded b across dp, like P(('tp',\"dp\")) you also need to all gather b along dp axis and concate along axis=0\n", | |
| " b = layer.b\n", | |
| " matmul_local = x @ w\n", | |
| " matmul_local_reduce_scattered = jax.lax.psum_scatter(\n", | |
| " matmul_local, \"tp\", scatter_dimension=1, tiled=True\n", | |
| " )\n", | |
| " return matmul_local_reduce_scattered + b\n", | |
| "\n", | |
| "\n", | |
| "@jax.jit\n", | |
| "@jax.value_and_grad\n", | |
| "@jax.shard_map(mesh=global_mesh, in_specs=(model_pspec, batch_pspec), out_specs=P())\n", | |
| "def loss_fsdp_tp(model_sharded: Model, batch_sharded: BatchData) -> jax.Array:\n", | |
| " # Inputs and Targets: f32[16@dp, 32@tp] -> f32[4{dp=4}, 16{tp=2}]\n", | |
| " # w: sharded as rows over tp x dp: all-gather over dp and concate along axis=0 -> x[@tp, ...]\n", | |
| " # x @ w -> col x row -> reduce_scatter along dim=1 -> matmul@[..., @tp] + b\n", | |
| " # (replicated or sharded over tp) -> y f32[4{dp=4}, 4@tp]\n", | |
| " preds = model_sharded.predict(batch_sharded.inputs, linear_fwd_fn=linear_fsdp_tp)\n", | |
| " # same as tp\n", | |
| " loss_local = jnp.sum((preds - batch_sharded.targets) ** 2, axis=-1)\n", | |
| " loss_per_example = jax.lax.psum(loss_local, \"tp\") # f32[16@dp, ] = f32[4{dp=4}]\n", | |
| " return jax.lax.pmean(jnp.mean(loss_per_example), axis_name=\"dp\")\n", | |
| "\n", | |
| "\n", | |
| "loss_fsdp_tp_, grads_fsdp_tp_ = loss_fsdp_tp(model, batch)\n", | |
| "print(\"FSDP_TP Loss:\", loss_fsdp_tp_)\n", | |
| "# Verify correctness\n", | |
| "assert check_all_close((loss_fsdp_tp_, grads_fsdp_tp_), (loss_baseline_, grads_))" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 52, | |
| "id": "640d3c98-b757-4fde-8369-cd1ccd86e527", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "2.01 ms ± 95.8 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "%timeit jax.block_until_ready(loss_fsdp_tp(model, batch))" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "id": "57a08331-15db-483c-b41c-fdd025759fe3", | |
| "metadata": {}, | |
| "source": [ | |
| "In practice, you can have a device mesh `(\"dp\", \"fsdp\", \"tp\", \"ep\")`. \n", | |
| "\n", | |
| "You could shard your data using `P((\"dp\", \"fsdp\"), \"tp\")` and shard your Linear layer weights using `P(\"tp\", \"fsdp\")`. During forward, you need to all-gather weights along `fsdp` mesh axis in to the right `dim=1`." | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "id": "2de60ea3-f57d-4efe-85a3-3e3d8e3c85ee", | |
| "metadata": {}, | |
| "source": [ | |
| "#### MOE Parallel with Token Dropping\n", | |
| "Due to Jax [Issue#34168](https://github.com/jax-ml/jax/issues/34168), this code does not run on GPU. `ragged_dot` is not well supported on GPU. " | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 53, | |
| "id": "fb8d3b5c-3155-48ce-878c-6e3093aa53cd", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "\"\"\"Mixture of Experts (MoE) Layer with token dropping\n", | |
| "\n", | |
| "Using Ragged All-to-All Communication and Ragged Dot in JAX.\n", | |
| "\"\"\"\n", | |
| "\n", | |
| "__author__ = \"Liutong Zhou\"\n", | |
| "\n", | |
| "from __future__ import annotations\n", | |
| "\n", | |
| "from dataclasses import dataclass, field\n", | |
| "from typing import Self\n", | |
| "\n", | |
| "import jax\n", | |
| "import jax.numpy as jnp\n", | |
| "from jax.sharding import Mesh, PartitionSpec as P\n", | |
| "from jaxtyping import Array, Bool, Float, Int\n", | |
| "\n", | |
| "\n", | |
| "def calculate_offsets(sizes: Int[Array, \"n\"]) -> Int[Array, \"n\"]:\n", | |
| " \"\"\"Calculate chunk start offsets given chunk sizes\n", | |
| "\n", | |
| " Parameters\n", | |
| " ----------\n", | |
| " sizes : Int[Array, \"n\"]\n", | |
| " An array of chunk sizes\n", | |
| "\n", | |
| " Returns\n", | |
| " -------\n", | |
| " offsets : Int[Array, \"n\"]\n", | |
| " An array of offsets starting at 0 such that offsets[i] is the starting index\n", | |
| " of chunk i in a flattened array.\n", | |
| "\n", | |
| " Examples\n", | |
| " -----------\n", | |
| " >>> sizes = jnp.array([2, 3, 1])\n", | |
| " >>> calculate_offsets(sizes)\n", | |
| " Array([0, 2, 5], dtype=int32)\n", | |
| " \"\"\"\n", | |
| " sizes_i32 = jnp.asarray(sizes, dtype=jnp.int32)\n", | |
| " # Pad with 0 at the start, remove the last element, then cumsum\n", | |
| " return jnp.cumsum(jnp.pad(sizes_i32[:-1], (1, 0)))\n", | |
| "\n", | |
| "\n", | |
| "def calculate_drop_token_mask(\n", | |
| " token_expert_ids: Int[Array, \"n_tokens\"],\n", | |
| " token_expert_probs: Float[Array, \"n_tokens\"],\n", | |
| " num_experts: int | Int[Array, \"\"],\n", | |
| " expert_capacity: int | Int[Array, \"\"],\n", | |
| ") -> Bool[Array, \"n_tokens\"]:\n", | |
| " \"\"\"Calculate a boolean mask to indicate which tokens should be dropped based on expert capacity.\n", | |
| "\n", | |
| " Parameters\n", | |
| " ----------\n", | |
| " token_expert_ids : Int[Array, \"n_tokens\"]\n", | |
| " 1-D array of expert IDs assigned to each token.\n", | |
| " token_expert_probs : Float[Array, \"n_tokens\"]\n", | |
| " 1-D array of probabilities assigned to each token for its expert.\n", | |
| " num_experts : int\n", | |
| " Total number of experts.\n", | |
| " expert_capacity : int\n", | |
| " Maximum number of tokens each expert can handle.\n", | |
| "\n", | |
| " Returns\n", | |
| " -------\n", | |
| " drop_mask : Bool[Array, \"n_tokens\"]\n", | |
| " A boolean mask indicating which tokens should be dropped (True) or kept (False).\n", | |
| " \"\"\"\n", | |
| " # Calculate each token's within-expert rank\n", | |
| " # Sort token experts by expert ID (ascending) and then by probability (descending)\n", | |
| " # so that after sorting, tokens for expert 0 are at the top, followed by tokens for expert 1, etc.\n", | |
| " index_to_sorted = jnp.lexsort((-token_expert_probs, token_expert_ids))\n", | |
| " token_expert_ids_sorted_by_rank = token_expert_ids[index_to_sorted]\n", | |
| "\n", | |
| " # Vectorized within-expert rank calculation\n", | |
| " # Find where each expert segment starts in token_expert_ids_sorted_by_rank\n", | |
| " expert_load_size = jnp.bincount(token_expert_ids, length=num_experts)\n", | |
| " expert_start_offsets = calculate_offsets(expert_load_size)\n", | |
| " # token absolute index - its expert start offset gives within-expert rank\n", | |
| " token_within_expert_rank = (\n", | |
| " jnp.arange(token_expert_ids_sorted_by_rank.shape[0])\n", | |
| " - expert_start_offsets[token_expert_ids_sorted_by_rank]\n", | |
| " )\n", | |
| "\n", | |
| " # Token dropping: cap the rank at expert_capacity\n", | |
| " is_dropped_sorted = token_within_expert_rank >= expert_capacity\n", | |
| "\n", | |
| " # revert the sorting to get the original token order\n", | |
| " index_to_inverse_sorted = jnp.argsort(index_to_sorted)\n", | |
| " drop_mask = is_dropped_sorted[index_to_inverse_sorted]\n", | |
| " return drop_mask\n", | |
| "\n", | |
| "\n", | |
| "@jax.tree_util.register_dataclass\n", | |
| "@dataclass(frozen=True, slots=True)\n", | |
| "class MOERaggedDispatcher:\n", | |
| " \"\"\"Orchestrating the MOE token dispatching and returning across devices using ragged all-to-all communications\n", | |
| "\n", | |
| " On initialization, this class pre-calculates the array layout required to move variable-sized\n", | |
| " slices of data between devices. It computes where data should be read from\n", | |
| " (input offsets) and where it should be written to (output offsets) for both\n", | |
| " the forward dispatch and the backward return trip.\n", | |
| "\n", | |
| " Attributes\n", | |
| " ----------\n", | |
| " send_sizes : Int[Array, \"devices\"]\n", | |
| " Number of items this device sends to each peer device.\n", | |
| " input_offsets : Int[Array, \"devices\"]\n", | |
| " Local array offsets to read data from during send.\n", | |
| " recv_sizes : Int[Array, \"devices\"]\n", | |
| " Number of items this device receives from each peer device.\n", | |
| " recv_offsets : Int[Array, \"devices\"]\n", | |
| " Local array offsets where received data will be stored.\n", | |
| " fwd_remote_output_offsets : Int[Array, \"devices\"]\n", | |
| " The offsets on the *receiver* devices where our sent data should be written.\n", | |
| " bwd_remote_output_offsets : Int[Array, \"devices\"]\n", | |
| " The offsets on the *original sender* devices where the returned results\n", | |
| " should be written.\n", | |
| " \"\"\"\n", | |
| "\n", | |
| " # Forward Pass Layout\n", | |
| " send_sizes: Int[Array, \"devices\"] # how much to read locally\n", | |
| " input_offsets: Int[Array, \"devices\"] # where to read from locally\n", | |
| " recv_sizes: Int[Array, \"devices\"] # how much to write locally\n", | |
| " recv_offsets: Int[Array, \"devices\"] # where to write locally\n", | |
| "\n", | |
| " # Remote Writing Instructions\n", | |
| " fwd_remote_output_offsets: Int[Array, \"devices\"] # where to write remotely\n", | |
| " # where to write remotely on return trip\n", | |
| " bwd_remote_output_offsets: Int[Array, \"devices\"]\n", | |
| "\n", | |
| " axis_name: str | tuple[str, ...] = field(\n", | |
| " default=\"ep\", metadata={\"static\": True}\n", | |
| " ) # mark as static for JAX jit\n", | |
| "\n", | |
| " @property\n", | |
| " def num_devices(self) -> int:\n", | |
| " \"\"\"Number of devices involved in all-to-all communication.\"\"\"\n", | |
| " return jax.lax.axis_size(self.axis_name)\n", | |
| "\n", | |
| " @classmethod\n", | |
| " def from_target_device_ids(\n", | |
| " cls,\n", | |
| " target_device_ids: Int[Array, \"n\"],\n", | |
| " *,\n", | |
| " axis_name: str | tuple[str, ...],\n", | |
| " mask: Bool[Array, \"n\"] | None = None,\n", | |
| " ) -> Self:\n", | |
| " \"\"\"Create a globally consistent communication plan.\n", | |
| "\n", | |
| " Parameters\n", | |
| " ----------\n", | |
| " target_device_ids : Int[Array, \"n\"]\n", | |
| " 1-D array indicating the target device ID for each item to be sent.\n", | |
| " axis_name : str | tuple[str, ...]\n", | |
| " Device mesh axis name(s) along which to perform all-to-all communication.\n", | |
| " mask : Bool[Array, \"n\"], optional\n", | |
| " Optional boolean mask indicating which items to send (True) or drop (False).\n", | |
| "\n", | |
| " Returns\n", | |
| " -------\n", | |
| " MOERaggedDispatcher\n", | |
| " An instance of MOERaggedDispatcher with precomputed communication layout.\n", | |
| " \"\"\"\n", | |
| " num_devices = jax.lax.axis_size(axis_name)\n", | |
| " device_send_load = jnp.bincount(target_device_ids, length=num_devices)\n", | |
| " input_offsets = calculate_offsets(device_send_load)\n", | |
| " # Only send this much, drop the masked-out items\n", | |
| " device_send_load_actual = jnp.bincount(\n", | |
| " target_device_ids,\n", | |
| " weights=mask.astype(target_device_ids.dtype) if mask is not None else None,\n", | |
| " length=num_devices,\n", | |
| " )\n", | |
| "\n", | |
| " # 1. Exchange sizes: Tell peer devices how much this device is sending;\n", | |
| " # receive how much they are sending to this device\n", | |
| " recv_sizes = jax.lax.all_to_all(\n", | |
| " device_send_load_actual,\n", | |
| " axis_name=axis_name,\n", | |
| " split_axis=0,\n", | |
| " concat_axis=0,\n", | |
| " tiled=True,\n", | |
| " )\n", | |
| "\n", | |
| " recv_offsets = calculate_offsets(recv_sizes)\n", | |
| "\n", | |
| " # 2. Exchange Output Offsets:\n", | |
| " # a) Forward: Tell original senders where to write in receivers' buffer.\n", | |
| " fwd_remote_output_offsets = jax.lax.all_to_all(\n", | |
| " recv_offsets,\n", | |
| " axis_name=axis_name,\n", | |
| " split_axis=0,\n", | |
| " concat_axis=0,\n", | |
| " tiled=True,\n", | |
| " )\n", | |
| "\n", | |
| " # b) Backward: Tell senders (expert devices) where to write back in receivers' (original senders') buffer.\n", | |
| " # When the expert forward is done, they need to return data to the original device. Tell\n", | |
| " # expert devices to write it exactly where tokens were originally read from.\n", | |
| " bwd_remote_output_offsets = jax.lax.all_to_all(\n", | |
| " input_offsets,\n", | |
| " axis_name=axis_name,\n", | |
| " split_axis=0,\n", | |
| " concat_axis=0,\n", | |
| " tiled=True,\n", | |
| " )\n", | |
| "\n", | |
| " return cls(\n", | |
| " send_sizes=device_send_load_actual,\n", | |
| " input_offsets=input_offsets,\n", | |
| " recv_sizes=recv_sizes,\n", | |
| " recv_offsets=recv_offsets,\n", | |
| " fwd_remote_output_offsets=fwd_remote_output_offsets,\n", | |
| " bwd_remote_output_offsets=bwd_remote_output_offsets,\n", | |
| " axis_name=axis_name,\n", | |
| " )\n", | |
| "\n", | |
| " def dispatch_forward[T: Array](\n", | |
| " self, data: T, capacity: int | Int[Array, \"\"]\n", | |
| " ) -> tuple[T, Bool[Array, \"n\"]]:\n", | |
| " \"\"\"Send data to expert devices.\n", | |
| "\n", | |
| " Parameters\n", | |
| " ----------\n", | |
| " data : Array\n", | |
| " The input data (e.g. tokens) sorted by target device. If with masking, data must be\n", | |
| " sorted such that dropped items are at the end within each device group.\n", | |
| " capacity : int\n", | |
| " The total size of the receiver's buffer (must be sufficient for worst case).\n", | |
| "\n", | |
| " Returns\n", | |
| " -------\n", | |
| " received : Array\n", | |
| " The data received from peer devices, packed contiguously in output_buffer.\n", | |
| " Shape (capacity, ...).\n", | |
| " is_valid : Bool[Array, \"n\"]\n", | |
| " 1-D boolean mask indicating which received items are valid in the output.\n", | |
| " received data beyond the actual received size are invalid and should be ignored.\n", | |
| " \"\"\"\n", | |
| " # pre-allocate an output buffer of static size for receiving data\n", | |
| " output_buffer = jnp.zeros((capacity,) + data.shape[1:], dtype=data.dtype)\n", | |
| " # each device sends data and returns the received data\n", | |
| " received = jax.lax.ragged_all_to_all(\n", | |
| " data,\n", | |
| " output_buffer, # for storing what this device will receive after forward dispatch\n", | |
| " self.input_offsets, # Read from here (local)\n", | |
| " self.send_sizes, # Read this amount (local)\n", | |
| " self.fwd_remote_output_offsets, # Write to here (remote)\n", | |
| " self.recv_sizes, # Write this amount (remote)\n", | |
| " axis_name=self.axis_name,\n", | |
| " )\n", | |
| " is_valid = jnp.arange(capacity) < jnp.sum(self.recv_sizes)\n", | |
| " return received, is_valid\n", | |
| "\n", | |
| " def dispatch_backward[T: Array](\n", | |
| " self, data: T, capacity: int | Int[Array, \"\"]\n", | |
| " ) -> tuple[T, Bool[Array, \"n\"]]:\n", | |
| " \"\"\"Return processed data to the original sender devices.\n", | |
| "\n", | |
| " Parameters\n", | |
| " ----------\n", | |
| " data : Array\n", | |
| " The processed data (must be in the same order as received).\n", | |
| " capacity : int\n", | |
| " The size of the buffer on the original sender (to restore shape).\n", | |
| "\n", | |
| " Returns\n", | |
| " -------\n", | |
| " returned_to_original_sender : Array\n", | |
| " The results, placed back into their original slots on the sender.\n", | |
| " is_valid : Bool[Array, \"n\"]\n", | |
| " 1-D boolean mask indicating which returned items are valid in the output.\n", | |
| " returned data beyond the actual returned size are invalid and should be ignored.\n", | |
| " \"\"\"\n", | |
| " output_buffer = jnp.zeros((capacity,) + data.shape[1:], dtype=data.dtype)\n", | |
| "\n", | |
| " # Note: Roles of input/output offsets are effectively swapped for the return trip\n", | |
| " returned_to_original_sender = jax.lax.ragged_all_to_all(\n", | |
| " data,\n", | |
| " output_buffer,\n", | |
| " self.recv_offsets, # Read from here (local processed data)\n", | |
| " self.recv_sizes, # Read this amount\n", | |
| " self.bwd_remote_output_offsets, # Write to here (remote original sender)\n", | |
| " self.send_sizes, # Write this amount\n", | |
| " axis_name=self.axis_name,\n", | |
| " )\n", | |
| " is_valid = jnp.arange(capacity) < jnp.sum(self.send_sizes)\n", | |
| " return returned_to_original_sender, is_valid" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 54, | |
| "id": "b2b86bd6-a7b6-4cc3-a196-444b88dcc5b1", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "def moe_layer(\n", | |
| " sequences: Float[Array, \"batch seq hidden\"],\n", | |
| " router_weights: Float[Array, \"hidden experts_total\"],\n", | |
| " expert_weights: Float[Array, \"experts_total hidden hidden\"],\n", | |
| " *,\n", | |
| " mesh: Mesh,\n", | |
| " top_k: int,\n", | |
| " expert_axis_name: str | tuple[str, ...] = \"ep\",\n", | |
| " capacity_factor: float = 1.2,\n", | |
| ") -> Float[Array, \"batch seq hidden\"]:\n", | |
| " \"\"\"Execute a Shard-Mapped Mixture of Experts layer with Ragged All-to-All Communication.\n", | |
| "\n", | |
| " Parameters\n", | |
| " ----------\n", | |
| " sequences : Float[Array, \"batch seq hidden\"]\n", | |
| " Input token sequences.\n", | |
| " router_weights : Float[Array, \"hidden experts_total\"]\n", | |
| " Weights for the gating network.\n", | |
| " expert_weights : Float[Array, \"experts_total hidden hidden\"]\n", | |
| " Weights for the experts (sharded).\n", | |
| " mesh : Mesh\n", | |
| " The JAX device mesh.\n", | |
| " top_k : int\n", | |
| " Number of experts to select per token.\n", | |
| " expert_axis_name : str | tuple[str, ...], optional\n", | |
| " Name of the mesh axis for experts, by default \"ep\".\n", | |
| " capacity_factor : float, optional\n", | |
| " Multiplier for expert capacity, by default 1.2.\n", | |
| " Expert Capacity = (total_tokens * top_k / total_experts) * capacity_factor.\n", | |
| "\n", | |
| " Returns\n", | |
| " -------\n", | |
| " Float[Array, \"batch seq hidden\"]\n", | |
| " The processed sequences.\n", | |
| " \"\"\"\n", | |
| "\n", | |
| " @jax.jit\n", | |
| " @jax.shard_map(\n", | |
| " mesh=mesh,\n", | |
| " in_specs=(\n", | |
| " P(\"dp\", expert_axis_name, None), # sequences[batch@dp, seq@ep, hidden]\n", | |
| " P(None, None), # router_weights[hidden, experts_total]\n", | |
| " # expert_weights[experts_total@ep, hidden, hidden]\n", | |
| " P(expert_axis_name, None, None),\n", | |
| " ),\n", | |
| " out_specs=P(\"dp\", expert_axis_name, None),\n", | |
| " )\n", | |
| " def _sharded_moe_impl(\n", | |
| " x_shard: Float[Array, \"local_batch local_seq hidden\"],\n", | |
| " router_weights: Float[Array, \"hidden experts_total\"],\n", | |
| " local_expert_weights: Float[Array, \"local_experts hidden hidden\"],\n", | |
| " ) -> Float[Array, \"local_batch local_seq hidden\"]:\n", | |
| " # x_shard[local_batch{dp}, local_seq{ep}, hidden]\n", | |
| "\n", | |
| " # Setup and preparations\n", | |
| " num_devices_ep = jax.lax.axis_size(expert_axis_name)\n", | |
| " num_experts_per_device = local_expert_weights.shape[0]\n", | |
| " num_total_experts = router_weights.shape[-1]\n", | |
| " assert (\n", | |
| " num_total_experts == num_devices_ep * num_experts_per_device\n", | |
| " ), f\"{num_total_experts=} must equal {num_devices_ep=} * {num_experts_per_device=}\"\n", | |
| "\n", | |
| " # A. Calculating Token Routing\n", | |
| " # 1. Flatten Local Batch: (b, s, hidden) -> (b*s, hidden)\n", | |
| " tokens = x_shard.reshape(-1, x_shard.shape[-1])\n", | |
| "\n", | |
| " # Average tokens per expert = Total Tokens to route / Total Experts\n", | |
| " expert_capacity = jnp.ceil(\n", | |
| " tokens.shape[0]\n", | |
| " * top_k\n", | |
| " * num_devices_ep\n", | |
| " / num_total_experts\n", | |
| " * capacity_factor\n", | |
| " )\n", | |
| " device_capacity = expert_capacity * num_experts_per_device\n", | |
| "\n", | |
| " # 2. Routing (Top-K): (b*s, hidden) @ (hidden, experts_total) -> (b*s, experts_total) -> topk -> (b*s, k)\n", | |
| " top_k_logits, top_k_expert_ids = jax.lax.top_k(tokens @ router_weights, k=top_k)\n", | |
| " top_k_expert_probs = jax.nn.softmax(top_k_logits, axis=-1)\n", | |
| "\n", | |
| " # 3. Expand: (b*s, hidden) -> (b*s*k, hidden)\n", | |
| " expert_ids_flat = top_k_expert_ids.ravel() # (b*s, k) -> (b*s*k,)\n", | |
| " expert_probs_flat = top_k_expert_probs.ravel() # (b*s, k) -> (b*s*k,)\n", | |
| " # Repeat interleave each token k times to create dispatch tokens\n", | |
| " tokens_flat = jnp.repeat(tokens, repeats=top_k, axis=0) # (b*s*k, hidden)\n", | |
| " # Track original token segment IDs to sum k results back later: (b*s*k,)\n", | |
| " token_segment_ids = jnp.repeat(\n", | |
| " jnp.arange(tokens.shape[0]), repeats=top_k, axis=0\n", | |
| " )\n", | |
| "\n", | |
| " # B. Token Dispatching with Token Dropping on senders' side\n", | |
| " is_token_dropped = calculate_drop_token_mask(\n", | |
| " token_expert_ids=expert_ids_flat,\n", | |
| " token_expert_probs=expert_probs_flat,\n", | |
| " num_experts=num_total_experts,\n", | |
| " expert_capacity=expert_capacity,\n", | |
| " )\n", | |
| " # Prepare data (e.g. tokens_flat) to be dispacthed by sorting so that after sorting, data to send are grouped by\n", | |
| " # 1. Target Device (ascending)\n", | |
| " # 2. Is Dropped? Dropped tokens go to the end within each device group (so that they are not sent)\n", | |
| " # 3. Local Expert ID within Device\n", | |
| " target_device_ids, target_local_expert_ids = jnp.divmod(\n", | |
| " expert_ids_flat, num_experts_per_device\n", | |
| " )\n", | |
| " index_to_sorted_for_dispatch = jnp.lexsort(\n", | |
| " (target_local_expert_ids, is_token_dropped, target_device_ids)\n", | |
| " )\n", | |
| "\n", | |
| " tokens_sorted_for_dispatch = tokens_flat[index_to_sorted_for_dispatch]\n", | |
| " expert_probs_sorted_for_dispatch = expert_probs_flat[\n", | |
| " index_to_sorted_for_dispatch\n", | |
| " ]\n", | |
| " target_local_expert_ids_sorted_for_dispatch = target_local_expert_ids[\n", | |
| " index_to_sorted_for_dispatch\n", | |
| " ]\n", | |
| " token_segment_ids_sorted_for_dispatch = token_segment_ids[\n", | |
| " index_to_sorted_for_dispatch\n", | |
| " ]\n", | |
| " # Initialize Device Dispatcher\n", | |
| " dispatcher = MOERaggedDispatcher.from_target_device_ids(\n", | |
| " target_device_ids, axis_name=expert_axis_name, mask=~is_token_dropped\n", | |
| " )\n", | |
| " # Dispatch Forward: Send tokens to expert devices, capping expert capacity on senders' side\n", | |
| " # In worst case, a receiver device may receive full expert load from all senders, need to provision receiver's capacity for that\n", | |
| " # note: device_capacity * num_devices_ep is constant wrt number of devices. As we increase devices, it does not increase.\n", | |
| " receiver_capacity = device_capacity * num_devices_ep\n", | |
| " recv_tokens, _ = dispatcher.dispatch_forward(\n", | |
| " tokens_sorted_for_dispatch, receiver_capacity\n", | |
| " )\n", | |
| " recv_probs, _ = dispatcher.dispatch_forward(\n", | |
| " expert_probs_sorted_for_dispatch, receiver_capacity\n", | |
| " )\n", | |
| " recv_local_expert_ids, is_valid = dispatcher.dispatch_forward(\n", | |
| " target_local_expert_ids_sorted_for_dispatch, receiver_capacity\n", | |
| " )\n", | |
| "\n", | |
| " # C. Expert Computation\n", | |
| " # Group tokens by local expert ids (0 to experts_per_device-1) for ragged dot\n", | |
| " # Note: we move invalid tokens (0s from output buffer) to the end so they are ignored in computation\n", | |
| " safe_local_expert_ids = jnp.where(\n", | |
| " is_valid, recv_local_expert_ids, num_experts_per_device\n", | |
| " )\n", | |
| " indices_to_sort_by_local_expert = jnp.argsort(safe_local_expert_ids)\n", | |
| " recv_tokens_by_local_expert = recv_tokens[indices_to_sort_by_local_expert]\n", | |
| " recv_probs_by_local_expert = recv_probs[indices_to_sort_by_local_expert]\n", | |
| " # Calculate batch size per local expert for ragged_dot\n", | |
| " local_expert_token_count = jnp.bincount(\n", | |
| " safe_local_expert_ids,\n", | |
| " # Effectively mask out invalid tokens\n", | |
| " weights=is_valid.astype(safe_local_expert_ids.dtype),\n", | |
| " length=num_experts_per_device,\n", | |
| " )\n", | |
| " # Perform MOE computation: (Batch, Hidden) @ (Hidden, Hidden) for each expert group\n", | |
| " # Note: expert_outputs.shape[0] = recv_tokens_by_local_expert.shape[0] with\n", | |
| " # expert_outputs[sum(local_expert_token_count):] = 0\n", | |
| " expert_outputs = jax.lax.ragged_dot(\n", | |
| " recv_tokens_by_local_expert,\n", | |
| " local_expert_weights,\n", | |
| " group_sizes=local_expert_token_count,\n", | |
| " )\n", | |
| " # Apply gating weights\n", | |
| " expert_outputs = expert_outputs * recv_probs_by_local_expert[:, None]\n", | |
| "\n", | |
| " # D. Dispatch Back (Return Trip)\n", | |
| " # Restore Order (Undo local expert sort)\n", | |
| " indices_invert_sort_by_local_expert = jnp.argsort(\n", | |
| " indices_to_sort_by_local_expert\n", | |
| " )\n", | |
| " packed_outputs = expert_outputs[indices_invert_sort_by_local_expert]\n", | |
| " tokens_moe_back, _ = dispatcher.dispatch_backward(\n", | |
| " packed_outputs, capacity=tokens_sorted_for_dispatch.shape[0]\n", | |
| " )\n", | |
| "\n", | |
| " # Finally. Summing (Segment Sum) expert outputs\n", | |
| " # Sum the K contributions back into the original token slots using the segment IDs.\n", | |
| " # (b*s*k, hidden) -> (b*s, hidden) corresponding to tokens due to usage of segment_ids\n", | |
| " combined_output = jax.ops.segment_sum(\n", | |
| " tokens_moe_back,\n", | |
| " segment_ids=token_segment_ids_sorted_for_dispatch,\n", | |
| " num_segments=tokens.shape[0],\n", | |
| " )\n", | |
| "\n", | |
| " return combined_output.reshape(x_shard.shape)\n", | |
| "\n", | |
| " return _sharded_moe_impl(sequences, router_weights, expert_weights)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "id": "d9acfd28-6e97-4391-b052-c946d7afb86e", | |
| "metadata": {}, | |
| "source": [ | |
| "### Knowledge Distillation\n", | |
| "We split the available devices into two disjoint meshes:\n", | |
| "* `mesh_inference`: Runs Teacher Model Forward Pass only\n", | |
| "* `mesh_train`: Runs Student Model Training (Forward + Backward + Update)\n", | |
| " \n", | |
| "Because JAX uses **Asynchronous Dispatch**, the Python for loop body in the code below will not be blocking. \n", | |
| "JAX will enqueue `teacher_step(Batch_2)` on `mesh_inference` while `mesh_train` is still busy calculating gradients for `student_step(Batch_1)`, thus **automatically creating a pipeline where both meshes work simultaneously**" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 55, | |
| "id": "d9abd854-27da-42f9-841a-89a1d16389d7", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "devices = jax.devices()\n", | |
| "# Split devices: First 4 for Teacher, Last 4 for Student (Adjust based on your hardware and model's GPU memory load)\n", | |
| "split_idx = len(devices) // 2\n", | |
| "\n", | |
| "# Tensor Parallel for big model inference\n", | |
| "mesh_inference = Mesh(devices[:split_idx], axis_names=(\"tp\",))\n", | |
| "\n", | |
| "# Data Parallel for small model training\n", | |
| "mesh_train = Mesh(devices[split_idx:], axis_names=(\"dp\",))" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 56, | |
| "id": "94acbd52-9613-4d2b-8a71-039a01f8d2e3", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "# Initialize models on separate device meshes\n", | |
| "with jax.set_mesh(mesh_inference):\n", | |
| " key_model_teacher, key_batch = jax.random.split(jax.random.key(0), num=2)\n", | |
| " model_teacher = Model(layer_sizes=(32, 128, 256, 128, 4), key=key_model_teacher)\n", | |
| " # init some fake data\n", | |
| " batch = model_teacher.init_random_batch(key=key_batch, batch_size=16)\n", | |
| "with jax.set_mesh(mesh_train):\n", | |
| " model_student = Model(layer_sizes=(32, 4), key=jax.random.key(1))" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 57, | |
| "id": "1d02b631-7b80-4773-8d60-3c659203dcc6", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "# Shard weights & input data for TP parallel inference\n", | |
| "linear_pspec: LinearPSpec = Linear(w=P(\"tp\", None), b=P(\"tp\"))\n", | |
| "model_teacher_pspec = Model(\n", | |
| " layer_sizes=model_teacher.layer_sizes,\n", | |
| " layers=[linear_pspec] * model_teacher.num_layers,\n", | |
| ")\n", | |
| "\n", | |
| "\n", | |
| "@jax.jit\n", | |
| "@jax.shard_map(\n", | |
| " mesh=mesh_inference,\n", | |
| " in_specs=(model_teacher_pspec, P(None, \"tp\")),\n", | |
| " out_specs=P(None, \"tp\"), # Output is sharded on 'tp' axis\n", | |
| ")\n", | |
| "def teacher_forward(\n", | |
| " model_sharded: Model, x: Float[Array, \"batch h\"]\n", | |
| ") -> Float[Array, \"batch C\"]:\n", | |
| " \"\"\"Mock Forward pass of teacher. Returns logits\"\"\"\n", | |
| "\n", | |
| " def linear_tp(x: jax.Array, layer: LinearParams) -> jax.Array:\n", | |
| " \"\"\"Local column x row matmul\"\"\"\n", | |
| " matmul_local = x @ layer.w\n", | |
| " matmul_reduce_scattered = jax.lax.psum_scatter(\n", | |
| " matmul_local, \"tp\", scatter_dimension=1, tiled=True\n", | |
| " )\n", | |
| " return matmul_reduce_scattered + layer.b\n", | |
| "\n", | |
| " return model_sharded.predict(x, linear_fwd_fn=linear_tp)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 58, | |
| "id": "b1bb9342-8a10-4209-83af-5c278a4c69eb", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "# Data Parallel for training small model: Only shard data\n", | |
| "linear_pspec: LinearPSpec = Linear(w=P(), b=P())\n", | |
| "model_student_pspec = Model(\n", | |
| " layer_sizes=model_student.layer_sizes,\n", | |
| " layers=[linear_pspec] * model_student.num_layers,\n", | |
| ")\n", | |
| "\n", | |
| "\n", | |
| "@jax.value_and_grad\n", | |
| "def kd_loss_dp(\n", | |
| " model_student: Model,\n", | |
| " x: Float[Array, \"batch_local hidden\"],\n", | |
| " y_pred_teacher: Float[Array, \"...\"],\n", | |
| ") -> Float[Array, \"\"]:\n", | |
| " \"\"\"Knowledge Distillation Loss\"\"\"\n", | |
| " y_pred_student = model_student.predict(x, linear_fwd_fn=linear)\n", | |
| " loss_dp_local = jnp.sum((y_pred_student - y_pred_teacher) ** 2, axis=-1).mean()\n", | |
| " return jax.lax.pmean(loss_dp_local, axis_name=\"dp\")\n", | |
| "\n", | |
| "\n", | |
| "@partial(jax.jit, donate_argnames=[\"model_student\"]) # buffer donation\n", | |
| "@jax.shard_map(\n", | |
| " mesh=mesh_train,\n", | |
| " # Only shard data (X,y) for DP\n", | |
| " in_specs=(model_student_pspec, P(\"dp\", None), P(\"dp\", None)),\n", | |
| " out_specs=(model_student_pspec, P()),\n", | |
| ")\n", | |
| "def train_student_step(\n", | |
| " model_student: Model,\n", | |
| " x: Float[Array, \"batch_local hidden\"],\n", | |
| " target_logits: Float[Array, \"batch_local C\"],\n", | |
| ") -> Model:\n", | |
| " loss, grads = kd_loss_dp(model_student, x, target_logits)\n", | |
| " # Update weights\n", | |
| " model_student_updated = jax.tree.map(\n", | |
| " lambda w, w_grad: w - 0.01 * w_grad, model_student, grads\n", | |
| " )\n", | |
| " return model_student_updated, loss" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 59, | |
| "id": "f42998d0-b788-47c3-b085-296bfb175b03", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "Loss: 1.09\n", | |
| "Loss: 1.01\n", | |
| "Loss: 0.94\n", | |
| "Loss: 0.88\n", | |
| "Loss: 0.83\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "# Train Loop\n", | |
| "for i in range(5):\n", | |
| " # Inference on Mesh_Inference\n", | |
| " with jax.set_mesh(mesh_inference):\n", | |
| " y_pred_teacher = teacher_forward(model_teacher, batch.inputs)\n", | |
| "\n", | |
| " # Train on Mesh train\n", | |
| " with jax.set_mesh(mesh_train):\n", | |
| " # Train on Mesh_Train\n", | |
| " # Move data across devices\n", | |
| " y_pred_teacher = jax.device_put(\n", | |
| " y_pred_teacher, NamedSharding(mesh_train, P(\"dp\", None))\n", | |
| " )\n", | |
| " batch_inputs = jax.device_put(\n", | |
| " batch.inputs, NamedSharding(mesh_train, P(\"dp\", None))\n", | |
| " )\n", | |
| " model_student, loss = train_student_step(\n", | |
| " model_student, batch_inputs, y_pred_teacher\n", | |
| " )\n", | |
| " print(f\"Loss: {loss:.2f}\")" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "id": "77fc6ef2-a584-4661-8c20-37a417bac9c0", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [] | |
| } | |
| ], | |
| "metadata": { | |
| "kernelspec": { | |
| "display_name": "Python 3 (ipykernel)", | |
| "language": "python", | |
| "name": "python3" | |
| }, | |
| "language_info": { | |
| "codemirror_mode": { | |
| "name": "ipython", | |
| "version": 3 | |
| }, | |
| "file_extension": ".py", | |
| "mimetype": "text/x-python", | |
| "name": "python", | |
| "nbconvert_exporter": "python", | |
| "pygments_lexer": "ipython3", | |
| "version": "3.13.11" | |
| } | |
| }, | |
| "nbformat": 4, | |
| "nbformat_minor": 5 | |
| } |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment