如何扩展 NumPy#

静态和重复的东西很无聊。动态的东西
和随机的东西令人困惑。艺术则介于两者之间。
John A. Locke
科学是微分方程。宗教是边界条件。
Alan Turing

编写扩展模块#

虽然 ndarray 对象旨在允许在 Python 中进行快速计算,但它也旨在通用且满足各种计算需求。因此,如果绝对速度至关重要,那么没有什么可以替代针对您的应用程序和硬件精心编写的编译循环。这就是 numpy 包含 f2py 的原因之一,这样就可以轻松地将 (简单) C/C++ 和 (任意) Fortran 代码直接链接到 Python 中。我们鼓励您使用和改进此机制。本节的目的不是记录此工具,而是记录此工具所依赖的编写扩展模块的更基本步骤。

当扩展模块被编写、编译并安装到 Python 路径 (sys.path) 的某个位置时,该代码就可以像标准 Python 文件一样导入到 Python 中。它将包含在 C 代码中定义和编译的对象和方法。在 Python 中执行此操作的基本步骤已记录在案,您可以在网上找到更多信息,网址为 www.python.org

除了 Python C-API 之外,还有一个完整而丰富的 NumPy C-API,允许在 C 层面上进行复杂的处理。但是,对于大多数应用程序,通常只需要使用几个 API 调用。例如,如果您只需要提取指向内存的指针以及一些形状信息以传递给另一个计算例程,那么您将使用与尝试创建新的类似数组类型或为 ndarray 添加新数据类型时不同的调用。本章记录了最常用的 API 调用和宏。

必需子程序#

为了让 Python 将其用作扩展模块,您的 C 代码中必须定义一个函数。该函数必须命名为 init{name},其中 {name} 是 Python 中模块的名称。此函数必须声明为对例程外部的代码可见。除了添加所需的方法和常量之外,此子程序还必须包含诸如 import_array() 和/或 import_ufunc() 的调用,具体取决于需要哪个 C-API。忘记放置这些命令会在实际调用任何 C-API 子程序时立即显示为丑陋的分段错误(崩溃)。实际上,可以在单个文件中拥有多个 init{name} 函数,在这种情况下,该文件将定义多个模块。但是,有一些技巧可以使其正常工作,此处未作介绍。

一个最小的 init{name} 方法如下所示

PyMODINIT_FUNC
init{name}(void)
{
   (void)Py_InitModule({name}, mymethods);
   import_array();
}

mymethods 必须是 PyMethodDef 结构的数组(通常是静态声明的),其中包含方法名称、实际 C 函数、指示方法是否使用关键字参数的变量以及文档字符串。这些将在下一节中解释。如果要向模块添加常量,则存储 Py_InitModule 的返回值,它是一个模块对象。向模块添加项目的常用方法是使用 PyModule_GetDict(module) 获取模块字典。使用模块字典,您可以手动向模块添加任何您想要的内容。向模块添加对象的更简单方法是使用三个附加 Python C-API 调用中的一个,这些调用不需要单独提取模块字典。这些在 Python 文档中有所记载,但为了方便起见,这里重复列出。

int PyModule_AddObject(PyObject *module, char *name, PyObject *value)#
int PyModule_AddIntConstant(PyObject *module, char *name, long value)#
int PyModule_AddStringConstant(PyObject *module, char *name, char *value)#

所有这三个函数都需要module 对象(Py_InitModule 的返回值)。name 是一个字符串,用于标记模块中的值。根据调用的函数的不同,value 参数是通用对象(PyModule_AddObject 会窃取对其的引用)、整数常量或字符串常量。

定义函数#

传递给 Py_InitModule 函数的第二个参数是一个结构,它使在模块中定义函数变得容易。在上例中,mymethods 结构将在文件的前面(通常在 init{name} 子程序之前)定义为

static PyMethodDef mymethods[] = {
    { nokeywordfunc,nokeyword_cfunc,
      METH_VARARGS,
      Doc string},
    { keywordfunc, keyword_cfunc,
      METH_VARARGS|METH_KEYWORDS,
      Doc string},
    {NULL, NULL, 0, NULL} /* Sentinel */
}

mymethods 数组中的每个条目都是一个 PyMethodDef 结构,包含 1) Python 名称,2) 实现该函数的 C 函数,3) 指示该函数是否接受关键字参数的标志,以及 4) 该函数的文档字符串。可以通过向该表中添加更多条目来为单个模块定义任意数量的函数。最后一个条目必须全部为 NULL,如所示,作为哨兵。Python 会查找此条目以了解模块的所有函数都已定义。

完成扩展模块的最后一步是实际编写执行所需函数的代码。有两种函数:不接受关键字参数的函数和接受关键字参数的函数。

不带关键字参数的函数#

不接受关键字参数的函数应编写为

static PyObject*
nokeyword_cfunc (PyObject *dummy, PyObject *args)
{
    /* convert Python arguments */
    /* do function */
    /* return something */
}

在此上下文中不使用虚拟参数,可以安全地忽略它。args 参数包含作为元组传递给函数的所有参数。此时您可以做任何您想做的事情,但通常管理输入参数的最简单方法是调用 PyArg_ParseTuple (args, format_string, addresses_to_C_variables…) 或 PyArg_UnpackTuple (tuple, “name”, min, max, …)。关于如何使用第一个函数的良好描述包含在 Python C-API 参考手册的第 5.5 节(解析参数和构建值)中。您应该特别注意“O&”格式,它使用转换器函数在 Python 对象和 C 对象之间进行转换。所有其他格式函数都可以(大多)被认为是此一般规则的特例。NumPy C-API 中定义了几个可能会有用的转换器函数。特别是,PyArray_DescrConverter 函数对于支持任意数据类型规范非常有用。此函数将任何有效的 data-type Python 对象转换为 PyArray_Descr* 对象。记住传入应该填充的 C 变量的地址。

在 NumPy 源代码中有很多关于如何使用 PyArg_ParseTuple 的示例。标准用法如下所示

PyObject *input;
PyArray_Descr *dtype;
if (!PyArg_ParseTuple(args, "OO&", &input,
                      PyArray_DescrConverter,
                      &dtype)) return NULL;

使用“O”格式字符串时,务必记住您获得的是对象的借用引用。但是,转换函数通常需要某种形式的内存处理。在此示例中,如果转换成功,dtype将持有指向PyArray_Descr*对象的新的引用,而input将持有借用引用。因此,如果此转换与另一个转换(例如转换为整数)混合,并且数据类型转换成功但整数转换失败,则需要在返回之前释放对数据类型对象的引用计数。一种典型的方法是在调用PyArg_ParseTuple之前将dtype设置为NULL,然后在返回之前对dtype使用Py_XDECREF

处理输入参数后,将编写实际执行工作的代码(可能需要调用其他函数)。C 函数的最后一步是返回某些内容。如果遇到错误,则应返回NULL(确保已实际设置错误)。如果不需要返回任何内容,则递增Py_None 并返回它。如果应返回单个对象,则返回它(确保您首先拥有对它的引用)。如果应返回多个对象,则需要返回一个元组。Py_BuildValue (format_string, c_variables…) 函数可以轻松地根据 C 变量构建 Python 对象的元组。特别注意格式字符串中“N”和“O”之间的区别,否则很容易造成内存泄漏。“O”格式字符串会递增其对应的PyObject* C 变量的引用计数,而“N”格式字符串会窃取对相应PyObject* C 变量的引用。如果您已经为对象创建了引用,并且只想将该引用提供给元组,则应使用“N”。如果您只拥有对象的借用引用,并且需要创建一个引用来提供给元组,则应使用“O”。

带关键字参数的函数#

这些函数与不带关键字参数的函数非常相似。唯一的区别在于函数签名是:

static PyObject*
keyword_cfunc (PyObject *dummy, PyObject *args, PyObject *kwds)
{
...
}

kwds 参数持有 Python 字典,其键是关键字参数的名称,其值是对应的关键字参数值。可以根据需要处理此字典。但是,最简单的处理方法是用对 PyArg_ParseTuple (args, format_string, addresses…) 函数的调用替换为对 PyArg_ParseTupleAndKeywords (args, kwds, format_string, char *kwlist[], addresses…) 的调用。此函数的 kwlist 参数是一个以NULL结尾的字符串数组,提供预期的关键字参数。对于格式字符串中的每个条目,应该有一个字符串。使用此函数,如果传递了无效的关键字参数,则会引发 TypeError。

