Cython 实现内循环#
那些希望从底层操作中获得极高性能的人应该认真考虑直接使用 C 中提供的迭代 API,但对于不熟悉 C 或 C++ 的人来说,Cython 是一个不错的折中方案,可以提供合理的性能权衡。对于 nditer 对象,这意味着让迭代器处理广播、dtype 转换和缓冲,同时将内循环交给 Cython。
在我们的例子中,我们将创建一个平方和函数。首先,让我们用直接的 Python 实现这个函数。我们希望支持一个类似于 NumPy sum 函数的 ‘axis’ 参数,所以我们需要为 op_axes 参数构造一个列表。下面是它的样子。
示例
>>> def axis_to_axeslist(axis, ndim):
... if axis is None:
... return [-1] * ndim
... else:
... if type(axis) is not tuple:
... axis = (axis,)
... axeslist = [1] * ndim
... for i in axis:
... axeslist[i] = -1
... ax = 0
... for i in range(ndim):
... if axeslist[i] != -1:
... axeslist[i] = ax
... ax += 1
... return axeslist
...
>>> def sum_squares_py(arr, axis=None, out=None):
... axeslist = axis_to_axeslist(axis, arr.ndim)
... it = np.nditer([arr, out], flags=['reduce_ok',
... 'buffered', 'delay_bufalloc'],
... op_flags=[['readonly'], ['readwrite', 'allocate']],
... op_axes=[None, axeslist],
... op_dtypes=['float64', 'float64'])
... with it:
... it.operands[1][...] = 0
... it.reset()
... for x, y in it:
... y[...] += x*x
... return it.operands[1]
...
>>> a = np.arange(6).reshape(2,3)
>>> sum_squares_py(a)
array(55.)
>>> sum_squares_py(a, axis=-1)
array([ 5., 50.])
为了 Cython 化这个函数,我们用专门针对 float64 dtype 的 Cython 代码替换内循环 (y[…] += x*x)。当 ‘external_loop’ 标志启用时,提供给内循环的数组总是_一维_的,因此只需要很少的检查。
下面是 sum_squares.pyx 的列表
import numpy as np
cimport numpy as np
cimport cython
def axis_to_axeslist(axis, ndim):
if axis is None:
return [-1] * ndim
else:
if type(axis) is not tuple:
axis = (axis,)
axeslist = [1] * ndim
for i in axis:
axeslist[i] = -1
ax = 0
for i in range(ndim):
if axeslist[i] != -1:
axeslist[i] = ax
ax += 1
return axeslist
@cython.boundscheck(False)
def sum_squares_cy(arr, axis=None, out=None):
cdef np.ndarray[double] x
cdef np.ndarray[double] y
cdef int size
cdef double value
axeslist = axis_to_axeslist(axis, arr.ndim)
it = np.nditer([arr, out], flags=['reduce_ok', 'external_loop',
'buffered', 'delay_bufalloc'],
op_flags=[['readonly'], ['readwrite', 'allocate']],
op_axes=[None, axeslist],
op_dtypes=['float64', 'float64'])
with it:
it.operands[1][...] = 0
it.reset()
for xarr, yarr in it:
x = xarr
y = yarr
size = x.shape[0]
for i in range(size):
value = x[i]
y[i] = y[i] + value * value
return it.operands[1]
在这台机器上,将 .pyx 文件构建成模块的过程如下,但您可能需要查找一些 Cython 教程来了解您系统配置的具体信息。
$ cython sum_squares.pyx
$ gcc -shared -pthread -fPIC -fwrapv -O2 -Wall -I/usr/include/python2.7 -fno-strict-aliasing -o sum_squares.so sum_squares.c
从 Python 解释器运行此代码会产生与我们原生的 Python/NumPy 代码相同的答案。
示例
>>> from sum_squares import sum_squares_cy
>>> a = np.arange(6).reshape(2,3)
>>> sum_squares_cy(a)
array(55.0)
>>> sum_squares_cy(a, axis=-1)
array([ 5., 50.])
在 IPython 中进行一些计时测试表明,Cython 内循环减少的开销和内存分配,相比于直接的 Python 代码和使用 NumPy 内置 sum 函数的表达式,提供了非常不错的速度提升。
>>> a = np.random.rand(1000,1000)
>>> timeit sum_squares_py(a, axis=-1)
10 loops, best of 3: 37.1 ms per loop
>>> timeit np.sum(a*a, axis=-1)
10 loops, best of 3: 20.9 ms per loop
>>> timeit sum_squares_cy(a, axis=-1)
100 loops, best of 3: 11.8 ms per loop
>>> np.all(sum_squares_cy(a, axis=-1) == np.sum(a*a, axis=-1))
True
>>> np.all(sum_squares_py(a, axis=-1) == np.sum(a*a, axis=-1))
True