NumPy 中的数据类型提升#
当混合两种不同的数据类型时,NumPy 需要确定操作结果的合适 dtype。这个步骤被称为*提升*或*查找公共 dtype*。
在典型情况下,用户无需担心提升的细节,因为提升步骤通常会确保结果要么与输入的精度相同,要么超过输入的精度。
例如,当输入具有相同的 dtype 时,结果的 dtype 与输入的 dtype 匹配
>>> np.int8(1) + np.int8(1)
np.int8(2)
混合两种不同的 dtype 通常会产生具有较高精度输入 dtype 的结果
>>> np.int8(4) + np.int64(8) # 64 > 8
np.int64(12)
>>> np.float32(3) + np.float16(3) # 32 > 16
np.float32(6.0)
在典型情况下,这不会带来意外。但是,如果您处理非默认 dtype,如无符号整数和低精度浮点数,或者您混合 NumPy 整数、NumPy 浮点数和 Python 标量,NumPy 提升规则的某些细节可能很重要。请注意,这些详细规则并不总是与其他语言的规则相匹配[1]。
数值 dtype 分为四种“种类”,具有自然层级。
无符号整数 (
uint)有符号整数 (
int)浮点数 (
float)复数 (
complex)
除了种类之外,NumPy 数值 dtype 还具有关联的精度,以比特为单位指定。种类和精度共同指定了 dtype。例如,uint8 是一个使用 8 比特存储的无符号整数。
操作的结果将始终具有与任何输入相等或更高的种类。此外,结果的精度将始终大于或等于输入的精度。这已经可能导致一些意想不到的例子
混合浮点数和整数时,整数的精度可能会迫使结果提高到更高的浮点精度。例如,涉及
int64和float16的操作结果是float64。当混合具有相同精度的无符号和有符号整数时,结果的精度将比任一输入都*高*。此外,如果其中一个已经具有 64 位精度,则没有更高的整数精度可用,例如,涉及
int64和uint64的操作将得到float64。
有关两者的详细信息,请参阅*数值提升*部分和下图。
Python 标量的详细行为#
自 NumPy 2.0 起[2],我们提升规则中的一个重要方面是,虽然涉及两个 NumPy dtypes 的操作永远不会丢失精度,但涉及 NumPy dtype 和 Python 标量(int、float 或 complex)的操作*可能*会丢失精度。例如,Python 整数和 NumPy 整数之间的操作结果应为 NumPy 整数,这可能是符合直觉的。然而,Python 整数具有任意精度,而所有 NumPy dtypes 都具有固定精度,因此 Python 整数的任意精度无法保留。
更一般地说,NumPy 在确定结果 dtype 时会考虑 Python 标量的“种类”,但会忽略其精度。这通常很方便。例如,当使用低精度 dtype 的数组时,简单的 Python 标量操作通常希望保留 dtype。
>>> arr_float32 = np.array([1, 2.5, 2.1], dtype="float32")
>>> arr_float32 + 10.0 # undesirable to promote to float64
array([11. , 12.5, 12.1], dtype=float32)
>>> arr_int16 = np.array([3, 5, 7], dtype="int16")
>>> arr_int16 + 10 # undesirable to promote to int64
array([13, 15, 17], dtype=int16)
在这两种情况下,结果精度均由 NumPy dtype 决定。因此,arr_float32 + 3.0 的行为与 arr_float32 + np.float32(3.0) 相同,而 arr_int16 + 10 的行为与 arr_int16 + np.int16(10.) 相同。
作为另一个例子,当混合 NumPy 整数与 Python float 或 complex 时,结果类型始终为 float64 或 complex128。
>>> np.int16(1) + 1.0 np.float64(2.0)
然而,当处理低精度 dtype 时,这些规则也可能导致意外行为。
首先,由于 Python 值在执行操作之前会被转换为 NumPy 值,当结果看似明显时,操作可能会因错误而失败。例如,np.int8(1) + 1000 无法继续,因为 1000 超出了 int8 的最大值。当 Python 标量无法强制转换为 NumPy dtype 时,会引发错误。
>>> np.int8(1) + 1000
Traceback (most recent call last):
...
OverflowError: Python integer 1000 out of bounds for int8
>>> np.int64(1) * 10**100
Traceback (most recent call last):
...
OverflowError: Python int too large to convert to C long
>>> np.float32(1) + 1e300
np.float32(inf)
... RuntimeWarning: overflow encountered in cast
其次,由于 Python 浮点数或整数的精度总是被忽略,低精度的 NumPy 标量将继续使用其较低的精度,除非显式转换为更高精度的 NumPy dtype 或 Python 标量(例如,通过 int()、float() 或 scalar.item())。这种较低的精度可能会损害某些计算或导致结果不正确,尤其是在整数溢出的情况下。
>>> np.int8(100) + 100 # the result exceeds the capacity of int8
np.int8(-56)
... RuntimeWarning: overflow encountered in scalar add
请注意,NumPy 在标量发生溢出时会发出警告,但对数组则不会;例如,np.array(100, dtype="uint8") + 100 不会发出警告。
数值提升#
下图显示了数值提升规则,其中种类位于垂直轴上,精度位于水平轴上。
具有较高种类的输入 dtype 决定了结果 dtype 的种类。结果 dtype 的精度应尽可能低,但不能出现在图中输入 dtype 的左侧。
请注意以下具体规则和观察结果:
当 Python
float或complex与 NumPy 整数交互时,结果将是float64或complex128(黄色边框)。NumPy 布尔值也将转换为默认整数[3]。当同时涉及 NumPy 浮点值时,这不相关。精度绘制方式是
float16 < int16 < uint16,因为大的uint16不适合int16,而大的int16存储在float16中会丢失精度。然而,这种模式被打破了,因为 NumPy 始终认为float64和complex128可以作为任何整数值的提升结果。一个特例是,NumPy 将许多有符号和无符号整数的组合提升为
float64。这里使用更高的种类,因为没有足够精确的有符号整数 dtype 可以容纳uint64。
一般提升规则的例外情况#
在 NumPy 中,提升指的是特定函数对结果的处理方式,在某些情况下,这意味着 NumPy 可能偏离 np.result_type 会给出的结果。
sum 和 prod 的行为#
np.sum 和 np.prod 在对整数值(或布尔值)求和时,始终返回默认整数类型。通常是 int64。这是因为整数求和否则极有可能溢出并产生混淆的结果。此规则也适用于底层的 np.add.reduce 和 np.multiply.reduce。
NumPy 或 Python 整数标量的显著行为#
NumPy 提升指的是结果 dtype 和操作精度,但操作有时会决定结果。除法总是返回浮点值,而比较总是返回布尔值。
这导致了可能看起来是“例外”于规则的行为:
NumPy 与 Python 整数或混合精度整数的比较始终返回正确的结果。输入永远不会以丢失精度的方式进行转换。
无法提升的类型之间的相等性比较将被视为全部
False(相等)或全部True(不相等)。一元数学函数,如
np.sin,总是返回浮点值,它通过将任何 Python 整数输入转换为float64来接受。除法总是返回浮点值,因此也允许任何 NumPy 整数与任何 Python 整数值进行除法,方法是将两者都转换为
float64。
原则上,其中一些例外情况可能对其他函数也有意义。如果您认为如此,请提出一个 issue。
Python 内建类型类的显著行为#
当结合 Python 的内建标量*类型*(即 float、int 或 complex,而不是标量*值*)时,提升规则可能看起来很奇怪。
>>> np.result_type(7, np.array([1], np.float32))
dtype('float32') # The scalar value '7' does not impact type promotion
>>> np.result_type(type(7), np.array([1], np.float32))
dtype('float64') # The *type* of the scalar value '7' does impact promotion
# Similar situations happen with Python's float and complex types
这种行为的原因是 NumPy 将 int 转换为其默认整数类型,并使用该类型进行提升。
>>> np.result_type(int)
dtype('int64')
有关更多详细信息,请参阅内建 Python 类型。
非数值数据类型的提升#
NumPy 将提升扩展到非数值类型,尽管在许多情况下,提升定义不明确且被直接拒绝。
适用以下规则:
NumPy 字节字符串 (
np.bytes_) 可以提升为 Unicode 字符串 (np.str_)。但是,将字节转换为 Unicode 对于非 ASCII 字符会失败。出于某些目的,NumPy 会将几乎任何其他数据类型提升为字符串。这适用于数组创建或连接。
当没有可行的提升时,
np.array()等数组构造函数将使用objectdtype。结构化 dtype 当其字段名称和顺序匹配时可以提升。在这种情况下,所有字段都会单独提升。
NumPy
timedelta在某些情况下可以与整数提升。
注意
其中一些规则有些令人惊讶,并且正在考虑未来进行更改。但是,任何向后不兼容的更改都必须权衡打破现有代码的风险。如果您对提升应该如何工作有具体的想法,请提出一个 issue。
提升的 dtype 实例的详细信息#
上述讨论主要涉及混合不同 DType 类的行为。附加到数组的 dtype 实例可以携带额外信息,如字节序、元数据、字符串长度或精确的结构化 dtype 布局。
虽然结构化 dtype 的字符串长度或字段名称很重要,但 NumPy 将字节序、元数据和结构化 dtype 的精确布局视为存储细节。
在提升过程中,NumPy*不*考虑这些存储细节。
字节序被转换为本地字节序。
附加到 dtype 的元数据可能会或可能不会被保留。
结果的结构化 dtype 将被打包(如果输入被打包,则对齐)。
这种行为对于大多数不相关存储细节对最终结果的程序来说是最佳行为,并且使用不正确的字节序可能会大大降低评估速度。