NEP 47 — 采用数组 API 标准#

作者:

Ralf Gommers <ralf.gommers@gmail.com>

作者:

Stephan Hoyer <shoyer@gmail.com>

作者:

Aaron Meurer <asmeurer@gmail.com>

状态:

已取代

替换为:

NEP 56 — NumPy 主命名空间中的数组 API 标准支持

类型:

标准跟踪

创建:

2021-01-21

决议:

https://mail.python.org/archives/list/[email protected]/message/Z6AA5CL47NHBNEPTFWYOTSUVSRDGHYPN/

注意

此 NEP 已在 NumPy 1.22.0-1.26.x 中以实验标签的形式实现并发布(它在导入时发出警告)。它在 NumPy 2.0.0 之前被移除,NumPy 在 2.0.0 版本中获得了在其主命名空间中支持数组 API 标准的功能(请参阅 NEP 56 — NumPy 主命名空间中的数组 API 标准支持)。numpy.array_api 的代码已移动到一个独立的包中:array-api-strict。有关最新版本的 numpy.array_api 模块与 numpy 之间差异的详细概述,请参阅 1.26.x 文档中的此表

摘要#

我们建议采用由 Python 数据 API 标准联盟 开发的 Python 数组 API 标准。在 NumPy 中将其作为单独的新命名空间实现,将允许依赖 NumPy 的库的作者以及最终用户编写可在 NumPy 和采用此标准的所有其他数组/张量库之间移植的代码。

注意

我们预计此 NEP 将在相当长一段时间内保持草稿状态。鉴于范围很大,我们预计不会很快提出接受它的建议;相反,我们希望征求关于高级设计和实现的反馈,并了解此 NEP 中需要更好地描述的内容或实现或数组 API 标准本身中需要更改的内容。

动机和范围#

Python 用户可以从众多用于数值计算、数据科学、机器学习和深度学习的库和框架中进行选择。每年都会出现推动这些领域最新技术发展的新框架。所有这些活动和创造力的一个意外后果是多维数组(也称为张量)库的碎片化——这些是这些领域的基本数据结构。选择包括 NumPy、Tensorflow、PyTorch、Dask、JAX、CuPy、MXNet 等。

每个库的 API 在很大程度上是相似的,但差异足够大,以至于很难编写适用于多个(或所有)库的代码。数组 API 标准旨在解决此问题,方法是为构建和使用数组的最常见方法指定 API。提议的 API 与 NumPy 的 API 非常相似,主要是在以下方面存在偏差:(a)NumPy 做出的设计选择本质上无法移植到其他实现,以及(b)其他库有意偏离 NumPy,因为 NumPy 的设计结果证明存在问题或不必要的复杂性。

有关数组 API 标准目的的更长讨论,请参阅 数组 API 标准的“目的和范围”部分 以及宣布联盟成立的两篇博文 [1] 和发布标准第一个草稿版本以供社区审查 [2]

此 NEP 的范围包括

  • 采用 2021 年版本的数组 API 标准

  • 添加一个单独的命名空间,暂定名为 numpy.array_api

  • 在新命名空间之外需要/希望进行的更改,例如 ndarray 对象上的新 dunder 方法

  • 实现选择以及新命名空间中的函数与主 numpy 命名空间中的函数之间的差异

  • 符合数组 API 标准的新数组对象

  • 维护工作和测试策略

  • 对 NumPy 的总公开 API 表面以及其他未来和正在讨论的设计选择的影响

  • 与现有的和提议的 NumPy 数组协议的关系(__array_ufunc____array_function____array_module__)。

  • 现有 NumPy 功能所需的改进

此 NEP 范围之外的是

  • 数组 API 标准本身的更改。这些可能会在审查此 NEP 时出现,但应根据需要上游,然后更新此 NEP。

使用和影响#

本节将在以后详细阐述,目前我们参考了 数组 API 标准用例部分 中给出的用例

除了这些用例之外,新命名空间还包含许多数组库广泛使用和支持的功能。因此,它是向 NumPy 新手传授和推荐为“最佳实践”的一组很好的函数。这与 NumPy 的主命名空间形成对比,后者包含许多已被取代或我们认为是错误的函数和对象——但由于向后兼容性的原因,我们无法将其移除。

下游库使用 numpy.array_api 命名空间的目的是使它们能够使用多种类型的数组,而无需硬性依赖所有这些数组库

_images/nep-0047-library-dependencies.png

在下游库中采用#

array_api 命名空间的原型实现将与 SciPy、scikit-learn 和其他依赖 NumPy 的感兴趣库一起使用,以便获得更多关于设计方面的经验,并找出是否缺少任何重要部分。

支持多个数组库的模式旨在类似于

def somefunc(x, y):
    # Retrieves standard namespace. Raises if x and y have different
    # namespaces.  See Appendix for possible get_namespace implementation
    xp = get_namespace(x, y)
    out = xp.mean(x, axis=0) + 2*xp.std(y, axis=0)
    return out

get_namespace 调用实际上是库作者选择使用标准 API 命名空间,从而明确地支持所有符合标准的数组库。

asarray / asanyarray 模式#

许多现有库使用与 NumPy 本身相同的 asarray(或 asanyarray)模式;接受任何可以强制转换为 np.ndarray 的对象。我们认为这种设计模式存在问题——牢记 Python 之禅,“显式优于隐式”,以及这种模式在 SciPy 生态系统中对于 ndarray 子类以及过度热心的对象创建方面一直存在的问题。所有其他数组/张量库都更加严格,这在实践中运行良好。我们建议新库的作者避免使用 asarray 模式。相反,他们应该只接受 NumPy 数组,或者如果他们希望支持多种类型的数组,则通过检查 __array_namespace__ 来检查传入的数组对象是否支持数组 API 标准,如上例所示。

现有库也可以进行这样的检查,并且仅在检查失败时才调用 asarray。这与 NEP 30 — NumPy 数组的鸭子类型 - 实现 中的 __duckarray__ 想法非常相似。

在应用程序代码中采用#

最终用户可以将新命名空间视为 NumPy 主命名空间的清理和精简版本。鼓励最终用户像这样使用此命名空间

import numpy.array_api as xp

x = xp.linspace(0, 2*xp.pi, num=100)
y = xp.cos(x)

似乎是完全合理的,并且可能是有益的——用户为每个目的只提供一个函数(我们认为是最佳实践),然后他们编写更容易移植到其他库的代码。

向后兼容性#

不建议弃用或删除现有的 NumPy API 或其他向后不兼容的更改。

高级设计#

数组 API 标准由大约 120 个对象组成,所有这些对象都具有直接的 NumPy 等价物。此图显示了高级别包含的内容

_images/nep-0047-scope-of-array-API.png

与 NumPy 目前提供的功能相比,最重要的更改是

  • 一个新的数组对象,numpy.array_api.Array,它

    • 是围绕 np.ndarray 的纯 Python(非子类)薄包装器,

    • 符合标准指定的转换规则和索引行为,

    • 除了 dunder 方法之外没有其他方法,

    • 不支持 NumPy 索引行为的全部范围(请参阅下面的 索引),

    • 没有不同的标量对象,只有 0-D 数组,

    • 不能直接构造。相反,应使用诸如 asarray() 之类的数组构造函数。

  • array_api 命名空间中的函数

    • 不接受 array_like 输入,仅接受 numpy.array_api 数组对象,并且仅在数组对象上的 dunder 运算符中支持 Python 标量,

    • 不支持 __array_ufunc____array_function__

    • 在其签名中使用仅位置和仅关键字参数,

    • 具有内联类型注释,

    • 与已存在于 NumPy 中的等效函数相比,其签名和语义可能存在细微差异,

    • 仅支持 dtype 文字,不支持格式字符串或其他指定 dtype 的方法,

    • 通常可能仅支持与其 NumPy 对应项相比的一组受限的 dtype。

  • DLPack 支持将添加到 NumPy 中,

  • 将添加用于“设备支持”的新语法,通过新数组对象上的 .device 属性以及 array_api 命名空间中数组创建函数中的 device= 关键字。

  • 转换规则将不同于 NumPy 当前使用的规则。输出数据类型可以从输入数据类型派生(即没有基于值的转换),并且 0 维数组被视为 >=1 维数组。不同种类之间的转换(例如,整数到浮点数)是不允许的。

  • 并非 NumPy 中的所有数据类型都是标准的一部分。仅支持布尔型、有符号和无符号整数以及高达 float64 的浮点型数据类型。复杂数据类型预计将在标准的下一个版本中添加。扩展精度、字符串、空值、对象和日期时间数据类型,以及结构化数据类型,都不包括在内。

