ai-content-maker/.venv/Lib/site-packages/numba/tests/test_withlifting.py

1215 lines
34 KiB
Python

import copy
import warnings
import numpy as np
import numba
from numba.core.transforms import find_setupwiths, with_lifting
from numba.core.withcontexts import bypass_context, call_context, objmode_context
from numba.core.bytecode import FunctionIdentity, ByteCode
from numba.core.interpreter import Interpreter
from numba.core import errors
from numba.core.registry import cpu_target
from numba.core.compiler import compile_ir, DEFAULT_FLAGS
from numba import njit, typeof, objmode, types
from numba.core.extending import overload
from numba.tests.support import (MemoryLeak, TestCase, captured_stdout,
skip_unless_scipy, linux_only,
strace_supported, strace,
expected_failure_py311,
expected_failure_py312)
from numba.core.utils import PYVERSION
from numba.experimental import jitclass
import unittest
def get_func_ir(func):
func_id = FunctionIdentity.from_function(func)
bc = ByteCode(func_id=func_id)
interp = Interpreter(func_id)
func_ir = interp.interpret(bc)
return func_ir
def lift1():
print("A")
with bypass_context:
print("B")
b()
print("C")
def lift2():
x = 1
print("A", x)
x = 1
with bypass_context:
print("B", x)
x += 100
b()
x += 1
with bypass_context:
print("C", x)
b()
x += 10
x += 1
print("D", x)
def lift3():
x = 1
y = 100
print("A", x, y)
with bypass_context:
print("B")
b()
x += 100
with bypass_context:
print("C")
y += 100000
b()
x += 1
y += 1
print("D", x, y)
def lift4():
x = 0
print("A", x)
x += 10
with bypass_context:
print("B")
b()
x += 1
for i in range(10):
with bypass_context:
print("C")
b()
x += i
with bypass_context:
print("D")
b()
if x:
x *= 10
x += 1
print("E", x)
def lift5():
print("A")
def liftcall1():
x = 1
print("A", x)
with call_context:
x += 1
print("B", x)
return x
def liftcall2():
x = 1
print("A", x)
with call_context:
x += 1
print("B", x)
with call_context:
x += 10
print("C", x)
return x
def liftcall3():
x = 1
print("A", x)
with call_context:
if x > 0:
x += 1
print("B", x)
with call_context:
for i in range(10):
x += i
print("C", x)
return x
def liftcall4():
with call_context:
with call_context:
pass
def liftcall5():
for i in range(10):
with call_context:
print(i)
if i == 5:
print("A")
break
return i
def lift_undefiend():
with undefined_global_var:
pass
bogus_contextmanager = object()
def lift_invalid():
with bogus_contextmanager:
pass
gv_type = types.intp
class TestWithFinding(TestCase):
def check_num_of_with(self, func, expect_count):
the_ir = get_func_ir(func)
ct = len(find_setupwiths(the_ir)[0])
self.assertEqual(ct, expect_count)
def test_lift1(self):
self.check_num_of_with(lift1, expect_count=1)
def test_lift2(self):
self.check_num_of_with(lift2, expect_count=2)
def test_lift3(self):
self.check_num_of_with(lift3, expect_count=1)
def test_lift4(self):
self.check_num_of_with(lift4, expect_count=2)
def test_lift5(self):
self.check_num_of_with(lift5, expect_count=0)
class BaseTestWithLifting(TestCase):
def setUp(self):
super(BaseTestWithLifting, self).setUp()
self.typingctx = cpu_target.typing_context
self.targetctx = cpu_target.target_context
self.flags = DEFAULT_FLAGS
def check_extracted_with(self, func, expect_count, expected_stdout):
the_ir = get_func_ir(func)
new_ir, extracted = with_lifting(
the_ir, self.typingctx, self.targetctx, self.flags,
locals={},
)
self.assertEqual(len(extracted), expect_count)
cres = self.compile_ir(new_ir)
with captured_stdout() as out:
cres.entry_point()
self.assertEqual(out.getvalue(), expected_stdout)
def compile_ir(self, the_ir, args=(), return_type=None):
typingctx = self.typingctx
targetctx = self.targetctx
flags = self.flags
return compile_ir(typingctx, targetctx, the_ir, args,
return_type, flags, locals={})
class TestLiftByPass(BaseTestWithLifting):
def test_lift1(self):
self.check_extracted_with(lift1, expect_count=1,
expected_stdout="A\nC\n")
def test_lift2(self):
self.check_extracted_with(lift2, expect_count=2,
expected_stdout="A 1\nD 3\n")
def test_lift3(self):
self.check_extracted_with(lift3, expect_count=1,
expected_stdout="A 1 100\nD 2 101\n")
def test_lift4(self):
self.check_extracted_with(lift4, expect_count=2,
expected_stdout="A 0\nE 11\n")
def test_lift5(self):
self.check_extracted_with(lift5, expect_count=0,
expected_stdout="A\n")
class TestLiftCall(BaseTestWithLifting):
def check_same_semantic(self, func):
"""Ensure same semantic with non-jitted code
"""
jitted = njit(func)
with captured_stdout() as got:
jitted()
with captured_stdout() as expect:
func()
self.assertEqual(got.getvalue(), expect.getvalue())
def test_liftcall1(self):
self.check_extracted_with(liftcall1, expect_count=1,
expected_stdout="A 1\nB 2\n")
self.check_same_semantic(liftcall1)
def test_liftcall2(self):
self.check_extracted_with(liftcall2, expect_count=2,
expected_stdout="A 1\nB 2\nC 12\n")
self.check_same_semantic(liftcall2)
def test_liftcall3(self):
self.check_extracted_with(liftcall3, expect_count=2,
expected_stdout="A 1\nB 2\nC 47\n")
self.check_same_semantic(liftcall3)
def test_liftcall4(self):
accept = (errors.TypingError, errors.NumbaRuntimeError,
errors.NumbaValueError, errors.CompilerError)
with self.assertRaises(accept) as raises:
njit(liftcall4)()
# Known error. We only support one context manager per function
# for body that are lifted.
msg = ("compiler re-entrant to the same function signature")
self.assertIn(msg, str(raises.exception))
@expected_failure_py311
@expected_failure_py312
def test_liftcall5(self):
self.check_extracted_with(liftcall5, expect_count=1,
expected_stdout="0\n1\n2\n3\n4\n5\nA\n")
self.check_same_semantic(liftcall5)
def expected_failure_for_list_arg(fn):
def core(self, *args, **kwargs):
with self.assertRaises(errors.TypingError) as raises:
fn(self, *args, **kwargs)
self.assertIn('Does not support list type',
str(raises.exception))
return core
def expected_failure_for_function_arg(fn):
def core(self, *args, **kwargs):
with self.assertRaises(errors.TypingError) as raises:
fn(self, *args, **kwargs)
self.assertIn('Does not support function type',
str(raises.exception))
return core
class TestLiftObj(MemoryLeak, TestCase):
def setUp(self):
warnings.simplefilter("error", errors.NumbaWarning)
def tearDown(self):
warnings.resetwarnings()
def assert_equal_return_and_stdout(self, pyfunc, *args):
py_args = copy.deepcopy(args)
c_args = copy.deepcopy(args)
cfunc = njit(pyfunc)
with captured_stdout() as stream:
expect_res = pyfunc(*py_args)
expect_out = stream.getvalue()
# avoid compiling during stdout-capturing for easier print-debugging
cfunc.compile(tuple(map(typeof, c_args)))
with captured_stdout() as stream:
got_res = cfunc(*c_args)
got_out = stream.getvalue()
self.assertEqual(expect_out, got_out)
self.assertPreciseEqual(expect_res, got_res)
def test_lift_objmode_basic(self):
def bar(ival):
print("ival =", {'ival': ival // 2})
def foo(ival):
ival += 1
with objmode_context:
bar(ival)
return ival + 1
def foo_nonglobal(ival):
ival += 1
with numba.objmode:
bar(ival)
return ival + 1
self.assert_equal_return_and_stdout(foo, 123)
self.assert_equal_return_and_stdout(foo_nonglobal, 123)
def test_lift_objmode_array_in(self):
def bar(arr):
print({'arr': arr // 2})
# arr is modified. the effect is visible outside.
arr *= 2
def foo(nelem):
arr = np.arange(nelem).astype(np.int64)
with objmode_context:
# arr is modified inplace inside bar()
bar(arr)
return arr + 1
nelem = 10
self.assert_equal_return_and_stdout(foo, nelem)
def test_lift_objmode_define_new_unused(self):
def bar(y):
print(y)
def foo(x):
with objmode_context():
y = 2 + x # defined but unused outside
a = np.arange(y) # defined but unused outside
bar(a)
return x
arg = 123
self.assert_equal_return_and_stdout(foo, arg)
def test_lift_objmode_return_simple(self):
def inverse(x):
print(x)
return 1 / x
def foo(x):
with objmode_context(y="float64"):
y = inverse(x)
return x, y
def foo_nonglobal(x):
with numba.objmode(y="float64"):
y = inverse(x)
return x, y
arg = 123
self.assert_equal_return_and_stdout(foo, arg)
self.assert_equal_return_and_stdout(foo_nonglobal, arg)
def test_lift_objmode_return_array(self):
def inverse(x):
print(x)
return 1 / x
def foo(x):
with objmode_context(y="float64[:]", z="int64"):
y = inverse(x)
z = int(y[0])
return x, y, z
arg = np.arange(1, 10, dtype=np.float64)
self.assert_equal_return_and_stdout(foo, arg)
@expected_failure_for_list_arg
def test_lift_objmode_using_list(self):
def foo(x):
with objmode_context(y="float64[:]"):
print(x)
x[0] = 4
print(x)
y = [1, 2, 3] + x
y = np.asarray([1 / i for i in y])
return x, y
arg = [1, 2, 3]
self.assert_equal_return_and_stdout(foo, arg)
def test_lift_objmode_var_redef(self):
def foo(x):
for x in range(x):
pass
if x:
x += 1
with objmode_context(x="intp"):
print(x)
x -= 1
print(x)
for i in range(x):
x += i
print(x)
return x
arg = 123
self.assert_equal_return_and_stdout(foo, arg)
@expected_failure_for_list_arg
def test_case01_mutate_list_ahead_of_ctx(self):
def foo(x, z):
x[2] = z
with objmode_context():
# should print [1, 2, 15] but prints [1, 2, 3]
print(x)
with objmode_context():
x[2] = 2 * z
# should print [1, 2, 30] but prints [1, 2, 15]
print(x)
return x
self.assert_equal_return_and_stdout(foo, [1, 2, 3], 15)
def test_case02_mutate_array_ahead_of_ctx(self):
def foo(x, z):
x[2] = z
with objmode_context():
# should print [1, 2, 15]
print(x)
with objmode_context():
x[2] = 2 * z
# should print [1, 2, 30]
print(x)
return x
x = np.array([1, 2, 3])
self.assert_equal_return_and_stdout(foo, x, 15)
@expected_failure_for_list_arg
def test_case03_create_and_mutate(self):
def foo(x):
with objmode_context(y='List(int64)'):
y = [1, 2, 3]
with objmode_context():
y[2] = 10
return y
self.assert_equal_return_and_stdout(foo, 1)
def test_case04_bogus_variable_type_info(self):
def foo(x):
# should specifying nonsense type info be considered valid?
with objmode_context(k="float64[:]"):
print(x)
return x
x = np.array([1, 2, 3])
cfoo = njit(foo)
with self.assertRaises(errors.TypingError) as raises:
cfoo(x)
self.assertIn(
"Invalid type annotation on non-outgoing variables",
str(raises.exception),
)
def test_case05_bogus_type_info(self):
def foo(x):
# should specifying the wrong type info be considered valid?
# z is complex.
# Note: for now, we will coerce for scalar and raise for array
with objmode_context(z="float64[:]"):
z = x + 1.j
return z
x = np.array([1, 2, 3])
cfoo = njit(foo)
with self.assertRaises(TypeError) as raises:
got = cfoo(x)
self.assertIn(
("can't unbox array from PyObject into native value."
" The object maybe of a different type"),
str(raises.exception),
)
def test_case06_double_objmode(self):
def foo(x):
# would nested ctx in the same scope ever make sense? Is this
# pattern useful?
with objmode_context():
#with npmmode_context(): not implemented yet
with objmode_context():
print(x)
return x
with self.assertRaises(errors.TypingError) as raises:
njit(foo)(123)
# Check that an error occurred in with-lifting in objmode
pat = ("During: resolving callee type: "
r"type\(ObjModeLiftedWith\(<.*>\)\)")
self.assertRegex(str(raises.exception), pat)
def test_case07_mystery_key_error(self):
# this raises a key error
def foo(x):
with objmode_context():
t = {'a': x}
u = 3
return x, t, u
x = np.array([1, 2, 3])
cfoo = njit(foo)
with self.assertRaises(errors.TypingError) as raises:
cfoo(x)
exstr = str(raises.exception)
self.assertIn("Missing type annotation on outgoing variable(s): "
"['t', 'u']",
exstr)
self.assertIn("Example code: with objmode"
"(t='<add_type_as_string_here>')",
exstr)
def test_case08_raise_from_external(self):
# this segfaults, expect its because the dict needs to raise as '2' is
# not in the keys until a later loop (looking for `d['0']` works fine).
d = dict()
def foo(x):
for i in range(len(x)):
with objmode_context():
k = str(i)
v = x[i]
d[k] = v
print(d['2'])
return x
x = np.array([1, 2, 3])
cfoo = njit(foo)
with self.assertRaises(KeyError) as raises:
cfoo(x)
self.assertEqual(str(raises.exception), "'2'")
def test_case09_explicit_raise(self):
def foo(x):
with objmode_context():
raise ValueError()
return x
x = np.array([1, 2, 3])
cfoo = njit(foo)
with self.assertRaises(errors.CompilerError) as raises:
cfoo(x)
self.assertIn(
('unsupported control flow due to raise statements inside '
'with block'),
str(raises.exception),
)
@expected_failure_for_list_arg
def test_case10_mutate_across_contexts(self):
# This shouldn't work due to using List as input.
def foo(x):
with objmode_context(y='List(int64)'):
y = [1, 2, 3]
with objmode_context():
y[2] = 10
return y
x = np.array([1, 2, 3])
self.assert_equal_return_and_stdout(foo, x)
def test_case10_mutate_array_across_contexts(self):
# Sub-case of case-10.
def foo(x):
with objmode_context(y='int64[:]'):
y = np.asarray([1, 2, 3], dtype='int64')
with objmode_context():
# Note: `y` is not an output.
y[2] = 10
return y
x = np.array([1, 2, 3])
self.assert_equal_return_and_stdout(foo, x)
def test_case11_define_function_in_context(self):
# should this work? no, global name 'bar' is not defined
def foo(x):
with objmode_context():
def bar(y):
return y + 1
return x
x = np.array([1, 2, 3])
cfoo = njit(foo)
with self.assertRaises(NameError) as raises:
cfoo(x)
self.assertIn(
"global name 'bar' is not defined",
str(raises.exception),
)
def test_case12_njit_inside_a_objmode_ctx(self):
# TODO: is this still the cases?
# this works locally but not inside this test, probably due to the way
# compilation is being done
def bar(y):
return y + 1
def foo(x):
with objmode_context(y='int64[:]'):
y = njit(bar)(x).astype('int64')
return x + y
x = np.array([1, 2, 3])
self.assert_equal_return_and_stdout(foo, x)
def test_case14_return_direct_from_objmode_ctx(self):
def foo(x):
with objmode_context(x='int64[:]'):
x += 1
return x
result = foo(np.array([1, 2, 3]))
np.testing.assert_array_equal(np.array([2, 3, 4]), result)
# No easy way to handle this yet.
@unittest.expectedFailure
def test_case15_close_over_objmode_ctx(self):
# Fails with Unsupported constraint encountered: enter_with $phi8.1
def foo(x):
j = 10
def bar(x):
with objmode_context(x='int64[:]'):
print(x)
return x + j
return bar(x) + 2
x = np.array([1, 2, 3])
self.assert_equal_return_and_stdout(foo, x)
@skip_unless_scipy
def test_case16_scipy_call_in_objmode_ctx(self):
from scipy import sparse as sp
def foo(x):
with objmode_context(k='int64'):
print(x)
spx = sp.csr_matrix(x)
# the np.int64 call is pointless, works around:
# https://github.com/scipy/scipy/issues/10206
# which hit the SciPy 1.3 release.
k = np.int64(spx[0, 0])
return k
x = np.array([1, 2, 3])
self.assert_equal_return_and_stdout(foo, x)
def test_case17_print_own_bytecode(self):
import dis
def foo(x):
with objmode_context():
dis.dis(foo)
x = np.array([1, 2, 3])
self.assert_equal_return_and_stdout(foo, x)
@expected_failure_for_function_arg
def test_case18_njitfunc_passed_to_objmode_ctx(self):
def foo(func, x):
with objmode_context():
func(x[0])
x = np.array([1, 2, 3])
fn = njit(lambda z: z + 5)
self.assert_equal_return_and_stdout(foo, fn, x)
@expected_failure_py311
@expected_failure_py312
def test_case19_recursion(self):
def foo(x):
with objmode_context():
if x == 0:
return 7
ret = foo(x - 1)
return ret
with self.assertRaises((errors.TypingError, errors.CompilerError)) as raises:
cfoo = njit(foo)
cfoo(np.array([1, 2, 3]))
msg = "Untyped global name 'foo'"
self.assertIn(msg, str(raises.exception))
@unittest.expectedFailure
def test_case20_rng_works_ok(self):
def foo(x):
np.random.seed(0)
y = np.random.rand()
with objmode_context(z="float64"):
# It's known that the random state does not sync
z = np.random.rand()
return x + z + y
x = np.array([1, 2, 3])
self.assert_equal_return_and_stdout(foo, x)
def test_case21_rng_seed_works_ok(self):
def foo(x):
np.random.seed(0)
y = np.random.rand()
with objmode_context(z="float64"):
# Similar to test_case20_rng_works_ok but call seed
np.random.seed(0)
z = np.random.rand()
return x + z + y
x = np.array([1, 2, 3])
self.assert_equal_return_and_stdout(foo, x)
def test_example01(self):
# Example from _ObjModeContextType.__doc__
def bar(x):
return np.asarray(list(reversed(x.tolist())))
@njit
def foo():
x = np.arange(5)
with objmode(y='intp[:]'): # annotate return type
# this region is executed by object-mode.
y = x + bar(x)
return y
self.assertPreciseEqual(foo(), foo.py_func())
self.assertIs(objmode, objmode_context)
def test_objmode_in_overload(self):
def foo(s):
pass
@overload(foo)
def foo_overload(s):
def impl(s):
with objmode(out='intp'):
out = s + 3
return out
return impl
@numba.njit
def f():
return foo(1)
self.assertEqual(f(), 1 + 3)
def test_objmode_gv_variable(self):
@njit
def global_var():
with objmode(val=gv_type):
val = 12.3
return val
ret = global_var()
# the result is truncated because of the intp return-type
self.assertIsInstance(ret, int)
self.assertEqual(ret, 12)
def test_objmode_gv_variable_error(self):
@njit
def global_var():
with objmode(val=gv_type2):
val = 123
return val
with self.assertRaisesRegex(
errors.CompilerError,
("Error handling objmode argument 'val'. "
r"Global 'gv_type2' is not defined.")
):
global_var()
def test_objmode_gv_mod_attr(self):
@njit
def modattr1():
with objmode(val=types.intp):
val = 12.3
return val
@njit
def modattr2():
with objmode(val=numba.types.intp):
val = 12.3
return val
for fn in (modattr1, modattr2):
with self.subTest(fn=str(fn)):
ret = fn()
# the result is truncated because of the intp return-type
self.assertIsInstance(ret, int)
self.assertEqual(ret, 12)
def test_objmode_gv_mod_attr_error(self):
@njit
def moderror():
with objmode(val=types.THIS_DOES_NOT_EXIST):
val = 12.3
return val
with self.assertRaisesRegex(
errors.CompilerError,
("Error handling objmode argument 'val'. "
"Getattr cannot be resolved at compile-time"),
):
moderror()
def test_objmode_gv_mod_attr_error_multiple(self):
@njit
def moderror():
with objmode(v1=types.intp, v2=types.THIS_DOES_NOT_EXIST,
v3=types.float32):
v1 = 12.3
v2 = 12.3
v3 = 12.3
return val
with self.assertRaisesRegex(
errors.CompilerError,
("Error handling objmode argument 'v2'. "
"Getattr cannot be resolved at compile-time"),
):
moderror()
def test_objmode_closure_type_in_overload(self):
def foo():
pass
@overload(foo)
def foo_overload():
shrubbery = types.float64[:]
def impl():
with objmode(out=shrubbery):
out = np.arange(10).astype(np.float64)
return out
return impl
@njit
def bar():
return foo()
self.assertPreciseEqual(bar(), np.arange(10).astype(np.float64))
def test_objmode_closure_type_in_overload_error(self):
def foo():
pass
@overload(foo)
def foo_overload():
shrubbery = types.float64[:]
def impl():
with objmode(out=shrubbery):
out = np.arange(10).astype(np.float64)
return out
# Remove closure var.
# Otherwise, it will "shrubbery" will be a global
del shrubbery
return impl
@njit
def bar():
return foo()
with self.assertRaisesRegex(
errors.TypingError,
("Error handling objmode argument 'out'. "
"Freevar 'shrubbery' is not defined"),
):
bar()
def test_objmode_invalid_use(self):
@njit
def moderror():
with objmode(bad=1 + 1):
out = 1
return val
with self.assertRaisesRegex(
errors.CompilerError,
("Error handling objmode argument 'bad'. "
"The value must be a compile-time constant either as "
"a non-local variable or a getattr expression that "
"refers to a Numba type."),
):
moderror()
def test_objmode_multi_type_args(self):
array_ty = types.int32[:]
@njit
def foo():
# t1 is a string
# t2 is a global type
# t3 is a non-local/freevar
with objmode(t1="float64", t2=gv_type, t3=array_ty):
t1 = 793856.5
t2 = t1 # to observe truncation
t3 = np.arange(5).astype(np.int32)
return t1, t2, t3
t1, t2, t3 = foo()
self.assertPreciseEqual(t1, 793856.5)
self.assertPreciseEqual(t2, 793856)
self.assertPreciseEqual(t3, np.arange(5).astype(np.int32))
def test_objmode_jitclass(self):
spec = [
('value', types.int32), # a simple scalar field
('array', types.float32[:]), # an array field
]
@jitclass(spec)
class Bag(object):
def __init__(self, value):
self.value = value
self.array = np.zeros(value, dtype=np.float32)
@property
def size(self):
return self.array.size
def increment(self, val):
for i in range(self.size):
self.array[i] += val
return self.array
@staticmethod
def add(x, y):
return x + y
n = 21
mybag = Bag(n)
def foo():
pass
@overload(foo)
def foo_overload():
shrubbery = mybag._numba_type_
def impl():
with objmode(out=shrubbery):
out = Bag(123)
out.increment(3)
return out
return impl
@njit
def bar():
return foo()
z = bar()
self.assertIsInstance(z, Bag)
self.assertEqual(z.add(2, 3), 2 + 3)
exp_array = np.zeros(123, dtype=np.float32) + 3
self.assertPreciseEqual(z.array, exp_array)
@staticmethod
def case_objmode_cache(x):
with objmode(output='float64'):
output = x / 10
return output
def test_objmode_reflected_list(self):
ret_type = typeof([1, 2, 3, 4, 5])
@njit
def test2():
with objmode(out=ret_type):
out = [1, 2, 3, 4, 5]
return out
with self.assertRaises(errors.CompilerError) as raises:
test2()
self.assertRegex(
str(raises.exception),
(r"Objmode context failed. "
r"Argument 'out' is declared as an unsupported type: "
r"reflected list\(int(32|64)\)<iv=None>. "
r"Reflected types are not supported."),
)
def test_objmode_reflected_set(self):
ret_type = typeof({1, 2, 3, 4, 5})
@njit
def test2():
with objmode(result=ret_type):
result = {1, 2, 3, 4, 5}
return result
with self.assertRaises(errors.CompilerError) as raises:
test2()
self.assertRegex(
str(raises.exception),
(r"Objmode context failed. "
r"Argument 'result' is declared as an unsupported type: "
r"reflected set\(int(32|64)\). "
r"Reflected types are not supported."),
)
def test_objmode_typed_dict(self):
ret_type = types.DictType(types.unicode_type, types.int64)
@njit
def test4():
with objmode(res=ret_type):
res = {'A': 1, 'B': 2}
return res
with self.assertRaises(TypeError) as raises:
test4()
self.assertIn(
("can't unbox a <class 'dict'> "
"as a <class 'numba.typed.typeddict.Dict'>"),
str(raises.exception),
)
def test_objmode_typed_list(self):
ret_type = types.ListType(types.int64)
@njit
def test4():
with objmode(res=ret_type):
res = [1, 2]
return res
with self.assertRaises(TypeError) as raises:
test4()
self.assertRegex(
str(raises.exception),
(r"can't unbox a <class 'list'> "
r"as a (<class ')?numba.typed.typedlist.List('>)?"),
)
def test_objmode_use_of_view(self):
# See issue #7158, npm functionality should only be validated if in
# npm.
@njit
def foo(x):
with numba.objmode(y="int64[::1]"):
y = x.view("int64")
return y
a = np.ones(1, np.int64).view('float64')
expected = foo.py_func(a)
got = foo(a)
self.assertPreciseEqual(expected, got)
def case_inner_pyfunc(x):
return x / 10
def case_objmode_cache(x):
with objmode(output='float64'):
output = case_inner_pyfunc(x)
return output
class TestLiftObjCaching(MemoryLeak, TestCase):
# Warnings in this test class are converted to errors
def setUp(self):
warnings.simplefilter("error", errors.NumbaWarning)
def tearDown(self):
warnings.resetwarnings()
def check(self, py_func):
first = njit(cache=True)(py_func)
self.assertEqual(first(123), 12.3)
second = njit(cache=True)(py_func)
self.assertFalse(second._cache_hits)
self.assertEqual(second(123), 12.3)
self.assertTrue(second._cache_hits)
def test_objmode_caching_basic(self):
def pyfunc(x):
with objmode(output='float64'):
output = x / 10
return output
self.check(pyfunc)
def test_objmode_caching_call_closure_bad(self):
def other_pyfunc(x):
return x / 10
def pyfunc(x):
with objmode(output='float64'):
output = other_pyfunc(x)
return output
self.check(pyfunc)
def test_objmode_caching_call_closure_good(self):
self.check(case_objmode_cache)
class TestBogusContext(BaseTestWithLifting):
def test_undefined_global(self):
the_ir = get_func_ir(lift_undefiend)
with self.assertRaises(errors.CompilerError) as raises:
with_lifting(
the_ir, self.typingctx, self.targetctx, self.flags, locals={},
)
self.assertIn(
"Undefined variable used as context manager",
str(raises.exception),
)
def test_invalid(self):
the_ir = get_func_ir(lift_invalid)
with self.assertRaises(errors.CompilerError) as raises:
with_lifting(
the_ir, self.typingctx, self.targetctx, self.flags, locals={},
)
self.assertIn(
"Unsupported context manager in use",
str(raises.exception),
)
def test_with_as_fails_gracefully(self):
@njit
def foo():
with open('') as f:
pass
with self.assertRaises(errors.UnsupportedError) as raises:
foo()
excstr = str(raises.exception)
msg = ("The 'with (context manager) as (variable):' construct is not "
"supported.")
self.assertIn(msg, excstr)
class TestMisc(TestCase):
# Tests for miscellaneous objmode issues. Run serially.
_numba_parallel_test_ = False
@linux_only
@TestCase.run_test_in_subprocess
def test_no_fork_in_compilation(self):
# Checks that there is no fork/clone/execve during compilation, see
# issue #7881. This needs running in a subprocess as the offending fork
# call that triggered #7881 occurs on the first call to uuid1 as it's
# part if the initialisation process for that function (gets hardware
# address of machine).
if not strace_supported():
# Needs strace support.
self.skipTest("strace support missing")
def force_compile():
@njit('void()') # force compilation
def f():
with numba.objmode():
pass
# capture these syscalls:
syscalls = ['fork', 'clone', 'execve']
# check that compilation does not trigger fork, clone or execve
strace_data = strace(force_compile, syscalls)
self.assertFalse(strace_data)
if __name__ == '__main__':
unittest.main()