NEP 50 — Python 标量类型的提升规则#
- 作者:
Sebastian Berg
- 状态:
最终
- 类型:
标准跟踪
- 创建时间:
2021-05-25
摘要#
自 NumPy 1.7 以来,提升规则使用所谓的“安全类型转换”,该规则依赖于对所涉及的值的检查。 这有助于识别用户的许多边缘情况,但实现起来很复杂,也使得行为难以预测。
有两种令人困惑的结果
基于值的提升意味着该值,例如 Python 整数的值,可以决定
np.result_type
找到的输出类型np.result_type(np.int8, 1) == np.int8 np.result_type(np.int8, 255) == np.int16
出现此逻辑的原因是
1
可以用uint8
或int8
表示,而255
不能用int8
表示,而只能用uint8
或int16
表示。当使用 0-D 数组(所谓的“标量数组”)时,情况也是如此
int64_0d_array = np.array(1, dtype=np.int64) np.result_type(np.int8, int64_0d_array) == np.int8
其中
int64_0d_array
具有int64
dtype 的事实对结果 dtype 没有影响。 在此示例中,dtype=np.int64
被有效地忽略了,因为只有它的值才重要。对于 Python
int
、float
或complex
,该值会像之前显示的那样进行检查。 但令人惊讶的是,当 NumPy 对象是 0-D 数组或 NumPy 标量时,不会进行检查np.result_type(np.array(1, dtype=np.uint8), 1) == np.int64 np.result_type(np.int8(1), 1) == np.int64
原因是当所有对象都是标量或 0-D 数组时,基于值的提升会被禁用。 因此,NumPy 返回与
np.array(1)
相同的类型,这通常是int64
(这取决于系统)。
请注意,这些示例也适用于乘法、加法、比较等操作,以及它们对应的函数,如 np.multiply
。
此 NEP 建议围绕以下两个指导原则重构行为
值绝不能影响结果类型。
NumPy 标量和 0-D 数组的行为应与其 N-D 对等项保持一致。
我们建议删除所有基于值的逻辑,并为 Python 标量添加特殊处理,以保留一些便捷的行为。Python 标量将被视为“弱”类型。 当 NumPy 数组/标量与 Python 标量组合时,它将被转换为 NumPy dtype,因此
np.array([1, 2, 3], dtype=np.uint8) + 1 # returns a uint8 array
np.array([1, 2, 3], dtype=np.float32) + 2. # returns a float32 array
将不会依赖于 Python 值本身。
建议的更改也适用于 np.can_cast(100, np.int8)
,但是,我们预计在实践中,函数(提升)中的行为将比类型转换更改本身重要得多。
注意
在 NumPy 1.24.x 系列中,NumPy 已有初步且有限的支持来测试此提案。
此外,有必要设置以下环境变量
export NPY_PROMOTION_STATE=weak
有效值为 weak
、weak_and_warn
和 legacy
。 请注意,weak_and_warn
实现了此 NEP 中提出的可选警告,预计会非常嘈杂。 我们建议首先使用 weak
选项,并主要使用 weak_and_warn
来了解所观察到的特定行为更改。
存在以下其他 API
np._set_promotion_state()
和np._get_promotion_state()
,它们等效于环境变量。 (非线程/上下文安全。)with np._no_nep50_warning():
允许在使用weak_and_warn
提升时抑制警告。 (线程和上下文安全。)
此时,缺少整数幂的溢出警告。 此外,np.can_cast
在 weak_and_warn
模式下无法发出警告。 其关于 Python 标量输入的行为可能仍在变化中(这应该会影响极少数用户)。
新提议的提升规则的架构#
更改后,NumPy 中的提升将遵循以下架构。 提升始终沿着绿色线发生:在其类型内从左到右,并且仅在必要时才提升到更高的类型。 结果类型始终是输入的最大类型。 请注意,float32
的精度低于 int32
或 uint32
,因此在示意图中略微向左排序。 这是因为 float32
不能精确地表示所有 int32
值。 但是,出于实际原因,NumPy 允许将 int64
提升为 float64
,从而有效地认为它们具有相同的精度。
Python 标量被插入到每个“类型”的最左侧,并且 Python 整数不区分有符号和无符号。 因此,NumPy 提升使用以下有序类型类别
布尔值
integral:有符号或无符号整数
inexact:浮点数和复数浮点数
当使用更高类别(布尔 < 整数 < 非精确)的数据类型提升 Python 标量时,我们使用最小/默认精度:即 float64
、complex128
或 int64
(在某些系统上,例如 Windows,使用 int32
)。
请参阅下一节中的示例,其中阐明了提议的行为。下表提供了更多示例,并与当前行为进行了比较。
新行为示例#
为了更容易理解上面的文本和图表,我们提供一些新行为的示例。在下面,Python 整数对结果类型没有影响
np.uint8(1) + 1 == np.uint8(2)
np.int16(2) + 2 == np.int16(4)
在下面,Python float
和 complex
是“非精确的”,但 NumPy 值是整数,因此我们至少使用 float64
/complex128
np.uint16(3) + 3.0 == np.float64(6.0)
np.int16(4) + 4j == np.complex128(4+4j)
但这不会发生在 float
到 complex
的提升中,其中 float32
和 complex64
具有相同的精度
np.float32(5) + 5j == np.complex64(5+5j)
请注意,示意图中省略了 bool
。它设置在“整数”之下,因此以下成立
np.bool_(True) + 1 == np.int64(2)
True + np.uint8(2) == np.uint8(3)
请注意,虽然此 NEP 使用简单的运算符作为示例,但所描述的规则通常适用于 NumPy 的所有操作。
新旧行为比较表#
下表列出了相关的更改和未更改的行为。请参阅旧实现,详细了解导致“旧结果”的规则,以及以下部分详细介绍新规则。向后兼容性部分讨论了这些更改可能如何影响用户。
请注意像 array(2)
这样的 0 维数组和像 array([2])
这样的非 0 维数组之间的重要区别。
表达式 |
旧结果 |
新结果 |
---|---|---|
|
|
|
|
|
|
|
|
|
|
|
未更改 |
|
|
未更改 |
|
|
未更改 [T3] |
|
|
异常 [T4] |
|
|
异常 [T5] |
|
|
|
|
|
|
|
|
未更改 |
|
|
|
|
|
未更改 |
|
|
|
|
|
|
|
|
|
|
|
未更改 [T12] |
新行为遵从 uint8
标量的数据类型。
当前 NumPy 在与数组组合时会忽略 0 维数组或 NumPy 标量的精度。
当前 NumPy 在与数组组合时会忽略 0 维数组或 NumPy 标量的精度。
旧的行为使用 uint16
,因为 300
不适合 uint8
,新的行为出于相同的原因会引发错误。
300
无法转换为 uint8
。
这可能是最危险的更改之一。保留类型会导致溢出。对于 NumPy 标量,会给出指示溢出的 RuntimeWarning
。
np.float32(3e100)
溢出到无穷大并发出警告。
当在 float32 中完成时,1 + 1e-14
会丢失精度,但在 float64 中则不会。旧的行为根据数组的维度以不同的方式将标量参数转换为 float32 或 float64;在新行为中,计算始终以数组精度(在本例中为 float32)完成。
NumPy 将 float32
和 int64
提升为 float64
。旧的行为在这里忽略了 int64
。
新行为在 array(3, complex64)
和 array([3], complex64)
之间保持一致:结果的数据类型与数组参数的数据类型相同。
新行为使用与数组参数 float32
兼容的精度的复数数据类型。
由于数组类型是整数,因此结果使用默认的复数精度,即 complex128
。
动机和范围#
更改关于检查 Python 标量和 NumPy 标量/0 维数组值的行为的动机有三个方面:
NumPy 标量/0 维数组的特殊处理以及值检查可能会让用户非常惊讶,
值检查逻辑更难解释和实现。通过NEP 42使其可用于用户定义的DType更困难。目前,这导致新旧(值敏感)系统的双重实现。解决这个问题将大大简化内部逻辑,并使结果更加一致。
这在很大程度上与其他项目(如JAX和data-apis.org)的选择一致(另请参见相关工作)。
我们认为,“弱”Python标量的提议将通过为用户提供一个清晰的心智模型,了解操作将导致哪种数据类型来帮助用户。该模型非常适合NumPy当前经常遵循的数组精度保留,也适用于就地操作。
arr += value
只要不跨越“kind”边界,就保留精度(否则会引发错误)。
虽然一些用户可能会怀念值检查行为,但即使在某些看似有用的情况下,它也很快会导致意外。这可能是意料之中的
np.array([100], dtype=np.uint8) + 1000 == np.array([1100], dtype=np.uint16)
但以下情况会让人感到意外
np.array([100], dtype=np.uint8) + 200 == np.array([44], dtype=np.uint8)
考虑到该提案与就地操作数的行为一致,并避免了仅在某些情况下避免结果溢出的令人惊讶的行为切换,我们认为该提案遵循“最小惊讶原则”。
使用和影响#
预计此NEP的实施将不会有任何过渡期来警告所有更改。这样的过渡期会产生许多(通常是无害的)警告,这些警告很难消除。我们预计,大多数用户将长期受益于更清晰的提升规则,并且很少有用户会直接(负面)受到更改的影响。但是,某些使用模式可能会导致有问题的更改,这些更改在向后兼容性部分中详细说明。
解决此问题的方法是使用可选的警告模式,该模式能够通知用户行为的潜在更改。此模式预计会生成许多无害的警告,但会提供一种系统地审查代码并在观察到问题时跟踪更改的方法。
对can_cast
的影响#
can_cast 将不再检查值。因此,以下结果预计会从 True
更改为 False
np.can_cast(np.int64(100), np.uint8)
np.can_cast(np.array(100, dtype=np.int64), np.uint8)
np.can_cast(100, np.uint8)
我们预计,与以下更改相比,此更改的影响将很小。
注意
最后一个输入为Python标量的示例_可能_会被保留,因为 100
可以用 uint8
表示。
对涉及NumPy数组或标量的运算符和函数的影响#
对不涉及Python标量(float
、int
、complex
)的操作的主要影响是,对0-D数组和NumPy标量的操作将永远不会依赖于它们的值。这消除了目前令人惊讶的情况。例如
np.arange(10, dtype=np.uint8) + np.int64(1)
# and:
np.add(np.arange(10, dtype=np.uint8), np.int64(1))
未来将返回 int64
数组,因为 np.int64(1)
的类型会被严格遵守。目前,返回的是 uint8
数组。
对涉及Python int
、float
和 complex
的运算符的影响#
当处理字面量值时,此NEP尝试保留旧行为的便利性。当前的基于值的逻辑在涉及“无类型”的字面量Python标量时,具有一些不错的属性
np.arange(10, dtype=np.int8) + 1 # returns an int8 array
np.array([1., 2.], dtype=np.float32) * 3.5 # returns a float32 array
但当涉及到“无法表示”的值时,会导致意外
np.arange(10, dtype=np.int8) + 256 # returns int16
np.array([1., 2.], dtype=np.float32) * 1e200 # returns float64
该提案旨在在很大程度上保留这种行为。这是通过考虑在操作中将Python int
、float
和 complex
视为“弱”类型来实现的。但是,为了避免意外,我们计划使转换为新类型的过程更加严格:在前两个示例中,结果将保持不变,但在第二个示例中,它将按以下方式更改
np.arange(10, dtype=np.int8) + 256 # raises a TypeError
np.array([1., 2.], dtype=np.float32) * 1e200 # warning and returns infinity
第二个示例会发出警告,因为 np.float32(1e200)
溢出为无穷大。然后,它将像往常一样继续使用 inf
进行计算。
其他库中的行为
在转换中溢出而不是引发错误是一种选择;这是大多数 C 设置中的默认设置(类似于 NumPy,C 可以设置为因溢出而引发错误)。例如,这也是 pytorch
1.10 的行为。
Python 整数的特定行为#
NEP的提升规则以结果dtype(通常也是操作dtype(在结果精度方面))来表示。这会导致Python整数出现看似异常的情况:虽然 uint8(3) + 1000
必须被拒绝,因为在 uint8
中运算是不可能的,但 uint8(3) / 1000
返回 float64
,并且可以将两个输入转换为 float64
来找到结果。
实际上,这意味着在以下情况下,可以接受任意Python整数值
NumPy和Python整数之间的所有比较(
==
、<
等)始终是明确定义的。像
np.sqrt
这样给出浮点结果的一元函数可以并且将把Python整数转换为浮点数。整数除法通过将输入转换为
float64
返回浮点数。
请注意,可能还有其他函数可以应用这些例外情况,但实际上没有。在这些情况下,允许它们应该被视为一种改进,但是当用户影响较小时,为了简单起见,我们可能不会这样做。
向后兼容性#
一般来说,仅使用默认dtype float64、int32/int64或更精确的dtype的代码不应受到影响。
但是,所提出的更改将在很多情况下修改混合使用0-D或标量值(具有非默认dtype)的结果。在许多情况下,这些将是错误修复,但是,某些更改可能对最终用户造成问题。
最有可能出现的故障可能是以下示例
arr = np.arange(100, dtype=np.uint8) # storage array with low precision
value = arr[10]
# calculation continues with "value" without considering where it came from
value * 100
之前,value * 100
会导致向上转换为 int32
/int64
(因为value是标量)。除非明确处理(就像 value
是一个数组一样),否则新行为将保留较低的精度。这会导致整数溢出,从而导致超出精度的错误结果。在许多情况下,这可能是无声的,尽管NumPy通常会为标量运算符发出警告。
同样,如果存储数组是 float32
,则计算可能会保留较低的 float32
精度,而不是使用默认的 float64
。
还会发生其他问题。例如
当混合精度时,浮点比较,尤其是相等性,可能会发生变化
np.float32(1/3) == 1/3 # was False, will be True.
预计某些操作将开始失败
np.array([1], np.uint8) * 1000 np.array([1], np.uint8) == 1000 # possibly also
以保护用户,防止以前基于值的强制转换导致向上转换的情况。(当将
1000
转换为uint8
时,会发生故障。)在更奇怪的情况下,可能会发生浮点溢出
np.float32(1e-30) * 1e50 # will return ``inf`` and a warning
因为
np.float32(1e50)
返回inf
。以前,即使1e50
不是 0-D 数组,也会返回双精度结果
在其他情况下,可能会发生精度提高。例如
np.multiple(float32_arr, 2.)
float32_arr * np.float64(2.)
都将返回 float64 而不是 float32
。这提高了精度,但略微改变了结果并使用了双倍的内存。
由于整数“精度阶梯”而导致的更改#
当从Python整数创建数组时,NumPy将按顺序尝试以下类型,结果取决于该值
long (usually int64) → int64 → uint64 -> object
这与上述提升略有不同。
此NEP目前不包括更改此阶梯(尽管可能会在单独的文档中建议)。但是,在混合操作中,此阶梯将被忽略,因为该值将被忽略。这意味着,操作永远不会无声地使用 object
dtype
np.array([3]) + 2**100 # Will error
用户将不得不编写以下其中一个
np.array([3]) + np.array(2**100)
np.array([3]) + np.array(2**100, dtype=object)
因此,隐式转换为 object
的情况应该很少见,并且解决方法很明确,我们预计向后兼容性问题相当小。
详细描述#
以下内容提供了一些关于当前“基于值”的类型提升逻辑的额外细节,以及“弱标量”类型提升及其内部处理方式的细节。
“基于值”的类型提升的旧实现#
本节回顾当前基于值的逻辑在实践中如何工作,请参阅以下章节,了解其有用的示例。
当 NumPy 看到一个“标量”值时,它可以是 Python 的 int、float、complex,NumPy 标量或数组
1000 # Python scalar
int32(1000) # NumPy scalar
np.array(1000, dtype=int64) # zero dimensional
或者浮点/复数等效值,NumPy 将忽略 dtype 的精度,并找到可以容纳该值的最小可能 dtype。也就是说,它将尝试以下 dtypes
整数:
uint8
,int8
,uint16
,int16
,uint32
,int32
,uint64
,int64
。浮点数:
float16
,float32
,float64
,longdouble
。复数:
complex64
,complex128
,clongdouble
。
请注意,例如,对于整数值 10
,最小的 dtype 可以是 uint8
或 int8
中的任何一个。
当所有参数都是标量值时,NumPy 从未应用此规则
np.int64(1) + np.int32(2) == np.int64(3)
对于整数,一个值是否适合由它是否可以用 dtype 表示来精确决定。对于浮点数和复数,如果满足以下条件,则认为 dtype 是足够的:
float16
:-65000 < value < 65000
(或 NaN/Inf)float32
:-3.4e38 < value < 3.4e38
(或 NaN/Inf)float64
:-1.7e308 < value < 1.7e308
(或 Nan/Inf)longdouble
:(最大范围,因此没有限制)
对于复数,这些界限应用于实部和虚部。这些值大致对应于 np.finfo(np.float32).max
。(NumPy 从未强制将 float64
用于 float32(3.402e38)
的值,但它将用于 Python 值 3.402e38
。)
当前“基于值”的类型提升的状态#
在我们提出当前数据类型系统的替代方案之前,回顾一下“基于值的类型提升”是如何使用以及如何有用的,这很有帮助。基于值的类型提升允许以下代码工作:
# Create uint8 array, as this is sufficient:
uint8_arr = np.array([1, 2, 3], dtype=np.uint8)
result = uint8_arr + 4
result.dtype == np.uint8
result = uint8_arr * (-1)
result.dtype == np.int16 # upcast as little as possible.
其中,特别是第一部分可能很有用:用户知道输入是一个具有特定精度的整数数组。考虑到简单的 + 4
保留以前的数据类型是很直观的。用 np.float32
替换此示例可能更清晰,因为浮点数很少会溢出。如果没有这种行为,上面的示例将需要编写 np.uint8(4)
,并且缺少该行为会使以下内容令人惊讶:
result = np.array([1, 2, 3], dtype=np.float32) * 2.
result.dtype == np.float32
其中,缺少特殊情况会导致返回 float64
。
重要的是要注意,该行为也适用于通用函数和零维数组
# This logic is also used for ufuncs:
np.add(uint8_arr, 4).dtype == np.uint8
# And even if the other array is explicitly typed:
np.add(uint8_arr, np.array(4, dtype=np.int64)).dtype == np.uint8
回顾一下,如果我们将 4
替换为 [4]
使其变为一维,则结果将不同
# This logic is also used for ufuncs:
np.add(uint8_arr, [4]).dtype == np.int64 # platform dependent
# And even if the other array is explicitly typed:
np.add(uint8_arr, np.array([4], dtype=np.int64)).dtype == np.int64
建议的弱类型提升#
此提案使用“弱标量”逻辑。这意味着 Python 的 int
、float
和 complex
不会被分配诸如 float64 或 int64 之类的典型 dtypes。相反,它们被分配一个特殊的抽象 DType,类似于“标量”层次结构的名称:Integral、Floating、ComplexFloating。
当发生类型提升时(如果不存在精确的循环匹配,则对于 ufuncs 会发生类型提升),另一个 DType 可以决定如何看待 Python 标量。例如,UInt16
与 Integral
进行类型提升将得到 UInt16
。
注意
将来很可能会为用户定义的 DTypes 提供默认值。最有可能的是,这将最终成为默认的整数/浮点数,但原则上可以实现更复杂的方案。
在任何时候都不会使用该值来决定此类型提升的结果。该值仅在将其转换为新 dtype 时才会被考虑;这可能会引发错误。
实现#
实现此 NEP 需要向所有二元运算符(或 ufuncs)添加一些额外的机制,以便它们在可能的情况下尝试使用“弱”逻辑。对此有两种可能的方法:
二元运算符只是尝试在出现这种情况时调用
np.result_type()
,并将 Python 标量转换为结果类型(如果已定义)。二元运算符指示输入是 Python 标量,其余部分使用 ufunc 分派/类型提升机制(请参阅 NEP 42)。这允许更大的灵活性,但需要在 ufunc 机制中添加一些额外的逻辑。
注意
到目前为止,尚不清楚哪种方法更好,这两种方法都将给出相当等效的结果,如果将来有必要,可以由 2. 扩展 1。
它还需要删除所有当前基于值的特殊代码路径。
违反直觉的是,实现中的一个更大步骤可能是实现一种解决方案,允许在以下示例中引发错误:
np.arange(10, dtype=np.uint8) + 1000
即使 np.uint8(1000)
返回的值与 np.uint8(232)
相同。
注意
请参阅替代方案,我们可能仍然认为这种静默溢出是可以接受的,或者至少是另一个问题。
替代方案#
有几个设计轴可以做出不同的选择。以下各节概述了这些。
使用强类型标量或两者混合#
解决基于值提升/转换问题的最简单方法是使用强类型 Python 标量,即 Python 浮点数被认为是双精度,Python 整数始终被认为与默认整数 dtype 相同。
这将是最简单的解决方案,但是,当使用 float32
或 int16
等数组时,会导致许多向上转换。这些情况的解决方案是依赖于就地操作。我们目前认为,虽然这种更改不太危险,但它会影响许多用户,并且会比不常见的情况更令人惊讶(尽管期望差异很大)。
原则上,弱行为与强行为不必是统一的。也可以使 Python 浮点数使用弱行为,但 Python 整数使用强行为,因为整数溢出更令人惊讶。
不要在函数中使用弱标量逻辑#
此 NEP 提案的一种替代方法是将弱类型的使用范围缩小到 Python 运算符。
这有优点和缺点:
主要的优点是将它限制在 Python 运算符意味着这些“弱”类型/dtypes 对于短的 Python 语句来说显然是短暂的。
一个缺点是
np.multiply
和*
的互换性较差。仅对运算符使用“弱”类型提升意味着库不必担心它们是否要“记住”输入最初是 Python 标量。另一方面,它会为 Python 运算符增加一些稍微不同(或额外)的逻辑。(从技术上讲,可能是作为 ufunc 分派机制的一个标志,用于切换弱逻辑。)
__array_ufunc__
通常单独使用,为实现它的类数组对象提供 Python 运算符支持。如果运算符是特殊的,这些类数组对象可能需要一种机制来匹配 NumPy(例如,为 ufunc 提供一个 kwarg 以启用弱提升)。
NumPy 标量可能是特殊的#
许多用户期望 NumPy 标量应该与 NumPy 数组不同,例如 np.uint8(3) + 3
应该返回一个 int64
(或 Python 整数),而 uint8_arr + 3
则保留 uint8
dtype。
这种替代方案将非常接近 NumPy 标量当前的行为,但它将巩固数组和标量之间的区别(NumPy 数组比 Python 标量“更强”,但 NumPy 标量则不是)。
这种区别是完全有可能的,但是,目前 NumPy 通常会(默默地)将 0 维数组转换为标量。因此,如果我们也改变这种静默转换(有时称为“衰减”)行为,那么考虑这种替代方案可能才有意义。
处理标量在不安全情况下的转换#
诸如以下情况
np.arange(10, dtype=np.uint8) + 1000
应按照此 NEP 产生错误。这可以放宽为给出警告,甚至忽略“不安全”的转换,这(在所有相关的硬件上)会导致 np.uint8(1000) == np.uint8(232)
被使用。
允许弱类型数组#
拥有弱类型 Python 标量,但没有弱类型数组的一个问题是,在许多情况下, np.asarray()
会不加区分地在输入上调用。为了解决这个问题,JAX 将考虑 np.asarray(1)
的结果也是弱类型的。然而,这有两个难点
JAX 注意到以下情况可能会令人困惑
np.broadcast_to(np.asarray(1), (100, 100))
是一个非 0 维数组,它“继承”了弱类型。[2]
与 JAX 张量不同,NumPy 数组是可变的,因此赋值可能需要使其成为强类型的?
一个标志很可能作为实现细节(例如,在 ufunc 中)是有用的,但是,到目前为止,我们不希望将其作为用户 API。主要原因是,如果将这样的标志作为函数的结果传递出来,而不是仅在非常局部的地方使用,可能会让用户感到惊讶。
待办事项
在接受 NEP 之前,最好进一步讨论这个问题。库可能需要更清晰的模式来“传播” “弱”类型,这可以只是一个 np.asarray_or_literal()
来保留 Python 标量,或者在 np.asarray()
之前调用 np.result_type()
的模式。
对 Python 标量保持使用基于值的逻辑#
当前逻辑的一些主要问题出现的原因是,我们将其应用于 NumPy 标量和 0 维数组,而不是应用于 Python 标量。因此,我们可以考虑继续检查 Python 标量的值。
我们拒绝这个想法,理由是它不会消除之前给出的意外情况
np.uint8(100) + 1000 == np.uint16(1100)
np.uint8(100) + 200 == np.uint8(44)
并且基于结果值而不是输入值来调整精度可能对标量操作是可行的,但对数组操作是不可行的。这是因为数组操作需要在执行计算之前分配结果数组。
讨论#
参考文献和脚注#
版权#
本文档已置于公共领域。[1]