广播#
另请参阅
术语“广播”描述了 NumPy 在算术运算中如何处理不同形状的数组。在满足某些约束的条件下,较小的数组会被“广播”到较大的数组上,以便它们具有兼容的形状。广播提供了一种向量化数组运算的方法,从而使循环在 C 语言而不是 Python 中执行。它这样做不会复制不必要的数据,并且通常能实现高效的算法。然而,在某些情况下,广播是一个糟糕的主意,因为它会导致内存的低效使用,从而减慢计算速度。
NumPy 操作通常是逐个元素地对数组对进行。在最简单的情况下,两个数组必须具有完全相同的形状,如下例所示。
>>> import numpy as np
>>> a = np.array([1.0, 2.0, 3.0])
>>> b = np.array([2.0, 2.0, 2.0])
>>> a * b
array([2., 4., 6.])
当数组的形状满足某些约束时,NumPy 的广播规则放宽了这一限制。最简单的广播示例发生在数组和标量值在运算中组合时。
>>> import numpy as np
>>> a = np.array([1.0, 2.0, 3.0])
>>> b = 2.0
>>> a * b
array([2., 4., 6.])
结果等同于之前将 b 视为数组的示例。我们可以将标量 b 在算术运算中“拉伸”成一个与 a 形状相同的数组。如图 1所示,b 中的新元素只是原始标量值的副本。拉伸的比喻只是概念性的。NumPy 足够智能,可以在不实际复制的情况下使用原始标量值,从而使广播操作在内存和计算上尽可能高效。
图 1#
在最简单的广播示例中,标量 b 被拉伸成与 a 形状相同的数组,以便形状兼容进行逐元素乘法。
第二个示例中的代码比第一个示例中的代码更有效,因为广播在乘法过程中移动的内存更少(b 是一个标量而不是一个数组)。
通用广播规则#
当对两个数组进行运算时,NumPy 会逐个元素地比较它们的形状。它从最后一个(即最右边的)维度开始,然后向左移动。当满足以下条件时,两个维度是兼容的:
它们相等,或
其中一个维度的大小为 1。
如果不满足这些条件,则会抛出 ValueError: operands could not broadcast together 异常,表明数组的形状不兼容。
输入数组不必具有相同数量的维度。结果数组将具有输入数组中维度最多的那个数组的维度数,并且每个维度的大小是输入数组中对应维度中的最大大小。请注意,缺失的维度被假定为大小为 1。
例如,如果您有一个 256x256x3 的 RGB 值数组,并且想用不同的值缩放图像中的每种颜色,您可以将图像乘以一个具有 3 个值的一维数组。根据广播规则对这些数组的最后一个轴的大小进行对齐,显示它们是兼容的。
Image (3d array): 256 x 256 x 3
Scale (1d array): 3
Result (3d array): 256 x 256 x 3
当比较的两个维度中有一个为 1 时,则使用另一个维度。换句话说,大小为 1 的维度会被拉伸或“复制”以匹配另一个维度。
在以下示例中,A 和 B 数组都有长度为 1 的轴,这些轴在广播操作期间被扩展到更大的尺寸。
A (4d array): 8 x 1 x 6 x 1
B (3d array): 7 x 1 x 5
Result (4d array): 8 x 7 x 6 x 5
可广播数组#
如果上述规则产生有效的结果,则一组数组被称为可以广播到同一形状。
例如,如果 a.shape 是 (5,1),b.shape 是 (1,6),c.shape 是 (6,),d.shape 是 (),那么 a, b, c, 和 d 都可以广播到维度 (5,6);并且
a 表现得像一个 (5,6) 的数组,其中
a[:,0]被广播到其他列,b 表现得像一个 (5,6) 的数组,其中
b[0,:]被广播到其他行,c 表现得像一个 (1,6) 的数组,因此表现得像一个 (5,6) 的数组,其中
c[:]被广播到每一行,最后,d 表现得像一个 (5,6) 的数组,其中单个值被重复。
这里还有一些例子
A (2d array): 5 x 4
B (1d array): 1
Result (2d array): 5 x 4
A (2d array): 5 x 4
B (1d array): 4
Result (2d array): 5 x 4
A (3d array): 15 x 3 x 5
B (3d array): 15 x 1 x 5
Result (3d array): 15 x 3 x 5
A (3d array): 15 x 3 x 5
B (2d array): 3 x 5
Result (3d array): 15 x 3 x 5
A (3d array): 15 x 3 x 5
B (2d array): 3 x 1
Result (3d array): 15 x 3 x 5
这里是形状不广播的例子
A (1d array): 3
B (1d array): 4 # trailing dimensions do not match
A (2d array): 2 x 1
B (3d array): 8 x 4 x 3 # second from last dimensions mismatched
当一个 1 维数组被添加到 2 维数组时的广播示例
>>> import numpy as np
>>> a = np.array([[ 0.0, 0.0, 0.0],
... [10.0, 10.0, 10.0],
... [20.0, 20.0, 20.0],
... [30.0, 30.0, 30.0]])
>>> b = np.array([1.0, 2.0, 3.0])
>>> a + b
array([[ 1., 2., 3.],
[11., 12., 13.],
[21., 22., 23.],
[31., 32., 33.]])
>>> b = np.array([1.0, 2.0, 3.0, 4.0])
>>> a + b
Traceback (most recent call last):
ValueError: operands could not be broadcast together with shapes (4,3) (4,)
如图 2所示,b 被添加到 a 的每一行。在图 3中,由于形状不兼容而引发了异常。
图 2#
如果一维数组的元素数量与二维数组的列数匹配,则一维数组添加到二维数组中会导致广播。
图 3#
当数组的最后一个维度不相等时,广播会失败,因为它不可能将第一个数组的行与第二个数组的元素对齐进行逐元素相加。
广播提供了一种方便的方式来计算两个数组的外积(或任何其他外运算)。下面的例子展示了两个一维数组的外加运算。
>>> import numpy as np
>>> a = np.array([0.0, 10.0, 20.0, 30.0])
>>> b = np.array([1.0, 2.0, 3.0])
>>> a[:, np.newaxis] + b
array([[ 1., 2., 3.],
[11., 12., 13.],
[21., 22., 23.],
[31., 32., 33.]])
图 4#
在某些情况下,广播会拉伸两个数组以形成比任何一个初始数组都大的输出数组。
这里,newaxis 索引运算符将一个新轴插入到 a 中,使其成为一个二维 4x1 数组。将 4x1 数组与形状为 (3,) 的 b 组合,会产生一个 4x3 数组。
实际示例:向量量化#
广播在现实问题中经常出现。向量量化(VQ)算法是信息论、分类和其他相关领域中的一个典型例子。VQ 中的基本操作是找到一组点(在 VQ 术语中称为 codes)中最接近给定点(称为 observation)的点。在下面所示的非常简单的二维情况下,observation 中的值描述了要分类的运动员的体重和身高。codes 代表不同的运动员类别。[1] 找到最近的点需要计算 observation 和每个 code 之间的距离。最短的距离提供了最佳匹配。在此示例中,codes[0] 是最接近的类别,表明该运动员可能是篮球运动员。
>>> from numpy import array, argmin, sqrt, sum
>>> observation = array([111.0, 188.0])
>>> codes = array([[102.0, 203.0],
... [132.0, 193.0],
... [45.0, 155.0],
... [57.0, 173.0]])
>>> diff = codes - observation # the broadcast happens here
>>> dist = sqrt(sum(diff**2,axis=-1))
>>> argmin(dist)
0
在此示例中,observation 数组被拉伸以匹配 codes 数组的形状。
Observation (1d array): 2
Codes (2d array): 4 x 2
Diff (2d array): 4 x 2
图 5#
向量量化的基本操作是计算待分类对象(暗方块)与多个已知代码(灰圆圈)之间的距离。在此简单情况下,代码代表单个类别。更复杂的情况每个类别使用多个代码。
通常,会比较大量 observations(可能从数据库读取)与一组 codes。考虑以下场景:
Observation (2d array): 10 x 3
Codes (3d array): 5 x 1 x 3
Diff (3d array): 5 x 10 x 3
三维数组 diff 是广播的结果,而不是计算的必需。大型数据集将生成一个大的中间数组,这在计算上效率低下。相反,如果使用 Python 循环单独计算每个 observation(如上面二维示例中的代码),则会使用一个更小的数组。
广播是一个强大的工具,可以编写简洁且通常直观的代码,这些代码在 C 语言中进行高效计算。但是,在某些情况下,广播对于特定算法会使用不必要的大量内存。在这些情况下,最好在 Python 中编写算法的外部循环。这可能也会产生更具可读性的代码,因为使用广播的算法随着广播维度的增加而变得更难解释。
脚注