需要改进现有 NumPy 功能,包括

  • numpy.linalg 中的一些当前缺少此类支持的函数中,添加对矩阵堆栈的支持。

  • keepdims 关键字添加到 np.argminnp.argmax 中。

  • np.asarray 添加“永不复制”模式。

  • np.finfo() 添加 smallest_normal。

  • DLPack 支持。

此外,numpy.array_api 的实现被选择为数组 API 标准的最小实现。这意味着它不仅符合数组 API 的所有要求,而且明确地不包含任何 API 或行为,除非标准明确要求。标准本身不要求实现如此严格,但通过 NumPy 数组 API 实现这样做将使其成为数组 API 标准的规范实现。任何想要使用数组 API 标准的人都可以使用 NumPy 实现,并确保他们的代码没有使用其他符合标准的实现中不存在的行为。

特别是,这意味着

  • numpy.array_api 将仅包含标准中列出的那些函数。这也适用于 Array 对象上的方法。

  • 函数将仅接受标准要求的输入数据类型(例如,像 cos 这样的超越函数将不接受整数数据类型,因为标准仅要求它们接受浮点数据类型)。

  • 类型提升仅针对标准要求的数据类型组合发生(参见下面的 数据类型和转换规则 部分)。

  • 索引仅限于可能的索引类型的一个子集(参见下面的 索引)。

array_api 命名空间中的函数#

让我们从一个函数实现示例开始,该示例展示了与主命名空间中等效函数最重要的区别

def matmul(x1: Array, x2: Array, /) -> Array:
    """
    Array API compatible wrapper for :py:func:`np.matmul <numpy.matmul>`.
    See its docstring for more information.
    """
    if x1.dtype not in _numeric_dtypes or x2.dtype not in _numeric_dtypes:
        raise TypeError("Only numeric dtypes are allowed in matmul")

    # Call result type here just to raise on disallowed type combinations
    _result_type(x1.dtype, x2.dtype)

    return Array._new(np.matmul(x1._array, x2._array))

此函数不接受 array_like 输入,仅接受 numpy.array_api.Array。这有几个原因。其他数组库都以这种方式工作。要求用户显式地执行 Python 标量、列表、生成器或其他外部对象的强制转换会导致更简洁的设计,并且减少意外行为。它具有更高的性能——减少了 asarray 调用的开销。静态类型更容易。子类将按预期工作。而且,由于用户必须在很少的情况下显式地强制转换为 ndarray 而导致的轻微冗长似乎是一个可以接受的小代价。

此函数不支持 __array_ufunc____array_function__。这些协议与数组 API 标准模块本身具有相似的目的,但通过不同的机制实现。因为只接受 Array 实例,所以通过其中一个协议进行分派不再有用。

此函数在其签名中使用仅位置参数。这使得代码更易于移植——例如,编写 max(a=a, ...) 将不再有效,因此如果其他库将第一个参数称为 input 而不是 a,那也没问题。请注意,NumPy 已经对用作通用函数的函数使用了仅位置参数。关键字仅参数(在以上示例中未显示)的基本原理有两个:最终用户代码的清晰度,以及将来更容易扩展签名而不必担心关键字的顺序。

此函数具有内联类型注释。内联注释比单独的存根文件更容易维护。并且由于类型很简单,因此与 NumPy 当前拥有的存根文件中的类型别名或联合相比,这不会导致大量混乱。

此函数仅接受数值数据类型(即,不包括 bool)。它也不允许输入数据类型属于不同的种类(内部 _result_type() 函数将在不同种类类型组合(如 _result_type(int32, float64))上引发 TypeError)。这使得实现能够保持最小化。防止在 NumPy 中有效但在数组 API 规范中未要求的组合,使用户能够知道他们没有依赖 NumPy 特定的行为,而这些行为可能不存在于其他库中的符合数组 API 的实现中。

用于零拷贝数据交换的 DLPack 支持#

将一种数组转换为另一种数组的能力非常宝贵,并且在后续库想要支持多种数组时确实必要。这需要一个明确指定的数据交换协议。NumPy 已经支持了其中的两个,即缓冲区协议(即 PEP 3118)和 __array_interface__(Python 端)/ __array_struct__(C 端)协议。两者工作方式类似,让“生产者”描述数据在内存中的布局方式,以便“消费者”可以使用该数据构建自己的数组类型。

