""" Test helper functions from numba.numpy_support. """ import sys from itertools import product import numpy as np import unittest from numba.core import types from numba.core.errors import NumbaNotImplementedError from numba.tests.support import TestCase from numba.tests.enum_usecases import Shake, RequestError from numba.np import numpy_support class TestFromDtype(TestCase): def test_number_types(self): """ Test from_dtype() and as_dtype() with the various scalar number types. """ f = numpy_support.from_dtype def check(typechar, numba_type): # Only native ordering and alignment is supported dtype = np.dtype(typechar) self.assertIs(f(dtype), numba_type) self.assertIs(f(np.dtype('=' + typechar)), numba_type) self.assertEqual(dtype, numpy_support.as_dtype(numba_type)) check('?', types.bool_) check('f', types.float32) check('f4', types.float32) check('d', types.float64) check('f8', types.float64) check('F', types.complex64) check('c8', types.complex64) check('D', types.complex128) check('c16', types.complex128) check('O', types.pyobject) check('b', types.int8) check('i1', types.int8) check('B', types.uint8) check('u1', types.uint8) check('h', types.int16) check('i2', types.int16) check('H', types.uint16) check('u2', types.uint16) check('i', types.int32) check('i4', types.int32) check('I', types.uint32) check('u4', types.uint32) check('q', types.int64) check('Q', types.uint64) for name in ('int8', 'uint8', 'int16', 'uint16', 'int32', 'uint32', 'int64', 'uint64', 'intp', 'uintp'): self.assertIs(f(np.dtype(name)), getattr(types, name)) # Non-native alignments are unsupported (except for 1-byte types) foreign_align = '>' if sys.byteorder == 'little' else '<' for letter in 'hHiIlLqQfdFD': self.assertRaises(NumbaNotImplementedError, f, np.dtype(foreign_align + letter)) def test_string_types(self): """ Test from_dtype() and as_dtype() with the character string types. """ def check(typestring, numba_type): # Only native ordering and alignment is supported dtype = np.dtype(typestring) self.assertEqual(numpy_support.from_dtype(dtype), numba_type) self.assertEqual(dtype, numpy_support.as_dtype(numba_type)) check('S10', types.CharSeq(10)) check('a11', types.CharSeq(11)) check('U12', types.UnicodeCharSeq(12)) def check_datetime_types(self, letter, nb_class): def check(dtype, numba_type, code): tp = numpy_support.from_dtype(dtype) self.assertEqual(tp, numba_type) self.assertEqual(tp.unit_code, code) self.assertEqual(numpy_support.as_dtype(numba_type), dtype) self.assertEqual(numpy_support.as_dtype(tp), dtype) # Unit-less ("generic") type check(np.dtype(letter), nb_class(''), 14) def test_datetime_types(self): """ Test from_dtype() and as_dtype() with the datetime types. """ self.check_datetime_types('M', types.NPDatetime) def test_timedelta_types(self): """ Test from_dtype() and as_dtype() with the timedelta types. """ self.check_datetime_types('m', types.NPTimedelta) def test_struct_types(self): def check(dtype, fields, size, aligned): tp = numpy_support.from_dtype(dtype) self.assertIsInstance(tp, types.Record) # Only check for dtype equality, as the Numba type may be interned self.assertEqual(tp.dtype, dtype) self.assertEqual(tp.fields, fields) self.assertEqual(tp.size, size) self.assertEqual(tp.aligned, aligned) dtype = np.dtype([('a', np.int16), ('b', np.int32)]) check(dtype, fields={'a': (types.int16, 0, None, None), 'b': (types.int32, 2, None, None)}, size=6, aligned=False) dtype = np.dtype([('a', np.int16), ('b', np.int32)], align=True) check(dtype, fields={'a': (types.int16, 0, None, None), 'b': (types.int32, 4, None, None)}, size=8, aligned=True) dtype = np.dtype([('m', np.int32), ('n', 'S5')]) check(dtype, fields={'m': (types.int32, 0, None, None), 'n': (types.CharSeq(5), 4, None, None)}, size=9, aligned=False) def test_enum_type(self): def check(base_inst, enum_def, type_class): np_dt = np.dtype(base_inst) nb_ty = numpy_support.from_dtype(np_dt) inst = type_class(enum_def, nb_ty) recovered = numpy_support.as_dtype(inst) self.assertEqual(np_dt, recovered) dts = [np.float64, np.int32, np.complex128, np.bool_] enums = [Shake, RequestError] for dt, enum in product(dts, enums): check(dt, enum, types.EnumMember) for dt, enum in product(dts, enums): check(dt, enum, types.IntEnumMember) class ValueTypingTestBase(object): """ Common tests for the typing of values. Also used by test_special. """ def check_number_values(self, func): """ Test *func*() with scalar numeric values. """ f = func # Standard Python types get inferred by numpy self.assertIn(f(1), (types.int32, types.int64)) self.assertIn(f(2**31 - 1), (types.int32, types.int64)) self.assertIn(f(-2**31), (types.int32, types.int64)) self.assertIs(f(1.0), types.float64) self.assertIs(f(1.0j), types.complex128) self.assertIs(f(True), types.bool_) self.assertIs(f(False), types.bool_) # Numpy scalar types get converted by from_dtype() for name in ('int8', 'uint8', 'int16', 'uint16', 'int32', 'uint32', 'int64', 'uint64', 'intc', 'uintc', 'intp', 'uintp', 'float32', 'float64', 'complex64', 'complex128', 'bool_'): val = getattr(np, name)() self.assertIs(f(val), getattr(types, name)) def _base_check_datetime_values(self, func, np_type, nb_type): f = func for unit in [ '', 'Y', 'M', 'D', 'h', 'm', 's', 'ms', 'us', 'ns', 'ps', 'fs', 'as', ]: if unit: t = np_type(3, unit) else: # "generic" datetime / timedelta t = np_type('Nat') tp = f(t) # This ensures the unit hasn't been lost self.assertEqual(tp, nb_type(unit)) def check_datetime_values(self, func): """ Test *func*() with np.datetime64 values. """ self._base_check_datetime_values(func, np.datetime64, types.NPDatetime) def check_timedelta_values(self, func): """ Test *func*() with np.timedelta64 values. """ self._base_check_datetime_values(func, np.timedelta64, types.NPTimedelta) class TestArrayScalars(ValueTypingTestBase, TestCase): def test_number_values(self): """ Test map_arrayscalar_type() with scalar number values. """ self.check_number_values(numpy_support.map_arrayscalar_type) def test_datetime_values(self): """ Test map_arrayscalar_type() with np.datetime64 values. """ f = numpy_support.map_arrayscalar_type self.check_datetime_values(f) # datetime64s with a non-one factor shouldn't be supported t = np.datetime64('2014', '10Y') with self.assertRaises(NotImplementedError): f(t) def test_timedelta_values(self): """ Test map_arrayscalar_type() with np.timedelta64 values. """ f = numpy_support.map_arrayscalar_type self.check_timedelta_values(f) # timedelta64s with a non-one factor shouldn't be supported t = np.timedelta64(10, '10Y') with self.assertRaises(NotImplementedError): f(t) class FakeUFunc(object): __slots__ = ('nin', 'nout', 'types', 'ntypes') __name__ = "fake ufunc" def __init__(self, types): self.types = types in_, out = self.types[0].split('->') self.nin = len(in_) self.nout = len(out) self.ntypes = len(types) for tp in types: in_, out = self.types[0].split('->') assert len(in_) == self.nin assert len(out) == self.nout # Typical types for np.add, np.multiply, np.isnan _add_types = ['??->?', 'bb->b', 'BB->B', 'hh->h', 'HH->H', 'ii->i', 'II->I', 'll->l', 'LL->L', 'qq->q', 'QQ->Q', 'ee->e', 'ff->f', 'dd->d', 'gg->g', 'FF->F', 'DD->D', 'GG->G', 'Mm->M', 'mm->m', 'mM->M', 'OO->O'] _mul_types = ['??->?', 'bb->b', 'BB->B', 'hh->h', 'HH->H', 'ii->i', 'II->I', 'll->l', 'LL->L', 'qq->q', 'QQ->Q', 'ee->e', 'ff->f', 'dd->d', 'gg->g', 'FF->F', 'DD->D', 'GG->G', 'mq->m', 'qm->m', 'md->m', 'dm->m', 'OO->O'] # Those ones only have floating-point loops _isnan_types = ['e->?', 'f->?', 'd->?', 'g->?', 'F->?', 'D->?', 'G->?'] _sqrt_types = ['e->e', 'f->f', 'd->d', 'g->g', 'F->F', 'D->D', 'G->G', 'O->O'] class TestUFuncs(TestCase): """ Test ufunc helpers. """ def test_ufunc_find_matching_loop(self): f = numpy_support.ufunc_find_matching_loop np_add = FakeUFunc(_add_types) np_mul = FakeUFunc(_mul_types) np_isnan = FakeUFunc(_isnan_types) np_sqrt = FakeUFunc(_sqrt_types) def check(ufunc, input_types, sigs, output_types=()): """ Check that ufunc_find_matching_loop() finds one of the given *sigs* for *ufunc*, *input_types* and optional *output_types*. """ loop = f(ufunc, input_types + output_types) self.assertTrue(loop) if isinstance(sigs, str): sigs = (sigs,) self.assertIn(loop.ufunc_sig, sigs, "inputs=%s and outputs=%s should have selected " "one of %s, got %s" % (input_types, output_types, sigs, loop.ufunc_sig)) self.assertEqual(len(loop.numpy_inputs), len(loop.inputs)) self.assertEqual(len(loop.numpy_outputs), len(loop.outputs)) if not output_types: # Add explicit outputs and check the result is the same loop_explicit = f(ufunc, list(input_types) + loop.outputs) self.assertEqual(loop_explicit, loop) else: self.assertEqual(loop.outputs, list(output_types)) # Round-tripping inputs and outputs loop_rt = f(ufunc, loop.inputs + loop.outputs) self.assertEqual(loop_rt, loop) return loop def check_exact(ufunc, input_types, sigs, output_types=()): """ Like check(), but also ensure no casting of inputs occurred. """ loop = check(ufunc, input_types, sigs, output_types) self.assertEqual(loop.inputs, list(input_types)) def check_no_match(ufunc, input_types): loop = f(ufunc, input_types) self.assertIs(loop, None) # Exact matching for number types check_exact(np_add, (types.bool_, types.bool_), '??->?') check_exact(np_add, (types.int8, types.int8), 'bb->b') check_exact(np_add, (types.uint8, types.uint8), 'BB->B') check_exact(np_add, (types.int64, types.int64), ('ll->l', 'qq->q')) check_exact(np_add, (types.uint64, types.uint64), ('LL->L', 'QQ->Q')) check_exact(np_add, (types.float32, types.float32), 'ff->f') check_exact(np_add, (types.float64, types.float64), 'dd->d') check_exact(np_add, (types.complex64, types.complex64), 'FF->F') check_exact(np_add, (types.complex128, types.complex128), 'DD->D') # Exact matching for datetime64 and timedelta64 types check_exact(np_add, (types.NPTimedelta('s'), types.NPTimedelta('s')), 'mm->m', output_types=(types.NPTimedelta('s'),)) check_exact(np_add, (types.NPTimedelta('ms'), types.NPDatetime('s')), 'mM->M', output_types=(types.NPDatetime('ms'),)) check_exact(np_add, (types.NPDatetime('s'), types.NPTimedelta('s')), 'Mm->M', output_types=(types.NPDatetime('s'),)) check_exact(np_add, (types.NPDatetime('s'), types.NPTimedelta('')), 'Mm->M', output_types=(types.NPDatetime('s'),)) check_exact(np_add, (types.NPDatetime('ns'), types.NPTimedelta('')), 'Mm->M', output_types=(types.NPDatetime('ns'),)) check_exact(np_add, (types.NPTimedelta(''), types.NPDatetime('s')), 'mM->M', output_types=(types.NPDatetime('s'),)) check_exact(np_add, (types.NPTimedelta(''), types.NPDatetime('ns')), 'mM->M', output_types=(types.NPDatetime('ns'),)) check_exact(np_mul, (types.NPTimedelta('s'), types.int64), 'mq->m', output_types=(types.NPTimedelta('s'),)) check_exact(np_mul, (types.float64, types.NPTimedelta('s')), 'dm->m', output_types=(types.NPTimedelta('s'),)) # Mix and match number types, with casting check(np_add, (types.bool_, types.int8), 'bb->b') check(np_add, (types.uint8, types.bool_), 'BB->B') check(np_add, (types.int16, types.uint16), 'ii->i') check(np_add, (types.complex64, types.float64), 'DD->D') check(np_add, (types.float64, types.complex64), 'DD->D') # Integers, when used together with floating-point numbers, # should cast to any real or complex (see #2006) int_types = [types.int32, types.uint32, types.int64, types.uint64] for intty in int_types: check(np_add, (types.float32, intty), 'ff->f') check(np_add, (types.float64, intty), 'dd->d') check(np_add, (types.complex64, intty), 'FF->F') check(np_add, (types.complex128, intty), 'DD->D') # However, when used alone, they should cast only to # floating-point types of sufficient precision # (typical use case: np.sqrt(2) should give an accurate enough value) for intty in int_types: check(np_sqrt, (intty,), 'd->d') check(np_isnan, (intty,), 'd->?') # With some timedelta64 arguments as well check(np_mul, (types.NPTimedelta('s'), types.int32), 'mq->m', output_types=(types.NPTimedelta('s'),)) check(np_mul, (types.NPTimedelta('s'), types.uint32), 'mq->m', output_types=(types.NPTimedelta('s'),)) check(np_mul, (types.NPTimedelta('s'), types.float32), 'md->m', output_types=(types.NPTimedelta('s'),)) check(np_mul, (types.float32, types.NPTimedelta('s')), 'dm->m', output_types=(types.NPTimedelta('s'),)) # No match check_no_match(np_add, (types.NPDatetime('s'), types.NPDatetime('s'))) # No implicit casting from int64 to timedelta64 (Numpy would allow # this). check_no_match(np_add, (types.NPTimedelta('s'), types.int64)) def test_layout_checker(self): def check_arr(arr): dims = arr.shape strides = arr.strides itemsize = arr.dtype.itemsize is_c = numpy_support.is_contiguous(dims, strides, itemsize) is_f = numpy_support.is_fortran(dims, strides, itemsize) expect_c = arr.flags['C_CONTIGUOUS'] expect_f = arr.flags['F_CONTIGUOUS'] self.assertEqual(is_c, expect_c) self.assertEqual(is_f, expect_f) arr = np.arange(24) # 1D check_arr(arr) # 2D check_arr(arr.reshape((3, 8))) check_arr(arr.reshape((3, 8)).T) check_arr(arr.reshape((3, 8))[::2]) # 3D check_arr(arr.reshape((2, 3, 4))) check_arr(arr.reshape((2, 3, 4)).T) # middle axis is shape 1 check_arr(arr.reshape((2, 3, 4))[:, ::3]) check_arr(arr.reshape((2, 3, 4)).T[:, ::3]) # leading axis is shape 1 check_arr(arr.reshape((2, 3, 4))[::2]) check_arr(arr.reshape((2, 3, 4)).T[:, :, ::2]) # 2 leading axis are shape 1 check_arr(arr.reshape((2, 3, 4))[::2, ::3]) check_arr(arr.reshape((2, 3, 4)).T[:, ::3, ::2]) # single item slices for all axis check_arr(arr.reshape((2, 3, 4))[::2, ::3, ::4]) check_arr(arr.reshape((2, 3, 4)).T[::4, ::3, ::2]) # 4D check_arr(arr.reshape((2, 2, 3, 2))[::2, ::2, ::3]) check_arr(arr.reshape((2, 2, 3, 2)).T[:, ::3, ::2, ::2]) # outer zero dims check_arr(arr.reshape((2, 2, 3, 2))[::5, ::2, ::3]) check_arr(arr.reshape((2, 2, 3, 2)).T[:, ::3, ::2, ::5]) if __name__ == '__main__': unittest.main()