如何扩展 NumPy#
编写扩展模块#
尽管 ndarray 对象旨在允许在 Python 中进行快速计算,但它也被设计为通用型并满足各种计算需求。因此,如果绝对速度至关重要,那么针对您的应用程序和硬件精心编写的编译循环是无可替代的。这是 NumPy 包含 f2py 的原因之一,以便提供易于使用的机制,将(简单)C/C++ 和(任意)Fortran 代码直接链接到 Python。我们鼓励您使用和改进这种机制。本节的目的不是记录此工具,而是记录此工具所依赖的编写扩展模块的更基本步骤。
当一个扩展模块被编写、编译并安装到 Python 路径 (sys.path) 中的某个位置时,代码就可以像标准 Python 文件一样被导入到 Python 中。它将包含在 C 代码中定义和编译的对象和方法。在 Python 中执行此操作的基本步骤都有详细记录,您可以在 www.python.org 在线获取的 Python 自身文档中找到更多信息。
除了 Python C-API 之外,NumPy 还有一个完整且丰富的 C-API,允许在 C 级别进行复杂的操控。然而,对于大多数应用程序,通常只会使用少数几个 API 调用。例如,如果您只需要提取内存指针和一些形状信息以传递给另一个计算例程,那么您将使用与尝试创建新的类数组类型或为 ndarray 添加新的数据类型时非常不同的调用。本章记录了最常用的 API 调用和宏。
必需的子例程#
为了让 Python 将其用作扩展模块,您的 C 代码中必须定义一个函数。该函数必须命名为 init{name},其中 {name} 是 Python 模块的名称。该函数必须声明为在例程外部可见。除了添加您想要的方法和常量外,此子例程还必须包含诸如 import_array()
和/或 import_ufunc()
等调用,具体取决于所需的 C-API。一旦实际调用任何 C-API 子例程,忘记放置这些命令将表现为丑陋的段错误(崩溃)。实际上,可以在单个文件中拥有多个 init{name} 函数,在这种情况下,该文件将定义多个模块。但是,有一些技巧可以使其正常工作,这里不作介绍。
一个最小的 init{name}
方法看起来像
PyMODINIT_FUNC
init{name}(void)
{
(void)Py_InitModule({name}, mymethods);
import_array();
}
mymethods 必须是 PyMethodDef 结构体的一个数组(通常是静态声明),其中包含方法名称、实际的 C 函数、一个指示方法是否使用关键字参数的变量以及文档字符串。这些将在下一节中解释。如果您想向模块添加常量,则存储 Py_InitModule 返回的值,它是一个模块对象。向模块添加项目的最通用方法是使用 PyModule_GetDict(module) 获取模块字典。通过模块字典,您可以手动将任何内容添加到模块中。一种更简单的向模块添加对象的方法是使用三个额外的 Python C-API 调用之一,它们不需要单独提取模块字典。这些在 Python 文档中有记载,但为了方便起见在此处重复列出
-
int PyModule_AddStringConstant(PyObject *module, char *name, char *value)#
这三个函数都要求提供 module 对象(Py_InitModule 的返回值)。 name 是模块中值的标签字符串。根据调用的函数, value 参数可以是一个通用对象(
PyModule_AddObject
会窃取其引用),一个整数常量,或一个字符串常量。
定义函数#
传递给 Py_InitModule 函数的第二个参数是一个结构体,它使得在模块中定义函数变得容易。在上面的示例中,mymethods 结构体将在文件前面(通常在 init{name} 子例程之前)定义为
static PyMethodDef mymethods[] = {
{ nokeywordfunc,nokeyword_cfunc,
METH_VARARGS,
Doc string},
{ keywordfunc, keyword_cfunc,
METH_VARARGS|METH_KEYWORDS,
Doc string},
{NULL, NULL, 0, NULL} /* Sentinel */
}
mymethods 数组中的每个条目都是一个 PyMethodDef
结构体,包含 1) Python 名称,2) 实现该功能的 C 函数,3) 指示该函数是否接受关键字参数的标志,以及 4) 该函数的文档字符串。通过向此表中添加更多条目,可以为单个模块定义任意数量的函数。最后一个条目必须全部为 NULL,如所示,作为哨兵。Python 会查找此条目以知道模块的所有函数都已定义。
完成扩展模块的最后一件事是实际编写执行所需功能的代码。有两种类型的函数:不接受关键字参数的函数,和接受关键字参数的函数。
无关键字参数的函数#
不接受关键字参数的函数应该这样编写
static PyObject*
nokeyword_cfunc (PyObject *dummy, PyObject *args)
{
/* convert Python arguments */
/* do function */
/* return something */
}
在此上下文中,dummy 参数未使用,可以安全忽略。 args 参数包含作为元组传递给函数的所有参数。此时您可以做任何您想做的事情,但通常管理输入参数最简单的方法是调用 PyArg_ParseTuple
(args, format_string, addresses_to_C_variables…) 或 PyArg_UnpackTuple
(tuple, “name”, min, max, …)。关于如何使用第一个函数的详细描述包含在 Python C-API 参考手册的 5.5 节(解析参数和构建值)中。您应该特别注意 “O&” 格式,它使用转换函数在 Python 对象和 C 对象之间进行转换。所有其他格式函数都可以(大部分)被认为是此通用规则的特例。NumPy C-API 中定义了几个可能有用转换函数。特别是,PyArray_DescrConverter
函数对于支持任意数据类型规范非常有用。此函数将任何有效的数据类型 Python 对象转换为 PyArray_Descr* 对象。请记住传入应该填充的 C 变量的地址。
在 NumPy 源代码中有很多关于如何使用 PyArg_ParseTuple
的示例。标准用法如下
PyObject *input;
PyArray_Descr *dtype;
if (!PyArg_ParseTuple(args, "OO&", &input,
PyArray_DescrConverter,
&dtype)) return NULL;
重要的是要记住,当使用“O”格式字符串时,您会获得对对象的借用引用。但是,转换函数通常需要某种形式的内存处理。在此示例中,如果转换成功,dtype 将持有对 PyArray_Descr* 对象的新引用,而 input 将持有借用引用。因此,如果此转换与另一个转换(例如转换为整数)混合,并且数据类型转换成功但整数转换失败,则在返回之前您需要释放数据类型对象的引用计数。一个典型的方法是在调用 PyArg_ParseTuple
之前将 dtype 设置为 NULL
,然后在返回之前对 dtype 使用 Py_XDECREF
。
输入参数处理完毕后,编写实际执行工作的代码(可能根据需要调用其他函数)。C 函数的最后一步是返回一些东西。如果遇到错误,则应返回 NULL
(确保已实际设置错误)。如果不需要返回任何东西,则递增 Py_None
并返回它。如果应返回单个对象,则返回该对象(首先确保您拥有其引用)。如果应返回多个对象,则需要返回一个元组。Py_BuildValue
(format_string, c_variables…) 函数可以轻松地从 C 变量构建 Python 对象的元组。请特别注意格式字符串中“N”和“O”之间的区别,否则很容易造成内存泄漏。“O”格式字符串将相应 PyObject* C 变量的引用计数增加一,而“N”格式字符串则窃取相应 PyObject* C 变量的引用。如果您已经为对象创建了引用并只想将该引用提供给元组,则应使用“N”。如果您只拥有对对象的借用引用并且需要创建一个以提供给元组,则应使用“O”。
带关键字参数的函数#
这些函数与无关键字参数的函数非常相似。唯一的区别是函数签名是
static PyObject*
keyword_cfunc (PyObject *dummy, PyObject *args, PyObject *kwds)
{
...
}
kwds 参数持有一个 Python 字典,其键是关键字参数的名称,其值是相应的关键字参数值。您可以根据需要处理此字典。然而,最简单的处理方式是将 PyArg_ParseTuple
(args, format_string, addresses…) 函数替换为调用 PyArg_ParseTupleAndKeywords
(args, kwds, format_string, char *kwlist[], addresses…)。此函数的 kwlist 参数是一个以 NULL
结尾的字符串数组,提供预期的关键字参数。format_string 中的每个条目都应该有一个字符串。如果传入无效的关键字参数,此函数将引发 TypeError。
有关此函数的更多帮助,请参阅 Python 文档中“扩展和嵌入教程”的第 1.8 节(扩展函数的关键字参数)。
引用计数#
编写扩展模块时最大的困难是引用计数。这是 f2py、weave、Cython、ctypes 等工具流行的重要原因。如果您错误处理引用计数,可能会导致从内存泄漏到段错误的问题。我所知道的唯一正确处理引用计数的策略是血、汗和泪。首先,您必须牢记每个 Python 变量都有一个引用计数。然后,您需要准确理解每个函数对您对象的引用计数做了什么,以便在需要时正确使用 DECREF 和 INCREF。引用计数确实可以考验您对编程技艺的耐心和勤奋程度。尽管描绘得很严峻,但大多数引用计数的情况都相当简单,最常见的困难是没有在由于某种错误而提前退出例程之前对对象使用 DECREF。其次,是常见的错误,即没有拥有传递给将窃取引用的函数或宏(例如 PyTuple_SET_ITEM
以及大多数接受 PyArray_Descr
对象的函数)的对象的引用。
通常,当一个变量被创建或作为某个函数的返回值时(然而有一些显著的例外——例如从元组或字典中获取项),您会获得该变量的新引用。当您拥有该引用时,您有责任确保在不再需要该变量时(且没有其他函数“窃取”其引用)调用 Py_DECREF
(var)。此外,如果您将一个 Python 对象传递给一个将“窃取”引用的函数,那么您需要确保您拥有它(或使用 Py_INCREF
来获取您自己的引用)。您还会遇到借用引用的概念。借用引用的函数不会改变对象的引用计数,也不期望“持有”该引用。它只是临时使用该对象。当您使用 PyArg_ParseTuple
或 PyArg_UnpackTuple
时,您会收到对元组中对象的借用引用,并且不应在函数内部改变它们的引用计数。通过实践,您可以学会正确处理引用计数,但这在开始时可能会令人沮丧。
引用计数错误的一个常见来源是 Py_BuildValue
函数。请仔细注意“N”格式字符和“O”格式字符之间的区别。如果您在子例程中创建了一个新对象(例如一个输出数组),并且您将其作为返回值的元组传回,那么您很可能应该在 Py_BuildValue
中使用“N”格式字符。“O”字符会将引用计数增加一。这将导致调用者对一个全新数组拥有两个引用计数。当变量被删除且引用计数减一后,仍然会多出一个引用计数,并且数组将永远不会被释放。您将遇到由引用计数导致的内存泄漏。使用“N”字符将避免这种情况,因为它将返回给调用者一个(元组内部)具有单个引用计数的对象。
处理数组对象#
大多数 NumPy 扩展模块都需要访问 ndarray 对象(或其子类)的内存。最简单的方法不需要您了解 NumPy 的内部细节。方法是
确保您正在处理正确类型和维度的、行为良好的数组(对齐、机器字节序且单段)。
通过使用
PyArray_FromAny
或基于它构建的宏,将其从某个 Python 对象转换而来。通过使用
PyArray_NewFromDescr
或基于它构建的更简单的宏或函数,构造一个所需形状和类型的新 ndarray。
获取数组的形状和指向其实际数据的指针。
将数据和形状信息传递给实际执行计算的子例程或其他代码段。
如果您正在编写算法,我建议您使用数组中包含的步长信息来访问数组元素(
PyArray_GetPtr
宏使这变得简单)。然后,您可以放宽要求,以便不强制使用单段数组并避免可能导致的数据复制。
以下小节将涵盖这些子主题。
转换任意序列对象#
从任何可以转换为数组的 Python 对象获取数组的主要例程是 PyArray_FromAny
。此函数非常灵活,有许多输入参数。一些宏使其更易于使用基本函数。PyArray_FROM_OTF
可谓是这些宏中最有用的一种,适用于最常见的用途。它允许您将任意 Python 对象转换为特定内置数据类型(例如浮点数)的数组,同时指定一组特定的要求(例如连续、对齐和可写)。语法是
PyArray_FROM_OTF
从任何可转换为数组的 Python 对象 obj 返回一个 ndarray。返回数组的维度数由对象确定。返回数组所需的精确数据类型在 typenum 中提供,它应该是枚举类型之一。返回数组的 requirements 可以是标准数组标志的任意组合。这些参数的每个方面将在下面更详细地解释。成功时,您将获得数组的一个新引用。失败时,返回
NULL
并设置异常。- obj
该对象可以是任何可转换为 ndarray 的 Python 对象。如果该对象已经是(或其子类)满足要求的 ndarray,则返回一个新的引用。否则,将构造一个新的数组。除非使用数组接口,否则 obj 的内容将复制到新数组中,因此数据无需复制。可以转换为数组的对象包括:1) 任何嵌套序列对象,2) 任何公开数组接口的对象,3) 任何具有
__array__
方法(应返回 ndarray)的对象,以及 4) 任何标量对象(变为零维数组)。满足要求的 ndarray 的子类将直接通过。如果您想确保是基类 ndarray,则在要求标志中使用NPY_ARRAY_ENSUREARRAY
。仅在必要时才进行复制。如果您想保证复制,则将NPY_ARRAY_ENSURECOPY
传递给要求标志。- typenum
如果数据类型应由对象本身确定,则为枚举类型之一或
NPY_NOTYPE
。可以使用基于 C 的名称或者,可以使用平台上支持的位宽名称。例如
只有在不丢失精度的情况下才能将对象转换为所需类型。否则将返回
NULL
并引发错误。在要求标志中使用NPY_ARRAY_FORCECAST
来覆盖此行为。- 要求
ndarray 的内存模型允许在每个维度中任意步长,以前进到数组的下一个元素。然而,通常您需要与期望 C 连续或 Fortran 连续内存布局的代码进行接口。此外,ndarray 可能未对齐(元素的地址不是元素大小的整数倍),如果您尝试解引用数组数据中的指针,这可能会导致您的程序崩溃(或至少运行得更慢)。通过将 Python 对象转换为对您的特定用法更“行为良好”的数组,可以解决这两个问题。
要求标志允许指定可接受的数组类型。如果传入的对象不满足这些要求,则会创建一个副本,以便返回的对象将满足这些要求。这些 ndarray 可以使用一个非常通用的内存指针。此标志允许指定返回数组对象的所需属性。所有标志都在详细的 API 章中解释。最常用的标志是
NPY_ARRAY_IN_ARRAY
、NPY_ARRAY_OUT_ARRAY
和NPY_ARRAY_INOUT_ARRAY
NPY_ARRAY_IN_ARRAY
此标志对于必须以 C 连续顺序和对齐的数组很有用。这些类型的数组通常是某些算法的输入数组。
NPY_ARRAY_OUT_ARRAY
此标志用于指定一个以 C 连续顺序、对齐且可写入的数组。这样的数组通常作为输出返回(尽管通常这样的输出数组是从零开始创建的)。
NPY_ARRAY_INOUT_ARRAY
此标志用于指定一个将同时用于输入和输出的数组。在接口例程结束时,必须在
Py_DECREF
之前调用PyArray_ResolveWritebackIfCopy
,以将临时数据写回传入的原始数组中。使用NPY_ARRAY_WRITEBACKIFCOPY
标志要求输入对象已经是数组(因为其他对象不能以这种方式自动更新)。如果发生错误,对设置了这些标志的数组使用PyArray_DiscardWritebackIfCopy
(obj)。这将使底层基数组可写,而不会导致内容被复制回原始数组。
可以作为额外要求进行或运算的其他有用标志包括
NPY_ARRAY_FORCECAST
强制转换为所需类型,即使在不丢失信息的情况下无法完成。
NPY_ARRAY_ENSURECOPY
确保生成的数组是原始数组的副本。
NPY_ARRAY_ENSUREARRAY
确保结果对象是实际的 ndarray 而不是子类。
注意
数组是否字节交换由数组的数据类型决定。本地字节序数组总是通过 PyArray_FROM_OTF
请求的,因此在 requirements 参数中不需要 NPY_ARRAY_NOTSWAPPED
标志。也无法从该例程获取字节交换数组。
创建全新的 ndarray#
通常,必须在扩展模块代码中创建新数组。也许需要一个输出数组,而您不想让调用者提供它。也许只需要一个临时数组来保存中间计算。无论需求是什么,都有简单的方法可以获取所需数据类型的 ndarray 对象。最通用的函数是 PyArray_NewFromDescr
。所有数组创建函数都通过此高度重用的代码。由于其灵活性,使用起来可能有些混乱。因此,存在更简单的形式,更易于使用。这些形式是 PyArray_SimpleNew
函数家族的一部分,它们通过为常见用例提供默认值来简化接口。
获取 ndarray 内存并访问 ndarray 元素#
如果 obj 是一个 ndarray (PyArrayObject*),则 ndarray 的数据区域由 void* 指针 PyArray_DATA
(obj) 或 char* 指针 PyArray_BYTES
(obj) 指向。请记住,(通常)此数据区域可能未根据数据类型对齐,它可能表示字节交换数据,和/或它可能不可写入。如果数据区域已对齐且为本机字节序,则如何访问数组的特定元素仅由 npy_intp 变量数组 PyArray_STRIDES
(obj) 确定。特别是,此整数的 c 数组显示了必须添加到当前元素指针的多少字节才能到达每个维度中的下一个元素。对于小于 4 维的数组,有 PyArray_GETPTR{k}
(obj, …) 宏,其中 {k} 是整数 1、2、3 或 4,它们使使用数组步长更容易。参数 … 表示数组中的 {k} 个非负整数索引。例如,假设 E
是一个 3 维 ndarray。指向元素 E[i,j,k]
的 (void*) 指针通过 PyArray_GETPTR3
(E, i, j, k) 获得。
如前所述,C 风格连续数组和 Fortran 风格连续数组具有特定的步长模式。两个数组标志(NPY_ARRAY_C_CONTIGUOUS
和 NPY_ARRAY_F_CONTIGUOUS
)指示特定数组的步长模式是否与 C 风格连续或 Fortran 风格连续匹配,或者两者都不匹配。是否与标准 C 或 Fortran 步长模式匹配可以使用 PyArray_IS_C_CONTIGUOUS
(obj) 和 PyArray_ISFORTRAN
(obj) 分别进行测试。大多数第三方库都期望连续数组。但是,通常支持通用步长并不困难。我鼓励您尽可能在自己的代码中使用步长信息,并将单段要求保留用于封装第三方代码。使用 ndarray 提供的步长信息而不是要求连续步长可以减少原本必须进行的复制。
示例#
以下示例展示了您如何编写一个包装器,它接受两个输入参数(将转换为数组)和一个输出参数(必须是数组)。该函数返回 None 并更新输出数组。请注意 NumPy v1.14 及更高版本对 WRITEBACKIFCOPY 语义的更新用法
static PyObject *
example_wrapper(PyObject *dummy, PyObject *args)
{
PyObject *arg1=NULL, *arg2=NULL, *out=NULL;
PyObject *arr1=NULL, *arr2=NULL, *oarr=NULL;
if (!PyArg_ParseTuple(args, "OOO!", &arg1, &arg2,
&PyArray_Type, &out)) return NULL;
arr1 = PyArray_FROM_OTF(arg1, NPY_DOUBLE, NPY_ARRAY_IN_ARRAY);
if (arr1 == NULL) return NULL;
arr2 = PyArray_FROM_OTF(arg2, NPY_DOUBLE, NPY_ARRAY_IN_ARRAY);
if (arr2 == NULL) goto fail;
#if NPY_API_VERSION >= 0x0000000c
oarr = PyArray_FROM_OTF(out, NPY_DOUBLE, NPY_ARRAY_INOUT_ARRAY2);
#else
oarr = PyArray_FROM_OTF(out, NPY_DOUBLE, NPY_ARRAY_INOUT_ARRAY);
#endif
if (oarr == NULL) goto fail;
/* code that makes use of arguments */
/* You will probably need at least
nd = PyArray_NDIM(<..>) -- number of dimensions
dims = PyArray_DIMS(<..>) -- npy_intp array of length nd
showing length in each dim.
dptr = (double *)PyArray_DATA(<..>) -- pointer to data.
If an error occurs goto fail.
*/
Py_DECREF(arr1);
Py_DECREF(arr2);
#if NPY_API_VERSION >= 0x0000000c
PyArray_ResolveWritebackIfCopy(oarr);
#endif
Py_DECREF(oarr);
Py_INCREF(Py_None);
return Py_None;
fail:
Py_XDECREF(arr1);
Py_XDECREF(arr2);
#if NPY_API_VERSION >= 0x0000000c
PyArray_DiscardWritebackIfCopy(oarr);
#endif
Py_XDECREF(oarr);
return NULL;
}