DLPack 的工作方式非常相似。与 NumPy 中已有的选项相比,首选 DLPack 的主要原因是

  1. DLPack 是唯一具有设备支持的协议(例如,使用 CUDA 或 ROCm 驱动程序的 GPU,或 OpenCL 设备)。NumPy 仅限于 CPU,但其他数组库并非如此。每个设备都有一个协议是不可行的,因此设备支持是必须的。

  2. 广泛的支持。DLPack 拥有所有协议中最广泛的采用。只有 NumPy 缺少支持,其他库的经验也很好。这与 NumPy 支持的协议形成对比,这些协议的使用非常少——当其他库想要与 NumPy 交互时,它们通常使用(更有限且 NumPy 特定的)__array__ 协议。

向 NumPy 添加对 DLPack 的支持需要

  • 添加一个 ndarray.__dlpack__() 方法,该方法返回一个包装在 PyCapsule 中的 dlpack C 结构。

  • 添加一个 np.from_dlpack(obj) 函数,其中 obj 支持 __dlpack__(),并返回一个 ndarray

DLPack 目前是一个大约 200 行代码的头文件,并且旨在直接包含,因此不需要外部依赖项。实现应该很简单。

设备支持语法#

NumPy 本身是仅限 CPU 的,因此它显然不需要设备支持。但是,其他库(例如 TensorFlow、PyTorch、JAX、MXNet)支持多种类型的设备:CPU、GPU、TPU 和更多异构硬件。为了在具有多个设备的系统上编写可移植代码,通常需要在与其他一些数组相同的设备上创建新数组,或者检查两个数组是否位于同一设备上。因此需要为此添加语法。

数组对象将具有一个 .device 属性,该属性可以比较不同数组的设备(只有当两个数组都来自同一个库并且是同一个硬件设备时,它们才应该比较相等)。此外,需要在数组创建函数中使用 device= 关键字。例如

def empty(shape: Union[int, Tuple[int, ...]], /, *,
          dtype: Optional[dtype] = None,
          device: Optional[device] = None) -> Array:
    """
    Array API compatible wrapper for :py:func:`np.empty <numpy.empty>`.
    """
    if device not in ["cpu", None]:
        raise ValueError(f"Unsupported device {device!r}")
    return Array._new(np.empty(shape, dtype=dtype))

NumPy 的实现与将 device 属性设置为字符串 "cpu" 并 在数组创建函数遇到任何其他值时引发异常一样简单。

数据类型和转换规则#

此命名空间中支持的数据类型为布尔型、8/16/32/64 位有符号和无符号整数以及 32/64 位浮点型数据类型。这些将作为数据类型字面量添加到命名空间中,并使用预期的名称(例如,booluint16float64)。

最明显的遗漏是复数数据类型。数组 API 标准第一个版本中缺少复数支持的原因是,几个库(PyTorch、MXNet)仍在添加对复数数据类型的支持。标准的下一个版本预计将包含 complex64complex128(有关更多详细信息,请参阅 此问题)。

指定函数的数据类型,例如通过 dtype= 关键字,预计仅使用数据类型字面量。格式字符串、Python 内置数据类型或数据类型字面量的字符串表示形式不被接受。这将以很少的代价提高代码的可读性和可移植性。此外,除了基本的相等性比较之外,不期望这些数据类型字面量本身具有任何行为。特别是,由于数组 API 没有标量对象,因此不允许像 float32(0.0) 这样的语法(可以使用 asarray(0.0, dtype=float32) 创建一个 0 维数组)。

转换规则仅在相同种类的不同数据类型之间定义(即,布尔型到布尔型、整数到整数或浮点数到浮点数)。这也意味着省略了在 NumPy 中将向上转换为 float64 的整数-uint64 组合。这样做的原因是库之间混合种类(例如,整数到浮点数)转换行为存在差异。

_images/nep-0047-casting-rules-lattice.png

类型提升图。任何两种类型之间的提升由它们在此格上的连接给出。只有参与数组的类型才重要,它们的数值不重要。虚线表示 Python 标量在溢出时行为未定义。Python 标量本身仅允许在数组对象上的运算符中使用,而不是在函数内部使用。布尔型、整数和浮点型数据类型没有连接,表示混合种类提升未定义(对于 NumPy 实现,这些会引发异常)。

NumPy 中的转换规则与数组 API 标准中的转换规则之间最显着的区别在于如何处理标量和 0 维数组。在标准中,数组标量不存在,并且 0 维数组遵循与更高维数组相同的转换规则。此外,标准中没有基于值的转换。运算的结果类型可以完全根据其输入数组的数据类型来预测,而不管其形状或值如何。Python 标量仅允许在 dunder 运算符(如 __add__)中使用,并且仅当它们与数组数据类型属于同一种类时。它们始终转换为数组的数据类型,而不管其值如何。溢出行为未定义。

