NEP 13 — 通用函数覆盖机制#

作者:

Blake Griffith

联系方式:

blake.g@utexas.edu

日期:

2013-07-10

作者:

Pauli Virtanen

作者:

Nathaniel Smith

作者:

Marten van Kerkwijk

作者:

Stephan Hoyer

日期:

2017-03-31

状态:

最终

更新:

2023-02-19

作者:

Roy Smart

执行摘要#

NumPy 的通用函数 (ufunc) 目前具有一些有限的功能,可用于使用 __array_prepare____array_wrap__ [1]ndarray 的用户定义子类进行操作,并且对任意对象的支持很少或根本没有。例如 SciPy 的稀疏矩阵 [2] [3]

这里我们建议添加一种基于 ufunc 检查其每个参数是否存在 __array_ufunc__ 方法来覆盖 ufunc 的机制。如果发现 __array_ufunc__,则 ufunc 将操作传递给该方法。

这涵盖了 Travis Oliphant 使用多方法 [4] 对 NumPy 进行改造的提案中的一些相同内容,该提案将解决相同的问题。此处的机制更紧密地遵循 Python 允许类覆盖 __mul__ 和其他二元运算符的方式。它还专门解决了二元运算符和 ufunc 如何交互的问题。(请注意,在早期版本中,覆盖称为 __numpy_ufunc__。已经实现了该功能,但行为并不完全正确,因此名称发生了更改。)

如下所述的 __array_ufunc__ 要求任何相应的 Python 二元运算符(__mul__ 等)都应以特定方式实现并与 NumPy 的 ndarray 语义兼容。不满足此条件的对象无法覆盖任何 NumPy ufunc。我们没有指定未来兼容的路径,通过该路径可以放宽此要求 - 这里出现的任何更改都需要在第三方代码中进行相应的更改。

动机#

人们普遍认为,当前用于分派通用函数的机制不足。已经进行了长时间的讨论和其他提出的解决方案 [5][6]

使用 ndarray 的子类与通用函数的交互仅限于 __array_prepare____array_wrap__ 来准备输出参数,但这些参数不允许您例如更改参数的形状或数据。尝试对不继承自 ndarray 的内容使用通用函数更加困难,因为输入参数往往会被转换为对象数组,这最终会产生令人惊讶的结果。

以通用函数与稀疏矩阵的互操作性为例。

In [1]: import numpy as np
import scipy.sparse as sp

a = np.random.randint(5, size=(3,3))
b = np.random.randint(5, size=(3,3))

asp = sp.csr_matrix(a)
bsp = sp.csr_matrix(b)

In [2]: a, b
Out[2]:(array([[0, 4, 4],
               [1, 3, 2],
               [1, 3, 1]]),
        array([[0, 1, 0],
               [0, 0, 1],
               [4, 0, 1]]))

In [3]: np.multiply(a, b) # The right answer
Out[3]: array([[0, 4, 0],
               [0, 0, 2],
               [4, 0, 1]])

In [4]: np.multiply(asp, bsp).todense() # calls __mul__ which does matrix multi
Out[4]: matrix([[16,  0,  8],
                [ 8,  1,  5],
                [ 4,  1,  4]], dtype=int64)

In [5]: np.multiply(a, bsp) # Returns NotImplemented to user, bad!
Out[5]: NotImplemented

向用户返回 NotImplemented 不应该发生。此外

In [6]: np.multiply(asp, b)
Out[6]: array([[ <3x3 sparse matrix of type '<class 'numpy.int64'>'
                with 8 stored elements in Compressed Sparse Row format>,
                    <3x3 sparse matrix of type '<class 'numpy.int64'>'
                with 8 stored elements in Compressed Sparse Row format>,
                    <3x3 sparse matrix of type '<class 'numpy.int64'>'
                with 8 stored elements in Compressed Sparse Row format>],
                   [ <3x3 sparse matrix of type '<class 'numpy.int64'>'
                with 8 stored elements in Compressed Sparse Row format>,
                    <3x3 sparse matrix of type '<class 'numpy.int64'>'
                with 8 stored elements in Compressed Sparse Row format>,
                    <3x3 sparse matrix of type '<class 'numpy.int64'>'
                with 8 stored elements in Compressed Sparse Row format>],
                   [ <3x3 sparse matrix of type '<class 'numpy.int64'>'
                with 8 stored elements in Compressed Sparse Row format>,
                    <3x3 sparse matrix of type '<class 'numpy.int64'>'
                with 8 stored elements in Compressed Sparse Row format>,
                    <3x3 sparse matrix of type '<class 'numpy.int64'>'
                with 8 stored elements in Compressed Sparse Row format>]], dtype=object)

这里,似乎稀疏矩阵被转换为对象数组标量,然后将其与 b 数组的所有元素相乘。但是,这种行为比有用更令人困惑,并且最好出现 TypeError

此提案 *不会* 解决 scipy.sparse 矩阵的问题,这些矩阵具有与 NumPy 数组不兼容的乘法语义。但是,目的是能够编写其他具有严格与 ndarray 兼容的语义的自定义数组类型。

提议的接口#

标准数组类 ndarray 获得了一个 __array_ufunc__ 方法,并且对象可以通过覆盖此方法(如果它们是 ndarray 的子类)或定义自己的方法来覆盖通用函数。方法签名是

def __array_ufunc__(self, ufunc, method, *inputs, **kwargs)

这里

  • ufunc 是调用的通用函数对象。

  • method 是一个字符串,指示通用函数是如何调用的,可以是 "__call__" 表示直接调用,也可以是其方法之一:"reduce""accumulate""reduceat""outer""at"

  • inputs 是通用函数输入参数的元组

  • kwargs 包含传递给函数的任何可选或关键字参数。这包括任何 out 参数,这些参数始终包含在元组中。

因此,参数被规范化:仅将必需的输入参数(inputs)作为位置参数传递,所有其他参数都作为关键字参数的字典(kwargs)传递。特别是,如果存在输出参数,否则为位置参数,且不为 None,则它们作为 out 关键字参数中的元组传递(即使对于 reduceaccumulatereduceat 方法,在所有当前情况下只有一个输出是有意义的)。

函数分派按以下方式进行

  • 如果输入、输出或 where 参数之一实现了 __array_ufunc__,则会执行它而不是通用函数。

  • 如果多个参数实现了 __array_ufunc__,则按以下顺序尝试:子类优先于超类,输入优先于输出,输出优先于 where,否则从左到右。

  • 第一个 __array_ufunc__ 方法返回除 NotImplemented 之外的任何内容都决定了通用函数的返回值。

  • 如果所有输入参数的 __array_ufunc__ 方法都返回 NotImplemented,则会引发 TypeError

  • 如果 __array_ufunc__ 方法引发错误,则会立即传播该错误。

  • 如果没有任何输入参数具有 __array_ufunc__ 方法,则执行将回退到默认的通用函数行为。

在上述内容中,有一个前提:如果一个类具有 __array_ufunc__ 属性,但它与 ndarray.__array_ufunc__ 相同,则会忽略该属性。这发生在 ndarray 的实例以及未覆盖其继承的 __array_ufunc__ 实现的 ndarray 子类上。

类型转换层次结构#

Python 运算符覆盖机制在如何编写覆盖方法方面提供了很大的自由度,并且需要一定的纪律才能获得可预测的结果。在这里,我们讨论了一种理解某些含义的方法,这可以为设计提供输入。

保持对哪些类型可以“向上转换”到其他类型(可能间接地,例如实现间接 A->B->C 但不实现直接 A->C)的清晰认识很有用。如果 __array_ufunc__ 的实现遵循一致的类型转换层次结构,则可以使用它来理解操作的结果。

类型转换可以表示为如下定义的

对于每个 __array_ufunc__ 方法,从每个可能的输入类型到每个可能的输出类型绘制有向边。

也就是说,在 y = x.__array_ufunc__(a, b, c, ...) 执行除返回 NotImplemented 或引发错误之外的操作的每种情况下,都绘制边 type(a) -> type(y)type(b) -> type(y)、…

如果生成的图是无环的,则它定义了一个连贯的类型转换层次结构(类型之间明确的部分排序)。在这种情况下,涉及多种类型的操作通常会可预测地生成“最高”类型的结果,或者引发TypeError。请参阅本节末尾的示例。

如果图具有循环,则__array_ufunc__类型转换未定义,并且诸如type(multiply(a, b)) != type(multiply(b, a))type(add(a, add(b, c))) != type(add(add(a, b), c))之类的情况不被排除(然后可能总是存在)。

如果类型转换层次结构定义良好,则对于每个类A,所有其他定义了__array_ufunc__的类都属于以下三组之一

  • 高于A:A可以在ufunc中(间接)向上转换到的类型。

  • 低于A:可以在ufunc中(间接)向上转换到A的类型。

  • 不兼容:既不高于也不低于A;无法进行(间接)向上转换的类型。

请注意,NumPy ufunc的传统行为是尝试通过ndarray将未知对象转换为np.asarray()。这相当于在图中将ndarray置于这些对象之上。由于我们在上面定义了ndarray对于具有自定义__array_ufunc__的类返回NotImplemented,这使得ndarray在类型层次结构中位于此类类的下方,从而允许覆盖操作。

鉴于以上,描述传递操作的二元ufunc应该旨在定义一个定义良好的转换层次结构。这对于所有ufunc来说可能也是一种明智的方法——对此的例外情况应仔细考虑是否会导致任何意外行为。

示例

类型转换层次结构。

_images/nep0013_image1.png

类型A的__array_ufunc__可以处理返回C的ndarray,B可以处理返回B的ndarray和D,C可以处理返回C的A和B,但不能处理ndarray或D。结果是一个有向无环图,并定义了一个类型转换层次结构,关系为C > AC > ndarrayC > B > ndarrayC > B > D。类型A与B、D、ndarray不兼容,D与A和ndarray不兼容。涉及这些类的ufunc表达式应产生所涉及的最高类型的结果,或引发TypeError

示例

__array_ufunc__图中的一个循环。

_images/nep0013_image2.png

在这种情况下,__array_ufunc__关系具有长度为1的循环,并且不存在类型转换层次结构。二元运算不是可交换的:type(a + b) is Atype(b + a) is B

示例

__array_ufunc__图中的较长循环。

_images/nep0013_image3.png

在这种情况下,__array_ufunc__关系具有更长的循环,并且不存在类型转换层次结构。二元运算仍然是可交换的,但类型传递性丢失了:type(a + (b + c)) is Atype((a + b) + c) is C

子类层次结构#

通常,希望在ufunc类型转换层次结构中反映类层次结构。建议是,除非输入是同一类或超类的实例,否则类的__array_ufunc__实现通常应返回NotImplemented。这保证了在类型转换层次结构中,超类在下方,子类在上方,其他类不兼容。对此的例外需要检查它们是否尊重隐式类型转换层次结构。

注意

请注意,此处定义的类型转换层次结构和类层次结构的方向“相反”。原则上,让__array_ufunc__也处理子类的实例也是一致的。在这种情况下,“子类优先”调度规则将确保相对相似的结果。但是,行为的指定就不那么明确了。

如果方法一致地使用super()遍历类层次结构[7],则可以轻松地构建子类。为了支持这一点,ndarray有自己的__array_ufunc__方法,等效于

def __array_ufunc__(self, ufunc, method, *inputs, **kwargs):
    # Cannot handle items that have __array_ufunc__ (other than our own).
    outputs = kwargs.get('out', ())
    objs = inputs + outputs
    if "where" in kwargs:
        objs = objs + (kwargs["where"], )
    for item in objs:
        if (hasattr(item, '__array_ufunc__') and
                type(item).__array_ufunc__ is not ndarray.__array_ufunc__):
            return NotImplemented

    # If we didn't have to support legacy behaviour (__array_prepare__,
    # __array_wrap__, etc.), we might here convert python floats,
    # lists, etc, to arrays with
    # items = [np.asarray(item) for item in inputs]
    # and then start the right iterator for the given method.
    # However, we do have to support legacy, so call back into the ufunc.
    # Its arguments are now guaranteed not to have __array_ufunc__
    # overrides, and it will do the coercion to array for us.
    return getattr(ufunc, method)(*items, **kwargs)

请注意,作为特殊情况,即使对于尚未覆盖默认ndarray实现的ndarray子类,ufunc调度机制也不会调用此ndarray.__array_ufunc__方法。因此,调用ndarray.__array_ufunc__不会导致嵌套的ufunc调度循环。

使用super()对于仅添加属性(如单位)的ndarray的子类应该特别有用。在其__array_ufunc__实现中,此类类可以对与其自身类相关的参数进行可能的调整,并使用super()传递到超类实现,直到ufunc实际完成,然后对输出进行可能的调整。

通常,__array_ufunc__的自定义实现应避免嵌套调度循环,其中一个不仅通过getattr(ufunc, method)(*items, **kwargs)调用ufunc,而且捕获可能的异常等。一如既往,可能存在例外。例如,对于像MaskedArray这样的类,它只关心它包含的内容是否是ndarray的子类,使用__array_ufunc__的重新实现可能更容易通过直接将ufunc应用于其数据,然后调整掩码来完成。实际上,可以将其视为类的一部分,用于确定它是否可以处理另一个参数(即它在类型层次结构中的位置)。在这种情况下,如果尝试失败,则应返回NotImplemented。因此,实现将类似于

def __array_ufunc__(self, ufunc, method, *inputs, **kwargs):
    # for simplicity, outputs are ignored here.
    unmasked_items = tuple((item.data if isinstance(item, MaskedArray)
                            else item) for item in inputs)
    try:
        unmasked_result = getattr(ufunc, method)(*unmasked_items, **kwargs)
    except TypeError:
        return NotImplemented
    # for simplicity, ignore that unmasked_result could be a tuple
    # or a scalar.
    if not isinstance(unmasked_result, np.ndarray):
        return NotImplemented
    # now combine masks and view as MaskedArray instance
    ...

作为一个具体的例子,考虑一个量和一个掩码数组类,它们都覆盖了__array_ufunc__,具有特定的实例qma,后者包含一个常规数组。执行np.multiply(q, ma),ufunc将首先分派到q.__array_ufunc__,后者返回NotImplemented(因为量类将其自身转换为数组并调用super(),后者传递到ndarray.__array_ufunc__,后者看到ma上的覆盖)。接下来,ma.__array_ufunc__获得机会。它不知道量,如果它也只返回NotImplemented,则会产生TypeError。但在我们的示例实现中,它使用getattr(ufunc, method)来有效地评估np.multiply(q, ma.data)。这将再次传递到q.__array_ufunc__,但这一次,由于ma.data是常规数组,因此它将返回一个也是量的结果。由于它是ndarray的子类,因此ma.__array_ufunc__可以将其转换为掩码数组,从而返回结果(显然,如果它不是数组子类,它仍然可以返回NotImplemented)。

请注意,在上面讨论的类型层次结构的上下文中,这是一个有点棘手的示例,因为MaskedArray的位置很奇怪:它位于ndarray的所有子类之上,因为它可以将其转换为自己的类型,但它本身不知道如何在ufunc中与它们交互。

关闭Ufunc#

