ai-content-maker/.venv/Lib/site-packages/numba/tests/test_sets.py

869 lines
23 KiB
Python

import unittest
from collections import namedtuple
import contextlib
import itertools
import random
from numba.core.errors import TypingError
import numpy as np
from numba import jit, njit
from numba.tests.support import (TestCase, enable_pyobj_flags, MemoryLeakMixin,
compile_function)
Point = namedtuple('Point', ('a', 'b'))
def _build_set_literal_usecase(code, args):
code = code % {'initializer': ', '.join(repr(arg) for arg in args)}
return compile_function('build_set', code, globals())
def set_literal_return_usecase(args):
code = """if 1:
def build_set():
return {%(initializer)s}
"""
return _build_set_literal_usecase(code, args)
def set_literal_convert_usecase(args):
code = """if 1:
def build_set():
my_set = {%(initializer)s}
return list(my_set)
"""
return _build_set_literal_usecase(code, args)
def empty_constructor_usecase():
s = set()
s.add(1)
return len(s)
def constructor_usecase(arg):
s = set(arg)
return len(s)
def iterator_usecase(arg):
s = set(arg)
l = []
for v in s:
l.append(v)
return l
def update_usecase(a, b, c):
s = set()
s.update(a)
s.update(b)
s.update(c)
return list(s)
def bool_usecase(arg):
# Remove one element to allow for empty sets.
s = set(arg[1:])
return bool(s)
def remove_usecase(a, b):
s = set(a)
for v in b:
s.remove(v)
return list(s)
def discard_usecase(a, b):
s = set(a)
for v in b:
s.discard(v)
return list(s)
def add_discard_usecase(a, u, v):
s = set(a)
for i in range(1000):
s.add(u)
s.discard(v)
return list(s)
def pop_usecase(a):
s = set(a)
l = []
while len(s) > 0:
l.append(s.pop())
return l
def contains_usecase(a, b):
s = set(a)
l = []
for v in b:
l.append(v in s)
return l
def difference_update_usecase(a, b):
s = set(a)
s.difference_update(set(b))
return list(s)
def intersection_update_usecase(a, b):
s = set(a)
s.intersection_update(set(b))
return list(s)
def symmetric_difference_update_usecase(a, b):
s = set(a)
s.symmetric_difference_update(set(b))
return list(s)
def isdisjoint_usecase(a, b):
return set(a).isdisjoint(set(b))
def issubset_usecase(a, b):
return set(a).issubset(set(b))
def issuperset_usecase(a, b):
return set(a).issuperset(set(b))
def clear_usecase(a):
s = set(a)
s.clear()
return len(s), list(s)
def copy_usecase(a):
s = set(a)
ss = s.copy()
s.pop()
return len(ss), list(ss)
def copy_usecase_empty(a):
s = set(a)
s.clear()
ss = s.copy()
s.add(a[0])
return len(ss), list(ss)
def copy_usecase_deleted(a, b):
s = set(a)
s.remove(b)
ss = s.copy()
s.pop()
return len(ss), list(ss)
def difference_usecase(a, b):
sa = set(a)
s = sa.difference(set(b))
return list(s)
def intersection_usecase(a, b):
sa = set(a)
s = sa.intersection(set(b))
return list(s)
def symmetric_difference_usecase(a, b):
sa = set(a)
s = sa.symmetric_difference(set(b))
return list(s)
def union_usecase(a, b):
sa = set(a)
s = sa.union(set(b))
return list(s)
def set_return_usecase(a):
s = set(a)
return s
def noop(x):
pass
def unbox_usecase(x):
"""
Expect a set of numbers
"""
res = 0
for v in x:
res += v
return res
def unbox_usecase2(x):
"""
Expect a set of tuples
"""
res = 0
for v in x:
res += len(v)
return res
def unbox_usecase3(x):
"""
Expect a (number, set of numbers) tuple.
"""
a, b = x
res = a
for v in b:
res += v
return res
def unbox_usecase4(x):
"""
Expect a (number, set of tuples) tuple.
"""
a, b = x
res = a
for v in b:
res += len(v)
return res
def reflect_simple(sa, sb):
sa.add(42)
sa.update(sb)
return sa, len(sa), len(sb)
def reflect_conditional(sa, sb):
# `sa` may or may not actually reflect a Python set
if len(sb) > 1:
sa = set((11., 22., 33., 44.))
sa.add(42.)
sa.update(sb)
# Combine with a non-reflected set (to check method typing)
sc = set((55., 66.))
sa.symmetric_difference_update(sc)
return sa, len(sa), len(sb)
def reflect_exception(s):
s.add(42)
raise ZeroDivisionError
def reflect_dual(sa, sb):
sa.add(sb.pop())
return sa is sb
def unique_usecase(src):
seen = set()
res = []
for v in src:
if v not in seen:
seen.add(v)
res.append(v)
return res
class BaseTest(MemoryLeakMixin, TestCase):
def setUp(self):
super(BaseTest, self).setUp()
self.rnd = random.Random(42)
def _range(self, stop):
return np.arange(int(stop))
def _random_choice(self, seq, n):
"""
Choose *n* possibly duplicate items from sequence.
"""
l = [self.rnd.choice(list(seq)) for i in range(n)]
if isinstance(seq, np.ndarray):
return np.array(l, dtype=seq.dtype)
else:
return l
def duplicates_array(self, n):
"""
Get a 1d array with many duplicate values.
"""
a = self._range(np.sqrt(n))
return self._random_choice(a, n)
def sparse_array(self, n):
"""
Get a 1d array with values spread around.
"""
# Note two calls to sparse_array() should generate reasonable overlap
a = self._range(n ** 1.3)
return self._random_choice(a, n)
def _assert_equal_unordered(self, a, b):
if isinstance(a, tuple):
self.assertIsInstance(b, tuple)
for u, v in zip(a, b):
self._assert_equal_unordered(u, v)
elif isinstance(a, list):
self.assertIsInstance(b, list)
self.assertPreciseEqual(sorted(a), sorted(b))
else:
self.assertPreciseEqual(a, b)
def unordered_checker(self, pyfunc):
cfunc = jit(nopython=True)(pyfunc)
def check(*args):
expected = pyfunc(*args)
got = cfunc(*args)
self._assert_equal_unordered(expected, got)
return check
class TestSetLiterals(BaseTest):
def check(self, pyfunc):
cfunc = njit(pyfunc)
expected = pyfunc()
got = cfunc()
self.assertPreciseEqual(expected, got)
return got, expected
def test_build_set(self):
pyfunc = set_literal_return_usecase((1, 2, 3, 2))
self.check(pyfunc)
def test_build_heterogeneous_set(self, flags=enable_pyobj_flags):
pyfunc = set_literal_return_usecase((1, 2.0, 3j, 2))
self.check(pyfunc)
pyfunc = set_literal_return_usecase((2.0, 2))
got, expected = self.check(pyfunc)
self.assertIs(type(got.pop()), type(expected.pop()))
def test_build_set_nopython(self):
arg = list(self.sparse_array(50))
pyfunc = set_literal_convert_usecase(arg)
cfunc = jit(nopython=True)(pyfunc)
expected = pyfunc()
got = cfunc()
self.assertPreciseEqual(sorted(expected), sorted(got))
class TestSets(BaseTest):
def test_constructor(self):
pyfunc = empty_constructor_usecase
cfunc = jit(nopython=True)(pyfunc)
self.assertPreciseEqual(cfunc(), pyfunc())
pyfunc = constructor_usecase
cfunc = jit(nopython=True)(pyfunc)
def check(arg):
self.assertPreciseEqual(pyfunc(arg), cfunc(arg))
check(self.duplicates_array(200))
check(self.sparse_array(200))
def test_set_return(self):
pyfunc = set_return_usecase
cfunc = jit(nopython=True)(pyfunc)
arg = self.duplicates_array(200)
self.assertEqual(cfunc(arg), set(arg))
def test_iterator(self):
pyfunc = iterator_usecase
check = self.unordered_checker(pyfunc)
check(self.duplicates_array(200))
check(self.sparse_array(200))
def test_update(self):
pyfunc = update_usecase
check = self.unordered_checker(pyfunc)
a = self.sparse_array(50)
b = self.duplicates_array(50)
c = self.sparse_array(50)
check(a, b, c)
def test_remove(self):
pyfunc = remove_usecase
check = self.unordered_checker(pyfunc)
a = self.sparse_array(50)
b = a[::10]
check(a, b)
def test_remove_error(self):
# References are leaked on exception
self.disable_leak_check()
pyfunc = remove_usecase
cfunc = jit(nopython=True)(pyfunc)
# ensure that there will be a key error
items = tuple(set(self.sparse_array(3)))
a = items[1:]
b = (items[0],)
with self.assertRaises(KeyError):
cfunc(a, b)
def test_discard(self):
pyfunc = discard_usecase
check = self.unordered_checker(pyfunc)
a = self.sparse_array(50)
b = self.sparse_array(50)
check(a, b)
def test_add_discard(self):
"""
Check that the insertion logic does not create an infinite lookup
chain with deleted entries (insertion should happen at the first
deleted entry, not at the free entry at the end of the chain).
See issue #1913.
"""
pyfunc = add_discard_usecase
check = self.unordered_checker(pyfunc)
# ensure a and b are different
a = b = None
while a == b:
a, b = self.sparse_array(2)
check((a,), b, b)
def test_pop(self):
pyfunc = pop_usecase
check = self.unordered_checker(pyfunc)
check(self.sparse_array(50))
def test_contains(self):
pyfunc = contains_usecase
cfunc = jit(nopython=True)(pyfunc)
def check(a, b):
self.assertPreciseEqual(pyfunc(a, b), cfunc(a, b))
a = self.sparse_array(50)
b = self.sparse_array(50)
check(a, b)
def _test_xxx_update(self, pyfunc):
check = self.unordered_checker(pyfunc)
sizes = (1, 50, 500)
for na, nb in itertools.product(sizes, sizes):
a = self.sparse_array(na)
b = self.sparse_array(nb)
check(a, b)
def test_difference_update(self):
self._test_xxx_update(difference_update_usecase)
def test_intersection_update(self):
self._test_xxx_update(intersection_update_usecase)
def test_symmetric_difference_update(self):
self._test_xxx_update(symmetric_difference_update_usecase)
def _test_comparator(self, pyfunc):
cfunc = jit(nopython=True)(pyfunc)
def check(a, b):
self.assertPreciseEqual(pyfunc(a, b), cfunc(a, b))
a, b = map(set, [self.sparse_array(10), self.sparse_array(15)])
args = [a & b, a - b, a | b, a ^ b]
args = [tuple(x) for x in args]
for a, b in itertools.product(args, args):
check(a, b)
def test_isdisjoint(self):
self._test_comparator(isdisjoint_usecase)
def test_issubset(self):
self._test_comparator(issubset_usecase)
def test_issuperset(self):
self._test_comparator(issuperset_usecase)
def test_clear(self):
pyfunc = clear_usecase
check = self.unordered_checker(pyfunc)
check(self.sparse_array(50))
def test_copy(self):
# Source set doesn't have any deleted entries
pyfunc = copy_usecase
check = self.unordered_checker(pyfunc)
check(self.sparse_array(50))
pyfunc = copy_usecase_empty
check = self.unordered_checker(pyfunc)
a = self.sparse_array(1)
check(a)
# Source set has deleted entries
pyfunc = copy_usecase_deleted
check = self.unordered_checker(pyfunc)
check((1, 2, 4, 11), 2)
a = self.sparse_array(50)
check(a, a[len(a) // 2])
def test_bool(self):
pyfunc = bool_usecase
check = self.unordered_checker(pyfunc)
check(self.sparse_array(1))
check(self.sparse_array(2))
def _test_set_operator(self, pyfunc):
check = self.unordered_checker(pyfunc)
a, b = (1, 2, 4, 11), (2, 3, 5, 11, 42)
check(a, b)
sizes = (1, 50, 500)
for na, nb in itertools.product(sizes, sizes):
a = self.sparse_array(na)
b = self.sparse_array(nb)
check(a, b)
def make_operator_usecase(self, op):
code = """if 1:
def operator_usecase(a, b):
s = set(a) %(op)s set(b)
return list(s)
""" % dict(op=op)
return compile_function('operator_usecase', code, globals())
def make_inplace_operator_usecase(self, op):
code = """if 1:
def inplace_operator_usecase(a, b):
sa = set(a)
sb = set(b)
sc = sa
sc %(op)s sb
return list(sc), list(sa)
""" % dict(op=op)
return compile_function('inplace_operator_usecase', code, globals())
def make_comparison_usecase(self, op):
code = """if 1:
def comparison_usecase(a, b):
return set(a) %(op)s set(b)
""" % dict(op=op)
return compile_function('comparison_usecase', code, globals())
def test_difference(self):
self._test_set_operator(difference_usecase)
def test_intersection(self):
self._test_set_operator(intersection_usecase)
def test_symmetric_difference(self):
self._test_set_operator(symmetric_difference_usecase)
def test_union(self):
self._test_set_operator(union_usecase)
def test_and(self):
self._test_set_operator(self.make_operator_usecase('&'))
def test_or(self):
self._test_set_operator(self.make_operator_usecase('|'))
def test_sub(self):
self._test_set_operator(self.make_operator_usecase('-'))
def test_xor(self):
self._test_set_operator(self.make_operator_usecase('^'))
def test_eq(self):
self._test_set_operator(self.make_comparison_usecase('=='))
def test_ne(self):
self._test_set_operator(self.make_comparison_usecase('!='))
def test_le(self):
self._test_set_operator(self.make_comparison_usecase('<='))
def test_lt(self):
self._test_set_operator(self.make_comparison_usecase('<'))
def test_ge(self):
self._test_set_operator(self.make_comparison_usecase('>='))
def test_gt(self):
self._test_set_operator(self.make_comparison_usecase('>'))
def test_iand(self):
self._test_set_operator(self.make_inplace_operator_usecase('&='))
def test_ior(self):
self._test_set_operator(self.make_inplace_operator_usecase('|='))
def test_isub(self):
self._test_set_operator(self.make_inplace_operator_usecase('-='))
def test_ixor(self):
self._test_set_operator(self.make_inplace_operator_usecase('^='))
class TestFloatSets(TestSets):
"""
Test sets with floating-point keys.
"""
# Only a few basic tests here, as the sanity of most operations doesn't
# depend on the key type.
def _range(self, stop):
return np.arange(stop, dtype=np.float32) * np.float32(0.1)
class TestTupleSets(TestSets):
"""
Test sets with tuple keys.
"""
def _range(self, stop):
a = np.arange(stop, dtype=np.int64)
b = a & 0x5555555555555555
c = (a & 0xaaaaaaaa).astype(np.int32)
d = ((a >> 32) & 1).astype(np.bool_)
return list(zip(b, c, d))
class TestUnicodeSets(TestSets):
"""
Test sets with unicode keys. For the purpose of testing refcounted sets.
"""
def _range(self, stop):
return ['A{}'.format(i) for i in range(int(stop))]
class TestSetsInvalidDtype(TestSets):
def _test_set_operator(self, pyfunc):
# it is invalid to apply some set operations on
# sets with different dtype
cfunc = jit(nopython=True)(pyfunc)
a = set([1, 2, 4, 11])
b = set(['a', 'b', 'c'])
msg = 'All Sets must be of the same type'
with self.assertRaisesRegex(TypingError, msg):
cfunc(a, b)
class TestSetsInvalid(TestSets):
def symmetric_difference_usecase(a, b):
s = a.symmetric_difference(b)
return list(s)
def difference_usecase(a, b):
s = a.difference(b)
return list(s)
def intersection_usecase(a, b):
s = a.intersection(b)
return list(s)
def union_usecase(a, b):
s = a.union(b)
return list(s)
def _test_set_operator(self, pyfunc):
# it is invalid to apply some set operations on
# sets with different dtype
cfunc = jit(nopython=True)(pyfunc)
a = set([1, 2, 4, 11])
b = (1, 2, 3)
msg = 'All arguments must be Sets'
with self.assertRaisesRegex(TypingError, msg):
cfunc(a, b)
def test_difference(self):
self._test_set_operator(TestSetsInvalid.difference_usecase)
def test_intersection(self):
self._test_set_operator(TestSetsInvalid.intersection_usecase)
def test_symmetric_difference(self):
self._test_set_operator(TestSetsInvalid.symmetric_difference_usecase)
def test_union(self):
self._test_set_operator(TestSetsInvalid.union_usecase)
def make_operator_usecase(self, op):
code = """if 1:
def operator_usecase(a, b):
s = a %(op)s b
return list(s)
""" % dict(op=op)
return compile_function('operator_usecase', code, globals())
def make_inplace_operator_usecase(self, op):
code = """if 1:
def inplace_operator_usecase(a, b):
sa = a
sb = b
sc = sa
sc %(op)s sb
return list(sc), list(sa)
""" % dict(op=op)
return compile_function('inplace_operator_usecase', code, globals())
def make_comparison_usecase(self, op):
code = """if 1:
def comparison_usecase(a, b):
return set(a) %(op)s b
""" % dict(op=op)
return compile_function('comparison_usecase', code, globals())
class TestUnboxing(BaseTest):
"""
Test unboxing of Python sets into native Numba sets.
"""
@contextlib.contextmanager
def assert_type_error(self, msg):
with self.assertRaises(TypeError) as raises:
yield
if msg is not None:
self.assertRegex(str(raises.exception), msg)
def check_unary(self, pyfunc):
cfunc = jit(nopython=True)(pyfunc)
def check(arg):
expected = pyfunc(arg)
got = cfunc(arg)
self.assertPreciseEqual(got, expected)
return check
def test_numbers(self):
check = self.check_unary(unbox_usecase)
check(set([1, 2]))
check(set([1j, 2.5j]))
# Check allocation and sizing
check(set(range(100)))
def test_tuples(self):
check = self.check_unary(unbox_usecase2)
check(set([(1, 2), (3, 4)]))
check(set([(1, 2j), (3, 4j)]))
def test_set_inside_tuple(self):
check = self.check_unary(unbox_usecase3)
check((1, set([2, 3, 4])))
def test_set_of_tuples_inside_tuple(self):
check = self.check_unary(unbox_usecase4)
check((1, set([(2,), (3,)])))
def test_errors(self):
# Error checking should ensure the set is homogeneous
msg = "can't unbox heterogeneous set"
pyfunc = noop
cfunc = jit(nopython=True)(pyfunc)
val = set([1, 2.5])
with self.assert_type_error(msg):
cfunc(val)
# The set hasn't been changed (bogus reflecting)
self.assertEqual(val, set([1, 2.5]))
with self.assert_type_error(msg):
cfunc(set([1, 2j]))
# Same when the set is nested in a tuple or namedtuple
with self.assert_type_error(msg):
cfunc((1, set([1, 2j])))
with self.assert_type_error(msg):
cfunc(Point(1, set([1, 2j])))
# Tuples of different size.
# Note the check is really on the tuple side.
lst = set([(1,), (2, 3)])
# Depending on which tuple is examined first, we could get
# a IndexError or a ValueError.
with self.assertRaises((IndexError, ValueError)) as raises:
cfunc(lst)
class TestSetReflection(BaseTest):
"""
Test reflection of native Numba sets on Python set objects.
"""
def check_reflection(self, pyfunc):
cfunc = jit(nopython=True)(pyfunc)
samples = [(set([1., 2., 3., 4.]), set([0.])),
(set([1., 2., 3., 4.]), set([5., 6., 7., 8., 9.])),
]
for dest, src in samples:
expected = set(dest)
got = set(dest)
pyres = pyfunc(expected, src)
with self.assertRefCount(got, src):
cres = cfunc(got, src)
self.assertPreciseEqual(cres, pyres)
self.assertPreciseEqual(expected, got)
self.assertEqual(pyres[0] is expected, cres[0] is got)
del pyres, cres
def test_reflect_simple(self):
self.check_reflection(reflect_simple)
def test_reflect_conditional(self):
self.check_reflection(reflect_conditional)
def test_reflect_exception(self):
"""
When the function exits with an exception, sets should still be
reflected.
"""
pyfunc = reflect_exception
cfunc = jit(nopython=True)(pyfunc)
s = set([1, 2, 3])
with self.assertRefCount(s):
with self.assertRaises(ZeroDivisionError):
cfunc(s)
self.assertPreciseEqual(s, set([1, 2, 3, 42]))
def test_reflect_same_set(self):
"""
When the same set object is reflected twice, behaviour should
be consistent.
"""
pyfunc = reflect_dual
cfunc = jit(nopython=True)(pyfunc)
pyset = set([1, 2, 3])
cset = pyset.copy()
expected = pyfunc(pyset, pyset)
got = cfunc(cset, cset)
self.assertPreciseEqual(expected, got)
self.assertPreciseEqual(pyset, cset)
self.assertRefCountEqual(pyset, cset)
def test_reflect_clean(self):
"""
When the set wasn't mutated, no reflection should take place.
"""
cfunc = jit(nopython=True)(noop)
# Use a complex, as Python integers can be cached
s = set([12.5j])
ids = [id(x) for x in s]
cfunc(s)
self.assertEqual([id(x) for x in s], ids)
class TestExamples(BaseTest):
"""
Examples of using sets.
"""
def test_unique(self):
pyfunc = unique_usecase
check = self.unordered_checker(pyfunc)
check(self.duplicates_array(200))
check(self.sparse_array(200))
def test_type_coercion_from_update(self):
# see issue #6621
def impl():
i = np.uint64(1)
R = set()
R.update({1, 2, 3})
R.add(i)
return R
check = self.unordered_checker(impl)
check()
if __name__ == '__main__':
unittest.main()