NumPy 中的数据类型提升#

当混合两种不同的数据类型时,NumPy 必须确定操作结果的适当 dtype。此步骤称为提升查找公共 dtype

在典型情况下,用户无需担心提升的细节,因为提升步骤通常确保结果要么与输入的精度匹配,要么超过输入的精度。

例如,当输入具有相同的 dtype 时,结果的 dtype 与输入的 dtype 相匹配

>>> np.int8(1) + np.int8(1)
np.int8(2)

混合两个不同的 dtype 通常会产生一个具有更高精度输入的 dtype 的结果

>>> np.int8(4) + np.int64(8)  # 64 > 8
np.int64(12)
>>> np.float32(3) + np.float16(3)  # 32 > 16
np.float32(6.0)

在典型情况下,这不会导致意外情况。但是,如果您使用非默认 dtype(如无符号整数和低精度浮点数),或者如果您混合了 NumPy 整数、NumPy 浮点数和 Python 标量,那么 NumPy 提升规则的一些细节可能与您相关。请注意,这些详细规则并不总是与其他语言的规则相匹配[1]

数值 dtype 分为四种“类型”,并具有自然的层次结构。

  1. 无符号整数 (uint)

  2. 有符号整数 (int)

  3. 浮点数 (float)

  4. 复数 (complex)

除了类型之外,NumPy 数值 dtype 还与一个关联的精度相关联,该精度以位数指定。类型和精度共同指定了 dtype。例如,uint8 是一个使用 8 位存储的无符号整数。

操作的结果将始终等于或高于任何输入的类型。此外,结果的精度将始终大于或等于输入的精度。这已经可以导致一些可能出乎意料的例子

  1. 当混合浮点数和整数时,整数的精度可能会迫使结果提升到更高的精度浮点数。例如,涉及 int64float16 的操作的结果为 float64

  2. 当混合具有相同精度的无符号整数和有符号整数时,结果将具有更高的精度,而不是任何输入的精度。此外,如果其中一个已经具有 64 位精度,则没有更高的精度整数可用,例如,涉及 int64uint64 的操作将返回 float64

请参阅下面的 数值提升部分和图像,了解有关这两个方面的详细信息。

Python 标量的详细行为#

从 NumPy 2.0 开始[2],提升规则中一个重要的点是,虽然涉及两个 NumPy dtype 的操作永远不会丢失精度,但涉及 NumPy dtype 和 Python 标量 (intfloatcomplex) 的操作可以丢失精度。例如,Python 整数和 NumPy 整数之间的操作结果应该是 NumPy 整数,这可能很直观。但是,Python 整数具有任意精度,而所有 NumPy dtype 都有固定精度,因此无法保留 Python 整数的任意精度。

更一般地说,NumPy 会考虑 Python 标量的“类型”,但在确定结果 dtype 时会忽略它们的精度。这通常很方便。例如,在使用低精度 dtype 的数组时,通常希望使用 Python 标量进行简单的操作能够保留 dtype。

>>> arr_float32 = np.array([1, 2.5, 2.1], dtype="float32")
>>> arr_float32 + 10.0  # undesirable to promote to float64
array([11. , 12.5, 12.1], dtype=float32)
>>> arr_int16 = np.array([3, 5, 7], dtype="int16")
>>> arr_int16 + 10  # undesirable to promote to int64
array([13, 15, 17], dtype=int16)

在这两种情况下,结果精度都由 NumPy dtype 决定。因此,arr_float32 + 3.0 的行为与 arr_float32 + np.float32(3.0) 相同,arr_int16 + 10 的行为与 arr_int16 + np.int16(10.) 相同。

作为另一个例子,当混合 NumPy 整数和 Python floatcomplex 时,结果始终具有类型 float64complex128

>> np.int16(1) + 1.0 np.float64(2.0)

但是,这些规则在使用低精度 dtype 时也会导致意外行为。

首先,由于 Python 值在执行操作之前被转换为 NumPy 值,因此当结果看起来很明显时,操作可能会失败并出现错误。例如,np.int8(1) + 1000 无法继续,因为 1000 超出了 int8 的最大值。当 Python 标量无法强制转换为 NumPy dtype 时,会引发错误

>>> np.int8(1) + 1000
Traceback (most recent call last):
  ...
OverflowError: Python integer 1000 out of bounds for int8
>>> np.int64(1) * 10**100
Traceback (most recent call last):
...
OverflowError: Python int too large to convert to C long
>>> np.float32(1) + 1e300
np.float32(inf)
... RuntimeWarning: overflow encountered in cast

其次,由于始终会忽略 Python 浮点数或整数精度,因此低精度 NumPy 标量会一直使用其较低的精度,除非明确转换为更高精度的 NumPy dtype 或 Python 标量(例如,通过 int()float()scalar.item())。这种较低的精度可能会对某些计算有害或导致错误的结果,尤其是在整数溢出情况下

>>> np.int8(100) + 100  # the result exceeds the capacity of int8
np.int8(-56)
... RuntimeWarning: overflow encountered in scalar add

请注意,当标量发生溢出时,NumPy 会发出警告,但不会对数组发出警告;例如,np.array(100, dtype="uint8") + 100 *不会*发出警告。

数值提升#

下图显示了数值提升规则,类型在垂直轴上,精度在水平轴上。

../_images/nep-0050-promotion-no-fonts.svg

具有更高类型的输入 dtype 决定结果 dtype 的类型。结果 dtype 的精度尽可能低,但不会出现在图中任何输入 dtype 的左侧。

请注意以下具体规则和观察结果

  1. 当 Python floatcomplex 与 NumPy 整数交互时,结果将是 float64complex128(黄色边框)。NumPy 布尔值也将被强制转换为默认整数。[#default-int] 当另外涉及 NumPy 浮点值时,这无关紧要。

  2. 精度的绘制方式为 float16 < int16 < uint16,因为较大的 uint16 不适合 int16,并且较大的 int16 在存储为 float16 时会丢失精度。但是,这种模式被打破了,因为 NumPy 始终认为 float64complex128 是任何整数值的可接受提升结果。

  3. 一个特例是 NumPy 将许多有符号和无符号整数的组合提升为 float64。在这里使用更高类型是因为没有足够精度的有符号整数 dtype 可以容纳 uint64

一般提升规则的例外#

在 NumPy 中,提升指的是特定函数对结果的处理方式,在某些情况下,这意味着 NumPy 可能会偏离 np.result_type 所返回的值。

sumprod 的行为#

``np.sum`` 和 ``np.prod``:在对整数值(或布尔值)求和时,始终返回默认整数类型。这通常是 int64。原因是整数求和否则很容易溢出并给出令人困惑的结果。此规则也适用于底层的 np.add.reducenp.multiply.reduce

使用 NumPy 或 Python 整数标量的显著行为#

NumPy 提升指的是结果 dtype 和操作精度,但操作有时会决定结果。除法始终返回浮点值,比较始终返回布尔值。

这会导致看起来像“例外”的行为

  • NumPy 与 Python 整数或混合精度整数的比较始终返回正确的结果。输入永远不会以导致精度丢失的方式进行强制转换。

  • 无法提升的类型之间的相等比较将被认为全部为 False(相等)或全部为 True(不相等)。

  • np.sin 这样的始终返回浮点值的单目数学函数,通过将任何 Python 整数转换为 float64 来接受任何 Python 整数输入。

  • 除法始终返回浮点值,因此也允许在任何 NumPy 整数与任何 Python 整数之间进行除法,方法是将两者都转换为 float64

原则上,这些例外中的一部分可能对其他函数有意义。如果您认为情况确实如此,请提出问题。

非数值数据类型的提升#

NumPy 将提升扩展到非数值类型,尽管在许多情况下,提升没有明确定义,并且简单地被拒绝。

以下规则适用

  • NumPy 字节字符串 (np.bytes_) 可以提升为 Unicode 字符串 (np.str_)。但是,将字节转换为 Unicode 将对非 ASCII 字符失败。

  • 出于某些目的,NumPy 将几乎所有其他数据类型提升为字符串。这适用于数组创建或连接。

  • 当没有可行的提升时,数组构造器(如 np.array())将使用 object 数据类型。

  • 当结构化数据类型的字段名称和顺序匹配时,结构化数据类型可以提升。在这种情况下,所有字段将被单独提升。

  • NumPy timedelta 在某些情况下可以与整数一起提升。

注意

其中一些规则有点令人惊讶,并且正在考虑在将来进行更改。但是,任何向后不兼容的更改都必须权衡破坏现有代码的风险。如果您对提升应该如何运作有任何具体的想法,请提出问题。

已提升 dtype 实例的详细信息#

以上讨论主要涉及混合不同 DType 类时的行为。附加到数组的 dtype 实例可以携带其他信息,例如字节序、元数据、字符串长度或精确的结构化数据类型布局。

虽然结构化数据类型的字符串长度或字段名称很重要,但 NumPy 将字节序、元数据以及结构化数据类型的精确布局视为存储细节。在提升过程中,NumPy *不会* 考虑这些存储细节:* 字节序被转换为本机字节序。* 附加到数据类型的元数据可能会或可能不会被保留。* 生成的结构化数据类型将被打包(但如果输入是,则对齐)。

这种行为对于大多数程序来说是最佳行为,因为存储细节与最终结果无关,并且使用不正确的字节序可能会大幅降低评估速度。