192 lines
4.7 KiB
Python
192 lines
4.7 KiB
Python
|
"""
|
||
|
Testing object mode specifics.
|
||
|
|
||
|
"""
|
||
|
|
||
|
import numpy as np
|
||
|
|
||
|
import unittest
|
||
|
from numba import jit
|
||
|
from numba.core import utils
|
||
|
from numba.tests.support import TestCase
|
||
|
|
||
|
|
||
|
def complex_constant(n):
|
||
|
tmp = n + 4
|
||
|
return tmp + 3j
|
||
|
|
||
|
|
||
|
def long_constant(n):
|
||
|
return n + 100000000000000000000000000000000000000000000000
|
||
|
|
||
|
|
||
|
def delitem_usecase(x):
|
||
|
del x[:]
|
||
|
|
||
|
|
||
|
def loop_nest_3(x, y):
|
||
|
n = 0
|
||
|
for i in range(x):
|
||
|
for j in range(y):
|
||
|
for k in range(x + y):
|
||
|
n += i * j
|
||
|
|
||
|
return n
|
||
|
|
||
|
|
||
|
def array_of_object(x):
|
||
|
return x
|
||
|
|
||
|
|
||
|
class TestObjectMode(TestCase):
|
||
|
|
||
|
def test_complex_constant(self):
|
||
|
pyfunc = complex_constant
|
||
|
cfunc = jit((), forceobj=True)(pyfunc)
|
||
|
self.assertPreciseEqual(pyfunc(12), cfunc(12))
|
||
|
|
||
|
def test_long_constant(self):
|
||
|
pyfunc = long_constant
|
||
|
cfunc = jit((), forceobj=True)(pyfunc)
|
||
|
self.assertPreciseEqual(pyfunc(12), cfunc(12))
|
||
|
|
||
|
def test_loop_nest(self):
|
||
|
"""
|
||
|
Test bug that decref the iterator early.
|
||
|
If the bug occurs, a segfault should occur
|
||
|
"""
|
||
|
pyfunc = loop_nest_3
|
||
|
cfunc = jit((), forceobj=True)(pyfunc)
|
||
|
self.assertEqual(pyfunc(5, 5), cfunc(5, 5))
|
||
|
|
||
|
def bm_pyfunc():
|
||
|
pyfunc(5, 5)
|
||
|
|
||
|
def bm_cfunc():
|
||
|
cfunc(5, 5)
|
||
|
|
||
|
utils.benchmark(bm_pyfunc)
|
||
|
utils.benchmark(bm_cfunc)
|
||
|
|
||
|
def test_array_of_object(self):
|
||
|
cfunc = jit(forceobj=True)(array_of_object)
|
||
|
objarr = np.array([object()] * 10)
|
||
|
self.assertIs(cfunc(objarr), objarr)
|
||
|
|
||
|
def test_sequence_contains(self):
|
||
|
"""
|
||
|
Test handling of the `in` comparison
|
||
|
"""
|
||
|
@jit(forceobj=True)
|
||
|
def foo(x, y):
|
||
|
return x in y
|
||
|
|
||
|
self.assertTrue(foo(1, [0, 1]))
|
||
|
self.assertTrue(foo(0, [0, 1]))
|
||
|
self.assertFalse(foo(2, [0, 1]))
|
||
|
|
||
|
with self.assertRaises(TypeError) as raises:
|
||
|
foo(None, None)
|
||
|
|
||
|
self.assertIn("is not iterable", str(raises.exception))
|
||
|
|
||
|
def test_delitem(self):
|
||
|
pyfunc = delitem_usecase
|
||
|
cfunc = jit((), forceobj=True)(pyfunc)
|
||
|
|
||
|
l = [3, 4, 5]
|
||
|
cfunc(l)
|
||
|
self.assertPreciseEqual(l, [])
|
||
|
with self.assertRaises(TypeError):
|
||
|
cfunc(42)
|
||
|
|
||
|
def test_starargs_non_tuple(self):
|
||
|
def consumer(*x):
|
||
|
return x
|
||
|
|
||
|
@jit(forceobj=True)
|
||
|
def foo(x):
|
||
|
return consumer(*x)
|
||
|
|
||
|
arg = "ijo"
|
||
|
got = foo(arg)
|
||
|
expect = foo.py_func(arg)
|
||
|
self.assertEqual(got, tuple(arg))
|
||
|
self.assertEqual(got, expect)
|
||
|
|
||
|
def test_expr_undef(self):
|
||
|
@jit(forceobj=True)
|
||
|
def foo():
|
||
|
# In Py3.12, this will emit a Expr.undef.
|
||
|
return [x for x in (1, 2)]
|
||
|
|
||
|
self.assertEqual(foo(), foo.py_func())
|
||
|
|
||
|
|
||
|
class TestObjectModeInvalidRewrite(TestCase):
|
||
|
"""
|
||
|
Tests to ensure that rewrite passes didn't affect objmode lowering.
|
||
|
"""
|
||
|
|
||
|
def _ensure_objmode(self, disp):
|
||
|
self.assertTrue(disp.signatures)
|
||
|
self.assertFalse(disp.nopython_signatures)
|
||
|
return disp
|
||
|
|
||
|
def test_static_raise_in_objmode_fallback(self):
|
||
|
"""
|
||
|
Test code based on user submitted issue at
|
||
|
https://github.com/numba/numba/issues/2159
|
||
|
"""
|
||
|
def test0(n):
|
||
|
return n
|
||
|
|
||
|
def test1(n):
|
||
|
if n == 0:
|
||
|
# static raise will fail in objmode if the IR is modified by
|
||
|
# rewrite pass
|
||
|
raise ValueError()
|
||
|
return test0(n) # trigger objmode fallback
|
||
|
|
||
|
compiled = jit(forceobj=True)(test1)
|
||
|
self.assertEqual(test1(10), compiled(10))
|
||
|
self._ensure_objmode(compiled)
|
||
|
|
||
|
def test_static_setitem_in_objmode_fallback(self):
|
||
|
"""
|
||
|
Test code based on user submitted issue at
|
||
|
https://github.com/numba/numba/issues/2169
|
||
|
"""
|
||
|
|
||
|
def test0(n):
|
||
|
return n
|
||
|
|
||
|
def test(a1, a2):
|
||
|
a1 = np.asarray(a1)
|
||
|
# static setitem here will fail in objmode if the IR is modified by
|
||
|
# rewrite pass
|
||
|
a2[0] = 1
|
||
|
return test0(a1.sum() + a2.sum()) # trigger objmode fallback
|
||
|
|
||
|
compiled = jit(forceobj=True)(test)
|
||
|
args = np.array([3]), np.array([4])
|
||
|
self.assertEqual(test(*args), compiled(*args))
|
||
|
self._ensure_objmode(compiled)
|
||
|
|
||
|
def test_dynamic_func_objmode(self):
|
||
|
"""
|
||
|
Test issue https://github.com/numba/numba/issues/3355
|
||
|
"""
|
||
|
func_text = "def func():\n"
|
||
|
func_text += " np.array([1,2,3])\n"
|
||
|
loc_vars = {}
|
||
|
custom_globals = {'np': np}
|
||
|
exec(func_text, custom_globals, loc_vars)
|
||
|
func = loc_vars['func']
|
||
|
jitted = jit(forceobj=True)(func)
|
||
|
jitted()
|
||
|
|
||
|
|
||
|
if __name__ == '__main__':
|
||
|
unittest.main()
|