Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Select an option

  • Save oyamad/19e1b44df2db1373a2fc39b3967775e3 to your computer and use it in GitHub Desktop.

Select an option

Save oyamad/19e1b44df2db1373a2fc39b3967775e3 to your computer and use it in GitHub Desktop.
Passing a jitted function to another jitted function as an argument
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Simple example\n",
"\n",
"Original version by Pablo Winant:\n",
"http://nbviewer.jupyter.org/github/albop/numba_experiments/blob/master/Using%20a%20class%20to%20define%20functions%20with%20numba.ipynb\n",
"\n",
"Given a function $f$, a strictly positive integer and a real number $x$ we want to define the function $I(f)$ such that: $I(f)(x)=\\frac{1}{N} \\sum_{n=0}^{N-1} f(\\frac{n}{N-1} x) $"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"'0.38.0'"
]
},
"execution_count": 1,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"import numba\n",
"numba.__version__"
]
},
{
"cell_type": "markdown",
"metadata": {
"collapsed": true
},
"source": [
"## code"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"def fjit(fun):\n",
" 'just-in-time compile a function by wrapping it in a singleton class'\n",
" \n",
" from numba import jitclass\n",
" import time\n",
" \n",
" # the function is jitted first\n",
" jitted_fun = njit(fun)\n",
"\n",
" # Generate a random class name like 'Singleton_Sat_Jan__2_18_08_32_2016'\n",
" classname = 'Singleton_' + time.asctime().replace(' ','_').replace(':','_')\n",
" \n",
" # programmatically create a class equivalent to :\n",
" # class Singleton_Sat_Jan__2_18_08_32_2016:\n",
" # def __init__(self): pass\n",
" # def value(self, x): return fj(x)\n",
" \n",
" def __init__(self): pass\n",
" def value(self, x): return jitted_fun(x)\n",
" SingletonClass = type(classname, (object,), {'__init__': __init__, 'value': value})\n",
" \n",
" # jit compile the class\n",
" # spec is [] since we don't store attributes\n",
" spec = []\n",
" sc = jitclass(spec)(SingletonClass)\n",
" \n",
" # return a unique instance of the class\n",
" return sc()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## usage"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"from numba import njit"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"@fjit\n",
"def f(x):\n",
" return x**2"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"9.0"
]
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# f is now a \"jitted\" function (wrapped in a class)\n",
"f.value(3.)"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"from numpy import linspace\n",
"@njit\n",
"def function_of_a_function(f,x,N):\n",
" xvec = linspace(0,x,N)\n",
" t = 0.0\n",
" for i in range(N):\n",
" t += f.value(xvec[i])\n",
" return t/N"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"CPU times: user 158 ms, sys: 3.64 ms, total: 162 ms\n",
"Wall time: 160 ms\n",
"CPU times: user 169 µs, sys: 1 µs, total: 170 µs\n",
"Wall time: 173 µs\n"
]
},
{
"data": {
"text/plain": [
"0.3333500016668329"
]
},
"execution_count": 7,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"%time function_of_a_function(f, 1, 10000)\n",
"%time function_of_a_function(f, 1, 10000)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Performance cost\n",
"\n",
"Here is a comparison of several ways to compute: $\\frac{1}{N}\\sum_{n=1}^N f(x)$ (the example is deliberately very simple so that automatic inlining can work).\n",
"\n",
"The performance ranking is as follows (timing) (updated):\n",
"\n",
"- copy and paste + jit (x1)\n",
"- njitted method (x1)\n",
"- jitted function argument (x6) \n",
"- jit-class method (x1600)\n",
"- pure python (x8000)"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [],
"source": [
"# functions\n",
"\n",
"# pure python\n",
"def fun(x):\n",
" return x**2\n",
"\n",
"# jitclass\n",
"from numba import jitclass\n",
"spec = []\n",
"@jitclass(spec)\n",
"class SingletonClass:\n",
" def __init__(self):\n",
" pass\n",
" def method(self, x):\n",
" return x**2\n",
"sc = SingletonClass()\n",
"\n",
"# jitted\n",
"@njit\n",
"def jitted_fun(x):\n",
" return x**2"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [],
"source": [
"# functions of functions\n",
"\n",
"def funfun_direct(x,N):\n",
" t = 0.0\n",
" for i in range(N):\n",
" t = (x)**2\n",
" return t\n",
"\n",
"funfun_direct_jitted = njit(funfun_direct)\n",
"\n",
"@njit\n",
"def funfun_jitted(x,N):\n",
" t = 0.0\n",
" for i in range(N):\n",
" t = jitted_fun(x)\n",
" return t\n",
"\n",
"@njit\n",
"def funfun_class(f,x,N):\n",
" t = 0.0\n",
" for i in range(N):\n",
" t = f.method(x)\n",
" return t"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [],
"source": [
"# jitted function passed as an argument\n",
"@njit\n",
"def funfun_jitted_arg(f, x, N):\n",
" t = 0.0\n",
" for i in range(N):\n",
" t = f(x)\n",
" return t"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [],
"source": [
"N = 1000000\n",
"x = 1.0"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"CPU times: user 75.7 ms, sys: 2.17 ms, total: 77.9 ms\n",
"Wall time: 76.2 ms\n"
]
},
{
"data": {
"text/plain": [
"1.0"
]
},
"execution_count": 12,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# pure python code\n",
"%time funfun_direct(x, N)"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {
"scrolled": true
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"CPU times: user 121 ms, sys: 3.18 ms, total: 124 ms\n",
"Wall time: 122 ms\n",
"CPU times: user 14.7 ms, sys: 129 µs, total: 14.9 ms\n",
"Wall time: 14.8 ms\n"
]
},
{
"data": {
"text/plain": [
"1.0"
]
},
"execution_count": 13,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# jit class\n",
"%time funfun_class(sc,x,N)\n",
"%time funfun_class(sc,x,N)"
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"CPU times: user 52.8 ms, sys: 3.16 ms, total: 55.9 ms\n",
"Wall time: 53.4 ms\n",
"CPU times: user 5 µs, sys: 0 ns, total: 5 µs\n",
"Wall time: 9.06 µs\n"
]
},
{
"data": {
"text/plain": [
"1.0"
]
},
"execution_count": 14,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# jitted, manual inlining\n",
"%time funfun_direct_jitted(x, N)\n",
"%time funfun_direct_jitted(x, N)"
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"CPU times: user 69.8 ms, sys: 2.51 ms, total: 72.3 ms\n",
"Wall time: 70.1 ms\n",
"CPU times: user 5 µs, sys: 1 µs, total: 6 µs\n",
"Wall time: 9.78 µs\n"
]
},
{
"data": {
"text/plain": [
"1.0"
]
},
"execution_count": 15,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# jitted automatic inlining\n",
"%time funfun_jitted(x,N)\n",
"%time funfun_jitted(x,N)"
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {
"scrolled": false
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"CPU times: user 52.4 ms, sys: 3.28 ms, total: 55.7 ms\n",
"Wall time: 53.1 ms\n",
"CPU times: user 48 µs, sys: 1 µs, total: 49 µs\n",
"Wall time: 52.9 µs\n"
]
},
{
"data": {
"text/plain": [
"1.0"
]
},
"execution_count": 16,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# jitted function passed as an argument\n",
"%time funfun_jitted_arg(jitted_fun, x, N)\n",
"%time funfun_jitted_arg(jitted_fun, x, N)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## strange inlining problem:\n",
"\n",
"If x is not passed as an argument but defined in the function instead, inlining doesn't seem to happen."
]
},
{
"cell_type": "code",
"execution_count": 17,
"metadata": {},
"outputs": [],
"source": [
"@njit\n",
"def funfun_manual_inline(N):\n",
" t = 0.0\n",
" x = 2.0\n",
" for i in range(N):\n",
" t = (x)**2\n",
" return t\n",
"\n",
"@njit\n",
"def funfun_auto_inline(N):\n",
" t = 0.0\n",
" x = 2.0\n",
" for i in range(N):\n",
" t = jitted_fun(x)\n",
" return t"
]
},
{
"cell_type": "code",
"execution_count": 18,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"CPU times: user 47.6 ms, sys: 2.15 ms, total: 49.8 ms\n",
"Wall time: 48.2 ms\n",
"CPU times: user 5 µs, sys: 1 µs, total: 6 µs\n",
"Wall time: 9.3 µs\n"
]
},
{
"data": {
"text/plain": [
"4.0"
]
},
"execution_count": 18,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"%time funfun_manual_inline(N)\n",
"%time funfun_manual_inline(N)"
]
},
{
"cell_type": "code",
"execution_count": 19,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"CPU times: user 52.9 ms, sys: 3.02 ms, total: 55.9 ms\n",
"Wall time: 54 ms\n",
"CPU times: user 6 µs, sys: 0 ns, total: 6 µs\n",
"Wall time: 9.06 µs\n"
]
},
{
"data": {
"text/plain": [
"4.0"
]
},
"execution_count": 19,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"%time funfun_auto_inline(N)\n",
"%time funfun_auto_inline(N)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.4"
}
},
"nbformat": 4,
"nbformat_minor": 1
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment