Created
March 13, 2026 14:00
-
-
Save adrn/691718dcd9aa2af064d9283b8825ae32 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| { | |
| "cells": [ | |
| { | |
| "cell_type": "code", | |
| "execution_count": 2, | |
| "id": "3e04b65c", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "import numpy as np\n", | |
| "import jax\n", | |
| "import jax.numpy as jnp\n", | |
| "from array_api_compat import array_namespace" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "id": "b40449d2", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "# Here's some pretend \"library\" functions, i.e. functions in Astropy written with the\n", | |
| "# array API. It doesn't have to know about JIT or jax\n", | |
| "def normalize(x):\n", | |
| " xp = array_namespace(x)\n", | |
| " lo = xp.min(x)\n", | |
| " hi = xp.max(x)\n", | |
| " return (x - lo) / (hi - lo)\n", | |
| "\n", | |
| "def normalize_sum(x):\n", | |
| " xp = array_namespace(x)\n", | |
| " return xp.sum(normalize(x))" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "id": "4ad63284", | |
| "metadata": {}, | |
| "source": [ | |
| "Works with numpy:" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 13, | |
| "id": "3294ba6c", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "array([0. , 0.5, 1. ])" | |
| ] | |
| }, | |
| "execution_count": 13, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "x = np.array([1, 2, 3])\n", | |
| "normalize(x)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "id": "ebd07185", | |
| "metadata": {}, | |
| "source": [ | |
| "Also JAX:" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 14, | |
| "id": "7dba48c1", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "Array([0. , 0.5, 1. ], dtype=float32)" | |
| ] | |
| }, | |
| "execution_count": 14, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "x = jnp.array([1., 2, 3])\n", | |
| "normalize(x)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "id": "af19969d", | |
| "metadata": {}, | |
| "source": [ | |
| "And we can JIT it:" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 15, | |
| "id": "678556f0", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "Array([0. , 0.5, 1. ], dtype=float32)" | |
| ] | |
| }, | |
| "execution_count": 15, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "jax.jit(normalize)(x)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "id": "a6facee4", | |
| "metadata": {}, | |
| "source": [ | |
| "Or use grad:" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 16, | |
| "id": "2db0a4cc", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "Array([-0.25, 0.5 , -0.25], dtype=float32)" | |
| ] | |
| }, | |
| "execution_count": 16, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "jax.grad(normalize_sum)(x)" | |
| ] | |
| } | |
| ], | |
| "metadata": { | |
| "kernelspec": { | |
| "display_name": "demo-notebooks (3.12.10)", | |
| "language": "python", | |
| "name": "python3" | |
| }, | |
| "language_info": { | |
| "codemirror_mode": { | |
| "name": "ipython", | |
| "version": 3 | |
| }, | |
| "file_extension": ".py", | |
| "mimetype": "text/x-python", | |
| "name": "python", | |
| "nbconvert_exporter": "python", | |
| "pygments_lexer": "ipython3", | |
| "version": "3.12.10" | |
| } | |
| }, | |
| "nbformat": 4, | |
| "nbformat_minor": 5 | |
| } |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment