如何扩展 NumPy#

静态和重复的东西很无聊。动态
和随机的东西令人困惑。艺术则介于两者之间。
John A. Locke
科学是一个微分方程。宗教是一个边界条件。
Alan Turing

编写扩展模块#

虽然 ndarray 对象旨在允许在 Python 中进行快速计算,但它也旨在通用且满足各种计算需求。因此,如果绝对速度至关重要,则没有比针对您的应用程序和硬件精心设计的编译循环更好的替代方案了。这就是 numpy 包含 f2py 的原因之一,以便提供一个易于使用的机制,用于将 (简单) C/C++ 和 (任意) Fortran 代码直接链接到 Python 中。鼓励您使用和改进此机制。本节的目的是不记录此工具,而是记录此工具所依赖的编写扩展模块的更基本步骤。

当扩展模块被编写、编译并安装到 Python 路径 (sys.path) 的某个位置时,代码就可以像标准 Python 文件一样导入到 Python 中。它将包含在 C 代码中定义和编译的对象和方法。在 Python 中执行此操作的基本步骤已记录在案,您可以在网上找到更多信息,网址为 www.python.org

除了 Python C-API 之外,NumPy 还提供了一个完整且丰富的 C-API,允许在 C 级别进行复杂的处理。但是,对于大多数应用程序,通常只需要使用几个 API 调用。例如,如果您只需要提取内存指针以及一些形状信息以传递给另一个计算例程,那么您将使用与尝试创建新的类数组类型或为 ndarray 添加新的数据类型时不同的调用。本章记录了最常用的 API 调用和宏。

必需的子例程#

为了让 Python 将您的 C 代码用作扩展模块,必须在您的 C 代码中定义一个函数。该函数必须命名为 init{name},其中 {name} 是 Python 中模块的名称。此函数必须声明为对例程外部的代码可见。除了添加您需要的 method 和常量之外,此子例程还必须包含诸如 import_array() 和/或 import_ufunc() 的调用,具体取决于需要哪个 C-API。忘记放置这些命令会导致在实际调用任何 C-API 子例程时出现难看的段错误 (崩溃)。实际上,可以在单个文件中包含多个 init{name} 函数,在这种情况下,该文件将定义多个模块。但是,有一些技巧可以使其正常工作,这里不作介绍。

一个最小的 init{name} 方法如下所示

PyMODINIT_FUNC
init{name}(void)
{
   (void)Py_InitModule({name}, mymethods);
   import_array();
}

mymethods 必须是 PyMethodDef 结构的数组(通常是静态声明的),其中包含方法名称、实际的 C 函数、一个指示方法是否使用关键字参数的变量以及文档字符串。这些将在下一节中解释。如果您想向模块添加常量,则存储 Py_InitModule 返回的值,该值是一个模块对象。向模块添加项目的通用方法是使用 PyModule_GetDict(module) 获取模块字典。使用模块字典,您可以手动向模块添加任何您想要的内容。向模块添加对象的更简单方法是使用三个额外的 Python C-API 调用,这些调用不需要单独提取模块字典。这些在 Python 文档中有所记录,但为了方便起见,这里重复一遍

int PyModule_AddObject(PyObject *module, char *name, PyObject *value)#
int PyModule_AddIntConstant(PyObject *module, char *name, long value)#
int PyModule_AddStringConstant(PyObject *module, char *name, char *value)#

所有这三个函数都需要 module 对象(Py_InitModule 的返回值)。name 是一个字符串,用于标记模块中的值。根据调用的函数,value 参数要么是一个通用对象(PyModule_AddObject 会窃取对它的引用)、一个整数常量或一个字符串常量。

定义函数#

传递给 Py_InitModule 函数的第二个参数是一个结构,它可以轻松地在模块中定义函数。在上面给出的示例中,mymethods 结构将在文件中的较早位置定义(通常在 init{name} 子例程之前)为

static PyMethodDef mymethods[] = {
    { nokeywordfunc,nokeyword_cfunc,
      METH_VARARGS,
      Doc string},
    { keywordfunc, keyword_cfunc,
      METH_VARARGS|METH_KEYWORDS,
      Doc string},
    {NULL, NULL, 0, NULL} /* Sentinel */
}

mymethods 数组中的每个条目都是一个 PyMethodDef 结构,包含 1) Python 名称,2) 实现函数的 C 函数,3) 指示此函数是否接受关键字的标志,以及 4) 函数的文档字符串。可以通过向此表添加更多条目来为单个模块定义任意数量的函数。最后一个条目必须全部为 NULL,如所示,用作哨兵。Python 查找此条目以知道模块的所有函数都已定义。

完成扩展模块的最后一步是实际编写执行所需函数的代码。函数有两种:不接受关键字参数的函数和接受关键字参数的函数。

不带关键字参数的函数#

不接受关键字参数的函数应编写为

static PyObject*
nokeyword_cfunc (PyObject *dummy, PyObject *args)
{
    /* convert Python arguments */
    /* do function */
    /* return something */
}

在此上下文中不使用 dummy 参数,可以安全地忽略它。args 参数包含作为元组传递给函数的所有参数。此时您可以执行任何操作,但通常管理输入参数的最简单方法是调用 PyArg_ParseTuple (args, format_string, addresses_to_C_variables…) 或 PyArg_UnpackTuple (tuple, “name”, min, max, …)。有关如何使用第一个函数的良好描述包含在 Python C-API 参考手册的第 5.5 节(解析参数和构建值)中。您应该特别注意“O&”格式,该格式使用转换器函数在 Python 对象和 C 对象之间进行转换。所有其他格式函数都可以(大部分)被认为是此一般规则的特例。NumPy C-API 中定义了几个可能会有用的转换器函数。特别是,PyArray_DescrConverter 函数对于支持任意数据类型规范非常有用。此函数将任何有效的数据类型 Python 对象转换为 PyArray_Descr* 对象。请记住传入应填充的 C 变量的地址。

在 NumPy 源代码中有很多关于如何使用 PyArg_ParseTuple 的示例。标准用法如下

PyObject *input;
PyArray_Descr *dtype;
if (!PyArg_ParseTuple(args, "OO&", &input,
                      PyArray_DescrConverter,
                      &dtype)) return NULL;

务必记住,使用“O”格式字符串时,您将获得对对象的借用引用。但是,转换器函数通常需要某种形式的内存处理。在此示例中,如果转换成功,则 dtype 将保存对 PyArray_Descr* 对象的新引用,而 input 将保存借用引用。因此,如果此转换与另一个转换(例如转换为整数)混合,并且数据类型转换成功但整数转换失败,则在返回之前,您需要释放对数据类型对象的引用计数。一种典型的方法是在调用 PyArg_ParseTuple 之前将 dtype 设置为 NULL,然后在返回之前对 dtype 使用 Py_XDECREF

处理完输入参数后,将编写实际执行工作的代码(可能需要调用其他函数)。C 函数的最后一步是返回某些内容。如果遇到错误,则应返回 NULL(确保已实际设置错误)。如果不需要返回任何内容,则递增 Py_None 并返回它。如果应返回单个对象,则返回它(首先确保您拥有对它的引用)。如果应返回多个对象,则需要返回一个元组。 Py_BuildValue (format_string, c_variables…) 函数可以轻松地从 C 变量构建 Python 对象的元组。请特别注意格式字符串中“N”和“O”之间的区别,否则很容易造成内存泄漏。“O”格式字符串会递增与其对应的 PyObject* C 变量的引用计数,而“N”格式字符串会窃取与其对应的 PyObject* C 变量的引用。如果您已为对象创建了一个引用并且只想将该引用提供给元组,则应使用“N”。如果您仅拥有对对象的借用引用并且需要创建一个引用来提供给元组,则应使用“O”。

带关键字参数的函数#

这些函数与不带关键字参数的函数非常相似。唯一的区别是函数签名为

static PyObject*
keyword_cfunc (PyObject *dummy, PyObject *args, PyObject *kwds)
{
...
}

kwds 参数保存一个 Python 字典,其键是关键字参数的名称,其值是相应的关键字参数值。可以根据需要处理此字典。但是,处理它的最简单方法是用对 PyArg_ParseTuple (args, format_string, addresses…) 的调用替换为对 PyArg_ParseTupleAndKeywords (args, kwds, format_string, char *kwlist[], addresses…) 的调用。此函数的 kwlist 参数是一个以 NULL 结尾的字符串数组,提供预期的关键字参数。格式字符串中的每个条目都应该有一个字符串。使用此函数将在传递无效关键字参数时引发 TypeError。

有关此函数的更多帮助,请参阅 Python 文档中“扩展和嵌入”教程的第 1.8 节(扩展函数的关键字参数)。

引用计数#

编写扩展模块时最大的困难在于引用计数。这是 f2py、weave、Cython、ctypes 等流行的重要原因……如果引用计数处理不当,可能会导致从内存泄漏到段错误等各种问题。我所知的唯一正确处理引用计数的策略是:血汗泪。首先,你要强迫自己记住每个 Python 变量都有一个引用计数。然后,你必须准确理解每个函数对你的对象引用计数做了什么,以便在需要时正确使用 DECREF 和 INCREF。引用计数可以真正考验你对编程技巧的耐心和勤奋程度。尽管描述得比较严峻,但大多数引用计数的情况都非常简单明了,最常见的困难是在由于某些错误过早退出例程之前,没有对对象使用 DECREF。其次,常见的错误是在将对象传递给将窃取其引用的函数或宏(例如 PyTuple_SET_ITEM,以及大多数使用 PyArray_Descr 对象的函数)时,没有拥有该对象的引用。

通常,当变量被创建或作为某些函数的返回值时,你会获得一个新的变量引用(但是,也有一些明显的例外——例如从元组或字典中获取项目)。当你拥有该引用时,你负责确保在不再需要该变量(并且没有其他函数“窃取”了它的引用)时调用 Py_DECREF (var)。此外,如果你将 Python 对象传递给一个将“窃取”其引用的函数,那么你需要确保你拥有该引用(或使用 Py_INCREF 获取你自己的引用)。你还会遇到借用引用的概念。借用引用的函数不会更改对象的引用计数,也不会期望“持有”该引用。它只是暂时使用该对象。当你使用 PyArg_ParseTuplePyArg_UnpackTuple 时,你会收到对元组中对象的借用引用,并且不应该在你的函数内部更改它们的引用计数。通过练习,你可以学会正确处理引用计数,但一开始可能会让人感到沮丧。

引用计数错误的一个常见来源是 Py_BuildValue 函数。请仔细注意“N”格式字符和“O”格式字符之间的区别。如果你在子例程中创建了一个新对象(例如输出数组),并且你将其通过返回值元组传递回去,那么你很可能应该在 Py_BuildValue 中使用“N”格式字符。“O”字符会将引用计数增加 1。这将使调用方对一个全新的数组拥有两个引用计数。当变量被删除并且引用计数减少 1 时,仍然存在额外的引用计数,并且该数组将永远不会被释放。你将遇到由引用计数引起的内存泄漏。使用“N”字符可以避免这种情况,因为它将返回一个(在元组内)只有一个引用计数的对象给调用方。

处理数组对象#

大多数 NumPy 的扩展模块都需要访问 ndarray 对象(或其子类之一)的内存。最简单的方法不需要你了解 NumPy 的内部机制。该方法是

  1. 确保你正在处理一个行为良好的数组(对齐的、机器字节序的和单段的),并且具有正确的类型和维度数量。

    1. 通过使用 PyArray_FromAny 或基于它的宏,将其从某个 Python 对象转换而来。

    2. 通过使用 PyArray_NewFromDescr 或基于它的更简单的宏或函数,构造一个具有你所需形状和类型的新的 ndarray。

  2. 获取数组的形状和指向其实际数据的指针。

  3. 将数据和形状信息传递给实际执行计算的子例程或代码的其他部分。

  4. 如果你正在编写算法,那么我建议你使用数组中包含的步长信息来访问数组的元素(PyArray_GetPtr 宏使这变得轻而易举)。然后,你可以放宽你的要求,以避免强制使用单段数组以及可能导致的数据复制。

以下各小节将介绍这些子主题。

转换任意序列对象#

从任何可以转换为数组的 Python 对象获取数组的主要例程是 PyArray_FromAny。此函数非常灵活,具有许多输入参数。几个宏使使用基本函数变得更容易。PyArray_FROM_OTF 可以说是这些宏中最有用的一个,适用于最常见的使用场景。它允许你将任意 Python 对象转换为特定内置数据类型(例如 float)的数组,同时指定一组特定的要求(例如 连续、对齐和可写)。语法如下:

PyArray_FROM_OTF

从任何可以转换为数组的 Python 对象 obj 返回一个 ndarray。返回数组中的维度数量由对象决定。返回数组的所需数据类型在 typenum 中提供,它应该是枚举类型之一。返回数组的 requirements 可以是标准数组标志的任何组合。下面将详细解释每个参数。成功时,你会收到对该数组的新引用。失败时,将返回 NULL 并设置异常。

obj

该对象可以是任何可转换为 ndarray 的 Python 对象。如果该对象已经是(ndarray 的子类)并且满足要求,则返回一个新的引用。否则,将构造一个新的数组。除非使用数组接口,否则 obj 的内容将被复制到新数组中,这样数据就不需要复制。可以转换为数组的对象包括:1) 任何嵌套序列对象,2) 任何公开数组接口的对象,3) 任何具有 __array__ 方法的对象(该方法应返回一个 ndarray),以及 4) 任何标量对象(变为零维数组)。满足要求的 ndarray 的子类将被传递。如果你想确保基础 ndarray,则在要求标志中使用 NPY_ARRAY_ENSUREARRAY。仅在必要时进行复制。如果你想保证复制,则将 NPY_ARRAY_ENSURECOPY 传递给要求标志。

typenum

枚举类型之一或 NPY_NOTYPE(如果数据类型应由对象本身确定)。可以使用基于 C 的名称

或者,可以使用平台上支持的位宽名称。例如

仅当可以在不丢失精度的情况下将对象转换为所需类型时,才会进行转换。否则,将返回 NULL 并引发错误。在要求标志中使用 NPY_ARRAY_FORCECAST 可覆盖此行为。

requirements

ndarray 的内存模型允许在每个维度中使用任意步长来前进到数组的下一个元素。但是,通常情况下,你需要与期望 C 连续或 Fortran 连续内存布局的代码进行交互。此外,ndarray 可能会未对齐(元素的地址不是元素大小的整数倍数),如果尝试取消对数组数据中指针的引用,这会导致程序崩溃(或至少工作速度变慢)。这两个问题都可以通过将 Python 对象转换为更适合你特定用途的数组来解决。

requirements 标志允许指定哪种数组是可以接受的。如果传入的对象不满足此要求,则会进行复制,以便返回的对象将满足要求。这些 ndarray 可以使用非常通用的内存指针。此标志允许指定返回的数组对象的所需属性。所有标志都在详细的 API 章节中进行了说明。最常需要的标志是 NPY_ARRAY_IN_ARRAYNPY_ARRAY_OUT_ARRAYNPY_ARRAY_INOUT_ARRAY

NPY_ARRAY_IN_ARRAY

此标志对于必须以 C 连续顺序且对齐的数组很有用。这些类型的数组通常是某些算法的输入数组。

NPY_ARRAY_OUT_ARRAY

此标志用于指定一个数组,该数组按 C 连续顺序排列,已对齐,并且也可以写入。这样的数组通常作为输出返回(尽管通常此类输出数组是从头开始创建的)。

NPY_ARRAY_INOUT_ARRAY

此标志用于指定一个将同时用于输入和输出的数组。在接口例程结束时,在调用 Py_DECREF 之前,必须调用 PyArray_ResolveWritebackIfCopy 以将临时数据写回传入的原始数组。使用 NPY_ARRAY_WRITEBACKIFCOPY 标志要求输入对象已经是数组(因为其他对象无法以这种方式自动更新)。如果发生错误,请对设置了这些标志的数组使用 PyArray_DiscardWritebackIfCopy (obj)。这将设置底层基础数组的可写性,而不会导致内容复制回原始数组。

其他可作为附加要求进行 OR 操作的有用标志是

NPY_ARRAY_FORCECAST

强制转换为所需类型,即使这样做会导致信息丢失。

NPY_ARRAY_ENSURECOPY

确保结果数组是原始数组的副本。

NPY_ARRAY_ENSUREARRAY

确保结果对象是实际的 ndarray 而不是子类。

注意

数组是否字节交换由数组的数据类型决定。 PyArray_FROM_OTF 始终请求本机字节序数组,因此在 requirements 参数中不需要 NPY_ARRAY_NOTSWAPPED 标志。也无法从此例程获取字节交换数组。

创建一个全新的 ndarray#

通常,必须从扩展模块代码中创建新的数组。也许需要一个输出数组,并且您不希望调用者提供它。也许只需要一个临时数组来保存中间计算。无论需要什么,都有简单的方法来获取所需任何数据类型的 ndarray 对象。执行此操作的最通用函数是 PyArray_NewFromDescr。所有数组创建函数都通过这段大量重复使用的代码。由于其灵活性,使用起来可能有点令人困惑。因此,存在一些更易于使用的更简单的形式。这些形式是 PyArray_SimpleNew 函数系列的一部分,这些函数通过为常见用例提供默认值来简化接口。

获取 ndarray 内存并访问 ndarray 的元素#

如果 obj 是一个 ndarray (PyArrayObject*),则 ndarray 的数据区域由 void* 指针 PyArray_DATA (obj) 或 char* 指针 PyArray_BYTES (obj) 指向。请记住(一般而言),此数据区域可能未根据数据类型对齐,它可能表示字节交换数据,并且/或者它可能不可写。如果数据区域已对齐且为本机字节序,则获取数组特定元素的方法仅由 npy_intp 变量数组 PyArray_STRIDES (obj) 决定。特别是,此整数 c 数组显示必须向当前元素指针添加多少**字节**才能到达每个维度中的下一个元素。对于小于 4 维的数组,存在 PyArray_GETPTR{k} (obj, …) 宏,其中 {k} 是整数 1、2、3 或 4,这使得使用数组步幅更容易。参数 …. 表示数组中 {k} 个非负整数索引。例如,假设 E 是一个 3 维 ndarray。指向元素 E[i,j,k] 的 (void*) 指针可通过 PyArray_GETPTR3 (E, i, j, k) 获取。

如前所述,C 样式连续数组和 Fortran 样式连续数组具有特定的步幅模式。两个数组标志 (NPY_ARRAY_C_CONTIGUOUSNPY_ARRAY_F_CONTIGUOUS) 指示特定数组的步幅模式是否与 C 样式连续或 Fortran 样式连续匹配,或者两者都不匹配。可以使用 PyArray_IS_C_CONTIGUOUS (obj) 和 PyArray_ISFORTRAN (obj) 分别测试步幅模式是否与标准 C 或 Fortran 模式匹配。大多数第三方库都期望连续数组。但是,通常支持通用步幅并不困难。我鼓励您在自己的代码中尽可能地使用步幅信息,并将单段要求保留用于包装第三方代码。使用 ndarray 提供的步幅信息而不是要求连续步幅可以减少必须进行的复制操作。

示例#

以下示例显示了如何编写一个包装器,该包装器接受两个输入参数(将转换为数组)和一个输出参数(必须是数组)。该函数返回 None 并更新输出数组。请注意,NumPy v1.14 及更高版本更新了 WRITEBACKIFCOPY 语义的使用

static PyObject *
example_wrapper(PyObject *dummy, PyObject *args)
{
    PyObject *arg1=NULL, *arg2=NULL, *out=NULL;
    PyObject *arr1=NULL, *arr2=NULL, *oarr=NULL;

    if (!PyArg_ParseTuple(args, "OOO!", &arg1, &arg2,
        &PyArray_Type, &out)) return NULL;

    arr1 = PyArray_FROM_OTF(arg1, NPY_DOUBLE, NPY_ARRAY_IN_ARRAY);
    if (arr1 == NULL) return NULL;
    arr2 = PyArray_FROM_OTF(arg2, NPY_DOUBLE, NPY_ARRAY_IN_ARRAY);
    if (arr2 == NULL) goto fail;
#if NPY_API_VERSION >= 0x0000000c
    oarr = PyArray_FROM_OTF(out, NPY_DOUBLE, NPY_ARRAY_INOUT_ARRAY2);
#else
    oarr = PyArray_FROM_OTF(out, NPY_DOUBLE, NPY_ARRAY_INOUT_ARRAY);
#endif
    if (oarr == NULL) goto fail;

    /* code that makes use of arguments */
    /* You will probably need at least
       nd = PyArray_NDIM(<..>)    -- number of dimensions
       dims = PyArray_DIMS(<..>)  -- npy_intp array of length nd
                                     showing length in each dim.
       dptr = (double *)PyArray_DATA(<..>) -- pointer to data.

       If an error occurs goto fail.
     */

    Py_DECREF(arr1);
    Py_DECREF(arr2);
#if NPY_API_VERSION >= 0x0000000c
    PyArray_ResolveWritebackIfCopy(oarr);
#endif
    Py_DECREF(oarr);
    Py_INCREF(Py_None);
    return Py_None;

 fail:
    Py_XDECREF(arr1);
    Py_XDECREF(arr2);
#if NPY_API_VERSION >= 0x0000000c
    PyArray_DiscardWritebackIfCopy(oarr);
#endif
    Py_XDECREF(oarr);
    return NULL;
}