对于某些类,Ufunc毫无意义,并且,就像某些其他特殊方法(例如__hash____iter__[8]一样,可以通过将__array_ufunc__设置为None来指示Ufunc不可用。如果在任何将__array_ufunc__ = None设置的操作数上调用Ufunc,它将无条件地引发TypeError

在类型转换层次结构中,这明确表明该类型相对于ndarray是不兼容的。

与Python的二元运算符结合使用时的行为#

Python运算符重载机制在ndarray中与__array_ufunc__机制耦合。对于Python用于实现二元运算(如*+)的特殊方法调用(例如ndarray.__mul__(self, other)),NumPy的ndarray实现了以下行为

  • 如果other.__array_ufunc__ is None,则ndarray返回NotImplemented。控制权将返回到Python,后者将依次尝试在other上调用相应反射方法(例如other.__rmul__),如果存在。

  • 如果other上缺少__array_ufunc__属性,并且other.__array_priority__ > self.__array_priority__,则ndarray也返回NotImplemented(逻辑与前一种情况相同)。这确保了与旧版本的NumPy向后兼容。

  • 否则,ndarray 会单方面调用相应的 Ufunc。Ufuncs 永远不会返回 NotImplemented,因此**诸如** other.__rmul__ **之类的反射方法无法用于覆盖 NumPy 数组的算术运算,如果** __array_ufunc__ **被设置为** None **以外的任何值**。相反,需要通过以与相应 Ufunc(例如 np.multiply)一致的方式实现 __array_ufunc__ 来更改其行为。有关受影响运算符及其对应 ufunc 的列表,请参见 运算符和 NumPy Ufuncs 列表

希望修改与 ndarray 在二元运算中交互的类有两个选项

  1. 实现 __array_ufunc__ 并遵循 Python 二元运算的 NumPy 语义(见下文)。

  2. 设置 __array_ufunc__ = None,并自由实现 Python 二元运算。在这种情况下,在此参数上调用的 ufuncs 将引发 TypeError(参见 关闭 Ufuncs)。

实现二元运算的建议#

对于大多数数值类,覆盖二元运算最简单的方法是定义 __array_ufunc__ 并覆盖相应的 Ufunc。然后,该类可以像 ndarray 本身一样,根据 Ufuncs 定义二元运算符。在这里,必须注意确保允许其他类指示它们不兼容,即实现应该类似于

def _disables_array_ufunc(obj):
    try:
        return obj.__array_ufunc__ is None
    except AttributeError:
        return False

class ArrayLike:
    ...
    def __array_ufunc__(self, ufunc, method, *inputs, **kwargs):
        ...
        return result

    # Option 1: call ufunc directly
    def __mul__(self, other):
        if _disables_array_ufunc(other):
            return NotImplemented
        return np.multiply(self, other)

    def __rmul__(self, other):
        if _disables_array_ufunc(other):
            return NotImplemented
        return np.multiply(other, self)

    def __imul__(self, other):
        return np.multiply(self, other, out=(self,))

    # Option 2: call into one's own __array_ufunc__
    def __mul__(self, other):
        return self.__array_ufunc__(np.multiply, '__call__', self, other)

    def __rmul__(self, other):
        return self.__array_ufunc__(np.multiply, '__call__', other, self)

    def __imul__(self, other):
        result = self.__array_ufunc__(np.multiply, '__call__', self, other,
                                      out=(self,))
        if result is NotImplemented:
            raise TypeError(...)

为了了解为什么需要谨慎,请考虑另一个类 other,它不知道如何处理数组和 ufuncs,因此已将 __array_ufunc__ 设置为 None,但知道如何进行乘法

class MyObject:
    __array_ufunc__ = None
    def __init__(self, value):
        self.value = value
    def __repr__(self):
        return "MyObject({!r})".format(self.value)
    def __mul__(self, other):
        return MyObject(1234)
    def __rmul__(self, other):
        return MyObject(4321)

对于上述任一选项,我们都会得到预期的结果

mine = MyObject(0)
arr = ArrayLike([0])

mine * arr    # -> MyObject(1234)
mine *= arr   # -> MyObject(1234)
arr * mine    # -> MyObject(4321)
arr *= mine   # -> TypeError

这里,在第一个和第二个示例中,会调用 mine.__mul__(arr),并且结果会立即到达。在第三个示例中,首先调用 arr.__mul__(mine)。在选项 (1) 中,对 mine.__array_ufunc__ is None 的检查将成功,因此返回 NotImplemented,这会导致执行 mine.__rmul__(arg)。在选项 (2) 中,它可能是在 arr.__array_ufunc__ 内部变得清楚另一个参数无法处理,并且再次返回 NotImplemented,导致控制权传递给 mine.__rmul__

对于第四个示例,对于就地运算符,我们在这里遵循了 ndarray 并确保我们永远不会返回 NotImplemented,而是引发 TypeError。在选项 (1) 中,这是间接发生的:我们传递给 np.multiply,它依次立即引发 TypeError,因为其操作数之一(out[0])禁用了 Ufuncs。在选项 (2) 中,我们直接传递给 arr.__array_ufunc__,它将返回 NotImplemented,我们会捕获它。

注意

不允许就地操作返回 NotImplemented 的原因是这些操作不能普遍地被简单的反向操作替换:大多数数组操作假设实例的内容会就地更改,并且不期望新实例。此外,ndarr[:] *= mine 会暗示什么?假设它表示 ndarr[:] = ndarr[:] * mine,正如 python 在 ndarr.__imul__ 返回 NotImplemented 时默认情况下所做的那样,这很可能是不正确的。

现在考虑如果我们没有添加检查会发生什么。对于选项 (1),相关的情况是我们没有检查 __array_func__ 是否设置为 None。在第三个示例中,调用 arr.__mul__(mine),如果没有检查,这将转到 np.multiply(arr, mine)。这会尝试 arr.__array_ufunc__,它返回 NotImplemented 并看到 mine.__array_ufunc__ is None,因此会引发 TypeError

对于选项 (2),相关的示例是第四个,使用 arr *= mine:如果我们让 NotImplemented 通过,python 将用 arr = mine.__rmul__(arr) 替换它,这不是我们想要的。

由于 Ufunc 覆盖和 Python 的二元运算的语义几乎相同,因此在大多数情况下,选项 (1) 和 (2) 将使用相同的 __array_ufunc__ 实现产生相同的结果。一个例外是当第二个参数是第一个参数的子类时尝试实现的顺序,这是由于 Python 的一个错误 [9] 预计将在 Python 3.7 中修复。

一般来说,我们建议采用选项 (1),这是与 ndarray 本身使用的选项最相似的选项。请注意,选项 (1) 是病毒式的,因为任何其他希望支持与您的类进行二元运算的类现在也必须遵循这些规则才能支持与 ndarray 进行二元算术运算(即,它们必须实现 __array_ufunc__ 或将其设置为 None)。我们认为这是一件好事,因为它确保了所有支持 ufuncs 和算术运算的对象的一致性。

为了使实现此类类似数组的类更容易,mixin 类 NDArrayOperatorsMixin 为所有具有对应 Ufuncs 的二元运算符提供了选项 (1) 样式的覆盖。希望为兼容版本的 NumPy 实现 __array_ufunc__ 但也需要在旧版本上支持与 NumPy 数组进行二元算术运算的类应确保 __array_ufunc__ 也可用于实现它们支持的所有二元运算。

最后,我们注意到我们对是否更有意义地要求像 MyObject 这样的类实现完整的 __array_ufunc__ 进行了广泛的讨论 [6]。最终,允许类选择退出是首选,上述推理使我们同意对 ndarray 本身进行类似的实现。选择退出机制要求禁用 Ufuncs,因此类不能定义 Ufuncs 以返回与相应的二元运算不同的结果(即,如果定义了 np.add(x, y),则它应与 x + y 匹配)。我们的目标是尽可能简化与 NumPy 数组进行二元运算的分派逻辑,方法是使其能够使用 Python 的分派规则或 NumPy 的分派规则,而不是同时使用两者的混合。

运算符和 NumPy Ufuncs 列表#

这是一个完整的 Python 二元运算符列表,以及 ndarrayNDArrayOperatorsMixin 使用的相应 NumPy Ufuncs

符号

运算符

NumPy Ufunc(s)

<

lt

less()

<=

le

less_equal()

==

eq

equal()

!=

ne

not_equal()

>

gt

greater()

>=

ge

greater_equal()

+

add

add()

-

sub

subtract()

*

mul

multiply()

/

truediv (Python 3)

true_divide()

/

div (Python 2)

divide()

//

floordiv

floor_divide()

%

mod

remainder()

NA

divmod

divmod()

**

pow

power() [10]

<<

lshift

left_shift()

>>

rshift

right_shift()

&

and_

bitwise_and()

^

xor_

bitwise_xor()

|

or_

bitwise_or()

@

matmul

尚未实现为 ufunc [11]

以下是单目运算符列表

符号

运算符

NumPy Ufunc(s)

-

neg

negative()

+

pos

positive() [12]

NA

abs

absolute()

~

invert

invert()

未来扩展到其他函数#

某些 NumPy 函数可以实现为(广义)Ufunc,在这种情况下,它们可以通过 __array_ufunc__ 方法被覆盖。一个主要的候选者是 matmul(),它目前不是 Ufunc,但可以相对容易地重写为(一组)广义 Ufunc。类似的情况也可能发生在诸如 median()min()argsort() 等函数上。