ai-content-maker/.venv/Lib/site-packages/numba/tests/test_recursion.py

130 lines
4.5 KiB
Python
Raw Normal View History

2024-05-03 04:18:51 +03:00
import math
import warnings
from numba import jit
from numba.core.errors import TypingError, NumbaWarning
from numba.tests.support import TestCase
import unittest
class TestSelfRecursion(TestCase):
def check_fib(self, cfunc):
self.assertPreciseEqual(cfunc(10), 55)
def test_global_explicit_sig(self):
from numba.tests.recursion_usecases import fib1
self.check_fib(fib1)
def test_inner_explicit_sig(self):
from numba.tests.recursion_usecases import fib2
self.check_fib(fib2)
def test_global_implicit_sig(self):
from numba.tests.recursion_usecases import fib3
self.check_fib(fib3)
def test_runaway(self):
from numba.tests.recursion_usecases import runaway_self
with self.assertRaises(TypingError) as raises:
runaway_self(123)
self.assertIn("cannot type infer runaway recursion",
str(raises.exception))
def test_type_change(self):
from numba.tests.recursion_usecases import make_type_change_self
pfunc = make_type_change_self()
cfunc = make_type_change_self(jit(nopython=True))
args = 13, 0.125
self.assertPreciseEqual(pfunc(*args), cfunc(*args))
def test_raise(self):
from numba.tests.recursion_usecases import raise_self
with self.assertRaises(ValueError) as raises:
raise_self(3)
self.assertEqual(str(raises.exception), "raise_self")
def test_optional_return(self):
from numba.tests.recursion_usecases import make_optional_return_case
pfunc = make_optional_return_case()
cfunc = make_optional_return_case(jit(nopython=True))
for arg in (0, 5, 10, 15):
self.assertEqual(pfunc(arg), cfunc(arg))
def test_growing_return_tuple(self):
from numba.tests.recursion_usecases import make_growing_tuple_case
cfunc = make_growing_tuple_case(jit(nopython=True))
with self.assertRaises(TypingError) as raises:
cfunc(100)
self.assertIn(
"Return type of recursive function does not converge",
str(raises.exception),
)
class TestMutualRecursion(TestCase):
def test_mutual_1(self):
from numba.tests.recursion_usecases import outer_fac
expect = math.factorial(10)
self.assertPreciseEqual(outer_fac(10), expect)
def test_mutual_2(self):
from numba.tests.recursion_usecases import make_mutual2
pfoo, pbar = make_mutual2()
cfoo, cbar = make_mutual2(jit(nopython=True))
for x in [-1, 0, 1, 3]:
self.assertPreciseEqual(pfoo(x=x), cfoo(x=x))
self.assertPreciseEqual(pbar(y=x, z=1), cbar(y=x, z=1))
def test_runaway(self):
from numba.tests.recursion_usecases import runaway_mutual
with self.assertRaises(TypingError) as raises:
runaway_mutual(123)
self.assertIn("cannot type infer runaway recursion",
str(raises.exception))
def test_type_change(self):
from numba.tests.recursion_usecases import make_type_change_mutual
pfunc = make_type_change_mutual()
cfunc = make_type_change_mutual(jit(nopython=True))
args = 13, 0.125
self.assertPreciseEqual(pfunc(*args), cfunc(*args))
def test_four_level(self):
from numba.tests.recursion_usecases import make_four_level
pfunc = make_four_level()
cfunc = make_four_level(jit(nopython=True))
arg = 7
self.assertPreciseEqual(pfunc(arg), cfunc(arg))
def test_inner_error(self):
from numba.tests.recursion_usecases import make_inner_error
# nopython mode
cfunc = make_inner_error(jit(nopython=True))
with self.assertRaises(TypingError) as raises:
cfunc(2)
errmsg = 'Unknown attribute \'ndim\''
self.assertIn(errmsg, str(raises.exception))
# objectmode
# error is never trigger, function return normally
cfunc = make_inner_error(jit(forceobj=True))
pfunc = make_inner_error()
with warnings.catch_warnings():
warnings.filterwarnings('ignore', category=NumbaWarning)
got = cfunc(6)
self.assertEqual(got, pfunc(6))
def test_raise(self):
from numba.tests.recursion_usecases import make_raise_mutual
cfunc = make_raise_mutual()#jit(nopython=True))
with self.assertRaises(ValueError) as raises:
cfunc(2)
self.assertEqual(str(raises.exception), "raise_mutual")
if __name__ == '__main__':
unittest.main()