2021 lines
58 KiB
Python
2021 lines
58 KiB
Python
|
import ctypes
|
||
|
import itertools
|
||
|
import pickle
|
||
|
import random
|
||
|
import typing as pt
|
||
|
import unittest
|
||
|
|
||
|
from collections import OrderedDict
|
||
|
|
||
|
import numpy as np
|
||
|
from numba import (boolean, deferred_type, float32, float64, int16, int32,
|
||
|
njit, optional, typeof)
|
||
|
from numba.core import errors, types
|
||
|
from numba.core.dispatcher import Dispatcher
|
||
|
from numba.core.errors import LoweringError, TypingError
|
||
|
from numba.core.runtime.nrt import MemInfo
|
||
|
from numba.experimental import jitclass
|
||
|
from numba.experimental.jitclass import _box
|
||
|
from numba.experimental.jitclass.base import JitClassType
|
||
|
from numba.tests.support import MemoryLeakMixin, TestCase, skip_if_typeguard
|
||
|
from numba.tests.support import skip_unless_scipy
|
||
|
|
||
|
|
||
|
class TestClass1(object):
|
||
|
def __init__(self, x, y, z=1, *, a=5):
|
||
|
self.x = x
|
||
|
self.y = y
|
||
|
self.z = z
|
||
|
self.a = a
|
||
|
|
||
|
|
||
|
class TestClass2(object):
|
||
|
def __init__(self, x, y, z=1, *args, a=5):
|
||
|
self.x = x
|
||
|
self.y = y
|
||
|
self.z = z
|
||
|
self.args = args
|
||
|
self.a = a
|
||
|
|
||
|
|
||
|
def _get_meminfo(box):
|
||
|
ptr = _box.box_get_meminfoptr(box)
|
||
|
mi = MemInfo(ptr)
|
||
|
mi.acquire()
|
||
|
return mi
|
||
|
|
||
|
|
||
|
class TestJitClass(TestCase, MemoryLeakMixin):
|
||
|
|
||
|
def _check_spec(self, spec=None, test_cls=None, all_expected=None):
|
||
|
if test_cls is None:
|
||
|
@jitclass(spec)
|
||
|
class Test(object):
|
||
|
|
||
|
def __init__(self):
|
||
|
pass
|
||
|
test_cls = Test
|
||
|
|
||
|
clsty = test_cls.class_type.instance_type
|
||
|
names = list(clsty.struct.keys())
|
||
|
values = list(clsty.struct.values())
|
||
|
|
||
|
if all_expected is None:
|
||
|
if isinstance(spec, OrderedDict):
|
||
|
all_expected = spec.items()
|
||
|
else:
|
||
|
all_expected = spec
|
||
|
|
||
|
assert all_expected is not None
|
||
|
|
||
|
self.assertEqual(len(names), len(all_expected))
|
||
|
for got, expected in zip(zip(names, values), all_expected):
|
||
|
self.assertEqual(got[0], expected[0])
|
||
|
self.assertEqual(got[1], expected[1])
|
||
|
|
||
|
def test_ordereddict_spec(self):
|
||
|
spec = OrderedDict()
|
||
|
spec["x"] = int32
|
||
|
spec["y"] = float32
|
||
|
self._check_spec(spec)
|
||
|
|
||
|
def test_list_spec(self):
|
||
|
spec = [("x", int32),
|
||
|
("y", float32)]
|
||
|
self._check_spec(spec)
|
||
|
|
||
|
def test_type_annotations(self):
|
||
|
spec = [("x", int32)]
|
||
|
|
||
|
@jitclass(spec)
|
||
|
class Test1(object):
|
||
|
x: int
|
||
|
y: pt.List[float]
|
||
|
|
||
|
def __init__(self):
|
||
|
pass
|
||
|
|
||
|
self._check_spec(spec, Test1, spec + [("y", types.ListType(float64))])
|
||
|
|
||
|
def test_type_annotation_inheritance(self):
|
||
|
|
||
|
class Foo:
|
||
|
x: int
|
||
|
|
||
|
@jitclass
|
||
|
class Bar(Foo):
|
||
|
y: float
|
||
|
|
||
|
def __init__(self, value: float) -> None:
|
||
|
self.x = int(value)
|
||
|
self.y = value
|
||
|
|
||
|
self._check_spec(
|
||
|
test_cls=Bar, all_expected=[("x", typeof(0)), ("y", typeof(0.0))]
|
||
|
)
|
||
|
|
||
|
def test_spec_errors(self):
|
||
|
spec1 = [("x", int), ("y", float32[:])]
|
||
|
spec2 = [(1, int32), ("y", float32[:])]
|
||
|
|
||
|
class Test(object):
|
||
|
|
||
|
def __init__(self):
|
||
|
pass
|
||
|
|
||
|
with self.assertRaises(TypeError) as raises:
|
||
|
jitclass(Test, spec1)
|
||
|
self.assertIn("spec values should be Numba type instances",
|
||
|
str(raises.exception))
|
||
|
with self.assertRaises(TypeError) as raises:
|
||
|
jitclass(Test, spec2)
|
||
|
self.assertEqual(str(raises.exception),
|
||
|
"spec keys should be strings, got 1")
|
||
|
|
||
|
def test_init_errors(self):
|
||
|
|
||
|
@jitclass([])
|
||
|
class Test:
|
||
|
def __init__(self):
|
||
|
return 7
|
||
|
|
||
|
with self.assertRaises(errors.TypingError) as raises:
|
||
|
Test()
|
||
|
|
||
|
self.assertIn("__init__() should return None, not",
|
||
|
str(raises.exception))
|
||
|
|
||
|
def _make_Float2AndArray(self):
|
||
|
spec = OrderedDict()
|
||
|
spec["x"] = float32
|
||
|
spec["y"] = float32
|
||
|
spec["arr"] = float32[:]
|
||
|
|
||
|
@jitclass(spec)
|
||
|
class Float2AndArray(object):
|
||
|
|
||
|
def __init__(self, x, y, arr):
|
||
|
self.x = x
|
||
|
self.y = y
|
||
|
self.arr = arr
|
||
|
|
||
|
def add(self, val):
|
||
|
self.x += val
|
||
|
self.y += val
|
||
|
return val
|
||
|
|
||
|
return Float2AndArray
|
||
|
|
||
|
def _make_Vector2(self):
|
||
|
spec = OrderedDict()
|
||
|
spec["x"] = int32
|
||
|
spec["y"] = int32
|
||
|
|
||
|
@jitclass(spec)
|
||
|
class Vector2(object):
|
||
|
|
||
|
def __init__(self, x, y):
|
||
|
self.x = x
|
||
|
self.y = y
|
||
|
|
||
|
return Vector2
|
||
|
|
||
|
def test_jit_class_1(self):
|
||
|
Float2AndArray = self._make_Float2AndArray()
|
||
|
Vector2 = self._make_Vector2()
|
||
|
|
||
|
@njit
|
||
|
def bar(obj):
|
||
|
return obj.x + obj.y
|
||
|
|
||
|
@njit
|
||
|
def foo(a):
|
||
|
obj = Float2AndArray(1, 2, a)
|
||
|
obj.add(123)
|
||
|
|
||
|
vec = Vector2(3, 4)
|
||
|
return bar(obj), bar(vec), obj.arr
|
||
|
|
||
|
inp = np.ones(10, dtype=np.float32)
|
||
|
a, b, c = foo(inp)
|
||
|
self.assertEqual(a, 123 + 1 + 123 + 2)
|
||
|
self.assertEqual(b, 3 + 4)
|
||
|
self.assertPreciseEqual(c, inp)
|
||
|
|
||
|
def test_jitclass_usage_from_python(self):
|
||
|
Float2AndArray = self._make_Float2AndArray()
|
||
|
|
||
|
@njit
|
||
|
def identity(obj):
|
||
|
return obj
|
||
|
|
||
|
@njit
|
||
|
def retrieve_attributes(obj):
|
||
|
return obj.x, obj.y, obj.arr
|
||
|
|
||
|
arr = np.arange(10, dtype=np.float32)
|
||
|
obj = Float2AndArray(1, 2, arr)
|
||
|
obj_meminfo = _get_meminfo(obj)
|
||
|
self.assertEqual(obj_meminfo.refcount, 2)
|
||
|
self.assertEqual(obj_meminfo.data, _box.box_get_dataptr(obj))
|
||
|
self.assertEqual(obj._numba_type_.class_type,
|
||
|
Float2AndArray.class_type)
|
||
|
# Use jit class instance in numba
|
||
|
other = identity(obj)
|
||
|
other_meminfo = _get_meminfo(other) # duplicates MemInfo object to obj
|
||
|
self.assertEqual(obj_meminfo.refcount, 4)
|
||
|
self.assertEqual(other_meminfo.refcount, 4)
|
||
|
self.assertEqual(other_meminfo.data, _box.box_get_dataptr(other))
|
||
|
self.assertEqual(other_meminfo.data, obj_meminfo.data)
|
||
|
|
||
|
# Check dtor
|
||
|
del other, other_meminfo
|
||
|
self.assertEqual(obj_meminfo.refcount, 2)
|
||
|
|
||
|
# Check attributes
|
||
|
out_x, out_y, out_arr = retrieve_attributes(obj)
|
||
|
self.assertEqual(out_x, 1)
|
||
|
self.assertEqual(out_y, 2)
|
||
|
self.assertIs(out_arr, arr)
|
||
|
|
||
|
# Access attributes from python
|
||
|
self.assertEqual(obj.x, 1)
|
||
|
self.assertEqual(obj.y, 2)
|
||
|
self.assertIs(obj.arr, arr)
|
||
|
|
||
|
# Access methods from python
|
||
|
self.assertEqual(obj.add(123), 123)
|
||
|
self.assertEqual(obj.x, 1 + 123)
|
||
|
self.assertEqual(obj.y, 2 + 123)
|
||
|
|
||
|
# Setter from python
|
||
|
obj.x = 333
|
||
|
obj.y = 444
|
||
|
obj.arr = newarr = np.arange(5, dtype=np.float32)
|
||
|
self.assertEqual(obj.x, 333)
|
||
|
self.assertEqual(obj.y, 444)
|
||
|
self.assertIs(obj.arr, newarr)
|
||
|
|
||
|
def test_jitclass_datalayout(self):
|
||
|
spec = OrderedDict()
|
||
|
# Boolean has different layout as value vs data
|
||
|
spec["val"] = boolean
|
||
|
|
||
|
@jitclass(spec)
|
||
|
class Foo(object):
|
||
|
|
||
|
def __init__(self, val):
|
||
|
self.val = val
|
||
|
|
||
|
self.assertTrue(Foo(True).val)
|
||
|
self.assertFalse(Foo(False).val)
|
||
|
|
||
|
def test_deferred_type(self):
|
||
|
node_type = deferred_type()
|
||
|
|
||
|
spec = OrderedDict()
|
||
|
spec["data"] = float32
|
||
|
spec["next"] = optional(node_type)
|
||
|
|
||
|
@njit
|
||
|
def get_data(node):
|
||
|
return node.data
|
||
|
|
||
|
@jitclass(spec)
|
||
|
class LinkedNode(object):
|
||
|
|
||
|
def __init__(self, data, next):
|
||
|
self.data = data
|
||
|
self.next = next
|
||
|
|
||
|
def get_next_data(self):
|
||
|
# use deferred type as argument
|
||
|
return get_data(self.next)
|
||
|
|
||
|
def append_to_tail(self, other):
|
||
|
cur = self
|
||
|
while cur.next is not None:
|
||
|
cur = cur.next
|
||
|
cur.next = other
|
||
|
|
||
|
node_type.define(LinkedNode.class_type.instance_type)
|
||
|
|
||
|
first = LinkedNode(123, None)
|
||
|
self.assertEqual(first.data, 123)
|
||
|
self.assertIsNone(first.next)
|
||
|
|
||
|
second = LinkedNode(321, first)
|
||
|
|
||
|
first_meminfo = _get_meminfo(first)
|
||
|
second_meminfo = _get_meminfo(second)
|
||
|
self.assertEqual(first_meminfo.refcount, 3)
|
||
|
self.assertEqual(second.next.data, first.data)
|
||
|
self.assertEqual(first_meminfo.refcount, 3)
|
||
|
self.assertEqual(second_meminfo.refcount, 2)
|
||
|
|
||
|
# Test using deferred type as argument
|
||
|
first_val = second.get_next_data()
|
||
|
self.assertEqual(first_val, first.data)
|
||
|
|
||
|
# Check setattr (issue #2606)
|
||
|
self.assertIsNone(first.next)
|
||
|
second.append_to_tail(LinkedNode(567, None))
|
||
|
self.assertIsNotNone(first.next)
|
||
|
self.assertEqual(first.next.data, 567)
|
||
|
self.assertIsNone(first.next.next)
|
||
|
second.append_to_tail(LinkedNode(678, None))
|
||
|
self.assertIsNotNone(first.next.next)
|
||
|
self.assertEqual(first.next.next.data, 678)
|
||
|
|
||
|
# Check ownership
|
||
|
self.assertEqual(first_meminfo.refcount, 3)
|
||
|
del second, second_meminfo
|
||
|
self.assertEqual(first_meminfo.refcount, 2)
|
||
|
|
||
|
def test_c_structure(self):
|
||
|
spec = OrderedDict()
|
||
|
spec["a"] = int32
|
||
|
spec["b"] = int16
|
||
|
spec["c"] = float64
|
||
|
|
||
|
@jitclass(spec)
|
||
|
class Struct(object):
|
||
|
|
||
|
def __init__(self, a, b, c):
|
||
|
self.a = a
|
||
|
self.b = b
|
||
|
self.c = c
|
||
|
|
||
|
st = Struct(0xabcd, 0xef, 3.1415)
|
||
|
|
||
|
class CStruct(ctypes.Structure):
|
||
|
_fields_ = [
|
||
|
("a", ctypes.c_int32),
|
||
|
("b", ctypes.c_int16),
|
||
|
("c", ctypes.c_double),
|
||
|
]
|
||
|
|
||
|
ptr = ctypes.c_void_p(_box.box_get_dataptr(st))
|
||
|
cstruct = ctypes.cast(ptr, ctypes.POINTER(CStruct))[0]
|
||
|
self.assertEqual(cstruct.a, st.a)
|
||
|
self.assertEqual(cstruct.b, st.b)
|
||
|
self.assertEqual(cstruct.c, st.c)
|
||
|
|
||
|
def test_is(self):
|
||
|
Vector = self._make_Vector2()
|
||
|
vec_a = Vector(1, 2)
|
||
|
|
||
|
@njit
|
||
|
def do_is(a, b):
|
||
|
return a is b
|
||
|
|
||
|
with self.assertRaises(LoweringError) as raises:
|
||
|
# trigger compilation
|
||
|
do_is(vec_a, vec_a)
|
||
|
self.assertIn("no default `is` implementation", str(raises.exception))
|
||
|
|
||
|
def test_isinstance(self):
|
||
|
Vector2 = self._make_Vector2()
|
||
|
vec = Vector2(1, 2)
|
||
|
self.assertIsInstance(vec, Vector2)
|
||
|
|
||
|
def test_subclassing(self):
|
||
|
Vector2 = self._make_Vector2()
|
||
|
with self.assertRaises(TypeError) as raises:
|
||
|
class SubV(Vector2):
|
||
|
pass
|
||
|
self.assertEqual(str(raises.exception),
|
||
|
"cannot subclass from a jitclass")
|
||
|
|
||
|
def test_base_class(self):
|
||
|
class Base(object):
|
||
|
|
||
|
def what(self):
|
||
|
return self.attr
|
||
|
|
||
|
@jitclass([("attr", int32)])
|
||
|
class Test(Base):
|
||
|
|
||
|
def __init__(self, attr):
|
||
|
self.attr = attr
|
||
|
|
||
|
obj = Test(123)
|
||
|
self.assertEqual(obj.what(), 123)
|
||
|
|
||
|
def test_globals(self):
|
||
|
|
||
|
class Mine(object):
|
||
|
constant = 123
|
||
|
|
||
|
def __init__(self):
|
||
|
pass
|
||
|
|
||
|
with self.assertRaises(TypeError) as raises:
|
||
|
jitclass(Mine)
|
||
|
|
||
|
self.assertEqual(str(raises.exception),
|
||
|
"class members are not yet supported: constant")
|
||
|
|
||
|
def test_user_getter_setter(self):
|
||
|
@jitclass([("attr", int32)])
|
||
|
class Foo(object):
|
||
|
|
||
|
def __init__(self, attr):
|
||
|
self.attr = attr
|
||
|
|
||
|
@property
|
||
|
def value(self):
|
||
|
return self.attr + 1
|
||
|
|
||
|
@value.setter
|
||
|
def value(self, val):
|
||
|
self.attr = val - 1
|
||
|
|
||
|
foo = Foo(123)
|
||
|
self.assertEqual(foo.attr, 123)
|
||
|
# Getter
|
||
|
self.assertEqual(foo.value, 123 + 1)
|
||
|
# Setter
|
||
|
foo.value = 789
|
||
|
self.assertEqual(foo.attr, 789 - 1)
|
||
|
self.assertEqual(foo.value, 789)
|
||
|
|
||
|
# Test nopython mode usage of getter and setter
|
||
|
@njit
|
||
|
def bar(foo, val):
|
||
|
a = foo.value
|
||
|
foo.value = val
|
||
|
b = foo.value
|
||
|
c = foo.attr
|
||
|
return a, b, c
|
||
|
|
||
|
a, b, c = bar(foo, 567)
|
||
|
self.assertEqual(a, 789)
|
||
|
self.assertEqual(b, 567)
|
||
|
self.assertEqual(c, 567 - 1)
|
||
|
|
||
|
def test_user_deleter_error(self):
|
||
|
class Foo(object):
|
||
|
|
||
|
def __init__(self):
|
||
|
pass
|
||
|
|
||
|
@property
|
||
|
def value(self):
|
||
|
return 1
|
||
|
|
||
|
@value.deleter
|
||
|
def value(self):
|
||
|
pass
|
||
|
|
||
|
with self.assertRaises(TypeError) as raises:
|
||
|
jitclass(Foo)
|
||
|
self.assertEqual(str(raises.exception),
|
||
|
"deleter is not supported: value")
|
||
|
|
||
|
def test_name_shadowing_error(self):
|
||
|
class Foo(object):
|
||
|
|
||
|
def __init__(self):
|
||
|
pass
|
||
|
|
||
|
@property
|
||
|
def my_property(self):
|
||
|
pass
|
||
|
|
||
|
def my_method(self):
|
||
|
pass
|
||
|
|
||
|
with self.assertRaises(NameError) as raises:
|
||
|
jitclass(Foo, [("my_property", int32)])
|
||
|
self.assertEqual(str(raises.exception), "name shadowing: my_property")
|
||
|
|
||
|
with self.assertRaises(NameError) as raises:
|
||
|
jitclass(Foo, [("my_method", int32)])
|
||
|
self.assertEqual(str(raises.exception), "name shadowing: my_method")
|
||
|
|
||
|
def test_distinct_classes(self):
|
||
|
# Different classes with the same names shouldn't confuse the compiler
|
||
|
@jitclass([("x", int32)])
|
||
|
class Foo(object):
|
||
|
|
||
|
def __init__(self, x):
|
||
|
self.x = x + 2
|
||
|
|
||
|
def run(self):
|
||
|
return self.x + 1
|
||
|
|
||
|
FirstFoo = Foo
|
||
|
|
||
|
@jitclass([("x", int32)])
|
||
|
class Foo(object):
|
||
|
|
||
|
def __init__(self, x):
|
||
|
self.x = x - 2
|
||
|
|
||
|
def run(self):
|
||
|
return self.x - 1
|
||
|
|
||
|
SecondFoo = Foo
|
||
|
foo = FirstFoo(5)
|
||
|
self.assertEqual(foo.x, 7)
|
||
|
self.assertEqual(foo.run(), 8)
|
||
|
foo = SecondFoo(5)
|
||
|
self.assertEqual(foo.x, 3)
|
||
|
self.assertEqual(foo.run(), 2)
|
||
|
|
||
|
def test_parameterized(self):
|
||
|
class MyClass(object):
|
||
|
|
||
|
def __init__(self, value):
|
||
|
self.value = value
|
||
|
|
||
|
def create_my_class(value):
|
||
|
cls = jitclass(MyClass, [("value", typeof(value))])
|
||
|
return cls(value)
|
||
|
|
||
|
a = create_my_class(123)
|
||
|
self.assertEqual(a.value, 123)
|
||
|
|
||
|
b = create_my_class(12.3)
|
||
|
self.assertEqual(b.value, 12.3)
|
||
|
|
||
|
c = create_my_class(np.array([123]))
|
||
|
np.testing.assert_equal(c.value, [123])
|
||
|
|
||
|
d = create_my_class(np.array([12.3]))
|
||
|
np.testing.assert_equal(d.value, [12.3])
|
||
|
|
||
|
def test_protected_attrs(self):
|
||
|
spec = {
|
||
|
"value": int32,
|
||
|
"_value": float32,
|
||
|
"__value": int32,
|
||
|
"__value__": int32,
|
||
|
}
|
||
|
|
||
|
@jitclass(spec)
|
||
|
class MyClass(object):
|
||
|
|
||
|
def __init__(self, value):
|
||
|
self.value = value
|
||
|
self._value = value / 2
|
||
|
self.__value = value * 2
|
||
|
self.__value__ = value - 1
|
||
|
|
||
|
@property
|
||
|
def private_value(self):
|
||
|
return self.__value
|
||
|
|
||
|
@property
|
||
|
def _inner_value(self):
|
||
|
return self._value
|
||
|
|
||
|
@_inner_value.setter
|
||
|
def _inner_value(self, v):
|
||
|
self._value = v
|
||
|
|
||
|
@property
|
||
|
def __private_value(self):
|
||
|
return self.__value
|
||
|
|
||
|
@__private_value.setter
|
||
|
def __private_value(self, v):
|
||
|
self.__value = v
|
||
|
|
||
|
def swap_private_value(self, new):
|
||
|
old = self.__private_value
|
||
|
self.__private_value = new
|
||
|
return old
|
||
|
|
||
|
def _protected_method(self, factor):
|
||
|
return self._value * factor
|
||
|
|
||
|
def __private_method(self, factor):
|
||
|
return self.__value * factor
|
||
|
|
||
|
def check_private_method(self, factor):
|
||
|
return self.__private_method(factor)
|
||
|
|
||
|
value = 123
|
||
|
inst = MyClass(value)
|
||
|
# test attributes
|
||
|
self.assertEqual(inst.value, value)
|
||
|
self.assertEqual(inst._value, value / 2)
|
||
|
self.assertEqual(inst.private_value, value * 2)
|
||
|
# test properties
|
||
|
self.assertEqual(inst._inner_value, inst._value)
|
||
|
freeze_inst_value = inst._value
|
||
|
inst._inner_value -= 1
|
||
|
self.assertEqual(inst._inner_value, freeze_inst_value - 1)
|
||
|
|
||
|
self.assertEqual(inst.swap_private_value(321), value * 2)
|
||
|
self.assertEqual(inst.swap_private_value(value * 2), 321)
|
||
|
# test methods
|
||
|
self.assertEqual(inst._protected_method(3), inst._value * 3)
|
||
|
self.assertEqual(inst.check_private_method(3), inst.private_value * 3)
|
||
|
# test special
|
||
|
self.assertEqual(inst.__value__, value - 1)
|
||
|
inst.__value__ -= 100
|
||
|
self.assertEqual(inst.__value__, value - 101)
|
||
|
|
||
|
# test errors
|
||
|
@njit
|
||
|
def access_dunder(inst):
|
||
|
return inst.__value
|
||
|
|
||
|
with self.assertRaises(errors.TypingError) as raises:
|
||
|
access_dunder(inst)
|
||
|
# It will appear as "_TestJitClass__value" because the `access_dunder`
|
||
|
# is under the scope of "TestJitClass".
|
||
|
self.assertIn("_TestJitClass__value", str(raises.exception))
|
||
|
|
||
|
with self.assertRaises(AttributeError) as raises:
|
||
|
access_dunder.py_func(inst)
|
||
|
self.assertIn("_TestJitClass__value", str(raises.exception))
|
||
|
|
||
|
@skip_if_typeguard
|
||
|
def test_annotations(self):
|
||
|
"""
|
||
|
Methods with annotations should compile fine (issue #1911).
|
||
|
"""
|
||
|
from .annotation_usecases import AnnotatedClass
|
||
|
|
||
|
spec = {"x": int32}
|
||
|
cls = jitclass(AnnotatedClass, spec)
|
||
|
|
||
|
obj = cls(5)
|
||
|
self.assertEqual(obj.x, 5)
|
||
|
self.assertEqual(obj.add(2), 7)
|
||
|
|
||
|
def test_docstring(self):
|
||
|
|
||
|
@jitclass
|
||
|
class Apple(object):
|
||
|
"Class docstring"
|
||
|
|
||
|
def __init__(self):
|
||
|
"init docstring"
|
||
|
|
||
|
def foo(self):
|
||
|
"foo method docstring"
|
||
|
|
||
|
@property
|
||
|
def aval(self):
|
||
|
"aval property docstring"
|
||
|
|
||
|
self.assertEqual(Apple.__doc__, "Class docstring")
|
||
|
self.assertEqual(Apple.__init__.__doc__, "init docstring")
|
||
|
self.assertEqual(Apple.foo.__doc__, "foo method docstring")
|
||
|
self.assertEqual(Apple.aval.__doc__, "aval property docstring")
|
||
|
|
||
|
def test_kwargs(self):
|
||
|
spec = [("a", int32),
|
||
|
("b", float64)]
|
||
|
|
||
|
@jitclass(spec)
|
||
|
class TestClass(object):
|
||
|
def __init__(self, x, y, z):
|
||
|
self.a = x * y
|
||
|
self.b = z
|
||
|
|
||
|
x = 2
|
||
|
y = 2
|
||
|
z = 1.1
|
||
|
kwargs = {"y": y, "z": z}
|
||
|
tc = TestClass(x=2, **kwargs)
|
||
|
self.assertEqual(tc.a, x * y)
|
||
|
self.assertEqual(tc.b, z)
|
||
|
|
||
|
def test_default_args(self):
|
||
|
spec = [("x", int32),
|
||
|
("y", int32),
|
||
|
("z", int32)]
|
||
|
|
||
|
@jitclass(spec)
|
||
|
class TestClass(object):
|
||
|
def __init__(self, x, y, z=1):
|
||
|
self.x = x
|
||
|
self.y = y
|
||
|
self.z = z
|
||
|
|
||
|
tc = TestClass(1, 2, 3)
|
||
|
self.assertEqual(tc.x, 1)
|
||
|
self.assertEqual(tc.y, 2)
|
||
|
self.assertEqual(tc.z, 3)
|
||
|
|
||
|
tc = TestClass(1, 2)
|
||
|
self.assertEqual(tc.x, 1)
|
||
|
self.assertEqual(tc.y, 2)
|
||
|
self.assertEqual(tc.z, 1)
|
||
|
|
||
|
tc = TestClass(y=2, z=5, x=1)
|
||
|
self.assertEqual(tc.x, 1)
|
||
|
self.assertEqual(tc.y, 2)
|
||
|
self.assertEqual(tc.z, 5)
|
||
|
|
||
|
def test_default_args_keyonly(self):
|
||
|
spec = [("x", int32),
|
||
|
("y", int32),
|
||
|
("z", int32),
|
||
|
("a", int32)]
|
||
|
|
||
|
TestClass = jitclass(TestClass1, spec)
|
||
|
|
||
|
tc = TestClass(2, 3)
|
||
|
self.assertEqual(tc.x, 2)
|
||
|
self.assertEqual(tc.y, 3)
|
||
|
self.assertEqual(tc.z, 1)
|
||
|
self.assertEqual(tc.a, 5)
|
||
|
|
||
|
tc = TestClass(y=4, x=2, a=42, z=100)
|
||
|
self.assertEqual(tc.x, 2)
|
||
|
self.assertEqual(tc.y, 4)
|
||
|
self.assertEqual(tc.z, 100)
|
||
|
self.assertEqual(tc.a, 42)
|
||
|
|
||
|
tc = TestClass(y=4, x=2, a=42)
|
||
|
self.assertEqual(tc.x, 2)
|
||
|
self.assertEqual(tc.y, 4)
|
||
|
self.assertEqual(tc.z, 1)
|
||
|
self.assertEqual(tc.a, 42)
|
||
|
|
||
|
tc = TestClass(y=4, x=2)
|
||
|
self.assertEqual(tc.x, 2)
|
||
|
self.assertEqual(tc.y, 4)
|
||
|
self.assertEqual(tc.z, 1)
|
||
|
self.assertEqual(tc.a, 5)
|
||
|
|
||
|
def test_default_args_starargs_and_keyonly(self):
|
||
|
spec = [("x", int32),
|
||
|
("y", int32),
|
||
|
("z", int32),
|
||
|
("args", types.UniTuple(int32, 2)),
|
||
|
("a", int32)]
|
||
|
|
||
|
with self.assertRaises(errors.UnsupportedError) as raises:
|
||
|
jitclass(TestClass2, spec)
|
||
|
|
||
|
msg = "VAR_POSITIONAL argument type unsupported"
|
||
|
self.assertIn(msg, str(raises.exception))
|
||
|
|
||
|
def test_generator_method(self):
|
||
|
spec = []
|
||
|
|
||
|
@jitclass(spec)
|
||
|
class TestClass(object):
|
||
|
def __init__(self):
|
||
|
pass
|
||
|
|
||
|
def gen(self, niter):
|
||
|
for i in range(niter):
|
||
|
yield np.arange(i)
|
||
|
|
||
|
def expected_gen(niter):
|
||
|
for i in range(niter):
|
||
|
yield np.arange(i)
|
||
|
|
||
|
for niter in range(10):
|
||
|
for expect, got in zip(expected_gen(niter), TestClass().gen(niter)):
|
||
|
self.assertPreciseEqual(expect, got)
|
||
|
|
||
|
def test_getitem(self):
|
||
|
spec = [("data", int32[:])]
|
||
|
|
||
|
@jitclass(spec)
|
||
|
class TestClass(object):
|
||
|
def __init__(self):
|
||
|
self.data = np.zeros(10, dtype=np.int32)
|
||
|
|
||
|
def __setitem__(self, key, data):
|
||
|
self.data[key] = data
|
||
|
|
||
|
def __getitem__(self, key):
|
||
|
return self.data[key]
|
||
|
|
||
|
@njit
|
||
|
def create_and_set_indices():
|
||
|
t = TestClass()
|
||
|
t[1] = 1
|
||
|
t[2] = 2
|
||
|
t[3] = 3
|
||
|
return t
|
||
|
|
||
|
@njit
|
||
|
def get_index(t, n):
|
||
|
return t[n]
|
||
|
|
||
|
t = create_and_set_indices()
|
||
|
self.assertEqual(get_index(t, 1), 1)
|
||
|
self.assertEqual(get_index(t, 2), 2)
|
||
|
self.assertEqual(get_index(t, 3), 3)
|
||
|
|
||
|
def test_getitem_unbox(self):
|
||
|
spec = [("data", int32[:])]
|
||
|
|
||
|
@jitclass(spec)
|
||
|
class TestClass(object):
|
||
|
def __init__(self):
|
||
|
self.data = np.zeros(10, dtype=np.int32)
|
||
|
|
||
|
def __setitem__(self, key, data):
|
||
|
self.data[key] = data
|
||
|
|
||
|
def __getitem__(self, key):
|
||
|
return self.data[key]
|
||
|
|
||
|
t = TestClass()
|
||
|
t[1] = 10
|
||
|
|
||
|
@njit
|
||
|
def set2return1(t):
|
||
|
t[2] = 20
|
||
|
return t[1]
|
||
|
|
||
|
t_1 = set2return1(t)
|
||
|
self.assertEqual(t_1, 10)
|
||
|
self.assertEqual(t[2], 20)
|
||
|
|
||
|
def test_getitem_complex_key(self):
|
||
|
spec = [("data", int32[:, :])]
|
||
|
|
||
|
@jitclass(spec)
|
||
|
class TestClass(object):
|
||
|
def __init__(self):
|
||
|
self.data = np.zeros((10, 10), dtype=np.int32)
|
||
|
|
||
|
def __setitem__(self, key, data):
|
||
|
self.data[int(key.real), int(key.imag)] = data
|
||
|
|
||
|
def __getitem__(self, key):
|
||
|
return self.data[int(key.real), int(key.imag)]
|
||
|
|
||
|
t = TestClass()
|
||
|
|
||
|
t[complex(1, 1)] = 3
|
||
|
|
||
|
@njit
|
||
|
def get_key(t, real, imag):
|
||
|
return t[complex(real, imag)]
|
||
|
|
||
|
@njit
|
||
|
def set_key(t, real, imag, data):
|
||
|
t[complex(real, imag)] = data
|
||
|
|
||
|
self.assertEqual(get_key(t, 1, 1), 3)
|
||
|
set_key(t, 2, 2, 4)
|
||
|
self.assertEqual(t[complex(2, 2)], 4)
|
||
|
|
||
|
def test_getitem_tuple_key(self):
|
||
|
spec = [("data", int32[:, :])]
|
||
|
|
||
|
@jitclass(spec)
|
||
|
class TestClass(object):
|
||
|
def __init__(self):
|
||
|
self.data = np.zeros((10, 10), dtype=np.int32)
|
||
|
|
||
|
def __setitem__(self, key, data):
|
||
|
self.data[key[0], key[1]] = data
|
||
|
|
||
|
def __getitem__(self, key):
|
||
|
return self.data[key[0], key[1]]
|
||
|
|
||
|
t = TestClass()
|
||
|
t[1, 1] = 11
|
||
|
|
||
|
@njit
|
||
|
def get11(t):
|
||
|
return t[1, 1]
|
||
|
|
||
|
@njit
|
||
|
def set22(t, data):
|
||
|
t[2, 2] = data
|
||
|
|
||
|
self.assertEqual(get11(t), 11)
|
||
|
set22(t, 22)
|
||
|
self.assertEqual(t[2, 2], 22)
|
||
|
|
||
|
def test_getitem_slice_key(self):
|
||
|
spec = [("data", int32[:])]
|
||
|
|
||
|
@jitclass(spec)
|
||
|
class TestClass(object):
|
||
|
def __init__(self):
|
||
|
self.data = np.zeros(10, dtype=np.int32)
|
||
|
|
||
|
def __setitem__(self, slc, data):
|
||
|
self.data[slc.start] = data
|
||
|
self.data[slc.stop] = data + slc.step
|
||
|
|
||
|
def __getitem__(self, slc):
|
||
|
return self.data[slc.start]
|
||
|
|
||
|
t = TestClass()
|
||
|
# set t.data[1] = 1 and t.data[5] = 2
|
||
|
t[1:5:1] = 1
|
||
|
|
||
|
self.assertEqual(t[1:1:1], 1)
|
||
|
self.assertEqual(t[5:5:5], 2)
|
||
|
|
||
|
@njit
|
||
|
def get5(t):
|
||
|
return t[5:6:1]
|
||
|
|
||
|
self.assertEqual(get5(t), 2)
|
||
|
|
||
|
# sets t.data[2] = data, and t.data[6] = data + 1
|
||
|
@njit
|
||
|
def set26(t, data):
|
||
|
t[2:6:1] = data
|
||
|
|
||
|
set26(t, 2)
|
||
|
self.assertEqual(t[2:2:1], 2)
|
||
|
self.assertEqual(t[6:6:1], 3)
|
||
|
|
||
|
def test_jitclass_longlabel_not_truncated(self):
|
||
|
# See issue #3872, llvm 7 introduced a max label length of 1024 chars
|
||
|
# Numba ships patched llvm 7.1 (ppc64le) and patched llvm 8 to undo this
|
||
|
# change, this test is here to make sure long labels are ok:
|
||
|
alphabet = [chr(ord("a") + x) for x in range(26)]
|
||
|
|
||
|
spec = [(letter * 10, float64) for letter in alphabet]
|
||
|
spec.extend([(letter.upper() * 10, float64) for letter in alphabet])
|
||
|
|
||
|
@jitclass(spec)
|
||
|
class TruncatedLabel(object):
|
||
|
def __init__(self,):
|
||
|
self.aaaaaaaaaa = 10.
|
||
|
|
||
|
def meth1(self):
|
||
|
self.bbbbbbbbbb = random.gauss(self.aaaaaaaaaa, self.aaaaaaaaaa)
|
||
|
|
||
|
def meth2(self):
|
||
|
self.meth1()
|
||
|
|
||
|
# unpatched LLVMs will raise here...
|
||
|
TruncatedLabel().meth2()
|
||
|
|
||
|
def test_pickling(self):
|
||
|
@jitclass
|
||
|
class PickleTestSubject(object):
|
||
|
def __init__(self):
|
||
|
pass
|
||
|
|
||
|
inst = PickleTestSubject()
|
||
|
ty = typeof(inst)
|
||
|
self.assertIsInstance(ty, types.ClassInstanceType)
|
||
|
pickled = pickle.dumps(ty)
|
||
|
self.assertIs(pickle.loads(pickled), ty)
|
||
|
|
||
|
def test_static_methods(self):
|
||
|
@jitclass([("x", int32)])
|
||
|
class Test1:
|
||
|
def __init__(self, x):
|
||
|
self.x = x
|
||
|
|
||
|
def increase(self, y):
|
||
|
self.x = self.add(self.x, y)
|
||
|
return self.x
|
||
|
|
||
|
@staticmethod
|
||
|
def add(a, b):
|
||
|
return a + b
|
||
|
|
||
|
@staticmethod
|
||
|
def sub(a, b):
|
||
|
return a - b
|
||
|
|
||
|
@jitclass([("x", int32)])
|
||
|
class Test2:
|
||
|
def __init__(self, x):
|
||
|
self.x = x
|
||
|
|
||
|
def increase(self, y):
|
||
|
self.x = self.add(self.x, y)
|
||
|
return self.x
|
||
|
|
||
|
@staticmethod
|
||
|
def add(a, b):
|
||
|
return a - b
|
||
|
|
||
|
self.assertIsInstance(Test1.add, Dispatcher)
|
||
|
self.assertIsInstance(Test1.sub, Dispatcher)
|
||
|
self.assertIsInstance(Test2.add, Dispatcher)
|
||
|
self.assertNotEqual(Test1.add, Test2.add)
|
||
|
|
||
|
self.assertEqual(3, Test1.add(1, 2))
|
||
|
self.assertEqual(-1, Test2.add(1, 2))
|
||
|
self.assertEqual(4, Test1.sub(6, 2))
|
||
|
|
||
|
t1 = Test1(0)
|
||
|
t2 = Test2(0)
|
||
|
self.assertEqual(1, t1.increase(1))
|
||
|
self.assertEqual(-1, t2.increase(1))
|
||
|
self.assertEqual(2, t1.add(1, 1))
|
||
|
self.assertEqual(0, t1.sub(1, 1))
|
||
|
self.assertEqual(0, t2.add(1, 1))
|
||
|
self.assertEqual(2j, t1.add(1j, 1j))
|
||
|
self.assertEqual(1j, t1.sub(2j, 1j))
|
||
|
self.assertEqual("foobar", t1.add("foo", "bar"))
|
||
|
|
||
|
with self.assertRaises(AttributeError) as raises:
|
||
|
Test2.sub(3, 1)
|
||
|
self.assertIn("has no attribute 'sub'",
|
||
|
str(raises.exception))
|
||
|
|
||
|
with self.assertRaises(TypeError) as raises:
|
||
|
Test1.add(3)
|
||
|
self.assertIn("not enough arguments: expected 2, got 1",
|
||
|
str(raises.exception))
|
||
|
|
||
|
# Check error message for calling a static method as a class attr from
|
||
|
# another method (currently unsupported).
|
||
|
|
||
|
@jitclass([])
|
||
|
class Test3:
|
||
|
def __init__(self):
|
||
|
pass
|
||
|
|
||
|
@staticmethod
|
||
|
def a_static_method(a, b):
|
||
|
pass
|
||
|
|
||
|
def call_static(self):
|
||
|
return Test3.a_static_method(1, 2)
|
||
|
|
||
|
invalid = Test3()
|
||
|
with self.assertRaises(errors.TypingError) as raises:
|
||
|
invalid.call_static()
|
||
|
|
||
|
self.assertIn("Unknown attribute 'a_static_method'",
|
||
|
str(raises.exception))
|
||
|
|
||
|
def test_jitclass_decorator_usecases(self):
|
||
|
spec = OrderedDict(x=float64)
|
||
|
|
||
|
@jitclass()
|
||
|
class Test1:
|
||
|
x: float
|
||
|
|
||
|
def __init__(self):
|
||
|
self.x = 0
|
||
|
|
||
|
self.assertIsInstance(Test1, JitClassType)
|
||
|
self.assertDictEqual(Test1.class_type.struct, spec)
|
||
|
|
||
|
@jitclass(spec=spec)
|
||
|
class Test2:
|
||
|
|
||
|
def __init__(self):
|
||
|
self.x = 0
|
||
|
|
||
|
self.assertIsInstance(Test2, JitClassType)
|
||
|
self.assertDictEqual(Test2.class_type.struct, spec)
|
||
|
|
||
|
@jitclass
|
||
|
class Test3:
|
||
|
x: float
|
||
|
|
||
|
def __init__(self):
|
||
|
self.x = 0
|
||
|
|
||
|
self.assertIsInstance(Test3, JitClassType)
|
||
|
self.assertDictEqual(Test3.class_type.struct, spec)
|
||
|
|
||
|
@jitclass(spec)
|
||
|
class Test4:
|
||
|
|
||
|
def __init__(self):
|
||
|
self.x = 0
|
||
|
|
||
|
self.assertIsInstance(Test4, JitClassType)
|
||
|
self.assertDictEqual(Test4.class_type.struct, spec)
|
||
|
|
||
|
def test_jitclass_function_usecases(self):
|
||
|
spec = OrderedDict(x=float64)
|
||
|
|
||
|
class AnnotatedTest:
|
||
|
x: float
|
||
|
|
||
|
def __init__(self):
|
||
|
self.x = 0
|
||
|
|
||
|
JitTest1 = jitclass(AnnotatedTest)
|
||
|
self.assertIsInstance(JitTest1, JitClassType)
|
||
|
self.assertDictEqual(JitTest1.class_type.struct, spec)
|
||
|
|
||
|
class UnannotatedTest:
|
||
|
|
||
|
def __init__(self):
|
||
|
self.x = 0
|
||
|
|
||
|
JitTest2 = jitclass(UnannotatedTest, spec)
|
||
|
self.assertIsInstance(JitTest2, JitClassType)
|
||
|
self.assertDictEqual(JitTest2.class_type.struct, spec)
|
||
|
|
||
|
def test_jitclass_isinstance(self):
|
||
|
spec = OrderedDict(value=int32)
|
||
|
|
||
|
@jitclass(spec)
|
||
|
class Foo(object):
|
||
|
def __init__(self, value):
|
||
|
self.value = value
|
||
|
|
||
|
def getValue(self):
|
||
|
return self.value
|
||
|
|
||
|
def getValueIncr(self):
|
||
|
return self.value + 1
|
||
|
|
||
|
@jitclass(spec)
|
||
|
class Bar(object):
|
||
|
def __init__(self, value):
|
||
|
self.value = value
|
||
|
|
||
|
def getValue(self):
|
||
|
return self.value
|
||
|
|
||
|
def test_jitclass_isinstance(obj):
|
||
|
if isinstance(obj, (Foo, Bar)):
|
||
|
# call something that both classes implements
|
||
|
x = obj.getValue()
|
||
|
if isinstance(obj, Foo): # something that only Foo implements
|
||
|
return obj.getValueIncr() + x, 'Foo'
|
||
|
else:
|
||
|
return obj.getValue() + x, 'Bar'
|
||
|
else:
|
||
|
return 'no match'
|
||
|
|
||
|
pyfunc = test_jitclass_isinstance
|
||
|
cfunc = njit(test_jitclass_isinstance)
|
||
|
|
||
|
self.assertIsInstance(Foo, JitClassType)
|
||
|
self.assertEqual(pyfunc(Foo(3)), cfunc(Foo(3)))
|
||
|
self.assertEqual(pyfunc(Bar(123)), cfunc(Bar(123)))
|
||
|
self.assertEqual(pyfunc(0), cfunc(0))
|
||
|
|
||
|
def test_jitclass_unsupported_dunder(self):
|
||
|
with self.assertRaises(TypeError) as e:
|
||
|
@jitclass
|
||
|
class Foo(object):
|
||
|
def __init__(self):
|
||
|
return
|
||
|
|
||
|
def __enter__(self):
|
||
|
return None
|
||
|
Foo()
|
||
|
self.assertIn("Method '__enter__' is not supported.", str(e.exception))
|
||
|
|
||
|
def test_modulename(self):
|
||
|
@jitclass
|
||
|
class TestModname(object):
|
||
|
def __init__(self):
|
||
|
self.x = 12
|
||
|
|
||
|
thisModule = __name__
|
||
|
classModule = TestModname.__module__
|
||
|
self.assertEqual(thisModule, classModule)
|
||
|
|
||
|
|
||
|
class TestJitClassOverloads(MemoryLeakMixin, TestCase):
|
||
|
|
||
|
class PyList:
|
||
|
def __init__(self):
|
||
|
self.x = [0]
|
||
|
|
||
|
def append(self, y):
|
||
|
self.x.append(y)
|
||
|
|
||
|
def clear(self):
|
||
|
self.x.clear()
|
||
|
|
||
|
def __abs__(self):
|
||
|
return len(self.x) * 7
|
||
|
|
||
|
def __bool__(self):
|
||
|
return len(self.x) % 3 != 0
|
||
|
|
||
|
def __complex__(self):
|
||
|
c = complex(2)
|
||
|
if self.x:
|
||
|
c += self.x[0]
|
||
|
return c
|
||
|
|
||
|
def __contains__(self, y):
|
||
|
return y in self.x
|
||
|
|
||
|
def __float__(self):
|
||
|
f = 3.1415
|
||
|
if self.x:
|
||
|
f += self.x[0]
|
||
|
return f
|
||
|
|
||
|
def __int__(self):
|
||
|
i = 5
|
||
|
if self.x:
|
||
|
i += self.x[0]
|
||
|
return i
|
||
|
|
||
|
def __len__(self):
|
||
|
return len(self.x) + 1
|
||
|
|
||
|
def __str__(self):
|
||
|
if len(self.x) == 0:
|
||
|
return "PyList empty"
|
||
|
else:
|
||
|
return "PyList non-empty"
|
||
|
|
||
|
@staticmethod
|
||
|
def get_int_wrapper():
|
||
|
@jitclass([("x", types.intp)])
|
||
|
class IntWrapper:
|
||
|
def __init__(self, value):
|
||
|
self.x = value
|
||
|
|
||
|
def __eq__(self, other):
|
||
|
return self.x == other.x
|
||
|
|
||
|
def __hash__(self):
|
||
|
return self.x
|
||
|
|
||
|
def __lshift__(self, other):
|
||
|
return IntWrapper(self.x << other.x)
|
||
|
|
||
|
def __rshift__(self, other):
|
||
|
return IntWrapper(self.x >> other.x)
|
||
|
|
||
|
def __and__(self, other):
|
||
|
return IntWrapper(self.x & other.x)
|
||
|
|
||
|
def __or__(self, other):
|
||
|
return IntWrapper(self.x | other.x)
|
||
|
|
||
|
def __xor__(self, other):
|
||
|
return IntWrapper(self.x ^ other.x)
|
||
|
|
||
|
return IntWrapper
|
||
|
|
||
|
@staticmethod
|
||
|
def get_float_wrapper():
|
||
|
@jitclass([("x", types.float64)])
|
||
|
class FloatWrapper:
|
||
|
|
||
|
def __init__(self, value):
|
||
|
self.x = value
|
||
|
|
||
|
def __eq__(self, other):
|
||
|
return self.x == other.x
|
||
|
|
||
|
def __hash__(self):
|
||
|
return self.x
|
||
|
|
||
|
def __ge__(self, other):
|
||
|
return self.x >= other.x
|
||
|
|
||
|
def __gt__(self, other):
|
||
|
return self.x > other.x
|
||
|
|
||
|
def __le__(self, other):
|
||
|
return self.x <= other.x
|
||
|
|
||
|
def __lt__(self, other):
|
||
|
return self.x < other.x
|
||
|
|
||
|
def __add__(self, other):
|
||
|
return FloatWrapper(self.x + other.x)
|
||
|
|
||
|
def __floordiv__(self, other):
|
||
|
return FloatWrapper(self.x // other.x)
|
||
|
|
||
|
def __mod__(self, other):
|
||
|
return FloatWrapper(self.x % other.x)
|
||
|
|
||
|
def __mul__(self, other):
|
||
|
return FloatWrapper(self.x * other.x)
|
||
|
|
||
|
def __neg__(self, other):
|
||
|
return FloatWrapper(-self.x)
|
||
|
|
||
|
def __pos__(self, other):
|
||
|
return FloatWrapper(+self.x)
|
||
|
|
||
|
def __pow__(self, other):
|
||
|
return FloatWrapper(self.x ** other.x)
|
||
|
|
||
|
def __sub__(self, other):
|
||
|
return FloatWrapper(self.x - other.x)
|
||
|
|
||
|
def __truediv__(self, other):
|
||
|
return FloatWrapper(self.x / other.x)
|
||
|
|
||
|
return FloatWrapper
|
||
|
|
||
|
def assertSame(self, first, second, msg=None):
|
||
|
self.assertEqual(type(first), type(second), msg=msg)
|
||
|
self.assertEqual(first, second, msg=msg)
|
||
|
|
||
|
def test_overloads(self):
|
||
|
# Check that the dunder methods are exposed on ClassInstanceType.
|
||
|
|
||
|
JitList = jitclass({"x": types.List(types.intp)})(self.PyList)
|
||
|
|
||
|
py_funcs = [
|
||
|
lambda x: abs(x),
|
||
|
lambda x: x.__abs__(),
|
||
|
lambda x: bool(x),
|
||
|
lambda x: x.__bool__(),
|
||
|
lambda x: complex(x),
|
||
|
lambda x: x.__complex__(),
|
||
|
lambda x: 0 in x, # contains
|
||
|
lambda x: x.__contains__(0),
|
||
|
lambda x: float(x),
|
||
|
lambda x: x.__float__(),
|
||
|
lambda x: int(x),
|
||
|
lambda x: x.__int__(),
|
||
|
lambda x: len(x),
|
||
|
lambda x: x.__len__(),
|
||
|
lambda x: str(x),
|
||
|
lambda x: x.__str__(),
|
||
|
lambda x: 1 if x else 0, # truth
|
||
|
]
|
||
|
jit_funcs = [njit(f) for f in py_funcs]
|
||
|
|
||
|
py_list = self.PyList()
|
||
|
jit_list = JitList()
|
||
|
for py_f, jit_f in zip(py_funcs, jit_funcs):
|
||
|
self.assertSame(py_f(py_list), py_f(jit_list))
|
||
|
self.assertSame(py_f(py_list), jit_f(jit_list))
|
||
|
|
||
|
py_list.append(2)
|
||
|
jit_list.append(2)
|
||
|
for py_f, jit_f in zip(py_funcs, jit_funcs):
|
||
|
self.assertSame(py_f(py_list), py_f(jit_list))
|
||
|
self.assertSame(py_f(py_list), jit_f(jit_list))
|
||
|
|
||
|
py_list.append(-5)
|
||
|
jit_list.append(-5)
|
||
|
for py_f, jit_f in zip(py_funcs, jit_funcs):
|
||
|
self.assertSame(py_f(py_list), py_f(jit_list))
|
||
|
self.assertSame(py_f(py_list), jit_f(jit_list))
|
||
|
|
||
|
py_list.clear()
|
||
|
jit_list.clear()
|
||
|
for py_f, jit_f in zip(py_funcs, jit_funcs):
|
||
|
self.assertSame(py_f(py_list), py_f(jit_list))
|
||
|
self.assertSame(py_f(py_list), jit_f(jit_list))
|
||
|
|
||
|
def test_bool_fallback(self):
|
||
|
|
||
|
def py_b(x):
|
||
|
return bool(x)
|
||
|
|
||
|
jit_b = njit(py_b)
|
||
|
|
||
|
@jitclass([("x", types.List(types.intp))])
|
||
|
class LenClass:
|
||
|
def __init__(self, x):
|
||
|
self.x = x
|
||
|
|
||
|
def __len__(self):
|
||
|
return len(self.x) % 4
|
||
|
|
||
|
def append(self, y):
|
||
|
self.x.append(y)
|
||
|
|
||
|
def pop(self):
|
||
|
self.x.pop(0)
|
||
|
|
||
|
obj = LenClass([1, 2, 3])
|
||
|
self.assertTrue(py_b(obj))
|
||
|
self.assertTrue(jit_b(obj))
|
||
|
|
||
|
obj.append(4)
|
||
|
self.assertFalse(py_b(obj))
|
||
|
self.assertFalse(jit_b(obj))
|
||
|
|
||
|
obj.pop()
|
||
|
self.assertTrue(py_b(obj))
|
||
|
self.assertTrue(jit_b(obj))
|
||
|
|
||
|
@jitclass([("y", types.float64)])
|
||
|
class NormalClass:
|
||
|
def __init__(self, y):
|
||
|
self.y = y
|
||
|
|
||
|
obj = NormalClass(0)
|
||
|
self.assertTrue(py_b(obj))
|
||
|
self.assertTrue(jit_b(obj))
|
||
|
|
||
|
def test_numeric_fallback(self):
|
||
|
def py_c(x):
|
||
|
return complex(x)
|
||
|
|
||
|
def py_f(x):
|
||
|
return float(x)
|
||
|
|
||
|
def py_i(x):
|
||
|
return int(x)
|
||
|
|
||
|
jit_c = njit(py_c)
|
||
|
jit_f = njit(py_f)
|
||
|
jit_i = njit(py_i)
|
||
|
|
||
|
@jitclass([])
|
||
|
class FloatClass:
|
||
|
def __init__(self):
|
||
|
pass
|
||
|
|
||
|
def __float__(self):
|
||
|
return 3.1415
|
||
|
|
||
|
obj = FloatClass()
|
||
|
self.assertSame(py_c(obj), complex(3.1415))
|
||
|
self.assertSame(jit_c(obj), complex(3.1415))
|
||
|
self.assertSame(py_f(obj), 3.1415)
|
||
|
self.assertSame(jit_f(obj), 3.1415)
|
||
|
|
||
|
with self.assertRaises(TypeError) as e:
|
||
|
py_i(obj)
|
||
|
self.assertIn("int", str(e.exception))
|
||
|
with self.assertRaises(TypingError) as e:
|
||
|
jit_i(obj)
|
||
|
self.assertIn("int", str(e.exception))
|
||
|
|
||
|
@jitclass([])
|
||
|
class IntClass:
|
||
|
def __init__(self):
|
||
|
pass
|
||
|
|
||
|
def __int__(self):
|
||
|
return 7
|
||
|
|
||
|
obj = IntClass()
|
||
|
self.assertSame(py_i(obj), 7)
|
||
|
self.assertSame(jit_i(obj), 7)
|
||
|
|
||
|
with self.assertRaises(TypeError) as e:
|
||
|
py_c(obj)
|
||
|
self.assertIn("complex", str(e.exception))
|
||
|
with self.assertRaises(TypingError) as e:
|
||
|
jit_c(obj)
|
||
|
self.assertIn("complex", str(e.exception))
|
||
|
with self.assertRaises(TypeError) as e:
|
||
|
py_f(obj)
|
||
|
self.assertIn("float", str(e.exception))
|
||
|
with self.assertRaises(TypingError) as e:
|
||
|
jit_f(obj)
|
||
|
self.assertIn("float", str(e.exception))
|
||
|
|
||
|
@jitclass([])
|
||
|
class IndexClass:
|
||
|
def __init__(self):
|
||
|
pass
|
||
|
|
||
|
def __index__(self):
|
||
|
return 1
|
||
|
|
||
|
obj = IndexClass()
|
||
|
|
||
|
self.assertSame(py_c(obj), complex(1))
|
||
|
self.assertSame(jit_c(obj), complex(1))
|
||
|
self.assertSame(py_f(obj), 1.)
|
||
|
self.assertSame(jit_f(obj), 1.)
|
||
|
self.assertSame(py_i(obj), 1)
|
||
|
self.assertSame(jit_i(obj), 1)
|
||
|
|
||
|
@jitclass([])
|
||
|
class FloatIntIndexClass:
|
||
|
def __init__(self):
|
||
|
pass
|
||
|
|
||
|
def __float__(self):
|
||
|
return 3.1415
|
||
|
|
||
|
def __int__(self):
|
||
|
return 7
|
||
|
|
||
|
def __index__(self):
|
||
|
return 1
|
||
|
|
||
|
obj = FloatIntIndexClass()
|
||
|
self.assertSame(py_c(obj), complex(3.1415))
|
||
|
self.assertSame(jit_c(obj), complex(3.1415))
|
||
|
self.assertSame(py_f(obj), 3.1415)
|
||
|
self.assertSame(jit_f(obj), 3.1415)
|
||
|
self.assertSame(py_i(obj), 7)
|
||
|
self.assertSame(jit_i(obj), 7)
|
||
|
|
||
|
def test_arithmetic_logical(self):
|
||
|
IntWrapper = self.get_int_wrapper()
|
||
|
FloatWrapper = self.get_float_wrapper()
|
||
|
|
||
|
float_py_funcs = [
|
||
|
lambda x, y: x == y,
|
||
|
lambda x, y: x != y,
|
||
|
lambda x, y: x >= y,
|
||
|
lambda x, y: x > y,
|
||
|
lambda x, y: x <= y,
|
||
|
lambda x, y: x < y,
|
||
|
lambda x, y: x + y,
|
||
|
lambda x, y: x // y,
|
||
|
lambda x, y: x % y,
|
||
|
lambda x, y: x * y,
|
||
|
lambda x, y: x ** y,
|
||
|
lambda x, y: x - y,
|
||
|
lambda x, y: x / y,
|
||
|
]
|
||
|
int_py_funcs = [
|
||
|
lambda x, y: x == y,
|
||
|
lambda x, y: x != y,
|
||
|
lambda x, y: x << y,
|
||
|
lambda x, y: x >> y,
|
||
|
lambda x, y: x & y,
|
||
|
lambda x, y: x | y,
|
||
|
lambda x, y: x ^ y,
|
||
|
]
|
||
|
|
||
|
test_values = [
|
||
|
(0.0, 2.0),
|
||
|
(1.234, 3.1415),
|
||
|
(13.1, 1.01),
|
||
|
]
|
||
|
|
||
|
def unwrap(value):
|
||
|
return getattr(value, "x", value)
|
||
|
|
||
|
for jit_f, (x, y) in itertools.product(
|
||
|
map(njit, float_py_funcs), test_values):
|
||
|
|
||
|
py_f = jit_f.py_func
|
||
|
|
||
|
expected = py_f(x, y)
|
||
|
jit_x = FloatWrapper(x)
|
||
|
jit_y = FloatWrapper(y)
|
||
|
|
||
|
check = (
|
||
|
self.assertEqual
|
||
|
if type(expected) is not float
|
||
|
else self.assertAlmostEqual
|
||
|
)
|
||
|
check(expected, jit_f(x, y))
|
||
|
check(expected, unwrap(py_f(jit_x, jit_y)))
|
||
|
check(expected, unwrap(jit_f(jit_x, jit_y)))
|
||
|
|
||
|
for jit_f, (x, y) in itertools.product(
|
||
|
map(njit, int_py_funcs), test_values):
|
||
|
|
||
|
py_f = jit_f.py_func
|
||
|
x, y = int(x), int(y)
|
||
|
|
||
|
expected = py_f(x, y)
|
||
|
jit_x = IntWrapper(x)
|
||
|
jit_y = IntWrapper(y)
|
||
|
|
||
|
self.assertEqual(expected, jit_f(x, y))
|
||
|
self.assertEqual(expected, unwrap(py_f(jit_x, jit_y)))
|
||
|
self.assertEqual(expected, unwrap(jit_f(jit_x, jit_y)))
|
||
|
|
||
|
def test_arithmetic_logical_inplace(self):
|
||
|
|
||
|
# If __i*__ methods are not defined, should fall back to normal methods.
|
||
|
JitIntWrapper = self.get_int_wrapper()
|
||
|
JitFloatWrapper = self.get_float_wrapper()
|
||
|
|
||
|
PyIntWrapper = JitIntWrapper.mro()[1]
|
||
|
PyFloatWrapper = JitFloatWrapper.mro()[1]
|
||
|
|
||
|
@jitclass([("x", types.intp)])
|
||
|
class JitIntUpdateWrapper(PyIntWrapper):
|
||
|
def __init__(self, value):
|
||
|
self.x = value
|
||
|
|
||
|
def __ilshift__(self, other):
|
||
|
return JitIntUpdateWrapper(self.x << other.x)
|
||
|
|
||
|
def __irshift__(self, other):
|
||
|
return JitIntUpdateWrapper(self.x >> other.x)
|
||
|
|
||
|
def __iand__(self, other):
|
||
|
return JitIntUpdateWrapper(self.x & other.x)
|
||
|
|
||
|
def __ior__(self, other):
|
||
|
return JitIntUpdateWrapper(self.x | other.x)
|
||
|
|
||
|
def __ixor__(self, other):
|
||
|
return JitIntUpdateWrapper(self.x ^ other.x)
|
||
|
|
||
|
@jitclass({"x": types.float64})
|
||
|
class JitFloatUpdateWrapper(PyFloatWrapper):
|
||
|
|
||
|
def __init__(self, value):
|
||
|
self.x = value
|
||
|
|
||
|
def __iadd__(self, other):
|
||
|
return JitFloatUpdateWrapper(self.x + 2.718 * other.x)
|
||
|
|
||
|
def __ifloordiv__(self, other):
|
||
|
return JitFloatUpdateWrapper(self.x * 2.718 // other.x)
|
||
|
|
||
|
def __imod__(self, other):
|
||
|
return JitFloatUpdateWrapper(self.x % (other.x + 1))
|
||
|
|
||
|
def __imul__(self, other):
|
||
|
return JitFloatUpdateWrapper(self.x * other.x + 1)
|
||
|
|
||
|
def __ipow__(self, other):
|
||
|
return JitFloatUpdateWrapper(self.x ** other.x + 1)
|
||
|
|
||
|
def __isub__(self, other):
|
||
|
return JitFloatUpdateWrapper(self.x - 3.1415 * other.x)
|
||
|
|
||
|
def __itruediv__(self, other):
|
||
|
return JitFloatUpdateWrapper((self.x + 1) / other.x)
|
||
|
|
||
|
PyIntUpdateWrapper = JitIntUpdateWrapper.mro()[1]
|
||
|
PyFloatUpdateWrapper = JitFloatUpdateWrapper.mro()[1]
|
||
|
|
||
|
def get_update_func(op):
|
||
|
template = f"""
|
||
|
def f(x, y):
|
||
|
x {op}= y
|
||
|
return x
|
||
|
"""
|
||
|
namespace = {}
|
||
|
exec(template, namespace)
|
||
|
return namespace["f"]
|
||
|
|
||
|
float_py_funcs = [get_update_func(op) for op in [
|
||
|
"+", "//", "%", "*", "**", "-", "/",
|
||
|
]]
|
||
|
int_py_funcs = [get_update_func(op) for op in [
|
||
|
"<<", ">>", "&", "|", "^",
|
||
|
]]
|
||
|
|
||
|
test_values = [
|
||
|
(0.0, 2.0),
|
||
|
(1.234, 3.1415),
|
||
|
(13.1, 1.01),
|
||
|
]
|
||
|
|
||
|
for jit_f, (py_cls, jit_cls), (x, y) in itertools.product(
|
||
|
map(njit, float_py_funcs),
|
||
|
[
|
||
|
(PyFloatWrapper, JitFloatWrapper),
|
||
|
(PyFloatUpdateWrapper, JitFloatUpdateWrapper)
|
||
|
],
|
||
|
test_values):
|
||
|
py_f = jit_f.py_func
|
||
|
|
||
|
expected = py_f(py_cls(x), py_cls(y)).x
|
||
|
self.assertAlmostEqual(expected, py_f(jit_cls(x), jit_cls(y)).x)
|
||
|
self.assertAlmostEqual(expected, jit_f(jit_cls(x), jit_cls(y)).x)
|
||
|
|
||
|
for jit_f, (py_cls, jit_cls), (x, y) in itertools.product(
|
||
|
map(njit, int_py_funcs),
|
||
|
[
|
||
|
(PyIntWrapper, JitIntWrapper),
|
||
|
(PyIntUpdateWrapper, JitIntUpdateWrapper)
|
||
|
],
|
||
|
test_values):
|
||
|
x, y = int(x), int(y)
|
||
|
py_f = jit_f.py_func
|
||
|
|
||
|
expected = py_f(py_cls(x), py_cls(y)).x
|
||
|
self.assertEqual(expected, py_f(jit_cls(x), jit_cls(y)).x)
|
||
|
self.assertEqual(expected, jit_f(jit_cls(x), jit_cls(y)).x)
|
||
|
|
||
|
def test_hash_eq_ne(self):
|
||
|
|
||
|
class HashEqTest:
|
||
|
x: int
|
||
|
|
||
|
def __init__(self, x):
|
||
|
self.x = x
|
||
|
|
||
|
def __hash__(self):
|
||
|
return self.x % 10
|
||
|
|
||
|
def __eq__(self, o):
|
||
|
return (self.x - o.x) % 20 == 0
|
||
|
|
||
|
class HashEqNeTest(HashEqTest):
|
||
|
def __ne__(self, o):
|
||
|
return (self.x - o.x) % 20 > 1
|
||
|
|
||
|
def py_hash(x):
|
||
|
return hash(x)
|
||
|
|
||
|
def py_eq(x, y):
|
||
|
return x == y
|
||
|
|
||
|
def py_ne(x, y):
|
||
|
return x != y
|
||
|
|
||
|
def identity_decorator(f):
|
||
|
return f
|
||
|
|
||
|
comparisons = [
|
||
|
(0, 1), # Will give different ne results.
|
||
|
(2, 22),
|
||
|
(7, 10),
|
||
|
(3, 3),
|
||
|
]
|
||
|
|
||
|
for base_cls, use_jit in itertools.product(
|
||
|
[HashEqTest, HashEqNeTest], [False, True]
|
||
|
):
|
||
|
decorator = njit if use_jit else identity_decorator
|
||
|
hash_func = decorator(py_hash)
|
||
|
eq_func = decorator(py_eq)
|
||
|
ne_func = decorator(py_ne)
|
||
|
|
||
|
jit_cls = jitclass(base_cls)
|
||
|
|
||
|
for v in [0, 2, 10, 24, -8]:
|
||
|
self.assertEqual(hash_func(jit_cls(v)), v % 10)
|
||
|
|
||
|
for x, y in comparisons:
|
||
|
self.assertEqual(
|
||
|
eq_func(jit_cls(x), jit_cls(y)),
|
||
|
base_cls(x) == base_cls(y),
|
||
|
)
|
||
|
self.assertEqual(
|
||
|
ne_func(jit_cls(x), jit_cls(y)),
|
||
|
base_cls(x) != base_cls(y),
|
||
|
)
|
||
|
|
||
|
def test_bool_fallback_len(self):
|
||
|
# Check that the fallback to using len(obj) to determine truth of an
|
||
|
# object is implemented correctly as per
|
||
|
# https://docs.python.org/3/library/stdtypes.html#truth-value-testing
|
||
|
#
|
||
|
# Relevant points:
|
||
|
#
|
||
|
# "By default, an object is considered true unless its class defines
|
||
|
# either a __bool__() method that returns False or a __len__() method
|
||
|
# that returns zero, when called with the object."
|
||
|
#
|
||
|
# and:
|
||
|
#
|
||
|
# "Operations and built-in functions that have a Boolean result always
|
||
|
# return 0 or False for false and 1 or True for true, unless otherwise
|
||
|
# stated."
|
||
|
|
||
|
class NoBoolHasLen:
|
||
|
def __init__(self, val):
|
||
|
self.val = val
|
||
|
|
||
|
def __len__(self):
|
||
|
return self.val
|
||
|
|
||
|
def get_bool(self):
|
||
|
return bool(self)
|
||
|
|
||
|
py_class = NoBoolHasLen
|
||
|
jitted_class = jitclass([('val', types.int64)])(py_class)
|
||
|
|
||
|
py_class_0_bool = py_class(0).get_bool()
|
||
|
py_class_2_bool = py_class(2).get_bool()
|
||
|
jitted_class_0_bool = jitted_class(0).get_bool()
|
||
|
jitted_class_2_bool = jitted_class(2).get_bool()
|
||
|
|
||
|
# Truth values from bool(obj) should be equal
|
||
|
self.assertEqual(py_class_0_bool, jitted_class_0_bool)
|
||
|
self.assertEqual(py_class_2_bool, jitted_class_2_bool)
|
||
|
|
||
|
# Truth values from bool(obj) should be the same type
|
||
|
self.assertEqual(type(py_class_0_bool), type(jitted_class_0_bool))
|
||
|
self.assertEqual(type(py_class_2_bool), type(jitted_class_2_bool))
|
||
|
|
||
|
def test_bool_fallback_default(self):
|
||
|
# Similar to test_bool_fallback, but checks the case where there is no
|
||
|
# __bool__() or __len__() defined, so the object should always be True.
|
||
|
|
||
|
class NoBoolNoLen:
|
||
|
def __init__(self):
|
||
|
pass
|
||
|
|
||
|
def get_bool(self):
|
||
|
return bool(self)
|
||
|
|
||
|
py_class = NoBoolNoLen
|
||
|
jitted_class = jitclass([])(py_class)
|
||
|
|
||
|
py_class_bool = py_class().get_bool()
|
||
|
jitted_class_bool = jitted_class().get_bool()
|
||
|
|
||
|
# Truth values from bool(obj) should be equal
|
||
|
self.assertEqual(py_class_bool, jitted_class_bool)
|
||
|
|
||
|
# Truth values from bool(obj) should be the same type
|
||
|
self.assertEqual(type(py_class_bool), type(jitted_class_bool))
|
||
|
|
||
|
def test_operator_reflection(self):
|
||
|
class OperatorsDefined:
|
||
|
def __init__(self, x):
|
||
|
self.x = x
|
||
|
|
||
|
def __eq__(self, other):
|
||
|
return self.x == other.x
|
||
|
|
||
|
def __le__(self, other):
|
||
|
return self.x <= other.x
|
||
|
|
||
|
def __lt__(self, other):
|
||
|
return self.x < other.x
|
||
|
|
||
|
def __ge__(self, other):
|
||
|
return self.x >= other.x
|
||
|
|
||
|
def __gt__(self, other):
|
||
|
return self.x > other.x
|
||
|
|
||
|
class NoOperatorsDefined:
|
||
|
def __init__(self, x):
|
||
|
self.x = x
|
||
|
|
||
|
spec = [('x', types.int32)]
|
||
|
JitOperatorsDefined = jitclass(spec)(OperatorsDefined)
|
||
|
JitNoOperatorsDefined = jitclass(spec)(NoOperatorsDefined)
|
||
|
|
||
|
py_ops_defined = OperatorsDefined(2)
|
||
|
py_ops_not_defined = NoOperatorsDefined(3)
|
||
|
|
||
|
jit_ops_defined = JitOperatorsDefined(2)
|
||
|
jit_ops_not_defined = JitNoOperatorsDefined(3)
|
||
|
|
||
|
self.assertEqual(py_ops_not_defined == py_ops_defined,
|
||
|
jit_ops_not_defined == jit_ops_defined)
|
||
|
|
||
|
self.assertEqual(py_ops_not_defined <= py_ops_defined,
|
||
|
jit_ops_not_defined <= jit_ops_defined)
|
||
|
|
||
|
self.assertEqual(py_ops_not_defined < py_ops_defined,
|
||
|
jit_ops_not_defined < jit_ops_defined)
|
||
|
|
||
|
self.assertEqual(py_ops_not_defined >= py_ops_defined,
|
||
|
jit_ops_not_defined >= jit_ops_defined)
|
||
|
|
||
|
self.assertEqual(py_ops_not_defined > py_ops_defined,
|
||
|
jit_ops_not_defined > jit_ops_defined)
|
||
|
|
||
|
@skip_unless_scipy
|
||
|
def test_matmul_operator(self):
|
||
|
class ArrayAt:
|
||
|
def __init__(self, array):
|
||
|
self.arr = array
|
||
|
|
||
|
def __matmul__(self, other):
|
||
|
return self.arr @ other.arr
|
||
|
|
||
|
def __rmatmul__(self, other):
|
||
|
return other.arr @ self.arr
|
||
|
|
||
|
def __imatmul__(self, other):
|
||
|
self.arr = self.arr @ other.arr
|
||
|
return self
|
||
|
|
||
|
class ArrayNoAt:
|
||
|
def __init__(self, array):
|
||
|
self.arr = array
|
||
|
|
||
|
n = 3
|
||
|
np.random.seed(1)
|
||
|
vec = np.random.random(size=(n,))
|
||
|
mat = np.random.random(size=(n, n))
|
||
|
|
||
|
vector_noat = ArrayNoAt(vec)
|
||
|
vector_at = ArrayAt(vec)
|
||
|
jit_vector_noat = jitclass(ArrayNoAt, spec={"arr": float64[::1]})(vec)
|
||
|
jit_vector_at = jitclass(ArrayAt, spec={"arr": float64[::1]})(vec)
|
||
|
|
||
|
matrix_noat = ArrayNoAt(mat)
|
||
|
matrix_at = ArrayAt(mat)
|
||
|
jit_matrix_noat = jitclass(ArrayNoAt, spec={"arr": float64[:,::1]})(mat)
|
||
|
jit_matrix_at = jitclass(ArrayAt, spec={"arr": float64[:,::1]})(mat)
|
||
|
|
||
|
# __matmul__
|
||
|
np.testing.assert_allclose(vector_at @ vector_noat,
|
||
|
jit_vector_at @ jit_vector_noat)
|
||
|
np.testing.assert_allclose(vector_at @ matrix_noat,
|
||
|
jit_vector_at @ jit_matrix_noat)
|
||
|
np.testing.assert_allclose(matrix_at @ vector_noat,
|
||
|
jit_matrix_at @ jit_vector_noat)
|
||
|
np.testing.assert_allclose(matrix_at @ matrix_noat,
|
||
|
jit_matrix_at @ jit_matrix_noat)
|
||
|
|
||
|
# __rmatmul__
|
||
|
np.testing.assert_allclose(vector_noat @ vector_at,
|
||
|
jit_vector_noat @ jit_vector_at)
|
||
|
np.testing.assert_allclose(vector_noat @ matrix_at,
|
||
|
jit_vector_noat @ jit_matrix_at)
|
||
|
np.testing.assert_allclose(matrix_noat @ vector_at,
|
||
|
jit_matrix_noat @ jit_vector_at)
|
||
|
np.testing.assert_allclose(matrix_noat @ matrix_at,
|
||
|
jit_matrix_noat @ jit_matrix_at)
|
||
|
|
||
|
# __imatmul__
|
||
|
vector_at @= matrix_noat
|
||
|
matrix_at @= matrix_noat
|
||
|
jit_vector_at @= jit_matrix_noat
|
||
|
jit_matrix_at @= jit_matrix_noat
|
||
|
|
||
|
np.testing.assert_allclose(vector_at.arr, jit_vector_at.arr)
|
||
|
np.testing.assert_allclose(matrix_at.arr, jit_matrix_at.arr)
|
||
|
|
||
|
def test_arithmetic_logical_reflection(self):
|
||
|
class OperatorsDefined:
|
||
|
def __init__(self, x):
|
||
|
self.x = x
|
||
|
|
||
|
def __radd__(self, other):
|
||
|
return other.x + self.x
|
||
|
|
||
|
def __rsub__(self, other):
|
||
|
return other.x - self.x
|
||
|
|
||
|
def __rmul__(self, other):
|
||
|
return other.x * self.x
|
||
|
|
||
|
def __rtruediv__(self, other):
|
||
|
return other.x / self.x
|
||
|
|
||
|
def __rfloordiv__(self, other):
|
||
|
return other.x // self.x
|
||
|
|
||
|
def __rmod__(self, other):
|
||
|
return other.x % self.x
|
||
|
|
||
|
def __rpow__(self, other):
|
||
|
return other.x ** self.x
|
||
|
|
||
|
def __rlshift__(self, other):
|
||
|
return other.x << self.x
|
||
|
|
||
|
def __rrshift__(self, other):
|
||
|
return other.x >> self.x
|
||
|
|
||
|
def __rand__(self, other):
|
||
|
return other.x & self.x
|
||
|
|
||
|
def __rxor__(self, other):
|
||
|
return other.x ^ self.x
|
||
|
|
||
|
def __ror__(self, other):
|
||
|
return other.x | self.x
|
||
|
|
||
|
class NoOperatorsDefined:
|
||
|
def __init__(self, x):
|
||
|
self.x = x
|
||
|
|
||
|
float_op = ["+", "-", "*", "**", "/", "//", "%"]
|
||
|
int_op = [*float_op, "<<", ">>" , "&", "^", "|"]
|
||
|
|
||
|
for test_type, test_op, test_value in [
|
||
|
(int32, int_op, (2, 4)),
|
||
|
(float64, float_op, (2., 4.)),
|
||
|
(float64[::1], float_op,
|
||
|
(np.array([1., 2., 4.]), np.array([20., -24., 1.])))
|
||
|
]:
|
||
|
spec = {"x": test_type}
|
||
|
JitOperatorsDefined = jitclass(OperatorsDefined, spec)
|
||
|
JitNoOperatorsDefined = jitclass(NoOperatorsDefined, spec)
|
||
|
|
||
|
py_ops_defined = OperatorsDefined(test_value[0]) # noqa: F841
|
||
|
py_ops_not_defined = NoOperatorsDefined(test_value[1]) # noqa: F841
|
||
|
|
||
|
jit_ops_defined = JitOperatorsDefined(test_value[0]) # noqa: F841
|
||
|
jit_ops_not_defined = JitNoOperatorsDefined(test_value[1]) # noqa: F841 E501
|
||
|
|
||
|
for op in test_op:
|
||
|
if not ("array" in str(test_type)):
|
||
|
self.assertEqual(
|
||
|
eval(f"py_ops_not_defined {op} py_ops_defined"),
|
||
|
eval(f"jit_ops_not_defined {op} jit_ops_defined")
|
||
|
)
|
||
|
else:
|
||
|
self.assertTupleEqual(
|
||
|
tuple(eval(f"py_ops_not_defined {op} py_ops_defined")),
|
||
|
tuple(eval(f"jit_ops_not_defined {op} jit_ops_defined"))
|
||
|
)
|
||
|
|
||
|
def test_implicit_hash_compiles(self):
|
||
|
# Ensure that classes with __hash__ implicitly defined as None due to
|
||
|
# the presence of __eq__ are correctly handled by ignoring the __hash__
|
||
|
# class member.
|
||
|
class ImplicitHash:
|
||
|
def __init__(self):
|
||
|
pass
|
||
|
|
||
|
def __eq__(self, other):
|
||
|
return False
|
||
|
|
||
|
jitted = jitclass([])(ImplicitHash)
|
||
|
instance = jitted()
|
||
|
|
||
|
self.assertFalse(instance == instance)
|
||
|
|
||
|
|
||
|
if __name__ == "__main__":
|
||
|
unittest.main()
|