ai-content-maker/.venv/Lib/site-packages/numba/tests/test_globals.py

267 lines
7.2 KiB
Python

import numpy as np
from numba import jit, njit, errors
from numba.extending import register_jitable
from numba.tests import usecases
import unittest
X = np.arange(10)
def global_ndarray_func(x):
y = x + X.shape[0]
return y
# Create complex array with real and imaginary parts of distinct value
cplx_X = np.arange(10, dtype=np.complex128)
tmp = np.arange(10, dtype=np.complex128)
cplx_X += (tmp+10)*1j
def global_cplx_arr_copy(a):
for i in range(len(a)):
a[i] = cplx_X[i]
# Create a recarray with fields of distinct value
x_dt = np.dtype([('a', np.int32), ('b', np.float32)])
rec_X = np.recarray(10, dtype=x_dt)
for i in range(len(rec_X)):
rec_X[i].a = i
rec_X[i].b = i + 0.5
def global_rec_arr_copy(a):
for i in range(len(a)):
a[i] = rec_X[i]
def global_rec_arr_extract_fields(a, b):
for i in range(len(a)):
a[i] = rec_X[i].a
b[i] = rec_X[i].b
# Create additional global recarray
y_dt = np.dtype([('c', np.int16), ('d', np.float64)])
rec_Y = np.recarray(10, dtype=y_dt)
for i in range(len(rec_Y)):
rec_Y[i].c = i + 10
rec_Y[i].d = i + 10.5
def global_two_rec_arrs(a, b, c, d):
for i in range(len(a)):
a[i] = rec_X[i].a
b[i] = rec_X[i].b
c[i] = rec_Y[i].c
d[i] = rec_Y[i].d
# Test a global record
record_only_X = np.recarray(1, dtype=x_dt)[0]
record_only_X.a = 1
record_only_X.b = 1.5
@jit(nopython=True)
def global_record_func(x):
return x.a == record_only_X.a
@jit(nopython=True)
def global_module_func(x, y):
return usecases.andornopython(x, y)
# Test a global tuple
tup_int = (1, 2)
tup_str = ('a', 'b')
tup_mixed = (1, 'a')
tup_float = (1.2, 3.5)
tup_npy_ints = (np.uint64(12), np.int8(3))
tup_tup_array = ((np.ones(5),),)
mixed_tup_tup_array = (('Z', np.ones(5),), 2j, 'A')
def global_int_tuple():
return tup_int[0] + tup_int[1]
def global_str_tuple():
return tup_str[0] + tup_str[1]
def global_mixed_tuple():
idx = tup_mixed[0]
field = tup_mixed[1]
return rec_X[idx][field]
def global_float_tuple():
return tup_float[0] + tup_float[1]
def global_npy_int_tuple():
return tup_npy_ints[0] + tup_npy_ints[1]
def global_write_to_arr_in_tuple():
tup_tup_array[0][0][0] = 10.
def global_write_to_arr_in_mixed_tuple():
mixed_tup_tup_array[0][1][0] = 10.
_glbl_np_bool_T = np.bool_(True)
_glbl_np_bool_F = np.bool_(False)
@register_jitable # consumer function
def _sink(*args):
pass
def global_npy_bool():
_sink(_glbl_np_bool_T, _glbl_np_bool_F)
return _glbl_np_bool_T, _glbl_np_bool_F
class TestGlobals(unittest.TestCase):
def check_global_ndarray(self, **jitargs):
# (see github issue #448)
ctestfunc = jit(**jitargs)(global_ndarray_func)
self.assertEqual(ctestfunc(1), 11)
def test_global_ndarray(self):
# This also checks we can access an unhashable global value
# (see issue #697)
self.check_global_ndarray(forceobj=True)
def test_global_ndarray_npm(self):
self.check_global_ndarray(nopython=True)
def check_global_complex_arr(self, **jitargs):
# (see github issue #897)
ctestfunc = jit(**jitargs)(global_cplx_arr_copy)
arr = np.zeros(len(cplx_X), dtype=np.complex128)
ctestfunc(arr)
np.testing.assert_equal(arr, cplx_X)
def test_global_complex_arr(self):
self.check_global_complex_arr(forceobj=True)
def test_global_complex_arr_npm(self):
self.check_global_complex_arr(nopython=True)
def check_global_rec_arr(self, **jitargs):
# (see github issue #897)
ctestfunc = jit(**jitargs)(global_rec_arr_copy)
arr = np.zeros(rec_X.shape, dtype=x_dt)
ctestfunc(arr)
np.testing.assert_equal(arr, rec_X)
def test_global_rec_arr(self):
self.check_global_rec_arr(forceobj=True)
def test_global_rec_arr_npm(self):
self.check_global_rec_arr(nopython=True)
def check_global_rec_arr_extract(self, **jitargs):
# (see github issue #897)
ctestfunc = jit(**jitargs)(global_rec_arr_extract_fields)
arr1 = np.zeros(rec_X.shape, dtype=np.int32)
arr2 = np.zeros(rec_X.shape, dtype=np.float32)
ctestfunc(arr1, arr2)
np.testing.assert_equal(arr1, rec_X.a)
np.testing.assert_equal(arr2, rec_X.b)
def test_global_rec_arr_extract(self):
self.check_global_rec_arr_extract(forceobj=True)
def test_global_rec_arr_extract_npm(self):
self.check_global_rec_arr_extract(nopython=True)
def check_two_global_rec_arrs(self, **jitargs):
# (see github issue #897)
ctestfunc = jit(**jitargs)(global_two_rec_arrs)
arr1 = np.zeros(rec_X.shape, dtype=np.int32)
arr2 = np.zeros(rec_X.shape, dtype=np.float32)
arr3 = np.zeros(rec_Y.shape, dtype=np.int16)
arr4 = np.zeros(rec_Y.shape, dtype=np.float64)
ctestfunc(arr1, arr2, arr3, arr4)
np.testing.assert_equal(arr1, rec_X.a)
np.testing.assert_equal(arr2, rec_X.b)
np.testing.assert_equal(arr3, rec_Y.c)
np.testing.assert_equal(arr4, rec_Y.d)
def test_two_global_rec_arrs(self):
self.check_two_global_rec_arrs(forceobj=True)
def test_two_global_rec_arrs_npm(self):
self.check_two_global_rec_arrs(nopython=True)
def test_global_module(self):
# (see github issue #1059)
res = global_module_func(5, 6)
self.assertEqual(True, res)
def test_global_record(self):
# (see github issue #1081)
x = np.recarray(1, dtype=x_dt)[0]
x.a = 1
res = global_record_func(x)
self.assertEqual(True, res)
x.a = 2
res = global_record_func(x)
self.assertEqual(False, res)
def test_global_int_tuple(self):
pyfunc = global_int_tuple
jitfunc = njit(pyfunc)
self.assertEqual(pyfunc(), jitfunc())
def test_global_str_tuple(self):
pyfunc = global_str_tuple
jitfunc = njit(pyfunc)
self.assertEqual(pyfunc(), jitfunc())
def test_global_mixed_tuple(self):
pyfunc = global_mixed_tuple
jitfunc = njit(pyfunc)
self.assertEqual(pyfunc(), jitfunc())
def test_global_float_tuple(self):
pyfunc = global_float_tuple
jitfunc = njit(pyfunc)
self.assertEqual(pyfunc(), jitfunc())
def test_global_npy_int_tuple(self):
pyfunc = global_npy_int_tuple
jitfunc = njit(pyfunc)
self.assertEqual(pyfunc(), jitfunc())
def test_global_write_to_arr_in_tuple(self):
# Test writing to an array in a global tuple
# See issue https://github.com/numba/numba/issues/7120
for func in (global_write_to_arr_in_tuple,
global_write_to_arr_in_mixed_tuple):
jitfunc = njit(func)
with self.assertRaises(errors.TypingError) as e:
jitfunc()
msg = "Cannot modify readonly array of type:"
self.assertIn(msg, str(e.exception))
def test_global_npy_bool(self):
# Test global NumPy bool
# See issue https://github.com/numba/numba/issues/6979
pyfunc = global_npy_bool
jitfunc = njit(pyfunc)
self.assertEqual(pyfunc(), jitfunc())
if __name__ == '__main__':
unittest.main()