Created
February 11, 2025 06:33
-
-
Save jessegrabowski/4f30ebbde57758466f5b1579063ca6d0 to your computer and use it in GitHub Desktop.
Compare solve gradients
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| { | |
| "cells": [ | |
| { | |
| "cell_type": "code", | |
| "execution_count": 1, | |
| "id": "7e1a5154", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "from scipy import linalg\n", | |
| "from numdifftools import Gradient, Jacobian, Derivative\n", | |
| "import numpy as np\n", | |
| "\n", | |
| "rng = np.random.default_rng()" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 2, | |
| "id": "2d252962", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "import pytensor\n", | |
| "import pytensor.tensor as pt" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 3, | |
| "id": "b53559de", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "import jax\n", | |
| "import jax.numpy as jnp\n", | |
| "import jax.scipy as jsp" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 4, | |
| "id": "37d94941", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "def numeric_gradient(a, b, assume_a='gen', a_is_chol=False):\n", | |
| " a_shape = a.shape\n", | |
| " \n", | |
| " def f(a, b):\n", | |
| " # When assume_a == 'pos', the numerical jitter applied here will screw up everything\n", | |
| " # so we work with the cholesky factor instead\n", | |
| " a2 = a.reshape(a_shape)\n", | |
| " \n", | |
| " if a_is_chol:\n", | |
| " a2 = a2 @ a2.T\n", | |
| " \n", | |
| " return linalg.solve(a2, b, assume_a=assume_a).sum()\n", | |
| " \n", | |
| " d_f = Gradient(f)\n", | |
| " return d_f(a.ravel(), b).reshape(a_shape)\n", | |
| "\n", | |
| "def pytensor_gradient(a, b, assume_a='gen', a_is_chol=False):\n", | |
| " A = pt.tensor('A', shape=(5, 5))\n", | |
| " B = pt.tensor('B', shape=(5, ))\n", | |
| " \n", | |
| " if a_is_chol:\n", | |
| " # See above -- A is a cholesky factor\n", | |
| " A2 = A @ A.T\n", | |
| " else:\n", | |
| " A2 = A\n", | |
| " \n", | |
| " X = pt.linalg.solve(A2, B, assume_a=assume_a)\n", | |
| " dX = pt.grad(X.sum(), A)\n", | |
| " f_dX = pytensor.function([A, B], dX)\n", | |
| " \n", | |
| " return f_dX(a, b)\n", | |
| "\n", | |
| "def jax_gradient(a, b, assume_a='gen', a_is_chol=False):\n", | |
| " @jax.jit\n", | |
| " def f_jax(a, b):\n", | |
| " if a_is_chol:\n", | |
| " # See above -- a is a cholesky factor\n", | |
| " a = a @ a.T\n", | |
| " return jsp.linalg.solve(a, b, assume_a=assume_a).sum()\n", | |
| "\n", | |
| " return jax.grad(f_jax, 0)(a, b)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 5, | |
| "id": "f7994795", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "def compare_gradients(a, b, assume_a='gen', exclude=None, a_is_chol=False):\n", | |
| " grads = []\n", | |
| " names = []\n", | |
| " if exclude is None:\n", | |
| " exclude = []\n", | |
| " if isinstance(exclude, str):\n", | |
| " exclude = [exclude]\n", | |
| " \n", | |
| " if 'numeric' not in exclude:\n", | |
| " numeric_grad = numeric_gradient(a, b, assume_a, a_is_chol)\n", | |
| " grads.append(numeric_grad)\n", | |
| " names.append('Numeric')\n", | |
| " \n", | |
| " if 'pytensor' not in exclude:\n", | |
| " pt_grad = pytensor_gradient(a, b, assume_a, a_is_chol)\n", | |
| " grads.append(pt_grad)\n", | |
| " names.append('Pytensor')\n", | |
| " \n", | |
| " if 'jax' not in exclude:\n", | |
| " jax_grad = jax_gradient(a, b, assume_a, a_is_chol)\n", | |
| " grads.append(jax_grad)\n", | |
| " names.append('JAX')\n", | |
| "\n", | |
| " \n", | |
| " with np.printoptions(linewidth=1000, precision=3, suppress=True):\n", | |
| " for grad, name in zip(grads, names):\n", | |
| " print(f'{name} Gradient:')\n", | |
| " print('-' * 30)\n", | |
| " print(grad)\n", | |
| " print('\\n')" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "id": "5f6280b6", | |
| "metadata": {}, | |
| "source": [ | |
| "# Gen" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 6, | |
| "id": "37ceafa5", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "Numeric Gradient:\n", | |
| "------------------------------\n", | |
| "[[ 1.323 -1.178 -1.647 0.459 -0.206]\n", | |
| " [ 1.269 -1.13 -1.58 0.441 -0.198]\n", | |
| " [ 2.398 -2.135 -2.986 0.833 -0.373]\n", | |
| " [-3.607 3.211 4.492 -1.253 0.561]\n", | |
| " [ 0.686 -0.611 -0.855 0.238 -0.107]]\n", | |
| "\n", | |
| "\n", | |
| "Pytensor Gradient:\n", | |
| "------------------------------\n", | |
| "[[ 1.323 -1.178 -1.647 0.459 -0.206]\n", | |
| " [ 1.269 -1.13 -1.58 0.441 -0.198]\n", | |
| " [ 2.398 -2.135 -2.986 0.833 -0.373]\n", | |
| " [-3.607 3.211 4.492 -1.253 0.561]\n", | |
| " [ 0.686 -0.611 -0.855 0.238 -0.107]]\n", | |
| "\n", | |
| "\n", | |
| "JAX Gradient:\n", | |
| "------------------------------\n", | |
| "[[ 1.323 -1.178 -1.647 0.459 -0.206]\n", | |
| " [ 1.269 -1.13 -1.58 0.441 -0.198]\n", | |
| " [ 2.398 -2.135 -2.986 0.833 -0.373]\n", | |
| " [-3.607 3.211 4.492 -1.253 0.561]\n", | |
| " [ 0.686 -0.611 -0.855 0.238 -0.107]]\n", | |
| "\n", | |
| "\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "a = rng.normal(size=(5, 5))\n", | |
| "b = rng.normal(size=(5,))\n", | |
| "compare_gradients(a, b, 'gen')" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "id": "38316b80", | |
| "metadata": {}, | |
| "source": [ | |
| "# Sym" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 7, | |
| "id": "7223c550", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "Numeric Gradient:\n", | |
| "------------------------------\n", | |
| "[[ 0.615 16.829 -1.455 -3.879 -5.17 ]\n", | |
| " [ 0. 27.426 -11.735 -9.001 -10.102]\n", | |
| " [ 0. 0. 0.67 2.536 3.291]\n", | |
| " [ 0. 0. 0. 0.58 1.069]\n", | |
| " [ 0. 0. 0. 0. 0.385]]\n", | |
| "\n", | |
| "\n", | |
| "Pytensor Gradient:\n", | |
| "------------------------------\n", | |
| "[[ 0.615 1.071 -0.386 -0.094 -0.046]\n", | |
| " [15.758 27.426 -9.875 -2.413 -1.184]\n", | |
| " [-1.069 -1.86 0.67 0.164 0.08 ]\n", | |
| " [-3.785 -6.588 2.372 0.58 0.284]\n", | |
| " [-5.124 -8.918 3.211 0.785 0.385]]\n", | |
| "\n", | |
| "\n", | |
| "JAX Gradient:\n", | |
| "------------------------------\n", | |
| "[[ 0.615 1.071 -0.386 -0.094 -0.046]\n", | |
| " [15.758 27.426 -9.875 -2.413 -1.184]\n", | |
| " [-1.069 -1.86 0.67 0.164 0.08 ]\n", | |
| " [-3.785 -6.588 2.372 0.58 0.284]\n", | |
| " [-5.124 -8.918 3.211 0.785 0.385]]\n", | |
| "\n", | |
| "\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "a = rng.normal(size=(5, 5))\n", | |
| "b = rng.normal(size=(5,))\n", | |
| "\n", | |
| "a = (a + a.T) / 2\n", | |
| "\n", | |
| "compare_gradients(a, b, 'sym')" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "id": "fc046be0", | |
| "metadata": {}, | |
| "source": [ | |
| "# Pos" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 8, | |
| "id": "cfe61c3a", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "Numeric Gradient:\n", | |
| "------------------------------\n", | |
| "[[ 0.035 -12.41 -0.76 -16.333 -122.356]\n", | |
| " [ -0.467 -61.222 -5.828 -81.918 -613.042]\n", | |
| " [ -0.052 -29.476 -2.247 -39.08 -292.625]\n", | |
| " [ -0.576 -72.837 -7.001 -97.504 -729.656]\n", | |
| " [ -0.235 -28.908 -2.8 -38.711 -289.68 ]]\n", | |
| "\n", | |
| "\n", | |
| "Pytensor Gradient:\n", | |
| "------------------------------\n", | |
| "[[ 0.035 -12.41 -0.76 -16.333 -122.356]\n", | |
| " [ -0.467 -61.222 -5.828 -81.918 -613.042]\n", | |
| " [ -0.052 -29.476 -2.247 -39.08 -292.625]\n", | |
| " [ -0.576 -72.837 -7.001 -97.504 -729.656]\n", | |
| " [ -0.235 -28.908 -2.8 -38.711 -289.68 ]]\n", | |
| "\n", | |
| "\n", | |
| "JAX Gradient:\n", | |
| "------------------------------\n", | |
| "[[ 0.035 -12.409 -0.76 -16.333 -122.354]\n", | |
| " [ -0.467 -61.221 -5.827 -81.918 -613.035]\n", | |
| " [ -0.052 -29.476 -2.247 -39.079 -292.622]\n", | |
| " [ -0.576 -72.837 -7. -97.503 -729.647]\n", | |
| " [ -0.235 -28.907 -2.799 -38.71 -289.676]]\n", | |
| "\n", | |
| "\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "a = rng.normal(size=(5, 5))\n", | |
| "b = rng.normal(size=(5,))\n", | |
| "\n", | |
| "a = np.linalg.cholesky(a @ a.T)\n", | |
| "\n", | |
| "compare_gradients(a, b, 'pos', a_is_chol=True)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "id": "230d1762", | |
| "metadata": {}, | |
| "source": [ | |
| "Double check doing the cholesky thing isn't doing something weird" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 9, | |
| "id": "55689b6a", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "Pytensor Gradient:\n", | |
| "------------------------------\n", | |
| "[[-458.861 138.34 533.123 -182.192 -384.686]\n", | |
| " [ 135.137 -40.742 -157.008 53.657 113.292]\n", | |
| " [ 516.842 -155.821 -600.488 205.214 433.295]\n", | |
| " [-185.079 55.799 215.032 -73.486 -155.161]\n", | |
| " [-374.896 113.026 435.569 -148.853 -314.294]]\n", | |
| "\n", | |
| "\n", | |
| "JAX Gradient:\n", | |
| "------------------------------\n", | |
| "[[-458.865 138.342 533.128 -182.194 -384.689]\n", | |
| " [ 135.138 -40.742 -157.009 53.657 113.293]\n", | |
| " [ 516.847 -155.822 -600.493 205.216 433.298]\n", | |
| " [-185.081 55.799 215.034 -73.487 -155.162]\n", | |
| " [-374.899 113.027 435.572 -148.855 -314.296]]\n", | |
| "\n", | |
| "\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "a = rng.normal(size=(5, 5))\n", | |
| "b = rng.normal(size=(5,))\n", | |
| "\n", | |
| "a = a @ a.T\n", | |
| "\n", | |
| "compare_gradients(a, b, 'pos', exclude='numeric', a_is_chol=False)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "id": "fda85a5a", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [] | |
| } | |
| ], | |
| "metadata": { | |
| "kernelspec": { | |
| "display_name": "Python 3 (ipykernel)", | |
| "language": "python", | |
| "name": "python3" | |
| }, | |
| "language_info": { | |
| "codemirror_mode": { | |
| "name": "ipython", | |
| "version": 3 | |
| }, | |
| "file_extension": ".py", | |
| "mimetype": "text/x-python", | |
| "name": "python", | |
| "nbconvert_exporter": "python", | |
| "pygments_lexer": "ipython3", | |
| "version": "3.12.7" | |
| } | |
| }, | |
| "nbformat": 4, | |
| "nbformat_minor": 5 | |
| } |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment