import itertools import numpy as np import unittest from numba import jit, typeof, njit from numba.core import types from numba.core.errors import TypingError from numba.tests.support import MemoryLeakMixin, TestCase def getitem_usecase(a, b): return a[b] def setitem_usecase(a, idx, b): a[idx] = b def np_take(A, indices): return np.take(A, indices) def np_take_kws(A, indices, axis): return np.take(A, indices, axis=axis) class TestFancyIndexing(MemoryLeakMixin, TestCase): def generate_advanced_indices(self, N, many=True): choices = [np.int16([0, N - 1, -2])] if many: choices += [np.uint16([0, 1, N - 1]), np.bool_([0, 1, 1, 0])] return choices def generate_basic_index_tuples(self, N, maxdim, many=True): """ Generate basic index tuples with 0 to *maxdim* items. """ # Note integers can be considered advanced indices in certain # cases, so we avoid them here. # See "Combining advanced and basic indexing" # in http://docs.scipy.org/doc/numpy/reference/arrays.indexing.html if many: choices = [slice(None, None, None), slice(1, N - 1, None), slice(0, None, 2), slice(N - 1, None, -2), slice(-N + 1, -1, None), slice(-1, -N, -2), ] else: choices = [slice(0, N - 1, None), slice(-1, -N, -2)] for ndim in range(maxdim + 1): for tup in itertools.product(choices, repeat=ndim): yield tup def generate_advanced_index_tuples(self, N, maxdim, many=True): """ Generate advanced index tuples by generating basic index tuples and adding a single advanced index item. """ # (Note Numba doesn't support advanced indices with more than # one advanced index array at the moment) choices = list(self.generate_advanced_indices(N, many=many)) for i in range(maxdim + 1): for tup in self.generate_basic_index_tuples(N, maxdim - 1, many): for adv in choices: yield tup[:i] + (adv,) + tup[i:] def generate_advanced_index_tuples_with_ellipsis(self, N, maxdim, many=True): """ Same as generate_advanced_index_tuples(), but also insert an ellipsis at various points. """ for tup in self.generate_advanced_index_tuples(N, maxdim, many): for i in range(len(tup) + 1): yield tup[:i] + (Ellipsis,) + tup[i:] def check_getitem_indices(self, arr, indices): pyfunc = getitem_usecase cfunc = jit(nopython=True)(pyfunc) orig = arr.copy() orig_base = arr.base or arr for index in indices: expected = pyfunc(arr, index) # Sanity check: if a copy wasn't made, this wasn't advanced # but basic indexing, and shouldn't be tested here. assert expected.base is not orig_base got = cfunc(arr, index) # Note Numba may not return the same array strides and # contiguity as Numpy self.assertEqual(got.shape, expected.shape) self.assertEqual(got.dtype, expected.dtype) np.testing.assert_equal(got, expected) # Check a copy was *really* returned by Numba if got.size: got.fill(42) np.testing.assert_equal(arr, orig) def test_getitem_tuple(self): # Test many variations of advanced indexing with a tuple index N = 4 ndim = 3 arr = np.arange(N ** ndim).reshape((N,) * ndim).astype(np.int32) indices = self.generate_advanced_index_tuples(N, ndim) self.check_getitem_indices(arr, indices) def test_getitem_tuple_and_ellipsis(self): # Same, but also insert an ellipsis at a random point N = 4 ndim = 3 arr = np.arange(N ** ndim).reshape((N,) * ndim).astype(np.int32) indices = self.generate_advanced_index_tuples_with_ellipsis(N, ndim, many=False) self.check_getitem_indices(arr, indices) def test_ellipsis_getsetitem(self): # See https://github.com/numba/numba/issues/3225 @jit(nopython=True) def foo(arr, v): arr[..., 0] = arr[..., 1] arr = np.arange(2) foo(arr, 1) self.assertEqual(arr[0], arr[1]) def test_getitem_array(self): # Test advanced indexing with a single array index N = 4 ndim = 3 arr = np.arange(N ** ndim).reshape((N,) * ndim).astype(np.int32) indices = self.generate_advanced_indices(N) self.check_getitem_indices(arr, indices) def check_setitem_indices(self, arr, indices): pyfunc = setitem_usecase cfunc = jit(nopython=True)(pyfunc) for index in indices: src = arr[index] expected = np.zeros_like(arr) got = np.zeros_like(arr) pyfunc(expected, index, src) cfunc(got, index, src) # Note Numba may not return the same array strides and # contiguity as Numpy self.assertEqual(got.shape, expected.shape) self.assertEqual(got.dtype, expected.dtype) np.testing.assert_equal(got, expected) def test_setitem_tuple(self): # Test many variations of advanced indexing with a tuple index N = 4 ndim = 3 arr = np.arange(N ** ndim).reshape((N,) * ndim).astype(np.int32) indices = self.generate_advanced_index_tuples(N, ndim) self.check_setitem_indices(arr, indices) def test_setitem_tuple_and_ellipsis(self): # Same, but also insert an ellipsis at a random point N = 4 ndim = 3 arr = np.arange(N ** ndim).reshape((N,) * ndim).astype(np.int32) indices = self.generate_advanced_index_tuples_with_ellipsis(N, ndim, many=False) self.check_setitem_indices(arr, indices) def test_setitem_array(self): # Test advanced indexing with a single array index N = 4 ndim = 3 arr = np.arange(N ** ndim).reshape((N,) * ndim).astype(np.int32) + 10 indices = self.generate_advanced_indices(N) self.check_setitem_indices(arr, indices) def test_setitem_0d(self): # Test setitem with a 0d-array pyfunc = setitem_usecase cfunc = jit(nopython=True)(pyfunc) inps = [ (np.zeros(3), np.array(3.14)), (np.zeros(2), np.array(2)), (np.zeros(3, dtype=np.int64), np.array(3, dtype=np.int64)), (np.zeros(3, dtype=np.float64), np.array(1, dtype=np.int64)), (np.zeros(5, dtype='