将内循环置于 Cython 中#

那些希望从其底层操作中获得出色性能的人应强烈考虑直接使用 C 提供的迭代 API,但对于不熟悉 C 或 C++ 的人来说,Cython 是一个很好的折衷方案,具有合理的性能权衡。对于 nditer 对象,这意味着让迭代器处理广播、dtype 转换和缓冲,同时将内循环交给 Cython。

为了我们的示例,我们将创建一个平方和函数。首先,让我们在纯 Python 中实现此函数。我们希望支持一个类似于 numpy sum 函数的 ‘axis’ 参数,因此我们需要为 op_axes 参数构建一个列表。示例如下。

示例

>>> def axis_to_axeslist(axis, ndim):
...     if axis is None:
...         return [-1] * ndim
...     else:
...         if type(axis) is not tuple:
...             axis = (axis,)
...         axeslist = [1] * ndim
...         for i in axis:
...             axeslist[i] = -1
...         ax = 0
...         for i in range(ndim):
...             if axeslist[i] != -1:
...                 axeslist[i] = ax
...                 ax += 1
...         return axeslist
...
>>> def sum_squares_py(arr, axis=None, out=None):
...     axeslist = axis_to_axeslist(axis, arr.ndim)
...     it = np.nditer([arr, out], flags=['reduce_ok',
...                                       'buffered', 'delay_bufalloc'],
...                 op_flags=[['readonly'], ['readwrite', 'allocate']],
...                 op_axes=[None, axeslist],
...                 op_dtypes=['float64', 'float64'])
...     with it:
...         it.operands[1][...] = 0
...         it.reset()
...         for x, y in it:
...             y[...] += x*x
...         return it.operands[1]
...
>>> a = np.arange(6).reshape(2,3)
>>> sum_squares_py(a)
array(55.)
>>> sum_squares_py(a, axis=-1)
array([  5.,  50.])

要将此函数 Cython 化,我们将内循环 (y[…] += x*x) 替换为专门用于 float64 dtype 的 Cython 代码。启用 ‘external_loop’ 标志后,提供给内循环的数组将始终是一维的,因此几乎不需要进行检查。

以下是 sum_squares.pyx 的代码列表

import numpy as np
cimport numpy as np
cimport cython

def axis_to_axeslist(axis, ndim):
    if axis is None:
        return [-1] * ndim
    else:
        if type(axis) is not tuple:
            axis = (axis,)
        axeslist = [1] * ndim
        for i in axis:
            axeslist[i] = -1
        ax = 0
        for i in range(ndim):
            if axeslist[i] != -1:
                axeslist[i] = ax
                ax += 1
        return axeslist

@cython.boundscheck(False)
def sum_squares_cy(arr, axis=None, out=None):
    cdef np.ndarray[double] x
    cdef np.ndarray[double] y
    cdef int size
    cdef double value

    axeslist = axis_to_axeslist(axis, arr.ndim)
    it = np.nditer([arr, out], flags=['reduce_ok', 'external_loop',
                                      'buffered', 'delay_bufalloc'],
                op_flags=[['readonly'], ['readwrite', 'allocate']],
                op_axes=[None, axeslist],
                op_dtypes=['float64', 'float64'])
    with it:
        it.operands[1][...] = 0
        it.reset()
        for xarr, yarr in it:
            x = xarr
            y = yarr
            size = x.shape[0]
            for i in range(size):
               value = x[i]
               y[i] = y[i] + value * value
        return it.operands[1]

在此机器上,将 .pyx 文件构建为模块的方式如下所示,但您可能需要查阅一些 Cython 教程以了解您的系统配置的具体信息。

$ cython sum_squares.pyx
$ gcc -shared -pthread -fPIC -fwrapv -O2 -Wall -I/usr/include/python2.7 -fno-strict-aliasing -o sum_squares.so sum_squares.c

在 Python 解释器中运行此代码会产生与我们原生 Python/NumPy 代码相同的结果。

示例

>>> from sum_squares import sum_squares_cy 
>>> a = np.arange(6).reshape(2,3)
>>> sum_squares_cy(a) 
array(55.0)
>>> sum_squares_cy(a, axis=-1) 
array([  5.,  50.])

在 IPython 中进行一些计时显示,Cython 内循环减少的开销和内存分配,相比于纯 Python 代码和使用 NumPy 内置 sum 函数的表达式,提供了非常显著的加速。

>>> a = np.random.rand(1000,1000)

>>> timeit sum_squares_py(a, axis=-1)
10 loops, best of 3: 37.1 ms per loop

>>> timeit np.sum(a*a, axis=-1)
10 loops, best of 3: 20.9 ms per loop

>>> timeit sum_squares_cy(a, axis=-1)
100 loops, best of 3: 11.8 ms per loop

>>> np.all(sum_squares_cy(a, axis=-1) == np.sum(a*a, axis=-1))
True

>>> np.all(sum_squares_py(a, axis=-1) == np.sum(a*a, axis=-1))
True