编写自定义数组容器#

NumPy 的分派机制(在 numpy 版本 v1.16 中引入)是编写与 numpy API 兼容并提供 NumPy 功能自定义实现的自定义 N 维数组容器的推荐方法。应用包括 dask 数组(跨多个节点的 N 维数组)和 cupy 数组(GPU 上的 N 维数组)。

为了体验编写自定义数组容器,我们将从一个实用性相当有限但能说明所涉及概念的简单示例开始。

>>> import numpy as np
>>> class DiagonalArray:
...     def __init__(self, N, value):
...         self._N = N
...         self._i = value
...     def __repr__(self):
...         return f"{self.__class__.__name__}(N={self._N}, value={self._i})"
...     def __array__(self, dtype=None, copy=None):
...         if copy is False:
...             raise ValueError(
...                 "`copy=False` isn't supported. A copy is always created."
...             )
...         return self._i * np.eye(self._N, dtype=dtype)

我们的自定义数组可以像这样实例化:

>>> arr = DiagonalArray(5, 1)
>>> arr
DiagonalArray(N=5, value=1)

我们可以使用 numpy.arraynumpy.asarray 将其转换为 NumPy 数组,这将调用其 __array__ 方法以获得标准的 numpy.ndarray

>>> np.asarray(arr)
array([[1., 0., 0., 0., 0.],
       [0., 1., 0., 0., 0.],
       [0., 0., 1., 0., 0.],
       [0., 0., 0., 1., 0.],
       [0., 0., 0., 0., 1.]])

__array__ 方法可以选择接受一个 dtype 参数。如果提供了此参数,则此参数指定所需 NumPy 数组的数据类型。您的实现应尝试将数据转换为此 dtype(如果可能)。如果不支持转换,通常最好回退到默认类型或引发 TypeErrorValueError

以下是一个展示其在 dtype 规范中使用的示例:

>>> np.asarray(arr, dtype=np.float32)
array([[1., 0., 0., 0., 0.],
       [0., 1., 0., 0., 0.],
       [0., 0., 1., 0., 0.],
       [0., 0., 0., 1., 0.],
       [0., 0., 0., 0., 1.]], dtype=float32)

如果我们使用 NumPy 函数对 arr 进行操作,NumPy 将再次使用 __array__ 接口将其转换为数组,然后按通常方式应用该函数。

>>> np.multiply(arr, 2)
array([[2., 0., 0., 0., 0.],
       [0., 2., 0., 0., 0.],
       [0., 0., 2., 0., 0.],
       [0., 0., 0., 2., 0.],
       [0., 0., 0., 0., 2.]])

请注意,返回类型是标准的 numpy.ndarray

>>> type(np.multiply(arr, 2))
<class 'numpy.ndarray'>

我们如何将自定义数组类型传递给此函数?NumPy 允许类通过 __array_ufunc____array_function__ 接口来指示其希望以自定义定义的方式处理计算。我们一次处理一个,先从 __array_ufunc__ 开始。此方法涵盖 通用函数(ufunc),这是一类函数,例如 numpy.multiplynumpy.sin

__array_ufunc__ 接收:

  • ufunc,一个类似 numpy.multiply 的函数。

  • method,一个字符串,区分 numpy.multiply(...) 和变体,如 numpy.multiply.outernumpy.multiply.accumulate 等。对于常见情况 numpy.multiply(...)method == '__call__'

  • inputs,它可以是不同类型的混合。

  • kwargs,传递给函数的关键字参数。

在此示例中,我们将仅处理 __call__ 方法。

>>> from numbers import Number
>>> class DiagonalArray:
...     def __init__(self, N, value):
...         self._N = N
...         self._i = value
...     def __repr__(self):
...         return f"{self.__class__.__name__}(N={self._N}, value={self._i})"
...     def __array__(self, dtype=None, copy=None):
...         if copy is False:
...             raise ValueError(
...                 "`copy=False` isn't supported. A copy is always created."
...             )
...         return self._i * np.eye(self._N, dtype=dtype)
...     def __array_ufunc__(self, ufunc, method, *inputs, **kwargs):
...         if method == '__call__':
...             N = None
...             scalars = []
...             for input in inputs:
...                 if isinstance(input, Number):
...                     scalars.append(input)
...                 elif isinstance(input, self.__class__):
...                     scalars.append(input._i)
...                     if N is not None:
...                         if N != input._N:
...                             raise TypeError("inconsistent sizes")
...                     else:
...                         N = input._N
...                 else:
...                     return NotImplemented
...             return self.__class__(N, ufunc(*scalars, **kwargs))
...         else:
...             return NotImplemented

现在我们的自定义数组类型可以通过 NumPy 函数。

>>> arr = DiagonalArray(5, 1)
>>> np.multiply(arr, 3)
DiagonalArray(N=5, value=3)
>>> np.add(arr, 3)
DiagonalArray(N=5, value=4)
>>> np.sin(arr)
DiagonalArray(N=5, value=0.8414709848078965)

此时 arr + 3 不起作用。

>>> arr + 3
Traceback (most recent call last):
...
TypeError: unsupported operand type(s) for +: 'DiagonalArray' and 'int'

要支持它,我们需要定义 Python 接口 __add____lt__ 等来分派到相应的 ufunc。我们可以通过继承混合类 NDArrayOperatorsMixin 来方便地实现这一点。

>>> import numpy.lib.mixins
>>> class DiagonalArray(numpy.lib.mixins.NDArrayOperatorsMixin):
...     def __init__(self, N, value):
...         self._N = N
...         self._i = value
...     def __repr__(self):
...         return f"{self.__class__.__name__}(N={self._N}, value={self._i})"
...     def __array__(self, dtype=None, copy=None):
...         if copy is False:
...             raise ValueError(
...                 "`copy=False` isn't supported. A copy is always created."
...             )
...         return self._i * np.eye(self._N, dtype=dtype)
...     def __array_ufunc__(self, ufunc, method, *inputs, **kwargs):
...         if method == '__call__':
...             N = None
...             scalars = []
...             for input in inputs:
...                 if isinstance(input, Number):
...                     scalars.append(input)
...                 elif isinstance(input, self.__class__):
...                     scalars.append(input._i)
...                     if N is not None:
...                         if N != input._N:
...                             raise TypeError("inconsistent sizes")
...                     else:
...                         N = input._N
...                 else:
...                     return NotImplemented
...             return self.__class__(N, ufunc(*scalars, **kwargs))
...         else:
...             return NotImplemented
>>> arr = DiagonalArray(5, 1)
>>> arr + 3
DiagonalArray(N=5, value=4)
>>> arr > 0
DiagonalArray(N=5, value=True)

现在我们来处理 __array_function__。我们将创建一个字典,将 NumPy 函数映射到我们的自定义变体。

>>> HANDLED_FUNCTIONS = {}
>>> class DiagonalArray(numpy.lib.mixins.NDArrayOperatorsMixin):
...     def __init__(self, N, value):
...         self._N = N
...         self._i = value
...     def __repr__(self):
...         return f"{self.__class__.__name__}(N={self._N}, value={self._i})"
...     def __array__(self, dtype=None, copy=None):
...         if copy is False:
...             raise ValueError(
...                 "`copy=False` isn't supported. A copy is always created."
...             )
...         return self._i * np.eye(self._N, dtype=dtype)
...     def __array_ufunc__(self, ufunc, method, *inputs, **kwargs):
...         if method == '__call__':
...             N = None
...             scalars = []
...             for input in inputs:
...                 # In this case we accept only scalar numbers or DiagonalArrays.
...                 if isinstance(input, Number):
...                     scalars.append(input)
...                 elif isinstance(input, self.__class__):
...                     scalars.append(input._i)
...                     if N is not None:
...                         if N != input._N:
...                             raise TypeError("inconsistent sizes")
...                     else:
...                         N = input._N
...                 else:
...                     return NotImplemented
...             return self.__class__(N, ufunc(*scalars, **kwargs))
...         else:
...             return NotImplemented
...     def __array_function__(self, func, types, args, kwargs):
...         if func not in HANDLED_FUNCTIONS:
...             return NotImplemented
...         # Note: this allows subclasses that don't override
...         # __array_function__ to handle DiagonalArray objects.
...         if not all(issubclass(t, self.__class__) for t in types):
...             return NotImplemented
...         return HANDLED_FUNCTIONS[func](*args, **kwargs)
...

