NEP 37 — NumPy 类模块的调度协议#

作者:

Stephan Hoyer <shoyer@google.com>

作者:

Hameer Abbasi

作者:

Sebastian Berg

状态:

已取代

被替换为:

NEP 56 — NumPy 主命名空间中的数组 API 标准支持

类型:

标准跟踪

创建:

2019-12-29

决议:

https://mail.python.org/archives/list/[email protected]/message/Z6AA5CL47NHBNEPTFWYOTSUVSRDGHYPN/

摘要#

NEP-18 的 __array_function__ 取得了喜忧参半的成功。一些项目(例如,dask、CuPy、xarray、sparse、Pint、MXNet)热情地采用了它。其他项目(例如,JAX)则更为犹豫。在这里,我们提出一个新的协议,__array_module__,我们预计它最终可以取代 __array_function__ 的大多数用例。该协议要求用户和库作者明确采用,这确保了向后兼容性,并且也比 __array_function__ 简单得多,我们预计这将使其更容易采用。

为什么 __array_function__ 不够#

NEP-18 有两种主要方式未能实现其目标

  1. 向后兼容性问题__array_function__ 对使用它的库有重大影响

    • JAX 一直不愿实现 __array_function__,部分原因是担心破坏现有代码:用户期望 NumPy 函数(如 np.concatenate)返回 NumPy 数组。这是 __array_function__ 设计的一个根本性限制,我们选择允许重写现有的 numpy 命名空间。像 Dask 和 CuPy 这样的库已经研究并接受了 __array_function__ 带来的向后兼容性影响;如果这种影响不存在,对它们来说仍然会更好。

      请注意,像 PyTorchscipy.sparse 这样的项目也尚未采用 __array_function__,因为它们没有与 NumPy 兼容的 API 或语义。对于 PyTorch,这很可能在将来添加。 scipy.sparsenumpy.matrix 处于相同的情况:它的语义与 numpy.ndarray 不兼容,因此添加 __array_function__(除了可能返回 NotImplemented)不是一个好主意。

    • __array_function__ 目前需要一种“全有或全无”的方法来实现 NumPy 的 API。没有好的途径来实现增量采用,这对于采用 __array_function__ 将导致重大更改的已建立项目来说尤其成问题。

  2. 对可重写内容的限制。 __array_function__ 有一些重要的差距,最显著的是数组创建和强制转换函数

    • 数组创建例程(例如,np.arangenp.random 中的例程)需要其他机制来指示要创建哪种类型的数组。NEP 35 提出向没有现有数组参数的函数添加可选的 like= 参数。但是,我们仍然缺乏任何重写对象方法的机制,例如 np.random.RandomState 所需的那些方法。

    • 数组转换无法重用现有的强制转换函数,如 np.asarray,因为 np.asarray 有时意味着“转换为精确的 np.ndarray”,而其他时候则意味着“转换为类似于 NumPy 数组的东西”。这导致了 NEP 30 对单独的 np.duckarray 函数的提案,但这仍然没有解决如何将一个鸭子数组转换为与另一个鸭子数组类型匹配的问题。

其他提出的可维护性问题包括

  • 在支持重写的模块中,不再可以使用 NumPy 函数的别名。例如,CuPy 和 JAX 都设置了 result_type = np.result_type,现在必须将它们自己的 result_type 函数包装在 np.result_type 的使用中。

  • 通过使用 NumPy 的实现来实现回退机制以处理未实现的 NumPy 函数很难做到正确(但请参见 来自 dask 的版本),因为 __array_function__ 没有提供一致的接口。转换所有数组类型的参数需要递归进入 *args, **kwargs 形式的泛型参数。

get_array_module__array_module__ 协议#

我们提出了一种新的面向用户的面向鸭子数组实现的调度机制,numpy.get_array_moduleget_array_module 执行与 __array_function__ 相同的类型解析,并返回一个模块,该模块具有承诺与 numpy 的标准接口匹配的 API,可以对所有提供的数组类型实现操作。

该协议本身比 __array_function__ 更简单、更强大,因为它不需要担心实际实现函数。我们相信它解决了 __array_function__ 的大多数可维护性和功能限制。

新协议是选择加入的,明确的,并且具有本地控制;请参阅 附录:API 重写的设计选择,了解有关这些设计功能重要性的讨论。

数组模块契约#

get_array_module/__array_module__返回的模块应尽最大努力在新数组类型(s)上实现NumPy的核心功能。未实现的功能应简单地省略(例如,访问未实现的函数应引发AttributeError)。将来,我们预计将对请求受限的numpy子集进行编码;更多详细信息,请参见请求NumPy API的受限子集

如何使用get_array_module#

想要支持通用鸭阵列的代码应该显式调用get_array_module来确定从中调用函数的适当数组模块,而不是直接使用numpy命名空间。例如

# calls the appropriate version of np.something for x and y
module = np.get_array_module(x, y)
module.something(x, y)

数组创建和数组转换都受支持,因为调度是由get_array_module处理的,而不是通过函数参数的类型处理的。例如,要使用随机数生成函数或方法,我们可以简单地提取相应的子模块

def duckarray_add_random(array):
    module = np.get_array_module(array)
    noise = module.random.randn(*array.shape)
    return array + noise

我们还可以编写来自NEP 30的鸭阵列stack函数,而无需新的np.duckarray函数

def duckarray_stack(arrays):
    module = np.get_array_module(*arrays)
    arrays = [module.asarray(arr) for arr in arrays]
    shapes = {arr.shape for arr in arrays}
    if len(shapes) != 1:
        raise ValueError('all input arrays must have the same shape')
    expanded_arrays = [arr[module.newaxis, ...] for arr in arrays]
    return module.concatenate(expanded_arrays, axis=0)

默认情况下,如果没有任何参数是数组,get_array_module将返回numpy模块。可以通过提供module关键字参数来显式控制此回退。也可以通过设置module=None来指示应引发异常而不是返回默认数组模块。

如何实现__array_module__#

实现想要支持get_array_module的鸭阵列类型的库需要实现相应的协议__array_module__。这个新的协议基于Python的算术调度协议,本质上是__array_function__的简化版本。

只将一个参数传递给__array_module__,这是一个传递给get_array_module的唯一数组类型的Python集合,即所有具有__array_module__属性的参数。

特殊方法应该返回一个具有与numpy匹配的API的命名空间,或者返回NotImplemented,表示它不知道如何处理该操作。

class MyArray:
    def __array_module__(self, types):
        if not all(issubclass(t, MyArray) for t in types):
            return NotImplemented
        return my_array_module

__array_module__返回自定义对象#

my_array_module通常(但不总是)是一个Python模块。返回自定义对象(例如,通过__getattr__实现函数)可能对某些高级用例有用。

例如,自定义对象可以允许鸭阵列模块的部分实现回退到NumPy(尽管通常不推荐这样做,因为这种回退行为可能会导致错误)。

class MyArray:
    def __array_module__(self, types):
        if all(issubclass(t, MyArray) for t in types):
            return ArrayModule()
        else:
            return NotImplemented

class ArrayModule:
    def __getattr__(self, name):
        import base_module
        return getattr(base_module, name, getattr(numpy, name))

numpy.ndarray继承#

NEP-18中关于定义良好的类型转换层次结构的所有相同指导仍然适用。numpy.ndarray本身包含__array_module__的匹配实现,这对于子类来说很方便。

class ndarray:
    def __array_module__(self, types):
        if all(issubclass(t, ndarray) for t in types):
            return numpy
        else:
            return NotImplemented

NumPy的内部机制#

get_array_module的类型解析规则遵循与Python和NumPy现有调度协议相同的模型:子类在超类之前调用,否则从左到右。保证__array_module__在每个唯一类型上只调用一次。

get_array_module的实际实现将用C语言编写,但应该等同于这段Python代码。

def get_array_module(*arrays, default=numpy):
    implementing_arrays, types = _implementing_arrays_and_types(arrays)
    if not implementing_arrays and default is not None:
        return default
    for array in implementing_arrays:
        module = array.__array_module__(types)
        if module is not NotImplemented:
            return module
    raise TypeError("no common array module found")

def _implementing_arrays_and_types(relevant_arrays):
    types = []
    implementing_arrays = []
    for array in relevant_arrays:
        t = type(array)
        if t not in types and hasattr(t, '__array_module__'):
            types.append(t)
            # Subclasses before superclasses, otherwise left to right
            index = len(implementing_arrays)
            for i, old_array in enumerate(implementing_arrays):
                if issubclass(t, type(old_array)):
                    index = i
                    break
            implementing_arrays.insert(index, array)
    return implementing_arrays, types

__array_ufunc____array_function__的关系#

这些较旧的协议具有不同的用例,应该保留#

__array_module__旨在解决__array_function__的局限性,因此自然会考虑它是否可以完全替换__array_function__。这将带来双重好处:(1)简化关于如何覆盖NumPy的用户故事;(2)消除调用每个NumPy函数时与检查调度相关的减速。

但是,从用户的角度来看,__array_module____array_function__差别很大:它需要显式调用get_array_function,而不是简单地重用原始numpy函数。对于依赖于鸭阵列的来说,这可能很好,但对于交互式使用来说可能会过于冗长。

__array_ufunc__的一些调度用例也由__array_module__解决,但并非所有用例都解决。例如,仍然可以使用通用的方法在非NumPy数组(例如,使用dask.array)上定义非NumPy ufunc(例如,来自Numba或SciPy)。

鉴于它们现有的采用情况和不同的用例,我们认为目前没有必要移除或弃用__array_function____array_ufunc__

实现__array_function____array_ufunc__的Mixin类#

尽管用户界面有所不同,但__array_module__和实现NumPy API的模块仍然包含实现现有鸭阵列协议调度所需的功能。

例如,以下Mixin类将根据get_array_module__array_module__为这些特殊方法提供合理的默认值。

class ArrayUfuncFromModuleMixin:

    def __array_ufunc__(self, ufunc, method, *inputs, **kwargs):
        arrays = inputs + kwargs.get('out', ())
        try:
            array_module = np.get_array_module(*arrays)
        except TypeError:
            return NotImplemented

        try:
            # Note this may have false positive matches, if ufunc.__name__
            # matches the name of a ufunc defined by NumPy. Unfortunately
            # there is no way to determine in which module a ufunc was
            # defined.
            new_ufunc = getattr(array_module, ufunc.__name__)
        except AttributeError:
            return NotImplemented

        try:
            callable = getattr(new_ufunc, method)
        except AttributeError:
            return NotImplemented

        return callable(*inputs, **kwargs)

class ArrayFunctionFromModuleMixin:

    def __array_function__(self, func, types, args, kwargs):
        array_module = self.__array_module__(types)
        if array_module is NotImplemented:
            return NotImplemented

        # Traverse submodules to find the appropriate function
        modules = func.__module__.split('.')
        assert modules[0] == 'numpy'
        for submodule in modules[1:]:
            module = getattr(module, submodule, None)
        new_func = getattr(module, func.__name__, None)
        if new_func is None:
            return NotImplemented

        return new_func(*args, **kwargs)

为了更容易编写鸭阵列,我们也可以将这些Mixin类添加到numpy.lib.mixins中(但上面的例子可能就足够了)。

考虑过的替代方案#

命名#

我们喜欢__array_module__这个名称,因为它与现有的__array_function____array_ufunc__协议相呼应。另一个合理的选择可能是__array_namespace__

调用此协议的NumPy函数应该叫什么还不清楚(在这个提案中是get_array_module)。一些可能的替代方案:array_modulecommon_array_moduleresolve_array_moduleget_namespaceget_numpyget_numpylike_moduleget_duck_array_module

