NEP 50 — Python 标量的提升规则#
- 作者:
塞巴斯蒂安·博格
- 状态:
最终
- 类型:
标准跟踪
- 创建:
2021-05-25
摘要#
自 NumPy 1.7 以来,提升规则使用了所谓的“安全转换”,它依赖于对所涉及的值的检查。这有助于识别用户的一些极端情况,但实施起来很复杂,而且也很难预测行为。
有两类令人困惑的结果
基于值的提升意味着一个值,例如一个 Python 整数,可以确定输出类型,正如
np.result_type
发现的那样np.result_type(np.int8, 1) == np.int8 np.result_type(np.int8, 255) == np.int16
此逻辑出现是因为
1
可以用uint8
或int8
表示,而255
不能用int8
表示,只能用uint8
或int16
表示。使用 0D 数组(所谓的“标量数组”)时,这也是成立的
int64_0d_array = np.array(1, dtype=np.int64) np.result_type(np.int8, int64_0d_array) == np.int8
其中
int64_0d_array
具有int64
dtype 的事实对所得的 dtype 没有任何影响。本例中实际上未考虑dtype=np.int64
,因为仅其值至关重要。对于 Python
int
、float
或complex
,会像前面展示的那样检查值。但令人惊讶的是,当 NumPy 对象为 0-D 数组或 NumPy 标量时不会检查值np.result_type(np.array(1, dtype=np.uint8), 1) == np.int64 np.result_type(np.int8(1), 1) == np.int64
原因在于,当所有对象均为标量或 0-D 数组时,会禁用基于值提升。因此,NumPy 返回与
np.array(1)
相同的类型,通常是int64
(具体取决于系统)。
请注意,这些示例也适用于乘法、加法、比较以及类似 np.multiply
函数等操作。
此 NEP 提议依据以下两条指导原则改进行为
值绝不能影响结果类型。
NumPy 标量和 0-D 数组的行为应与其 N-D 对应项保持一致。
我们建议删除所有基于值的逻辑,并为 Python 标量添加特殊处理,以保留一些便捷行为。Python 标量将被视为“弱”类型。当 NumPy 数组/标量与 Python 标量结合时,它将被转换为 NumPy dtype,从而
np.array([1, 2, 3], dtype=np.uint8) + 1 # returns a uint8 array
np.array([1, 2, 3], dtype=np.float32) + 2. # returns a float32 array
不会依赖于 Python 值本身。
提出的更改也适用于 np.can_cast(100, np.int8)
;但是,我们预期在实践中,函数(提升)中的行为将比铸造更改本身重要得多。
请注意
截至 NumPy 1.24.x 系列,NumPy 已提供初步的和有限的支持来测试此提案。
还有必要设置以下环境变量
export NPY_PROMOTION_STATE=weak
有效值为 weak
、weak_and_warn
和 legacy
。注意,weak_and_warn
实现了 NEP 中提出的可选警告,预计非常嘈杂。我们建议先使用 weak
选项,主要使用 weak_and_warn
来了解行为中的特定观察到的变化。
存在以下附加 API
np._set_promotion_state()
和np._get_promotion_state()
与环境变量等效。(非线程/上下文安全。)with np._no_nep50_warning():
可在使用weak_and_warn
提临时禁止警告。(线程和上下文安全。)
此时整数幂运算中的溢出警告已消失。此外,np.can_cast
在 weak_and_warn
模式中未发出警告。它对 Python 标量输入的行为可能会发生变化(这应影响很少的用户)。
新提议的提升规则架构#
更改后,NumPy 中的提升将遵循以下架构。提升始终沿绿色线条进行:从左向右提升到同类,仅在必要时提升到更高类。结果类别始终是输入的最大类别。请注意,float32
的精度低于 int32
或 uint32
,因此在示意图中稍稍排在左边。这是因为 float32
无法完全表示所有 int32
值。但是,出于实际原因,NumPy 允许将 int64
提升到 float64
,有效地认为它们的精度相同。
Python 标量插入在每个“类别”的最左边,而 Python 整数不区分带符号和不带符号。因此,NumPy 提升使用以下有序类别:
布尔
整数:有符号或无符号整数
不精确:浮点数和复浮点数
当用类别较高的类别提升类别较低的 Python 标量(布尔 < 整数 < 不精确)时,我们使用最小/默认精度:即 float64
、complex128
或 int64
(在某些系统上使用 int32
,例如 Windows)。
请参见下一部分以了解阐明提议行为的示例。可在下表中找到与当前行为比较的更多示例。
新行为示例#
为方便解释上述文本和图表,我们提供一些新的行为示例。下面,Python 整数对结果类型无影响
np.uint8(1) + 1 == np.uint8(2)
np.int16(2) + 2 == np.int16(4)
在以下示例中,Python float
和 complex
为“不精确”,但 NumPy 值是整数,因此我们至少使用 float64
/complex128
np.uint16(3) + 3.0 == np.float64(6.0)
np.int16(4) + 4j == np.complex128(4+4j)
但是,对于提升 float
到 complex
的情况不适用,其中 float32
和 complex64
具有相同的精度
np.float32(5) + 5j == np.complex64(5+5j)
请注意,该示意图遗漏了 bool
。它设置为“整数”以下,以便保留以下内容
np.bool_(True) + 1 == np.int64(2)
True + np.uint8(2) == np.uint8(3)
请注意,尽管此 NEP 使用简单的操作符作为示例,但所述规则通常适用于所有 NumPy 操作。
比较新旧行为的表格#
下表列出了相关更改和未更改的行为。请参阅 旧实现 以详细了解导致“旧结果”的规则,以及以下部分以了解详细说明新规则的内容。向后兼容性部分讨论了这些更改如何影响用户。
请注意,类似 array(2)
的 0-D 数组与不是 0-D 的数组之间存在重要区别,例如 array([2])
。
表达式 |
旧结果 |
新结果 |
---|---|---|
|
|
|
|
|
|
|
|
|
|
|
unchanged |
|
|
unchanged |
|
|
unchanged [T3] |
|
|
Exception [T4] |
|
|
Exception [T5] |
|
|
|
|
|
|
|
|
unchanged |
|
|
|
|
|
unchanged |
|
|
|
|
|
|
|
|
|
|
|
unchanged [T12] |
新行为尊重 uint8
标量的 `dtype`。
将当前的 NumPy 与数组组合使用时,会忽略 0-D 数组或 NumPy 标量的精度。
将当前的 NumPy 与数组组合使用时,会忽略 0-D 数组或 NumPy 标量的精度。
因为 300
不适合 uint8
,所以旧行为使用 uint16
,而新行为出于相同原因引发错误。
300
无法转换为 uint8
。
最危险的更改之一。保留类型会导致溢出。 RuntimeWarning
表明溢出已提供给 NumPy 标量。
np.float32(3e100)
溢出为无限大并附带警告。
1 + 1e-14
在 float32 中完成时会损失精度,而在 float64 中不会。旧行为会根据数组的维度,将标量参数强制转换为 float32 或 float64;对于新行为,运算始终在数组精度(本例为 float32)中完成。
NumPy 将 float32
和 int64
提升为 float64
。旧行为在这里忽略 int64
。
新行为在 array(3, complex64)
和 array([3], complex64)
之间保持一致:结果的 `dtype` 是数组参数的 `dtype`。
新行为使用与数组参数兼容精度的复数数据类型,float32
。
由于数组种类为整数,结果使用默认复数精度,即 complex128
。
动机和范围#
针对检查 Python 标量和 NumPy 标量/0-D 数组的值更改行为的动机有以下三个方面
将 NumPy 标量/0-D 数组视为特殊处理以及值检查可能会让用户感到惊讶,
值检查逻辑难以解释和实现。而且,很难通过 NEP 42 向用户定义的数据类型提供此逻辑。目前,这导致新系统和旧系统(区分值的系统)的实现重复。修复此问题将极大简化内部逻辑,使结果更加一致。
这在很大程度上与 JAX 和 data-apis.org 等其他项目的做法保持一致(另请参见 相关工作)。
我们相信“弱”Python 标量的提议将通过为用户提供明确的心理模型来帮助用户理解操作将产生哪种数据类型。该模型与目前 NumPy 经常遵循的数组精度保持一致,并且也用于就地操作
arr += value
只要不跨越“kind”边界(否则将引发错误),就会保留精度。
虽然一些用户可能不会丢失值检查行为,但即使对于那些看似有用的情况,它也会很快带来意外。这可能是预期之中的
np.array([100], dtype=np.uint8) + 1000 == np.array([1100], dtype=np.uint16)
但以下情况会令人惊讶
np.array([100], dtype=np.uint8) + 200 == np.array([44], dtype=np.uint8)
考虑到该提议与就地操作数的行为一致,并避免了仅在某些情况下才避免结果溢出的令人惊讶的行为转换,我们相信该提议遵循了“最小惊讶原则”。
用法和影响#
该 NEP 预计将在 不设置任何过渡期的前提下实现,该过渡期会对所有更改发出警告。这样的过渡期会生成许多(通常是无害的)警告,这些警告很难消音。我们预计,大多数用户将长期受益于更明确的提升规则,并且很少有人直接(消极地)受到更改的影响。但是,某些使用模式可能会导致问题更改,这些在向后兼容性部分中进行了详细说明。
解决方案将是一种可选的警告模式,能够通知用户潜在的行为变化。预计此模式会生成许多无害的警告,但如果观察到问题,它将提供系统性审查代码和跟踪更改的方法。
对 can_cast
的影响#
can_cast 以后将不再检查值。因此,预期以下结果将从 True
更改为 False
np.can_cast(np.int64(100), np.uint8)
np.can_cast(np.array(100, dtype=np.int64), np.uint8)
np.can_cast(100, np.uint8)
我们预期此更改的影响将小于以下更改的影响。
请注意
最后一个示例,其中输入是一个 Python 标量,_可能_会保留,因为 100
可以由 uint8
表示。
涉及 NumPy 数组或标量的运算符和函数的影响#
对不涉及 Python 标量的操作(float
、int
、complex
)的主要影响是,对 0 维数组和 NumPy 标量的操作将永远不取决于它们的值。这消除了目前令人惊讶的情况。例如
np.arange(10, dtype=np.uint8) + np.int64(1)
# and:
np.add(np.arange(10, dtype=np.uint8), np.int64(1))
将来将返回一个 int64
数组,因为 np.int64(1)
的类型会被严格遵守。目前返回一个 uint8
数组。
涉及 Python int
、float
和 complex
的运算符的影响#
此 NEP 尝试在处理文字值时保留旧行为的便利性。在涉及“无类型”文字 Python 标量时,基于当前价值的逻辑具有一些不错的特性
np.arange(10, dtype=np.int8) + 1 # returns an int8 array
np.array([1., 2.], dtype=np.float32) * 3.5 # returns a float32 array
但在涉及“不可表示”值时会出现令人惊讶的情况
np.arange(10, dtype=np.int8) + 256 # returns int16
np.array([1., 2.], dtype=np.float32) * 1e200 # returns float64
该提案的目的是在很大程度上保留这种行为。这是通过在操作中将 Python int
、float
和 complex
视为“弱”类型来实现的。但是,为了避免惊喜,我们计划对新类型的转换更加严格:前两个示例中的结果将保持不变,但在第二个示例中,它将发生如下更改
np.arange(10, dtype=np.int8) + 256 # raises a TypeError
np.array([1., 2.], dtype=np.float32) * 1e200 # warning and returns infinity
第二个示例发出警告,因为 np.float32(1e200)
会溢出为无穷大。然后它将继续像往常一样使用 inf
进行计算。
其他库中的行为
转换中出现溢出而非引发错误是一种选择,在大多数 C 设置中它是默认选项(类似于 NumPy,C 可被设置为在溢出时引发错误)。例如,这也是 pytorch
1.10 的行为。
Python 整数的特定行为#
NEPs 晋升规则声明的结果数据类型,它通常也是操作数据类型(就结果精度而言)。这导致的结果看起来像是 Python 整数出现异常:虽然 uint8(3) + 1000
必须被拒绝,因为在 uint8
中进行操作是不可能的,但 uint8(3) / 1000
会返回一个 float64
,并且可以将两个输入转换为 float64
来找出结果。
实际上,这意味着在以下情况下会接受任意 Python 整数值
NumPy 和 Python 整数之间的所有比较(
==
、<
等)始终都已明确定义。类似于
np.sqrt
这类会产生浮点数结果的一元函数能够并将会把 Python 整数转换为浮点数。将整数相除会返回浮点数,方法是将输入强制转换为
float64
。
请注意,可能有额外的函数可以应用这些异常,但没有这样做。在这些情况下,允许它们被认为是一种改进,但当用户影响较小时,我们可能不会出于简单性这么做。
向后兼容性#
通常情况下,仅使用默认数据类型 float64,或 int32/int64 或更精确的数据类型而编写的代码不会受到影响。
但是,建议的变更会在许多场合中修改结果,在这种情况下,0 维或标量值(具有非默认数据类型)被混合。在许多情况下,这些将是错误修复,不过对于最终用户来说,某些更改可能会有问题。
最重要的可能失败可能是以下示例
arr = np.arange(100, dtype=np.uint8) # storage array with low precision
value = arr[10]
# calculation continues with "value" without considering where it came from
value * 100
以前,value * 100
会向上转换为 int32
/int64
(因为 value 是标量)。新的行为会在未明确进行处理的情况下保持低精度(就像 value
是数组一样)。这可能导致整数溢出,从而造成超出精度的错误结果。在许多情况下,这可能是静默的,尽管 NumPy 通常会对标量运算符发出警告。
同样,如果存储数组是 float32
,则计算可能会保留较低的 float32
精度,而不是使用默认的 float64
。
可能会出现更多问题。例如
浮点比较(尤其是相等比较)在混合精度时可能会改变
np.float32(1/3) == 1/3 # was False, will be True.
预计某些操作将开始失败
np.array([1], np.uint8) * 1000 np.array([1], np.uint8) == 1000 # possibly also
为了保护用户,在以前基于值的转换导致向上转换的情况下。(在将
1000
转换为uint8
时,将会发生故障。)在奇怪的情况下可能发生浮点溢出
np.float32(1e-30) * 1e50 # will return ``inf`` and a warning
因为
np.float32(1e50)
返回inf
。以前,即使1e50
不是 0-D 数组,也会返回一个双精度结果
在其他情况下,精度可能会提高。例如
np.multiple(float32_arr, 2.)
float32_arr * np.float64(2.)
都将返回 float64,而不是 float32
。这提高了精度,但会略微改变结果并使用双倍的内存。
由于整数“精度阶梯”导致的变化#
从 Python 整数创建数组时,NumPy 将按以下顺序尝试以下类型,结果取决于值
long (usually int64) → int64 → uint64 -> object
这与上面描述的提升有细微差别。
该 NEP 目前不包括更改此阶梯(尽管它可能会在单独的文档中提出)。然而,在混合操作中,将忽略此阶梯,因为会忽略该值。这意味着,操作永远不会在静默中使用 object
数据类型
np.array([3]) + 2**100 # Will error
用户将不得不编写其中之一
np.array([3]) + np.array(2**100)
np.array([3]) + np.array(2**100, dtype=object)
由于隐式转换为 object
应该很少见,并且解决方法很明确,我们预计向后兼容性问题相当小。
详细描述#
下面提供一些关于当前“基于值”提升逻辑的附加详细信息,然后提供“弱标量”提升以及它如何在内部处理的相关信息。
原来“基于值”提升的实施#
本节回顾当前基于值的逻辑在实践中的工作方式,请查阅以下部分,了解它如何发挥作用的示例。
当 NumPy 看到一个“标量”值(该值可以是 Python int、float、complex、NumPy 标量或数组)时
1000 # Python scalar
int32(1000) # NumPy scalar
np.array(1000, dtype=int64) # zero dimensional
或者 float/complex 等效值,NumPy 将忽略 dtype 的精度并找到可以保存该值的最小可能的 dtype。也就是说,它将尝试以下 dtype
整数:
uint8
、int8
、uint16
、int16
、uint32
、int32
、uint64
、int64
。浮点数:
float16
、float32
、float64
、longdouble
。复数:
complex64
、complex128
、clongdouble
。
请注意,对于10
的整数,最小的 dtype 可以同时是 uint8
或int8
。
当所有参数都是标量值时,NumPy 绝不会应用这条规则
np.int64(1) + np.int32(2) == np.int64(3)
对于整数,一个值是否合适完全取决于该值能否用 dtype 表示。对于 float 和 complex,如果满足以下任一项,则 dtype 即可被认为已足够
float16
:-65000 < value < 65000
(或 NaN/Inf)float32
:-3.4e38 < value < 3.4e38
(或 NaN/Inf)float64
:-1.7e308 < value < 1.7e308
(或 Nan/Inf)longdouble
:(最大范围,因此没有限制)
对于复杂的范围,这些范围应用于实部和虚部组件。这些值大致对应于 np.finfo(np.float32).max
。(NumPy 从未强制使用 float64
作为值 float32(3.402e38)
,而它会强制作为 Python 值 3.402e38
。)
当前“基于值”提升的状态#
在我们提出当前数据类型系统的替代方案之前,回顾一下如何使用“基于值”提升是有帮助的,并且它可能是有用的。基于值的提升允许以下代码工作
# Create uint8 array, as this is sufficient:
uint8_arr = np.array([1, 2, 3], dtype=np.uint8)
result = uint8_arr + 4
result.dtype == np.uint8
result = uint8_arr * (-1)
result.dtype == np.int16 # upcast as little as possible.
其中特别是有用的部分之一:用户知道输入是一个具有特定精度的整数数组。考虑到简单的 + 4
保留前一个数据类型是直观的。将其示例替换为 np.float32
可能更清晰,因为浮点数很少会溢出。如果没有此行为,上面的示例将需要编写 np.uint8(4)
,并且缺乏此行为会使以下内容令人惊讶
result = np.array([1, 2, 3], dtype=np.float32) * 2.
result.dtype == np.float32
其中缺乏特殊情况会导致返回 float64
。
需要注意的是,该行为也适用于通用函数和零维数组
# This logic is also used for ufuncs:
np.add(uint8_arr, 4).dtype == np.uint8
# And even if the other array is explicitly typed:
np.add(uint8_arr, np.array(4, dtype=np.int64)).dtype == np.uint8
回顾一下,如果我们将 4
替换为 [4]
以使其变为一维,则结果将会不同
# This logic is also used for ufuncs:
np.add(uint8_arr, [4]).dtype == np.int64 # platform dependent
# And even if the other array is explicitly typed:
np.add(uint8_arr, np.array([4], dtype=np.int64)).dtype == np.int64
建议的弱提升#
此提案使用“弱标量”逻辑。这意味着 Python int
、float
和 complex
没有赋予 float64 或 int64 等典型数据类型。相反,它们被赋予了一个特殊的抽象 DType,类似于“标量”层次结构名称:Integral、Floating、ComplexFloating。
当提升发生时(就像对于 ufunc 没有精确的循环匹配一样),其他 DType 能够决定如何看待 Python 标量。例如一个 UInt16
与 Integral
提升将产生 UInt16
。
请注意
将来很可能会为用户定义的 DType 提供默认值。这个最终很可能会变成缺省整数/浮点数,但原则上可以实现更复杂的方案。
这个值从不会用于确定此推广的结果。该值仅在转换为新数据类型时才会考虑;这可能会引发错误。
实现#
实现此 NEP 需要将一些附加机制添加到所有二元运算符(或 ufunc),以便它们尝试在可能的情况下使用“弱”逻辑。有两种可能的方法
如果出现这种情况,二元运算符会尝试调用
np.result_type()
,并将 Python 标量转换为结果类型(如果定义)。二元运算符指示输入是一个 Python 标量,并且 ufunc 调度/推广机制用于其余部分(参见NEP 42)。这允许更大的灵活性,但在 ufunc 机制中需要一些额外的逻辑。
请注意
在目前,尚不清楚哪种方法更好,毕竟两种给出的结果相当并且如果必要,1. 可以在将来通过 2. 进行扩展。
它还要求移除所有当前基于特殊值代码路径。
不合直觉的是,实现中更大的步骤可能是实现一种解决方案以允许在以下示例中引发错误
np.arange(10, dtype=np.uint8) + 1000
即使np.uint8(1000)
返回的值与np.uint8(232)
相同。
请注意
请参阅其他选择,我们可能决定此静默溢出是可以接受的,或至少是一个单独的问题。
备选方案#
有几个设计轴,不同的选择在其中都是可能的。以下章节概述了它们。
使用强类型标量或两者的混合#
解决基于值的推广/强制转换问题的最简单方法是使用强类型 Python 标量,即 Python 浮点被认为是双精度,Python 整数始终被认为与默认整数数据类型相同。
这将是最简单的解决方案,然而它会导致在使用 float32
或 int16
等的数组时发生多次向上转换。这些情况的解决方案将是依靠就地操作。我们目前认为,尽管较不危险,但此项更改将影响许多用户,并且大多数情况下都会令人惊讶(尽管预期差异很大)。
原则上,弱与强行为不需要一致。还可以让 Python 浮点使用弱行为,但 Python 整数使用强行为,因为整数溢出更令人惊讶。
请勿在函数中使用弱标量逻辑#
此 NEP 提议的一个替代方案是将弱类型的使用范围缩小到 Python 运算符。
这利弊兼有
主要优势在于,将其限制到 Python 运算符意味着这些“弱”类型/数据类型显然适用于 Python 简短语句。
缺点是,
np.multiply
和*
的可互换性较差。仅对运算符使用“弱”提升功能意味着库不必担心是否要“记住”一项输入最初为 Python 标量。另一方面,它将为 Python 运算符添加轻微或额外的逻辑需求。(技术上而言,可能作为切换弱逻辑的 ufunc 调度机制的标志。)
__array_ufunc__
通常单独使用,以向实现它的 类数组提供 Python 运算符支持。如果运算符是特殊的,这些类数组可能需要一个机制来匹配 NumPy(例如,用于启用弱提升的 ufunc 的 kwarg。)
NumPy 标量可以是特殊的#
许多用户期望 NumPy 标量应该不同于 NumPy 数组,在于 np.uint8(3) + 3
应该返回一个 int64
(或 Python 整数),而 uint8_arr + 3
保留 uint8
数据类型。
此替代方案将非常接近 NumPy 标量的当前行为,但它将巩固数组和标量之间的区别(NumPy 数组“强于”Python 标量,但 NumPy 标量则不然)。
然而这种区分在很大的程度上也是可能的,此时 NumPy 通常会(并且会悄悄地)将 0-D 数组转换为标量。因此,若我们也改变这种无提示转换(有时称为“衰减”)行为,那么或许只考虑此种替代方案是合理的。
处理标量转换时不安全#
诸如
np.arange(10, dtype=np.uint8) + 1000
应该根据此 NEP 提出一个错误。此项措施可以放宽为提出一个警告,甚至忽略该“不安全”转换(在所有相关硬件上),这将导致 np.uint8(1000) == np.uint8(232)
被使用。
允许弱类型数组#
使用弱类型 Python 标量,而非弱类型数组存在的一个问题是,在许多情况下会对输入不加区别地调用 np.asarray()
。
JAX 注意到,这一点可能会令人感到困惑
np.broadcast_to(np.asarray(1), (100, 100))
是一个非 0-D 数组,它“继承”了弱类型。[2]
不同于 JAX 张量,NumPy 数组是可变动的,因此分配可能需要使其变为强类型?
一个标志可能在实现详情中(如在 ufuncs 中)是有用的,然而,此时我们不希望将其作为用户 API。最主要的原因是,如果该标志是从一个函数而不是仅在非常本地化的情况下使用中传出作为结果,那么这样的标志可能会令用户感到惊讶。
待办事项
在接受 NEP 之前,可以进一步讨论此问题。库可能需要更加清晰的模式来“传播”该“弱”类型,这可能仅仅是一个 np.asarray_or_literal()
来保留 Python 标量,或是先调用 np.result_type()
再调用 np.asarray()
的模式。
持续对 Python 标量使用基于值的逻辑#
当前逻辑的一些主要问题出现了,因为我们将其应用于 NumPy 标量和 0-D 数组,而不是应用于 Python 标量。因此,我们可以考虑继续检查 Python 标量的值。
我们拒绝这一想法,理由是它不会消除前面给出的意外结果。
np.uint8(100) + 1000 == np.uint16(1100)
np.uint8(100) + 200 == np.uint8(44)
可能会根据结果值而非输入值来改进标量操作的精度,但这种方式对数组操作不可行。这是因为数组操作会在执行计算之前分配结果数组。
讨论#
参考和脚注#
版权#
此文档已放置在公有领域中。 [1]