149 lines
6.1 KiB
Python
149 lines
6.1 KiB
Python
|
from itertools import product, permutations
|
||
|
from collections import defaultdict
|
||
|
|
||
|
import unittest
|
||
|
from numba.core.base import OverloadSelector
|
||
|
from numba.core.registry import cpu_target
|
||
|
from numba.core.imputils import builtin_registry, RegistryLoader
|
||
|
from numba.core import types
|
||
|
from numba.core.errors import NumbaNotImplementedError, NumbaTypeError
|
||
|
|
||
|
|
||
|
class TestOverloadSelector(unittest.TestCase):
|
||
|
def test_select_and_sort_1(self):
|
||
|
os = OverloadSelector()
|
||
|
os.append(1, (types.Any, types.Boolean))
|
||
|
os.append(2, (types.Boolean, types.Integer))
|
||
|
os.append(3, (types.Boolean, types.Any))
|
||
|
os.append(4, (types.Boolean, types.Boolean))
|
||
|
compats = os._select_compatible((types.boolean, types.boolean))
|
||
|
self.assertEqual(len(compats), 3)
|
||
|
ordered, scoring = os._sort_signatures(compats)
|
||
|
self.assertEqual(len(ordered), 3)
|
||
|
self.assertEqual(len(scoring), 3)
|
||
|
self.assertEqual(ordered[0], (types.Boolean, types.Boolean))
|
||
|
self.assertEqual(scoring[types.Boolean, types.Boolean], 0)
|
||
|
self.assertEqual(scoring[types.Boolean, types.Any], 1)
|
||
|
self.assertEqual(scoring[types.Any, types.Boolean], 1)
|
||
|
|
||
|
def test_select_and_sort_2(self):
|
||
|
os = OverloadSelector()
|
||
|
os.append(1, (types.Container,))
|
||
|
os.append(2, (types.Sequence,))
|
||
|
os.append(3, (types.MutableSequence,))
|
||
|
os.append(4, (types.List,))
|
||
|
compats = os._select_compatible((types.List,))
|
||
|
self.assertEqual(len(compats), 4)
|
||
|
ordered, scoring = os._sort_signatures(compats)
|
||
|
self.assertEqual(len(ordered), 4)
|
||
|
self.assertEqual(len(scoring), 4)
|
||
|
self.assertEqual(ordered[0], (types.List,))
|
||
|
self.assertEqual(scoring[(types.List,)], 0)
|
||
|
self.assertEqual(scoring[(types.MutableSequence,)], 1)
|
||
|
self.assertEqual(scoring[(types.Sequence,)], 2)
|
||
|
self.assertEqual(scoring[(types.Container,)], 3)
|
||
|
|
||
|
def test_match(self):
|
||
|
os = OverloadSelector()
|
||
|
self.assertTrue(os._match(formal=types.Boolean, actual=types.boolean))
|
||
|
self.assertTrue(os._match(formal=types.Boolean, actual=types.Boolean))
|
||
|
# test subclass
|
||
|
self.assertTrue(issubclass(types.Sequence, types.Container))
|
||
|
self.assertTrue(os._match(formal=types.Container,
|
||
|
actual=types.Sequence))
|
||
|
self.assertFalse(os._match(formal=types.Sequence,
|
||
|
actual=types.Container))
|
||
|
# test any
|
||
|
self.assertTrue(os._match(formal=types.Any, actual=types.Any))
|
||
|
self.assertTrue(os._match(formal=types.Any, actual=types.Container))
|
||
|
self.assertFalse(os._match(formal=types.Container, actual=types.Any))
|
||
|
|
||
|
def test_ambiguous_detection(self):
|
||
|
os = OverloadSelector()
|
||
|
# unambiguous signatures
|
||
|
os.append(1, (types.Any, types.Boolean))
|
||
|
os.append(2, (types.Integer, types.Boolean))
|
||
|
self.assertEqual(os.find((types.boolean, types.boolean)), 1)
|
||
|
# not implemented
|
||
|
with self.assertRaises(NumbaNotImplementedError) as raises:
|
||
|
os.find((types.boolean, types.int32))
|
||
|
# generic
|
||
|
os.append(3, (types.Any, types.Any))
|
||
|
self.assertEqual(os.find((types.boolean, types.int32)), 3)
|
||
|
self.assertEqual(os.find((types.boolean, types.boolean)), 1)
|
||
|
# add ambiguous signature; can match (bool, any) and (any, bool)
|
||
|
os.append(4, (types.Boolean, types.Any))
|
||
|
with self.assertRaises(NumbaTypeError) as raises:
|
||
|
os.find((types.boolean, types.boolean))
|
||
|
self.assertIn('2 ambiguous signatures', str(raises.exception))
|
||
|
# disambiguous
|
||
|
os.append(5, (types.boolean, types.boolean))
|
||
|
self.assertEqual(os.find((types.boolean, types.boolean)), 5)
|
||
|
|
||
|
def test_subclass_specialization(self):
|
||
|
os = OverloadSelector()
|
||
|
self.assertTrue(issubclass(types.Sequence, types.Container))
|
||
|
os.append(1, (types.Container, types.Container,))
|
||
|
lstty = types.List(types.boolean)
|
||
|
self.assertEqual(os.find((lstty, lstty)), 1)
|
||
|
os.append(2, (types.Container, types.Sequence,))
|
||
|
self.assertEqual(os.find((lstty, lstty)), 2)
|
||
|
|
||
|
def test_cache(self):
|
||
|
os = OverloadSelector()
|
||
|
self.assertEqual(len(os._cache), 0)
|
||
|
os.append(1, (types.Any,))
|
||
|
self.assertEqual(os.find((types.int32,)), 1)
|
||
|
self.assertEqual(len(os._cache), 1)
|
||
|
os.append(2, (types.Integer,))
|
||
|
self.assertEqual(len(os._cache), 0)
|
||
|
self.assertEqual(os.find((types.int32,)), 2)
|
||
|
self.assertEqual(len(os._cache), 1)
|
||
|
|
||
|
|
||
|
class TestAmbiguousOverloads(unittest.TestCase):
|
||
|
|
||
|
@classmethod
|
||
|
def setUpClass(cls):
|
||
|
# ensure all impls are loaded
|
||
|
cpu_target.target_context.refresh()
|
||
|
|
||
|
def create_overload_selector(self, kind):
|
||
|
os = OverloadSelector()
|
||
|
loader = RegistryLoader(builtin_registry)
|
||
|
for impl, sig in loader.new_registrations(kind):
|
||
|
os.append(impl, sig)
|
||
|
return os
|
||
|
|
||
|
def test_ambiguous_casts(self):
|
||
|
os = self.create_overload_selector(kind='casts')
|
||
|
all_types = set(t for sig, impl in os.versions for t in sig)
|
||
|
# ensure there are no ambiguous cast overloads
|
||
|
# note: using permutations to avoid testing cast to the same type
|
||
|
for sig in permutations(all_types, r=2):
|
||
|
try:
|
||
|
os.find(sig)
|
||
|
except NumbaNotImplementedError:
|
||
|
pass # ignore not implemented cast
|
||
|
|
||
|
def test_ambiguous_functions(self):
|
||
|
loader = RegistryLoader(builtin_registry)
|
||
|
selectors = defaultdict(OverloadSelector)
|
||
|
for impl, fn, sig in loader.new_registrations('functions'):
|
||
|
os = selectors[fn]
|
||
|
os.append(impl, sig)
|
||
|
|
||
|
for fn, os in selectors.items():
|
||
|
all_types = set(t for sig, impl in os.versions for t in sig)
|
||
|
# ensure there are no ambiguous overloads
|
||
|
for sig in product(all_types, all_types):
|
||
|
try:
|
||
|
os.find(sig)
|
||
|
except NumbaNotImplementedError:
|
||
|
pass # ignore not implemented cast
|
||
|
|
||
|
|
||
|
|
||
|
if __name__ == '__main__':
|
||
|
unittest.main()
|