请求NumPy API的受限子集#

随着时间的推移,NumPy积累了非常大的API表面积,仅顶级numpy模块就拥有超过600个属性。任何鸭阵列库都不可能或不想实现所有这些函数和类,因为NumPy常用的子集要小得多。

我们认为定义NumPy API的“最小”子集(省略很少使用或不推荐的功能)将是有益的练习。例如,最小NumPy可能包括stack,但不包括其他堆叠函数column_stackdstackhstackvstack。这可以清楚地向鸭阵列的作者和用户表明哪些功能是核心功能,哪些功能可以跳过。

支持请求NumPy API的受限子集将是get_array_function__array_module__中自然包含的功能,例如:

# array_module is only guaranteed to contain "minimal" NumPy
array_module = np.get_array_module(*arrays, request='minimal')

为了方便使用NumPy进行测试以及与任何有效的鸭子类型数组库一起使用,当仅对NumPy数组调用get_array_module时,NumPy本身将返回numpy模块的受限版本。省略的函数将根本不存在。

不幸的是,我们还没有确定这些受限子集应该是什么,所以现在这样做没有意义。当/如果我们这样做时,我们可以向get_array_module添加新的关键字参数,或者添加新的顶级函数,例如get_minimal_array_module。我们还需要添加一个新的基于__array_module__的协议(例如__array_module_minimal__),或者可以向__array_module__添加一个可选的第二个参数(使用try/except捕获错误)。

用于隐式调度的新的命名空间#

与其使用__array_function__在主要的numpy命名空间中支持覆盖,我们可以创建一个新的选择加入命名空间,例如numpy.api,其中包含支持调度的NumPy函数版本。这些覆盖需要新的选择加入协议,例如__array_function_api__,其模式与__array_function__类似。

通过选择加入,这将解决__array_function__最大的局限性,并且还可以明确地覆盖诸如asarray之类的函数,因为np.api.asarray将始终意味着“转换类数组对象”。但是它并不能解决__array_module__满足的所有调度需求,并且将使我们不得不为数组用户和实现者支持一个相当复杂的协议。

我们可能可以通过__array_module__协议实现这样一个新的命名空间。当然,一些用户会觉得这很方便,因为它稍微减少了样板代码。但这将使用户面临一个令人困惑的选择:何时应该使用get_array_modulenp.api.something。此外,我们必须添加和维护一个全新的模块,这比仅仅添加一个函数要昂贵得多。

根据类型和数组进行调度,而不仅仅是类型#

与其仅通过唯一的数组类型支持调度,我们还可以通过数组对象支持调度,例如,通过将arrays参数作为__array_module__协议的一部分传递。这对于具有元数据的数组(例如Dask和Pint提供的数组)的调度可能很有用,但会在类型安全性和复杂性方面带来成本。

例如,一个同时支持CPU和GPU上数组的库可能会根据输入参数决定从ones之类的函数创建新数组的设备。

class Array:
    def __array_module__(self, types, arrays):
        useful_arrays = tuple(a in arrays if isinstance(a, Array))
        if not useful_arrays:
            return NotImplemented
        prefer_gpu = any(a.prefer_gpu for a in useful_arrays)
        return ArrayModule(prefer_gpu)

class ArrayModule:
    def __init__(self, prefer_gpu):
        self.prefer_gpu = prefer_gpu

    def __getattr__(self, name):
        import base_module
        base_func = getattr(base_module, name)
        return functools.partial(base_func, prefer_gpu=self.prefer_gpu)

这可能很有用,但我们还不清楚我们是否真的需要它。Pint似乎在没有任何显式的数组创建例程的情况下也能很好地工作(倾向于用单位进行乘法,例如np.ones(5) * ureg.m),并且在大多数情况下,Dask也能很好地使用现有的__array_function__样式覆盖(例如,倾向于使用np.ones_like而不是np.ones)。选择将数组放置在CPU还是GPU上可以通过使数组创建变为惰性来解决。

