Last active
March 29, 2022 05:15
-
-
Save dfm/2f38b8349221ccfd9df38f93a99c1de9 to your computer and use it in GitHub Desktop.
Live coded notebooks for the GPRV workshop. Check out https://github.com/dfm/gprv for more self contained versions.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| { | |
| "cells": [ | |
| { | |
| "cell_type": "markdown", | |
| "id": "b102d5cc-0756-45d8-ae50-abcf03538936", | |
| "metadata": {}, | |
| "source": [ | |
| "# Introduction to `jax`" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 1, | |
| "id": "e627de02-c33d-4dbc-a287-539c95ecb0b7", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "name": "stderr", | |
| "output_type": "stream", | |
| "text": [ | |
| "/opt/homebrew/Caskroom/miniforge/base/envs/tinygp/lib/python3.9/site-packages/jax/_src/lib/__init__.py:32: UserWarning: JAX on Mac ARM machines is experimental and minimally tested. Please see https://github.com/google/jax/issues/5501 in the event of problems.\n", | |
| " warnings.warn(\"JAX on Mac ARM machines is experimental and minimally tested. \"\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "import jax\n", | |
| "\n", | |
| "jax.config.update(\"jax_enable_x64\", True)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 2, | |
| "id": "2e832a0c-cb6b-47b3-9d5b-f83837bdaafe", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "import jax.numpy as jnp" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 4, | |
| "id": "8ff69de5-8f8a-44cc-8c6f-e8e66807108f", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "DeviceArray([0. , 0.5, 1. , 1.5, 2. ], dtype=float64)" | |
| ] | |
| }, | |
| "execution_count": 4, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "x = jnp.linspace(0, 2, 5)\n", | |
| "x" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 5, | |
| "id": "37ea9bd7-dfa5-45ba-a8f6-27018461f090", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "DeviceArray([0. , 0.47942554, 0.84147098, 0.99749499, 0.90929743], dtype=float64)" | |
| ] | |
| }, | |
| "execution_count": 5, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "jnp.sin(x)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 6, | |
| "id": "c578f014-e232-4834-a207-6ef518e97a50", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "import numpy as np\n", | |
| "\n", | |
| "y = np.linspace(0, 2, 5)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 7, | |
| "id": "d04e872f-d9c1-4914-80ac-7c912cec721f", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "DeviceArray([0. , 0.47942554, 0.84147098, 0.99749499, 0.90929743], dtype=float64)" | |
| ] | |
| }, | |
| "execution_count": 7, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "jnp.sin(y)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 8, | |
| "id": "048206c1-71b0-4bd7-af7e-6b663f309045", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "array([0. , 0.47942554, 0.84147098, 0.99749499, 0.90929743])" | |
| ] | |
| }, | |
| "execution_count": 8, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "np.sin(x)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 48, | |
| "id": "1505361c-677a-4f66-b437-cd29cd67bbb9", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "from functools import partial\n", | |
| "\n", | |
| "@partial(jax.jit, backend=\"cpu\")\n", | |
| "def func(x):\n", | |
| " arg = jnp.sin(x)\n", | |
| " return 1.5 + jnp.exp(arg)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 49, | |
| "id": "7082e98a-22d8-4c85-834c-18af44fe7ecd", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "DeviceArray([2.5 , 3.1151463 , 3.81977682, 4.21148102, 3.98257773], dtype=float64)" | |
| ] | |
| }, | |
| "execution_count": 49, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "func(x)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 42, | |
| "id": "d257b2ad-e12f-4ac0-9361-aa6608ab369b", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "grad_func = jax.grad(func)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 43, | |
| "id": "68574e9d-69ab-4743-884a-843003f7b23e", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "DeviceArray(1.41742422, dtype=float64, weak_type=True)" | |
| ] | |
| }, | |
| "execution_count": 43, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "grad_func(0.5)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 45, | |
| "id": "ffa28e94-1ae0-4617-bc73-404f344c3fe9", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "DeviceArray([ 1. , 1.41742422, 1.25338077, 0.19180258,\n", | |
| " -1.03311687], dtype=float64)" | |
| ] | |
| }, | |
| "execution_count": 45, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "jax.vmap(grad_func)(x)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 50, | |
| "id": "0e6c7d7b-8e68-4022-8f69-565c15eee9fe", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "def func(params):\n", | |
| " arg = jnp.sin(params[\"a\"])\n", | |
| " return params[\"b\"] + jnp.exp(arg)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 57, | |
| "id": "e1104a4b-eddc-4e24-8a20-3e6f2730ac22", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "DeviceArray([1. , 2.1151463 , 3.31977682, 4.21148102, 4.48257773], dtype=float64)" | |
| ] | |
| }, | |
| "execution_count": 57, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "params = {\n", | |
| " \"a\": np.linspace(0, 2, 5),\n", | |
| " \"b\": jnp.linspace(0, 2, 5),\n", | |
| "}\n", | |
| "func(params)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 58, | |
| "id": "6f528892-f419-42c3-bfe4-34733660aa43", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "DeviceArray([1. , 2.1151463 , 3.31977682, 4.21148102, 4.48257773], dtype=float64)" | |
| ] | |
| }, | |
| "execution_count": 58, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "func(params)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 59, | |
| "id": "31f97436-719a-4036-914c-4e469299a9df", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "{'a': DeviceArray([ 1. , 1.41742422, 1.25338077, 0.19180258,\n", | |
| " -1.03311687], dtype=float64),\n", | |
| " 'b': DeviceArray([1., 1., 1., 1., 1.], dtype=float64)}" | |
| ] | |
| }, | |
| "execution_count": 59, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "jax.vmap(jax.grad(func))(params)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "id": "e4071c23-3e78-4608-8e53-0b760d5ccb0c", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [] | |
| } | |
| ], | |
| "metadata": { | |
| "kernelspec": { | |
| "display_name": "Python 3 (ipykernel)", | |
| "language": "python", | |
| "name": "python3" | |
| }, | |
| "language_info": { | |
| "codemirror_mode": { | |
| "name": "ipython", | |
| "version": 3 | |
| }, | |
| "file_extension": ".py", | |
| "mimetype": "text/x-python", | |
| "name": "python", | |
| "nbconvert_exporter": "python", | |
| "pygments_lexer": "ipython3", | |
| "version": "3.9.9" | |
| } | |
| }, | |
| "nbformat": 4, | |
| "nbformat_minor": 5 | |
| } |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment