""" Testing numba implementation of the numba dictionary. The tests here only check that the numba typing and codegen are working correctly. Detailed testing of the underlying dictionary operations is done in test_dictimpl.py. """ import sys import warnings import numpy as np from numba import njit, literally from numba import int32, int64, float32, float64 from numba import typeof from numba.typed import Dict, dictobject, List from numba.typed.typedobjectutils import _sentry_safe_cast from numba.core.errors import TypingError from numba.core import types from numba.tests.support import (TestCase, MemoryLeakMixin, unittest, override_config, forbid_codegen) from numba.experimental import jitclass from numba.extending import overload class TestDictObject(MemoryLeakMixin, TestCase): def test_dict_bool(self): """ Exercise bool(dict) """ @njit def foo(n): d = dictobject.new_dict(int32, float32) for i in range(n): d[i] = i + 1 return bool(d) # Insert nothing self.assertEqual(foo(n=0), False) # Insert 1 entry self.assertEqual(foo(n=1), True) # Insert 2 entries self.assertEqual(foo(n=2), True) # Insert 100 entries self.assertEqual(foo(n=100), True) def test_dict_create(self): """ Exercise dictionary creation, insertion and len """ @njit def foo(n): d = dictobject.new_dict(int32, float32) for i in range(n): d[i] = i + 1 return len(d) # Insert nothing self.assertEqual(foo(n=0), 0) # Insert 1 entry self.assertEqual(foo(n=1), 1) # Insert 2 entries self.assertEqual(foo(n=2), 2) # Insert 100 entries self.assertEqual(foo(n=100), 100) def test_dict_get(self): """ Exercise dictionary creation, insertion and get """ @njit def foo(n, targets): d = dictobject.new_dict(int32, float64) # insertion loop for i in range(n): d[i] = i # retrieval loop output = [] for t in targets: output.append(d.get(t)) return output self.assertEqual(foo(5, [0, 1, 9]), [0, 1, None]) self.assertEqual(foo(10, [0, 1, 9]), [0, 1, 9]) self.assertEqual(foo(10, [-1, 9, 1]), [None, 9, 1]) def test_dict_get_with_default(self): """ Exercise dict.get(k, d) where d is set """ @njit def foo(n, target, default): d = dictobject.new_dict(int32, float64) # insertion loop for i in range(n): d[i] = i # retrieval loop return d.get(target, default) self.assertEqual(foo(5, 3, -1), 3) self.assertEqual(foo(5, 5, -1), -1) def test_dict_getitem(self): """ Exercise dictionary __getitem__ """ @njit def foo(keys, vals, target): d = dictobject.new_dict(int32, float64) # insertion for k, v in zip(keys, vals): d[k] = v # lookup return d[target] keys = [1, 2, 3] vals = [0.1, 0.2, 0.3] self.assertEqual(foo(keys, vals, 1), 0.1) self.assertEqual(foo(keys, vals, 2), 0.2) self.assertEqual(foo(keys, vals, 3), 0.3) # check no leak so far self.assert_no_memory_leak() # disable leak check for exception test self.disable_leak_check() with self.assertRaises(KeyError): foo(keys, vals, 0) with self.assertRaises(KeyError): foo(keys, vals, 4) def test_dict_popitem(self): """ Exercise dictionary .popitem """ @njit def foo(keys, vals): d = dictobject.new_dict(int32, float64) # insertion for k, v in zip(keys, vals): d[k] = v # popitem return d.popitem() keys = [1, 2, 3] vals = [0.1, 0.2, 0.3] for i in range(1, len(keys)): self.assertEqual( foo(keys[:i], vals[:i]), (keys[i - 1], vals[i - 1]), ) def test_dict_popitem_many(self): """ Exercise dictionary .popitem """ @njit def core(d, npop): # popitem keysum, valsum = 0, 0 for _ in range(npop): k, v = d.popitem() keysum += k valsum -= v return keysum, valsum @njit def foo(keys, vals, npop): d = dictobject.new_dict(int32, int32) # insertion for k, v in zip(keys, vals): d[k] = v return core(d, npop) keys = [1, 2, 3] vals = [10, 20, 30] for i in range(len(keys)): self.assertEqual( foo(keys, vals, npop=3), core.py_func(dict(zip(keys, vals)), npop=3), ) # check no leak so far self.assert_no_memory_leak() # disable leak check for exception test self.disable_leak_check() with self.assertRaises(KeyError): foo(keys, vals, npop=4) def test_dict_pop(self): """ Exercise dictionary .pop """ @njit def foo(keys, vals, target): d = dictobject.new_dict(int32, float64) # insertion for k, v in zip(keys, vals): d[k] = v # popitem return d.pop(target, None), len(d) keys = [1, 2, 3] vals = [0.1, 0.2, 0.3] self.assertEqual(foo(keys, vals, 1), (0.1, 2)) self.assertEqual(foo(keys, vals, 2), (0.2, 2)) self.assertEqual(foo(keys, vals, 3), (0.3, 2)) self.assertEqual(foo(keys, vals, 0), (None, 3)) # check no leak so far self.assert_no_memory_leak() # disable leak check for exception test self.disable_leak_check() @njit def foo(): d = dictobject.new_dict(int32, float64) # popitem return d.pop(0) with self.assertRaises(KeyError): foo() def test_dict_pop_many(self): """ Exercise dictionary .pop """ @njit def core(d, pops): total = 0 for k in pops: total += k + d.pop(k, 0.123) + len(d) total *= 2 return total @njit def foo(keys, vals, pops): d = dictobject.new_dict(int32, float64) # insertion for k, v in zip(keys, vals): d[k] = v # popitem return core(d, pops) keys = [1, 2, 3] vals = [0.1, 0.2, 0.3] pops = [2, 3, 3, 1, 0, 2, 1, 0, -1] self.assertEqual( foo(keys, vals, pops), core.py_func(dict(zip(keys, vals)), pops), ) def test_dict_delitem(self): @njit def foo(keys, vals, target): d = dictobject.new_dict(int32, float64) # insertion for k, v in zip(keys, vals): d[k] = v del d[target] return len(d), d.get(target) keys = [1, 2, 3] vals = [0.1, 0.2, 0.3] self.assertEqual(foo(keys, vals, 1), (2, None)) self.assertEqual(foo(keys, vals, 2), (2, None)) self.assertEqual(foo(keys, vals, 3), (2, None)) # check no leak so far self.assert_no_memory_leak() # disable leak check for exception test self.disable_leak_check() with self.assertRaises(KeyError): foo(keys, vals, 0) def test_dict_clear(self): """ Exercise dict.clear """ @njit def foo(keys, vals): d = dictobject.new_dict(int32, float64) # insertion for k, v in zip(keys, vals): d[k] = v b4 = len(d) # clear d.clear() return b4, len(d) keys = [1, 2, 3] vals = [0.1, 0.2, 0.3] self.assertEqual(foo(keys, vals), (3, 0)) def test_dict_items(self): """ Exercise dict.items """ @njit def foo(keys, vals): d = dictobject.new_dict(int32, float64) # insertion for k, v in zip(keys, vals): d[k] = v out = [] for kv in d.items(): out.append(kv) return out keys = [1, 2, 3] vals = [0.1, 0.2, 0.3] self.assertEqual( foo(keys, vals), list(zip(keys, vals)), ) # Test .items() on empty dict @njit def foo(): d = dictobject.new_dict(int32, float64) out = [] for kv in d.items(): out.append(kv) return out self.assertEqual(foo(), []) def test_dict_keys(self): """ Exercise dict.keys """ @njit def foo(keys, vals): d = dictobject.new_dict(int32, float64) # insertion for k, v in zip(keys, vals): d[k] = v out = [] for k in d.keys(): out.append(k) return out keys = [1, 2, 3] vals = [0.1, 0.2, 0.3] self.assertEqual( foo(keys, vals), keys, ) def test_dict_keys_len(self): """ Exercise len(dict.keys()) """ @njit def foo(keys, vals): d = dictobject.new_dict(int32, float64) # insertion for k, v in zip(keys, vals): d[k] = v return len(d.keys()) keys = [1, 2, 3] vals = [0.1, 0.2, 0.3] self.assertEqual( foo(keys, vals), len(keys), ) def test_dict_values(self): """ Exercise dict.values """ @njit def foo(keys, vals): d = dictobject.new_dict(int32, float64) # insertion for k, v in zip(keys, vals): d[k] = v out = [] for v in d.values(): out.append(v) return out keys = [1, 2, 3] vals = [0.1, 0.2, 0.3] self.assertEqual( foo(keys, vals), vals, ) def test_dict_values_len(self): """ Exercise len(dict.values()) """ @njit def foo(keys, vals): d = dictobject.new_dict(int32, float64) # insertion for k, v in zip(keys, vals): d[k] = v return len(d.values()) keys = [1, 2, 3] vals = [0.1, 0.2, 0.3] self.assertEqual( foo(keys, vals), len(vals), ) def test_dict_items_len(self): """ Exercise len(dict.items()) """ @njit def foo(keys, vals): d = dictobject.new_dict(int32, float64) # insertion for k, v in zip(keys, vals): d[k] = v return len(d.items()) keys = [1, 2, 3] vals = [0.1, 0.2, 0.3] self.assertPreciseEqual( foo(keys, vals), len(vals), ) def test_dict_iter(self): """ Exercise iter(dict) """ @njit def foo(keys, vals): d = dictobject.new_dict(int32, float64) # insertion for k, v in zip(keys, vals): d[k] = v out = [] for k in d: out.append(k) return out keys = [1, 2, 3] vals = [0.1, 0.2, 0.3] self.assertEqual( foo(keys, vals), [1, 2, 3] ) def test_dict_contains(self): """ Exercise operator.contains """ @njit def foo(keys, vals, checklist): d = dictobject.new_dict(int32, float64) # insertion for k, v in zip(keys, vals): d[k] = v out = [] for k in checklist: out.append(k in d) return out keys = [1, 2, 3] vals = [0.1, 0.2, 0.3] self.assertEqual( foo(keys, vals, [2, 3, 4, 1, 0]), [True, True, False, True, False], ) def test_dict_copy(self): """ Exercise dict.copy """ @njit def foo(keys, vals): d = dictobject.new_dict(int32, float64) # insertion for k, v in zip(keys, vals): d[k] = v return list(d.copy().items()) keys = list(range(20)) vals = [x + i / 100 for i, x in enumerate(keys)] out = foo(keys, vals) self.assertEqual(out, list(zip(keys, vals))) def test_dict_setdefault(self): """ Exercise dict.setdefault """ @njit def foo(): d = dictobject.new_dict(int32, float64) d.setdefault(1, 1.2) # used because key is not in a = d.get(1) d[1] = 2.3 b = d.get(1) d[2] = 3.4 d.setdefault(2, 4.5) # not used because key is in c = d.get(2) return a, b, c self.assertEqual(foo(), (1.2, 2.3, 3.4)) def test_dict_equality(self): """ Exercise dict.__eq__ and .__ne__ """ @njit def foo(na, nb, fa, fb): da = dictobject.new_dict(int32, float64) db = dictobject.new_dict(int32, float64) for i in range(na): da[i] = i * fa for i in range(nb): db[i] = i * fb return da == db, da != db # Same keys and values self.assertEqual(foo(10, 10, 3, 3), (True, False)) # Same keys and diff values self.assertEqual(foo(10, 10, 3, 3.1), (False, True)) # LHS has more keys self.assertEqual(foo(11, 10, 3, 3), (False, True)) # RHS has more keys self.assertEqual(foo(10, 11, 3, 3), (False, True)) def test_dict_equality_more(self): """ Exercise dict.__eq__ """ @njit def foo(ak, av, bk, bv): # The key-value types are different in the two dictionaries da = dictobject.new_dict(int32, float64) db = dictobject.new_dict(int64, float32) for i in range(len(ak)): da[ak[i]] = av[i] for i in range(len(bk)): db[bk[i]] = bv[i] return da == db # Simple equal case ak = [1, 2, 3] av = [2, 3, 4] bk = [1, 2, 3] bv = [2, 3, 4] self.assertTrue(foo(ak, av, bk, bv)) # Equal with replacement ak = [1, 2, 3] av = [2, 3, 4] bk = [1, 2, 2, 3] bv = [2, 1, 3, 4] self.assertTrue(foo(ak, av, bk, bv)) # Diff values ak = [1, 2, 3] av = [2, 3, 4] bk = [1, 2, 3] bv = [2, 1, 4] self.assertFalse(foo(ak, av, bk, bv)) # Diff keys ak = [0, 2, 3] av = [2, 3, 4] bk = [1, 2, 3] bv = [2, 3, 4] self.assertFalse(foo(ak, av, bk, bv)) def test_dict_equality_diff_type(self): """ Exercise dict.__eq__ """ @njit def foo(na, b): da = dictobject.new_dict(int32, float64) for i in range(na): da[i] = i return da == b # dict != int self.assertFalse(foo(10, 1)) # dict != tuple[int] self.assertFalse(foo(10, (1,))) def test_dict_to_from_meminfo(self): """ Exercise dictobject.{_as_meminfo, _from_meminfo} """ @njit def make_content(nelem): for i in range(nelem): yield i, i + (i + 1) / 100 @njit def boxer(nelem): d = dictobject.new_dict(int32, float64) for k, v in make_content(nelem): d[k] = v return dictobject._as_meminfo(d) dcttype = types.DictType(int32, float64) @njit def unboxer(mi): d = dictobject._from_meminfo(mi, dcttype) return list(d.items()) mi = boxer(10) self.assertEqual(mi.refcount, 1) got = unboxer(mi) expected = list(make_content.py_func(10)) self.assertEqual(got, expected) def test_001_cannot_downcast_key(self): @njit def foo(n): d = dictobject.new_dict(int32, float64) for i in range(n): d[i] = i + 1 # bad key type z = d.get(1j) return z with self.assertRaises(TypingError) as raises: foo(10) self.assertIn( 'cannot safely cast complex128 to int32', str(raises.exception), ) def test_002_cannot_downcast_default(self): @njit def foo(n): d = dictobject.new_dict(int32, float64) for i in range(n): d[i] = i + 1 # bad default type z = d.get(2 * n, 1j) return z with self.assertRaises(TypingError) as raises: foo(10) self.assertIn( 'cannot safely cast complex128 to float64', str(raises.exception), ) def test_003_cannot_downcast_key(self): @njit def foo(n): d = dictobject.new_dict(int32, float64) for i in range(n): d[i] = i + 1 # bad cast!? z = d.get(2.4) return z # should raise with self.assertRaises(TypingError) as raises: foo(10) self.assertIn( 'cannot safely cast float64 to int32', str(raises.exception), ) def test_004_cannot_downcast_key(self): @njit def foo(): d = dictobject.new_dict(int32, float64) # should raise TypingError d[1j] = 7. with self.assertRaises(TypingError) as raises: foo() self.assertIn( 'cannot safely cast complex128 to int32', str(raises.exception), ) def test_005_cannot_downcast_value(self): @njit def foo(): d = dictobject.new_dict(int32, float64) # should raise TypingError d[1] = 1j with self.assertRaises(TypingError) as raises: foo() self.assertIn( 'cannot safely cast complex128 to float64', str(raises.exception), ) def test_006_cannot_downcast_key(self): @njit def foo(): d = dictobject.new_dict(int32, float64) # raise TypingError d[11.5] with self.assertRaises(TypingError) as raises: foo() self.assertIn( 'cannot safely cast float64 to int32', str(raises.exception), ) @unittest.skipUnless(sys.maxsize > 2 ** 32, "64 bit test only") def test_007_collision_checks(self): # this checks collisions in real life for 64bit systems @njit def foo(v1, v2): d = dictobject.new_dict(int64, float64) c1 = np.uint64(2 ** 61 - 1) c2 = np.uint64(0) assert hash(c1) == hash(c2) d[c1] = v1 d[c2] = v2 return (d[c1], d[c2]) a, b = 10., 20. x, y = foo(a, b) self.assertEqual(x, a) self.assertEqual(y, b) def test_008_lifo_popitem(self): # check that (keys, vals) are LIFO .popitem() @njit def foo(n): d = dictobject.new_dict(int32, float64) for i in range(n): d[i] = i + 1 keys = [] vals = [] for i in range(n): tmp = d.popitem() keys.append(tmp[0]) vals.append(tmp[1]) return keys, vals z = 10 gk, gv = foo(z) self.assertEqual(gk, [x for x in reversed(range(z))]) self.assertEqual(gv, [x + 1 for x in reversed(range(z))]) def test_010_cannot_downcast_default(self): @njit def foo(): d = dictobject.new_dict(int32, float64) d[0] = 6. d[1] = 7. # pop'd default must have same type as value d.pop(11, 12j) with self.assertRaises(TypingError) as raises: foo() self.assertIn( "cannot safely cast complex128 to float64", str(raises.exception), ) def test_011_cannot_downcast_key(self): @njit def foo(): d = dictobject.new_dict(int32, float64) d[0] = 6. d[1] = 7. # pop'd key must have same type as key d.pop(11j) with self.assertRaises(TypingError) as raises: foo() self.assertIn( "cannot safely cast complex128 to int32", str(raises.exception), ) def test_012_cannot_downcast_key(self): @njit def foo(): d = dictobject.new_dict(int32, float64) d[0] = 6. # invalid key type return 1j in d with self.assertRaises(TypingError) as raises: foo() self.assertIn( "cannot safely cast complex128 to int32", str(raises.exception), ) def test_013_contains_empty_dict(self): @njit def foo(): d = dictobject.new_dict(int32, float64) # contains on empty dict return 1 in d self.assertFalse(foo()) def test_014_not_contains_empty_dict(self): @njit def foo(): d = dictobject.new_dict(int32, float64) # not contains empty dict return 1 not in d self.assertTrue(foo()) def test_015_dict_clear(self): @njit def foo(n): d = dictobject.new_dict(int32, float64) for i in range(n): d[i] = i + 1 x = len(d) d.clear() y = len(d) return x, y m = 10 self.assertEqual(foo(m), (m, 0)) def test_016_cannot_downcast_key(self): @njit def foo(): d = dictobject.new_dict(int32, float64) # key is wrong type d.setdefault(1j, 12.) with self.assertRaises(TypingError) as raises: foo() self.assertIn( "cannot safely cast complex128 to int32", str(raises.exception), ) def test_017_cannot_downcast_default(self): @njit def foo(): d = dictobject.new_dict(int32, float64) # default value is wrong type d.setdefault(1, 12.j) with self.assertRaises(TypingError) as raises: foo() self.assertIn( "cannot safely cast complex128 to float64", str(raises.exception), ) def test_018_keys_iter_are_views(self): # this is broken somewhere in llvmlite, intent of test is to check if # keys behaves like a view or not @njit def foo(): d = dictobject.new_dict(int32, float64) d[11] = 12. k1 = d.keys() d[22] = 9. k2 = d.keys() rk1 = [x for x in k1] rk2 = [x for x in k2] return rk1, rk2 a, b = foo() self.assertEqual(a, b) self.assertEqual(a, [11, 22]) # Not implemented yet @unittest.expectedFailure def test_019(self): # should keys/vals be set-like? @njit def foo(): d = dictobject.new_dict(int32, float64) d[11] = 12. d[22] = 9. k2 = d.keys() & {12, } return k2 print(foo()) def test_020_string_key(self): @njit def foo(): d = dictobject.new_dict(types.unicode_type, float64) d['a'] = 1. d['b'] = 2. d['c'] = 3. d['d'] = 4. out = [] for x in d.items(): out.append(x) return out, d['a'] items, da = foo() self.assertEqual(items, [('a', 1.), ('b', 2.), ('c', 3.), ('d', 4)]) self.assertEqual(da, 1.) def test_021_long_str_key(self): @njit def foo(): d = dictobject.new_dict(types.unicode_type, float64) tmp = [] for i in range(10000): tmp.append('a') s = ''.join(tmp) d[s] = 1. out = list(d.items()) return out self.assertEqual(foo(), [('a' * 10000, 1)]) def test_022_references_juggle(self): @njit def foo(): d = dictobject.new_dict(int32, float64) e = d d[1] = 12. e[2] = 14. e = dictobject.new_dict(int32, float64) e[1] = 100. e[2] = 1000. f = d d = e k1 = [x for x in d.items()] k2 = [x for x in e.items()] k3 = [x for x in f.items()] return k1, k2, k3 k1, k2, k3 = foo() self.assertEqual(k1, [(1, 100.0), (2, 1000.0)]) self.assertEqual(k2, [(1, 100.0), (2, 1000.0)]) self.assertEqual(k3, [(1, 12), (2, 14)]) def test_023_closure(self): @njit def foo(): d = dictobject.new_dict(int32, float64) def bar(): d[1] = 12. d[2] = 14. bar() return [x for x in d.keys()] self.assertEqual(foo(), [1, 2]) def test_024_unicode_getitem_keys(self): # See issue #6135 @njit def foo(): s = 'a\u1234' d = {s[0] : 1} return d['a'] self.assertEqual(foo(), foo.py_func()) @njit def foo(): s = 'abc\u1234' d = {s[:1] : 1} return d['a'] self.assertEqual(foo(), foo.py_func()) def test_issue6570_alignment_padding(self): # Create a key type that is 12-bytes long on a 8-byte aligned system # so that the a 4-byte padding is needed. # If the 4-byte padding is not zero-filled, it will have garbage data # that affects key matching in the lookup. keyty = types.Tuple([types.uint64, types.float32]) @njit def foo(): d = dictobject.new_dict(keyty, float64) t1 = np.array([3], dtype=np.uint64) t2 = np.array([5.67], dtype=np.float32) v1 = np.array([10.23], dtype=np.float32) d[(t1[0], t2[0])] = v1[0] return (t1[0], t2[0]) in d self.assertTrue(foo()) def test_dict_update(self): """ Tests dict.update works with various dictionaries. """ n = 10 def f1(n): """ Test update with a regular dictionary. """ d1 = {i: i + 1 for i in range(n)} d2 = {3 * i: i for i in range(n)} d1.update(d2) return d1 py_func = f1 cfunc = njit()(f1) a = py_func(n) b = cfunc(n) self.assertEqual(a, b) def f2(n): """ Test update where one of the dictionaries is created as a Python literal. """ d1 = { 1: 2, 3: 4, 5: 6 } d2 = {3 * i: i for i in range(n)} d1.update(d2) return d1 py_func = f2 cfunc = njit()(f2) a = py_func(n) b = cfunc(n) self.assertEqual(a, b) class TestDictTypeCasting(TestCase): def check_good(self, fromty, toty): _sentry_safe_cast(fromty, toty) def check_bad(self, fromty, toty): with self.assertRaises(TypingError) as raises: _sentry_safe_cast(fromty, toty) self.assertIn( 'cannot safely cast {fromty} to {toty}'.format(**locals()), str(raises.exception), ) def test_cast_int_to(self): self.check_good(types.int32, types.float32) self.check_good(types.int32, types.float64) self.check_good(types.int32, types.complex128) self.check_good(types.int64, types.complex128) self.check_bad(types.int32, types.complex64) self.check_good(types.int8, types.complex64) def test_cast_float_to(self): self.check_good(types.float32, types.float64) self.check_good(types.float32, types.complex64) self.check_good(types.float64, types.complex128) def test_cast_bool_to(self): self.check_good(types.boolean, types.int32) self.check_good(types.boolean, types.float64) self.check_good(types.boolean, types.complex128) class TestTypedDict(MemoryLeakMixin, TestCase): def test_basic(self): d = Dict.empty(int32, float32) # len self.assertEqual(len(d), 0) # setitems d[1] = 1 d[2] = 2.3 d[3] = 3.4 self.assertEqual(len(d), 3) # keys self.assertEqual(list(d.keys()), [1, 2, 3]) # values for x, y in zip(list(d.values()), [1, 2.3, 3.4]): self.assertAlmostEqual(x, y, places=4) # getitem self.assertAlmostEqual(d[1], 1) self.assertAlmostEqual(d[2], 2.3, places=4) self.assertAlmostEqual(d[3], 3.4, places=4) # deltiem del d[2] self.assertEqual(len(d), 2) # get self.assertIsNone(d.get(2)) # setdefault d.setdefault(2, 100) d.setdefault(3, 200) self.assertEqual(d[2], 100) self.assertAlmostEqual(d[3], 3.4, places=4) # update d.update({4: 5, 5: 6}) self.assertAlmostEqual(d[4], 5) self.assertAlmostEqual(d[5], 6) # contains self.assertTrue(4 in d) # items pyd = dict(d.items()) self.assertEqual(len(pyd), len(d)) # pop self.assertAlmostEqual(d.pop(4), 5) # popitem nelem = len(d) k, v = d.popitem() self.assertEqual(len(d), nelem - 1) self.assertTrue(k not in d) # __eq__ & copy copied = d.copy() self.assertEqual(copied, d) self.assertEqual(list(copied.items()), list(d.items())) def test_copy_from_dict(self): expect = {k: float(v) for k, v in zip(range(10), range(10, 20))} nbd = Dict.empty(int32, float64) for k, v in expect.items(): nbd[k] = v got = dict(nbd) self.assertEqual(got, expect) def test_compiled(self): @njit def producer(): d = Dict.empty(int32, float64) d[1] = 1.23 return d @njit def consumer(d): return d[1] d = producer() val = consumer(d) self.assertEqual(val, 1.23) def test_gh7908(self): d = Dict.empty( key_type=types.Tuple([types.uint32, types.uint32]), value_type=int64) d[(1, 1)] = 12345 self.assertEqual(d[(1, 1)], d.get((1, 1))) def check_stringify(self, strfn, prefix=False): nbd = Dict.empty(int32, int32) d = {} nbd[1] = 2 d[1] = 2 checker = self.assertIn if prefix else self.assertEqual checker(strfn(d), strfn(nbd)) nbd[2] = 3 d[2] = 3 checker(strfn(d), strfn(nbd)) for i in range(10, 20): nbd[i] = i + 1 d[i] = i + 1 checker(strfn(d), strfn(nbd)) if prefix: self.assertTrue(strfn(nbd).startswith('DictType')) def test_repr(self): self.check_stringify(repr, prefix=True) def test_str(self): self.check_stringify(str) class DictIterableCtor: def test_iterable_type_constructor(self): # https://docs.python.org/3/library/stdtypes.html#dict @njit def func1(a, b): d = Dict(zip(a, b)) return d @njit def func2(a_, b): a = range(3) return Dict(zip(a, b)) @njit def func3(a_, b): a = [0, 1, 2] return Dict(zip(a, b)) @njit def func4(a, b): c = zip(a, b) return Dict(zip(a, zip(c, a))) @njit def func5(a, b): return Dict(zip(zip(a, b), b)) @njit def func6(items): return Dict(items) @njit def func7(k, v): return Dict({k: v}) # mapping - not supported @njit def func8(k, v): d = Dict() d[k] = v return d def _get_dict(py_dict): d = Dict() for k, v in py_dict.items(): d[k] = v return d vals = ( (func1, [(0, 1, 2), 'abc'], _get_dict({0: 'a', 1: 'b', 2: 'c'})), (func2, [(0, 1, 2), 'abc'], _get_dict({0: 'a', 1: 'b', 2: 'c'})), (func3, [(0, 1, 2), 'abc'], _get_dict({0: 'a', 1: 'b', 2: 'c'})), (func4, [(0, 1, 2), 'abc'], _get_dict( {0: ((0, 'a'), 0), 1: ((1, 'b'), 1), 2: ((2, 'c'), 2)})), (func5, [(0, 1, 2), 'abc'], _get_dict( {(0, 'a'): 'a', (1, 'b'): 'b', (2, 'c'): 'c'})), # (func6, [(),], Dict({})), (func6, [((1, 'a'), (3, 'b')),], _get_dict({1: 'a', 3: 'b'})), (func1, ['key', _get_dict({1: 'abc'})], _get_dict({'k': 1})), (func8, ['key', _get_dict({1: 'abc'})], _get_dict( {'key': _get_dict({1: 'abc'})})), (func8, ['key', List([1, 2, 3])], _get_dict( {'key': List([1, 2, 3])})), ) for func, args, expected in vals: if self.jit_enabled: got = func(*args) else: got = func.py_func(*args) self.assertPreciseEqual(expected, got) class TestDictIterableCtorJit(TestCase, DictIterableCtor): def setUp(self): self.jit_enabled = True def test_exception_no_iterable_arg(self): @njit def ctor(): return Dict(3) msg = ".*No implementation of function.*" with self.assertRaisesRegex(TypingError, msg): ctor() def test_exception_dict_mapping(self): @njit def ctor(): return Dict({1: 2, 3: 4}) msg = ".*No implementation of function.*" with self.assertRaisesRegex(TypingError, msg): ctor() def test_exception_setitem(self): @njit def ctor(): return Dict(((1, 'a'), (2, 'b', 3))) msg = ".*No implementation of function.*" with self.assertRaisesRegex(TypingError, msg): ctor() class TestDictIterableCtorNoJit(TestCase, DictIterableCtor): def setUp(self): self.jit_enabled = False def test_exception_nargs(self): msg = 'Dict expect at most 1 argument, got 2' with self.assertRaisesRegex(TypingError, msg): Dict(1, 2) def test_exception_mapping_ctor(self): msg = r'.*dict\(mapping\) is not supported.*' # noqa: W605 with self.assertRaisesRegex(TypingError, msg): Dict({1: 2}) def test_exception_non_iterable_arg(self): msg = '.*object is not iterable.*' with self.assertRaisesRegex(TypingError, msg): Dict(3) def test_exception_setitem(self): msg = ".*dictionary update sequence element #1 has length 3.*" with self.assertRaisesRegex(ValueError, msg): Dict(((1, 'a'), (2, 'b', 3))) class TestDictRefctTypes(MemoryLeakMixin, TestCase): def test_str_key(self): @njit def foo(): d = Dict.empty( key_type=types.unicode_type, value_type=types.int32, ) d["123"] = 123 d["321"] = 321 return d d = foo() self.assertEqual(d['123'], 123) self.assertEqual(d['321'], 321) expect = {'123': 123, '321': 321} self.assertEqual(dict(d), expect) # Test insert replacement d['123'] = 231 expect['123'] = 231 self.assertEqual(d['123'], 231) self.assertEqual(dict(d), expect) # Test dictionary growth nelem = 100 for i in range(nelem): d[str(i)] = i expect[str(i)] = i for i in range(nelem): self.assertEqual(d[str(i)], i) self.assertEqual(dict(d), expect) def test_str_val(self): @njit def foo(): d = Dict.empty( key_type=types.int32, value_type=types.unicode_type, ) d[123] = "123" d[321] = "321" return d d = foo() self.assertEqual(d[123], '123') self.assertEqual(d[321], '321') expect = {123: '123', 321: '321'} self.assertEqual(dict(d), expect) # Test insert replacement d[123] = "231" expect[123] = "231" self.assertEqual(dict(d), expect) # Test dictionary growth nelem = 1 for i in range(nelem): d[i] = str(i) expect[i] = str(i) for i in range(nelem): self.assertEqual(d[i], str(i)) self.assertEqual(dict(d), expect) def test_str_key_array_value(self): np.random.seed(123) d = Dict.empty( key_type=types.unicode_type, value_type=types.float64[:], ) expect = [] expect.append(np.random.random(10)) d['mass'] = expect[-1] expect.append(np.random.random(20)) d['velocity'] = expect[-1] for i in range(100): expect.append(np.random.random(i)) d[str(i)] = expect[-1] self.assertEqual(len(d), len(expect)) self.assertPreciseEqual(d['mass'], expect[0]) self.assertPreciseEqual(d['velocity'], expect[1]) # Ordering is kept for got, exp in zip(d.values(), expect): self.assertPreciseEqual(got, exp) # Try deleting self.assertTrue('mass' in d) self.assertTrue('velocity' in d) del d['mass'] self.assertFalse('mass' in d) del d['velocity'] self.assertFalse('velocity' in d) del expect[0:2] for i in range(90): k, v = d.popitem() w = expect.pop() self.assertPreciseEqual(v, w) # Trigger a resize expect.append(np.random.random(10)) d["last"] = expect[-1] # Ordering is kept for got, exp in zip(d.values(), expect): self.assertPreciseEqual(got, exp) def test_dict_of_dict_int_keyval(self): def inner_numba_dict(): d = Dict.empty( key_type=types.intp, value_type=types.intp, ) return d d = Dict.empty( key_type=types.intp, value_type=types.DictType(types.intp, types.intp), ) def usecase(d, make_inner_dict): for i in range(100): mid = make_inner_dict() for j in range(i + 1): mid[j] = j * 10000 d[i] = mid return d got = usecase(d, inner_numba_dict) expect = usecase({}, dict) self.assertIsInstance(expect, dict) self.assertEqual(dict(got), expect) # Delete items for where in [12, 3, 6, 8, 10]: del got[where] del expect[where] self.assertEqual(dict(got), expect) def test_dict_of_dict_npm(self): inner_dict_ty = types.DictType(types.intp, types.intp) @njit def inner_numba_dict(): d = Dict.empty( key_type=types.intp, value_type=types.intp, ) return d @njit def foo(count): d = Dict.empty( key_type=types.intp, value_type=inner_dict_ty, ) for i in range(count): d[i] = inner_numba_dict() for j in range(i + 1): d[i][j] = j return d d = foo(100) ct = 0 for k, dd in d.items(): ct += 1 self.assertEqual(len(dd), k + 1) for kk, vv in dd.items(): self.assertEqual(kk, vv) self.assertEqual(ct, 100) def test_delitem(self): d = Dict.empty(types.int64, types.unicode_type) d[1] = 'apple' @njit def foo(x, k): del x[1] foo(d, 1) self.assertEqual(len(d), 0) self.assertFalse(d) def test_getitem_return_type(self): # Dict.__getitem__ must return non-optional type. d = Dict.empty(types.int64, types.int64[:]) d[1] = np.arange(10, dtype=np.int64) @njit def foo(d): d[1] += 100 return d[1] foo(d) # Return type is an array, not optional retty = foo.nopython_signatures[0].return_type self.assertIsInstance(retty, types.Array) self.assertNotIsInstance(retty, types.Optional) # Value is correctly updated self.assertPreciseEqual(d[1], np.arange(10, dtype=np.int64) + 100) def test_storage_model_mismatch(self): # https://github.com/numba/numba/issues/4520 # check for storage model mismatch in refcount ops generation dct = Dict() ref = [ ("a", True, "a"), ("b", False, "b"), ("c", False, "c"), ] # populate for x in ref: dct[x] = x # test for i, x in enumerate(ref): self.assertEqual(dct[x], x) class TestDictForbiddenTypes(TestCase): def assert_disallow(self, expect, callable): with self.assertRaises(TypingError) as raises: callable() msg = str(raises.exception) self.assertIn(expect, msg) def assert_disallow_key(self, ty): msg = '{} as key is forbidden'.format(ty) self.assert_disallow(msg, lambda: Dict.empty(ty, types.intp)) @njit def foo(): Dict.empty(ty, types.intp) self.assert_disallow(msg, foo) def assert_disallow_value(self, ty): msg = '{} as value is forbidden'.format(ty) self.assert_disallow(msg, lambda: Dict.empty(types.intp, ty)) @njit def foo(): Dict.empty(types.intp, ty) self.assert_disallow(msg, foo) def test_disallow_list(self): self.assert_disallow_key(types.List(types.intp)) self.assert_disallow_value(types.List(types.intp)) def test_disallow_set(self): self.assert_disallow_key(types.Set(types.intp)) self.assert_disallow_value(types.Set(types.intp)) class TestDictInferred(TestCase): def test_simple_literal(self): @njit def foo(): d = Dict() d[123] = 321 return d k, v = 123, 321 d = foo() self.assertEqual(dict(d), {k: v}) self.assertEqual(typeof(d).key_type, typeof(k)) self.assertEqual(typeof(d).value_type, typeof(v)) def test_simple_args(self): @njit def foo(k, v): d = Dict() d[k] = v return d k, v = 123, 321 d = foo(k, v) self.assertEqual(dict(d), {k: v}) self.assertEqual(typeof(d).key_type, typeof(k)) self.assertEqual(typeof(d).value_type, typeof(v)) def test_simple_upcast(self): @njit def foo(k, v, w): d = Dict() d[k] = v d[k] = w return d k, v, w = 123, 32.1, 321 d = foo(k, v, w) self.assertEqual(dict(d), {k: w}) self.assertEqual(typeof(d).key_type, typeof(k)) self.assertEqual(typeof(d).value_type, typeof(v)) def test_conflicting_value_type(self): @njit def foo(k, v, w): d = Dict() d[k] = v d[k] = w return d k, v, w = 123, 321, 32.1 with self.assertRaises(TypingError) as raises: foo(k, v, w) self.assertIn( 'cannot safely cast float64 to {}'.format(typeof(v)), str(raises.exception), ) def test_conflicting_key_type(self): @njit def foo(k, h, v): d = Dict() d[k] = v d[h] = v return d k, h, v = 123, 123.1, 321 with self.assertRaises(TypingError) as raises: foo(k, h, v) self.assertIn( 'cannot safely cast float64 to {}'.format(typeof(v)), str(raises.exception), ) def test_conflict_key_type_non_number(self): # Allow non-number types to cast unsafely @njit def foo(k1, v1, k2): d = Dict() d[k1] = v1 return d, d[k2] # k2 will unsafely downcast typeof(k1) k1 = (np.int8(1), np.int8(2)) k2 = (np.int32(1), np.int32(2)) v1 = np.intp(123) with warnings.catch_warnings(record=True) as w: d, dk2 = foo(k1, v1, k2) self.assertEqual(len(w), 1) # Make sure the warning is about unsafe cast msg = 'unsafe cast from UniTuple(int32 x 2) to UniTuple(int8 x 2)' self.assertIn(msg, str(w[0])) keys = list(d.keys()) self.assertEqual(keys[0], (1, 2)) self.assertEqual(dk2, d[(np.int32(1), np.int32(2))]) def test_ifelse_filled_both_branches(self): @njit def foo(k, v): d = Dict() if k: d[k] = v else: d[0xdead] = v + 1 return d k, v = 123, 321 d = foo(k, v) self.assertEqual(dict(d), {k: v}) k, v = 0, 0 d = foo(k, v) self.assertEqual(dict(d), {0xdead: v + 1}) def test_ifelse_empty_one_branch(self): @njit def foo(k, v): d = Dict() if k: d[k] = v return d k, v = 123, 321 d = foo(k, v) self.assertEqual(dict(d), {k: v}) k, v = 0, 0 d = foo(k, v) self.assertEqual(dict(d), {}) self.assertEqual(typeof(d).key_type, typeof(k)) self.assertEqual(typeof(d).value_type, typeof(v)) def test_loop(self): @njit def foo(ks, vs): d = Dict() for k, v in zip(ks, vs): d[k] = v return d vs = list(range(4)) ks = list(map(lambda x : x + 100, vs)) d = foo(ks, vs) self.assertEqual(dict(d), dict(zip(ks, vs))) def test_unused(self): @njit def foo(): d = Dict() return d with self.assertRaises(TypingError) as raises: foo() self.assertIn( "imprecise type", str(raises.exception) ) def test_define_after_use(self): @njit def foo(define): d = Dict() ct = len(d) for k, v in d.items(): ct += v if define: # This will set the type d[1] = 2 return ct, d, len(d) ct, d, n = foo(True) self.assertEqual(ct, 0) self.assertEqual(n, 1) self.assertEqual(dict(d), {1: 2}) ct, d, n = foo(False) self.assertEqual(ct, 0) self.assertEqual(dict(d), {}) self.assertEqual(n, 0) def test_dict_of_dict(self): @njit def foo(k1, k2, v): d = Dict() z1 = Dict() z1[k1 + 1] = v + k1 z2 = Dict() z2[k2 + 2] = v + k2 d[k1] = z1 d[k2] = z2 return d k1, k2, v = 100, 200, 321 d = foo(k1, k2, v) self.assertEqual( dict(d), { k1: {k1 + 1: k1 + v}, k2: {k2 + 2: k2 + v}, }, ) def test_comprehension_basic(self): @njit def foo(): return {i: 2 * i for i in range(10)} self.assertEqual(foo(), foo.py_func()) def test_comprehension_basic_mixed_type(self): @njit def foo(): return {i: float(j) for i, j in zip(range(10), range(10, 0, -1))} self.assertEqual(foo(), foo.py_func()) def test_comprehension_involved(self): @njit def foo(): a = {0: 'A', 1: 'B', 2: 'C'} return {3 + i: a[i] for i in range(3)} self.assertEqual(foo(), foo.py_func()) def test_comprehension_fail_mixed_type(self): @njit def foo(): a = {0: 'A', 1: 'B', 2: 1j} return {3 + i: a[i] for i in range(3)} with self.assertRaises(TypingError) as e: foo() excstr = str(e.exception) self.assertIn("Cannot cast complex128 to unicode_type", excstr) class TestNonCompiledInfer(TestCase): def test_check_untyped_dict_ops(self): # Check operation on untyped dictionary d = Dict() self.assertFalse(d._typed) self.assertEqual(len(d), 0) self.assertEqual(str(d), str({})) self.assertEqual(list(iter(d)), []) # Test __getitem__ with self.assertRaises(KeyError) as raises: d[1] self.assertEqual(str(raises.exception), str(KeyError(1))) # Test __delitem__ with self.assertRaises(KeyError) as raises: del d[1] self.assertEqual(str(raises.exception), str(KeyError(1))) # Test .pop with self.assertRaises(KeyError): d.pop(1) self.assertEqual(str(raises.exception), str(KeyError(1))) # Test .pop self.assertIs(d.pop(1, None), None) # Test .get self.assertIs(d.get(1), None) # Test .popitem with self.assertRaises(KeyError) as raises: d.popitem() self.assertEqual(str(raises.exception), str(KeyError('dictionary is empty'))) # Test setdefault(k) with self.assertRaises(TypeError) as raises: d.setdefault(1) self.assertEqual( str(raises.exception), str(TypeError('invalid operation on untyped dictionary')), ) # Test __contains__ self.assertFalse(1 in d) # It's untyped self.assertFalse(d._typed) def test_getitem(self): # Test __getitem__ d = Dict() d[1] = 2 # It's typed now self.assertTrue(d._typed) self.assertEqual(d[1], 2) def test_setdefault(self): # Test setdefault(k, d) d = Dict() d.setdefault(1, 2) # It's typed now self.assertTrue(d._typed) self.assertEqual(d[1], 2) @jitclass(spec=[('a', types.intp)]) class Bag(object): def __init__(self, a): self.a = a def __hash__(self): return hash(self.a) class TestDictWithJitclass(TestCase): def test_jitclass_as_value(self): @njit def foo(x): d = Dict() d[0] = x d[1] = Bag(101) return d d = foo(Bag(a=100)) self.assertEqual(d[0].a, 100) self.assertEqual(d[1].a, 101) class TestNoJit(TestCase): """Exercise dictionary creation with JIT disabled. """ def test_dict_create_no_jit_using_new_dict(self): with override_config('DISABLE_JIT', True): with forbid_codegen(): d = dictobject.new_dict(int32, float32) self.assertEqual(type(d), dict) def test_dict_create_no_jit_using_Dict(self): with override_config('DISABLE_JIT', True): with forbid_codegen(): d = Dict() self.assertEqual(type(d), dict) def test_dict_create_no_jit_using_empty(self): with override_config('DISABLE_JIT', True): with forbid_codegen(): d = Dict.empty(types.int32, types.float32) self.assertEqual(type(d), dict) class TestDictIterator(TestCase): def test_dict_iterator(self): @njit def fun1(): dd = Dict.empty(key_type=types.intp, value_type=types.intp) dd[0] = 10 dd[1] = 20 dd[2] = 30 return list(dd.keys()), list(dd.values()) @njit def fun2(): dd = Dict.empty(key_type=types.intp, value_type=types.intp) dd[4] = 77 dd[5] = 88 dd[6] = 99 return list(dd.keys()), list(dd.values()) res1 = fun1() res2 = fun2() self.assertEqual([0,1,2], res1[0]) self.assertEqual([10,20,30], res1[1]) self.assertEqual([4,5,6], res2[0]) self.assertEqual([77,88,99], res2[1]) class TestTypedDictInitialValues(MemoryLeakMixin, TestCase): """Tests that typed dictionaries carry their initial value if present""" def test_homogeneous_and_literal(self): def bar(d): ... @overload(bar) def ol_bar(d): if d.initial_value is None: return lambda d: literally(d) self.assertTrue(isinstance(d, types.DictType)) self.assertEqual(d.initial_value, {'a': 1, 'b': 2, 'c': 3}) self.assertEqual(hasattr(d, 'literal_value'), False) return lambda d: d @njit def foo(): # keys and values all have literal representation x = {'a': 1, 'b': 2, 'c': 3} bar(x) foo() def test_heterogeneous_but_castable_to_homogeneous(self): def bar(d): ... @overload(bar) def ol_bar(d): self.assertTrue(isinstance(d, types.DictType)) self.assertEqual(d.initial_value, None) self.assertEqual(hasattr(d, 'literal_value'), False) return lambda d: d @njit def foo(): # This dictionary will be typed based on 1j, i.e. complex128 # as the values are not all literals, there's no "initial_value" # available irrespective of whether it's possible to rip this # information out of the bytecode. x = {'a': 1j, 'b': 2, 'c': 3} bar(x) foo() def test_heterogeneous_but_not_castable_to_homogeneous(self): def bar(d): ... @overload(bar) def ol_bar(d): a = {'a': 1, 'b': 2j, 'c': 3} def specific_ty(z): return types.literal(z) if types.maybe_literal(z) else typeof(z) expected = {types.literal(x): specific_ty(y) for x, y in a.items()} self.assertTrue(isinstance(d, types.LiteralStrKeyDict)) self.assertEqual(d.literal_value, expected) self.assertEqual(hasattr(d, 'initial_value'), False) return lambda d: d @njit def foo(): # This dictionary will be typed based on 1, i.e. intp, as the values # cannot all be cast to this type, but the keys are literal strings # this is a LiteralStrKey[Dict], there's no initial_value but there # is a literal_value. x = {'a': 1, 'b': 2j, 'c': 3} bar(x) foo() def test_mutation_not_carried(self): def bar(d): ... @overload(bar) def ol_bar(d): if d.initial_value is None: return lambda d: literally(d) self.assertTrue(isinstance(d, types.DictType)) self.assertEqual(d.initial_value, {'a': 1, 'b': 2, 'c': 3}) return lambda d: d @njit def foo(): # This dictionary is mutated, check the initial_value carries # correctly and is not mutated x = {'a': 1, 'b': 2, 'c': 3} x['d'] = 4 bar(x) foo() def test_mutation_not_carried_single_function(self): # this is another pattern for using literally @njit def nop(*args): pass for fn, iv in (nop, None), (literally, {'a': 1, 'b': 2, 'c': 3}): @njit def baz(x): pass def bar(z): pass @overload(bar) def ol_bar(z): def impl(z): fn(z) baz(z) return impl @njit def foo(): x = {'a': 1, 'b': 2, 'c': 3} bar(x) x['d'] = 4 return x foo() # baz should be specialised based on literally being invoked and # the literal/unliteral arriving at the call site larg = baz.signatures[0][0] self.assertEqual(larg.initial_value, iv) def test_unify_across_function_call(self): @njit def bar(x): o = {1: 2} if x: o = {2: 3} return o @njit def foo(x): if x: d = {3: 4} else: d = bar(x) return d e1 = Dict() e1[3] = 4 e2 = Dict() e2[1] = 2 self.assertEqual(foo(True), e1) self.assertEqual(foo(False), e2) class TestLiteralStrKeyDict(MemoryLeakMixin, TestCase): """ Tests for dictionaries with string keys that can map to anything!""" def test_basic_const_lowering_boxing(self): @njit def foo(): ld = {'a': 1, 'b': 2j, 'c': 'd'} return (ld['a'], ld['b'], ld['c']) self.assertEqual(foo(), (1, 2j, 'd')) def test_basic_nonconst_in_scope(self): @njit def foo(x): y = x + 5 e = True if y > 2 else False ld = {'a': 1, 'b': 2j, 'c': 'd', 'non_const': e} return ld['non_const'] # Recall that key non_const has a value of a known type, bool, and it's # value is stuffed in at run time, this is permitted as the dictionary # is immutable in type self.assertTrue(foo(34)) self.assertFalse(foo(-100)) def test_basic_nonconst_freevar(self): e = 5 def bar(x): pass @overload(bar) def ol_bar(x): self.assertEqual(x.literal_value, {types.literal('a'): types.literal(1), types.literal('b'): typeof(2j), types.literal('c'): types.literal('d'), types.literal('d'): types.literal(5)}) def impl(x): pass return impl @njit def foo(): ld = {'a': 1, 'b': 2j, 'c': 'd', 'd': e} bar(ld) foo() def test_literal_value(self): def bar(x): pass @overload(bar) def ol_bar(x): self.assertEqual(x.literal_value, {types.literal('a'): types.literal(1), types.literal('b'): typeof(2j), types.literal('c'): types.literal('d')}) def impl(x): pass return impl @njit def foo(): ld = {'a': 1, 'b': 2j, 'c': 'd'} bar(ld) foo() def test_list_and_array_as_value(self): def bar(x): pass @overload(bar) def ol_bar(x): self.assertEqual(x.literal_value, {types.literal('a'): types.literal(1), types.literal('b'): types.List(types.intp, initial_value=[1,2,3]), types.literal('c'): typeof(np.zeros(5))}) def impl(x): pass return impl @njit def foo(): b = [1, 2, 3] ld = {'a': 1, 'b': b, 'c': np.zeros(5)} bar(ld) foo() def test_repeated_key_literal_value(self): def bar(x): pass @overload(bar) def ol_bar(x): # order is important, 'a' was seen first, but updated later self.assertEqual(x.literal_value, {types.literal('a'): types.literal('aaaa'), types.literal('b'): typeof(2j), types.literal('c'): types.literal('d')}) def impl(x): pass return impl @njit def foo(): ld = {'a': 1, 'a': 10, 'b': 2j, 'c': 'd', 'a': 'aaaa'} # noqa #F601 bar(ld) foo() def test_read_only(self): def _len(): ld = {'a': 1, 'b': 2j, 'c': 'd'} return len(ld) def static_getitem(): ld = {'a': 1, 'b': 2j, 'c': 'd'} return ld['b'] def contains(): ld = {'a': 1, 'b': 2j, 'c': 'd'} return 'b' in ld, 'f' in ld def copy(): ld = {'a': 1, 'b': 2j, 'c': 'd'} new = ld.copy() return ld == new rdonlys = (_len, static_getitem, contains, copy) for test in rdonlys: with self.subTest(test.__name__): self.assertPreciseEqual(njit(test)(), test()) def test_mutation_failure(self): def setitem(): ld = {'a': 1, 'b': 2j, 'c': 'd'} ld['a'] = 12 def delitem(): ld = {'a': 1, 'b': 2j, 'c': 'd'} del ld['a'] def popitem(): ld = {'a': 1, 'b': 2j, 'c': 'd'} ld.popitem() def pop(): ld = {'a': 1, 'b': 2j, 'c': 'd'} ld.pop() def clear(): ld = {'a': 1, 'b': 2j, 'c': 'd'} ld.clear() def setdefault(): ld = {'a': 1, 'b': 2j, 'c': 'd'} ld.setdefault('f', 1) illegals = (setitem, delitem, popitem, pop, clear, setdefault) for test in illegals: with self.subTest(test.__name__): with self.assertRaises(TypingError) as raises: njit(test)() expect = "Cannot mutate a literal dictionary" self.assertIn(expect, str(raises.exception)) def test_get(self): @njit def get(x): ld = {'a': 2j, 'c': 'd'} return ld.get(x) @njit def getitem(x): ld = {'a': 2j, 'c': 'd'} return ld[x] for test in (get, getitem): with self.subTest(test.__name__): with self.assertRaises(TypingError) as raises: test('a') expect = "Cannot get{item}() on a literal dictionary" self.assertIn(expect, str(raises.exception)) def test_dict_keys(self): @njit def foo(): ld = {'a': 2j, 'c': 'd'} return [x for x in ld.keys()] self.assertEqual(foo(), ['a', 'c']) def test_dict_values(self): @njit def foo(): ld = {'a': 2j, 'c': 'd'} return ld.values() self.assertEqual(foo(), (2j, 'd')) def test_dict_items(self): @njit def foo(): ld = {'a': 2j, 'c': 'd', 'f': np.zeros((5))} return ld.items() self.assertPreciseEqual(foo(), (('a', 2j), ('c', 'd'), ('f', np.zeros((5))))) def test_dict_return(self): @njit def foo(): ld = {'a': 2j, 'c': 'd'} return ld # escaping heterogeneous dictionary is not supported with self.assertRaises(TypeError) as raises: foo() excstr = str(raises.exception) self.assertIn("cannot convert native LiteralStrKey", excstr) def test_dict_unify(self): @njit def foo(x): if x + 7 > 4: a = {'a': 2j, 'c': 'd', 'e': np.zeros(4)} else: # Note the use of a different literal str for key 'c' a = {'a': 5j, 'c': 'CAT', 'e': np.zeros((5,))} return a['c'] self.assertEqual(foo(100), 'd') self.assertEqual(foo(-100), 'CAT') self.assertEqual(foo(100), foo.py_func(100)) self.assertEqual(foo(-100), foo.py_func(-100)) def test_dict_not_unify(self): @njit def key_mismatch(x): if x + 7 > 4: a = {'BAD_KEY': 2j, 'c': 'd', 'e': np.zeros(4)} else: a = {'a': 5j, 'c': 'CAT', 'e': np.zeros((5,))} # prevents inline of return on py310 py310_defeat1 = 1 # noqa py310_defeat2 = 2 # noqa py310_defeat3 = 3 # noqa py310_defeat4 = 4 # noqa return a['a'] with self.assertRaises(TypingError) as raises: key_mismatch(100) self.assertIn("Cannot unify LiteralStrKey", str(raises.exception)) @njit def value_type_mismatch(x): if x + 7 > 4: a = {'a': 2j, 'c': 'd', 'e': np.zeros((4, 3))} else: a = {'a': 5j, 'c': 'CAT', 'e': np.zeros((5,))} # prevents inline of return on py310 py310_defeat1 = 1 # noqa py310_defeat2 = 2 # noqa py310_defeat3 = 3 # noqa py310_defeat4 = 4 # noqa return a['a'] with self.assertRaises(TypingError) as raises: value_type_mismatch(100) self.assertIn("Cannot unify LiteralStrKey", str(raises.exception)) def test_dict_value_coercion(self): # checks that things coerce or not! p = {# safe and no conversion: TypedDict (np.int32, np.int32): types.DictType, # safe and convertible: TypedDict (np.int32, np.int8): types.DictType, # safe convertible: TypedDict (np.complex128, np.int32): types.DictType, # unsafe not convertible: LiteralStrKey (np.int32, np.complex128): types.LiteralStrKeyDict, # unsafe not convertible: LiteralStrKey (np.int32, np.array): types.LiteralStrKeyDict, # unsafe not convertible: LiteralStrKey (np.array, np.int32): types.LiteralStrKeyDict, # unsafe not convertible: LiteralStrKey (np.int8, np.int32): types.LiteralStrKeyDict, # unsafe not convertible: LiteralStrKey (issue #6420 case) (np.int64, np.float64): types.LiteralStrKeyDict,} def bar(x): pass for dts, container in p.items(): @overload(bar) def ol_bar(x): self.assertTrue(isinstance(x, container)) def impl(x): pass return impl ty1, ty2 = dts @njit def foo(): d = {'a': ty1(1), 'b': ty2(2)} bar(d) foo() def test_build_map_op_code(self): # tests building dictionaries via `build_map`, which, for statically # determinable str key->things cases is just a single key:value # any other build_map would either end up as being non-const str keys # or keys of some non-string type and therefore not considered. def bar(x): pass @overload(bar) def ol_bar(x): def impl(x): pass return impl @njit def foo(): a = {'a': {'b1': 10, 'b2': 'string'}} bar(a) foo() def test_dict_as_arg(self): @njit def bar(fake_kwargs=None): if fake_kwargs is not None: # Add 10 to array in key 'd' fake_kwargs['d'][:] += 10 @njit def foo(): a = 1 b = 2j c = 'string' d = np.zeros(3) e = {'a': a, 'b': b, 'c': c, 'd': d} bar(fake_kwargs=e) return e['d'] np.testing.assert_allclose(foo(), np.ones(3) * 10) def test_dict_with_single_literallist_value(self): #see issue #6094 @njit def foo(): z = {"A": [lambda a: 2 * a, "B"]} return z["A"][0](5) self.assertPreciseEqual(foo(), foo.py_func()) def test_tuple_not_in_mro(self): # Related to #6094, make sure that LiteralStrKey does not inherit from # types.BaseTuple as this breaks isinstance checks. def bar(x): pass @overload(bar) def ol_bar(x): self.assertFalse(isinstance(x, types.BaseTuple)) self.assertTrue(isinstance(x, types.LiteralStrKeyDict)) return lambda x: ... @njit def foo(): d = {'a': 1, 'b': 'c'} bar(d) foo() def test_const_key_not_in_dict(self): @njit def foo(): a = {'not_a': 2j, 'c': 'd', 'e': np.zeros(4)} return a['a'] with self.assertRaises(TypingError) as raises: foo() self.assertIn("Key 'a' is not in dict.", str(raises.exception)) def test_uncommon_identifiers(self): # Tests uncommon identifiers like numerical values and operators in # the key fields. See #6518 and #7416. # Numerical values in keys @njit def foo(): d = {'0': np.ones(5), '1': 4} return len(d) self.assertPreciseEqual(foo(), foo.py_func()) # operators in keys @njit def bar(): d = {'+': np.ones(5), 'x--': 4} return len(d) self.assertPreciseEqual(bar(), bar.py_func()) def test_update_error(self): # Tests that dict.update produces a reasonable # error with a LiteralStrKeyDict input. @njit def foo(): d1 = { 'a': 2, 'b': 4, 'c': 'a' } d1.update({'x': 3}) return d1 with self.assertRaises(TypingError) as raises: foo() self.assertIn( "Cannot mutate a literal dictionary", str(raises.exception) ) if __name__ == '__main__': unittest.main()