Skip to content

Instantly share code, notes, and snippets.

@jessegrabowski
Last active May 23, 2025 07:00
Show Gist options
  • Select an option

  • Save jessegrabowski/f194bf5cdf92937b42075027fa46f3d5 to your computer and use it in GitHub Desktop.

Select an option

Save jessegrabowski/f194bf5cdf92937b42075027fa46f3d5 to your computer and use it in GitHub Desktop.
Specialized Dot Benchmarks
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"id": "36896a04",
"metadata": {},
"outputs": [],
"source": [
"import numpy as np\n",
"from scipy import linalg\n",
"\n",
"\n",
"SEED = sum(map(ord, 'matrix madness!'))\n",
"rng = np.random.default_rng(SEED)\n",
"\n",
"def make_Ab(size=(10, 10), b_ndim=1, assume_a='gen'):\n",
" assert b_ndim in [1, 2]\n",
" assert assume_a in ['gen', 'sym', 'pos', 'tri', 'tridiag', 'tridiag-direct']\n",
" \n",
" A = rng.normal(size=size)\n",
" b = rng.normal(size=(size[-1],) * b_ndim)\n",
" \n",
" if assume_a == 'sym':\n",
" A = (A + A.T) / 2\n",
" elif assume_a == 'pos':\n",
" A = A @ A.T\n",
" elif assume_a == 'tri':\n",
" A = np.tril(A)\n",
" elif assume_a == 'tridiag':\n",
" dl, d, du = (np.diag(A, k=k) for k in [-1, 0, 1])\n",
" A = np.diag(dl, k=-1) + np.diag(d) + np.diag(du, k=1)\n",
" \n",
" return np.asfortranarray(A), np.asfortranarray(b)\n",
"\n",
"def direct_tridiag(dl, d, du, b):\n",
" b = np.asarray(b)\n",
" x = d[..., :, None] * b if b.ndim == 2 else d * b\n",
"\n",
" if b.ndim == 1:\n",
" x[1:] += dl * b[:-1]\n",
" x[:-1] += du * b[1:]\n",
" else:\n",
" x[1:, :] += dl[:, None] * b[:-1, :]\n",
" x[:-1, :] += du[:, None] * b[1:, :]\n",
" \n",
" return x\n",
" \n",
" \n",
"# Mask inaccurate errors on M4 mac, see https://github.com/numpy/numpy/issues/28687 \n",
"@np.errstate(all=\"ignore\")\n",
"def test_dots(A, b, assume_a='gen'):\n",
" print('Default: ', end='\\t')\n",
" x1 = A @ b\n",
" %timeit A @ b\n",
" \n",
" print('Specialized: ', end='\\t')\n",
" if assume_a == 'sym':\n",
" if b.ndim == 1:\n",
" f = linalg.get_blas_funcs('symv')\n",
" \n",
" alpha = 1\n",
" x2 = f(alpha, A, b)\n",
" %timeit f(alpha, A, b)\n",
" else:\n",
" f = linalg.get_blas_funcs('symm')\n",
" alpha = 1\n",
"\n",
" x2 = f(alpha, A, b)\n",
" %timeit f(alpha, A, b)\n",
" \n",
" elif assume_a == 'pos':\n",
" raise NotImplementedError\n",
" elif assume_a == 'tri':\n",
" if b.ndim == 1:\n",
" f = linalg.get_blas_funcs('trmv')\n",
" x2 = f(A, b, lower=1)\n",
" %timeit f(A, b, lower=1)\n",
" else:\n",
" f = linalg.get_blas_funcs('trmm')\n",
" x2 = f(1, A, b, lower=1)\n",
" %timeit f(1, A, b, lower=1)\n",
" elif assume_a == 'tridiag':\n",
" if b.ndim == 1:\n",
" m, n = A.shape\n",
" kl, ku = 1, 1\n",
" alpha = 1\n",
" dl, d, du = [np.diag(A, k=k) for k in [-1, 0, 1]]\n",
" a = np.c_[np.r_[0, du], d, np.r_[dl, 0]]\n",
"\n",
" f = linalg.get_blas_funcs('gbmv')\n",
" x2 = f(m, n, kl, ku, alpha, a.T, b)\n",
" %timeit f(m, n, kl, ku, alpha, a.T, b)\n",
" else:\n",
" raise NotImplementedError\n",
" elif assume_a == 'tridiag-direct':\n",
" m, n = A.shape\n",
" dl, d, du = [np.diag(A, k=k) for k in [-1, 0, 1]]\n",
" x2 = direct_tridiag(dl, d, du, b)\n",
" %timeit direct_tridiag(dl, d, du, b)\n",
" \n",
" np.testing.assert_allclose(x1, x2, atol=1e-6, rtol=1e-6)"
]
},
{
"cell_type": "markdown",
"id": "9fcac217",
"metadata": {},
"source": [
"# Matrix-Vector"
]
},
{
"cell_type": "markdown",
"id": "0509f5d4",
"metadata": {},
"source": [
"## Triangular"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "72801bc6",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"---------------------------------------Testing Shape (10, 10)---------------------------------------\n",
"Default: \t343 ns ± 5 ns per loop (mean ± std. dev. of 7 runs, 1,000,000 loops each)\n",
"Specialized: \t275 ns ± 1.21 ns per loop (mean ± std. dev. of 7 runs, 1,000,000 loops each)\n",
"--------------------------------------Testing Shape (100, 100)--------------------------------------\n",
"Default: \t919 ns ± 8.73 ns per loop (mean ± std. dev. of 7 runs, 1,000,000 loops each)\n",
"Specialized: \t696 ns ± 3.46 ns per loop (mean ± std. dev. of 7 runs, 1,000,000 loops each)\n",
"-------------------------------------Testing Shape (1000, 1000)-------------------------------------\n",
"Default: \t13.9 μs ± 227 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)\n",
"Specialized: \t8.66 μs ± 144 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)\n",
"------------------------------------Testing Shape (10000, 10000)------------------------------------\n",
"Default: \t3.27 ms ± 13.4 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)\n",
"Specialized: \t3.13 ms ± 5.28 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)\n"
]
}
],
"source": [
"for size in [(10, 10), (100, 100), (1000, 1000), (10_000, 10_000)]:\n",
" print(f'Testing Shape {size}'.center(100, '-'))\n",
" A, b = make_Ab(size=size, assume_a='tri', b_ndim=1)\n",
" test_dots(A, b, assume_a='tri')"
]
},
{
"cell_type": "markdown",
"id": "3443c3b6",
"metadata": {},
"source": [
"## Symmetrical"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "a23609cb",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"---------------------------------------Testing Shape (10, 10)---------------------------------------\n",
"Default: \t362 ns ± 11.2 ns per loop (mean ± std. dev. of 7 runs, 1,000,000 loops each)\n",
"Specialized: \t162 ns ± 1.06 ns per loop (mean ± std. dev. of 7 runs, 10,000,000 loops each)\n",
"--------------------------------------Testing Shape (100, 100)--------------------------------------\n",
"Default: \t922 ns ± 14.3 ns per loop (mean ± std. dev. of 7 runs, 1,000,000 loops each)\n",
"Specialized: \t1.92 μs ± 50.9 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)\n",
"-------------------------------------Testing Shape (1000, 1000)-------------------------------------\n",
"Default: \t13.6 μs ± 63.3 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)\n",
"Specialized: \t43 μs ± 104 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)\n",
"------------------------------------Testing Shape (10000, 10000)------------------------------------\n",
"Default: \t3.26 ms ± 14.6 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)\n",
"Specialized: \t12.8 ms ± 45.8 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)\n"
]
}
],
"source": [
"for size in [(10, 10), (100, 100), (1000, 1000), (10_000, 10_000)]:\n",
" print(f'Testing Shape {size}'.center(100, '-'))\n",
" A, b = make_Ab(size=size, assume_a='sym', b_ndim=1)\n",
" test_dots(A, b, assume_a='sym')"
]
},
{
"cell_type": "markdown",
"id": "a37b32fa",
"metadata": {},
"source": [
"## Tridiagonal"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "17bdb121",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"---------------------------------------Testing Shape (10, 10)---------------------------------------\n",
"Default: \t366 ns ± 4.91 ns per loop (mean ± std. dev. of 7 runs, 1,000,000 loops each)\n",
"Specialized: \t242 ns ± 2.64 ns per loop (mean ± std. dev. of 7 runs, 1,000,000 loops each)\n",
"--------------------------------------Testing Shape (100, 100)--------------------------------------\n",
"Default: \t960 ns ± 28 ns per loop (mean ± std. dev. of 7 runs, 1,000,000 loops each)\n",
"Specialized: \t328 ns ± 6.16 ns per loop (mean ± std. dev. of 7 runs, 1,000,000 loops each)\n",
"-------------------------------------Testing Shape (1000, 1000)-------------------------------------\n",
"Default: \t13.7 μs ± 70.7 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)\n",
"Specialized: \t1.58 μs ± 6.94 ns per loop (mean ± std. dev. of 7 runs, 1,000,000 loops each)\n",
"------------------------------------Testing Shape (10000, 10000)------------------------------------\n",
"Default: \t3.24 ms ± 15.6 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)\n",
"Specialized: \t10.8 μs ± 43 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)\n"
]
}
],
"source": [
"for size in [(10, 10), (100, 100), (1000, 1000), (10_000, 10_000)]:\n",
" print(f'Testing Shape {size}'.center(100, '-'))\n",
" A, b = make_Ab(size=size, assume_a='tridiag', b_ndim=1)\n",
" test_dots(A, b, assume_a='tridiag')"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "085dfd7d",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"---------------------------------------Testing Shape (10, 10)---------------------------------------\n",
"Default: \t359 ns ± 8.39 ns per loop (mean ± std. dev. of 7 runs, 1,000,000 loops each)\n",
"Specialized: \t1.26 μs ± 10.8 ns per loop (mean ± std. dev. of 7 runs, 1,000,000 loops each)\n",
"--------------------------------------Testing Shape (100, 100)--------------------------------------\n",
"Default: \t1.04 μs ± 35.1 ns per loop (mean ± std. dev. of 7 runs, 1,000,000 loops each)\n",
"Specialized: \t1.38 μs ± 20.2 ns per loop (mean ± std. dev. of 7 runs, 1,000,000 loops each)\n",
"-------------------------------------Testing Shape (1000, 1000)-------------------------------------\n",
"Default: \t13.6 μs ± 61.7 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)\n",
"Specialized: \t2.68 μs ± 18.7 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)\n",
"------------------------------------Testing Shape (10000, 10000)------------------------------------\n",
"Default: \t3.23 ms ± 5.55 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)\n",
"Specialized: \t65.8 μs ± 2.36 μs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)\n"
]
}
],
"source": [
"for size in [(10, 10), (100, 100), (1000, 1000), (10_000, 10_000)]:\n",
" print(f'Testing Shape {size}'.center(100, '-'))\n",
" A, b = make_Ab(size=size, assume_a='tridiag', b_ndim=1)\n",
" test_dots(A, b, assume_a='tridiag-direct')"
]
},
{
"cell_type": "markdown",
"id": "d302e064",
"metadata": {},
"source": [
"# Matrix-Matrix"
]
},
{
"cell_type": "markdown",
"id": "4f5d6643",
"metadata": {},
"source": [
"## Triangular"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "6450ce84",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"---------------------------------------Testing Shape (10, 10)---------------------------------------\n",
"Default: \t542 ns ± 2.82 ns per loop (mean ± std. dev. of 7 runs, 1,000,000 loops each)\n",
"Specialized: \t369 ns ± 3.11 ns per loop (mean ± std. dev. of 7 runs, 1,000,000 loops each)\n",
"--------------------------------------Testing Shape (100, 100)--------------------------------------\n",
"Default: \t7.81 μs ± 19.4 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)\n",
"Specialized: \t8.13 μs ± 306 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)\n",
"-------------------------------------Testing Shape (1000, 1000)-------------------------------------\n",
"Default: \t2.51 ms ± 31.2 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)\n",
"Specialized: \t2 ms ± 29.5 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)\n",
"------------------------------------Testing Shape (10000, 10000)------------------------------------\n",
"Default: \t2.58 s ± 7.24 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)\n",
"Specialized: \t1.28 s ± 11.6 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)\n"
]
}
],
"source": [
"for size in [(10, 10), (100, 100), (1000, 1000), (10_000, 10_000)]:\n",
" print(f'Testing Shape {size}'.center(100, '-'))\n",
" A, b = make_Ab(size=size, assume_a='tri', b_ndim=2)\n",
" test_dots(A, b, assume_a='tri')"
]
},
{
"cell_type": "markdown",
"id": "93fb86e8",
"metadata": {},
"source": [
"## Symmetric"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "fd9e3133",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"---------------------------------------Testing Shape (10, 10)---------------------------------------\n",
"Default: \t542 ns ± 3.71 ns per loop (mean ± std. dev. of 7 runs, 1,000,000 loops each)\n",
"Specialized: \t274 ns ± 4.02 ns per loop (mean ± std. dev. of 7 runs, 1,000,000 loops each)\n",
"--------------------------------------Testing Shape (100, 100)--------------------------------------\n",
"Default: \t7.85 μs ± 36.2 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)\n",
"Specialized: \t15 μs ± 574 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)\n",
"-------------------------------------Testing Shape (1000, 1000)-------------------------------------\n",
"Default: \t2.69 ms ± 31.8 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)\n",
"Specialized: \t3.02 ms ± 27.3 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)\n",
"------------------------------------Testing Shape (10000, 10000)------------------------------------\n",
"Default: \t2.58 s ± 4.77 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)\n",
"Specialized: \t2.6 s ± 4.35 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)\n"
]
}
],
"source": [
"for size in [(10, 10), (100, 100), (1000, 1000), (10_000, 10_000)]:\n",
" print(f'Testing Shape {size}'.center(100, '-'))\n",
" A, b = make_Ab(size=size, assume_a='sym', b_ndim=2)\n",
" test_dots(A, b, assume_a='sym')"
]
},
{
"cell_type": "markdown",
"id": "4d8a7b14",
"metadata": {},
"source": [
"## Tridiagonal"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "ae133aeb",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"---------------------------------------Testing Shape (10, 10)---------------------------------------\n",
"Default: \t543 ns ± 3.92 ns per loop (mean ± std. dev. of 7 runs, 1,000,000 loops each)\n",
"Specialized: \t3.99 μs ± 24.8 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)\n",
"--------------------------------------Testing Shape (100, 100)--------------------------------------\n",
"Default: \t7.83 μs ± 25.4 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)\n",
"Specialized: \t60 μs ± 13.1 μs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)\n",
"-------------------------------------Testing Shape (1000, 1000)-------------------------------------\n",
"Default: \t2.49 ms ± 32.7 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)\n",
"Specialized: \t4.03 ms ± 85.1 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)\n",
"------------------------------------Testing Shape (10000, 10000)------------------------------------\n",
"Default: \t2.59 s ± 5.65 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)\n",
"Specialized: \t758 ms ± 3.39 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)\n"
]
}
],
"source": [
"for size in [(10, 10), (100, 100), (1000, 1000), (10_000, 10_000)]:\n",
" print(f'Testing Shape {size}'.center(100, '-'))\n",
" A, b = make_Ab(size=size, assume_a='tridiag', b_ndim=2)\n",
" test_dots(A, b, assume_a='tridiag-direct')"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "47dfd569",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.9"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment