Numba - an overview

Numba is a just-in-time compiler for Python that works best on code that uses NumPy arrays and functions, and loops. ArviZ includes Numba as an optional dependency and a number of functions have been included in utils.py for systems in which Numba is pre-installed. An additional functionality of disabling/re-enabling numba for systems which have numba installed has also been included.

A simple example to display the effectiveness of Numba

[1]:
import arviz as az
from arviz.utils import conditional_jit, Numba
from arviz.stats import geweke
from arviz.stats.diagnostics import ks_summary
import numpy as np
import timeit
[2]:
data = np.random.randn(1000000)
[3]:
def variance(data, ddof=0): # Method to calculate variance without using numba
    a_a, b_b = 0, 0
    for i in data:
        a_a = a_a + i
        b_b = b_b + i * i
    var = b_b / (len(data)) - ((a_a / (len(data))) ** 2)
    var = var * (len(data) / (len(data) - ddof))
    return var
[4]:
%timeit variance(data, ddof=1)
382 ms ± 23.2 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
[5]:
@conditional_jit
def variance_jit(data, ddof=0): # Calculating variance with numba
    a_a, b_b = 0, 0
    for i in data:
        a_a = a_a + i
        b_b = b_b + i * i
    var = b_b / (len(data)) - ((a_a / (len(data))) ** 2)
    var = var * (len(data) / (len(data) - ddof))
    return var
[6]:
%timeit variance_jit(data, ddof=1)
1.88 ms ± 136 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)

That is almost 300 times faster!! Let’s compare this to numpy

[7]:
%timeit np.var(data, ddof=1)
7.52 ms ± 866 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)

In certain scenarios, Numba outperforms numpy! Let’s see Numba’s effect on a few of ArviZ functions

[8]:
Numba.disable_numba() # This disables numba
Numba.numba_flag
[8]:
False
[9]:
data = np.random.randn(1000000)
smaller_data = np.random.randn(1000)
[10]:
%timeit geweke(data)
96.1 ms ± 8.26 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
[11]:
%timeit geweke(smaller_data)
3.31 ms ± 220 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
[12]:
Numba.enable_numba() #This will re-enable numba
Numba.numba_flag # This indicates the status of Numba
[12]:
True
[13]:
%timeit geweke(data)
29.1 ms ± 1.93 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
[14]:
%timeit geweke(smaller_data)
1.5 ms ± 150 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
[15]:
Numba.enable_numba()
Numba.numba_flag
[15]:
True

Numba speeds up the code by a factor of two approximately. Let’s check some other method

[16]:
summary_data = np.random.randn(1000,100,10)
school = az.load_arviz_data("centered_eight").posterior["mu"].values
[17]:
Numba.disable_numba()
Numba.numba_flag
[17]:
False
[18]:
%timeit ks_summary(summary_data)
124 ms ± 5.75 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
[19]:
%timeit ks_summary(school)
2.29 ms ± 234 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
[20]:
Numba.enable_numba()
Numba.numba_flag
[20]:
True
[21]:
%timeit ks_summary(summary_data)
15.9 ms ± 977 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)
[22]:
%timeit ks_summary(school)
1.3 ms ± 166 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)

Numba has provided a substantial speedup once again.