NEP 47 — 采用数组 API 标准#
- 作者:
Ralf Gommers <ralf.gommers@gmail.com>
- 作者:
Stephan Hoyer <shoyer@gmail.com>
- 作者:
Aaron Meurer <asmeurer@gmail.com>
- 状态:
已取代
- 被替换为:
- 类型:
标准轨迹
- 创建:
2021-01-21
- 决议:
https://mail.python.org/archives/list/[email protected]/message/Z6AA5CL47NHBNEPTFWYOTSUVSRDGHYPN/
注意
此 NEP 在 NumPy 1.22.0-1.26.x 中以实验标签发布(导入时会发出警告)。在 NumPy 2.0.0 之前它已被移除,那时 NumPy 在其主命名空间中获得了对数组 API 标准的支持(参见 NEP 56 — NumPy 主命名空间中的数组 API 标准支持)。numpy.array_api
的代码已移至独立包:array-api-strict。有关 numpy.array_api
模块的最后一个版本与 numpy
之间差异的详细概述,请参见 1.26.x 文档中的此表。
摘要#
我们建议采用由Python 数据 API 标准联盟开发的Python 数组 API 标准。在 NumPy 中将其作为单独的新命名空间实现,将允许依赖 NumPy 的库的作者以及最终用户编写可在 NumPy 和采用此标准的所有其他数组/张量库之间移植的代码。
注意
我们预计此 NEP 将在相当长的一段时间内保持草稿状态。鉴于范围很大,我们预计不会很快提出接受;相反,我们希望征求对高级设计和实现的反馈,并了解需要在本 NEP 中更好地描述的内容或在实现或数组 API 标准本身中需要更改的内容。
动机和范围#
Python 用户拥有丰富的数值计算、数据科学、机器学习和深度学习库和框架选择。每年都会出现推动这些领域技术发展的新框架。所有这些活动和创造性带来一个意外的结果是多维数组(又名张量)库的碎片化——这些是这些领域的根本数据结构。选择包括 NumPy、Tensorflow、PyTorch、Dask、JAX、CuPy、MXNet 等。
这些库的 API 在很大程度上是相似的,但差异足以使编写可与多个(或所有)这些库一起工作的代码非常困难。数组 API 标准旨在解决这个问题,方法是指定构造和使用数组的最常见方式的 API。拟议的 API 与 NumPy 的 API 非常相似,主要是在以下方面有所偏离:(a)NumPy 做出的设计选择本质上不能移植到其他实现中,以及 (b) 其他库故意偏离 NumPy,因为 NumPy 的设计被证明存在问题或不必要的复杂性。
有关数组 API 标准目的的更长讨论,请参阅数组 API 标准的“目的和范围”部分以及宣布联盟成立[1]和发布第一个社区审查标准草案版本的两个博客文章[2]。
此 NEP 的范围包括
采用 2021 年版本的数组 API 标准
添加一个单独的命名空间,暂定名为
numpy.array_api
需要/希望在新的命名空间之外进行的更改,例如
ndarray
对象上的新 dunder 方法实现选择以及新命名空间中的函数与主
numpy
命名空间中的函数之间的差异符合数组 API 标准的新数组对象
维护工作和测试策略
对 NumPy 的总公开 API 表面以及其他未来和正在讨论的设计选择的影响
与现有的和提议的 NumPy 数组协议的关系(
__array_ufunc__
、__array_function__
、__array_module__
)。对现有 NumPy 功能的改进要求
此 NEP 的范围之外
数组 API 标准本身的更改。这些可能会在审查此 NEP 的过程中出现,但应根据需要进行上游处理,并且此 NEP 随后会更新。
用法和影响#
本节稍后将详细说明,目前我们参考 数组 API 标准用例部分中给出的用例
除了这些用例之外,新命名空间还包含许多数组库广泛使用和支持的功能。因此,它是向 NumPy 新手传授并推荐为“最佳实践”的一组很好的函数。这与 NumPy 的主命名空间形成对比,后者包含许多已被取代或我们认为是错误的函数和对象——但由于向后兼容性原因,我们无法删除这些函数和对象。
下游库使用 numpy.array_api
命名空间旨在使它们能够使用多种类型的数组,而无需对所有这些数组库都具有硬依赖关系
下游库中的采用#
我们将array_api
命名空间的原型实现与SciPy、scikit-learn以及其他依赖NumPy的感兴趣的库一起使用,以获得更多关于设计的经验,并找出是否缺少任何重要部分。
支持多个数组库的模式旨在类似于
def somefunc(x, y):
# Retrieves standard namespace. Raises if x and y have different
# namespaces. See Appendix for possible get_namespace implementation
xp = get_namespace(x, y)
out = xp.mean(x, axis=0) + 2*xp.std(y, axis=0)
return out
get_namespace
调用实际上是库作者选择使用标准API命名空间,从而明确支持所有符合标准的数组库。
asarray
/ asanyarray
模式#
许多现有库使用与NumPy本身相同的asarray
(或asanyarray
)模式;接受任何可以强制转换为np.ndarray
的对象。我们认为这种设计模式存在问题——记住Python的禅宗,“显式优于隐式”,以及这种模式在SciPy生态系统中对于ndarray
子类和过度热切的对象创建历来存在问题。所有其他数组/张量库都更严格,这在实践中效果很好。我们建议新库的作者避免使用asarray
模式。相反,他们应该只接受NumPy数组,或者如果他们想支持多种数组,则通过检查__array_namespace__
来检查传入的数组对象是否支持数组API标准,如上例所示。
现有库也可以进行这样的检查,并且只有在检查失败时才调用asarray
。这与NEP 30 — NumPy数组的鸭子类型 - 实现中的__duckarray__
理念非常相似。
应用代码中的采用#
最终用户可以将新的命名空间视为NumPy主命名空间的清理和精简版本。鼓励最终用户像这样使用此命名空间
import numpy.array_api as xp
x = xp.linspace(0, 2*xp.pi, num=100)
y = xp.cos(x)
似乎是完全合理的,并且可能是有益的——用户为每个目的只提供一个函数(我们认为是最佳实践的函数),然后他们编写的代码更容易移植到其他库。
向后兼容性#
不建议弃用或删除现有的NumPy API或其他向后不兼容的更改。
高级设计#
数组API标准包含大约120个对象,所有这些对象都具有直接的NumPy等效项。此图显示了高级别包含的内容
与NumPy目前提供的功能相比,最重要的更改是
一个新的数组对象,
numpy.array_api.Array
,它是围绕
np.ndarray
的精简纯Python(非子类)包装器,符合标准指定的强制转换规则和索引行为,
除了dunder方法之外没有其他方法,
不支持NumPy索引行为的全部范围(参见下面的索引),
没有不同的标量对象,只有0维数组,
不能直接构造。应该使用诸如
asarray()
之类的数组构造函数。
array_api
命名空间中的函数不接受
array_like
输入,只接受numpy.array_api
数组对象,只有Python标量在数组对象的dunder运算符中受支持,不支持
__array_ufunc__
和__array_function__
,在其签名中使用仅位置和仅关键字参数,
具有内联类型注释,
与已存在于NumPy中的等效项相比,其函数的签名和语义可能会有细微的更改,
只支持dtype字面量,不支持格式字符串或其他指定dtype的方法,
通常可能只支持与其NumPy对应项相比受限的dtype集。
将向NumPy添加对DLPack的支持,
将通过新数组对象上的
.device
属性以及array_api
命名空间中数组创建函数中的device=
关键字添加新的“设备支持”语法,强制转换规则将与NumPy当前具有的规则不同。输出dtype可以从输入dtype推导出来(即没有基于值的强制转换),并且0维数组被视为>=1维数组。不允许跨类型强制转换(例如,int到float)。
并非NumPy拥有的所有dtype都是标准的一部分。只支持布尔值、有符号和无符号整数以及高达
float64
的浮点dtype。复杂dtype预计将在标准的下一个版本中添加。扩展精度、字符串、void、对象和日期时间dtype以及结构化dtype不包含在内。
所需的对现有NumPy功能的改进包括
向
numpy.linalg
中当前缺少此类支持的一些函数添加对矩阵堆栈的支持。向
np.argmin
和np.argmax
添加keepdims
关键字。向
np.asarray
添加“永不复制”模式。向
np.finfo()
添加smallest_normal。DLPack支持。
此外,numpy.array_api
实现被选择为数组API标准的*最小*实现。这意味着它不仅符合数组API的所有要求,而且明确不包含API或行为没有明确要求它。标准本身不要求实现如此严格,但是通过NumPy数组API实现这样做将使其成为数组API标准的规范实现。任何想要使用数组API标准的人都可以使用NumPy实现,并确保他们的代码没有使用其他符合标准的实现中不存在的行为。
特别是,这意味着
numpy.array_api
将只包含标准中列出的那些函数。这也适用于Array
对象上的方法,函数只接受标准要求的输入dtype(例如,像
cos
这样的超越函数不接受整数dtype,因为标准只要求它们接受浮点dtype),类型提升仅针对标准要求的dtype组合发生(参见下面的Dtype和强制转换规则部分),
索引限于可能的索引类型的子集(参见下面的索引)。
array_api
命名空间中的函数#
让我们从一个函数实现示例开始,该示例显示了与主命名空间中等效函数的最重要区别
def matmul(x1: Array, x2: Array, /) -> Array:
"""
Array API compatible wrapper for :py:func:`np.matmul <numpy.matmul>`.
See its docstring for more information.
"""
if x1.dtype not in _numeric_dtypes or x2.dtype not in _numeric_dtypes:
raise TypeError("Only numeric dtypes are allowed in matmul")
# Call result type here just to raise on disallowed type combinations
_result_type(x1.dtype, x2.dtype)
return Array._new(np.matmul(x1._array, x2._array))
此函数不接受array_like
输入,只接受numpy.array_api.Array
。这有多种原因。其他数组库都这样工作。要求用户显式地进行Python标量、列表、生成器或其他外部对象的强制转换会导致设计更简洁,行为更少意外。它具有更高的性能——减少了asarray
调用的开销。静态类型更容易。子类将按预期工作。而且由于用户必须在极少数情况下显式地强制转换为ndarray
,因此冗长性的轻微增加似乎是一个可以接受的小代价。
此函数不支持__array_ufunc__
也不支持__array_function__
。这些协议与数组API标准模块本身具有相似的目的,但通过不同的机制。因为只接受Array
实例,所以通过这些协议之一进行分派不再有用。
此函数在其签名中使用仅位置参数。这使代码更易于移植——例如,编写max(a=a, ...)
不再有效,因此如果其他库将第一个参数命名为input
而不是a
,那也没关系。请注意,NumPy已经对作为ufunc的函数使用仅位置参数。仅关键字参数(在上面的示例中未显示)的基本原理有两个:最终用户代码的清晰度,以及将来更容易扩展签名而无需担心关键字的顺序。
此函数具有内联类型注释。内联注释比单独的存根文件更容易维护。并且由于类型很简单,因此这不会像NumPy当前拥有的存根文件中那样导致大量使用类型别名或联合的混乱。
此函数仅接受数值型 dtype(即,不包括 bool
)。它也不允许输入 dtype 为不同类型(内部 _result_type()
函数会在不同类型组合(例如 _result_type(int32, float64)
)上引发 TypeError
)。这样可以使实现最小化。避免在 NumPy 中有效但在数组 API 规范中不需要的组合,可以让子模块的用户知道他们没有依赖 NumPy 的特定行为,而这种行为在其他库的符合数组 API 的实现中可能不存在。
支持零拷贝数据交换的 DLPack#
将一种数组转换为另一种数组的能力非常有价值,并且在后续库想要支持多种数组类型时是必要的。这需要一个明确指定的数据交换协议。NumPy 已经支持其中的两种,即缓冲区协议(即 PEP 3118)和 __array_interface__
(Python 端)/ __array_struct__
(C 端)协议。两者工作方式相似,允许“生产者”描述数据在内存中的布局方式,以便“消费者”可以使用该数据构建自己的数组类型。
DLPack 的工作方式非常相似。与 NumPy 中已有的选项相比,更倾向于使用 DLPack 的主要原因是:
DLPack 是唯一支持设备的协议(例如,使用 CUDA 或 ROCm 驱动程序的 GPU 或 OpenCL 设备)。NumPy 仅限 CPU,但其他数组库并非如此。每个设备使用一个协议是不可行的,因此必须支持设备。
广泛的支持。DLPack 拥有所有协议中最广泛的采用率。只有 NumPy 不支持它,而其他库的经验都是积极的。这与 NumPy 支持的协议形成对比,这些协议的使用非常少——当其他库想要与 NumPy 交互时,它们通常使用(更有限且特定于 NumPy 的)
__array__
协议。
向 NumPy 添加对 DLPack 的支持需要:
添加一个
ndarray.__dlpack__()
方法,该方法返回一个用PyCapsule
包装的dlpack
C 结构。添加一个
np.from_dlpack(obj)
函数,其中obj
支持__dlpack__()
,并返回一个ndarray
。
DLPack 目前是一个大约 200 行代码的头文件,旨在直接包含,因此不需要外部依赖项。实现应该很简单。
设备支持的语法#
NumPy 本身仅限 CPU,因此它显然不需要设备支持。但是,其他库(例如 TensorFlow、PyTorch、JAX、MXNet)支持多种类型的设备:CPU、GPU、TPU 和更多更特殊的硬件。为了在具有多个设备的系统上编写可移植代码,通常需要在与其他数组相同的设备上创建新数组,或者检查两个数组是否位于同一设备上。因此,需要相应的语法。
数组对象将具有一个 .device
属性,该属性可以比较不同数组的设备(只有当两个数组来自同一个库并且是相同的硬件设备时,它们才应该比较相等)。此外,还需要数组创建函数中的 device=
关键字。例如:
def empty(shape: Union[int, Tuple[int, ...]], /, *,
dtype: Optional[dtype] = None,
device: Optional[device] = None) -> Array:
"""
Array API compatible wrapper for :py:func:`np.empty <numpy.empty>`.
"""
if device not in ["cpu", None]:
raise ValueError(f"Unsupported device {device!r}")
return Array._new(np.empty(shape, dtype=dtype))
NumPy 的实现很简单,只需将 device 属性设置为字符串 "cpu"
,如果数组创建函数遇到任何其他值,则引发异常。
数据类型和转换规则#
此命名空间中支持的数据类型包括布尔型、8/16/32/64 位有符号和无符号整数以及 32/64 位浮点型数据类型。这些将作为数据类型字面量添加到命名空间中,并使用预期的名称(例如,bool
、uint16
、float64
)。
最明显的遗漏是复数数据类型。数组 API 标准第一个版本中缺乏复数支持的原因是,一些库(PyTorch、MXNet)仍在添加对复数数据类型的支持。标准的下一个版本预计将包含 complex64
和 complex128
(有关详细信息,请参阅 此问题)。
指定函数的数据类型(例如,通过 dtype=
关键字)预期仅使用数据类型字面量。格式字符串、Python 内置数据类型或数据类型字面量的字符串表示形式不被接受。这将以很小的代价提高代码的可读性和可移植性。此外,除了基本的相等比较之外,这些数据类型字面量本身不应具有任何行为。特别是,由于数组 API 没有标量对象,因此不允许使用 float32(0.0)
之类的语法(可以使用 asarray(0.0, dtype=float32)
创建一个 0 维数组)。
转换规则仅在相同类型的不同数据类型之间定义(即,布尔型到布尔型、整数到整数或浮点型到浮点型)。这也意味着省略了在 NumPy 中会向上转换为 float64 的整数-uint64 组合。这样做的原因是混合类型(例如,整数到浮点型)的转换行为在不同的库之间有所不同。
类型提升图。任何两种类型之间的提升由它们在此格上的连接给出。只有参与数组的类型很重要,而其值无关紧要。虚线表示 Python 标量在溢出时的行为未定义。Python 标量本身仅允许在数组对象的运算符中使用,而不是在函数内部使用。布尔型、整数型和浮点型数据类型没有连接,表示混合类型的提升未定义(对于 NumPy 实现,这些会引发异常)。
NumPy 和数组 API 标准中的转换规则之间最重要的区别在于如何处理标量和 0 维数组。在标准中,数组标量不存在,0 维数组遵循与更高维数组相同的转换规则。此外,标准中没有基于值的转换。运算的结果类型可以完全根据其输入数组的 dtype 来预测,而不管其形状或值如何。Python 标量仅允许在双下划线运算符(如 __add__
)中使用,并且仅当它们与数组 dtype 类型相同。它们始终转换为数组的 dtype,而不管值如何。溢出行为未定义。
有关更多详细信息,请参阅 数组 API 标准的类型提升规则部分。
在实现中,这意味着:
确保在
Array
构造函数中将 NumPy 中会产生标量对象的任何运算转换为 0 维数组。检查会应用基于值的转换的组合,并确保它们提升到正确的类型。这可以通过例如手动广播 0 维输入(阻止它们参与基于值的转换)或通过显式传递
signature
参数到底层的 ufunc 来实现。在双下划线运算符方法中,如果 Python 标量输入与数组类型相同,则手动将它们转换为匹配 dtype 的 0 维数组,否则引发异常。对于超出给定 dtype 范围的标量(其行为由规范未定义),使用
np.array(scalar, dtype=dtype)
的行为(强制转换或引发 OverflowError)。
索引#
使用 ndarray
将返回标量的索引表达式,例如 arr_2d[0, 0]
,将使用新的 Array
对象返回一个 0 维数组。这样做的原因有几个:数组标量在很大程度上被认为是一个设计错误,其他任何数组库都没有复制;对于非 CPU 库来说,它效果更好(通常数组可以驻留在设备上,标量驻留在主机上);它只是一个更一致的设计。要从 0 维数组中获取 Python 标量,可以使用类型的内置函数,例如 float(arr_0d)
。
标准中的其他 索引模式 的工作方式与 numpy.ndarray
几乎相同。一个值得注意的区别是,切片索引中的裁剪(例如,a[:n]
,其中 n
大于第一个轴的大小)是未指定的行为,因为这种类型的检查在加速器上可能代价很高。
标准省略了高级索引(通过整数数组进行索引),布尔索引仅限于单个 n 维布尔数组。这是因为这些索引模式不适合所有类型的数组或 JIT 编译。此外,一些高级 NumPy 索引语义,例如在单个索引中混合高级和非高级索引的语义,被认为是 NumPy 中的设计错误。这些更高级索引类型的缺失似乎没有问题;如果用户或库作者想要使用它们,他们可以通过零拷贝转换为 numpy.ndarray
来实现。这将正确地向任何阅读代码的人发出信号,表明它随后是特定于 NumPy 的,而不是可移植到所有符合条件的数组类型。
作为一个最小实现,numpy.array_api
将明确不允许具有裁剪边界的切片、高级索引以及与其他索引混合的布尔索引。
数组对象#
标准中的数组对象除了双下划线方法之外没有其他方法。它也不允许直接构造,而是更倾向于使用 asarray
之类的数组构造方法。这样做的原因是并非所有数组库的数组对象上都有方法(例如,TensorFlow 没有)。它还只提供一种执行操作的方式,而不是具有实际上重复的函数和方法。
将可能产生视图的操作(例如,索引、nonzero
)与变异操作(例如,项或切片赋值)混合使用,在标准中明确说明不支持。这不能轻易地在数组对象本身中禁止;相反,这将通过文档向用户提供指导。
标准规范并未规定数组对象本身的名称。我们建议将其命名为 Array
。这符合 PEP 8 中关于类名大小写的规范,并且不会与任何现有的 NumPy 类名冲突。[3] 请注意,数组类的实际名称并不重要,因为它本身不包含在顶级命名空间中,无法直接构造。
实现#
array_api
命名空间的原型可以在 numpy/numpy#18585 中找到。其 __init__.py
文件中的文档字符串包含一些关于实现细节的重要说明。包装函数的代码在与 NumPy API 存在差异的地方也包含了 # Note:
注释。实现完全使用纯 Python 完成,主要由包装类/函数组成,这些类/函数在应用输入验证和任何更改后的行为后,会传递到相应的 NumPy 函数。尚未实现的一个重要部分是 DLPack 支持,因为它在 np.ndarray
中的实现仍在进行中 (numpy/numpy#19083)。
numpy.array_api
模块被认为是实验性的。这意味着导入它将发出 UserWarning
警告。另一种方法是将模块命名为 numpy._array_api
,但选择了发出警告,这样将来就不需要重命名模块,从而避免破坏用户代码。由于大量使用了仅限位置参数语法,该模块还要求 Python 3.8 或更高版本。
该模块的实验性质也意味着它尚未在 NumPy 文档中的任何地方提及,模块文档字符串和本 NEP 除外。实现的文档本身就是一个具有挑战性的问题。目前,实现中的每个文档字符串都简单地引用了它实现的底层 NumPy 函数。但是,这并不理想,因为底层 NumPy 函数的行为可能与数组 API 中的相应函数不同,例如,数组 API 中不存在的额外关键字参数。有人建议可以直接从规范本身提取文档,但这需要对规范的编写方式进行一些技术更改,因此当前的实现尚未尝试这样做。
数组 API 规范附带了一个正在进行中的 官方测试套件,该套件旨在测试任何库对数组 API 规范的符合性。因此,实现中包含的测试将是最小的,因为大部分行为将由这个测试套件验证。NumPy 本身针对 array_api
子模块的测试将仅包括测试数组 API 测试套件未涵盖的行为,例如,测试实现是最小的,并正确拒绝诸如不允许的类型组合之类的东西。一个 CI 作业将添加到数组 API 测试套件存储库中,以定期针对 NumPy 实现对其进行测试。如果库希望这样做,数组 API 测试套件设计为可以进行供应商管理,但是这个想法被 NumPy 拒绝了,因为与现有的 NumPy 测试套件相比,它花费的时间非常多,而且测试套件本身仍在开发中。
dtype 对象#
我们必须能够比较 dtype 的相等性,并且必须能够使用以下表达式
np.array_api.some_func(..., dtype=x.dtype)
上述内容意味着最好有 np.array_api.float32 == np.array_api.ndarray(...).dtype
。
用户不应假设 dtype 具有类层次结构,但是如果方便的话,我们可以使用类层次结构来实现它。我们考虑了以下实现 dtype 对象的选项
将 dtype 作为主命名空间中的别名,例如
np.array_api.float32 = np.float32
。使 dtype 成为
np.dtype
的实例,例如np.array_api.float32 = np.dtype(np.float32)
。创建仅具有所需方法/属性(目前只有
__eq__
)的新单例类。
从与主命名空间外部函数交互的角度来看,(2) 似乎是最简单的,而 (3) 最符合标准。(2) 不会阻止用户访问 dtype 对象的 NumPy 特定属性,例如 (3) 会阻止那样做,尽管与 (1) 不同的是,它不允许创建像 float32(0.0)
这样的标量对象。(2) 还为每个 dtype 保留一个对象——使用 (1),arr.dtype
仍然是 dtype 实例。当前实现使用 (2)。
待定:标准目前还没有很好的方法来检查 dtype 的属性,例如询问“这是整数 dtype 吗?”。也许这对用户来说很容易做到,例如
def _get_dtype(dt_or_arr):
return dt_or_arr.dtype if hasattr(dt_or_arr, 'dtype') else dt_or_arr
def is_floating(dtype_or_array):
dtype = _get_dtype(dtype_or_array)
return dtype in (float32, float64)
def is_integer(dtype_or_array):
dtype = _get_dtype(dtype_or_array)
return dtype in (uint8, uint16, uint32, uint64, int8, int16, int32, int64)
但是,将其添加到标准中可能是有意义的。请注意,NumPy 本身目前没有很好的方法来提出此类问题,请参阅 gh-17325。
替代方案#
有人提议将 NumPy 数组 API 实现作为 NumPy 的一个单独库。这被拒绝了,因为将其分开会降低人们审查它的可能性,而将其作为实验性子模块包含在 NumPy 本身中,将使已经依赖 NumPy 的最终用户和库作者更容易访问该实现。
附录 - 可能的 get_namespace
实现#
在 应用程序代码中的采用 部分中提到的 get_namespace
函数可以像这样实现
def get_namespace(*xs):
# `xs` contains one or more arrays, or possibly Python scalars (accepting
# those is a matter of taste, but doesn't seem unreasonable).
namespaces = {
x.__array_namespace__() if hasattr(x, '__array_namespace__') else None for x in xs if not isinstance(x, (bool, int, float, complex))
}
if not namespaces:
# one could special-case np.ndarray above or use np.asarray here if
# older numpy versions need to be supported.
raise ValueError("Unrecognized array input")
if len(namespaces) != 1:
raise ValueError(f"Multiple namespaces for array inputs: {namespaces}")
xp, = namespaces
if xp is None:
raise ValueError("The input is not a supported array type")
return xp
讨论#
参考文献和脚注#
版权#
本文档已进入公共领域。[1]