Created
April 28, 2018 05:01
-
-
Save oyamad/19e1b44df2db1373a2fc39b3967775e3 to your computer and use it in GitHub Desktop.
Passing a jitted function to another jitted function as an argument
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| { | |
| "cells": [ | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "# Simple example\n", | |
| "\n", | |
| "Original version by Pablo Winant:\n", | |
| "http://nbviewer.jupyter.org/github/albop/numba_experiments/blob/master/Using%20a%20class%20to%20define%20functions%20with%20numba.ipynb\n", | |
| "\n", | |
| "Given a function $f$, a strictly positive integer and a real number $x$ we want to define the function $I(f)$ such that: $I(f)(x)=\\frac{1}{N} \\sum_{n=0}^{N-1} f(\\frac{n}{N-1} x) $" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 1, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "'0.38.0'" | |
| ] | |
| }, | |
| "execution_count": 1, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "import numba\n", | |
| "numba.__version__" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": { | |
| "collapsed": true | |
| }, | |
| "source": [ | |
| "## code" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 2, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "def fjit(fun):\n", | |
| " 'just-in-time compile a function by wrapping it in a singleton class'\n", | |
| " \n", | |
| " from numba import jitclass\n", | |
| " import time\n", | |
| " \n", | |
| " # the function is jitted first\n", | |
| " jitted_fun = njit(fun)\n", | |
| "\n", | |
| " # Generate a random class name like 'Singleton_Sat_Jan__2_18_08_32_2016'\n", | |
| " classname = 'Singleton_' + time.asctime().replace(' ','_').replace(':','_')\n", | |
| " \n", | |
| " # programmatically create a class equivalent to :\n", | |
| " # class Singleton_Sat_Jan__2_18_08_32_2016:\n", | |
| " # def __init__(self): pass\n", | |
| " # def value(self, x): return fj(x)\n", | |
| " \n", | |
| " def __init__(self): pass\n", | |
| " def value(self, x): return jitted_fun(x)\n", | |
| " SingletonClass = type(classname, (object,), {'__init__': __init__, 'value': value})\n", | |
| " \n", | |
| " # jit compile the class\n", | |
| " # spec is [] since we don't store attributes\n", | |
| " spec = []\n", | |
| " sc = jitclass(spec)(SingletonClass)\n", | |
| " \n", | |
| " # return a unique instance of the class\n", | |
| " return sc()" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "## usage" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 3, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "from numba import njit" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 4, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "@fjit\n", | |
| "def f(x):\n", | |
| " return x**2" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 5, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "9.0" | |
| ] | |
| }, | |
| "execution_count": 5, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "# f is now a \"jitted\" function (wrapped in a class)\n", | |
| "f.value(3.)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 6, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "from numpy import linspace\n", | |
| "@njit\n", | |
| "def function_of_a_function(f,x,N):\n", | |
| " xvec = linspace(0,x,N)\n", | |
| " t = 0.0\n", | |
| " for i in range(N):\n", | |
| " t += f.value(xvec[i])\n", | |
| " return t/N" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 7, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "CPU times: user 158 ms, sys: 3.64 ms, total: 162 ms\n", | |
| "Wall time: 160 ms\n", | |
| "CPU times: user 169 µs, sys: 1 µs, total: 170 µs\n", | |
| "Wall time: 173 µs\n" | |
| ] | |
| }, | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "0.3333500016668329" | |
| ] | |
| }, | |
| "execution_count": 7, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "%time function_of_a_function(f, 1, 10000)\n", | |
| "%time function_of_a_function(f, 1, 10000)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "# Performance cost\n", | |
| "\n", | |
| "Here is a comparison of several ways to compute: $\\frac{1}{N}\\sum_{n=1}^N f(x)$ (the example is deliberately very simple so that automatic inlining can work).\n", | |
| "\n", | |
| "The performance ranking is as follows (timing) (updated):\n", | |
| "\n", | |
| "- copy and paste + jit (x1)\n", | |
| "- njitted method (x1)\n", | |
| "- jitted function argument (x6) \n", | |
| "- jit-class method (x1600)\n", | |
| "- pure python (x8000)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 8, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "# functions\n", | |
| "\n", | |
| "# pure python\n", | |
| "def fun(x):\n", | |
| " return x**2\n", | |
| "\n", | |
| "# jitclass\n", | |
| "from numba import jitclass\n", | |
| "spec = []\n", | |
| "@jitclass(spec)\n", | |
| "class SingletonClass:\n", | |
| " def __init__(self):\n", | |
| " pass\n", | |
| " def method(self, x):\n", | |
| " return x**2\n", | |
| "sc = SingletonClass()\n", | |
| "\n", | |
| "# jitted\n", | |
| "@njit\n", | |
| "def jitted_fun(x):\n", | |
| " return x**2" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 9, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "# functions of functions\n", | |
| "\n", | |
| "def funfun_direct(x,N):\n", | |
| " t = 0.0\n", | |
| " for i in range(N):\n", | |
| " t = (x)**2\n", | |
| " return t\n", | |
| "\n", | |
| "funfun_direct_jitted = njit(funfun_direct)\n", | |
| "\n", | |
| "@njit\n", | |
| "def funfun_jitted(x,N):\n", | |
| " t = 0.0\n", | |
| " for i in range(N):\n", | |
| " t = jitted_fun(x)\n", | |
| " return t\n", | |
| "\n", | |
| "@njit\n", | |
| "def funfun_class(f,x,N):\n", | |
| " t = 0.0\n", | |
| " for i in range(N):\n", | |
| " t = f.method(x)\n", | |
| " return t" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 10, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "# jitted function passed as an argument\n", | |
| "@njit\n", | |
| "def funfun_jitted_arg(f, x, N):\n", | |
| " t = 0.0\n", | |
| " for i in range(N):\n", | |
| " t = f(x)\n", | |
| " return t" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 11, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "N = 1000000\n", | |
| "x = 1.0" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 12, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "CPU times: user 75.7 ms, sys: 2.17 ms, total: 77.9 ms\n", | |
| "Wall time: 76.2 ms\n" | |
| ] | |
| }, | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "1.0" | |
| ] | |
| }, | |
| "execution_count": 12, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "# pure python code\n", | |
| "%time funfun_direct(x, N)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 13, | |
| "metadata": { | |
| "scrolled": true | |
| }, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "CPU times: user 121 ms, sys: 3.18 ms, total: 124 ms\n", | |
| "Wall time: 122 ms\n", | |
| "CPU times: user 14.7 ms, sys: 129 µs, total: 14.9 ms\n", | |
| "Wall time: 14.8 ms\n" | |
| ] | |
| }, | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "1.0" | |
| ] | |
| }, | |
| "execution_count": 13, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "# jit class\n", | |
| "%time funfun_class(sc,x,N)\n", | |
| "%time funfun_class(sc,x,N)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 14, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "CPU times: user 52.8 ms, sys: 3.16 ms, total: 55.9 ms\n", | |
| "Wall time: 53.4 ms\n", | |
| "CPU times: user 5 µs, sys: 0 ns, total: 5 µs\n", | |
| "Wall time: 9.06 µs\n" | |
| ] | |
| }, | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "1.0" | |
| ] | |
| }, | |
| "execution_count": 14, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "# jitted, manual inlining\n", | |
| "%time funfun_direct_jitted(x, N)\n", | |
| "%time funfun_direct_jitted(x, N)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 15, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "CPU times: user 69.8 ms, sys: 2.51 ms, total: 72.3 ms\n", | |
| "Wall time: 70.1 ms\n", | |
| "CPU times: user 5 µs, sys: 1 µs, total: 6 µs\n", | |
| "Wall time: 9.78 µs\n" | |
| ] | |
| }, | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "1.0" | |
| ] | |
| }, | |
| "execution_count": 15, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "# jitted automatic inlining\n", | |
| "%time funfun_jitted(x,N)\n", | |
| "%time funfun_jitted(x,N)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 16, | |
| "metadata": { | |
| "scrolled": false | |
| }, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "CPU times: user 52.4 ms, sys: 3.28 ms, total: 55.7 ms\n", | |
| "Wall time: 53.1 ms\n", | |
| "CPU times: user 48 µs, sys: 1 µs, total: 49 µs\n", | |
| "Wall time: 52.9 µs\n" | |
| ] | |
| }, | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "1.0" | |
| ] | |
| }, | |
| "execution_count": 16, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "# jitted function passed as an argument\n", | |
| "%time funfun_jitted_arg(jitted_fun, x, N)\n", | |
| "%time funfun_jitted_arg(jitted_fun, x, N)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "## strange inlining problem:\n", | |
| "\n", | |
| "If x is not passed as an argument but defined in the function instead, inlining doesn't seem to happen." | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 17, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "@njit\n", | |
| "def funfun_manual_inline(N):\n", | |
| " t = 0.0\n", | |
| " x = 2.0\n", | |
| " for i in range(N):\n", | |
| " t = (x)**2\n", | |
| " return t\n", | |
| "\n", | |
| "@njit\n", | |
| "def funfun_auto_inline(N):\n", | |
| " t = 0.0\n", | |
| " x = 2.0\n", | |
| " for i in range(N):\n", | |
| " t = jitted_fun(x)\n", | |
| " return t" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 18, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "CPU times: user 47.6 ms, sys: 2.15 ms, total: 49.8 ms\n", | |
| "Wall time: 48.2 ms\n", | |
| "CPU times: user 5 µs, sys: 1 µs, total: 6 µs\n", | |
| "Wall time: 9.3 µs\n" | |
| ] | |
| }, | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "4.0" | |
| ] | |
| }, | |
| "execution_count": 18, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "%time funfun_manual_inline(N)\n", | |
| "%time funfun_manual_inline(N)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 19, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "CPU times: user 52.9 ms, sys: 3.02 ms, total: 55.9 ms\n", | |
| "Wall time: 54 ms\n", | |
| "CPU times: user 6 µs, sys: 0 ns, total: 6 µs\n", | |
| "Wall time: 9.06 µs\n" | |
| ] | |
| }, | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "4.0" | |
| ] | |
| }, | |
| "execution_count": 19, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "%time funfun_auto_inline(N)\n", | |
| "%time funfun_auto_inline(N)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [] | |
| } | |
| ], | |
| "metadata": { | |
| "kernelspec": { | |
| "display_name": "Python 3", | |
| "language": "python", | |
| "name": "python3" | |
| }, | |
| "language_info": { | |
| "codemirror_mode": { | |
| "name": "ipython", | |
| "version": 3 | |
| }, | |
| "file_extension": ".py", | |
| "mimetype": "text/x-python", | |
| "name": "python", | |
| "nbconvert_exporter": "python", | |
| "pygments_lexer": "ipython3", | |
| "version": "3.6.4" | |
| } | |
| }, | |
| "nbformat": 4, | |
| "nbformat_minor": 1 | |
| } |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment