Last active
June 26, 2025 10:53
-
-
Save jessegrabowski/6b3dc83898ed355b850a7f698476ef71 to your computer and use it in GitHub Desktop.
Sparse Jacobians in Pytensor
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| { | |
| "cells": [ | |
| { | |
| "cell_type": "code", | |
| "execution_count": 1, | |
| "id": "54289e6e", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "import pytensor.tensor as pt\n", | |
| "import pytensor\n", | |
| "from pytensor.graph.basic import explicit_graph_inputs\n", | |
| "\n", | |
| "import numpy as np\n", | |
| "from scipy import sparse\n", | |
| "import networkx as nx" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "id": "02bed835", | |
| "metadata": {}, | |
| "source": [ | |
| "Test system:\n", | |
| "\n", | |
| "$$\n", | |
| "\\begin{aligned}\n", | |
| "f_1(x) &= x_1^2 + \\sin x_2, \\\\\n", | |
| "f_2(x) &= x_2^3, \\\\\n", | |
| "f_3(x) &= e^{x_3} + x_1, \\\\\n", | |
| "f_4(x) &= e^{x_4} + 3x_1, \\\\\n", | |
| "f_5(x) &= x_5^3 + x_2 x_5,\n", | |
| "\\end{aligned}\n", | |
| "\\qquad\n", | |
| "x = \\begin{bmatrix} x_1 \\\\ x_2 \\\\ x_3 \\\\ x_4 \\\\ x_5 \\end{bmatrix}.\n", | |
| "$$" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 2, | |
| "id": "39f02063", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "def get_jacobian_connectivity(variables, equations):\n", | |
| " connectivity_pattern = []\n", | |
| "\n", | |
| " n_eqs = len(equations)\n", | |
| " n_vars = len(variables)\n", | |
| " \n", | |
| " for i, eq in enumerate(equations):\n", | |
| " for j, var in enumerate(variables):\n", | |
| " if var in explicit_graph_inputs(eq):\n", | |
| " connectivity_pattern.append((i, j))\n", | |
| " \n", | |
| " \n", | |
| " return list(zip(*connectivity_pattern))" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 3, | |
| "id": "42dbdf04", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "x0, x1, x2, x3, x4 = variables = pt.dscalars('x0 x1 x2 x3 x4'.split())\n", | |
| "\n", | |
| "equations = [\n", | |
| " x0 ** 2 + pt.sin(x1),\n", | |
| " x1 ** 3,\n", | |
| " pt.exp(x2) + x0,\n", | |
| " pt.exp(x3) + 3 * x0,\n", | |
| " x4 ** 3 + x1 * x4\n", | |
| "]\n", | |
| "\n", | |
| "n_vars = len(variables)\n", | |
| "n_eqs = len(equations)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 4, | |
| "id": "362aa6f2", | |
| "metadata": { | |
| "scrolled": true | |
| }, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "[(0, 0, 1, 2, 2, 3, 3, 4, 4), (0, 1, 1, 0, 2, 0, 3, 1, 4)]" | |
| ] | |
| }, | |
| "execution_count": 4, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "test_values = dict.fromkeys(variables, 1.0)\n", | |
| "get_jacobian_connectivity(variables, equations)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 5, | |
| "id": "6613e45f", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "# TODO: Make an Op?\n", | |
| "def _greedy_color(\n", | |
| " connectivity: sparse.spmatrix,\n", | |
| " strategy: str = 'largest_first',\n", | |
| ") -> tuple[np.ndarray, int]:\n", | |
| " assert connectivity.ndim == 2\n", | |
| " assert connectivity.shape[0] == connectivity.shape[1]\n", | |
| " G = nx.convert_matrix.from_scipy_sparse_array(connectivity)\n", | |
| " coloring_dict = nx.algorithms.coloring.greedy_color(G, strategy)\n", | |
| " \n", | |
| " indices, colors = list(zip(*coloring_dict.items()))\n", | |
| " coloring = np.asarray(colors)[np.argsort(indices)]\n", | |
| " \n", | |
| " n_colors = np.unique(coloring).size\n", | |
| " \n", | |
| " return G, coloring, n_colors" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 6, | |
| "id": "913fd5a3", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "coords = get_jacobian_connectivity(variables, equations)\n", | |
| "data = np.ones(len(coords[0]), dtype='bool')\n", | |
| "\n", | |
| "sparsity = sparse.coo_array((data, coords), (5, 5))\n", | |
| "output_connectivity = sparsity @ sparsity.T\n", | |
| "input_connectivity = sparsity.T @ sparsity" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 7, | |
| "id": "0d7c816c", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "G, output_coloring, n_colors = _greedy_color(output_connectivity)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 8, | |
| "id": "f3225e84", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "array([0, 1, 1, 2, 2])" | |
| ] | |
| }, | |
| "execution_count": 8, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "output_coloring" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 9, | |
| "id": "5f90d65c", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "image/png": "", | |
| "text/plain": [ | |
| "<Figure size 640x480 with 1 Axes>" | |
| ] | |
| }, | |
| "metadata": {}, | |
| "output_type": "display_data" | |
| } | |
| ], | |
| "source": [ | |
| "nx.draw_networkx(G)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 10, | |
| "id": "668b188d", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "assert output_coloring.size == sparsity.shape[0]\n", | |
| "\n", | |
| "projection_matrix = (\n", | |
| " np.arange(n_colors)[:, None] == output_coloring[None, :]\n", | |
| ").astype(float)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 11, | |
| "id": "4d2355a9", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "array([[1., 0., 0., 0., 0.],\n", | |
| " [0., 1., 1., 0., 0.],\n", | |
| " [0., 0., 0., 1., 1.]])" | |
| ] | |
| }, | |
| "execution_count": 11, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "projection_matrix" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 12, | |
| "id": "94ccad32", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "projected_eqs = projection_matrix @ pt.stack(equations)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 13, | |
| "id": "fb144e67", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "def coo_to_csc(rows, cols, data, shape):\n", | |
| " rows = pt.as_tensor_variable(rows)\n", | |
| " cols = pt.as_tensor_variable(cols)\n", | |
| " n_rows, n_cols = shape\n", | |
| " \n", | |
| " order = pt.argsort(cols)\n", | |
| " csc_data = data[order]\n", | |
| " csc_indices = rows[order]\n", | |
| " sorted_cols = cols[order]\n", | |
| " \n", | |
| " counts = pt.bincount(sorted_cols, minlength=n_cols)\n", | |
| " csc_indptr = pt.concatenate(([0], pt.cumsum(counts).astype(int)))\n", | |
| " \n", | |
| " return csc_data, csc_indices, csc_indptr, shape " | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 14, | |
| "id": "3a4dc7a3", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "# TODO: Is this really the fastest way to do this?\n", | |
| "# TODO: Try with the input mapping (I guess this uses R_op instead? Or L_op?)\n", | |
| "compressed_jac = pt.stack(pt.jacobian(projected_eqs, variables), axis=-1)\n", | |
| "\n", | |
| "\n", | |
| "# TODO: Make COO constructor in pytensor directly?\n", | |
| "rows, cols = sparsity.coords\n", | |
| "compressed_index = (output_coloring[rows], cols)\n", | |
| "data = compressed_jac[compressed_index]\n", | |
| "\n", | |
| "\n", | |
| "# Alternative scheme if we had done an input_mapping\n", | |
| "# row, col = sparsity.coords\n", | |
| "# compressed_index = (row, coloring[col])\n", | |
| "# data = compressed_jac_val[compressed_index]\n", | |
| "from pytensor import sparse as pts\n", | |
| "sparse_jac = pts.CSC(*coo_to_csc(rows, cols, data, (n_eqs, n_vars)))" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 15, | |
| "id": "c7404821", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "f_sparse_jac = pytensor.function(variables, sparse_jac)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 16, | |
| "id": "5f827e79", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "matrix([[2. , 0.54030231, 0. , 0. , 0. ],\n", | |
| " [0. , 3. , 0. , 0. , 0. ],\n", | |
| " [1. , 0. , 2.71828183, 0. , 0. ],\n", | |
| " [3. , 0. , 0. , 2.71828183, 0. ],\n", | |
| " [0. , 1. , 0. , 0. , 4. ]])" | |
| ] | |
| }, | |
| "execution_count": 16, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "f_sparse_jac(*np.ones(5,)).todense()" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 17, | |
| "id": "6b3ab6e6", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "40.9 μs ± 211 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "%timeit f_sparse_jac(*np.ones(5,))" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "id": "fda1891a", | |
| "metadata": {}, | |
| "source": [ | |
| "## Sanity Check" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 18, | |
| "id": "238e8bc4", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "dense_jac = pt.stack(pt.jacobian(pt.stack(equations), variables), axis=-1)\n", | |
| "dense_fn = pytensor.function(variables, dense_jac)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 19, | |
| "id": "60556abe", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "array([[2. , 0.54030231, 0. , 0. , 0. ],\n", | |
| " [0. , 3. , 0. , 0. , 0. ],\n", | |
| " [1. , 0. , 2.71828183, 0. , 0. ],\n", | |
| " [3. , 0. , 0. , 2.71828183, 0. ],\n", | |
| " [0. , 1. , 0. , 0. , 4. ]])" | |
| ] | |
| }, | |
| "execution_count": 19, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "dense_fn(*np.ones(5,))" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 20, | |
| "id": "f78bdac9", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "22.1 μs ± 140 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "%timeit dense_fn(1., 1., 1., 1., 1.)" | |
| ] | |
| } | |
| ], | |
| "metadata": { | |
| "kernelspec": { | |
| "display_name": "Python 3 (ipykernel)", | |
| "language": "python", | |
| "name": "python3" | |
| }, | |
| "language_info": { | |
| "codemirror_mode": { | |
| "name": "ipython", | |
| "version": 3 | |
| }, | |
| "file_extension": ".py", | |
| "mimetype": "text/x-python", | |
| "name": "python", | |
| "nbconvert_exporter": "python", | |
| "pygments_lexer": "ipython3", | |
| "version": "3.12.9" | |
| } | |
| }, | |
| "nbformat": 4, | |
| "nbformat_minor": 5 | |
| } |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment