import unittest import sys import numpy from numba import jit, njit from numba.core import types, utils from numba.tests.support import tag from numba.cpython.rangeobj import length_of_iterator def loop1(n): s = 0 for i in range(n): s += i return s def loop2(a, b): s = 0 for i in range(a, b): s += i return s def loop3(a, b, c): s = 0 for i in range(a, b, c): s += i return s def range_len1(n): return len(range(n)) def range_len2(a, b): return len(range(a, b)) def range_len3(a, b, c): return len(range(a, b, c)) def range_iter_len1(a): return length_of_iterator(iter(range(a))) def range_iter_len2(a): return length_of_iterator(iter(a)) def range_attrs(start, stop, step): r1 = range(start) r2 = range(start, stop) r3 = range(start, stop, step) tmp = [] for r in (r1, r2, r3): tmp.append((r.start, r.stop, r.step)) return tmp def range_contains(val, start, stop, step): r1 = range(start) r2 = range(start, stop) r3 = range(start, stop, step) return [val in r for r in (r1, r2, r3)] class TestRange(unittest.TestCase): def test_loop1_int16(self): pyfunc = loop1 cfunc = njit((types.int16,))(pyfunc) self.assertTrue(cfunc(5), pyfunc(5)) def test_loop2_int16(self): pyfunc = loop2 cfunc = njit((types.int16, types.int16))(pyfunc) self.assertTrue(cfunc(1, 6), pyfunc(1, 6)) def test_loop3_int32(self): pyfunc = loop3 cfunc = njit((types.int32, types.int32, types.int32))(pyfunc) arglist = [ (1, 2, 1), (2, 8, 3), (-10, -11, -10), (-10, -10, -2), ] for args in arglist: self.assertEqual(cfunc(*args), pyfunc(*args)) def test_range_len1(self): pyfunc = range_len1 typelist = [types.int16, types.int32, types.int64] arglist = [5, 0, -5] for typ in typelist: cfunc = njit((typ,))(pyfunc) for arg in arglist: self.assertEqual(cfunc(typ(arg)), pyfunc(typ(arg))) def test_range_len2(self): pyfunc = range_len2 typelist = [types.int16, types.int32, types.int64] arglist = [(1,6), (6,1), (-5, -1)] for typ in typelist: cfunc = njit((typ, typ))(pyfunc) for args in arglist: args_ = tuple(typ(x) for x in args) self.assertEqual(cfunc(*args_), pyfunc(*args_)) def test_range_len3(self): pyfunc = range_len3 typelist = [types.int16, types.int32, types.int64] arglist = [ (1, 2, 1), (2, 8, 3), (-10, -11, -10), (-10, -10, -2), ] for typ in typelist: cfunc = njit((typ, typ, typ))(pyfunc) for args in arglist: args_ = tuple(typ(x) for x in args) self.assertEqual(cfunc(*args_), pyfunc(*args_)) def test_range_iter_len1(self): range_func = range_len1 range_iter_func = range_iter_len1 typelist = [types.int16, types.int32, types.int64] arglist = [5, 0, -5] for typ in typelist: cfunc = njit((typ,))(range_iter_func) for arg in arglist: self.assertEqual(cfunc(typ(arg)), range_func(typ(arg))) def test_range_iter_list(self): range_iter_func = range_iter_len2 cfunc = njit((types.List(types.intp, reflected=True),))(range_iter_func) arglist = [1, 2, 3, 4, 5] self.assertEqual(cfunc(arglist), len(arglist)) def test_range_attrs(self): pyfunc = range_attrs arglist = [(0, 0, 1), (0, -1, 1), (-1, 1, 1), (-1, 4, 1), (-1, 4, 10), (5, -5, -2),] cfunc = njit((types.int64, types.int64, types.int64),)(pyfunc) for arg in arglist: self.assertEqual(cfunc(*arg), pyfunc(*arg)) def test_range_contains(self): pyfunc = range_contains arglist = [(0, 0, 1), (-1, 0, 1), (1, 0, -1), (0, -1, 1), (0, 1, -1), (-1, 1, 1), (-1, 4, 1), (-1, 4, 10), (5, -5, -2),] bool_vals = [True, False] int_vals = [-10, -6, -5, -4, -2, -1, 0, 1, 2, 4, 5, 6, 10] float_vals = [-1.1, -1.0, 0.0, 1.0, 1.1] complex_vals = [1 + 0j, 1 + 1j, 1.1 + 0j, 1.0 + 1.1j] vallist = (bool_vals + int_vals + float_vals + complex_vals) cfunc = njit(pyfunc) for arg in arglist: for val in vallist: self.assertEqual(cfunc(val, *arg), pyfunc(val, *arg)) non_numeric_vals = [{'a': 1}, [1, ], 'abc', (1,)] cfunc_obj = jit(pyfunc, forceobj=True) for arg in arglist: for val in non_numeric_vals: self.assertEqual(cfunc_obj(val, *arg), pyfunc(val, *arg)) if __name__ == '__main__': unittest.main()