Last active
May 23, 2025 07:00
-
-
Save jessegrabowski/f194bf5cdf92937b42075027fa46f3d5 to your computer and use it in GitHub Desktop.
Specialized Dot Benchmarks
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| { | |
| "cells": [ | |
| { | |
| "cell_type": "code", | |
| "execution_count": 1, | |
| "id": "36896a04", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "import numpy as np\n", | |
| "from scipy import linalg\n", | |
| "\n", | |
| "\n", | |
| "SEED = sum(map(ord, 'matrix madness!'))\n", | |
| "rng = np.random.default_rng(SEED)\n", | |
| "\n", | |
| "def make_Ab(size=(10, 10), b_ndim=1, assume_a='gen'):\n", | |
| " assert b_ndim in [1, 2]\n", | |
| " assert assume_a in ['gen', 'sym', 'pos', 'tri', 'tridiag', 'tridiag-direct']\n", | |
| " \n", | |
| " A = rng.normal(size=size)\n", | |
| " b = rng.normal(size=(size[-1],) * b_ndim)\n", | |
| " \n", | |
| " if assume_a == 'sym':\n", | |
| " A = (A + A.T) / 2\n", | |
| " elif assume_a == 'pos':\n", | |
| " A = A @ A.T\n", | |
| " elif assume_a == 'tri':\n", | |
| " A = np.tril(A)\n", | |
| " elif assume_a == 'tridiag':\n", | |
| " dl, d, du = (np.diag(A, k=k) for k in [-1, 0, 1])\n", | |
| " A = np.diag(dl, k=-1) + np.diag(d) + np.diag(du, k=1)\n", | |
| " \n", | |
| " return np.asfortranarray(A), np.asfortranarray(b)\n", | |
| "\n", | |
| "def direct_tridiag(dl, d, du, b):\n", | |
| " b = np.asarray(b)\n", | |
| " x = d[..., :, None] * b if b.ndim == 2 else d * b\n", | |
| "\n", | |
| " if b.ndim == 1:\n", | |
| " x[1:] += dl * b[:-1]\n", | |
| " x[:-1] += du * b[1:]\n", | |
| " else:\n", | |
| " x[1:, :] += dl[:, None] * b[:-1, :]\n", | |
| " x[:-1, :] += du[:, None] * b[1:, :]\n", | |
| " \n", | |
| " return x\n", | |
| " \n", | |
| " \n", | |
| "# Mask inaccurate errors on M4 mac, see https://github.com/numpy/numpy/issues/28687 \n", | |
| "@np.errstate(all=\"ignore\")\n", | |
| "def test_dots(A, b, assume_a='gen'):\n", | |
| " print('Default: ', end='\\t')\n", | |
| " x1 = A @ b\n", | |
| " %timeit A @ b\n", | |
| " \n", | |
| " print('Specialized: ', end='\\t')\n", | |
| " if assume_a == 'sym':\n", | |
| " if b.ndim == 1:\n", | |
| " f = linalg.get_blas_funcs('symv')\n", | |
| " \n", | |
| " alpha = 1\n", | |
| " x2 = f(alpha, A, b)\n", | |
| " %timeit f(alpha, A, b)\n", | |
| " else:\n", | |
| " f = linalg.get_blas_funcs('symm')\n", | |
| " alpha = 1\n", | |
| "\n", | |
| " x2 = f(alpha, A, b)\n", | |
| " %timeit f(alpha, A, b)\n", | |
| " \n", | |
| " elif assume_a == 'pos':\n", | |
| " raise NotImplementedError\n", | |
| " elif assume_a == 'tri':\n", | |
| " if b.ndim == 1:\n", | |
| " f = linalg.get_blas_funcs('trmv')\n", | |
| " x2 = f(A, b, lower=1)\n", | |
| " %timeit f(A, b, lower=1)\n", | |
| " else:\n", | |
| " f = linalg.get_blas_funcs('trmm')\n", | |
| " x2 = f(1, A, b, lower=1)\n", | |
| " %timeit f(1, A, b, lower=1)\n", | |
| " elif assume_a == 'tridiag':\n", | |
| " if b.ndim == 1:\n", | |
| " m, n = A.shape\n", | |
| " kl, ku = 1, 1\n", | |
| " alpha = 1\n", | |
| " dl, d, du = [np.diag(A, k=k) for k in [-1, 0, 1]]\n", | |
| " a = np.c_[np.r_[0, du], d, np.r_[dl, 0]]\n", | |
| "\n", | |
| " f = linalg.get_blas_funcs('gbmv')\n", | |
| " x2 = f(m, n, kl, ku, alpha, a.T, b)\n", | |
| " %timeit f(m, n, kl, ku, alpha, a.T, b)\n", | |
| " else:\n", | |
| " raise NotImplementedError\n", | |
| " elif assume_a == 'tridiag-direct':\n", | |
| " m, n = A.shape\n", | |
| " dl, d, du = [np.diag(A, k=k) for k in [-1, 0, 1]]\n", | |
| " x2 = direct_tridiag(dl, d, du, b)\n", | |
| " %timeit direct_tridiag(dl, d, du, b)\n", | |
| " \n", | |
| " np.testing.assert_allclose(x1, x2, atol=1e-6, rtol=1e-6)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "id": "9fcac217", | |
| "metadata": {}, | |
| "source": [ | |
| "# Matrix-Vector" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "id": "0509f5d4", | |
| "metadata": {}, | |
| "source": [ | |
| "## Triangular" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 2, | |
| "id": "72801bc6", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "---------------------------------------Testing Shape (10, 10)---------------------------------------\n", | |
| "Default: \t343 ns ± 5 ns per loop (mean ± std. dev. of 7 runs, 1,000,000 loops each)\n", | |
| "Specialized: \t275 ns ± 1.21 ns per loop (mean ± std. dev. of 7 runs, 1,000,000 loops each)\n", | |
| "--------------------------------------Testing Shape (100, 100)--------------------------------------\n", | |
| "Default: \t919 ns ± 8.73 ns per loop (mean ± std. dev. of 7 runs, 1,000,000 loops each)\n", | |
| "Specialized: \t696 ns ± 3.46 ns per loop (mean ± std. dev. of 7 runs, 1,000,000 loops each)\n", | |
| "-------------------------------------Testing Shape (1000, 1000)-------------------------------------\n", | |
| "Default: \t13.9 μs ± 227 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)\n", | |
| "Specialized: \t8.66 μs ± 144 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)\n", | |
| "------------------------------------Testing Shape (10000, 10000)------------------------------------\n", | |
| "Default: \t3.27 ms ± 13.4 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)\n", | |
| "Specialized: \t3.13 ms ± 5.28 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "for size in [(10, 10), (100, 100), (1000, 1000), (10_000, 10_000)]:\n", | |
| " print(f'Testing Shape {size}'.center(100, '-'))\n", | |
| " A, b = make_Ab(size=size, assume_a='tri', b_ndim=1)\n", | |
| " test_dots(A, b, assume_a='tri')" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "id": "3443c3b6", | |
| "metadata": {}, | |
| "source": [ | |
| "## Symmetrical" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 3, | |
| "id": "a23609cb", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "---------------------------------------Testing Shape (10, 10)---------------------------------------\n", | |
| "Default: \t362 ns ± 11.2 ns per loop (mean ± std. dev. of 7 runs, 1,000,000 loops each)\n", | |
| "Specialized: \t162 ns ± 1.06 ns per loop (mean ± std. dev. of 7 runs, 10,000,000 loops each)\n", | |
| "--------------------------------------Testing Shape (100, 100)--------------------------------------\n", | |
| "Default: \t922 ns ± 14.3 ns per loop (mean ± std. dev. of 7 runs, 1,000,000 loops each)\n", | |
| "Specialized: \t1.92 μs ± 50.9 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)\n", | |
| "-------------------------------------Testing Shape (1000, 1000)-------------------------------------\n", | |
| "Default: \t13.6 μs ± 63.3 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)\n", | |
| "Specialized: \t43 μs ± 104 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)\n", | |
| "------------------------------------Testing Shape (10000, 10000)------------------------------------\n", | |
| "Default: \t3.26 ms ± 14.6 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)\n", | |
| "Specialized: \t12.8 ms ± 45.8 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "for size in [(10, 10), (100, 100), (1000, 1000), (10_000, 10_000)]:\n", | |
| " print(f'Testing Shape {size}'.center(100, '-'))\n", | |
| " A, b = make_Ab(size=size, assume_a='sym', b_ndim=1)\n", | |
| " test_dots(A, b, assume_a='sym')" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "id": "a37b32fa", | |
| "metadata": {}, | |
| "source": [ | |
| "## Tridiagonal" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 4, | |
| "id": "17bdb121", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "---------------------------------------Testing Shape (10, 10)---------------------------------------\n", | |
| "Default: \t366 ns ± 4.91 ns per loop (mean ± std. dev. of 7 runs, 1,000,000 loops each)\n", | |
| "Specialized: \t242 ns ± 2.64 ns per loop (mean ± std. dev. of 7 runs, 1,000,000 loops each)\n", | |
| "--------------------------------------Testing Shape (100, 100)--------------------------------------\n", | |
| "Default: \t960 ns ± 28 ns per loop (mean ± std. dev. of 7 runs, 1,000,000 loops each)\n", | |
| "Specialized: \t328 ns ± 6.16 ns per loop (mean ± std. dev. of 7 runs, 1,000,000 loops each)\n", | |
| "-------------------------------------Testing Shape (1000, 1000)-------------------------------------\n", | |
| "Default: \t13.7 μs ± 70.7 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)\n", | |
| "Specialized: \t1.58 μs ± 6.94 ns per loop (mean ± std. dev. of 7 runs, 1,000,000 loops each)\n", | |
| "------------------------------------Testing Shape (10000, 10000)------------------------------------\n", | |
| "Default: \t3.24 ms ± 15.6 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)\n", | |
| "Specialized: \t10.8 μs ± 43 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "for size in [(10, 10), (100, 100), (1000, 1000), (10_000, 10_000)]:\n", | |
| " print(f'Testing Shape {size}'.center(100, '-'))\n", | |
| " A, b = make_Ab(size=size, assume_a='tridiag', b_ndim=1)\n", | |
| " test_dots(A, b, assume_a='tridiag')" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 5, | |
| "id": "085dfd7d", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "---------------------------------------Testing Shape (10, 10)---------------------------------------\n", | |
| "Default: \t359 ns ± 8.39 ns per loop (mean ± std. dev. of 7 runs, 1,000,000 loops each)\n", | |
| "Specialized: \t1.26 μs ± 10.8 ns per loop (mean ± std. dev. of 7 runs, 1,000,000 loops each)\n", | |
| "--------------------------------------Testing Shape (100, 100)--------------------------------------\n", | |
| "Default: \t1.04 μs ± 35.1 ns per loop (mean ± std. dev. of 7 runs, 1,000,000 loops each)\n", | |
| "Specialized: \t1.38 μs ± 20.2 ns per loop (mean ± std. dev. of 7 runs, 1,000,000 loops each)\n", | |
| "-------------------------------------Testing Shape (1000, 1000)-------------------------------------\n", | |
| "Default: \t13.6 μs ± 61.7 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)\n", | |
| "Specialized: \t2.68 μs ± 18.7 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)\n", | |
| "------------------------------------Testing Shape (10000, 10000)------------------------------------\n", | |
| "Default: \t3.23 ms ± 5.55 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)\n", | |
| "Specialized: \t65.8 μs ± 2.36 μs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "for size in [(10, 10), (100, 100), (1000, 1000), (10_000, 10_000)]:\n", | |
| " print(f'Testing Shape {size}'.center(100, '-'))\n", | |
| " A, b = make_Ab(size=size, assume_a='tridiag', b_ndim=1)\n", | |
| " test_dots(A, b, assume_a='tridiag-direct')" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "id": "d302e064", | |
| "metadata": {}, | |
| "source": [ | |
| "# Matrix-Matrix" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "id": "4f5d6643", | |
| "metadata": {}, | |
| "source": [ | |
| "## Triangular" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 6, | |
| "id": "6450ce84", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "---------------------------------------Testing Shape (10, 10)---------------------------------------\n", | |
| "Default: \t542 ns ± 2.82 ns per loop (mean ± std. dev. of 7 runs, 1,000,000 loops each)\n", | |
| "Specialized: \t369 ns ± 3.11 ns per loop (mean ± std. dev. of 7 runs, 1,000,000 loops each)\n", | |
| "--------------------------------------Testing Shape (100, 100)--------------------------------------\n", | |
| "Default: \t7.81 μs ± 19.4 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)\n", | |
| "Specialized: \t8.13 μs ± 306 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)\n", | |
| "-------------------------------------Testing Shape (1000, 1000)-------------------------------------\n", | |
| "Default: \t2.51 ms ± 31.2 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)\n", | |
| "Specialized: \t2 ms ± 29.5 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)\n", | |
| "------------------------------------Testing Shape (10000, 10000)------------------------------------\n", | |
| "Default: \t2.58 s ± 7.24 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)\n", | |
| "Specialized: \t1.28 s ± 11.6 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "for size in [(10, 10), (100, 100), (1000, 1000), (10_000, 10_000)]:\n", | |
| " print(f'Testing Shape {size}'.center(100, '-'))\n", | |
| " A, b = make_Ab(size=size, assume_a='tri', b_ndim=2)\n", | |
| " test_dots(A, b, assume_a='tri')" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "id": "93fb86e8", | |
| "metadata": {}, | |
| "source": [ | |
| "## Symmetric" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 7, | |
| "id": "fd9e3133", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "---------------------------------------Testing Shape (10, 10)---------------------------------------\n", | |
| "Default: \t542 ns ± 3.71 ns per loop (mean ± std. dev. of 7 runs, 1,000,000 loops each)\n", | |
| "Specialized: \t274 ns ± 4.02 ns per loop (mean ± std. dev. of 7 runs, 1,000,000 loops each)\n", | |
| "--------------------------------------Testing Shape (100, 100)--------------------------------------\n", | |
| "Default: \t7.85 μs ± 36.2 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)\n", | |
| "Specialized: \t15 μs ± 574 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)\n", | |
| "-------------------------------------Testing Shape (1000, 1000)-------------------------------------\n", | |
| "Default: \t2.69 ms ± 31.8 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)\n", | |
| "Specialized: \t3.02 ms ± 27.3 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)\n", | |
| "------------------------------------Testing Shape (10000, 10000)------------------------------------\n", | |
| "Default: \t2.58 s ± 4.77 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)\n", | |
| "Specialized: \t2.6 s ± 4.35 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "for size in [(10, 10), (100, 100), (1000, 1000), (10_000, 10_000)]:\n", | |
| " print(f'Testing Shape {size}'.center(100, '-'))\n", | |
| " A, b = make_Ab(size=size, assume_a='sym', b_ndim=2)\n", | |
| " test_dots(A, b, assume_a='sym')" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "id": "4d8a7b14", | |
| "metadata": {}, | |
| "source": [ | |
| "## Tridiagonal" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 8, | |
| "id": "ae133aeb", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "---------------------------------------Testing Shape (10, 10)---------------------------------------\n", | |
| "Default: \t543 ns ± 3.92 ns per loop (mean ± std. dev. of 7 runs, 1,000,000 loops each)\n", | |
| "Specialized: \t3.99 μs ± 24.8 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)\n", | |
| "--------------------------------------Testing Shape (100, 100)--------------------------------------\n", | |
| "Default: \t7.83 μs ± 25.4 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)\n", | |
| "Specialized: \t60 μs ± 13.1 μs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)\n", | |
| "-------------------------------------Testing Shape (1000, 1000)-------------------------------------\n", | |
| "Default: \t2.49 ms ± 32.7 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)\n", | |
| "Specialized: \t4.03 ms ± 85.1 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)\n", | |
| "------------------------------------Testing Shape (10000, 10000)------------------------------------\n", | |
| "Default: \t2.59 s ± 5.65 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)\n", | |
| "Specialized: \t758 ms ± 3.39 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "for size in [(10, 10), (100, 100), (1000, 1000), (10_000, 10_000)]:\n", | |
| " print(f'Testing Shape {size}'.center(100, '-'))\n", | |
| " A, b = make_Ab(size=size, assume_a='tridiag', b_ndim=2)\n", | |
| " test_dots(A, b, assume_a='tridiag-direct')" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "id": "47dfd569", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [] | |
| } | |
| ], | |
| "metadata": { | |
| "kernelspec": { | |
| "display_name": "Python 3 (ipykernel)", | |
| "language": "python", | |
| "name": "python3" | |
| }, | |
| "language_info": { | |
| "codemirror_mode": { | |
| "name": "ipython", | |
| "version": 3 | |
| }, | |
| "file_extension": ".py", | |
| "mimetype": "text/x-python", | |
| "name": "python", | |
| "nbconvert_exporter": "python", | |
| "pygments_lexer": "ipython3", | |
| "version": "3.12.9" | |
| } | |
| }, | |
| "nbformat": 4, | |
| "nbformat_minor": 5 | |
| } |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment