Pythran Case: Resampling
While hanging on Stackoverflow (everybody does this, no?) I found this Numpy code snippet:
import numpy as np
def resample(qs, xs, rands):
results = np.empty_like(qs)
lookup = np.cumsum(qs)
for j, key in enumerate(rands):
i = np.argmax(lookup > key)
results[j] = xs[i]
return results
When running it through timeit, we get:
% python -m timeit -s 'import numpy as np; np.random.seed(0) ; n = 1000; xs = np.arange(n, dtype=np.float64); qs = np.array([1.0/n,]*n); rands = np.random.rand(n); from resample import resample' 'resample(qs, xs, rands)'
100 loops, best of 3: 3.02 msec per loop
The initialization code, after the -s switch, is run only once, and includes a call to np.random.seed so that further comparisons hold.
First step: Pythran
What kind of optimisations could improve this code? np.cumsum, np.argmax and lookup > key all are Numpy functions, so they run as native code and there should not be much to gain there.
But if we look carefully, lookup > key is building an intermediate array, which is then passed as argument to np.argmax. This temporary array is not needed, as np.argmax could work on a stream. That's a typical shortcoming of Numpy eager evaluation, a pedantic word to state that expressions are evaluated when they are called, and not when their result is needed (which is lazy evaluation).
Pythran automatically computes when an expression can be lazily evaluated, (even when it's bound to a variable, which is not the case here). So maybe we could get some speedup?
To use Pythran, we just add a comment line that states the expected types of the top-level function:
#pythran export resample(float[], float[], float[])
import numpy as np
def resample(qs, xs, rands):
results = np.empty_like(qs)
lookup = np.cumsum(qs)
for j, key in enumerate(rands):
i = np.argmax(lookup > key)
results[j] = xs[i]
return results
And then call the pythran compiler:
% pythran resample.py
This turns the Python file into a native extension, namely resample.so on Linux. Running it yields a nice speedup:
% python -m timeit 'import numpy as np; np.random.seed(0) ; n = 1000; xs = np.arange(n, dtype=np.float64); qs = np.array([1.0/n,]*n); rands = np.random.rand(n); from resample import resample' 'resample(qs, xs, rands)'
1000 loops, best of 3: 1.23 msec per loop
Second step: Pythran + OpenMP
But could we do better? An astute reader would note that the for loop can be run in parallel (iterations are independent). There's a famous standard for C, C++ and Fortran to parallelize this kind of trivial loops (and to do many non trivial stuff also, but that's not the point here) called OpenMP. It turns out Pythran supports OpenMP :-). By adding an extra comment (that should look pretty familiar to anyone accustomed to OpenMP) on the parallel loop:
#pythran export resample(float[], float[], float[])
import numpy as np
def resample(qs, xs, rands):
results = np.empty_like(qs)
lookup = np.cumsum(qs)
#omp parallel for
for j, key in enumerate(rands):
i = np.argmax(lookup > key)
results[j] = xs[i]
return results
And adding the -fopenmp flag to the pythran call:
% pythran resample.py -fopenmp
We get an extra speedup (only two cores there, sorry about this :-/):
% python -m timeit 'import numpy as np; np.random.seed(0) ; n = 1000; xs = np.arange(n, dtype=np.float64); qs = np.array([1.0/n,]*n); rands = np.random.rand(n); from resample import resample' 'resample(qs, xs, rands)'
1000 loops, best of 3: 693 usec per loop
Third step: Pythran + Brain
Now wait… calling np.argmax on an array of bool is indeed a nice trick to get the index of the first value where lookup > key, but it evaluates the whole array. There's no early exit, while there could be (there's only 0 and 1 after all). As pointed out on the SO thread, one could write a np_index(array_expr) function that behaves like the list.index one:
#pythran export resample(float[], float[], float[])
import numpy as np
def np_index(haystack, needle):
for i, v in enumerate(haystack):
if v == needle:
return i
raise ValueError("Value not found")
def resample(qs, xs, rands):
results = np.empty_like(qs)
lookup = np.cumsum(qs)
#omp parallel for
for j, key in enumerate(rands):
i = np_index(lookup > key, True)
results[j] = xs[i]
return results
There's a few things to note in this implementation:
- there's no pythran export for np_index as it's not meant to be used outside the module;
- np_index plays well with lazy evaluation: the tail of the lookup > key expression is not evaluated if a non null value is found before;
- Pythran supports built-in exceptions ;-)
And a last benchmark, without OpenMP:
% pythran resample.py
% python -m timeit 'import numpy as np; np.random.seed(0) ; n = 1000; xs = np.arange(n, dtype=np.float64); qs = np.array([1.0/n,]*n); rands = np.random.rand(n); from resample import resample' 'resample(qs, xs, rands)'
1000 loops, best of 3: 491 usec per loop
And with OpenMP:
% pythran resample.py -fopenmp
% python -m timeit 'import numpy as np; np.random.seed(0) ; n = 1000; xs = np.arange(n, dtype=np.float64); qs = np.array([1.0/n,]*n); rands = np.random.rand(n); from resample import resample' 'resample(qs, xs, rands)'
1000 loops, best of 3: 326 usec per loop
The Stack Overflow Solution
For reference, the Numba solution proposed as the answer to the Stack Overflow thread is:
@nb.jit(nb.f8[:](nb.f8[:]))
def numba_cumsum(x):
return np.cumsum(x)
@nb.autojit
def numba_resample2(qs, xs, rands):
n = qs.shape[0]
lookup = numba_cumsum(qs)
results = np.empty(n)
for j in range(n):
for i in range(n):
if rands[j] < lookup[i]:
results[j] = xs[i]
break
return results
On my laptop, it runs in:
10 loops, best of 3: 419 usec per loop
The equivalent implementation in Pythran does not need type annotation for np.cumsum as it's already supported:
#pythran export resample(float[], float[], float[])
def resample(qs, xs, rands):
n = qs.shape[0]
lookup = np.cumsum(qs)
results = np.empty(n)
#omp parallel for
for j in range(n):
for i in range(n):
if rands[j] < lookup[i]:
results[j] = xs[i]
break
return results
And once compiled with Pythran it runs (no OpenMP) in:
1000 loops, best of 3: 350 usec per loop
Pythran and Numba timings are within the same range. Numba is still easier to integrate (Just In Time Compilation is really nice!) but it implies lower level implementation. Pythran can still use this implementation level efficiently, but that's not my preferred way of programming in Python ;-).
Final Thoughts
This is only a story telling of the initial Stack Overflow post, reinterpreted with Pythran in mind. What do we learn? Numpy provides a lot of nice facilities, but one still need to understand some of its internal to rip the best of it. And using Pythran you can do so while keeping a relatively good abstraction!