有关更多详细信息,请参阅 数组 API 标准的类型提升规则部分

在实现中,这意味着

  • 确保 NumPy 中将产生标量对象的任何运算都转换为 Array 构造函数中的 0 维数组。

  • 检查将应用基于值的转换的组合,并确保它们提升到正确的类型。这可以通过例如手动广播 0 维输入(防止它们参与基于值的转换)或通过显式地将 signature 参数传递给底层通用函数来实现。

  • 在 dunder 运算符方法中,如果 Python 标量输入与数组属于同一种类,则将其手动转换为匹配数据类型的 0 维数组,否则引发异常。对于超出给定数据类型范围的标量(规范未定义其行为),使用 np.array(scalar, dtype=dtype) 的行为(强制转换或引发 OverflowError)。

索引#

使用 ndarray 将返回标量的索引表达式,例如 arr_2d[0, 0],将使用新的 Array 对象返回一个 0 维数组。这有几个原因:数组标量在很大程度上被认为是一个设计错误,其他数组库没有复制;对于非 CPU 库来说,它效果更好(通常数组可以驻留在设备上,标量驻留在主机上);它只是一个更一致的设计。要从 0 维数组中获取 Python 标量,可以使用该类型的内置函数,例如 float(arr_0d)

标准中的其他 索引模式 在很大程度上与 numpy.ndarray 中的工作方式相同。一个值得注意的区别是,切片索引中的裁剪(例如,a[:n],其中 n 大于第一个轴的大小)是未指定的行为,因为这种类型的检查在加速器上可能很昂贵。

标准省略了高级索引(通过整数数组索引),并且布尔索引仅限于单个 n 维布尔数组。这是因为这些索引模式不适合所有类型的数组或 JIT 编译。此外,一些高级 NumPy 索引语义,例如在单个索引中混合高级和非高级索引的语义,被认为是 NumPy 中的设计错误。这些更高级索引类型的缺失似乎没有问题;如果用户或库作者想要使用它们,可以通过零拷贝转换为 numpy.ndarray 来实现。这将正确地向阅读代码的人发出信号,表明它然后是 NumPy 特定的,而不是所有符合标准的数组类型都可移植的。

作为最小实现,numpy.array_api 将明确地不允许具有裁剪边界的切片、高级索引以及与其他索引混合的布尔索引。

数组对象#

标准中的数组对象除了 dunder 方法之外没有其他方法。它也不允许直接构造,而是更倾向于使用数组构造方法,例如 asarray。这样做的理由是,并非所有数组库在其数组对象上都有方法(例如,TensorFlow 没有)。它还只提供了一种执行某项操作的方式,而不是提供实际上是重复的函数和方法。

将可能产生视图的操作(例如,索引,nonzero)与变异操作(例如,项目或切片赋值)混合使用,在标准中明确记录为不支持。这在数组对象本身中不容易禁止;相反,这将是通过文档向用户提供的指导。

标准目前没有规定数组对象本身的名称。我们建议将其命名为 Array。这使用了 PEP 8 中类的大写规范,并且不会与任何现有的 NumPy 类名称冲突。[3] 请注意,数组类的实际名称并不那么重要,因为它本身不包含在顶级命名空间中,也不能直接构造。

实现#

array_api 命名空间的原型可以在 numpy/numpy#18585 中找到。其 __init__.py 中的文档字符串包含关于实现细节的一些重要说明。包装器函数的代码还在 NumPy API 与之存在差异的地方包含了 # Note: 注释。该实现完全使用纯 Python,主要由包装器类/函数组成,这些类/函数在应用输入验证和任何更改后的行为后传递给相应的 NumPy 函数。一个尚未实现的重要部分是 DLPack 支持,因为它在 np.ndarray 中的实现仍在进行中 (numpy/numpy#19083)。

numpy.array_api 模块被认为是实验性的。这意味着导入它将发出 UserWarning。另一种方法是将模块命名为 numpy._array_api,但选择了警告,以便将来无需重命名模块,从而可能破坏用户代码。由于广泛使用了仅限位置的参数语法,因此该模块还需要 Python 3.8 或更高版本。

