441 lines
12 KiB
Python
441 lines
12 KiB
Python
"""
|
|
Test mutable struct, aka, structref
|
|
"""
|
|
import warnings
|
|
|
|
import numpy as np
|
|
|
|
from numba import typed, njit, errors, typeof
|
|
from numba.core import types
|
|
from numba.experimental import structref
|
|
from numba.extending import overload_method, overload_attribute
|
|
from numba.tests.support import (
|
|
MemoryLeakMixin, TestCase, temp_directory, override_config,
|
|
)
|
|
|
|
|
|
@structref.register
|
|
class MySimplerStructType(types.StructRef):
|
|
"""
|
|
Test associated with this type represent the lowest level uses of structref.
|
|
"""
|
|
pass
|
|
|
|
|
|
my_struct_ty = MySimplerStructType(
|
|
fields=[("values", types.intp[:]), ("counter", types.intp)]
|
|
)
|
|
|
|
structref.define_boxing(MySimplerStructType, structref.StructRefProxy)
|
|
|
|
|
|
class MyStruct(structref.StructRefProxy):
|
|
|
|
def __new__(cls, values, counter):
|
|
# Define this method to customize the constructor.
|
|
# The default takes `*args`. Customizing allow the use of keyword-arg.
|
|
# The impl of the method calls `StructRefProxy.__new__`
|
|
return structref.StructRefProxy.__new__(cls, values, counter)
|
|
|
|
# The below defines wrappers for attributes and methods manually
|
|
|
|
@property
|
|
def values(self):
|
|
return get_values(self)
|
|
|
|
@values.setter
|
|
def values(self, val):
|
|
return set_values(self, val)
|
|
|
|
@property
|
|
def counter(self):
|
|
return get_counter(self)
|
|
|
|
def testme(self, arg):
|
|
return self.values * arg + self.counter
|
|
|
|
@property
|
|
def prop(self):
|
|
return self.values, self.counter
|
|
|
|
def __hash__(self):
|
|
return compute_fields(self)
|
|
|
|
|
|
@structref.register
|
|
class MyStructType(types.StructRef):
|
|
"""Test associated with this type represent the higher-level uses of
|
|
structef.
|
|
"""
|
|
pass
|
|
|
|
|
|
# Call to define_proxy is needed to register the use of `MyStruct` as a
|
|
# PyObject proxy for creating a Numba-allocated structref.
|
|
# The `MyStruct` class can then be used in both jit-code and interpreted-code.
|
|
structref.define_proxy(
|
|
MyStruct,
|
|
MyStructType,
|
|
['values', 'counter'],
|
|
)
|
|
|
|
|
|
@njit
|
|
def my_struct(values, counter):
|
|
st = structref.new(my_struct_ty)
|
|
my_struct_init(st, values, counter)
|
|
return st
|
|
|
|
|
|
@njit
|
|
def my_struct_init(self, values, counter):
|
|
self.values = values
|
|
self.counter = counter
|
|
|
|
|
|
@njit
|
|
def ctor_by_intrinsic(vs, ctr):
|
|
st = my_struct(vs, counter=ctr)
|
|
st.values += st.values
|
|
st.counter *= ctr
|
|
return st
|
|
|
|
|
|
@njit
|
|
def ctor_by_class(vs, ctr):
|
|
return MyStruct(values=vs, counter=ctr)
|
|
|
|
|
|
@njit
|
|
def get_values(st):
|
|
return st.values
|
|
|
|
|
|
@njit
|
|
def set_values(st, val):
|
|
st.values = val
|
|
|
|
|
|
@njit
|
|
def get_counter(st):
|
|
return st.counter
|
|
|
|
|
|
@njit
|
|
def compute_fields(st):
|
|
return st.values + st.counter
|
|
|
|
|
|
class TestStructRefBasic(MemoryLeakMixin, TestCase):
|
|
def test_structref_type(self):
|
|
sr = types.StructRef([('a', types.int64)])
|
|
self.assertEqual(sr.field_dict['a'], types.int64)
|
|
sr = types.StructRef([('a', types.int64), ('b', types.float64)])
|
|
self.assertEqual(sr.field_dict['a'], types.int64)
|
|
self.assertEqual(sr.field_dict['b'], types.float64)
|
|
# bad case
|
|
with self.assertRaisesRegex(ValueError,
|
|
"expecting a str for field name"):
|
|
types.StructRef([(1, types.int64)])
|
|
with self.assertRaisesRegex(ValueError,
|
|
"expecting a Numba Type for field type"):
|
|
types.StructRef([('a', 123)])
|
|
|
|
def test_invalid_uses(self):
|
|
with self.assertRaisesRegex(ValueError, "cannot register"):
|
|
structref.register(types.StructRef)
|
|
with self.assertRaisesRegex(ValueError, "cannot register"):
|
|
structref.define_boxing(types.StructRef, MyStruct)
|
|
|
|
def test_MySimplerStructType(self):
|
|
vs = np.arange(10, dtype=np.intp)
|
|
ctr = 13
|
|
|
|
first_expected = vs + vs
|
|
first_got = ctor_by_intrinsic(vs, ctr)
|
|
# the returned instance is a structref.StructRefProxy
|
|
# but not a MyStruct
|
|
self.assertNotIsInstance(first_got, MyStruct)
|
|
self.assertPreciseEqual(first_expected, get_values(first_got))
|
|
|
|
second_expected = first_expected + (ctr * ctr)
|
|
second_got = compute_fields(first_got)
|
|
self.assertPreciseEqual(second_expected, second_got)
|
|
|
|
def test_MySimplerStructType_wrapper_has_no_attrs(self):
|
|
vs = np.arange(10, dtype=np.intp)
|
|
ctr = 13
|
|
wrapper = ctor_by_intrinsic(vs, ctr)
|
|
self.assertIsInstance(wrapper, structref.StructRefProxy)
|
|
with self.assertRaisesRegex(AttributeError, 'values'):
|
|
wrapper.values
|
|
with self.assertRaisesRegex(AttributeError, 'counter'):
|
|
wrapper.counter
|
|
|
|
def test_MyStructType(self):
|
|
vs = np.arange(10, dtype=np.float64)
|
|
ctr = 11
|
|
|
|
first_expected_arr = vs.copy()
|
|
first_got = ctor_by_class(vs, ctr)
|
|
self.assertIsInstance(first_got, MyStruct)
|
|
self.assertPreciseEqual(first_expected_arr, first_got.values)
|
|
|
|
second_expected = first_expected_arr + ctr
|
|
second_got = compute_fields(first_got)
|
|
self.assertPreciseEqual(second_expected, second_got)
|
|
self.assertEqual(first_got.counter, ctr)
|
|
|
|
def test_MyStructType_mixed_types(self):
|
|
# structref constructor is generic
|
|
@njit
|
|
def mixed_type(x, y, m, n):
|
|
return MyStruct(x, y), MyStruct(m, n)
|
|
|
|
a, b = mixed_type(1, 2.3, 3.4j, (4,))
|
|
self.assertEqual(a.values, 1)
|
|
self.assertEqual(a.counter, 2.3)
|
|
self.assertEqual(b.values, 3.4j)
|
|
self.assertEqual(b.counter, (4,))
|
|
|
|
def test_MyStructType_in_dict(self):
|
|
td = typed.Dict()
|
|
td['a'] = MyStruct(1, 2.3)
|
|
self.assertEqual(td['a'].values, 1)
|
|
self.assertEqual(td['a'].counter, 2.3)
|
|
# overwrite
|
|
td['a'] = MyStruct(2, 3.3)
|
|
self.assertEqual(td['a'].values, 2)
|
|
self.assertEqual(td['a'].counter, 3.3)
|
|
# mutate
|
|
td['a'].values += 10
|
|
self.assertEqual(td['a'].values, 12) # changed
|
|
self.assertEqual(td['a'].counter, 3.3) # unchanged
|
|
# insert
|
|
td['b'] = MyStruct(4, 5.6)
|
|
|
|
def test_MyStructType_in_dict_mixed_type_error(self):
|
|
self.disable_leak_check()
|
|
|
|
td = typed.Dict()
|
|
td['a'] = MyStruct(1, 2.3)
|
|
self.assertEqual(td['a'].values, 1)
|
|
self.assertEqual(td['a'].counter, 2.3)
|
|
|
|
# ERROR: store different types
|
|
with self.assertRaisesRegex(errors.TypingError,
|
|
r"Cannot cast numba.MyStructType"):
|
|
# because first field is not a float;
|
|
# the second field is now an integer.
|
|
td['b'] = MyStruct(2.3, 1)
|
|
|
|
def test_MyStructType_hash_no_typeof_recursion(self):
|
|
# Tests that __hash__ is not called prematurely in typeof
|
|
# causing infinite recursion (see #8241).
|
|
st = MyStruct(1, 2)
|
|
typeof(st)
|
|
|
|
self.assertEqual(hash(st), 3)
|
|
|
|
|
|
@overload_method(MyStructType, "testme")
|
|
def _ol_mystructtype_testme(self, arg):
|
|
def impl(self, arg):
|
|
return self.values * arg + self.counter
|
|
return impl
|
|
|
|
|
|
@overload_attribute(MyStructType, "prop")
|
|
def _ol_mystructtype_prop(self):
|
|
def get(self):
|
|
return self.values, self.counter
|
|
return get
|
|
|
|
|
|
class TestStructRefExtending(MemoryLeakMixin, TestCase):
|
|
def test_overload_method(self):
|
|
@njit
|
|
def check(x):
|
|
vs = np.arange(10, dtype=np.float64)
|
|
ctr = 11
|
|
obj = MyStruct(vs, ctr)
|
|
return obj.testme(x)
|
|
|
|
x = 3
|
|
got = check(x)
|
|
expect = check.py_func(x)
|
|
self.assertPreciseEqual(got, expect)
|
|
|
|
def test_overload_attribute(self):
|
|
@njit
|
|
def check():
|
|
vs = np.arange(10, dtype=np.float64)
|
|
ctr = 11
|
|
obj = MyStruct(vs, ctr)
|
|
return obj.prop
|
|
|
|
got = check()
|
|
expect = check.py_func()
|
|
self.assertPreciseEqual(got, expect)
|
|
|
|
|
|
def caching_test_make(x, y):
|
|
struct = MyStruct(values=x, counter=y)
|
|
return struct
|
|
|
|
|
|
def caching_test_use(struct, z):
|
|
return struct.testme(z)
|
|
|
|
|
|
class TestStructRefCaching(MemoryLeakMixin, TestCase):
|
|
def setUp(self):
|
|
self._cache_dir = temp_directory(TestStructRefCaching.__name__)
|
|
self._cache_override = override_config('CACHE_DIR', self._cache_dir)
|
|
self._cache_override.__enter__()
|
|
warnings.simplefilter("error")
|
|
warnings.filterwarnings(action="ignore", module="typeguard")
|
|
|
|
def tearDown(self):
|
|
self._cache_override.__exit__(None, None, None)
|
|
warnings.resetwarnings()
|
|
|
|
def test_structref_caching(self):
|
|
def assert_cached(stats):
|
|
self.assertEqual(len(stats.cache_hits), 1)
|
|
self.assertEqual(len(stats.cache_misses), 0)
|
|
|
|
def assert_not_cached(stats):
|
|
self.assertEqual(len(stats.cache_hits), 0)
|
|
self.assertEqual(len(stats.cache_misses), 1)
|
|
|
|
def check(cached):
|
|
check_make = njit(cache=True)(caching_test_make)
|
|
check_use = njit(cache=True)(caching_test_use)
|
|
vs = np.random.random(3)
|
|
ctr = 17
|
|
factor = 3
|
|
st = check_make(vs, ctr)
|
|
got = check_use(st, factor)
|
|
expect = vs * factor + ctr
|
|
self.assertPreciseEqual(got, expect)
|
|
|
|
if cached:
|
|
assert_cached(check_make.stats)
|
|
assert_cached(check_use.stats)
|
|
else:
|
|
assert_not_cached(check_make.stats)
|
|
assert_not_cached(check_use.stats)
|
|
|
|
check(cached=False)
|
|
check(cached=True)
|
|
|
|
|
|
@structref.register
|
|
class PolygonStructType(types.StructRef):
|
|
|
|
def preprocess_fields(self, fields):
|
|
# temp name to allow Optional instantiation
|
|
self.name = f"numba.PolygonStructType#{id(self)}"
|
|
fields = tuple([
|
|
('value', types.Optional(types.int64)),
|
|
('parent', types.Optional(self)),
|
|
])
|
|
|
|
return fields
|
|
|
|
|
|
polygon_struct_type = PolygonStructType(fields=(
|
|
('value', types.Any),
|
|
('parent', types.Any)
|
|
))
|
|
|
|
|
|
class PolygonStruct(structref.StructRefProxy):
|
|
def __new__(cls, value, parent):
|
|
return structref.StructRefProxy.__new__(cls, value, parent)
|
|
|
|
@property
|
|
def value(self):
|
|
return PolygonStruct_get_value(self)
|
|
|
|
@property
|
|
def parent(self):
|
|
return PolygonStruct_get_parent(self)
|
|
|
|
|
|
@njit
|
|
def PolygonStruct_get_value(self):
|
|
return self.value
|
|
|
|
|
|
@njit
|
|
def PolygonStruct_get_parent(self):
|
|
return self.parent
|
|
|
|
|
|
structref.define_proxy(
|
|
PolygonStruct,
|
|
PolygonStructType,
|
|
["value", "parent"]
|
|
)
|
|
|
|
|
|
@overload_method(PolygonStructType, "flip")
|
|
def _ol_polygon_struct_flip(self):
|
|
def impl(self):
|
|
if self.value is not None:
|
|
self.value = -self.value
|
|
return impl
|
|
|
|
|
|
@overload_attribute(PolygonStructType, "prop")
|
|
def _ol_polygon_struct_prop(self):
|
|
def get(self):
|
|
return self.value, self.parent
|
|
return get
|
|
|
|
|
|
class TestStructRefForwardTyping(MemoryLeakMixin, TestCase):
|
|
def test_same_type_assignment(self):
|
|
@njit
|
|
def check(x):
|
|
poly = PolygonStruct(None, None)
|
|
p_poly = PolygonStruct(None, None)
|
|
poly.value = x
|
|
poly.parent = p_poly
|
|
p_poly.value = x
|
|
return poly.parent.value
|
|
|
|
x = 11
|
|
got = check(x)
|
|
expect = x
|
|
self.assertPreciseEqual(got, expect)
|
|
|
|
def test_overload_method(self):
|
|
@njit
|
|
def check(x):
|
|
poly = PolygonStruct(None, None)
|
|
p_poly = PolygonStruct(None, None)
|
|
poly.value = x
|
|
poly.parent = p_poly
|
|
p_poly.value = x
|
|
poly.flip()
|
|
poly.parent.flip()
|
|
return poly.parent.value
|
|
|
|
x = 3
|
|
got = check(x)
|
|
expect = -x
|
|
self.assertPreciseEqual(got, expect)
|
|
|
|
def test_overload_attribute(self):
|
|
@njit
|
|
def check():
|
|
obj = PolygonStruct(5, None)
|
|
return obj.prop[0]
|
|
|
|
got = check()
|
|
expect = 5
|
|
self.assertPreciseEqual(got, expect)
|