import numpy as np from numba.core.errors import TypingError from numba import njit from numba.core import types import struct import unittest def float_to_int(x): return types.int32(x) def int_to_float(x): return types.float64(x) / 2 def float_to_unsigned(x): return types.uint32(x) def float_to_complex(x): return types.complex128(x) def numpy_scalar_cast_error(): np.int32(np.zeros((4,))) class TestCasting(unittest.TestCase): def test_float_to_int(self): pyfunc = float_to_int cfunc = njit((types.float32,))(pyfunc) self.assertEqual(cfunc.nopython_signatures[0].return_type, types.int32) self.assertEqual(cfunc(12.3), pyfunc(12.3)) self.assertEqual(cfunc(12.3), int(12.3)) self.assertEqual(cfunc(-12.3), pyfunc(-12.3)) self.assertEqual(cfunc(-12.3), int(-12.3)) def test_int_to_float(self): pyfunc = int_to_float cfunc = njit((types.int64,))(pyfunc) self.assertEqual(cfunc.nopython_signatures[0].return_type, types.float64) self.assertEqual(cfunc(321), pyfunc(321)) self.assertEqual(cfunc(321), 321. / 2) def test_float_to_unsigned(self): pyfunc = float_to_unsigned cfunc = njit((types.float32,))(pyfunc) self.assertEqual(cfunc.nopython_signatures[0].return_type, types.uint32) self.assertEqual(cfunc(3.21), pyfunc(3.21)) self.assertEqual(cfunc(3.21), struct.unpack('I', struct.pack('i', 3))[0]) def test_float_to_complex(self): pyfunc = float_to_complex cfunc = njit((types.float64,))(pyfunc) self.assertEqual(cfunc.nopython_signatures[0].return_type, types.complex128) self.assertEqual(cfunc(-3.21), pyfunc(-3.21)) self.assertEqual(cfunc(-3.21), -3.21 + 0j) def test_array_to_array(self): """Make sure this compiles. Cast C to A array """ @njit("f8(f8[:])") def inner(x): return x[0] inner.disable_compile() @njit("f8(f8[::1])") def driver(x): return inner(x) x = np.array([1234], dtype=np.float64) self.assertEqual(driver(x), x[0]) self.assertEqual(len(inner.overloads), 1) def test_0darrayT_to_T(self): @njit def inner(x): return x.dtype.type(x) inputs = [ (np.bool_, True), (np.float32, 12.3), (np.float64, 12.3), (np.int64, 12), (np.complex64, 2j+3), (np.complex128, 2j+3), (np.timedelta64, np.timedelta64(3, 'h')), (np.datetime64, np.datetime64('2016-01-01')), ('