一种方便的模式是定义一个 implements 装饰器,它可以用来将函数添加到 HANDLED_FUNCTIONS

>>> def implements(np_function):
...    "Register an __array_function__ implementation for DiagonalArray objects."
...    def decorator(func):
...        HANDLED_FUNCTIONS[np_function] = func
...        return func
...    return decorator
...

现在我们为 DiagonalArray 编写 NumPy 函数的实现。为了完整起见,为了支持用法 arr.sum(),请添加一个调用 numpy.sum(self)sum 方法,对于 mean 也是如此。

>>> @implements(np.sum)
... def sum(arr):
...     "Implementation of np.sum for DiagonalArray objects"
...     return arr._i * arr._N
...
>>> @implements(np.mean)
... def mean(arr):
...     "Implementation of np.mean for DiagonalArray objects"
...     return arr._i / arr._N
...
>>> arr = DiagonalArray(5, 1)
>>> np.sum(arr)
5
>>> np.mean(arr)
0.2

如果用户尝试使用 HANDLED_FUNCTIONS 中未包含的任何 NumPy 函数,NumPy 将引发 TypeError,表明此操作不受支持。例如,连接两个 DiagonalArrays 不会产生另一个对角线数组,因此不受支持。

>>> np.concatenate([arr, arr])
Traceback (most recent call last):
...
TypeError: no implementation found for 'numpy.concatenate' on types that implement __array_function__: [<class '__main__.DiagonalArray'>]

此外,我们对 summean 的实现不支持 NumPy 实现所支持的可选参数。

>>> np.sum(arr, axis=0)
Traceback (most recent call last):
...
TypeError: sum() got an unexpected keyword argument 'axis'

用户始终可以选择使用 numpy.asarray 转换为常规 numpy.ndarray,然后从中进行标准的 NumPy 操作。

>>> np.concatenate([np.asarray(arr), np.asarray(arr)])
array([[1., 0., 0., 0., 0.],
       [0., 1., 0., 0., 0.],
       [0., 0., 1., 0., 0.],
       [0., 0., 0., 1., 0.],
       [0., 0., 0., 0., 1.],
       [1., 0., 0., 0., 0.],
       [0., 1., 0., 0., 0.],
       [0., 0., 1., 0., 0.],
       [0., 0., 0., 1., 0.],
       [0., 0., 0., 0., 1.]])

此示例中 DiagonalArray 的实现仅出于简洁性考虑,处理了 np.sumnp.mean 函数。NumPy API 中还有许多其他函数可供包装,一个功能齐全的自定义数组容器可以显式支持 NumPy 提供的所有可包装函数。

NumPy 在 numpy.testing.overrides 命名空间中提供了一些实用程序,用于帮助测试实现 __array_ufunc____array_function__ 协议的自定义数组容器。

要检查 NumPy 函数是否可以通过 __array_ufunc__ 进行覆盖,您可以使用 allows_array_ufunc_override

>>> from numpy.testing.overrides import allows_array_ufunc_override
>>> allows_array_ufunc_override(np.add)
True

类似地,您可以使用 allows_array_function_override 检查函数是否可以通过 __array_function__ 进行覆盖。

NumPy API 中每个可覆盖函数的列表也可通过 get_overridable_numpy_array_functions(对于支持 __array_function__ 协议的函数)和 get_overridable_numpy_ufuncs(对于支持 __array_ufunc__ 协议的函数)获得。这两个函数都返回 NumPy 公共 API 中存在的函数集。用户定义的 ufunc 或在依赖于 NumPy 的其他库中定义的 ufunc 不包含在这些集中。

有关更完整的自定义数组容器示例,请参阅 dask 源代码cupy 源代码

另请参阅 NEP 18