有关此函数的更多帮助,请参阅 Python 文档中扩展和嵌入教程的第 1.8 节(扩展函数的关键字参数)。

引用计数#

编写扩展模块时最大的困难是引用计数。这是 f2py、weave、Cython、ctypes 等流行的重要原因……如果错误地处理引用计数,则可能会从内存泄漏到分段错误等问题。我所知道的正确处理引用计数的唯一策略是付出心血和汗水。首先,您必须牢记每个 Python 变量都有一个引用计数。然后,您准确地了解每个函数对对象引用计数的作用,以便在需要时正确使用 DECREF 和 INCREF。引用计数确实可以测试您对编程技巧的耐心和尽责程度。尽管描述严峻,但大多数引用计数情况都非常简单明了,最常见的困难是在由于某些错误而过早退出例程之前没有对对象使用 DECREF。其次,常见的错误是没有拥有传递给将窃取引用的函数或宏的对象的引用(例如PyTuple_SET_ITEM,以及大多数接受PyArray_Descr 对象的函数)。

创建变量或它是某些函数的返回值时,通常会获得对变量的新引用(但是,有一些突出的例外情况——例如从元组或字典中获取项目)。当您拥有引用时,您有责任确保在不再需要变量时调用Py_DECREF (var)(并且没有其他函数“窃取”了它的引用)。此外,如果您将 Python 对象传递给将“窃取”引用的函数,则需要确保您拥有它(或使用Py_INCREF 获取您自己的引用)。您还将遇到借用引用的概念。借用引用的函数不会更改对象的引用计数,也不期望“保留”引用。它只是临时使用该对象。当您使用PyArg_ParseTuplePyArg_UnpackTuple 时,您会收到对元组中对象的借用引用,并且不应在函数内部更改其引用计数。通过练习,您可以学习正确地进行引用计数,但一开始可能会令人沮丧。

引用计数错误的一个常见来源是Py_BuildValue 函数。请仔细注意“N”格式字符和“O”格式字符之间的区别。如果您在子例程中创建了一个新对象(例如输出数组),并且您将其作为返回值元组的一部分返回,则很可能应该在Py_BuildValue 中使用“N”格式字符。“O”字符会将引用计数增加一。这将使调用者拥有一个全新数组的两个引用计数。当变量被删除并且引用计数减少一时,仍然会有额外的引用计数,并且该数组永远不会被释放。您将遇到由引用计数引起的内存泄漏。使用“N”字符可以避免这种情况,因为它将向调用者返回一个(元组内的)只有一个引用计数的对象。

处理数组对象#

大多数 NumPy 扩展模块都需要访问 ndarray 对象(或其子类之一)的内存。最简单的方法不需要您了解 NumPy 的内部结构。该方法是:

  1. 确保您正在处理一个行为良好的数组(已对齐,采用机器字节序且为单段), 类型和维数正确。

    1. 通过使用PyArray_FromAny 或基于它的宏从某个 Python 对象转换它。

    2. 通过使用PyArray_NewFromDescr 或基于它的更简单的宏或函数构造所需形状和类型的新的 ndarray。

  2. 获取数组的形状及其实际数据的指针。

  3. 将数据和形状信息传递给实际执行计算的子例程或代码的其他部分。

  4. 如果您正在编写算法,那么我建议您使用数组中包含的步幅信息来访问数组的元素(PyArray_GetPtr 宏使这变得轻松)。然后,您可以放宽要求,以免强制使用单段数组以及可能导致的数据复制。

以下小节将介绍每个子主题。

转换任意序列对象#

从任何可以转换为数组的 Python 对象获取数组的主要例程是 PyArray_FromAny。此函数非常灵活,具有许多输入参数。几个宏使其更易于使用基本函数。PyArray_FROM_OTF可以说是这些宏中最有用的一个,用于最常见的用途。它允许您将任意 Python 对象转换为特定内置数据类型(例如 float)的数组,同时指定一组特定的要求(例如 连续、对齐和可写)。语法如下:

PyArray_FROM_OTF

从任何可以转换为数组的 Python 对象 *obj* 返回一个 ndarray。返回数组的维数由对象决定。返回数组的所需数据类型在 *typenum* 中提供,它应该是枚举类型之一。返回数组的 *requirements* 可以是标准数组标志的任何组合。下面将详细解释每个参数。成功时,您将收到对数组的新引用。失败时,返回 NULL 并设置异常。

obj

该对象可以是任何可转换为 ndarray 的 Python 对象。如果该对象已经是满足要求的 ndarray(或其子类),则返回一个新引用。否则,将构造一个新数组。除非使用数组接口,否则 *obj* 的内容将被复制到新数组中,这样就不需要复制数据。可以转换为数组的对象包括:1) 任何嵌套序列对象,2) 任何公开数组接口的对象,3) 任何具有 __array__ 方法的对象(该方法应返回一个 ndarray),以及 4) 任何标量对象(成为一个零维数组)。满足要求的 ndarray 的子类将被传递。如果您想确保一个基类 ndarray,则在 requirements 标志中使用 NPY_ARRAY_ENSUREARRAY。只有在必要时才会进行复制。如果您想保证复制,则将 NPY_ARRAY_ENSURECOPY 传递给 requirements 标志。

typenum

枚举类型之一,或者如果数据类型应该从对象本身确定,则为 NPY_NOTYPE。可以使用基于 C 的名称

或者,可以使用平台上支持的位宽名称。例如

只有在不丢失精度的情况下才能将对象转换为所需类型。否则,将返回 NULL 并引发错误。在 requirements 标志中使用 NPY_ARRAY_FORCECAST 来覆盖此行为。

requirements

ndarray 的内存模型允许在每个维度中使用任意步幅来前进到数组的下一个元素。但是,通常需要与期望 C 连续或 Fortran 连续内存布局的代码接口。此外,ndarray 可能未对齐(元素的地址不是元素大小的整数倍数),如果您尝试取消引用指向数组数据的指针,这可能会导致程序崩溃(或至少工作速度变慢)。这两个问题都可以通过将 Python 对象转换为更适合您特定用途的数组来解决。

requirements 标志允许指定可接受的数组类型。如果传入的对象不满足此要求,则会进行复制,以便返回的对象将满足这些要求。这些 ndarray 可以使用非常通用的内存指针。此标志允许指定返回的数组对象的所需属性。所有标志都在详细的 API 章节中解释。最常用的标志是 NPY_ARRAY_IN_ARRAYNPY_ARRAY_OUT_ARRAYNPY_ARRAY_INOUT_ARRAY

NPY_ARRAY_IN_ARRAY

此标志对于必须按 C 连续顺序排列且已对齐的数组很有用。此类数组通常是某些算法的输入数组。

NPY_ARRAY_OUT_ARRAY

此标志用于指定一个按 C 连续顺序排列、已对齐且也可以写入的数组。这样的数组通常作为输出返回(尽管通常这样的输出数组是从头开始创建的)。

NPY_ARRAY_INOUT_ARRAY

此标志用于指定将同时用作输入和输出的数组。PyArray_ResolveWritebackIfCopy 必须在接口例程结束时在 Py_DECREF 之前调用,以便将临时数据写回传入的原始数组。使用 NPY_ARRAY_WRITEBACKIFCOPY 标志要求输入对象已经是数组(因为其他对象无法以这种方式自动更新)。如果发生错误,请对设置了这些标志的数组使用 PyArray_DiscardWritebackIfCopy (obj)。这将设置底层基础数组的可写性,而不会导致内容复制回原始数组。

其他可以作为附加要求进行 OR 运算的有用标志是

NPY_ARRAY_FORCECAST

强制转换为所需类型,即使这样做可能会丢失信息。

NPY_ARRAY_ENSURECOPY

确保生成的数组是原始数组的副本。

NPY_ARRAY_ENSUREARRAY

确保生成的對象是實際的 ndarray,而不是子類。

注意

数组是否进行字节交换由数组的数据类型决定。PyArray_FROM_OTF 始终请求本机字节顺序数组,因此在 requirements 参数中不需要 NPY_ARRAY_NOTSWAPPED 标志。也无法从此例程获取字节交换数组。

创建全新的 ndarray#

扩展模块代码中经常需要创建新的数组。也许需要一个输出数组,而你不想让调用者提供它;也许只需要一个临时数组来保存中间计算结果。无论出于何种需要,都有简单的方法来获取所需任何数据类型的ndarray对象。执行此操作最通用的函数是PyArray_NewFromDescr。所有数组创建函数都通过这段被大量复用的代码。由于其灵活性,使用起来可能有些混乱。因此,存在更易于使用的简化形式。这些形式是PyArray_SimpleNew函数族的一部分,这些函数通过为常见用例提供默认值来简化接口。

访问ndarray内存和访问ndarray的元素#

如果obj是一个ndarray (PyArrayObject*),则ndarray的数据区域由void*指针PyArray_DATA (obj)或char*指针PyArray_BYTES (obj)指向。请记住,(通常)此数据区域可能未根据数据类型对齐,它可能表示字节交换的数据,并且/或者它可能不可写。如果数据区域已对齐并采用本机字节序,则访问数组特定元素的方法仅由npy_intp变量数组PyArray_STRIDES (obj)确定。特别是,这个整数c数组显示了必须添加到当前元素指针的**字节数**,以获取每个维度中的下一个元素。对于小于4维的数组,存在PyArray_GETPTR{k} (obj, …)宏,其中{k}是整数1、2、3或4,这使得使用数组步幅更容易。参数……表示数组中的{k}个非负整数索引。例如,假设E是一个3维ndarray。指向元素E[i,j,k]的(void*)指针可通过PyArray_GETPTR3 (E, i, j, k)获得。

如前所述,C风格连续数组和Fortran风格连续数组具有特定的步幅模式。两个数组标志(NPY_ARRAY_C_CONTIGUOUSNPY_ARRAY_F_CONTIGUOUS)指示特定数组的步幅模式是否与C风格连续或Fortran风格连续匹配,或者两者都不匹配。步幅模式是否与标准C或Fortran模式匹配可以使用PyArray_IS_C_CONTIGUOUS (obj)和PyArray_ISFORTRAN (obj)分别进行测试。大多数第三方库期望连续数组。但是,支持通用步幅通常并不困难。我鼓励你尽可能在自己的代码中使用步幅信息,并将单段要求保留给包装第三方代码。使用ndarray提供的步幅信息,而不是要求连续步幅,可以减少否则必须进行的复制。

示例#

以下示例显示了如何编写一个接受两个输入参数(将被转换为数组)和一个输出参数(必须是数组)的包装器。该函数返回None并更新输出数组。请注意,对于NumPy v1.14及更高版本,WRITEBACKIFCOPY语义已更新。

static PyObject *
example_wrapper(PyObject *dummy, PyObject *args)
{
    PyObject *arg1=NULL, *arg2=NULL, *out=NULL;
    PyObject *arr1=NULL, *arr2=NULL, *oarr=NULL;

    if (!PyArg_ParseTuple(args, "OOO!", &arg1, &arg2,
        &PyArray_Type, &out)) return NULL;

    arr1 = PyArray_FROM_OTF(arg1, NPY_DOUBLE, NPY_ARRAY_IN_ARRAY);
    if (arr1 == NULL) return NULL;
    arr2 = PyArray_FROM_OTF(arg2, NPY_DOUBLE, NPY_ARRAY_IN_ARRAY);
    if (arr2 == NULL) goto fail;
#if NPY_API_VERSION >= 0x0000000c
    oarr = PyArray_FROM_OTF(out, NPY_DOUBLE, NPY_ARRAY_INOUT_ARRAY2);
#else
    oarr = PyArray_FROM_OTF(out, NPY_DOUBLE, NPY_ARRAY_INOUT_ARRAY);
#endif
    if (oarr == NULL) goto fail;

    /* code that makes use of arguments */
    /* You will probably need at least
       nd = PyArray_NDIM(<..>)    -- number of dimensions
       dims = PyArray_DIMS(<..>)  -- npy_intp array of length nd
                                     showing length in each dim.
       dptr = (double *)PyArray_DATA(<..>) -- pointer to data.

       If an error occurs goto fail.
     */

    Py_DECREF(arr1);
    Py_DECREF(arr2);
#if NPY_API_VERSION >= 0x0000000c
    PyArray_ResolveWritebackIfCopy(oarr);
#endif
    Py_DECREF(oarr);
    Py_INCREF(Py_None);
    return Py_None;

 fail:
    Py_XDECREF(arr1);
    Py_XDECREF(arr2);
#if NPY_API_VERSION >= 0x0000000c
    PyArray_DiscardWritebackIfCopy(oarr);
#endif
    Py_XDECREF(oarr);
    return NULL;
}