NEP 11 — 延迟 UFunc 计算#
- 作者:
Mark Wiebe <mwwiebe@gmail.com>
- 内容类型:
text/x-rst
- 创建:
2010年11月30日
- 状态:
延迟
摘要#
本 NEP 描述了一个向 NumPy 的 UFunc 添加延迟计算的提案。这将允许像“a[:] = b + c + d + e”这样的 Python 表达式一次性遍历所有变量进行计算,而无需临时数组。由此产生的性能可能与 numexpr 库相当,但语法更自然。
这个想法与 UFunc 错误处理和 UPDATEIFCOPY 标志有一些交互,影响了设计和实现,但结果允许在 Python 用户几乎无需付出任何努力的情况下使用延迟计算。
动机#
NumPy 的 UFunc 执行风格会导致大型表达式的性能不佳,因为会分配多个临时变量,并且会多次遍历输入。numexpr 库可以通过在小的缓存友好块中执行并对每个元素计算整个表达式来胜过 NumPy 处理此类大型表达式。这导致对每个输入进行一次遍历,这对于缓存来说要好得多。
为了了解如何在不更改 Python 代码的情况下获得这种行为,请考虑 C++ 中表达式模板的技术。这些可以用于相当任意地重新排列使用向量或其他数据结构的表达式,例如
A = B + C + D;
可以转换为等效于
for(i = 0; i < A.size; ++i) {
A[i] = B[i] + C[i] + D[i];
}
这是通过返回一个知道如何计算结果的代理对象来完成的,而不是返回实际的对象。使用现代 C++ 优化编译器,生成的机器代码通常与手写循环相同。有关示例,请参阅 Blitz++ 库。一个最近创建的用于帮助编写表达式模板的库是 Boost Proto。
通过在 Python 中使用返回代理对象相同的思想,我们可以动态地实现相同的功能。返回的对象是未分配其缓冲区的 ndarray,并且具有足够的知识在需要时自行计算。当最终计算“延迟数组”时,我们可以使用由所有操作数延迟数组组成的表达式树,有效地创建一个新的 UFunc 来动态计算。
Python 代码示例#
以下是如何在 NumPy 中使用它。
# a, b, c are large ndarrays
with np.deferredstate(True):
d = a + b + c
# Now d is a 'deferred array,' a, b, and c are marked READONLY
# similar to the existing UPDATEIFCOPY mechanism.
print d
# Since the value of d was required, it is evaluated so d becomes
# a regular ndarray and gets printed.
d[:] = a*b*c
# Here, the automatically combined "ufunc" that computes
# a*b*c effectively gets an out= parameter, so no temporary
# arrays are needed whatsoever.
e = a+b+c*d
# Now e is a 'deferred array,' a, b, c, and d are marked READONLY
d[:] = a
# d was marked readonly, but the assignment could see that
# this was due to it being a deferred expression operand.
# This triggered the deferred evaluation so it could assign
# the value of a to d.
不过,可能存在一些意外的行为。
with np.deferredstate(True):
d = a + b + c
# d is deferred
e[:] = d
f[:] = d
g[:] = d
# d is still deferred, and its deferred expression
# was evaluated three times, once for each assignment.
# This could be detected, with d being converted to
# a regular ndarray the second time it is evaluated.
我相信在文档中应该推荐的使用方法是将延迟状态保留为默认状态,除非在计算可以从中受益的大型表达式时。
# calculations
with np.deferredstate(True):
x = <big expression>
# more calculations
这将避免由于始终保持延迟使用为 True 而导致的意外情况,例如在以后使用延迟表达式时出现意外时间的浮点警告或异常。通过推荐这种方法,可以避免用户提出的诸如“为什么我的打印语句抛出除以零错误?”之类的问题。
建议的延迟计算 API#
为了使延迟计算工作,C API 需要知道它的存在,并且能够在必要时触发计算。ndarray 将获得两个新标志。
NPY_ISDEFERRED
指示此 ndarray 实例的表达式计算已被延迟。
NPY_DEFERRED_WASWRITEABLE
仅当
PyArray_GetDeferredUsageCount(arr) > 0
时才能设置。它指示当arr
首次用于延迟表达式时,它是一个可写数组。如果设置了此标志,则调用PyArray_CalculateAllDeferred()
将使arr
再次可写。
注意
问题
NPY_DEFERRED 和 NPY_DEFERRED_WASWRITEABLE 应该对 Python 可见,还是应该从 python 访问标志在必要时触发 PyArray_CalculateAllDeferred?
API 将通过多个函数进行扩展。
int PyArray_CalculateAllDeferred()
此函数强制执行所有当前延迟的计算。
例如,如果错误状态设置为忽略所有内容,并且 np.seterr({all=’raise’}),这将改变已经延迟表达式的行为。因此,在更改错误状态之前,应该计算所有现有的延迟数组。
int PyArray_CalculateDeferred(PyArrayObject* arr)
如果“arr”是延迟数组,则为其分配内存并计算延迟表达式。如果“arr”不是延迟数组,则只需返回成功。返回 NPY_SUCCESS 或 NPY_FAILURE。
int PyArray_CalculateDeferredAssignment(PyArrayObject* arr, PyArrayObject* out)
如果“arr”是延迟数组,则将延迟表达式计算到“out”中,“arr”保持为延迟数组。如果“arr”不是延迟数组,则将其值复制到 out。返回 NPY_SUCCESS 或 NPY_FAILURE。
int PyArray_GetDeferredUsageCount(PyArrayObject* arr)
返回使用此数组作为操作数的延迟表达式的数量。
Python API 将如下扩展。
numpy.setdeferred(state)
启用或禁用延迟计算。True 表示始终使用延迟计算。False 表示永远不使用延迟计算。None 表示如果错误处理状态设置为忽略所有内容,则使用延迟计算。在 NumPy 初始化时,延迟状态为 None。
返回之前的延迟状态。
numpy.getdeferred()
返回当前的延迟状态。
numpy.deferredstate(state)
用于延迟状态处理的上下文管理器,类似于
numpy.errstate
。
错误处理#
错误处理对于延迟计算来说是一个棘手的问题。如果 NumPy 错误状态为 {all=’ignore’},则引入延迟计算作为默认值可能是合理的,但是如果 UFunc 可能会引发错误,则稍后的“打印”语句引发异常而不是导致错误的实际操作将非常奇怪。
一个可能的好方法是默认情况下仅在错误状态设置为忽略所有内容时启用延迟计算,但允许用户使用“setdeferred”和“getdeferred”函数进行控制。True 表示始终使用延迟计算,False 表示从不使用它,None 表示仅在安全时使用它(即错误状态设置为忽略所有内容)。
与 UPDATEIFCOPY 的交互#
NPY_UPDATEIFCOPY
文档说明
数据区域表示一个(行为良好的)副本,当此数组被删除时,其信息应传输回原始数组。
这是一个特殊的标志,如果该数组表示一个副本(因为用户在 PyArray_FromAny 中要求某些标志,并且必须复制其他一些数组,并且用户要求在这种情况下设置此标志),则会设置该标志。然后,base 属性指向“行为异常”的数组(该数组设置为只读)。当设置了此标志的数组被释放时,它将把其内容复制回“行为异常”的数组(如果需要则进行类型转换),并将“行为异常”的数组重置为 NPY_WRITEABLE。如果“行为异常”数组最初不是 NPY_WRITEABLE,那么 PyArray_FromAny 将返回错误,因为 NPY_UPDATEIFCOPY 将不可用。
UPDATEIFCOPY 的当前实现假设它是唯一以这种方式修改可写标志的机制。这些机制必须相互了解才能正常工作。以下是如何出现错误的示例。
使用 UPDATEIFCOPY 创建 'arr' 的临时副本('arr' 变为只读)。
在延迟表达式中使用 'arr'(延迟使用计数变为 1,NPY_DEFERRED_WASWRITEABLE **未** 设置,因为 'arr' 是只读的)。
销毁临时副本,导致 'arr' 变为可写。
写入 'arr' 会破坏延迟表达式的值。
为了解决此问题,我们使这两种状态互斥。
UPDATEIFCOPY 的使用检查
NPY_DEFERRED_WASWRITEABLE
标志,如果已设置,则调用PyArray_CalculateAllDeferred
以在继续之前刷新所有延迟计算。ndarray 获取一个新的标志
NPY_UPDATEIFCOPY_TARGET
,表示该数组将在将来的某个时间点被更新并变为可写。如果延迟评估机制在任何操作数中看到此标志,它将触发立即评估。
其他实现细节#
创建延迟数组时,它会获取对 UFunc 的所有操作数以及 UFunc 本身的引用。每个操作数的“DeferredUsageCount”都会递增,稍后在计算延迟表达式或销毁延迟数组时递减。
跟踪一个全局弱引用列表,其中包含所有延迟数组,按创建顺序排列。当调用 PyArray_CalculateAllDeferred
时,首先计算最新的延迟数组。这可能会释放对延迟表达式树中包含的其他延迟数组的引用,然后这些数组永远不必计算。
进一步优化#
不是在任何错误未设置为“忽略”时保守地禁用延迟评估,每个 UFunc 可以提供其生成的可能错误集。然后,如果所有这些错误都设置为“忽略”,即使其他错误未设置为忽略,也可以使用延迟评估。
一旦显式存储了表达式树,就可以对其进行转换。例如,add(add(a,b),c) 可以转换为 add3(a,b,c),或者 add(multiply(a,b),c) 可以使用 CPU 融合乘加指令(如果可用)转换为 fma(a,b,c)。
虽然我将延迟评估描述为仅适用于 UFunc,但它可以扩展到其他函数,例如 dot()。例如,可以重新排序链式矩阵乘法以最小化中间结果的大小,或者窥孔式优化器传递可以搜索与优化 BLAS/其他高性能库调用匹配的模式。
对于非常大的数组上的操作,将像 LLVM 这样的 JIT 集成到此系统中可能是一个很大的好处。UFunc 和其他操作将提供位代码,这些位代码可以内联在一起并由 LLVM 优化器优化,然后执行。实际上,迭代器本身也可以用位代码表示,允许 LLVM 在进行优化时考虑整个迭代。