Skip to content

Instantly share code, notes, and snippets.

@jessegrabowski
Last active June 26, 2025 10:53
Show Gist options
  • Select an option

  • Save jessegrabowski/6b3dc83898ed355b850a7f698476ef71 to your computer and use it in GitHub Desktop.

Select an option

Save jessegrabowski/6b3dc83898ed355b850a7f698476ef71 to your computer and use it in GitHub Desktop.
Sparse Jacobians in Pytensor
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"id": "54289e6e",
"metadata": {},
"outputs": [],
"source": [
"import pytensor.tensor as pt\n",
"import pytensor\n",
"from pytensor.graph.basic import explicit_graph_inputs\n",
"\n",
"import numpy as np\n",
"from scipy import sparse\n",
"import networkx as nx"
]
},
{
"cell_type": "markdown",
"id": "02bed835",
"metadata": {},
"source": [
"Test system:\n",
"\n",
"$$\n",
"\\begin{aligned}\n",
"f_1(x) &= x_1^2 + \\sin x_2, \\\\\n",
"f_2(x) &= x_2^3, \\\\\n",
"f_3(x) &= e^{x_3} + x_1, \\\\\n",
"f_4(x) &= e^{x_4} + 3x_1, \\\\\n",
"f_5(x) &= x_5^3 + x_2 x_5,\n",
"\\end{aligned}\n",
"\\qquad\n",
"x = \\begin{bmatrix} x_1 \\\\ x_2 \\\\ x_3 \\\\ x_4 \\\\ x_5 \\end{bmatrix}.\n",
"$$"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "39f02063",
"metadata": {},
"outputs": [],
"source": [
"def get_jacobian_connectivity(variables, equations):\n",
" connectivity_pattern = []\n",
"\n",
" n_eqs = len(equations)\n",
" n_vars = len(variables)\n",
" \n",
" for i, eq in enumerate(equations):\n",
" for j, var in enumerate(variables):\n",
" if var in explicit_graph_inputs(eq):\n",
" connectivity_pattern.append((i, j))\n",
" \n",
" \n",
" return list(zip(*connectivity_pattern))"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "42dbdf04",
"metadata": {},
"outputs": [],
"source": [
"x0, x1, x2, x3, x4 = variables = pt.dscalars('x0 x1 x2 x3 x4'.split())\n",
"\n",
"equations = [\n",
" x0 ** 2 + pt.sin(x1),\n",
" x1 ** 3,\n",
" pt.exp(x2) + x0,\n",
" pt.exp(x3) + 3 * x0,\n",
" x4 ** 3 + x1 * x4\n",
"]\n",
"\n",
"n_vars = len(variables)\n",
"n_eqs = len(equations)"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "362aa6f2",
"metadata": {
"scrolled": true
},
"outputs": [
{
"data": {
"text/plain": [
"[(0, 0, 1, 2, 2, 3, 3, 4, 4), (0, 1, 1, 0, 2, 0, 3, 1, 4)]"
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"test_values = dict.fromkeys(variables, 1.0)\n",
"get_jacobian_connectivity(variables, equations)"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "6613e45f",
"metadata": {},
"outputs": [],
"source": [
"# TODO: Make an Op?\n",
"def _greedy_color(\n",
" connectivity: sparse.spmatrix,\n",
" strategy: str = 'largest_first',\n",
") -> tuple[np.ndarray, int]:\n",
" assert connectivity.ndim == 2\n",
" assert connectivity.shape[0] == connectivity.shape[1]\n",
" G = nx.convert_matrix.from_scipy_sparse_array(connectivity)\n",
" coloring_dict = nx.algorithms.coloring.greedy_color(G, strategy)\n",
" \n",
" indices, colors = list(zip(*coloring_dict.items()))\n",
" coloring = np.asarray(colors)[np.argsort(indices)]\n",
" \n",
" n_colors = np.unique(coloring).size\n",
" \n",
" return G, coloring, n_colors"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "913fd5a3",
"metadata": {},
"outputs": [],
"source": [
"coords = get_jacobian_connectivity(variables, equations)\n",
"data = np.ones(len(coords[0]), dtype='bool')\n",
"\n",
"sparsity = sparse.coo_array((data, coords), (5, 5))\n",
"output_connectivity = sparsity @ sparsity.T\n",
"input_connectivity = sparsity.T @ sparsity"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "0d7c816c",
"metadata": {},
"outputs": [],
"source": [
"G, output_coloring, n_colors = _greedy_color(output_connectivity)"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "f3225e84",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"array([0, 1, 1, 2, 2])"
]
},
"execution_count": 8,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"output_coloring"
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "5f90d65c",
"metadata": {},
"outputs": [
{
"data": {
"image/png": "",
"text/plain": [
"<Figure size 640x480 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"nx.draw_networkx(G)"
]
},
{
"cell_type": "code",
"execution_count": 10,
"id": "668b188d",
"metadata": {},
"outputs": [],
"source": [
"assert output_coloring.size == sparsity.shape[0]\n",
"\n",
"projection_matrix = (\n",
" np.arange(n_colors)[:, None] == output_coloring[None, :]\n",
").astype(float)"
]
},
{
"cell_type": "code",
"execution_count": 11,
"id": "4d2355a9",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"array([[1., 0., 0., 0., 0.],\n",
" [0., 1., 1., 0., 0.],\n",
" [0., 0., 0., 1., 1.]])"
]
},
"execution_count": 11,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"projection_matrix"
]
},
{
"cell_type": "code",
"execution_count": 12,
"id": "94ccad32",
"metadata": {},
"outputs": [],
"source": [
"projected_eqs = projection_matrix @ pt.stack(equations)"
]
},
{
"cell_type": "code",
"execution_count": 13,
"id": "fb144e67",
"metadata": {},
"outputs": [],
"source": [
"def coo_to_csc(rows, cols, data, shape):\n",
" rows = pt.as_tensor_variable(rows)\n",
" cols = pt.as_tensor_variable(cols)\n",
" n_rows, n_cols = shape\n",
" \n",
" order = pt.argsort(cols)\n",
" csc_data = data[order]\n",
" csc_indices = rows[order]\n",
" sorted_cols = cols[order]\n",
" \n",
" counts = pt.bincount(sorted_cols, minlength=n_cols)\n",
" csc_indptr = pt.concatenate(([0], pt.cumsum(counts).astype(int)))\n",
" \n",
" return csc_data, csc_indices, csc_indptr, shape "
]
},
{
"cell_type": "code",
"execution_count": 14,
"id": "3a4dc7a3",
"metadata": {},
"outputs": [],
"source": [
"# TODO: Is this really the fastest way to do this?\n",
"# TODO: Try with the input mapping (I guess this uses R_op instead? Or L_op?)\n",
"compressed_jac = pt.stack(pt.jacobian(projected_eqs, variables), axis=-1)\n",
"\n",
"\n",
"# TODO: Make COO constructor in pytensor directly?\n",
"rows, cols = sparsity.coords\n",
"compressed_index = (output_coloring[rows], cols)\n",
"data = compressed_jac[compressed_index]\n",
"\n",
"\n",
"# Alternative scheme if we had done an input_mapping\n",
"# row, col = sparsity.coords\n",
"# compressed_index = (row, coloring[col])\n",
"# data = compressed_jac_val[compressed_index]\n",
"from pytensor import sparse as pts\n",
"sparse_jac = pts.CSC(*coo_to_csc(rows, cols, data, (n_eqs, n_vars)))"
]
},
{
"cell_type": "code",
"execution_count": 15,
"id": "c7404821",
"metadata": {},
"outputs": [],
"source": [
"f_sparse_jac = pytensor.function(variables, sparse_jac)"
]
},
{
"cell_type": "code",
"execution_count": 16,
"id": "5f827e79",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"matrix([[2. , 0.54030231, 0. , 0. , 0. ],\n",
" [0. , 3. , 0. , 0. , 0. ],\n",
" [1. , 0. , 2.71828183, 0. , 0. ],\n",
" [3. , 0. , 0. , 2.71828183, 0. ],\n",
" [0. , 1. , 0. , 0. , 4. ]])"
]
},
"execution_count": 16,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"f_sparse_jac(*np.ones(5,)).todense()"
]
},
{
"cell_type": "code",
"execution_count": 17,
"id": "6b3ab6e6",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"40.9 μs ± 211 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)\n"
]
}
],
"source": [
"%timeit f_sparse_jac(*np.ones(5,))"
]
},
{
"cell_type": "markdown",
"id": "fda1891a",
"metadata": {},
"source": [
"## Sanity Check"
]
},
{
"cell_type": "code",
"execution_count": 18,
"id": "238e8bc4",
"metadata": {},
"outputs": [],
"source": [
"dense_jac = pt.stack(pt.jacobian(pt.stack(equations), variables), axis=-1)\n",
"dense_fn = pytensor.function(variables, dense_jac)"
]
},
{
"cell_type": "code",
"execution_count": 19,
"id": "60556abe",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"array([[2. , 0.54030231, 0. , 0. , 0. ],\n",
" [0. , 3. , 0. , 0. , 0. ],\n",
" [1. , 0. , 2.71828183, 0. , 0. ],\n",
" [3. , 0. , 0. , 2.71828183, 0. ],\n",
" [0. , 1. , 0. , 0. , 4. ]])"
]
},
"execution_count": 19,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"dense_fn(*np.ones(5,))"
]
},
{
"cell_type": "code",
"execution_count": 20,
"id": "f78bdac9",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"22.1 μs ± 140 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)\n"
]
}
],
"source": [
"%timeit dense_fn(1., 1., 1., 1., 1.)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.9"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment