如何扩展NumPy#

静止和重复是枯燥的。动态的
和随机是令人困惑的。中间是艺术。
John A. Locke
科学是一阶微分方程。宗教是一个边界条件。
Alan Turing

编写扩展模块#

虽然 ndarray 对象旨在允许 Python 中的快速计算,但它也旨在通用并满足各种计算需求。因此,如果绝对速度至关重要,那么没有比针对您的应用程序和硬件精心制作的编译循环更好的了。这也是 numpy 包含 f2py 的原因之一,它提供了一种易于使用的机制,可以将(简单的)C/C++ 和(任意的)Fortran 代码直接链接到 Python 中。我们鼓励您使用和改进此机制。本节的目的不是记录此工具,而是记录此工具所依赖的编写扩展模块的更基本步骤。

当扩展模块被编写、编译并安装到 Python 路径 (sys.path) 中的某个位置时,就可以像导入标准的 python 文件一样在 Python 中导入该代码。它将包含用 C 代码定义和编译的对象和方法。执行此操作的基本步骤在 Python 中已有很好的记录,您可以在 Python 本身的文档中找到更多信息,在线地址为 www.python.org

除了 Python C-API 之外,NumPy 还有一个完整而丰富的 C-API,可以进行复杂的 C 级操作。但是,对于大多数应用程序,通常只会使用几个 API 调用。例如,如果您只需要提取内存指针以及一些形状信息以传递给另一个计算例程,那么您将使用与尝试创建新的类似数组的类型或为 ndarray 添加新的数据类型非常不同的调用。本章记录了最常用的 API 调用和宏。

必需的子程序#

为了让 Python 将您的 C 代码用作扩展模块,您的 C 代码中必须恰好定义一个函数。该函数必须命名为 init{name},其中 {name} 是 Python 中模块的名称。必须声明此函数,使其对该例程外部的代码可见。除了添加您想要的函数和常量之外,此子程序还必须包含 import_array() 和/或 import_ufunc() 等调用,具体取决于所需的 C-API。忘记放置这些命令将在任何 C-API 子程序实际被调用时导致丑陋的段错误(崩溃)。实际上可以在单个文件中拥有多个 init{name} 函数,在这种情况下,该文件将定义多个模块。但是,有一些技巧可以使此工作正常进行,这里不涵盖。

一个最小的 init{name} 方法看起来像

PyMODINIT_FUNC
init{name}(void)
{
   (void)Py_InitModule({name}, mymethods);
   import_array();
}

mymethods 必须是一个数组(通常是静态声明的)PyMethodDef 结构,其中包含方法名称、实际 C 函数、指示方法是否使用关键字参数的变量以及文档字符串。这些将在下一节中进行解释。如果您想向模块添加常量,则存储 Py_InitModule 的返回值,它是一个模块对象。向模块添加项目的最通用方法是使用 PyModule_GetDict(module) 获取模块字典。有了模块字典,您可以手动向模块添加任何您喜欢的内容。向模块添加对象的更简单方法是使用三个额外的 Python C-API 调用之一,它们不需要单独提取模块字典。这些在 Python 文档中有记录,但为方便起见在此重复:

int PyModule_AddObject(PyObject *module, char *name, PyObject *value)#
int PyModule_AddIntConstant(PyObject *module, char *name, long value)#
int PyModule_AddStringConstant(PyObject *module, char *name, char *value)#

所有这三个函数都需要 module 对象(Py_InitModule 的返回值)。name 是一个字符串,用于标记模块中的值。根据调用哪个函数,value 参数是要么是通用对象(PyModule_AddObject 会获取其引用),要么是整数常量,要么是字符串常量。

定义函数#

传递给 Py_InitModule 函数的第二个参数是一个结构,可以方便地在模块中定义函数。在上面的示例中,mymethods 结构应该在文件前面(通常在 init{name} 子程序之前)定义为

static PyMethodDef mymethods[] = {
    { nokeywordfunc,nokeyword_cfunc,
      METH_VARARGS,
      Doc string},
    { keywordfunc, keyword_cfunc,
      METH_VARARGS|METH_KEYWORDS,
      Doc string},
    {NULL, NULL, 0, NULL} /* Sentinel */
}

mymethods 数组中的每个条目都是一个 PyMethodDef 结构,包含 1) Python 名称,2) 实现函数的 C 函数,3) 指示函数是否接受关键字的标志,以及 4) 函数的文档字符串。可以通过向此表添加更多条目来为单个模块定义任意数量的函数。最后一个条目必须全部为 NULL,如所示,以作为哨兵。Python 会查找此条目以知道模块的所有函数都已定义。

完成扩展模块必须做的最后一件事是实际编写执行所需功能的代码。有两种函数:不接受关键字参数的函数和接受关键字参数的函数。

不带关键字参数的函数#

不接受关键字参数的函数应编写为

static PyObject*
nokeyword_cfunc (PyObject *dummy, PyObject *args)
{
    /* convert Python arguments */
    /* do function */
    /* return something */
}

dummy 参数在此上下文中未使用,可以安全地忽略。args 参数将传递给函数的所有参数作为元组包含。此时您可以做任何您想做的事情,但通常管理输入参数最简单的方法是调用 PyArg_ParseTuple (args, format_string, addresses_to_C_variables…) 或 PyArg_UnpackTuple (tuple, “name”, min, max, …)。Python C-API 参考手册第 5.5 节(解析参数和构建值)中对此第一个函数的用法进行了很好的描述。您应特别注意“O&”格式,它使用转换器函数在 Python 对象和 C 对象之间进行转换。所有其他格式函数(大部分)都可以视为此通用规则的特例。NumPy C-API 中定义了几个可能使用的转换器函数。特别是,PyArray_DescrConverter 函数对于支持任意数据类型规范非常有用。此函数将任何有效的数据类型 Python 对象转换为 PyArray_Descr* 对象。请记住传入应该被填充的 C 变量的地址。

NumPy 源代码中有许多关于如何使用 PyArg_ParseTuple 的示例。标准的用法如下

PyObject *input;
PyArray_Descr *dtype;
if (!PyArg_ParseTuple(args, "OO&", &input,
                      PyArray_DescrConverter,
                      &dtype)) return NULL;

在使用“O”格式字符串时,重要的是要记住您获得的是对象的借用引用。但是,转换器函数通常需要某种形式的内存处理。在此示例中,如果转换成功,dtype 将持有对 PyArray_Descr* 对象的*新*引用,而 input 将持有借用引用。因此,如果此转换与其他转换(例如转换为整数)混合,并且数据类型转换成功但整数转换失败,那么您需要在返回之前释放数据类型对象的引用计数。一种典型的方法是在调用 PyArg_ParseTuple 之前将 dtype 设置为 NULL,然后在使用 Py_XDECREF 处理 dtype 后再返回。 典型的做法是在调用 PyArg_ParseTuple 之前将 dtype 设置为 NULL,然后在返回之前使用 Py_XDECREF 处理 dtype

在处理完输入参数后,就编写实际执行工作的代码(可能需要调用其他函数)。C 函数的最后一步是返回某些内容。如果遇到错误,则应返回 NULL(确保已设置错误)。如果什么都不应该返回,则增加 Py_None 并返回它。如果要返回单个对象,则返回它(确保您首先拥有其引用)。如果要返回多个对象,则需要返回一个元组。Py_BuildValue (format_string, c_variables…) 函数可以方便地从 C 变量构建 Python 对象元组。请特别注意格式字符串中“N”和“O”之间的区别,否则很容易造成内存泄漏。“O”格式字符串会增加对应 C 变量的 PyObject* 的引用计数,而“N”格式字符串会“窃取”对应 PyObject* C 变量的引用。如果您已经为该对象创建了引用并只想将该引用交给元组,则应使用“N”。如果您只有一个对象的借用引用,则需要创建引用以提供给元组,这时应使用“O”。

带关键字参数的函数#

这些函数与不带关键字参数的函数非常相似。唯一的区别在于函数签名是

static PyObject*
keyword_cfunc (PyObject *dummy, PyObject *args, PyObject *kwds)
{
...
}

kwds 参数包含一个 Python 字典,其键是关键字参数的名称,其值是相应的关键字参数值。此字典可以按您认为合适的方式进行处理。然而,最简单的处理方法是将 PyArg_ParseTuple (args, format_string, addresses…) 函数替换为调用 PyArg_ParseTupleAndKeywords (args, kwds, format_string, char *kwlist[], addresses…)。此函数的 kwlist 参数是一个以 NULL 结尾的字符串数组,提供了预期的关键字参数。对于 format_string 中的每个条目都应该有一个字符串。使用此函数会在传递无效关键字参数时引发 TypeError。

有关此函数的更多帮助,请参阅 Python 文档中“扩展和嵌入教程”的第 1.8 节(扩展函数的关键字参数)。

引用计数#

编写扩展模块时最大的困难是引用计数。这是 f2py、weave、Cython、ctypes 等流行的原因。如果您错误地处理引用计数,可能会导致从内存泄漏到段错误的问题。我知道的唯一正确处理引用计数的策略是付出的努力。首先,您要强迫自己记住每个 Python 变量都有一个引用计数。然后,您要精确理解每个函数如何影响对象的引用计数,以便在需要时正确使用 DECREF 和 INCREF。引用计数确实可以考验您对编程技艺的耐心和勤奋程度。尽管描述严峻,但大多数引用计数的情况都相当直接,最常见的困难是由于某种错误而过早退出例程时没有使用 DECREF。其次,常见错误是未拥有传递给将“窃取”引用的函数或宏的对象(*例如*,PyTuple_SET_ITEM,以及大多数接受 PyArray_Descr 对象的函数)。

通常,当创建变量或函数返回值时(*尽管有一些显著的例外——例如从元组或字典中获取项*),您会获得变量的新引用。当您拥有引用时,您有责任确保在不再需要变量时调用 Py_DECREF (var)(并且没有其他函数“窃取”其引用)。另外,如果您将 Python 对象传递给将“窃取”引用的函数,那么您需要确保您拥有它(或使用 Py_INCREF 来获取自己的引用)。您还会遇到“借用引用”的概念。借用引用的函数不会改变对象的引用计数,并且不期望“持有”该引用。它只是临时使用该对象。当您使用 PyArg_ParseTuplePyArg_UnpackTuple 时,您会收到元组中对象的借用引用,并且不应在函数内部更改它们的引用计数。通过实践,您可以学会正确处理引用计数,但一开始可能会令人沮丧。

引用计数错误的一个常见来源是 Py_BuildValue 函数。请仔细注意“N”格式字符和“O”格式字符之间的区别。如果您在子程序中创建了一个新对象(例如输出数组),并且将其作为返回值的元组传回,那么您很可能应该在 Py_BuildValue 中使用“N”格式字符。“O”字符会将引用计数加一。这会让调用者对一个全新的数组有两个引用计数。当变量被删除并将引用计数减一后,仍将存在额外的引用计数,并且数组将永远不会被释放。您将有一个由引用计数引起的内存泄漏。使用“N”字符可以避免这种情况,因为它将返回给调用者一个具有单个引用计数的对象(在元组内)。

处理数组对象#

NumPy 的大多数扩展模块都需要访问 ndarray 对象(或其子类)的内存。最简单的方法是,您不需要了解 NumPy 的内部知识。方法是

  1. 确保您正在处理一个行为良好(对齐、符合机器字节序且为单段)的、类型和维度正确的数组。

    1. 通过 PyArray_FromAny 或在其上构建的宏将其从任何 Python 对象转换为数组。

    2. 通过使用 PyArray_NewFromDescr 或基于它的更简单的宏或函数来构造具有所需形状和类型的新的 ndarray。

  2. 获取数组的形状以及指向其实际数据的指针。

  3. 将数据和形状信息传递给实际执行计算的子程序或其他代码段。

  4. 如果您正在编写算法,我建议您使用数组中包含的步幅信息来访问数组的元素(PyArray_GetPtr 宏使此过程变得轻松)。然后,您可以放宽要求,以避免强制使用单段数组和可能产生的数据复制。

这些子主题中的每一个都在以下子节中介绍。

转换任意序列对象#

从任何可以转换为数组的 Python 对象获取数组的主要例程是 PyArray_FromAny。此函数非常灵活,有许多输入参数。几个宏使得使用基本函数更加容易。PyArray_FROM_OTF 可能是这些宏中最常用的宏。它允许您将任意 Python 对象转换为特定内置数据类型(*例如*,float)的数组,同时指定一组特定的要求(*例如*,连续、对齐和可写)。语法是

PyArray_FROM_OTF

从任何可以转换为数组的 Python 对象 obj 返回一个 ndarray。返回数组的维度数由对象确定。返回数组的所需数据类型在 typenum 中提供,它应该是枚举类型之一。返回数组的 requirements 可以是标准数组标志的任意组合。每个参数将在下面更详细地解释。成功时,您将获得数组的新引用。失败时,将返回 NULL 并设置异常。

obj

该对象可以是任何可转换为 ndarray 的 Python 对象。如果对象已经是满足要求的 ndarray(或其子类),则返回新引用。否则,将构造一个新数组。除非使用数组接口,否则 obj 的内容将被复制到新数组中,从而无需复制数据。可以转换为数组的对象包括:1) 任何嵌套序列对象,2) 任何公开数组接口的对象,3) 任何具有 __array__ 方法的对象(应返回 ndarray),以及 4) 任何标量对象(变成零维数组)。如果满足要求的 ndarray 的子类将被传递。如果您想确保是基类 ndarray,请在要求标志中使用 NPY_ARRAY_ENSUREARRAY。仅在必要时才复制。如果您想保证复制,请将 NPY_ARRAY_ENSURECOPY 传递给要求标志。

typenum

枚举类型之一,或者如果数据类型应从对象本身确定,则为 NPY_NOTYPE。可以使用 C 风格的名称

或者,可以使用平台上支持的位宽名称。例如

只有在不丢失精度的情况下才能进行转换,对象才会被转换为所需类型。否则将返回 NULL 并引发错误。使用要求标志中的 NPY_ARRAY_FORCECAST 来覆盖此行为。

requirements

ndarray 的内存模型允许任意的步幅(stride)在每个维度上前进到数组的下一个元素。然而,通常您需要与期望 C 连续或 Fortran 连续内存布局的代码进行交互。此外,ndarray 可能未对齐(元素的地址不是元素大小的整数倍),这可能导致您的程序崩溃(或至少运行得更慢),如果您尝试解引用数组数据中的指针。所有这些问题都可以通过将 Python 对象转换为更适合您特定用例的“行为良好”的数组来解决。

requirements 标志允许指定接受哪种类型的数组。如果传入的对象不满足此要求,则会进行复制,以便返回的对象将满足要求。这些 ndarray 可以使用非常通用的内存指针。此标志允许指定返回的数组对象的所需属性。所有标志都在详细的 API 章节中进行了说明。最常需要的标志是 NPY_ARRAY_IN_ARRAYNPY_ARRAY_OUT_ARRAYNPY_ARRAY_INOUT_ARRAY

NPY_ARRAY_IN_ARRAY

此标志对于必须是 C 连续顺序且对齐的数组很有用。这类数组通常是某些算法的输入数组。

NPY_ARRAY_OUT_ARRAY

此标志对于指定一个 C 连续顺序、对齐且可写入的数组很有用。此类数组通常作为输出返回(尽管通常此类输出数组是从头开始创建的)。

NPY_ARRAY_INOUT_ARRAY

此标志对于指定一个将用于输入和输出的数组很有用。必须在 Py_DECREF(在接口例程结束时)之前调用 PyArray_ResolveWritebackIfCopy,以将临时数据写回到传入的原始数组中。使用 NPY_ARRAY_WRITEBACKIFCOPY 标志要求输入对象已经是数组(因为其他对象不能自动更新)。如果发生错误,请在具有这些标志设置的数组上调用 PyArray_DiscardWritebackIfCopy (obj)。这将把底层基础数组设置为可写,而不会将内容复制回原始数组。

可以 OR 运算作为附加要求的其他有用的标志是

NPY_ARRAY_FORCECAST

即使在可能丢失信息的情况下也要强制转换为所需类型。

NPY_ARRAY_ENSURECOPY

确保结果数组是原始数组的副本。

NPY_ARRAY_ENSUREARRAY

确保结果对象是实际的 ndarray 而不是子类。

注意

数组是否字节序颠倒由数组的数据类型决定。PyArray_FROM_OTF 始终请求本机字节序数组,因此在 requirements 参数中不需要 NPY_ARRAY_NOTSWAPPED 标志。此外,此例程无法获得字节序颠倒的数组。

创建全新的 ndarray#

在扩展模块代码中,经常需要从头开始创建新的数组。也许需要一个输出数组,而您不希望调用者提供它。也许只需要一个临时数组来保存中间计算。无论需要什么,都有简单的方法可以获得所需数据类型的 ndarray 对象。最通用的函数是 PyArray_NewFromDescr。所有数组创建函数都通过这段代码实现。由于其灵活性,使用起来可能有些令人困惑。因此,存在更易于使用的简化形式。这些形式是 PyArray_SimpleNew 系列函数的一部分,该系列函数通过为常见用例提供默认值来简化接口。

访问 ndarray 内存并访问 ndarray 的元素#

如果 obj 是一个 ndarray (PyArrayObject*),那么 ndarray 的数据区域由 void* 指针 PyArray_DATA (obj) 或 char* 指针 PyArray_BYTES (obj) 指向。请记住,(通常) 这个数据区域可能未按数据类型对齐,它可能表示字节顺序已交换的数据,以及/或者它可能不是可写的。如果数据区域已对齐且是本地字节序,那么如何访问数组的特定元素仅由 npy_intp 变量数组 PyArray_STRIDES (obj) 决定。特别是,这个 c 数组的整数显示了必须添加到当前元素指针的 **字节** 数,才能获得每个维度中的下一个元素。对于维度小于 4 的数组,有 PyArray_GETPTR{k} (obj, …) 宏,其中 {k} 是整数 1、2、3 或 4,可以更轻松地使用数组步幅。参数 …. 代表数组的 {k} 个非负整数索引。例如,假设 E 是一个 3 维 ndarray。指向元素 E[i,j,k] 的 (void*) 指针通过 PyArray_GETPTR3 (E, i, j, k) 获得。

如前所述,C 风格连续数组和 Fortran 风格连续数组具有特定的步幅模式。两个数组标志 (NPY_ARRAY_C_CONTIGUOUSNPY_ARRAY_F_CONTIGUOUS) 表明特定数组的步幅模式是否匹配 C 风格连续或 Fortran 风格连续,或两者都不是。是否匹配标准 C 或 Fortran 步幅模式可以使用 PyArray_IS_C_CONTIGUOUS (obj) 和 PyArray_ISFORTRAN (obj) 进行测试。大多数第三方库都期望连续数组。但是,支持通用步幅通常并不困难。我鼓励您在自己的代码中尽可能多地使用步幅信息,并将单段要求保留用于包装第三方代码。使用 ndarray 提供的步幅信息而不是要求连续步幅可以减少必须进行的复制。

示例#

以下示例显示了如何编写一个接受两个输入参数(将转换为数组)和一个输出参数(必须是数组)的包装器。该函数返回 None 并更新输出数组。请注意,对于 NumPy v1.14 及更高版本, WRITEBACKIFCOPY 的语义已更新。

static PyObject *
example_wrapper(PyObject *dummy, PyObject *args)
{
    PyObject *arg1=NULL, *arg2=NULL, *out=NULL;
    PyObject *arr1=NULL, *arr2=NULL, *oarr=NULL;

    if (!PyArg_ParseTuple(args, "OOO!", &arg1, &arg2,
        &PyArray_Type, &out)) return NULL;

    arr1 = PyArray_FROM_OTF(arg1, NPY_DOUBLE, NPY_ARRAY_IN_ARRAY);
    if (arr1 == NULL) return NULL;
    arr2 = PyArray_FROM_OTF(arg2, NPY_DOUBLE, NPY_ARRAY_IN_ARRAY);
    if (arr2 == NULL) goto fail;
#if NPY_API_VERSION >= 0x0000000c
    oarr = PyArray_FROM_OTF(out, NPY_DOUBLE, NPY_ARRAY_INOUT_ARRAY2);
#else
    oarr = PyArray_FROM_OTF(out, NPY_DOUBLE, NPY_ARRAY_INOUT_ARRAY);
#endif
    if (oarr == NULL) goto fail;

    /* code that makes use of arguments */
    /* You will probably need at least
       nd = PyArray_NDIM(<..>)    -- number of dimensions
       dims = PyArray_DIMS(<..>)  -- npy_intp array of length nd
                                     showing length in each dim.
       dptr = (double *)PyArray_DATA(<..>) -- pointer to data.

       If an error occurs goto fail.
     */

    Py_DECREF(arr1);
    Py_DECREF(arr2);
#if NPY_API_VERSION >= 0x0000000c
    PyArray_ResolveWritebackIfCopy(oarr);
#endif
    Py_DECREF(oarr);
    Py_INCREF(Py_None);
    return Py_None;

 fail:
    Py_XDECREF(arr1);
    Py_XDECREF(arr2);
#if NPY_API_VERSION >= 0x0000000c
    PyArray_DiscardWritebackIfCopy(oarr);
#endif
    Py_XDECREF(oarr);
    return NULL;
}