Cython 实现内循环#

那些希望从底层操作中获得极高性能的人应该认真考虑直接使用 C 中提供的迭代 API,但对于不熟悉 C 或 C++ 的人来说,Cython 是一个不错的折中方案,可以提供合理的性能权衡。对于 nditer 对象,这意味着让迭代器处理广播、dtype 转换和缓冲,同时将内循环交给 Cython。

在我们的例子中,我们将创建一个平方和函数。首先,让我们用直接的 Python 实现这个函数。我们希望支持一个类似于 NumPy sum 函数的 ‘axis’ 参数,所以我们需要为 op_axes 参数构造一个列表。下面是它的样子。

示例

>>> def axis_to_axeslist(axis, ndim):
...     if axis is None:
...         return [-1] * ndim
...     else:
...         if type(axis) is not tuple:
...             axis = (axis,)
...         axeslist = [1] * ndim
...         for i in axis:
...             axeslist[i] = -1
...         ax = 0
...         for i in range(ndim):
...             if axeslist[i] != -1:
...                 axeslist[i] = ax
...                 ax += 1
...         return axeslist
...
>>> def sum_squares_py(arr, axis=None, out=None):
...     axeslist = axis_to_axeslist(axis, arr.ndim)
...     it = np.nditer([arr, out], flags=['reduce_ok',
...                                       'buffered', 'delay_bufalloc'],
...                 op_flags=[['readonly'], ['readwrite', 'allocate']],
...                 op_axes=[None, axeslist],
...                 op_dtypes=['float64', 'float64'])
...     with it:
...         it.operands[1][...] = 0
...         it.reset()
...         for x, y in it:
...             y[...] += x*x
...         return it.operands[1]
...
>>> a = np.arange(6).reshape(2,3)
>>> sum_squares_py(a)
array(55.)
>>> sum_squares_py(a, axis=-1)
array([  5.,  50.])

为了 Cython 化这个函数,我们用专门针对 float64 dtype 的 Cython 代码替换内循环 (y[…] += x*x)。当 ‘external_loop’ 标志启用时,提供给内循环的数组总是_一维_的,因此只需要很少的检查。

下面是 sum_squares.pyx 的列表

import numpy as np
cimport numpy as np
cimport cython

def axis_to_axeslist(axis, ndim):
    if axis is None:
        return [-1] * ndim
    else:
        if type(axis) is not tuple:
            axis = (axis,)
        axeslist = [1] * ndim
        for i in axis:
            axeslist[i] = -1
        ax = 0
        for i in range(ndim):
            if axeslist[i] != -1:
                axeslist[i] = ax
                ax += 1
        return axeslist

@cython.boundscheck(False)
def sum_squares_cy(arr, axis=None, out=None):
    cdef np.ndarray[double] x
    cdef np.ndarray[double] y
    cdef int size
    cdef double value

    axeslist = axis_to_axeslist(axis, arr.ndim)
    it = np.nditer([arr, out], flags=['reduce_ok', 'external_loop',
                                      'buffered', 'delay_bufalloc'],
                op_flags=[['readonly'], ['readwrite', 'allocate']],
                op_axes=[None, axeslist],
                op_dtypes=['float64', 'float64'])
    with it:
        it.operands[1][...] = 0
        it.reset()
        for xarr, yarr in it:
            x = xarr
            y = yarr
            size = x.shape[0]
            for i in range(size):
               value = x[i]
               y[i] = y[i] + value * value
        return it.operands[1]

在这台机器上,将 .pyx 文件构建成模块的过程如下,但您可能需要查找一些 Cython 教程来了解您系统配置的具体信息。

$ cython sum_squares.pyx
$ gcc -shared -pthread -fPIC -fwrapv -O2 -Wall -I/usr/include/python2.7 -fno-strict-aliasing -o sum_squares.so sum_squares.c

从 Python 解释器运行此代码会产生与我们原生的 Python/NumPy 代码相同的答案。

示例

>>> from sum_squares import sum_squares_cy 
>>> a = np.arange(6).reshape(2,3)
>>> sum_squares_cy(a) 
array(55.0)
>>> sum_squares_cy(a, axis=-1) 
array([  5.,  50.])

在 IPython 中进行一些计时测试表明,Cython 内循环减少的开销和内存分配,相比于直接的 Python 代码和使用 NumPy 内置 sum 函数的表达式,提供了非常不错的速度提升。

>>> a = np.random.rand(1000,1000)

>>> timeit sum_squares_py(a, axis=-1)
10 loops, best of 3: 37.1 ms per loop

>>> timeit np.sum(a*a, axis=-1)
10 loops, best of 3: 20.9 ms per loop

>>> timeit sum_squares_cy(a, axis=-1)
100 loops, best of 3: 11.8 ms per loop

>>> np.all(sum_squares_cy(a, axis=-1) == np.sum(a*a, axis=-1))
True

>>> np.all(sum_squares_py(a, axis=-1) == np.sum(a*a, axis=-1))
True