该模块的实验性质还意味着它尚未在 NumPy 文档中的任何地方提及,除了其模块文档字符串和此 NEP 之外。实现的文档本身就是一个具有挑战性的问题。目前,实现中的每个文档字符串都简单地引用了它实现的基础 NumPy 函数。但是,这并不理想,因为基础 NumPy 函数的行为可能与数组 API 中的相应函数不同,例如,数组 API 中不存在的其他关键字参数。有人建议可以从规范本身直接提取文档,但这需要对规范的编写方式进行一些技术更改,因此当前的实现尚未尝试这样做。

数组 API 规范附带了一个正在进行的 官方测试套件,旨在测试任何库对数组 API 规范的符合性。因此,实现中包含的测试将是最小的,因为大多数行为将由此测试套件验证。NumPy 本身针对 array_api 子模块的测试将仅包括测试数组 API 测试套件未涵盖的行为,例如,测试实现是最小的,并正确拒绝诸如不允许的类型组合之类的东西。将向数组 API 测试套件存储库添加一个 CI 作业,以定期针对 NumPy 实现对其进行测试。如果库希望这样做,数组 API 测试套件旨在被销售,但此想法被 NumPy 拒绝,因为与现有的 NumPy 测试套件相比,它花费的时间非常长,并且因为测试套件本身仍在开发中。

dtype 对象#

我们必须能够比较 dtype 以确定是否相等,并且必须可以使用以下表达式

np.array_api.some_func(..., dtype=x.dtype)

上面暗示最好有 np.array_api.float32 == np.array_api.ndarray(...).dtype

用户不应假设 dtype 具有类层次结构,但是如果方便,我们可以使用类层次结构来实现它。我们考虑了以下选项来实现 dtype 对象

  1. 将 dtype 作为主命名空间中的别名,例如,np.array_api.float32 = np.float32

  2. 使 dtype 成为 np.dtype 的实例,例如,np.array_api.float32 = np.dtype(np.float32)

  3. 创建仅具有所需方法/属性(目前仅为 __eq__)的新单例类。

从与主命名空间外部的函数交互的角度来看,(2) 似乎是最简单的,而 (3) 最符合标准。(2) 不会阻止用户访问 dtype 对象的 NumPy 特定属性,例如 (3) 会做的那样,尽管与 (1) 不同,它确实不允许创建标量对象,例如 float32(0.0)。(2) 还为每个 dtype 保留了一个对象——使用 (1),arr.dtype 仍然将是 dtype 实例。当前的实现使用 (2)。

待定:标准还没有一个好的方法来检查 dtype 的属性,例如询问“这是否是一个整数 dtype?”。也许这对用户来说很容易,例如

def _get_dtype(dt_or_arr):
    return dt_or_arr.dtype if hasattr(dt_or_arr, 'dtype') else dt_or_arr

def is_floating(dtype_or_array):
    dtype = _get_dtype(dtype_or_array)
    return dtype in (float32, float64)

def is_integer(dtype_or_array):
    dtype = _get_dtype(dtype_or_array)
    return dtype in (uint8, uint16, uint32, uint64, int8, int16, int32, int64)

但是添加标准可能是有意义的。请注意,NumPy 本身目前没有一个很好的方法来询问此类问题,请参阅 gh-17325

来自下游库作者的反馈#

待办事项 - 这只能在尝试了一些用例后才能完成

Leo Fang (CuPy):“我的印象是,对于 CuPy,我们可以简单地获取这个新的数组对象并 s/numpy/cupy”

备选方案#

有人提议将 NumPy 数组 API 实现作为与 NumPy 分开的库。这被拒绝了,因为将其分开会降低人们审查它的可能性,而将其作为实验性子模块包含在 NumPy 本身中将使已经依赖 NumPy 的最终用户和库作者更容易访问该实现。

附录 - 可能的 get_namespace 实现#

应用程序代码中的采用 部分中提到的 get_namespace 函数可以像这样实现

def get_namespace(*xs):
    # `xs` contains one or more arrays, or possibly Python scalars (accepting
    # those is a matter of taste, but doesn't seem unreasonable).
    namespaces = {
        x.__array_namespace__() if hasattr(x, '__array_namespace__') else None for x in xs if not isinstance(x, (bool, int, float, complex))
    }

    if not namespaces:
        # one could special-case np.ndarray above or use np.asarray here if
        # older numpy versions need to be supported.
        raise ValueError("Unrecognized array input")

    if len(namespaces) != 1:
        raise ValueError(f"Multiple namespaces for array inputs: {namespaces}")

    xp, = namespaces
    if xp is None:
        raise ValueError("The input is not a supported array type")

    return xp

讨论#

参考文献和脚注#