Created
November 10, 2023 05:27
-
-
Save zhangqiaorjc/73cc154f22cbae474f6959d9a9fc8589 to your computer and use it in GitHub Desktop.
sincos remat example.ipynb
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| { | |
| "nbformat": 4, | |
| "nbformat_minor": 0, | |
| "metadata": { | |
| "colab": { | |
| "provenance": [], | |
| "include_colab_link": true | |
| }, | |
| "kernelspec": { | |
| "name": "python3", | |
| "display_name": "Python 3" | |
| }, | |
| "language_info": { | |
| "name": "python" | |
| } | |
| }, | |
| "cells": [ | |
| { | |
| "cell_type": "markdown", | |
| "metadata": { | |
| "id": "view-in-github", | |
| "colab_type": "text" | |
| }, | |
| "source": [ | |
| "<a href=\"https://colab.research.google.com/gist/zhangqiaorjc/73cc154f22cbae474f6959d9a9fc8589/sincos-remat-example.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "metadata": { | |
| "id": "wgXLGxV2-nfI" | |
| }, | |
| "outputs": [], | |
| "source": [ | |
| "import jax\n", | |
| "from jax import core\n", | |
| "import jax.numpy as jnp" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "source": [ | |
| "sincos_p = core.Primitive('sincos')\n", | |
| "sincos_p.multiple_results = True" | |
| ], | |
| "metadata": { | |
| "id": "OXui8YA2-uF4" | |
| }, | |
| "execution_count": null, | |
| "outputs": [] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "source": [ | |
| "@sincos_p.def_impl\n", | |
| "def sincos_impl(x):\n", | |
| " return jnp.sin(x), jnp.cos(x)" | |
| ], | |
| "metadata": { | |
| "id": "eeeD56dS-8Lc" | |
| }, | |
| "execution_count": null, | |
| "outputs": [] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "source": [ | |
| "@sincos_p.def_abstract_eval\n", | |
| "def sincos_abstract_eval(x):\n", | |
| " return x, x" | |
| ], | |
| "metadata": { | |
| "id": "m0L5ow22_AmN" | |
| }, | |
| "execution_count": null, | |
| "outputs": [] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "source": [ | |
| "def sincos(x):\n", | |
| " return sincos_p.bind(x)" | |
| ], | |
| "metadata": { | |
| "id": "0xyg3Jg__Q8C" | |
| }, | |
| "execution_count": null, | |
| "outputs": [] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "source": [ | |
| "jax.make_jaxpr(sincos)(5)" | |
| ], | |
| "metadata": { | |
| "id": "m1cFiNCb_LbK", | |
| "outputId": "5f464fa0-b521-4f16-98d3-4bea39ec3850" | |
| }, | |
| "execution_count": null, | |
| "outputs": [ | |
| { | |
| "output_type": "execute_result", | |
| "data": { | |
| "text/plain": [ | |
| "{ \u001b[34m\u001b[22m\u001b[1mlambda \u001b[39m\u001b[22m\u001b[22m; a\u001b[35m:i32[]\u001b[39m. \u001b[34m\u001b[22m\u001b[1mlet\u001b[39m\u001b[22m\u001b[22m b\u001b[35m:i32[]\u001b[39m c\u001b[35m:i32[]\u001b[39m = sincos a \u001b[34m\u001b[22m\u001b[1min \u001b[39m\u001b[22m\u001b[22m(b, c) }" | |
| ] | |
| }, | |
| "metadata": {}, | |
| "execution_count": 8 | |
| } | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "source": [ | |
| "@jax.custom_vjp\n", | |
| "def sin(x):\n", | |
| " return jnp.sin(x)\n", | |
| "\n", | |
| "def sin_fwd(x):\n", | |
| " return sincos(x)\n", | |
| "\n", | |
| "def sin_bwd(res, g):\n", | |
| " return (res * g,)\n", | |
| "\n", | |
| "sin.defvjp(sin_fwd, sin_bwd)" | |
| ], | |
| "metadata": { | |
| "id": "w6pG4Q0__PdI" | |
| }, | |
| "execution_count": null, | |
| "outputs": [] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "source": [ | |
| "jax.make_jaxpr(jax.grad(sin))(5.0)" | |
| ], | |
| "metadata": { | |
| "id": "MUZNly7x_vWW", | |
| "outputId": "34e730b9-6b0f-4794-c2ba-a384f2fc9813" | |
| }, | |
| "execution_count": null, | |
| "outputs": [ | |
| { | |
| "output_type": "execute_result", | |
| "data": { | |
| "text/plain": [ | |
| "{ \u001b[34m\u001b[22m\u001b[1mlambda \u001b[39m\u001b[22m\u001b[22m; a\u001b[35m:f32[]\u001b[39m. \u001b[34m\u001b[22m\u001b[1mlet\n", | |
| " \u001b[39m\u001b[22m\u001b[22m_\u001b[35m:f32[]\u001b[39m b\u001b[35m:f32[]\u001b[39m = sincos a\n", | |
| " c\u001b[35m:f32[]\u001b[39m = mul b 1.0\n", | |
| " \u001b[34m\u001b[22m\u001b[1min \u001b[39m\u001b[22m\u001b[22m(c,) }" | |
| ] | |
| }, | |
| "metadata": {}, | |
| "execution_count": 14 | |
| } | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "source": [ | |
| "def loss(x):\n", | |
| " return jnp.exp(sin(x))" | |
| ], | |
| "metadata": { | |
| "id": "5tgmUiTH_2aH" | |
| }, | |
| "execution_count": null, | |
| "outputs": [] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "source": [ | |
| "jax.make_jaxpr(jax.grad(loss))(5.0)" | |
| ], | |
| "metadata": { | |
| "id": "NW80ucjxAK1z", | |
| "outputId": "21839d74-de71-4e48-d91e-222b199b4705" | |
| }, | |
| "execution_count": null, | |
| "outputs": [ | |
| { | |
| "output_type": "execute_result", | |
| "data": { | |
| "text/plain": [ | |
| "{ \u001b[34m\u001b[22m\u001b[1mlambda \u001b[39m\u001b[22m\u001b[22m; a\u001b[35m:f32[]\u001b[39m. \u001b[34m\u001b[22m\u001b[1mlet\n", | |
| " \u001b[39m\u001b[22m\u001b[22mb\u001b[35m:f32[]\u001b[39m c\u001b[35m:f32[]\u001b[39m = sincos a\n", | |
| " d\u001b[35m:f32[]\u001b[39m = exp b\n", | |
| " e\u001b[35m:f32[]\u001b[39m = mul 1.0 d\n", | |
| " f\u001b[35m:f32[]\u001b[39m = mul c e\n", | |
| " \u001b[34m\u001b[22m\u001b[1min \u001b[39m\u001b[22m\u001b[22m(f,) }" | |
| ] | |
| }, | |
| "metadata": {}, | |
| "execution_count": 17 | |
| } | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "source": [ | |
| "@jax.checkpoint\n", | |
| "def loss(x):\n", | |
| " return jnp.exp(sin(x))\n", | |
| "\n", | |
| "jax.make_jaxpr(jax.grad(loss))(5.0)" | |
| ], | |
| "metadata": { | |
| "id": "C1swFN6MAMRa", | |
| "outputId": "5a81c4bd-d466-422d-b52a-762441006c91" | |
| }, | |
| "execution_count": null, | |
| "outputs": [ | |
| { | |
| "output_type": "execute_result", | |
| "data": { | |
| "text/plain": [ | |
| "{ \u001b[34m\u001b[22m\u001b[1mlambda \u001b[39m\u001b[22m\u001b[22m; a\u001b[35m:f32[]\u001b[39m. \u001b[34m\u001b[22m\u001b[1mlet\n", | |
| " \u001b[39m\u001b[22m\u001b[22mb\u001b[35m:f32[]\u001b[39m _\u001b[35m:f32[]\u001b[39m = sincos a\n", | |
| " _\u001b[35m:f32[]\u001b[39m = exp b\n", | |
| " c\u001b[35m:f32[]\u001b[39m = remat2[\n", | |
| " differentiated=True\n", | |
| " jaxpr={ \u001b[34m\u001b[22m\u001b[1mlambda \u001b[39m\u001b[22m\u001b[22m; d\u001b[35m:f32[]\u001b[39m e\u001b[35m:f32[]\u001b[39m. \u001b[34m\u001b[22m\u001b[1mlet\n", | |
| " \u001b[39m\u001b[22m\u001b[22mf\u001b[35m:f32[]\u001b[39m g\u001b[35m:f32[]\u001b[39m = sincos d\n", | |
| " h\u001b[35m:f32[]\u001b[39m = exp f\n", | |
| " i\u001b[35m:f32[]\u001b[39m = mul e h\n", | |
| " j\u001b[35m:f32[]\u001b[39m = mul g i\n", | |
| " \u001b[34m\u001b[22m\u001b[1min \u001b[39m\u001b[22m\u001b[22m(j,) }\n", | |
| " policy=None\n", | |
| " prevent_cse=True\n", | |
| " ] a 1.0\n", | |
| " \u001b[34m\u001b[22m\u001b[1min \u001b[39m\u001b[22m\u001b[22m(c,) }" | |
| ] | |
| }, | |
| "metadata": {}, | |
| "execution_count": 20 | |
| } | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "source": [ | |
| "def loss(x):\n", | |
| " x = jax._src.ad_checkpoint.checkpoint_name(sin(x), 'sin(x)')\n", | |
| " return jnp.exp(x)\n", | |
| "\n", | |
| "loss = jax.checkpoint(loss, policy=jax.checkpoint_policies.save_only_these_names('sin(x)'))\n", | |
| "\n", | |
| "jax.make_jaxpr(jax.grad(loss))(5.0)" | |
| ], | |
| "metadata": { | |
| "id": "tZP1R5t7AWU8", | |
| "outputId": "48d1919c-7d02-4c98-c169-d454239b8ed4" | |
| }, | |
| "execution_count": null, | |
| "outputs": [ | |
| { | |
| "output_type": "execute_result", | |
| "data": { | |
| "text/plain": [ | |
| "{ \u001b[34m\u001b[22m\u001b[1mlambda \u001b[39m\u001b[22m\u001b[22m; a\u001b[35m:f32[]\u001b[39m. \u001b[34m\u001b[22m\u001b[1mlet\n", | |
| " \u001b[39m\u001b[22m\u001b[22mb\u001b[35m:f32[]\u001b[39m _\u001b[35m:f32[]\u001b[39m = sincos a\n", | |
| " c\u001b[35m:f32[]\u001b[39m = name[name=sin(x)] b\n", | |
| " _\u001b[35m:f32[]\u001b[39m = exp c\n", | |
| " d\u001b[35m:f32[]\u001b[39m = remat2[\n", | |
| " differentiated=True\n", | |
| " jaxpr={ \u001b[34m\u001b[22m\u001b[1mlambda \u001b[39m\u001b[22m\u001b[22m; e\u001b[35m:f32[]\u001b[39m f\u001b[35m:f32[]\u001b[39m g\u001b[35m:f32[]\u001b[39m. \u001b[34m\u001b[22m\u001b[1mlet\n", | |
| " \u001b[39m\u001b[22m\u001b[22m_\u001b[35m:f32[]\u001b[39m h\u001b[35m:f32[]\u001b[39m = sincos f\n", | |
| " i\u001b[35m:f32[]\u001b[39m = exp e\n", | |
| " j\u001b[35m:f32[]\u001b[39m = mul g i\n", | |
| " k\u001b[35m:f32[]\u001b[39m = mul h j\n", | |
| " \u001b[34m\u001b[22m\u001b[1min \u001b[39m\u001b[22m\u001b[22m(k,) }\n", | |
| " policy=<function save_only_these_names.<locals>.policy at 0x7f531ed7b130>\n", | |
| " prevent_cse=True\n", | |
| " ] c a 1.0\n", | |
| " \u001b[34m\u001b[22m\u001b[1min \u001b[39m\u001b[22m\u001b[22m(d,) }" | |
| ] | |
| }, | |
| "metadata": {}, | |
| "execution_count": 23 | |
| } | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "source": [ | |
| "@jax.custom_vjp\n", | |
| "def sin(x):\n", | |
| " return jnp.sin(x)\n", | |
| "\n", | |
| "def sin_fwd(x):\n", | |
| " sinx, cosx = sincos(x)\n", | |
| " sinx = jax._src.ad_checkpoint.checkpoint_name(sinx, 'sin(x)')\n", | |
| " cosx = jax._src.ad_checkpoint.checkpoint_name(cosx, 'sin(x)')\n", | |
| " return sinx, cosx\n", | |
| "\n", | |
| "def sin_bwd(res, g):\n", | |
| " return (res * g,)\n", | |
| "\n", | |
| "sin.defvjp(sin_fwd, sin_bwd)\n", | |
| "\n", | |
| "def loss(x):\n", | |
| " return jnp.exp(sin(x))\n", | |
| "\n", | |
| "loss = jax.checkpoint(loss, policy=jax.checkpoint_policies.save_only_these_names('sin(x)'))\n", | |
| "\n", | |
| "jax.make_jaxpr(jax.grad(loss))(5.0)" | |
| ], | |
| "metadata": { | |
| "id": "QbiqHCecAr_l", | |
| "outputId": "31d4a105-df41-4e7b-f17c-814a19ad6a2c" | |
| }, | |
| "execution_count": null, | |
| "outputs": [ | |
| { | |
| "output_type": "execute_result", | |
| "data": { | |
| "text/plain": [ | |
| "{ \u001b[34m\u001b[22m\u001b[1mlambda \u001b[39m\u001b[22m\u001b[22m; a\u001b[35m:f32[]\u001b[39m. \u001b[34m\u001b[22m\u001b[1mlet\n", | |
| " \u001b[39m\u001b[22m\u001b[22mb\u001b[35m:f32[]\u001b[39m c\u001b[35m:f32[]\u001b[39m = sincos a\n", | |
| " d\u001b[35m:f32[]\u001b[39m = name[name=sin(x)] b\n", | |
| " e\u001b[35m:f32[]\u001b[39m = name[name=sin(x)] c\n", | |
| " _\u001b[35m:f32[]\u001b[39m = exp d\n", | |
| " f\u001b[35m:f32[]\u001b[39m = remat2[\n", | |
| " differentiated=True\n", | |
| " jaxpr={ \u001b[34m\u001b[22m\u001b[1mlambda \u001b[39m\u001b[22m\u001b[22m; g\u001b[35m:f32[]\u001b[39m h\u001b[35m:f32[]\u001b[39m i\u001b[35m:f32[]\u001b[39m. \u001b[34m\u001b[22m\u001b[1mlet\n", | |
| " \u001b[39m\u001b[22m\u001b[22mj\u001b[35m:f32[]\u001b[39m = exp h\n", | |
| " k\u001b[35m:f32[]\u001b[39m = mul i j\n", | |
| " l\u001b[35m:f32[]\u001b[39m = mul g k\n", | |
| " \u001b[34m\u001b[22m\u001b[1min \u001b[39m\u001b[22m\u001b[22m(l,) }\n", | |
| " policy=<function save_only_these_names.<locals>.policy at 0x7f531ed7ba30>\n", | |
| " prevent_cse=True\n", | |
| " ] e d 1.0\n", | |
| " \u001b[34m\u001b[22m\u001b[1min \u001b[39m\u001b[22m\u001b[22m(f,) }" | |
| ] | |
| }, | |
| "metadata": {}, | |
| "execution_count": 25 | |
| } | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "source": [], | |
| "metadata": { | |
| "id": "t0IEhLtICZY4" | |
| }, | |
| "execution_count": null, | |
| "outputs": [] | |
| } | |
| ] | |
| } |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment