import ctypes import itertools import pickle import random import typing as pt import unittest from collections import OrderedDict import numpy as np from numba import (boolean, deferred_type, float32, float64, int16, int32, njit, optional, typeof) from numba.core import errors, types from numba.core.dispatcher import Dispatcher from numba.core.errors import LoweringError, TypingError from numba.core.runtime.nrt import MemInfo from numba.experimental import jitclass from numba.experimental.jitclass import _box from numba.experimental.jitclass.base import JitClassType from numba.tests.support import MemoryLeakMixin, TestCase, skip_if_typeguard from numba.tests.support import skip_unless_scipy class TestClass1(object): def __init__(self, x, y, z=1, *, a=5): self.x = x self.y = y self.z = z self.a = a class TestClass2(object): def __init__(self, x, y, z=1, *args, a=5): self.x = x self.y = y self.z = z self.args = args self.a = a def _get_meminfo(box): ptr = _box.box_get_meminfoptr(box) mi = MemInfo(ptr) mi.acquire() return mi class TestJitClass(TestCase, MemoryLeakMixin): def _check_spec(self, spec=None, test_cls=None, all_expected=None): if test_cls is None: @jitclass(spec) class Test(object): def __init__(self): pass test_cls = Test clsty = test_cls.class_type.instance_type names = list(clsty.struct.keys()) values = list(clsty.struct.values()) if all_expected is None: if isinstance(spec, OrderedDict): all_expected = spec.items() else: all_expected = spec assert all_expected is not None self.assertEqual(len(names), len(all_expected)) for got, expected in zip(zip(names, values), all_expected): self.assertEqual(got[0], expected[0]) self.assertEqual(got[1], expected[1]) def test_ordereddict_spec(self): spec = OrderedDict() spec["x"] = int32 spec["y"] = float32 self._check_spec(spec) def test_list_spec(self): spec = [("x", int32), ("y", float32)] self._check_spec(spec) def test_type_annotations(self): spec = [("x", int32)] @jitclass(spec) class Test1(object): x: int y: pt.List[float] def __init__(self): pass self._check_spec(spec, Test1, spec + [("y", types.ListType(float64))]) def test_type_annotation_inheritance(self): class Foo: x: int @jitclass class Bar(Foo): y: float def __init__(self, value: float) -> None: self.x = int(value) self.y = value self._check_spec( test_cls=Bar, all_expected=[("x", typeof(0)), ("y", typeof(0.0))] ) def test_spec_errors(self): spec1 = [("x", int), ("y", float32[:])] spec2 = [(1, int32), ("y", float32[:])] class Test(object): def __init__(self): pass with self.assertRaises(TypeError) as raises: jitclass(Test, spec1) self.assertIn("spec values should be Numba type instances", str(raises.exception)) with self.assertRaises(TypeError) as raises: jitclass(Test, spec2) self.assertEqual(str(raises.exception), "spec keys should be strings, got 1") def test_init_errors(self): @jitclass([]) class Test: def __init__(self): return 7 with self.assertRaises(errors.TypingError) as raises: Test() self.assertIn("__init__() should return None, not", str(raises.exception)) def _make_Float2AndArray(self): spec = OrderedDict() spec["x"] = float32 spec["y"] = float32 spec["arr"] = float32[:] @jitclass(spec) class Float2AndArray(object): def __init__(self, x, y, arr): self.x = x self.y = y self.arr = arr def add(self, val): self.x += val self.y += val return val return Float2AndArray def _make_Vector2(self): spec = OrderedDict() spec["x"] = int32 spec["y"] = int32 @jitclass(spec) class Vector2(object): def __init__(self, x, y): self.x = x self.y = y return Vector2 def test_jit_class_1(self): Float2AndArray = self._make_Float2AndArray() Vector2 = self._make_Vector2() @njit def bar(obj): return obj.x + obj.y @njit def foo(a): obj = Float2AndArray(1, 2, a) obj.add(123) vec = Vector2(3, 4) return bar(obj), bar(vec), obj.arr inp = np.ones(10, dtype=np.float32) a, b, c = foo(inp) self.assertEqual(a, 123 + 1 + 123 + 2) self.assertEqual(b, 3 + 4) self.assertPreciseEqual(c, inp) def test_jitclass_usage_from_python(self): Float2AndArray = self._make_Float2AndArray() @njit def identity(obj): return obj @njit def retrieve_attributes(obj): return obj.x, obj.y, obj.arr arr = np.arange(10, dtype=np.float32) obj = Float2AndArray(1, 2, arr) obj_meminfo = _get_meminfo(obj) self.assertEqual(obj_meminfo.refcount, 2) self.assertEqual(obj_meminfo.data, _box.box_get_dataptr(obj)) self.assertEqual(obj._numba_type_.class_type, Float2AndArray.class_type) # Use jit class instance in numba other = identity(obj) other_meminfo = _get_meminfo(other) # duplicates MemInfo object to obj self.assertEqual(obj_meminfo.refcount, 4) self.assertEqual(other_meminfo.refcount, 4) self.assertEqual(other_meminfo.data, _box.box_get_dataptr(other)) self.assertEqual(other_meminfo.data, obj_meminfo.data) # Check dtor del other, other_meminfo self.assertEqual(obj_meminfo.refcount, 2) # Check attributes out_x, out_y, out_arr = retrieve_attributes(obj) self.assertEqual(out_x, 1) self.assertEqual(out_y, 2) self.assertIs(out_arr, arr) # Access attributes from python self.assertEqual(obj.x, 1) self.assertEqual(obj.y, 2) self.assertIs(obj.arr, arr) # Access methods from python self.assertEqual(obj.add(123), 123) self.assertEqual(obj.x, 1 + 123) self.assertEqual(obj.y, 2 + 123) # Setter from python obj.x = 333 obj.y = 444 obj.arr = newarr = np.arange(5, dtype=np.float32) self.assertEqual(obj.x, 333) self.assertEqual(obj.y, 444) self.assertIs(obj.arr, newarr) def test_jitclass_datalayout(self): spec = OrderedDict() # Boolean has different layout as value vs data spec["val"] = boolean @jitclass(spec) class Foo(object): def __init__(self, val): self.val = val self.assertTrue(Foo(True).val) self.assertFalse(Foo(False).val) def test_deferred_type(self): node_type = deferred_type() spec = OrderedDict() spec["data"] = float32 spec["next"] = optional(node_type) @njit def get_data(node): return node.data @jitclass(spec) class LinkedNode(object): def __init__(self, data, next): self.data = data self.next = next def get_next_data(self): # use deferred type as argument return get_data(self.next) def append_to_tail(self, other): cur = self while cur.next is not None: cur = cur.next cur.next = other node_type.define(LinkedNode.class_type.instance_type) first = LinkedNode(123, None) self.assertEqual(first.data, 123) self.assertIsNone(first.next) second = LinkedNode(321, first) first_meminfo = _get_meminfo(first) second_meminfo = _get_meminfo(second) self.assertEqual(first_meminfo.refcount, 3) self.assertEqual(second.next.data, first.data) self.assertEqual(first_meminfo.refcount, 3) self.assertEqual(second_meminfo.refcount, 2) # Test using deferred type as argument first_val = second.get_next_data() self.assertEqual(first_val, first.data) # Check setattr (issue #2606) self.assertIsNone(first.next) second.append_to_tail(LinkedNode(567, None)) self.assertIsNotNone(first.next) self.assertEqual(first.next.data, 567) self.assertIsNone(first.next.next) second.append_to_tail(LinkedNode(678, None)) self.assertIsNotNone(first.next.next) self.assertEqual(first.next.next.data, 678) # Check ownership self.assertEqual(first_meminfo.refcount, 3) del second, second_meminfo self.assertEqual(first_meminfo.refcount, 2) def test_c_structure(self): spec = OrderedDict() spec["a"] = int32 spec["b"] = int16 spec["c"] = float64 @jitclass(spec) class Struct(object): def __init__(self, a, b, c): self.a = a self.b = b self.c = c st = Struct(0xabcd, 0xef, 3.1415) class CStruct(ctypes.Structure): _fields_ = [ ("a", ctypes.c_int32), ("b", ctypes.c_int16), ("c", ctypes.c_double), ] ptr = ctypes.c_void_p(_box.box_get_dataptr(st)) cstruct = ctypes.cast(ptr, ctypes.POINTER(CStruct))[0] self.assertEqual(cstruct.a, st.a) self.assertEqual(cstruct.b, st.b) self.assertEqual(cstruct.c, st.c) def test_is(self): Vector = self._make_Vector2() vec_a = Vector(1, 2) @njit def do_is(a, b): return a is b with self.assertRaises(LoweringError) as raises: # trigger compilation do_is(vec_a, vec_a) self.assertIn("no default `is` implementation", str(raises.exception)) def test_isinstance(self): Vector2 = self._make_Vector2() vec = Vector2(1, 2) self.assertIsInstance(vec, Vector2) def test_subclassing(self): Vector2 = self._make_Vector2() with self.assertRaises(TypeError) as raises: class SubV(Vector2): pass self.assertEqual(str(raises.exception), "cannot subclass from a jitclass") def test_base_class(self): class Base(object): def what(self): return self.attr @jitclass([("attr", int32)]) class Test(Base): def __init__(self, attr): self.attr = attr obj = Test(123) self.assertEqual(obj.what(), 123) def test_globals(self): class Mine(object): constant = 123 def __init__(self): pass with self.assertRaises(TypeError) as raises: jitclass(Mine) self.assertEqual(str(raises.exception), "class members are not yet supported: constant") def test_user_getter_setter(self): @jitclass([("attr", int32)]) class Foo(object): def __init__(self, attr): self.attr = attr @property def value(self): return self.attr + 1 @value.setter def value(self, val): self.attr = val - 1 foo = Foo(123) self.assertEqual(foo.attr, 123) # Getter self.assertEqual(foo.value, 123 + 1) # Setter foo.value = 789 self.assertEqual(foo.attr, 789 - 1) self.assertEqual(foo.value, 789) # Test nopython mode usage of getter and setter @njit def bar(foo, val): a = foo.value foo.value = val b = foo.value c = foo.attr return a, b, c a, b, c = bar(foo, 567) self.assertEqual(a, 789) self.assertEqual(b, 567) self.assertEqual(c, 567 - 1) def test_user_deleter_error(self): class Foo(object): def __init__(self): pass @property def value(self): return 1 @value.deleter def value(self): pass with self.assertRaises(TypeError) as raises: jitclass(Foo) self.assertEqual(str(raises.exception), "deleter is not supported: value") def test_name_shadowing_error(self): class Foo(object): def __init__(self): pass @property def my_property(self): pass def my_method(self): pass with self.assertRaises(NameError) as raises: jitclass(Foo, [("my_property", int32)]) self.assertEqual(str(raises.exception), "name shadowing: my_property") with self.assertRaises(NameError) as raises: jitclass(Foo, [("my_method", int32)]) self.assertEqual(str(raises.exception), "name shadowing: my_method") def test_distinct_classes(self): # Different classes with the same names shouldn't confuse the compiler @jitclass([("x", int32)]) class Foo(object): def __init__(self, x): self.x = x + 2 def run(self): return self.x + 1 FirstFoo = Foo @jitclass([("x", int32)]) class Foo(object): def __init__(self, x): self.x = x - 2 def run(self): return self.x - 1 SecondFoo = Foo foo = FirstFoo(5) self.assertEqual(foo.x, 7) self.assertEqual(foo.run(), 8) foo = SecondFoo(5) self.assertEqual(foo.x, 3) self.assertEqual(foo.run(), 2) def test_parameterized(self): class MyClass(object): def __init__(self, value): self.value = value def create_my_class(value): cls = jitclass(MyClass, [("value", typeof(value))]) return cls(value) a = create_my_class(123) self.assertEqual(a.value, 123) b = create_my_class(12.3) self.assertEqual(b.value, 12.3) c = create_my_class(np.array([123])) np.testing.assert_equal(c.value, [123]) d = create_my_class(np.array([12.3])) np.testing.assert_equal(d.value, [12.3]) def test_protected_attrs(self): spec = { "value": int32, "_value": float32, "__value": int32, "__value__": int32, } @jitclass(spec) class MyClass(object): def __init__(self, value): self.value = value self._value = value / 2 self.__value = value * 2 self.__value__ = value - 1 @property def private_value(self): return self.__value @property def _inner_value(self): return self._value @_inner_value.setter def _inner_value(self, v): self._value = v @property def __private_value(self): return self.__value @__private_value.setter def __private_value(self, v): self.__value = v def swap_private_value(self, new): old = self.__private_value self.__private_value = new return old def _protected_method(self, factor): return self._value * factor def __private_method(self, factor): return self.__value * factor def check_private_method(self, factor): return self.__private_method(factor) value = 123 inst = MyClass(value) # test attributes self.assertEqual(inst.value, value) self.assertEqual(inst._value, value / 2) self.assertEqual(inst.private_value, value * 2) # test properties self.assertEqual(inst._inner_value, inst._value) freeze_inst_value = inst._value inst._inner_value -= 1 self.assertEqual(inst._inner_value, freeze_inst_value - 1) self.assertEqual(inst.swap_private_value(321), value * 2) self.assertEqual(inst.swap_private_value(value * 2), 321) # test methods self.assertEqual(inst._protected_method(3), inst._value * 3) self.assertEqual(inst.check_private_method(3), inst.private_value * 3) # test special self.assertEqual(inst.__value__, value - 1) inst.__value__ -= 100 self.assertEqual(inst.__value__, value - 101) # test errors @njit def access_dunder(inst): return inst.__value with self.assertRaises(errors.TypingError) as raises: access_dunder(inst) # It will appear as "_TestJitClass__value" because the `access_dunder` # is under the scope of "TestJitClass". self.assertIn("_TestJitClass__value", str(raises.exception)) with self.assertRaises(AttributeError) as raises: access_dunder.py_func(inst) self.assertIn("_TestJitClass__value", str(raises.exception)) @skip_if_typeguard def test_annotations(self): """ Methods with annotations should compile fine (issue #1911). """ from .annotation_usecases import AnnotatedClass spec = {"x": int32} cls = jitclass(AnnotatedClass, spec) obj = cls(5) self.assertEqual(obj.x, 5) self.assertEqual(obj.add(2), 7) def test_docstring(self): @jitclass class Apple(object): "Class docstring" def __init__(self): "init docstring" def foo(self): "foo method docstring" @property def aval(self): "aval property docstring" self.assertEqual(Apple.__doc__, "Class docstring") self.assertEqual(Apple.__init__.__doc__, "init docstring") self.assertEqual(Apple.foo.__doc__, "foo method docstring") self.assertEqual(Apple.aval.__doc__, "aval property docstring") def test_kwargs(self): spec = [("a", int32), ("b", float64)] @jitclass(spec) class TestClass(object): def __init__(self, x, y, z): self.a = x * y self.b = z x = 2 y = 2 z = 1.1 kwargs = {"y": y, "z": z} tc = TestClass(x=2, **kwargs) self.assertEqual(tc.a, x * y) self.assertEqual(tc.b, z) def test_default_args(self): spec = [("x", int32), ("y", int32), ("z", int32)] @jitclass(spec) class TestClass(object): def __init__(self, x, y, z=1): self.x = x self.y = y self.z = z tc = TestClass(1, 2, 3) self.assertEqual(tc.x, 1) self.assertEqual(tc.y, 2) self.assertEqual(tc.z, 3) tc = TestClass(1, 2) self.assertEqual(tc.x, 1) self.assertEqual(tc.y, 2) self.assertEqual(tc.z, 1) tc = TestClass(y=2, z=5, x=1) self.assertEqual(tc.x, 1) self.assertEqual(tc.y, 2) self.assertEqual(tc.z, 5) def test_default_args_keyonly(self): spec = [("x", int32), ("y", int32), ("z", int32), ("a", int32)] TestClass = jitclass(TestClass1, spec) tc = TestClass(2, 3) self.assertEqual(tc.x, 2) self.assertEqual(tc.y, 3) self.assertEqual(tc.z, 1) self.assertEqual(tc.a, 5) tc = TestClass(y=4, x=2, a=42, z=100) self.assertEqual(tc.x, 2) self.assertEqual(tc.y, 4) self.assertEqual(tc.z, 100) self.assertEqual(tc.a, 42) tc = TestClass(y=4, x=2, a=42) self.assertEqual(tc.x, 2) self.assertEqual(tc.y, 4) self.assertEqual(tc.z, 1) self.assertEqual(tc.a, 42) tc = TestClass(y=4, x=2) self.assertEqual(tc.x, 2) self.assertEqual(tc.y, 4) self.assertEqual(tc.z, 1) self.assertEqual(tc.a, 5) def test_default_args_starargs_and_keyonly(self): spec = [("x", int32), ("y", int32), ("z", int32), ("args", types.UniTuple(int32, 2)), ("a", int32)] with self.assertRaises(errors.UnsupportedError) as raises: jitclass(TestClass2, spec) msg = "VAR_POSITIONAL argument type unsupported" self.assertIn(msg, str(raises.exception)) def test_generator_method(self): spec = [] @jitclass(spec) class TestClass(object): def __init__(self): pass def gen(self, niter): for i in range(niter): yield np.arange(i) def expected_gen(niter): for i in range(niter): yield np.arange(i) for niter in range(10): for expect, got in zip(expected_gen(niter), TestClass().gen(niter)): self.assertPreciseEqual(expect, got) def test_getitem(self): spec = [("data", int32[:])] @jitclass(spec) class TestClass(object): def __init__(self): self.data = np.zeros(10, dtype=np.int32) def __setitem__(self, key, data): self.data[key] = data def __getitem__(self, key): return self.data[key] @njit def create_and_set_indices(): t = TestClass() t[1] = 1 t[2] = 2 t[3] = 3 return t @njit def get_index(t, n): return t[n] t = create_and_set_indices() self.assertEqual(get_index(t, 1), 1) self.assertEqual(get_index(t, 2), 2) self.assertEqual(get_index(t, 3), 3) def test_getitem_unbox(self): spec = [("data", int32[:])] @jitclass(spec) class TestClass(object): def __init__(self): self.data = np.zeros(10, dtype=np.int32) def __setitem__(self, key, data): self.data[key] = data def __getitem__(self, key): return self.data[key] t = TestClass() t[1] = 10 @njit def set2return1(t): t[2] = 20 return t[1] t_1 = set2return1(t) self.assertEqual(t_1, 10) self.assertEqual(t[2], 20) def test_getitem_complex_key(self): spec = [("data", int32[:, :])] @jitclass(spec) class TestClass(object): def __init__(self): self.data = np.zeros((10, 10), dtype=np.int32) def __setitem__(self, key, data): self.data[int(key.real), int(key.imag)] = data def __getitem__(self, key): return self.data[int(key.real), int(key.imag)] t = TestClass() t[complex(1, 1)] = 3 @njit def get_key(t, real, imag): return t[complex(real, imag)] @njit def set_key(t, real, imag, data): t[complex(real, imag)] = data self.assertEqual(get_key(t, 1, 1), 3) set_key(t, 2, 2, 4) self.assertEqual(t[complex(2, 2)], 4) def test_getitem_tuple_key(self): spec = [("data", int32[:, :])] @jitclass(spec) class TestClass(object): def __init__(self): self.data = np.zeros((10, 10), dtype=np.int32) def __setitem__(self, key, data): self.data[key[0], key[1]] = data def __getitem__(self, key): return self.data[key[0], key[1]] t = TestClass() t[1, 1] = 11 @njit def get11(t): return t[1, 1] @njit def set22(t, data): t[2, 2] = data self.assertEqual(get11(t), 11) set22(t, 22) self.assertEqual(t[2, 2], 22) def test_getitem_slice_key(self): spec = [("data", int32[:])] @jitclass(spec) class TestClass(object): def __init__(self): self.data = np.zeros(10, dtype=np.int32) def __setitem__(self, slc, data): self.data[slc.start] = data self.data[slc.stop] = data + slc.step def __getitem__(self, slc): return self.data[slc.start] t = TestClass() # set t.data[1] = 1 and t.data[5] = 2 t[1:5:1] = 1 self.assertEqual(t[1:1:1], 1) self.assertEqual(t[5:5:5], 2) @njit def get5(t): return t[5:6:1] self.assertEqual(get5(t), 2) # sets t.data[2] = data, and t.data[6] = data + 1 @njit def set26(t, data): t[2:6:1] = data set26(t, 2) self.assertEqual(t[2:2:1], 2) self.assertEqual(t[6:6:1], 3) def test_jitclass_longlabel_not_truncated(self): # See issue #3872, llvm 7 introduced a max label length of 1024 chars # Numba ships patched llvm 7.1 (ppc64le) and patched llvm 8 to undo this # change, this test is here to make sure long labels are ok: alphabet = [chr(ord("a") + x) for x in range(26)] spec = [(letter * 10, float64) for letter in alphabet] spec.extend([(letter.upper() * 10, float64) for letter in alphabet]) @jitclass(spec) class TruncatedLabel(object): def __init__(self,): self.aaaaaaaaaa = 10. def meth1(self): self.bbbbbbbbbb = random.gauss(self.aaaaaaaaaa, self.aaaaaaaaaa) def meth2(self): self.meth1() # unpatched LLVMs will raise here... TruncatedLabel().meth2() def test_pickling(self): @jitclass class PickleTestSubject(object): def __init__(self): pass inst = PickleTestSubject() ty = typeof(inst) self.assertIsInstance(ty, types.ClassInstanceType) pickled = pickle.dumps(ty) self.assertIs(pickle.loads(pickled), ty) def test_static_methods(self): @jitclass([("x", int32)]) class Test1: def __init__(self, x): self.x = x def increase(self, y): self.x = self.add(self.x, y) return self.x @staticmethod def add(a, b): return a + b @staticmethod def sub(a, b): return a - b @jitclass([("x", int32)]) class Test2: def __init__(self, x): self.x = x def increase(self, y): self.x = self.add(self.x, y) return self.x @staticmethod def add(a, b): return a - b self.assertIsInstance(Test1.add, Dispatcher) self.assertIsInstance(Test1.sub, Dispatcher) self.assertIsInstance(Test2.add, Dispatcher) self.assertNotEqual(Test1.add, Test2.add) self.assertEqual(3, Test1.add(1, 2)) self.assertEqual(-1, Test2.add(1, 2)) self.assertEqual(4, Test1.sub(6, 2)) t1 = Test1(0) t2 = Test2(0) self.assertEqual(1, t1.increase(1)) self.assertEqual(-1, t2.increase(1)) self.assertEqual(2, t1.add(1, 1)) self.assertEqual(0, t1.sub(1, 1)) self.assertEqual(0, t2.add(1, 1)) self.assertEqual(2j, t1.add(1j, 1j)) self.assertEqual(1j, t1.sub(2j, 1j)) self.assertEqual("foobar", t1.add("foo", "bar")) with self.assertRaises(AttributeError) as raises: Test2.sub(3, 1) self.assertIn("has no attribute 'sub'", str(raises.exception)) with self.assertRaises(TypeError) as raises: Test1.add(3) self.assertIn("not enough arguments: expected 2, got 1", str(raises.exception)) # Check error message for calling a static method as a class attr from # another method (currently unsupported). @jitclass([]) class Test3: def __init__(self): pass @staticmethod def a_static_method(a, b): pass def call_static(self): return Test3.a_static_method(1, 2) invalid = Test3() with self.assertRaises(errors.TypingError) as raises: invalid.call_static() self.assertIn("Unknown attribute 'a_static_method'", str(raises.exception)) def test_jitclass_decorator_usecases(self): spec = OrderedDict(x=float64) @jitclass() class Test1: x: float def __init__(self): self.x = 0 self.assertIsInstance(Test1, JitClassType) self.assertDictEqual(Test1.class_type.struct, spec) @jitclass(spec=spec) class Test2: def __init__(self): self.x = 0 self.assertIsInstance(Test2, JitClassType) self.assertDictEqual(Test2.class_type.struct, spec) @jitclass class Test3: x: float def __init__(self): self.x = 0 self.assertIsInstance(Test3, JitClassType) self.assertDictEqual(Test3.class_type.struct, spec) @jitclass(spec) class Test4: def __init__(self): self.x = 0 self.assertIsInstance(Test4, JitClassType) self.assertDictEqual(Test4.class_type.struct, spec) def test_jitclass_function_usecases(self): spec = OrderedDict(x=float64) class AnnotatedTest: x: float def __init__(self): self.x = 0 JitTest1 = jitclass(AnnotatedTest) self.assertIsInstance(JitTest1, JitClassType) self.assertDictEqual(JitTest1.class_type.struct, spec) class UnannotatedTest: def __init__(self): self.x = 0 JitTest2 = jitclass(UnannotatedTest, spec) self.assertIsInstance(JitTest2, JitClassType) self.assertDictEqual(JitTest2.class_type.struct, spec) def test_jitclass_isinstance(self): spec = OrderedDict(value=int32) @jitclass(spec) class Foo(object): def __init__(self, value): self.value = value def getValue(self): return self.value def getValueIncr(self): return self.value + 1 @jitclass(spec) class Bar(object): def __init__(self, value): self.value = value def getValue(self): return self.value def test_jitclass_isinstance(obj): if isinstance(obj, (Foo, Bar)): # call something that both classes implements x = obj.getValue() if isinstance(obj, Foo): # something that only Foo implements return obj.getValueIncr() + x, 'Foo' else: return obj.getValue() + x, 'Bar' else: return 'no match' pyfunc = test_jitclass_isinstance cfunc = njit(test_jitclass_isinstance) self.assertIsInstance(Foo, JitClassType) self.assertEqual(pyfunc(Foo(3)), cfunc(Foo(3))) self.assertEqual(pyfunc(Bar(123)), cfunc(Bar(123))) self.assertEqual(pyfunc(0), cfunc(0)) def test_jitclass_unsupported_dunder(self): with self.assertRaises(TypeError) as e: @jitclass class Foo(object): def __init__(self): return def __enter__(self): return None Foo() self.assertIn("Method '__enter__' is not supported.", str(e.exception)) def test_modulename(self): @jitclass class TestModname(object): def __init__(self): self.x = 12 thisModule = __name__ classModule = TestModname.__module__ self.assertEqual(thisModule, classModule) class TestJitClassOverloads(MemoryLeakMixin, TestCase): class PyList: def __init__(self): self.x = [0] def append(self, y): self.x.append(y) def clear(self): self.x.clear() def __abs__(self): return len(self.x) * 7 def __bool__(self): return len(self.x) % 3 != 0 def __complex__(self): c = complex(2) if self.x: c += self.x[0] return c def __contains__(self, y): return y in self.x def __float__(self): f = 3.1415 if self.x: f += self.x[0] return f def __int__(self): i = 5 if self.x: i += self.x[0] return i def __len__(self): return len(self.x) + 1 def __str__(self): if len(self.x) == 0: return "PyList empty" else: return "PyList non-empty" @staticmethod def get_int_wrapper(): @jitclass([("x", types.intp)]) class IntWrapper: def __init__(self, value): self.x = value def __eq__(self, other): return self.x == other.x def __hash__(self): return self.x def __lshift__(self, other): return IntWrapper(self.x << other.x) def __rshift__(self, other): return IntWrapper(self.x >> other.x) def __and__(self, other): return IntWrapper(self.x & other.x) def __or__(self, other): return IntWrapper(self.x | other.x) def __xor__(self, other): return IntWrapper(self.x ^ other.x) return IntWrapper @staticmethod def get_float_wrapper(): @jitclass([("x", types.float64)]) class FloatWrapper: def __init__(self, value): self.x = value def __eq__(self, other): return self.x == other.x def __hash__(self): return self.x def __ge__(self, other): return self.x >= other.x def __gt__(self, other): return self.x > other.x def __le__(self, other): return self.x <= other.x def __lt__(self, other): return self.x < other.x def __add__(self, other): return FloatWrapper(self.x + other.x) def __floordiv__(self, other): return FloatWrapper(self.x // other.x) def __mod__(self, other): return FloatWrapper(self.x % other.x) def __mul__(self, other): return FloatWrapper(self.x * other.x) def __neg__(self, other): return FloatWrapper(-self.x) def __pos__(self, other): return FloatWrapper(+self.x) def __pow__(self, other): return FloatWrapper(self.x ** other.x) def __sub__(self, other): return FloatWrapper(self.x - other.x) def __truediv__(self, other): return FloatWrapper(self.x / other.x) return FloatWrapper def assertSame(self, first, second, msg=None): self.assertEqual(type(first), type(second), msg=msg) self.assertEqual(first, second, msg=msg) def test_overloads(self): # Check that the dunder methods are exposed on ClassInstanceType. JitList = jitclass({"x": types.List(types.intp)})(self.PyList) py_funcs = [ lambda x: abs(x), lambda x: x.__abs__(), lambda x: bool(x), lambda x: x.__bool__(), lambda x: complex(x), lambda x: x.__complex__(), lambda x: 0 in x, # contains lambda x: x.__contains__(0), lambda x: float(x), lambda x: x.__float__(), lambda x: int(x), lambda x: x.__int__(), lambda x: len(x), lambda x: x.__len__(), lambda x: str(x), lambda x: x.__str__(), lambda x: 1 if x else 0, # truth ] jit_funcs = [njit(f) for f in py_funcs] py_list = self.PyList() jit_list = JitList() for py_f, jit_f in zip(py_funcs, jit_funcs): self.assertSame(py_f(py_list), py_f(jit_list)) self.assertSame(py_f(py_list), jit_f(jit_list)) py_list.append(2) jit_list.append(2) for py_f, jit_f in zip(py_funcs, jit_funcs): self.assertSame(py_f(py_list), py_f(jit_list)) self.assertSame(py_f(py_list), jit_f(jit_list)) py_list.append(-5) jit_list.append(-5) for py_f, jit_f in zip(py_funcs, jit_funcs): self.assertSame(py_f(py_list), py_f(jit_list)) self.assertSame(py_f(py_list), jit_f(jit_list)) py_list.clear() jit_list.clear() for py_f, jit_f in zip(py_funcs, jit_funcs): self.assertSame(py_f(py_list), py_f(jit_list)) self.assertSame(py_f(py_list), jit_f(jit_list)) def test_bool_fallback(self): def py_b(x): return bool(x) jit_b = njit(py_b) @jitclass([("x", types.List(types.intp))]) class LenClass: def __init__(self, x): self.x = x def __len__(self): return len(self.x) % 4 def append(self, y): self.x.append(y) def pop(self): self.x.pop(0) obj = LenClass([1, 2, 3]) self.assertTrue(py_b(obj)) self.assertTrue(jit_b(obj)) obj.append(4) self.assertFalse(py_b(obj)) self.assertFalse(jit_b(obj)) obj.pop() self.assertTrue(py_b(obj)) self.assertTrue(jit_b(obj)) @jitclass([("y", types.float64)]) class NormalClass: def __init__(self, y): self.y = y obj = NormalClass(0) self.assertTrue(py_b(obj)) self.assertTrue(jit_b(obj)) def test_numeric_fallback(self): def py_c(x): return complex(x) def py_f(x): return float(x) def py_i(x): return int(x) jit_c = njit(py_c) jit_f = njit(py_f) jit_i = njit(py_i) @jitclass([]) class FloatClass: def __init__(self): pass def __float__(self): return 3.1415 obj = FloatClass() self.assertSame(py_c(obj), complex(3.1415)) self.assertSame(jit_c(obj), complex(3.1415)) self.assertSame(py_f(obj), 3.1415) self.assertSame(jit_f(obj), 3.1415) with self.assertRaises(TypeError) as e: py_i(obj) self.assertIn("int", str(e.exception)) with self.assertRaises(TypingError) as e: jit_i(obj) self.assertIn("int", str(e.exception)) @jitclass([]) class IntClass: def __init__(self): pass def __int__(self): return 7 obj = IntClass() self.assertSame(py_i(obj), 7) self.assertSame(jit_i(obj), 7) with self.assertRaises(TypeError) as e: py_c(obj) self.assertIn("complex", str(e.exception)) with self.assertRaises(TypingError) as e: jit_c(obj) self.assertIn("complex", str(e.exception)) with self.assertRaises(TypeError) as e: py_f(obj) self.assertIn("float", str(e.exception)) with self.assertRaises(TypingError) as e: jit_f(obj) self.assertIn("float", str(e.exception)) @jitclass([]) class IndexClass: def __init__(self): pass def __index__(self): return 1 obj = IndexClass() self.assertSame(py_c(obj), complex(1)) self.assertSame(jit_c(obj), complex(1)) self.assertSame(py_f(obj), 1.) self.assertSame(jit_f(obj), 1.) self.assertSame(py_i(obj), 1) self.assertSame(jit_i(obj), 1) @jitclass([]) class FloatIntIndexClass: def __init__(self): pass def __float__(self): return 3.1415 def __int__(self): return 7 def __index__(self): return 1 obj = FloatIntIndexClass() self.assertSame(py_c(obj), complex(3.1415)) self.assertSame(jit_c(obj), complex(3.1415)) self.assertSame(py_f(obj), 3.1415) self.assertSame(jit_f(obj), 3.1415) self.assertSame(py_i(obj), 7) self.assertSame(jit_i(obj), 7) def test_arithmetic_logical(self): IntWrapper = self.get_int_wrapper() FloatWrapper = self.get_float_wrapper() float_py_funcs = [ lambda x, y: x == y, lambda x, y: x != y, lambda x, y: x >= y, lambda x, y: x > y, lambda x, y: x <= y, lambda x, y: x < y, lambda x, y: x + y, lambda x, y: x // y, lambda x, y: x % y, lambda x, y: x * y, lambda x, y: x ** y, lambda x, y: x - y, lambda x, y: x / y, ] int_py_funcs = [ lambda x, y: x == y, lambda x, y: x != y, lambda x, y: x << y, lambda x, y: x >> y, lambda x, y: x & y, lambda x, y: x | y, lambda x, y: x ^ y, ] test_values = [ (0.0, 2.0), (1.234, 3.1415), (13.1, 1.01), ] def unwrap(value): return getattr(value, "x", value) for jit_f, (x, y) in itertools.product( map(njit, float_py_funcs), test_values): py_f = jit_f.py_func expected = py_f(x, y) jit_x = FloatWrapper(x) jit_y = FloatWrapper(y) check = ( self.assertEqual if type(expected) is not float else self.assertAlmostEqual ) check(expected, jit_f(x, y)) check(expected, unwrap(py_f(jit_x, jit_y))) check(expected, unwrap(jit_f(jit_x, jit_y))) for jit_f, (x, y) in itertools.product( map(njit, int_py_funcs), test_values): py_f = jit_f.py_func x, y = int(x), int(y) expected = py_f(x, y) jit_x = IntWrapper(x) jit_y = IntWrapper(y) self.assertEqual(expected, jit_f(x, y)) self.assertEqual(expected, unwrap(py_f(jit_x, jit_y))) self.assertEqual(expected, unwrap(jit_f(jit_x, jit_y))) def test_arithmetic_logical_inplace(self): # If __i*__ methods are not defined, should fall back to normal methods. JitIntWrapper = self.get_int_wrapper() JitFloatWrapper = self.get_float_wrapper() PyIntWrapper = JitIntWrapper.mro()[1] PyFloatWrapper = JitFloatWrapper.mro()[1] @jitclass([("x", types.intp)]) class JitIntUpdateWrapper(PyIntWrapper): def __init__(self, value): self.x = value def __ilshift__(self, other): return JitIntUpdateWrapper(self.x << other.x) def __irshift__(self, other): return JitIntUpdateWrapper(self.x >> other.x) def __iand__(self, other): return JitIntUpdateWrapper(self.x & other.x) def __ior__(self, other): return JitIntUpdateWrapper(self.x | other.x) def __ixor__(self, other): return JitIntUpdateWrapper(self.x ^ other.x) @jitclass({"x": types.float64}) class JitFloatUpdateWrapper(PyFloatWrapper): def __init__(self, value): self.x = value def __iadd__(self, other): return JitFloatUpdateWrapper(self.x + 2.718 * other.x) def __ifloordiv__(self, other): return JitFloatUpdateWrapper(self.x * 2.718 // other.x) def __imod__(self, other): return JitFloatUpdateWrapper(self.x % (other.x + 1)) def __imul__(self, other): return JitFloatUpdateWrapper(self.x * other.x + 1) def __ipow__(self, other): return JitFloatUpdateWrapper(self.x ** other.x + 1) def __isub__(self, other): return JitFloatUpdateWrapper(self.x - 3.1415 * other.x) def __itruediv__(self, other): return JitFloatUpdateWrapper((self.x + 1) / other.x) PyIntUpdateWrapper = JitIntUpdateWrapper.mro()[1] PyFloatUpdateWrapper = JitFloatUpdateWrapper.mro()[1] def get_update_func(op): template = f""" def f(x, y): x {op}= y return x """ namespace = {} exec(template, namespace) return namespace["f"] float_py_funcs = [get_update_func(op) for op in [ "+", "//", "%", "*", "**", "-", "/", ]] int_py_funcs = [get_update_func(op) for op in [ "<<", ">>", "&", "|", "^", ]] test_values = [ (0.0, 2.0), (1.234, 3.1415), (13.1, 1.01), ] for jit_f, (py_cls, jit_cls), (x, y) in itertools.product( map(njit, float_py_funcs), [ (PyFloatWrapper, JitFloatWrapper), (PyFloatUpdateWrapper, JitFloatUpdateWrapper) ], test_values): py_f = jit_f.py_func expected = py_f(py_cls(x), py_cls(y)).x self.assertAlmostEqual(expected, py_f(jit_cls(x), jit_cls(y)).x) self.assertAlmostEqual(expected, jit_f(jit_cls(x), jit_cls(y)).x) for jit_f, (py_cls, jit_cls), (x, y) in itertools.product( map(njit, int_py_funcs), [ (PyIntWrapper, JitIntWrapper), (PyIntUpdateWrapper, JitIntUpdateWrapper) ], test_values): x, y = int(x), int(y) py_f = jit_f.py_func expected = py_f(py_cls(x), py_cls(y)).x self.assertEqual(expected, py_f(jit_cls(x), jit_cls(y)).x) self.assertEqual(expected, jit_f(jit_cls(x), jit_cls(y)).x) def test_hash_eq_ne(self): class HashEqTest: x: int def __init__(self, x): self.x = x def __hash__(self): return self.x % 10 def __eq__(self, o): return (self.x - o.x) % 20 == 0 class HashEqNeTest(HashEqTest): def __ne__(self, o): return (self.x - o.x) % 20 > 1 def py_hash(x): return hash(x) def py_eq(x, y): return x == y def py_ne(x, y): return x != y def identity_decorator(f): return f comparisons = [ (0, 1), # Will give different ne results. (2, 22), (7, 10), (3, 3), ] for base_cls, use_jit in itertools.product( [HashEqTest, HashEqNeTest], [False, True] ): decorator = njit if use_jit else identity_decorator hash_func = decorator(py_hash) eq_func = decorator(py_eq) ne_func = decorator(py_ne) jit_cls = jitclass(base_cls) for v in [0, 2, 10, 24, -8]: self.assertEqual(hash_func(jit_cls(v)), v % 10) for x, y in comparisons: self.assertEqual( eq_func(jit_cls(x), jit_cls(y)), base_cls(x) == base_cls(y), ) self.assertEqual( ne_func(jit_cls(x), jit_cls(y)), base_cls(x) != base_cls(y), ) def test_bool_fallback_len(self): # Check that the fallback to using len(obj) to determine truth of an # object is implemented correctly as per # https://docs.python.org/3/library/stdtypes.html#truth-value-testing # # Relevant points: # # "By default, an object is considered true unless its class defines # either a __bool__() method that returns False or a __len__() method # that returns zero, when called with the object." # # and: # # "Operations and built-in functions that have a Boolean result always # return 0 or False for false and 1 or True for true, unless otherwise # stated." class NoBoolHasLen: def __init__(self, val): self.val = val def __len__(self): return self.val def get_bool(self): return bool(self) py_class = NoBoolHasLen jitted_class = jitclass([('val', types.int64)])(py_class) py_class_0_bool = py_class(0).get_bool() py_class_2_bool = py_class(2).get_bool() jitted_class_0_bool = jitted_class(0).get_bool() jitted_class_2_bool = jitted_class(2).get_bool() # Truth values from bool(obj) should be equal self.assertEqual(py_class_0_bool, jitted_class_0_bool) self.assertEqual(py_class_2_bool, jitted_class_2_bool) # Truth values from bool(obj) should be the same type self.assertEqual(type(py_class_0_bool), type(jitted_class_0_bool)) self.assertEqual(type(py_class_2_bool), type(jitted_class_2_bool)) def test_bool_fallback_default(self): # Similar to test_bool_fallback, but checks the case where there is no # __bool__() or __len__() defined, so the object should always be True. class NoBoolNoLen: def __init__(self): pass def get_bool(self): return bool(self) py_class = NoBoolNoLen jitted_class = jitclass([])(py_class) py_class_bool = py_class().get_bool() jitted_class_bool = jitted_class().get_bool() # Truth values from bool(obj) should be equal self.assertEqual(py_class_bool, jitted_class_bool) # Truth values from bool(obj) should be the same type self.assertEqual(type(py_class_bool), type(jitted_class_bool)) def test_operator_reflection(self): class OperatorsDefined: def __init__(self, x): self.x = x def __eq__(self, other): return self.x == other.x def __le__(self, other): return self.x <= other.x def __lt__(self, other): return self.x < other.x def __ge__(self, other): return self.x >= other.x def __gt__(self, other): return self.x > other.x class NoOperatorsDefined: def __init__(self, x): self.x = x spec = [('x', types.int32)] JitOperatorsDefined = jitclass(spec)(OperatorsDefined) JitNoOperatorsDefined = jitclass(spec)(NoOperatorsDefined) py_ops_defined = OperatorsDefined(2) py_ops_not_defined = NoOperatorsDefined(3) jit_ops_defined = JitOperatorsDefined(2) jit_ops_not_defined = JitNoOperatorsDefined(3) self.assertEqual(py_ops_not_defined == py_ops_defined, jit_ops_not_defined == jit_ops_defined) self.assertEqual(py_ops_not_defined <= py_ops_defined, jit_ops_not_defined <= jit_ops_defined) self.assertEqual(py_ops_not_defined < py_ops_defined, jit_ops_not_defined < jit_ops_defined) self.assertEqual(py_ops_not_defined >= py_ops_defined, jit_ops_not_defined >= jit_ops_defined) self.assertEqual(py_ops_not_defined > py_ops_defined, jit_ops_not_defined > jit_ops_defined) @skip_unless_scipy def test_matmul_operator(self): class ArrayAt: def __init__(self, array): self.arr = array def __matmul__(self, other): return self.arr @ other.arr def __rmatmul__(self, other): return other.arr @ self.arr def __imatmul__(self, other): self.arr = self.arr @ other.arr return self class ArrayNoAt: def __init__(self, array): self.arr = array n = 3 np.random.seed(1) vec = np.random.random(size=(n,)) mat = np.random.random(size=(n, n)) vector_noat = ArrayNoAt(vec) vector_at = ArrayAt(vec) jit_vector_noat = jitclass(ArrayNoAt, spec={"arr": float64[::1]})(vec) jit_vector_at = jitclass(ArrayAt, spec={"arr": float64[::1]})(vec) matrix_noat = ArrayNoAt(mat) matrix_at = ArrayAt(mat) jit_matrix_noat = jitclass(ArrayNoAt, spec={"arr": float64[:,::1]})(mat) jit_matrix_at = jitclass(ArrayAt, spec={"arr": float64[:,::1]})(mat) # __matmul__ np.testing.assert_allclose(vector_at @ vector_noat, jit_vector_at @ jit_vector_noat) np.testing.assert_allclose(vector_at @ matrix_noat, jit_vector_at @ jit_matrix_noat) np.testing.assert_allclose(matrix_at @ vector_noat, jit_matrix_at @ jit_vector_noat) np.testing.assert_allclose(matrix_at @ matrix_noat, jit_matrix_at @ jit_matrix_noat) # __rmatmul__ np.testing.assert_allclose(vector_noat @ vector_at, jit_vector_noat @ jit_vector_at) np.testing.assert_allclose(vector_noat @ matrix_at, jit_vector_noat @ jit_matrix_at) np.testing.assert_allclose(matrix_noat @ vector_at, jit_matrix_noat @ jit_vector_at) np.testing.assert_allclose(matrix_noat @ matrix_at, jit_matrix_noat @ jit_matrix_at) # __imatmul__ vector_at @= matrix_noat matrix_at @= matrix_noat jit_vector_at @= jit_matrix_noat jit_matrix_at @= jit_matrix_noat np.testing.assert_allclose(vector_at.arr, jit_vector_at.arr) np.testing.assert_allclose(matrix_at.arr, jit_matrix_at.arr) def test_arithmetic_logical_reflection(self): class OperatorsDefined: def __init__(self, x): self.x = x def __radd__(self, other): return other.x + self.x def __rsub__(self, other): return other.x - self.x def __rmul__(self, other): return other.x * self.x def __rtruediv__(self, other): return other.x / self.x def __rfloordiv__(self, other): return other.x // self.x def __rmod__(self, other): return other.x % self.x def __rpow__(self, other): return other.x ** self.x def __rlshift__(self, other): return other.x << self.x def __rrshift__(self, other): return other.x >> self.x def __rand__(self, other): return other.x & self.x def __rxor__(self, other): return other.x ^ self.x def __ror__(self, other): return other.x | self.x class NoOperatorsDefined: def __init__(self, x): self.x = x float_op = ["+", "-", "*", "**", "/", "//", "%"] int_op = [*float_op, "<<", ">>" , "&", "^", "|"] for test_type, test_op, test_value in [ (int32, int_op, (2, 4)), (float64, float_op, (2., 4.)), (float64[::1], float_op, (np.array([1., 2., 4.]), np.array([20., -24., 1.]))) ]: spec = {"x": test_type} JitOperatorsDefined = jitclass(OperatorsDefined, spec) JitNoOperatorsDefined = jitclass(NoOperatorsDefined, spec) py_ops_defined = OperatorsDefined(test_value[0]) # noqa: F841 py_ops_not_defined = NoOperatorsDefined(test_value[1]) # noqa: F841 jit_ops_defined = JitOperatorsDefined(test_value[0]) # noqa: F841 jit_ops_not_defined = JitNoOperatorsDefined(test_value[1]) # noqa: F841 E501 for op in test_op: if not ("array" in str(test_type)): self.assertEqual( eval(f"py_ops_not_defined {op} py_ops_defined"), eval(f"jit_ops_not_defined {op} jit_ops_defined") ) else: self.assertTupleEqual( tuple(eval(f"py_ops_not_defined {op} py_ops_defined")), tuple(eval(f"jit_ops_not_defined {op} jit_ops_defined")) ) def test_implicit_hash_compiles(self): # Ensure that classes with __hash__ implicitly defined as None due to # the presence of __eq__ are correctly handled by ignoring the __hash__ # class member. class ImplicitHash: def __init__(self): pass def __eq__(self, other): return False jitted = jitclass([])(ImplicitHash) instance = jitted() self.assertFalse(instance == instance) if __name__ == "__main__": unittest.main()