497 lines
13 KiB
Python
497 lines
13 KiB
Python
|
# import numpy in two ways, both uses needed
|
||
|
import numpy as np
|
||
|
import numpy
|
||
|
|
||
|
import unittest
|
||
|
from numba import njit, jit
|
||
|
from numba.core.errors import TypingError, UnsupportedError
|
||
|
from numba.core import ir
|
||
|
from numba.tests.support import TestCase, IRPreservingTestPipeline
|
||
|
|
||
|
|
||
|
class TestClosure(TestCase):
|
||
|
|
||
|
def run_jit_closure_variable(self, **jitargs):
|
||
|
Y = 10
|
||
|
|
||
|
def add_Y(x):
|
||
|
return x + Y
|
||
|
|
||
|
c_add_Y = jit('i4(i4)', **jitargs)(add_Y)
|
||
|
self.assertEqual(c_add_Y(1), 11)
|
||
|
|
||
|
# Like globals in Numba, the value of the closure is captured
|
||
|
# at time of JIT
|
||
|
Y = 12 # should not affect function
|
||
|
self.assertEqual(c_add_Y(1), 11)
|
||
|
|
||
|
def test_jit_closure_variable(self):
|
||
|
self.run_jit_closure_variable(forceobj=True)
|
||
|
|
||
|
def test_jit_closure_variable_npm(self):
|
||
|
self.run_jit_closure_variable(nopython=True)
|
||
|
|
||
|
def run_rejitting_closure(self, **jitargs):
|
||
|
Y = 10
|
||
|
|
||
|
def add_Y(x):
|
||
|
return x + Y
|
||
|
|
||
|
c_add_Y = jit('i4(i4)', **jitargs)(add_Y)
|
||
|
self.assertEqual(c_add_Y(1), 11)
|
||
|
|
||
|
# Redo the jit
|
||
|
Y = 12
|
||
|
c_add_Y_2 = jit('i4(i4)', **jitargs)(add_Y)
|
||
|
self.assertEqual(c_add_Y_2(1), 13)
|
||
|
Y = 13 # should not affect function
|
||
|
self.assertEqual(c_add_Y_2(1), 13)
|
||
|
|
||
|
self.assertEqual(c_add_Y(1), 11) # Test first function again
|
||
|
|
||
|
def test_rejitting_closure(self):
|
||
|
self.run_rejitting_closure(forceobj=True)
|
||
|
|
||
|
def test_rejitting_closure_npm(self):
|
||
|
self.run_rejitting_closure(nopython=True)
|
||
|
|
||
|
def run_jit_multiple_closure_variables(self, **jitargs):
|
||
|
Y = 10
|
||
|
Z = 2
|
||
|
|
||
|
def add_Y_mult_Z(x):
|
||
|
return (x + Y) * Z
|
||
|
|
||
|
c_add_Y_mult_Z = jit('i4(i4)', **jitargs)(add_Y_mult_Z)
|
||
|
self.assertEqual(c_add_Y_mult_Z(1), 22)
|
||
|
|
||
|
def test_jit_multiple_closure_variables(self):
|
||
|
self.run_jit_multiple_closure_variables(forceobj=True)
|
||
|
|
||
|
def test_jit_multiple_closure_variables_npm(self):
|
||
|
self.run_jit_multiple_closure_variables(nopython=True)
|
||
|
|
||
|
def run_jit_inner_function(self, **jitargs):
|
||
|
def mult_10(a):
|
||
|
return a * 10
|
||
|
|
||
|
c_mult_10 = jit('intp(intp)', **jitargs)(mult_10)
|
||
|
c_mult_10.disable_compile()
|
||
|
|
||
|
def do_math(x):
|
||
|
return c_mult_10(x + 4)
|
||
|
|
||
|
c_do_math = jit('intp(intp)', **jitargs)(do_math)
|
||
|
c_do_math.disable_compile()
|
||
|
|
||
|
with self.assertRefCount(c_do_math, c_mult_10):
|
||
|
self.assertEqual(c_do_math(1), 50)
|
||
|
|
||
|
def test_jit_inner_function(self):
|
||
|
self.run_jit_inner_function(forceobj=True)
|
||
|
|
||
|
def test_jit_inner_function_npm(self):
|
||
|
self.run_jit_inner_function(nopython=True)
|
||
|
|
||
|
|
||
|
class TestInlinedClosure(TestCase):
|
||
|
"""
|
||
|
Tests for (partial) closure support in njit. The support is partial
|
||
|
because it only works for closures that can be successfully inlined
|
||
|
at compile time.
|
||
|
"""
|
||
|
|
||
|
def test_inner_function(self):
|
||
|
|
||
|
def outer(x):
|
||
|
|
||
|
def inner(x):
|
||
|
return x * x
|
||
|
|
||
|
return inner(x) + inner(x)
|
||
|
|
||
|
cfunc = njit(outer)
|
||
|
self.assertEqual(cfunc(10), outer(10))
|
||
|
|
||
|
def test_inner_function_with_closure(self):
|
||
|
|
||
|
def outer(x):
|
||
|
y = x + 1
|
||
|
|
||
|
def inner(x):
|
||
|
return x * x + y
|
||
|
|
||
|
return inner(x) + inner(x)
|
||
|
|
||
|
cfunc = njit(outer)
|
||
|
self.assertEqual(cfunc(10), outer(10))
|
||
|
|
||
|
def test_inner_function_with_closure_2(self):
|
||
|
|
||
|
def outer(x):
|
||
|
y = x + 1
|
||
|
|
||
|
def inner(x):
|
||
|
return x * y
|
||
|
|
||
|
y = inner(x)
|
||
|
return y + inner(x)
|
||
|
|
||
|
cfunc = njit(outer)
|
||
|
self.assertEqual(cfunc(10), outer(10))
|
||
|
|
||
|
def test_inner_function_with_closure_3(self):
|
||
|
|
||
|
code = """
|
||
|
def outer(x):
|
||
|
y = x + 1
|
||
|
z = 0
|
||
|
|
||
|
def inner(x):
|
||
|
nonlocal z
|
||
|
z += x * x
|
||
|
return z + y
|
||
|
|
||
|
return inner(x) + inner(x) + z
|
||
|
"""
|
||
|
ns = {}
|
||
|
exec(code.strip(), ns)
|
||
|
|
||
|
cfunc = njit(ns['outer'])
|
||
|
self.assertEqual(cfunc(10), ns['outer'](10))
|
||
|
|
||
|
def test_inner_function_nested(self):
|
||
|
|
||
|
def outer(x):
|
||
|
|
||
|
def inner(y):
|
||
|
|
||
|
def innermost(z):
|
||
|
return x + y + z
|
||
|
|
||
|
s = 0
|
||
|
for i in range(y):
|
||
|
s += innermost(i)
|
||
|
return s
|
||
|
|
||
|
return inner(x * x)
|
||
|
|
||
|
cfunc = njit(outer)
|
||
|
self.assertEqual(cfunc(10), outer(10))
|
||
|
|
||
|
def test_bulk_use_cases(self):
|
||
|
""" Tests the large number of use cases defined below """
|
||
|
|
||
|
# jitted function used in some tests
|
||
|
@njit
|
||
|
def fib3(n):
|
||
|
if n < 2:
|
||
|
return n
|
||
|
return fib3(n - 1) + fib3(n - 2)
|
||
|
|
||
|
def outer1(x):
|
||
|
""" Test calling recursive function from inner """
|
||
|
def inner(x):
|
||
|
return fib3(x)
|
||
|
return inner(x)
|
||
|
|
||
|
def outer2(x):
|
||
|
""" Test calling recursive function from closure """
|
||
|
z = x + 1
|
||
|
|
||
|
def inner(x):
|
||
|
return x + fib3(z)
|
||
|
return inner(x)
|
||
|
|
||
|
def outer3(x):
|
||
|
""" Test recursive inner """
|
||
|
def inner(x):
|
||
|
if x < 2:
|
||
|
return 10
|
||
|
else:
|
||
|
inner(x - 1)
|
||
|
return inner(x)
|
||
|
|
||
|
def outer4(x):
|
||
|
""" Test recursive closure """
|
||
|
y = x + 1
|
||
|
|
||
|
def inner(x):
|
||
|
if x + y < 2:
|
||
|
return 10
|
||
|
else:
|
||
|
inner(x - 1)
|
||
|
return inner(x)
|
||
|
|
||
|
def outer5(x):
|
||
|
""" Test nested closure """
|
||
|
y = x + 1
|
||
|
|
||
|
def inner1(x):
|
||
|
z = y + x + 2
|
||
|
|
||
|
def inner2(x):
|
||
|
return x + z
|
||
|
|
||
|
return inner2(x) + y
|
||
|
|
||
|
return inner1(x)
|
||
|
|
||
|
def outer6(x):
|
||
|
""" Test closure with list comprehension in body """
|
||
|
y = x + 1
|
||
|
|
||
|
def inner1(x):
|
||
|
z = y + x + 2
|
||
|
return [t for t in range(z)]
|
||
|
return inner1(x)
|
||
|
|
||
|
_OUTER_SCOPE_VAR = 9
|
||
|
|
||
|
def outer7(x):
|
||
|
""" Test use of outer scope var, no closure """
|
||
|
z = x + 1
|
||
|
return x + z + _OUTER_SCOPE_VAR
|
||
|
|
||
|
_OUTER_SCOPE_VAR = 9
|
||
|
|
||
|
def outer8(x):
|
||
|
""" Test use of outer scope var, with closure """
|
||
|
z = x + 1
|
||
|
|
||
|
def inner(x):
|
||
|
return x + z + _OUTER_SCOPE_VAR
|
||
|
return inner(x)
|
||
|
|
||
|
def outer9(x):
|
||
|
""" Test closure assignment"""
|
||
|
z = x + 1
|
||
|
|
||
|
def inner(x):
|
||
|
return x + z
|
||
|
f = inner
|
||
|
return f(x)
|
||
|
|
||
|
def outer10(x):
|
||
|
""" Test two inner, one calls other """
|
||
|
z = x + 1
|
||
|
|
||
|
def inner(x):
|
||
|
return x + z
|
||
|
|
||
|
def inner2(x):
|
||
|
return inner(x)
|
||
|
|
||
|
return inner2(x)
|
||
|
|
||
|
def outer11(x):
|
||
|
""" return the closure """
|
||
|
z = x + 1
|
||
|
|
||
|
def inner(x):
|
||
|
return x + z
|
||
|
return inner
|
||
|
|
||
|
def outer12(x):
|
||
|
""" closure with kwarg"""
|
||
|
z = x + 1
|
||
|
|
||
|
def inner(x, kw=7):
|
||
|
return x + z + kw
|
||
|
return inner(x)
|
||
|
|
||
|
def outer13(x, kw=7):
|
||
|
""" outer with kwarg no closure"""
|
||
|
z = x + 1 + kw
|
||
|
return z
|
||
|
|
||
|
def outer14(x, kw=7):
|
||
|
""" outer with kwarg used in closure"""
|
||
|
z = x + 1
|
||
|
|
||
|
def inner(x):
|
||
|
return x + z + kw
|
||
|
return inner(x)
|
||
|
|
||
|
def outer15(x, kw=7):
|
||
|
""" outer with kwarg as arg to closure"""
|
||
|
z = x + 1
|
||
|
|
||
|
def inner(x, kw):
|
||
|
return x + z + kw
|
||
|
return inner(x, kw)
|
||
|
|
||
|
def outer16(x):
|
||
|
""" closure is generator, consumed locally """
|
||
|
z = x + 1
|
||
|
|
||
|
def inner(x):
|
||
|
yield x + z
|
||
|
|
||
|
return list(inner(x))
|
||
|
|
||
|
def outer17(x):
|
||
|
""" closure is generator, returned """
|
||
|
z = x + 1
|
||
|
|
||
|
def inner(x):
|
||
|
yield x + z
|
||
|
|
||
|
return inner(x)
|
||
|
|
||
|
def outer18(x):
|
||
|
""" closure is generator, consumed in loop """
|
||
|
z = x + 1
|
||
|
|
||
|
def inner(x):
|
||
|
yield x + z
|
||
|
|
||
|
for i in inner(x):
|
||
|
t = i
|
||
|
|
||
|
return t
|
||
|
|
||
|
def outer19(x):
|
||
|
""" closure as arg to another closure """
|
||
|
z1 = x + 1
|
||
|
z2 = x + 2
|
||
|
|
||
|
def inner(x):
|
||
|
return x + z1
|
||
|
|
||
|
def inner2(f, x):
|
||
|
return f(x) + z2
|
||
|
|
||
|
return inner2(inner, x)
|
||
|
|
||
|
def outer20(x):
|
||
|
""" Test calling numpy in closure """
|
||
|
z = x + 1
|
||
|
|
||
|
def inner(x):
|
||
|
return x + numpy.cos(z)
|
||
|
return inner(x)
|
||
|
|
||
|
def outer21(x):
|
||
|
""" Test calling numpy import as in closure """
|
||
|
z = x + 1
|
||
|
|
||
|
def inner(x):
|
||
|
return x + np.cos(z)
|
||
|
return inner(x)
|
||
|
|
||
|
def outer22():
|
||
|
"""Test to ensure that unsupported *args raises correctly"""
|
||
|
def bar(a, b):
|
||
|
pass
|
||
|
x = 1, 2
|
||
|
bar(*x)
|
||
|
|
||
|
# functions to test that are expected to pass
|
||
|
f = [outer1, outer2, outer5, outer6, outer7, outer8,
|
||
|
outer9, outer10, outer12, outer13, outer14,
|
||
|
outer15, outer19, outer20, outer21]
|
||
|
for ref in f:
|
||
|
cfunc = njit(ref)
|
||
|
var = 10
|
||
|
self.assertEqual(cfunc(var), ref(var))
|
||
|
|
||
|
# test functions that are expected to fail
|
||
|
with self.assertRaises(NotImplementedError) as raises:
|
||
|
cfunc = jit(nopython=True)(outer3)
|
||
|
cfunc(var)
|
||
|
msg = "Unsupported use of op_LOAD_CLOSURE encountered"
|
||
|
self.assertIn(msg, str(raises.exception))
|
||
|
|
||
|
with self.assertRaises(NotImplementedError) as raises:
|
||
|
cfunc = jit(nopython=True)(outer4)
|
||
|
cfunc(var)
|
||
|
msg = "Unsupported use of op_LOAD_CLOSURE encountered"
|
||
|
self.assertIn(msg, str(raises.exception))
|
||
|
|
||
|
with self.assertRaises(TypingError) as raises:
|
||
|
cfunc = jit(nopython=True)(outer11)
|
||
|
cfunc(var)
|
||
|
msg = "Cannot capture the non-constant value"
|
||
|
self.assertIn(msg, str(raises.exception))
|
||
|
|
||
|
with self.assertRaises(UnsupportedError) as raises:
|
||
|
cfunc = jit(nopython=True)(outer16)
|
||
|
cfunc(var)
|
||
|
msg = "The use of yield in a closure is unsupported."
|
||
|
self.assertIn(msg, str(raises.exception))
|
||
|
|
||
|
with self.assertRaises(UnsupportedError) as raises:
|
||
|
cfunc = jit(nopython=True)(outer17)
|
||
|
cfunc(var)
|
||
|
msg = "The use of yield in a closure is unsupported."
|
||
|
self.assertIn(msg, str(raises.exception))
|
||
|
|
||
|
with self.assertRaises(UnsupportedError) as raises:
|
||
|
cfunc = jit(nopython=True)(outer18)
|
||
|
cfunc(var)
|
||
|
msg = "The use of yield in a closure is unsupported."
|
||
|
self.assertIn(msg, str(raises.exception))
|
||
|
|
||
|
with self.assertRaises(UnsupportedError) as raises:
|
||
|
cfunc = jit(nopython=True)(outer22)
|
||
|
cfunc()
|
||
|
msg = "Calling a closure with *args is unsupported."
|
||
|
self.assertIn(msg, str(raises.exception))
|
||
|
|
||
|
def test_closure_renaming_scheme(self):
|
||
|
# See #7380, this checks that inlined (from closure) variables have a
|
||
|
# name derived from the function they were defined in.
|
||
|
|
||
|
@njit(pipeline_class=IRPreservingTestPipeline)
|
||
|
def foo(a, b):
|
||
|
def bar(z):
|
||
|
x = 5
|
||
|
y = 10
|
||
|
return x + y + z
|
||
|
return bar(a), bar(b)
|
||
|
|
||
|
self.assertEqual(foo(10, 20), (25, 35))
|
||
|
|
||
|
# check IR. Look for the `x = 5`... there should be
|
||
|
# Two lots of `const(int, 5)`, one for each inline
|
||
|
# The LHS of the assignment will have a name like:
|
||
|
# closure__locals__bar_v2_x
|
||
|
# Ensure that this is the case!
|
||
|
func_ir = foo.overloads[foo.signatures[0]].metadata['preserved_ir']
|
||
|
store = []
|
||
|
for blk in func_ir.blocks.values():
|
||
|
for stmt in blk.body:
|
||
|
if isinstance(stmt, ir.Assign):
|
||
|
if isinstance(stmt.value, ir.Const):
|
||
|
if stmt.value.value == 5:
|
||
|
store.append(stmt)
|
||
|
|
||
|
self.assertEqual(len(store), 2)
|
||
|
for i in store:
|
||
|
name = i.target.name
|
||
|
regex = r'closure__locals__bar_v[0-9]+.x'
|
||
|
self.assertRegex(name, regex)
|
||
|
|
||
|
def test_issue9222(self):
|
||
|
# Ensures that float default arguments are handled correctly in
|
||
|
# closures.
|
||
|
|
||
|
@njit
|
||
|
def foo():
|
||
|
def bar(x, y=1.1):
|
||
|
return x + y
|
||
|
return bar
|
||
|
|
||
|
@njit
|
||
|
def consume():
|
||
|
return foo()(4)
|
||
|
|
||
|
# In Issue #9222, the result was completely wrong - 15 instead of 5.1 -
|
||
|
# so allclose should be sufficient for comparison here.
|
||
|
np.testing.assert_allclose(consume(), 4 + 1.1)
|
||
|
|
||
|
|
||
|
if __name__ == '__main__':
|
||
|
unittest.main()
|