附录:API 覆盖的设计选择#

覆盖NumPy的API有很多可能的设计选择。在这里,我们讨论了设计决策的三个主要方面,这些方面指导了我们对__array_module__的设计。

用户选择加入与选择退出#

__array_ufunc____array_function__协议提供了一种机制,用于在NumPy的现有命名空间中覆盖NumPy函数。这意味着如果用户不希望有任何被覆盖的行为,则需要显式选择退出,例如,通过使用np.asarray()转换数组。

理论上,这种方法降低了在用户代码和库中采用这些协议的障碍,因为使用标准NumPy命名空间的代码会自动兼容。但在实践中,这并没有奏效。例如,大多数维护良好的使用NumPy的库都遵循使用np.asarray()转换所有输入的最佳实践,他们必须显式地放松此限制才能使用__array_function__。我们的经验是,使库与新的鸭子类型数组兼容通常需要至少少量的工作来适应数据模型和可以有效实现的操作的差异。

这些选择退出方法还会大大使采用这些协议的库的向后兼容性复杂化,因为库通过选择加入,它们也选择了它们的用户的加入,无论他们是否期望它。为了赢得那些无法采用__array_function__的库的支持,选择加入方法似乎是必须的。

显式与隐式实现选择#

__array_ufunc____array_function__都对调度进行隐式控制:调度函数通过每次函数调用中的适当协议确定。这很好地概括了处理许多不同类型的对象,正如它在Python中用于实现算术运算符所证明的那样,但它对于**可读性**有一个重要的缺点:调用函数时发生的事情不再对代码阅读者立即显而易见,因为函数的实现可能会被它的任何参数覆盖。

**速度**影响是

  • 当使用鸭子类型数组时,get_array_module意味着类型检查只需要在每个支持鸭子类型的函数内部发生一次,而使用__array_function__时,它会在每次调用NumPy函数时发生。显然,这取决于函数,但是如果一个典型的支持鸭子类型数组的函数调用其他NumPy函数3-5次,那么这将是3-5倍的额外开销。

  • 当使用NumPy数组时,get_array_module每个函数有一个额外的调用(__array_function__开销保持不变),这意味着少量额外的开销。

显式和隐式实现选择并非相互排斥的选项。事实上,我们熟悉的通过__array_function__实现NumPy API覆盖的大多数实现(即Dask、CuPy和Sparse,但不是Pint)也包括一种显式的方法来通过直接导入模块来使用它们自己的NumPy API版本(分别为dask.arraycupysparse)。

局部控制、非局部控制与全局控制#

最终的设计轴是用户如何控制API的选择

  • **局部控制**,例如多重分派和Python算术协议,通过检查类型或调用函数的直接参数上的方法来确定使用哪个实现。

  • **非局部控制**,例如np.errstate,通过函数装饰器或上下文管理器使用全局状态来覆盖行为。控制是分层确定的,通过最内部的上下文。

  • **全局控制**为用户提供了一种机制来设置默认行为,可以通过函数调用或配置文件来实现。例如,matplotlib允许设置绘图后端的全局选择。

局部控制通常被认为是API设计的最佳实践,因为控制流是完全明确的,这使得它最容易理解。非局部和全局控制偶尔会被使用,但通常是由于缺乏知识或缺乏更好的替代方案。

对于NumPy公共API的鸭子类型,我们认为非局部或全局控制将是错误的,主要是因为它们**不能很好地组合**。如果一个库设置/需要一组覆盖,然后内部调用一个期望另一组覆盖的例程,则结果行为可能会非常令人惊讶。高阶函数尤其成问题,因为函数求值的环境可能与函数定义的环境不同。

我们认为非局部和全局控制是合适的覆盖用例的一类是选择一个保证具有完全一致接口的后端系统,例如NumPy数组上numpy.fft的更快替代实现。但是,这些不在当前提案的范围内,该提案重点关注鸭子类型数组。