Skip to content

Instantly share code, notes, and snippets.

@jessegrabowski
Created February 11, 2025 06:33
Show Gist options
  • Select an option

  • Save jessegrabowski/4f30ebbde57758466f5b1579063ca6d0 to your computer and use it in GitHub Desktop.

Select an option

Save jessegrabowski/4f30ebbde57758466f5b1579063ca6d0 to your computer and use it in GitHub Desktop.
Compare solve gradients
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"id": "7e1a5154",
"metadata": {},
"outputs": [],
"source": [
"from scipy import linalg\n",
"from numdifftools import Gradient, Jacobian, Derivative\n",
"import numpy as np\n",
"\n",
"rng = np.random.default_rng()"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "2d252962",
"metadata": {},
"outputs": [],
"source": [
"import pytensor\n",
"import pytensor.tensor as pt"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "b53559de",
"metadata": {},
"outputs": [],
"source": [
"import jax\n",
"import jax.numpy as jnp\n",
"import jax.scipy as jsp"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "37d94941",
"metadata": {},
"outputs": [],
"source": [
"def numeric_gradient(a, b, assume_a='gen', a_is_chol=False):\n",
" a_shape = a.shape\n",
" \n",
" def f(a, b):\n",
" # When assume_a == 'pos', the numerical jitter applied here will screw up everything\n",
" # so we work with the cholesky factor instead\n",
" a2 = a.reshape(a_shape)\n",
" \n",
" if a_is_chol:\n",
" a2 = a2 @ a2.T\n",
" \n",
" return linalg.solve(a2, b, assume_a=assume_a).sum()\n",
" \n",
" d_f = Gradient(f)\n",
" return d_f(a.ravel(), b).reshape(a_shape)\n",
"\n",
"def pytensor_gradient(a, b, assume_a='gen', a_is_chol=False):\n",
" A = pt.tensor('A', shape=(5, 5))\n",
" B = pt.tensor('B', shape=(5, ))\n",
" \n",
" if a_is_chol:\n",
" # See above -- A is a cholesky factor\n",
" A2 = A @ A.T\n",
" else:\n",
" A2 = A\n",
" \n",
" X = pt.linalg.solve(A2, B, assume_a=assume_a)\n",
" dX = pt.grad(X.sum(), A)\n",
" f_dX = pytensor.function([A, B], dX)\n",
" \n",
" return f_dX(a, b)\n",
"\n",
"def jax_gradient(a, b, assume_a='gen', a_is_chol=False):\n",
" @jax.jit\n",
" def f_jax(a, b):\n",
" if a_is_chol:\n",
" # See above -- a is a cholesky factor\n",
" a = a @ a.T\n",
" return jsp.linalg.solve(a, b, assume_a=assume_a).sum()\n",
"\n",
" return jax.grad(f_jax, 0)(a, b)"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "f7994795",
"metadata": {},
"outputs": [],
"source": [
"def compare_gradients(a, b, assume_a='gen', exclude=None, a_is_chol=False):\n",
" grads = []\n",
" names = []\n",
" if exclude is None:\n",
" exclude = []\n",
" if isinstance(exclude, str):\n",
" exclude = [exclude]\n",
" \n",
" if 'numeric' not in exclude:\n",
" numeric_grad = numeric_gradient(a, b, assume_a, a_is_chol)\n",
" grads.append(numeric_grad)\n",
" names.append('Numeric')\n",
" \n",
" if 'pytensor' not in exclude:\n",
" pt_grad = pytensor_gradient(a, b, assume_a, a_is_chol)\n",
" grads.append(pt_grad)\n",
" names.append('Pytensor')\n",
" \n",
" if 'jax' not in exclude:\n",
" jax_grad = jax_gradient(a, b, assume_a, a_is_chol)\n",
" grads.append(jax_grad)\n",
" names.append('JAX')\n",
"\n",
" \n",
" with np.printoptions(linewidth=1000, precision=3, suppress=True):\n",
" for grad, name in zip(grads, names):\n",
" print(f'{name} Gradient:')\n",
" print('-' * 30)\n",
" print(grad)\n",
" print('\\n')"
]
},
{
"cell_type": "markdown",
"id": "5f6280b6",
"metadata": {},
"source": [
"# Gen"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "37ceafa5",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Numeric Gradient:\n",
"------------------------------\n",
"[[ 1.323 -1.178 -1.647 0.459 -0.206]\n",
" [ 1.269 -1.13 -1.58 0.441 -0.198]\n",
" [ 2.398 -2.135 -2.986 0.833 -0.373]\n",
" [-3.607 3.211 4.492 -1.253 0.561]\n",
" [ 0.686 -0.611 -0.855 0.238 -0.107]]\n",
"\n",
"\n",
"Pytensor Gradient:\n",
"------------------------------\n",
"[[ 1.323 -1.178 -1.647 0.459 -0.206]\n",
" [ 1.269 -1.13 -1.58 0.441 -0.198]\n",
" [ 2.398 -2.135 -2.986 0.833 -0.373]\n",
" [-3.607 3.211 4.492 -1.253 0.561]\n",
" [ 0.686 -0.611 -0.855 0.238 -0.107]]\n",
"\n",
"\n",
"JAX Gradient:\n",
"------------------------------\n",
"[[ 1.323 -1.178 -1.647 0.459 -0.206]\n",
" [ 1.269 -1.13 -1.58 0.441 -0.198]\n",
" [ 2.398 -2.135 -2.986 0.833 -0.373]\n",
" [-3.607 3.211 4.492 -1.253 0.561]\n",
" [ 0.686 -0.611 -0.855 0.238 -0.107]]\n",
"\n",
"\n"
]
}
],
"source": [
"a = rng.normal(size=(5, 5))\n",
"b = rng.normal(size=(5,))\n",
"compare_gradients(a, b, 'gen')"
]
},
{
"cell_type": "markdown",
"id": "38316b80",
"metadata": {},
"source": [
"# Sym"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "7223c550",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Numeric Gradient:\n",
"------------------------------\n",
"[[ 0.615 16.829 -1.455 -3.879 -5.17 ]\n",
" [ 0. 27.426 -11.735 -9.001 -10.102]\n",
" [ 0. 0. 0.67 2.536 3.291]\n",
" [ 0. 0. 0. 0.58 1.069]\n",
" [ 0. 0. 0. 0. 0.385]]\n",
"\n",
"\n",
"Pytensor Gradient:\n",
"------------------------------\n",
"[[ 0.615 1.071 -0.386 -0.094 -0.046]\n",
" [15.758 27.426 -9.875 -2.413 -1.184]\n",
" [-1.069 -1.86 0.67 0.164 0.08 ]\n",
" [-3.785 -6.588 2.372 0.58 0.284]\n",
" [-5.124 -8.918 3.211 0.785 0.385]]\n",
"\n",
"\n",
"JAX Gradient:\n",
"------------------------------\n",
"[[ 0.615 1.071 -0.386 -0.094 -0.046]\n",
" [15.758 27.426 -9.875 -2.413 -1.184]\n",
" [-1.069 -1.86 0.67 0.164 0.08 ]\n",
" [-3.785 -6.588 2.372 0.58 0.284]\n",
" [-5.124 -8.918 3.211 0.785 0.385]]\n",
"\n",
"\n"
]
}
],
"source": [
"a = rng.normal(size=(5, 5))\n",
"b = rng.normal(size=(5,))\n",
"\n",
"a = (a + a.T) / 2\n",
"\n",
"compare_gradients(a, b, 'sym')"
]
},
{
"cell_type": "markdown",
"id": "fc046be0",
"metadata": {},
"source": [
"# Pos"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "cfe61c3a",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Numeric Gradient:\n",
"------------------------------\n",
"[[ 0.035 -12.41 -0.76 -16.333 -122.356]\n",
" [ -0.467 -61.222 -5.828 -81.918 -613.042]\n",
" [ -0.052 -29.476 -2.247 -39.08 -292.625]\n",
" [ -0.576 -72.837 -7.001 -97.504 -729.656]\n",
" [ -0.235 -28.908 -2.8 -38.711 -289.68 ]]\n",
"\n",
"\n",
"Pytensor Gradient:\n",
"------------------------------\n",
"[[ 0.035 -12.41 -0.76 -16.333 -122.356]\n",
" [ -0.467 -61.222 -5.828 -81.918 -613.042]\n",
" [ -0.052 -29.476 -2.247 -39.08 -292.625]\n",
" [ -0.576 -72.837 -7.001 -97.504 -729.656]\n",
" [ -0.235 -28.908 -2.8 -38.711 -289.68 ]]\n",
"\n",
"\n",
"JAX Gradient:\n",
"------------------------------\n",
"[[ 0.035 -12.409 -0.76 -16.333 -122.354]\n",
" [ -0.467 -61.221 -5.827 -81.918 -613.035]\n",
" [ -0.052 -29.476 -2.247 -39.079 -292.622]\n",
" [ -0.576 -72.837 -7. -97.503 -729.647]\n",
" [ -0.235 -28.907 -2.799 -38.71 -289.676]]\n",
"\n",
"\n"
]
}
],
"source": [
"a = rng.normal(size=(5, 5))\n",
"b = rng.normal(size=(5,))\n",
"\n",
"a = np.linalg.cholesky(a @ a.T)\n",
"\n",
"compare_gradients(a, b, 'pos', a_is_chol=True)"
]
},
{
"cell_type": "markdown",
"id": "230d1762",
"metadata": {},
"source": [
"Double check doing the cholesky thing isn't doing something weird"
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "55689b6a",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Pytensor Gradient:\n",
"------------------------------\n",
"[[-458.861 138.34 533.123 -182.192 -384.686]\n",
" [ 135.137 -40.742 -157.008 53.657 113.292]\n",
" [ 516.842 -155.821 -600.488 205.214 433.295]\n",
" [-185.079 55.799 215.032 -73.486 -155.161]\n",
" [-374.896 113.026 435.569 -148.853 -314.294]]\n",
"\n",
"\n",
"JAX Gradient:\n",
"------------------------------\n",
"[[-458.865 138.342 533.128 -182.194 -384.689]\n",
" [ 135.138 -40.742 -157.009 53.657 113.293]\n",
" [ 516.847 -155.822 -600.493 205.216 433.298]\n",
" [-185.081 55.799 215.034 -73.487 -155.162]\n",
" [-374.899 113.027 435.572 -148.855 -314.296]]\n",
"\n",
"\n"
]
}
],
"source": [
"a = rng.normal(size=(5, 5))\n",
"b = rng.normal(size=(5,))\n",
"\n",
"a = a @ a.T\n",
"\n",
"compare_gradients(a, b, 'pos', exclude='numeric', a_is_chol=False)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "fda85a5a",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.7"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment