ai-content-maker/.venv/Lib/site-packages/numba/tests/test_jitclasses.py

2021 lines
58 KiB
Python
Raw Normal View History

2024-05-03 04:18:51 +03:00
import ctypes
import itertools
import pickle
import random
import typing as pt
import unittest
from collections import OrderedDict
import numpy as np
from numba import (boolean, deferred_type, float32, float64, int16, int32,
njit, optional, typeof)
from numba.core import errors, types
from numba.core.dispatcher import Dispatcher
from numba.core.errors import LoweringError, TypingError
from numba.core.runtime.nrt import MemInfo
from numba.experimental import jitclass
from numba.experimental.jitclass import _box
from numba.experimental.jitclass.base import JitClassType
from numba.tests.support import MemoryLeakMixin, TestCase, skip_if_typeguard
from numba.tests.support import skip_unless_scipy
class TestClass1(object):
def __init__(self, x, y, z=1, *, a=5):
self.x = x
self.y = y
self.z = z
self.a = a
class TestClass2(object):
def __init__(self, x, y, z=1, *args, a=5):
self.x = x
self.y = y
self.z = z
self.args = args
self.a = a
def _get_meminfo(box):
ptr = _box.box_get_meminfoptr(box)
mi = MemInfo(ptr)
mi.acquire()
return mi
class TestJitClass(TestCase, MemoryLeakMixin):
def _check_spec(self, spec=None, test_cls=None, all_expected=None):
if test_cls is None:
@jitclass(spec)
class Test(object):
def __init__(self):
pass
test_cls = Test
clsty = test_cls.class_type.instance_type
names = list(clsty.struct.keys())
values = list(clsty.struct.values())
if all_expected is None:
if isinstance(spec, OrderedDict):
all_expected = spec.items()
else:
all_expected = spec
assert all_expected is not None
self.assertEqual(len(names), len(all_expected))
for got, expected in zip(zip(names, values), all_expected):
self.assertEqual(got[0], expected[0])
self.assertEqual(got[1], expected[1])
def test_ordereddict_spec(self):
spec = OrderedDict()
spec["x"] = int32
spec["y"] = float32
self._check_spec(spec)
def test_list_spec(self):
spec = [("x", int32),
("y", float32)]
self._check_spec(spec)
def test_type_annotations(self):
spec = [("x", int32)]
@jitclass(spec)
class Test1(object):
x: int
y: pt.List[float]
def __init__(self):
pass
self._check_spec(spec, Test1, spec + [("y", types.ListType(float64))])
def test_type_annotation_inheritance(self):
class Foo:
x: int
@jitclass
class Bar(Foo):
y: float
def __init__(self, value: float) -> None:
self.x = int(value)
self.y = value
self._check_spec(
test_cls=Bar, all_expected=[("x", typeof(0)), ("y", typeof(0.0))]
)
def test_spec_errors(self):
spec1 = [("x", int), ("y", float32[:])]
spec2 = [(1, int32), ("y", float32[:])]
class Test(object):
def __init__(self):
pass
with self.assertRaises(TypeError) as raises:
jitclass(Test, spec1)
self.assertIn("spec values should be Numba type instances",
str(raises.exception))
with self.assertRaises(TypeError) as raises:
jitclass(Test, spec2)
self.assertEqual(str(raises.exception),
"spec keys should be strings, got 1")
def test_init_errors(self):
@jitclass([])
class Test:
def __init__(self):
return 7
with self.assertRaises(errors.TypingError) as raises:
Test()
self.assertIn("__init__() should return None, not",
str(raises.exception))
def _make_Float2AndArray(self):
spec = OrderedDict()
spec["x"] = float32
spec["y"] = float32
spec["arr"] = float32[:]
@jitclass(spec)
class Float2AndArray(object):
def __init__(self, x, y, arr):
self.x = x
self.y = y
self.arr = arr
def add(self, val):
self.x += val
self.y += val
return val
return Float2AndArray
def _make_Vector2(self):
spec = OrderedDict()
spec["x"] = int32
spec["y"] = int32
@jitclass(spec)
class Vector2(object):
def __init__(self, x, y):
self.x = x
self.y = y
return Vector2
def test_jit_class_1(self):
Float2AndArray = self._make_Float2AndArray()
Vector2 = self._make_Vector2()
@njit
def bar(obj):
return obj.x + obj.y
@njit
def foo(a):
obj = Float2AndArray(1, 2, a)
obj.add(123)
vec = Vector2(3, 4)
return bar(obj), bar(vec), obj.arr
inp = np.ones(10, dtype=np.float32)
a, b, c = foo(inp)
self.assertEqual(a, 123 + 1 + 123 + 2)
self.assertEqual(b, 3 + 4)
self.assertPreciseEqual(c, inp)
def test_jitclass_usage_from_python(self):
Float2AndArray = self._make_Float2AndArray()
@njit
def identity(obj):
return obj
@njit
def retrieve_attributes(obj):
return obj.x, obj.y, obj.arr
arr = np.arange(10, dtype=np.float32)
obj = Float2AndArray(1, 2, arr)
obj_meminfo = _get_meminfo(obj)
self.assertEqual(obj_meminfo.refcount, 2)
self.assertEqual(obj_meminfo.data, _box.box_get_dataptr(obj))
self.assertEqual(obj._numba_type_.class_type,
Float2AndArray.class_type)
# Use jit class instance in numba
other = identity(obj)
other_meminfo = _get_meminfo(other) # duplicates MemInfo object to obj
self.assertEqual(obj_meminfo.refcount, 4)
self.assertEqual(other_meminfo.refcount, 4)
self.assertEqual(other_meminfo.data, _box.box_get_dataptr(other))
self.assertEqual(other_meminfo.data, obj_meminfo.data)
# Check dtor
del other, other_meminfo
self.assertEqual(obj_meminfo.refcount, 2)
# Check attributes
out_x, out_y, out_arr = retrieve_attributes(obj)
self.assertEqual(out_x, 1)
self.assertEqual(out_y, 2)
self.assertIs(out_arr, arr)
# Access attributes from python
self.assertEqual(obj.x, 1)
self.assertEqual(obj.y, 2)
self.assertIs(obj.arr, arr)
# Access methods from python
self.assertEqual(obj.add(123), 123)
self.assertEqual(obj.x, 1 + 123)
self.assertEqual(obj.y, 2 + 123)
# Setter from python
obj.x = 333
obj.y = 444
obj.arr = newarr = np.arange(5, dtype=np.float32)
self.assertEqual(obj.x, 333)
self.assertEqual(obj.y, 444)
self.assertIs(obj.arr, newarr)
def test_jitclass_datalayout(self):
spec = OrderedDict()
# Boolean has different layout as value vs data
spec["val"] = boolean
@jitclass(spec)
class Foo(object):
def __init__(self, val):
self.val = val
self.assertTrue(Foo(True).val)
self.assertFalse(Foo(False).val)
def test_deferred_type(self):
node_type = deferred_type()
spec = OrderedDict()
spec["data"] = float32
spec["next"] = optional(node_type)
@njit
def get_data(node):
return node.data
@jitclass(spec)
class LinkedNode(object):
def __init__(self, data, next):
self.data = data
self.next = next
def get_next_data(self):
# use deferred type as argument
return get_data(self.next)
def append_to_tail(self, other):
cur = self
while cur.next is not None:
cur = cur.next
cur.next = other
node_type.define(LinkedNode.class_type.instance_type)
first = LinkedNode(123, None)
self.assertEqual(first.data, 123)
self.assertIsNone(first.next)
second = LinkedNode(321, first)
first_meminfo = _get_meminfo(first)
second_meminfo = _get_meminfo(second)
self.assertEqual(first_meminfo.refcount, 3)
self.assertEqual(second.next.data, first.data)
self.assertEqual(first_meminfo.refcount, 3)
self.assertEqual(second_meminfo.refcount, 2)
# Test using deferred type as argument
first_val = second.get_next_data()
self.assertEqual(first_val, first.data)
# Check setattr (issue #2606)
self.assertIsNone(first.next)
second.append_to_tail(LinkedNode(567, None))
self.assertIsNotNone(first.next)
self.assertEqual(first.next.data, 567)
self.assertIsNone(first.next.next)
second.append_to_tail(LinkedNode(678, None))
self.assertIsNotNone(first.next.next)
self.assertEqual(first.next.next.data, 678)
# Check ownership
self.assertEqual(first_meminfo.refcount, 3)
del second, second_meminfo
self.assertEqual(first_meminfo.refcount, 2)
def test_c_structure(self):
spec = OrderedDict()
spec["a"] = int32
spec["b"] = int16
spec["c"] = float64
@jitclass(spec)
class Struct(object):
def __init__(self, a, b, c):
self.a = a
self.b = b
self.c = c
st = Struct(0xabcd, 0xef, 3.1415)
class CStruct(ctypes.Structure):
_fields_ = [
("a", ctypes.c_int32),
("b", ctypes.c_int16),
("c", ctypes.c_double),
]
ptr = ctypes.c_void_p(_box.box_get_dataptr(st))
cstruct = ctypes.cast(ptr, ctypes.POINTER(CStruct))[0]
self.assertEqual(cstruct.a, st.a)
self.assertEqual(cstruct.b, st.b)
self.assertEqual(cstruct.c, st.c)
def test_is(self):
Vector = self._make_Vector2()
vec_a = Vector(1, 2)
@njit
def do_is(a, b):
return a is b
with self.assertRaises(LoweringError) as raises:
# trigger compilation
do_is(vec_a, vec_a)
self.assertIn("no default `is` implementation", str(raises.exception))
def test_isinstance(self):
Vector2 = self._make_Vector2()
vec = Vector2(1, 2)
self.assertIsInstance(vec, Vector2)
def test_subclassing(self):
Vector2 = self._make_Vector2()
with self.assertRaises(TypeError) as raises:
class SubV(Vector2):
pass
self.assertEqual(str(raises.exception),
"cannot subclass from a jitclass")
def test_base_class(self):
class Base(object):
def what(self):
return self.attr
@jitclass([("attr", int32)])
class Test(Base):
def __init__(self, attr):
self.attr = attr
obj = Test(123)
self.assertEqual(obj.what(), 123)
def test_globals(self):
class Mine(object):
constant = 123
def __init__(self):
pass
with self.assertRaises(TypeError) as raises:
jitclass(Mine)
self.assertEqual(str(raises.exception),
"class members are not yet supported: constant")
def test_user_getter_setter(self):
@jitclass([("attr", int32)])
class Foo(object):
def __init__(self, attr):
self.attr = attr
@property
def value(self):
return self.attr + 1
@value.setter
def value(self, val):
self.attr = val - 1
foo = Foo(123)
self.assertEqual(foo.attr, 123)
# Getter
self.assertEqual(foo.value, 123 + 1)
# Setter
foo.value = 789
self.assertEqual(foo.attr, 789 - 1)
self.assertEqual(foo.value, 789)
# Test nopython mode usage of getter and setter
@njit
def bar(foo, val):
a = foo.value
foo.value = val
b = foo.value
c = foo.attr
return a, b, c
a, b, c = bar(foo, 567)
self.assertEqual(a, 789)
self.assertEqual(b, 567)
self.assertEqual(c, 567 - 1)
def test_user_deleter_error(self):
class Foo(object):
def __init__(self):
pass
@property
def value(self):
return 1
@value.deleter
def value(self):
pass
with self.assertRaises(TypeError) as raises:
jitclass(Foo)
self.assertEqual(str(raises.exception),
"deleter is not supported: value")
def test_name_shadowing_error(self):
class Foo(object):
def __init__(self):
pass
@property
def my_property(self):
pass
def my_method(self):
pass
with self.assertRaises(NameError) as raises:
jitclass(Foo, [("my_property", int32)])
self.assertEqual(str(raises.exception), "name shadowing: my_property")
with self.assertRaises(NameError) as raises:
jitclass(Foo, [("my_method", int32)])
self.assertEqual(str(raises.exception), "name shadowing: my_method")
def test_distinct_classes(self):
# Different classes with the same names shouldn't confuse the compiler
@jitclass([("x", int32)])
class Foo(object):
def __init__(self, x):
self.x = x + 2
def run(self):
return self.x + 1
FirstFoo = Foo
@jitclass([("x", int32)])
class Foo(object):
def __init__(self, x):
self.x = x - 2
def run(self):
return self.x - 1
SecondFoo = Foo
foo = FirstFoo(5)
self.assertEqual(foo.x, 7)
self.assertEqual(foo.run(), 8)
foo = SecondFoo(5)
self.assertEqual(foo.x, 3)
self.assertEqual(foo.run(), 2)
def test_parameterized(self):
class MyClass(object):
def __init__(self, value):
self.value = value
def create_my_class(value):
cls = jitclass(MyClass, [("value", typeof(value))])
return cls(value)
a = create_my_class(123)
self.assertEqual(a.value, 123)
b = create_my_class(12.3)
self.assertEqual(b.value, 12.3)
c = create_my_class(np.array([123]))
np.testing.assert_equal(c.value, [123])
d = create_my_class(np.array([12.3]))
np.testing.assert_equal(d.value, [12.3])
def test_protected_attrs(self):
spec = {
"value": int32,
"_value": float32,
"__value": int32,
"__value__": int32,
}
@jitclass(spec)
class MyClass(object):
def __init__(self, value):
self.value = value
self._value = value / 2
self.__value = value * 2
self.__value__ = value - 1
@property
def private_value(self):
return self.__value
@property
def _inner_value(self):
return self._value
@_inner_value.setter
def _inner_value(self, v):
self._value = v
@property
def __private_value(self):
return self.__value
@__private_value.setter
def __private_value(self, v):
self.__value = v
def swap_private_value(self, new):
old = self.__private_value
self.__private_value = new
return old
def _protected_method(self, factor):
return self._value * factor
def __private_method(self, factor):
return self.__value * factor
def check_private_method(self, factor):
return self.__private_method(factor)
value = 123
inst = MyClass(value)
# test attributes
self.assertEqual(inst.value, value)
self.assertEqual(inst._value, value / 2)
self.assertEqual(inst.private_value, value * 2)
# test properties
self.assertEqual(inst._inner_value, inst._value)
freeze_inst_value = inst._value
inst._inner_value -= 1
self.assertEqual(inst._inner_value, freeze_inst_value - 1)
self.assertEqual(inst.swap_private_value(321), value * 2)
self.assertEqual(inst.swap_private_value(value * 2), 321)
# test methods
self.assertEqual(inst._protected_method(3), inst._value * 3)
self.assertEqual(inst.check_private_method(3), inst.private_value * 3)
# test special
self.assertEqual(inst.__value__, value - 1)
inst.__value__ -= 100
self.assertEqual(inst.__value__, value - 101)
# test errors
@njit
def access_dunder(inst):
return inst.__value
with self.assertRaises(errors.TypingError) as raises:
access_dunder(inst)
# It will appear as "_TestJitClass__value" because the `access_dunder`
# is under the scope of "TestJitClass".
self.assertIn("_TestJitClass__value", str(raises.exception))
with self.assertRaises(AttributeError) as raises:
access_dunder.py_func(inst)
self.assertIn("_TestJitClass__value", str(raises.exception))
@skip_if_typeguard
def test_annotations(self):
"""
Methods with annotations should compile fine (issue #1911).
"""
from .annotation_usecases import AnnotatedClass
spec = {"x": int32}
cls = jitclass(AnnotatedClass, spec)
obj = cls(5)
self.assertEqual(obj.x, 5)
self.assertEqual(obj.add(2), 7)
def test_docstring(self):
@jitclass
class Apple(object):
"Class docstring"
def __init__(self):
"init docstring"
def foo(self):
"foo method docstring"
@property
def aval(self):
"aval property docstring"
self.assertEqual(Apple.__doc__, "Class docstring")
self.assertEqual(Apple.__init__.__doc__, "init docstring")
self.assertEqual(Apple.foo.__doc__, "foo method docstring")
self.assertEqual(Apple.aval.__doc__, "aval property docstring")
def test_kwargs(self):
spec = [("a", int32),
("b", float64)]
@jitclass(spec)
class TestClass(object):
def __init__(self, x, y, z):
self.a = x * y
self.b = z
x = 2
y = 2
z = 1.1
kwargs = {"y": y, "z": z}
tc = TestClass(x=2, **kwargs)
self.assertEqual(tc.a, x * y)
self.assertEqual(tc.b, z)
def test_default_args(self):
spec = [("x", int32),
("y", int32),
("z", int32)]
@jitclass(spec)
class TestClass(object):
def __init__(self, x, y, z=1):
self.x = x
self.y = y
self.z = z
tc = TestClass(1, 2, 3)
self.assertEqual(tc.x, 1)
self.assertEqual(tc.y, 2)
self.assertEqual(tc.z, 3)
tc = TestClass(1, 2)
self.assertEqual(tc.x, 1)
self.assertEqual(tc.y, 2)
self.assertEqual(tc.z, 1)
tc = TestClass(y=2, z=5, x=1)
self.assertEqual(tc.x, 1)
self.assertEqual(tc.y, 2)
self.assertEqual(tc.z, 5)
def test_default_args_keyonly(self):
spec = [("x", int32),
("y", int32),
("z", int32),
("a", int32)]
TestClass = jitclass(TestClass1, spec)
tc = TestClass(2, 3)
self.assertEqual(tc.x, 2)
self.assertEqual(tc.y, 3)
self.assertEqual(tc.z, 1)
self.assertEqual(tc.a, 5)
tc = TestClass(y=4, x=2, a=42, z=100)
self.assertEqual(tc.x, 2)
self.assertEqual(tc.y, 4)
self.assertEqual(tc.z, 100)
self.assertEqual(tc.a, 42)
tc = TestClass(y=4, x=2, a=42)
self.assertEqual(tc.x, 2)
self.assertEqual(tc.y, 4)
self.assertEqual(tc.z, 1)
self.assertEqual(tc.a, 42)
tc = TestClass(y=4, x=2)
self.assertEqual(tc.x, 2)
self.assertEqual(tc.y, 4)
self.assertEqual(tc.z, 1)
self.assertEqual(tc.a, 5)
def test_default_args_starargs_and_keyonly(self):
spec = [("x", int32),
("y", int32),
("z", int32),
("args", types.UniTuple(int32, 2)),
("a", int32)]
with self.assertRaises(errors.UnsupportedError) as raises:
jitclass(TestClass2, spec)
msg = "VAR_POSITIONAL argument type unsupported"
self.assertIn(msg, str(raises.exception))
def test_generator_method(self):
spec = []
@jitclass(spec)
class TestClass(object):
def __init__(self):
pass
def gen(self, niter):
for i in range(niter):
yield np.arange(i)
def expected_gen(niter):
for i in range(niter):
yield np.arange(i)
for niter in range(10):
for expect, got in zip(expected_gen(niter), TestClass().gen(niter)):
self.assertPreciseEqual(expect, got)
def test_getitem(self):
spec = [("data", int32[:])]
@jitclass(spec)
class TestClass(object):
def __init__(self):
self.data = np.zeros(10, dtype=np.int32)
def __setitem__(self, key, data):
self.data[key] = data
def __getitem__(self, key):
return self.data[key]
@njit
def create_and_set_indices():
t = TestClass()
t[1] = 1
t[2] = 2
t[3] = 3
return t
@njit
def get_index(t, n):
return t[n]
t = create_and_set_indices()
self.assertEqual(get_index(t, 1), 1)
self.assertEqual(get_index(t, 2), 2)
self.assertEqual(get_index(t, 3), 3)
def test_getitem_unbox(self):
spec = [("data", int32[:])]
@jitclass(spec)
class TestClass(object):
def __init__(self):
self.data = np.zeros(10, dtype=np.int32)
def __setitem__(self, key, data):
self.data[key] = data
def __getitem__(self, key):
return self.data[key]
t = TestClass()
t[1] = 10
@njit
def set2return1(t):
t[2] = 20
return t[1]
t_1 = set2return1(t)
self.assertEqual(t_1, 10)
self.assertEqual(t[2], 20)
def test_getitem_complex_key(self):
spec = [("data", int32[:, :])]
@jitclass(spec)
class TestClass(object):
def __init__(self):
self.data = np.zeros((10, 10), dtype=np.int32)
def __setitem__(self, key, data):
self.data[int(key.real), int(key.imag)] = data
def __getitem__(self, key):
return self.data[int(key.real), int(key.imag)]
t = TestClass()
t[complex(1, 1)] = 3
@njit
def get_key(t, real, imag):
return t[complex(real, imag)]
@njit
def set_key(t, real, imag, data):
t[complex(real, imag)] = data
self.assertEqual(get_key(t, 1, 1), 3)
set_key(t, 2, 2, 4)
self.assertEqual(t[complex(2, 2)], 4)
def test_getitem_tuple_key(self):
spec = [("data", int32[:, :])]
@jitclass(spec)
class TestClass(object):
def __init__(self):
self.data = np.zeros((10, 10), dtype=np.int32)
def __setitem__(self, key, data):
self.data[key[0], key[1]] = data
def __getitem__(self, key):
return self.data[key[0], key[1]]
t = TestClass()
t[1, 1] = 11
@njit
def get11(t):
return t[1, 1]
@njit
def set22(t, data):
t[2, 2] = data
self.assertEqual(get11(t), 11)
set22(t, 22)
self.assertEqual(t[2, 2], 22)
def test_getitem_slice_key(self):
spec = [("data", int32[:])]
@jitclass(spec)
class TestClass(object):
def __init__(self):
self.data = np.zeros(10, dtype=np.int32)
def __setitem__(self, slc, data):
self.data[slc.start] = data
self.data[slc.stop] = data + slc.step
def __getitem__(self, slc):
return self.data[slc.start]
t = TestClass()
# set t.data[1] = 1 and t.data[5] = 2
t[1:5:1] = 1
self.assertEqual(t[1:1:1], 1)
self.assertEqual(t[5:5:5], 2)
@njit
def get5(t):
return t[5:6:1]
self.assertEqual(get5(t), 2)
# sets t.data[2] = data, and t.data[6] = data + 1
@njit
def set26(t, data):
t[2:6:1] = data
set26(t, 2)
self.assertEqual(t[2:2:1], 2)
self.assertEqual(t[6:6:1], 3)
def test_jitclass_longlabel_not_truncated(self):
# See issue #3872, llvm 7 introduced a max label length of 1024 chars
# Numba ships patched llvm 7.1 (ppc64le) and patched llvm 8 to undo this
# change, this test is here to make sure long labels are ok:
alphabet = [chr(ord("a") + x) for x in range(26)]
spec = [(letter * 10, float64) for letter in alphabet]
spec.extend([(letter.upper() * 10, float64) for letter in alphabet])
@jitclass(spec)
class TruncatedLabel(object):
def __init__(self,):
self.aaaaaaaaaa = 10.
def meth1(self):
self.bbbbbbbbbb = random.gauss(self.aaaaaaaaaa, self.aaaaaaaaaa)
def meth2(self):
self.meth1()
# unpatched LLVMs will raise here...
TruncatedLabel().meth2()
def test_pickling(self):
@jitclass
class PickleTestSubject(object):
def __init__(self):
pass
inst = PickleTestSubject()
ty = typeof(inst)
self.assertIsInstance(ty, types.ClassInstanceType)
pickled = pickle.dumps(ty)
self.assertIs(pickle.loads(pickled), ty)
def test_static_methods(self):
@jitclass([("x", int32)])
class Test1:
def __init__(self, x):
self.x = x
def increase(self, y):
self.x = self.add(self.x, y)
return self.x
@staticmethod
def add(a, b):
return a + b
@staticmethod
def sub(a, b):
return a - b
@jitclass([("x", int32)])
class Test2:
def __init__(self, x):
self.x = x
def increase(self, y):
self.x = self.add(self.x, y)
return self.x
@staticmethod
def add(a, b):
return a - b
self.assertIsInstance(Test1.add, Dispatcher)
self.assertIsInstance(Test1.sub, Dispatcher)
self.assertIsInstance(Test2.add, Dispatcher)
self.assertNotEqual(Test1.add, Test2.add)
self.assertEqual(3, Test1.add(1, 2))
self.assertEqual(-1, Test2.add(1, 2))
self.assertEqual(4, Test1.sub(6, 2))
t1 = Test1(0)
t2 = Test2(0)
self.assertEqual(1, t1.increase(1))
self.assertEqual(-1, t2.increase(1))
self.assertEqual(2, t1.add(1, 1))
self.assertEqual(0, t1.sub(1, 1))
self.assertEqual(0, t2.add(1, 1))
self.assertEqual(2j, t1.add(1j, 1j))
self.assertEqual(1j, t1.sub(2j, 1j))
self.assertEqual("foobar", t1.add("foo", "bar"))
with self.assertRaises(AttributeError) as raises:
Test2.sub(3, 1)
self.assertIn("has no attribute 'sub'",
str(raises.exception))
with self.assertRaises(TypeError) as raises:
Test1.add(3)
self.assertIn("not enough arguments: expected 2, got 1",
str(raises.exception))
# Check error message for calling a static method as a class attr from
# another method (currently unsupported).
@jitclass([])
class Test3:
def __init__(self):
pass
@staticmethod
def a_static_method(a, b):
pass
def call_static(self):
return Test3.a_static_method(1, 2)
invalid = Test3()
with self.assertRaises(errors.TypingError) as raises:
invalid.call_static()
self.assertIn("Unknown attribute 'a_static_method'",
str(raises.exception))
def test_jitclass_decorator_usecases(self):
spec = OrderedDict(x=float64)
@jitclass()
class Test1:
x: float
def __init__(self):
self.x = 0
self.assertIsInstance(Test1, JitClassType)
self.assertDictEqual(Test1.class_type.struct, spec)
@jitclass(spec=spec)
class Test2:
def __init__(self):
self.x = 0
self.assertIsInstance(Test2, JitClassType)
self.assertDictEqual(Test2.class_type.struct, spec)
@jitclass
class Test3:
x: float
def __init__(self):
self.x = 0
self.assertIsInstance(Test3, JitClassType)
self.assertDictEqual(Test3.class_type.struct, spec)
@jitclass(spec)
class Test4:
def __init__(self):
self.x = 0
self.assertIsInstance(Test4, JitClassType)
self.assertDictEqual(Test4.class_type.struct, spec)
def test_jitclass_function_usecases(self):
spec = OrderedDict(x=float64)
class AnnotatedTest:
x: float
def __init__(self):
self.x = 0
JitTest1 = jitclass(AnnotatedTest)
self.assertIsInstance(JitTest1, JitClassType)
self.assertDictEqual(JitTest1.class_type.struct, spec)
class UnannotatedTest:
def __init__(self):
self.x = 0
JitTest2 = jitclass(UnannotatedTest, spec)
self.assertIsInstance(JitTest2, JitClassType)
self.assertDictEqual(JitTest2.class_type.struct, spec)
def test_jitclass_isinstance(self):
spec = OrderedDict(value=int32)
@jitclass(spec)
class Foo(object):
def __init__(self, value):
self.value = value
def getValue(self):
return self.value
def getValueIncr(self):
return self.value + 1
@jitclass(spec)
class Bar(object):
def __init__(self, value):
self.value = value
def getValue(self):
return self.value
def test_jitclass_isinstance(obj):
if isinstance(obj, (Foo, Bar)):
# call something that both classes implements
x = obj.getValue()
if isinstance(obj, Foo): # something that only Foo implements
return obj.getValueIncr() + x, 'Foo'
else:
return obj.getValue() + x, 'Bar'
else:
return 'no match'
pyfunc = test_jitclass_isinstance
cfunc = njit(test_jitclass_isinstance)
self.assertIsInstance(Foo, JitClassType)
self.assertEqual(pyfunc(Foo(3)), cfunc(Foo(3)))
self.assertEqual(pyfunc(Bar(123)), cfunc(Bar(123)))
self.assertEqual(pyfunc(0), cfunc(0))
def test_jitclass_unsupported_dunder(self):
with self.assertRaises(TypeError) as e:
@jitclass
class Foo(object):
def __init__(self):
return
def __enter__(self):
return None
Foo()
self.assertIn("Method '__enter__' is not supported.", str(e.exception))
def test_modulename(self):
@jitclass
class TestModname(object):
def __init__(self):
self.x = 12
thisModule = __name__
classModule = TestModname.__module__
self.assertEqual(thisModule, classModule)
class TestJitClassOverloads(MemoryLeakMixin, TestCase):
class PyList:
def __init__(self):
self.x = [0]
def append(self, y):
self.x.append(y)
def clear(self):
self.x.clear()
def __abs__(self):
return len(self.x) * 7
def __bool__(self):
return len(self.x) % 3 != 0
def __complex__(self):
c = complex(2)
if self.x:
c += self.x[0]
return c
def __contains__(self, y):
return y in self.x
def __float__(self):
f = 3.1415
if self.x:
f += self.x[0]
return f
def __int__(self):
i = 5
if self.x:
i += self.x[0]
return i
def __len__(self):
return len(self.x) + 1
def __str__(self):
if len(self.x) == 0:
return "PyList empty"
else:
return "PyList non-empty"
@staticmethod
def get_int_wrapper():
@jitclass([("x", types.intp)])
class IntWrapper:
def __init__(self, value):
self.x = value
def __eq__(self, other):
return self.x == other.x
def __hash__(self):
return self.x
def __lshift__(self, other):
return IntWrapper(self.x << other.x)
def __rshift__(self, other):
return IntWrapper(self.x >> other.x)
def __and__(self, other):
return IntWrapper(self.x & other.x)
def __or__(self, other):
return IntWrapper(self.x | other.x)
def __xor__(self, other):
return IntWrapper(self.x ^ other.x)
return IntWrapper
@staticmethod
def get_float_wrapper():
@jitclass([("x", types.float64)])
class FloatWrapper:
def __init__(self, value):
self.x = value
def __eq__(self, other):
return self.x == other.x
def __hash__(self):
return self.x
def __ge__(self, other):
return self.x >= other.x
def __gt__(self, other):
return self.x > other.x
def __le__(self, other):
return self.x <= other.x
def __lt__(self, other):
return self.x < other.x
def __add__(self, other):
return FloatWrapper(self.x + other.x)
def __floordiv__(self, other):
return FloatWrapper(self.x // other.x)
def __mod__(self, other):
return FloatWrapper(self.x % other.x)
def __mul__(self, other):
return FloatWrapper(self.x * other.x)
def __neg__(self, other):
return FloatWrapper(-self.x)
def __pos__(self, other):
return FloatWrapper(+self.x)
def __pow__(self, other):
return FloatWrapper(self.x ** other.x)
def __sub__(self, other):
return FloatWrapper(self.x - other.x)
def __truediv__(self, other):
return FloatWrapper(self.x / other.x)
return FloatWrapper
def assertSame(self, first, second, msg=None):
self.assertEqual(type(first), type(second), msg=msg)
self.assertEqual(first, second, msg=msg)
def test_overloads(self):
# Check that the dunder methods are exposed on ClassInstanceType.
JitList = jitclass({"x": types.List(types.intp)})(self.PyList)
py_funcs = [
lambda x: abs(x),
lambda x: x.__abs__(),
lambda x: bool(x),
lambda x: x.__bool__(),
lambda x: complex(x),
lambda x: x.__complex__(),
lambda x: 0 in x, # contains
lambda x: x.__contains__(0),
lambda x: float(x),
lambda x: x.__float__(),
lambda x: int(x),
lambda x: x.__int__(),
lambda x: len(x),
lambda x: x.__len__(),
lambda x: str(x),
lambda x: x.__str__(),
lambda x: 1 if x else 0, # truth
]
jit_funcs = [njit(f) for f in py_funcs]
py_list = self.PyList()
jit_list = JitList()
for py_f, jit_f in zip(py_funcs, jit_funcs):
self.assertSame(py_f(py_list), py_f(jit_list))
self.assertSame(py_f(py_list), jit_f(jit_list))
py_list.append(2)
jit_list.append(2)
for py_f, jit_f in zip(py_funcs, jit_funcs):
self.assertSame(py_f(py_list), py_f(jit_list))
self.assertSame(py_f(py_list), jit_f(jit_list))
py_list.append(-5)
jit_list.append(-5)
for py_f, jit_f in zip(py_funcs, jit_funcs):
self.assertSame(py_f(py_list), py_f(jit_list))
self.assertSame(py_f(py_list), jit_f(jit_list))
py_list.clear()
jit_list.clear()
for py_f, jit_f in zip(py_funcs, jit_funcs):
self.assertSame(py_f(py_list), py_f(jit_list))
self.assertSame(py_f(py_list), jit_f(jit_list))
def test_bool_fallback(self):
def py_b(x):
return bool(x)
jit_b = njit(py_b)
@jitclass([("x", types.List(types.intp))])
class LenClass:
def __init__(self, x):
self.x = x
def __len__(self):
return len(self.x) % 4
def append(self, y):
self.x.append(y)
def pop(self):
self.x.pop(0)
obj = LenClass([1, 2, 3])
self.assertTrue(py_b(obj))
self.assertTrue(jit_b(obj))
obj.append(4)
self.assertFalse(py_b(obj))
self.assertFalse(jit_b(obj))
obj.pop()
self.assertTrue(py_b(obj))
self.assertTrue(jit_b(obj))
@jitclass([("y", types.float64)])
class NormalClass:
def __init__(self, y):
self.y = y
obj = NormalClass(0)
self.assertTrue(py_b(obj))
self.assertTrue(jit_b(obj))
def test_numeric_fallback(self):
def py_c(x):
return complex(x)
def py_f(x):
return float(x)
def py_i(x):
return int(x)
jit_c = njit(py_c)
jit_f = njit(py_f)
jit_i = njit(py_i)
@jitclass([])
class FloatClass:
def __init__(self):
pass
def __float__(self):
return 3.1415
obj = FloatClass()
self.assertSame(py_c(obj), complex(3.1415))
self.assertSame(jit_c(obj), complex(3.1415))
self.assertSame(py_f(obj), 3.1415)
self.assertSame(jit_f(obj), 3.1415)
with self.assertRaises(TypeError) as e:
py_i(obj)
self.assertIn("int", str(e.exception))
with self.assertRaises(TypingError) as e:
jit_i(obj)
self.assertIn("int", str(e.exception))
@jitclass([])
class IntClass:
def __init__(self):
pass
def __int__(self):
return 7
obj = IntClass()
self.assertSame(py_i(obj), 7)
self.assertSame(jit_i(obj), 7)
with self.assertRaises(TypeError) as e:
py_c(obj)
self.assertIn("complex", str(e.exception))
with self.assertRaises(TypingError) as e:
jit_c(obj)
self.assertIn("complex", str(e.exception))
with self.assertRaises(TypeError) as e:
py_f(obj)
self.assertIn("float", str(e.exception))
with self.assertRaises(TypingError) as e:
jit_f(obj)
self.assertIn("float", str(e.exception))
@jitclass([])
class IndexClass:
def __init__(self):
pass
def __index__(self):
return 1
obj = IndexClass()
self.assertSame(py_c(obj), complex(1))
self.assertSame(jit_c(obj), complex(1))
self.assertSame(py_f(obj), 1.)
self.assertSame(jit_f(obj), 1.)
self.assertSame(py_i(obj), 1)
self.assertSame(jit_i(obj), 1)
@jitclass([])
class FloatIntIndexClass:
def __init__(self):
pass
def __float__(self):
return 3.1415
def __int__(self):
return 7
def __index__(self):
return 1
obj = FloatIntIndexClass()
self.assertSame(py_c(obj), complex(3.1415))
self.assertSame(jit_c(obj), complex(3.1415))
self.assertSame(py_f(obj), 3.1415)
self.assertSame(jit_f(obj), 3.1415)
self.assertSame(py_i(obj), 7)
self.assertSame(jit_i(obj), 7)
def test_arithmetic_logical(self):
IntWrapper = self.get_int_wrapper()
FloatWrapper = self.get_float_wrapper()
float_py_funcs = [
lambda x, y: x == y,
lambda x, y: x != y,
lambda x, y: x >= y,
lambda x, y: x > y,
lambda x, y: x <= y,
lambda x, y: x < y,
lambda x, y: x + y,
lambda x, y: x // y,
lambda x, y: x % y,
lambda x, y: x * y,
lambda x, y: x ** y,
lambda x, y: x - y,
lambda x, y: x / y,
]
int_py_funcs = [
lambda x, y: x == y,
lambda x, y: x != y,
lambda x, y: x << y,
lambda x, y: x >> y,
lambda x, y: x & y,
lambda x, y: x | y,
lambda x, y: x ^ y,
]
test_values = [
(0.0, 2.0),
(1.234, 3.1415),
(13.1, 1.01),
]
def unwrap(value):
return getattr(value, "x", value)
for jit_f, (x, y) in itertools.product(
map(njit, float_py_funcs), test_values):
py_f = jit_f.py_func
expected = py_f(x, y)
jit_x = FloatWrapper(x)
jit_y = FloatWrapper(y)
check = (
self.assertEqual
if type(expected) is not float
else self.assertAlmostEqual
)
check(expected, jit_f(x, y))
check(expected, unwrap(py_f(jit_x, jit_y)))
check(expected, unwrap(jit_f(jit_x, jit_y)))
for jit_f, (x, y) in itertools.product(
map(njit, int_py_funcs), test_values):
py_f = jit_f.py_func
x, y = int(x), int(y)
expected = py_f(x, y)
jit_x = IntWrapper(x)
jit_y = IntWrapper(y)
self.assertEqual(expected, jit_f(x, y))
self.assertEqual(expected, unwrap(py_f(jit_x, jit_y)))
self.assertEqual(expected, unwrap(jit_f(jit_x, jit_y)))
def test_arithmetic_logical_inplace(self):
# If __i*__ methods are not defined, should fall back to normal methods.
JitIntWrapper = self.get_int_wrapper()
JitFloatWrapper = self.get_float_wrapper()
PyIntWrapper = JitIntWrapper.mro()[1]
PyFloatWrapper = JitFloatWrapper.mro()[1]
@jitclass([("x", types.intp)])
class JitIntUpdateWrapper(PyIntWrapper):
def __init__(self, value):
self.x = value
def __ilshift__(self, other):
return JitIntUpdateWrapper(self.x << other.x)
def __irshift__(self, other):
return JitIntUpdateWrapper(self.x >> other.x)
def __iand__(self, other):
return JitIntUpdateWrapper(self.x & other.x)
def __ior__(self, other):
return JitIntUpdateWrapper(self.x | other.x)
def __ixor__(self, other):
return JitIntUpdateWrapper(self.x ^ other.x)
@jitclass({"x": types.float64})
class JitFloatUpdateWrapper(PyFloatWrapper):
def __init__(self, value):
self.x = value
def __iadd__(self, other):
return JitFloatUpdateWrapper(self.x + 2.718 * other.x)
def __ifloordiv__(self, other):
return JitFloatUpdateWrapper(self.x * 2.718 // other.x)
def __imod__(self, other):
return JitFloatUpdateWrapper(self.x % (other.x + 1))
def __imul__(self, other):
return JitFloatUpdateWrapper(self.x * other.x + 1)
def __ipow__(self, other):
return JitFloatUpdateWrapper(self.x ** other.x + 1)
def __isub__(self, other):
return JitFloatUpdateWrapper(self.x - 3.1415 * other.x)
def __itruediv__(self, other):
return JitFloatUpdateWrapper((self.x + 1) / other.x)
PyIntUpdateWrapper = JitIntUpdateWrapper.mro()[1]
PyFloatUpdateWrapper = JitFloatUpdateWrapper.mro()[1]
def get_update_func(op):
template = f"""
def f(x, y):
x {op}= y
return x
"""
namespace = {}
exec(template, namespace)
return namespace["f"]
float_py_funcs = [get_update_func(op) for op in [
"+", "//", "%", "*", "**", "-", "/",
]]
int_py_funcs = [get_update_func(op) for op in [
"<<", ">>", "&", "|", "^",
]]
test_values = [
(0.0, 2.0),
(1.234, 3.1415),
(13.1, 1.01),
]
for jit_f, (py_cls, jit_cls), (x, y) in itertools.product(
map(njit, float_py_funcs),
[
(PyFloatWrapper, JitFloatWrapper),
(PyFloatUpdateWrapper, JitFloatUpdateWrapper)
],
test_values):
py_f = jit_f.py_func
expected = py_f(py_cls(x), py_cls(y)).x
self.assertAlmostEqual(expected, py_f(jit_cls(x), jit_cls(y)).x)
self.assertAlmostEqual(expected, jit_f(jit_cls(x), jit_cls(y)).x)
for jit_f, (py_cls, jit_cls), (x, y) in itertools.product(
map(njit, int_py_funcs),
[
(PyIntWrapper, JitIntWrapper),
(PyIntUpdateWrapper, JitIntUpdateWrapper)
],
test_values):
x, y = int(x), int(y)
py_f = jit_f.py_func
expected = py_f(py_cls(x), py_cls(y)).x
self.assertEqual(expected, py_f(jit_cls(x), jit_cls(y)).x)
self.assertEqual(expected, jit_f(jit_cls(x), jit_cls(y)).x)
def test_hash_eq_ne(self):
class HashEqTest:
x: int
def __init__(self, x):
self.x = x
def __hash__(self):
return self.x % 10
def __eq__(self, o):
return (self.x - o.x) % 20 == 0
class HashEqNeTest(HashEqTest):
def __ne__(self, o):
return (self.x - o.x) % 20 > 1
def py_hash(x):
return hash(x)
def py_eq(x, y):
return x == y
def py_ne(x, y):
return x != y
def identity_decorator(f):
return f
comparisons = [
(0, 1), # Will give different ne results.
(2, 22),
(7, 10),
(3, 3),
]
for base_cls, use_jit in itertools.product(
[HashEqTest, HashEqNeTest], [False, True]
):
decorator = njit if use_jit else identity_decorator
hash_func = decorator(py_hash)
eq_func = decorator(py_eq)
ne_func = decorator(py_ne)
jit_cls = jitclass(base_cls)
for v in [0, 2, 10, 24, -8]:
self.assertEqual(hash_func(jit_cls(v)), v % 10)
for x, y in comparisons:
self.assertEqual(
eq_func(jit_cls(x), jit_cls(y)),
base_cls(x) == base_cls(y),
)
self.assertEqual(
ne_func(jit_cls(x), jit_cls(y)),
base_cls(x) != base_cls(y),
)
def test_bool_fallback_len(self):
# Check that the fallback to using len(obj) to determine truth of an
# object is implemented correctly as per
# https://docs.python.org/3/library/stdtypes.html#truth-value-testing
#
# Relevant points:
#
# "By default, an object is considered true unless its class defines
# either a __bool__() method that returns False or a __len__() method
# that returns zero, when called with the object."
#
# and:
#
# "Operations and built-in functions that have a Boolean result always
# return 0 or False for false and 1 or True for true, unless otherwise
# stated."
class NoBoolHasLen:
def __init__(self, val):
self.val = val
def __len__(self):
return self.val
def get_bool(self):
return bool(self)
py_class = NoBoolHasLen
jitted_class = jitclass([('val', types.int64)])(py_class)
py_class_0_bool = py_class(0).get_bool()
py_class_2_bool = py_class(2).get_bool()
jitted_class_0_bool = jitted_class(0).get_bool()
jitted_class_2_bool = jitted_class(2).get_bool()
# Truth values from bool(obj) should be equal
self.assertEqual(py_class_0_bool, jitted_class_0_bool)
self.assertEqual(py_class_2_bool, jitted_class_2_bool)
# Truth values from bool(obj) should be the same type
self.assertEqual(type(py_class_0_bool), type(jitted_class_0_bool))
self.assertEqual(type(py_class_2_bool), type(jitted_class_2_bool))
def test_bool_fallback_default(self):
# Similar to test_bool_fallback, but checks the case where there is no
# __bool__() or __len__() defined, so the object should always be True.
class NoBoolNoLen:
def __init__(self):
pass
def get_bool(self):
return bool(self)
py_class = NoBoolNoLen
jitted_class = jitclass([])(py_class)
py_class_bool = py_class().get_bool()
jitted_class_bool = jitted_class().get_bool()
# Truth values from bool(obj) should be equal
self.assertEqual(py_class_bool, jitted_class_bool)
# Truth values from bool(obj) should be the same type
self.assertEqual(type(py_class_bool), type(jitted_class_bool))
def test_operator_reflection(self):
class OperatorsDefined:
def __init__(self, x):
self.x = x
def __eq__(self, other):
return self.x == other.x
def __le__(self, other):
return self.x <= other.x
def __lt__(self, other):
return self.x < other.x
def __ge__(self, other):
return self.x >= other.x
def __gt__(self, other):
return self.x > other.x
class NoOperatorsDefined:
def __init__(self, x):
self.x = x
spec = [('x', types.int32)]
JitOperatorsDefined = jitclass(spec)(OperatorsDefined)
JitNoOperatorsDefined = jitclass(spec)(NoOperatorsDefined)
py_ops_defined = OperatorsDefined(2)
py_ops_not_defined = NoOperatorsDefined(3)
jit_ops_defined = JitOperatorsDefined(2)
jit_ops_not_defined = JitNoOperatorsDefined(3)
self.assertEqual(py_ops_not_defined == py_ops_defined,
jit_ops_not_defined == jit_ops_defined)
self.assertEqual(py_ops_not_defined <= py_ops_defined,
jit_ops_not_defined <= jit_ops_defined)
self.assertEqual(py_ops_not_defined < py_ops_defined,
jit_ops_not_defined < jit_ops_defined)
self.assertEqual(py_ops_not_defined >= py_ops_defined,
jit_ops_not_defined >= jit_ops_defined)
self.assertEqual(py_ops_not_defined > py_ops_defined,
jit_ops_not_defined > jit_ops_defined)
@skip_unless_scipy
def test_matmul_operator(self):
class ArrayAt:
def __init__(self, array):
self.arr = array
def __matmul__(self, other):
return self.arr @ other.arr
def __rmatmul__(self, other):
return other.arr @ self.arr
def __imatmul__(self, other):
self.arr = self.arr @ other.arr
return self
class ArrayNoAt:
def __init__(self, array):
self.arr = array
n = 3
np.random.seed(1)
vec = np.random.random(size=(n,))
mat = np.random.random(size=(n, n))
vector_noat = ArrayNoAt(vec)
vector_at = ArrayAt(vec)
jit_vector_noat = jitclass(ArrayNoAt, spec={"arr": float64[::1]})(vec)
jit_vector_at = jitclass(ArrayAt, spec={"arr": float64[::1]})(vec)
matrix_noat = ArrayNoAt(mat)
matrix_at = ArrayAt(mat)
jit_matrix_noat = jitclass(ArrayNoAt, spec={"arr": float64[:,::1]})(mat)
jit_matrix_at = jitclass(ArrayAt, spec={"arr": float64[:,::1]})(mat)
# __matmul__
np.testing.assert_allclose(vector_at @ vector_noat,
jit_vector_at @ jit_vector_noat)
np.testing.assert_allclose(vector_at @ matrix_noat,
jit_vector_at @ jit_matrix_noat)
np.testing.assert_allclose(matrix_at @ vector_noat,
jit_matrix_at @ jit_vector_noat)
np.testing.assert_allclose(matrix_at @ matrix_noat,
jit_matrix_at @ jit_matrix_noat)
# __rmatmul__
np.testing.assert_allclose(vector_noat @ vector_at,
jit_vector_noat @ jit_vector_at)
np.testing.assert_allclose(vector_noat @ matrix_at,
jit_vector_noat @ jit_matrix_at)
np.testing.assert_allclose(matrix_noat @ vector_at,
jit_matrix_noat @ jit_vector_at)
np.testing.assert_allclose(matrix_noat @ matrix_at,
jit_matrix_noat @ jit_matrix_at)
# __imatmul__
vector_at @= matrix_noat
matrix_at @= matrix_noat
jit_vector_at @= jit_matrix_noat
jit_matrix_at @= jit_matrix_noat
np.testing.assert_allclose(vector_at.arr, jit_vector_at.arr)
np.testing.assert_allclose(matrix_at.arr, jit_matrix_at.arr)
def test_arithmetic_logical_reflection(self):
class OperatorsDefined:
def __init__(self, x):
self.x = x
def __radd__(self, other):
return other.x + self.x
def __rsub__(self, other):
return other.x - self.x
def __rmul__(self, other):
return other.x * self.x
def __rtruediv__(self, other):
return other.x / self.x
def __rfloordiv__(self, other):
return other.x // self.x
def __rmod__(self, other):
return other.x % self.x
def __rpow__(self, other):
return other.x ** self.x
def __rlshift__(self, other):
return other.x << self.x
def __rrshift__(self, other):
return other.x >> self.x
def __rand__(self, other):
return other.x & self.x
def __rxor__(self, other):
return other.x ^ self.x
def __ror__(self, other):
return other.x | self.x
class NoOperatorsDefined:
def __init__(self, x):
self.x = x
float_op = ["+", "-", "*", "**", "/", "//", "%"]
int_op = [*float_op, "<<", ">>" , "&", "^", "|"]
for test_type, test_op, test_value in [
(int32, int_op, (2, 4)),
(float64, float_op, (2., 4.)),
(float64[::1], float_op,
(np.array([1., 2., 4.]), np.array([20., -24., 1.])))
]:
spec = {"x": test_type}
JitOperatorsDefined = jitclass(OperatorsDefined, spec)
JitNoOperatorsDefined = jitclass(NoOperatorsDefined, spec)
py_ops_defined = OperatorsDefined(test_value[0]) # noqa: F841
py_ops_not_defined = NoOperatorsDefined(test_value[1]) # noqa: F841
jit_ops_defined = JitOperatorsDefined(test_value[0]) # noqa: F841
jit_ops_not_defined = JitNoOperatorsDefined(test_value[1]) # noqa: F841 E501
for op in test_op:
if not ("array" in str(test_type)):
self.assertEqual(
eval(f"py_ops_not_defined {op} py_ops_defined"),
eval(f"jit_ops_not_defined {op} jit_ops_defined")
)
else:
self.assertTupleEqual(
tuple(eval(f"py_ops_not_defined {op} py_ops_defined")),
tuple(eval(f"jit_ops_not_defined {op} jit_ops_defined"))
)
def test_implicit_hash_compiles(self):
# Ensure that classes with __hash__ implicitly defined as None due to
# the presence of __eq__ are correctly handled by ignoring the __hash__
# class member.
class ImplicitHash:
def __init__(self):
pass
def __eq__(self, other):
return False
jitted = jitclass([])(ImplicitHash)
instance = jitted()
self.assertFalse(instance == instance)
if __name__ == "__main__":
unittest.main()