409 lines
13 KiB
Python
409 lines
13 KiB
Python
"""
|
|
Tests for @cfunc and friends.
|
|
"""
|
|
|
|
|
|
import ctypes
|
|
import os
|
|
import subprocess
|
|
import sys
|
|
from collections import namedtuple
|
|
|
|
import numpy as np
|
|
|
|
from numba import cfunc, carray, farray, njit
|
|
from numba.core import types, typing, utils
|
|
import numba.core.typing.cffi_utils as cffi_support
|
|
from numba.tests.support import (TestCase, skip_unless_cffi, tag,
|
|
captured_stderr)
|
|
import unittest
|
|
from numba.np import numpy_support
|
|
|
|
|
|
def add_usecase(a, b):
|
|
return a + b
|
|
|
|
def div_usecase(a, b):
|
|
c = a / b
|
|
return c
|
|
|
|
def square_usecase(a):
|
|
return a ** 2
|
|
|
|
add_sig = "float64(float64, float64)"
|
|
|
|
div_sig = "float64(int64, int64)"
|
|
|
|
square_sig = "float64(float64)"
|
|
|
|
def objmode_usecase(a, b):
|
|
object()
|
|
return a + b
|
|
|
|
# Test functions for carray() and farray()
|
|
|
|
CARRAY_USECASE_OUT_LEN = 8
|
|
|
|
def make_cfarray_usecase(func):
|
|
|
|
def cfarray_usecase(in_ptr, out_ptr, m, n):
|
|
# Tuple shape
|
|
in_ = func(in_ptr, (m, n))
|
|
# Integer shape
|
|
out = func(out_ptr, CARRAY_USECASE_OUT_LEN)
|
|
out[0] = in_.ndim
|
|
out[1:3] = in_.shape
|
|
out[3:5] = in_.strides
|
|
out[5] = in_.flags.c_contiguous
|
|
out[6] = in_.flags.f_contiguous
|
|
s = 0
|
|
for i, j in np.ndindex(m, n):
|
|
s += in_[i, j] * (i - j)
|
|
out[7] = s
|
|
|
|
return cfarray_usecase
|
|
|
|
carray_usecase = make_cfarray_usecase(carray)
|
|
farray_usecase = make_cfarray_usecase(farray)
|
|
|
|
|
|
def make_cfarray_dtype_usecase(func):
|
|
# Same as make_cfarray_usecase(), but with explicit dtype.
|
|
|
|
def cfarray_usecase(in_ptr, out_ptr, m, n):
|
|
# Tuple shape
|
|
in_ = func(in_ptr, (m, n), dtype=np.float32)
|
|
# Integer shape
|
|
out = func(out_ptr, CARRAY_USECASE_OUT_LEN, np.float32)
|
|
out[0] = in_.ndim
|
|
out[1:3] = in_.shape
|
|
out[3:5] = in_.strides
|
|
out[5] = in_.flags.c_contiguous
|
|
out[6] = in_.flags.f_contiguous
|
|
s = 0
|
|
for i, j in np.ndindex(m, n):
|
|
s += in_[i, j] * (i - j)
|
|
out[7] = s
|
|
|
|
return cfarray_usecase
|
|
|
|
carray_dtype_usecase = make_cfarray_dtype_usecase(carray)
|
|
farray_dtype_usecase = make_cfarray_dtype_usecase(farray)
|
|
|
|
carray_float32_usecase_sig = types.void(types.CPointer(types.float32),
|
|
types.CPointer(types.float32),
|
|
types.intp, types.intp)
|
|
|
|
carray_float64_usecase_sig = types.void(types.CPointer(types.float64),
|
|
types.CPointer(types.float64),
|
|
types.intp, types.intp)
|
|
|
|
carray_voidptr_usecase_sig = types.void(types.voidptr, types.voidptr,
|
|
types.intp, types.intp)
|
|
|
|
|
|
class TestCFunc(TestCase):
|
|
|
|
def test_basic(self):
|
|
"""
|
|
Basic usage and properties of a cfunc.
|
|
"""
|
|
f = cfunc(add_sig)(add_usecase)
|
|
|
|
self.assertEqual(f.__name__, "add_usecase")
|
|
self.assertEqual(f.__qualname__, "add_usecase")
|
|
self.assertIs(f.__wrapped__, add_usecase)
|
|
|
|
symbol = f.native_name
|
|
self.assertIsInstance(symbol, str)
|
|
self.assertIn("add_usecase", symbol)
|
|
|
|
addr = f.address
|
|
self.assertIsInstance(addr, int)
|
|
|
|
ct = f.ctypes
|
|
self.assertEqual(ctypes.cast(ct, ctypes.c_void_p).value, addr)
|
|
|
|
self.assertPreciseEqual(ct(2.0, 3.5), 5.5)
|
|
|
|
@skip_unless_cffi
|
|
def test_cffi(self):
|
|
from numba.tests import cffi_usecases
|
|
ffi, lib = cffi_usecases.load_inline_module()
|
|
|
|
f = cfunc(square_sig)(square_usecase)
|
|
|
|
res = lib._numba_test_funcptr(f.cffi)
|
|
self.assertPreciseEqual(res, 2.25) # 1.5 ** 2
|
|
|
|
def test_locals(self):
|
|
# By forcing the intermediate result into an integer, we
|
|
# truncate the ultimate function result
|
|
f = cfunc(div_sig, locals={'c': types.int64})(div_usecase)
|
|
self.assertPreciseEqual(f.ctypes(8, 3), 2.0)
|
|
|
|
def test_errors(self):
|
|
f = cfunc(div_sig)(div_usecase)
|
|
|
|
with captured_stderr() as err:
|
|
self.assertPreciseEqual(f.ctypes(5, 2), 2.5)
|
|
self.assertEqual(err.getvalue(), "")
|
|
|
|
with captured_stderr() as err:
|
|
res = f.ctypes(5, 0)
|
|
# This is just a side effect of Numba zero-initializing
|
|
# stack variables, and could change in the future.
|
|
self.assertPreciseEqual(res, 0.0)
|
|
err = err.getvalue()
|
|
self.assertIn("ZeroDivisionError:", err)
|
|
self.assertIn("Exception ignored", err)
|
|
|
|
def test_llvm_ir(self):
|
|
f = cfunc(add_sig)(add_usecase)
|
|
ir = f.inspect_llvm()
|
|
self.assertIn(f.native_name, ir)
|
|
self.assertIn("fadd double", ir)
|
|
|
|
def test_object_mode(self):
|
|
"""
|
|
Object mode is currently unsupported.
|
|
"""
|
|
with self.assertRaises(NotImplementedError):
|
|
cfunc(add_sig, forceobj=True)(add_usecase)
|
|
with self.assertTypingError() as raises:
|
|
cfunc(add_sig)(objmode_usecase)
|
|
self.assertIn("Untyped global name 'object'", str(raises.exception))
|
|
|
|
|
|
class TestCArray(TestCase):
|
|
"""
|
|
Tests for carray() and farray().
|
|
"""
|
|
|
|
def run_carray_usecase(self, pointer_factory, func):
|
|
a = np.arange(10, 16).reshape((2, 3)).astype(np.float32)
|
|
out = np.empty(CARRAY_USECASE_OUT_LEN, dtype=np.float32)
|
|
func(pointer_factory(a), pointer_factory(out), *a.shape)
|
|
return out
|
|
|
|
def check_carray_usecase(self, pointer_factory, pyfunc, cfunc):
|
|
expected = self.run_carray_usecase(pointer_factory, pyfunc)
|
|
got = self.run_carray_usecase(pointer_factory, cfunc)
|
|
self.assertPreciseEqual(expected, got)
|
|
|
|
def make_voidptr(self, arr):
|
|
return arr.ctypes.data_as(ctypes.c_void_p)
|
|
|
|
def make_float32_pointer(self, arr):
|
|
return arr.ctypes.data_as(ctypes.POINTER(ctypes.c_float))
|
|
|
|
def make_float64_pointer(self, arr):
|
|
return arr.ctypes.data_as(ctypes.POINTER(ctypes.c_double))
|
|
|
|
def check_carray_farray(self, func, order):
|
|
def eq(got, expected):
|
|
# Same layout, dtype, shape, etc.
|
|
self.assertPreciseEqual(got, expected)
|
|
# Same underlying data
|
|
self.assertEqual(got.ctypes.data, expected.ctypes.data)
|
|
|
|
base = np.arange(6).reshape((2, 3)).astype(np.float32).copy(order=order)
|
|
|
|
# With typed pointer and implied dtype
|
|
a = func(self.make_float32_pointer(base), base.shape)
|
|
eq(a, base)
|
|
# Integer shape
|
|
a = func(self.make_float32_pointer(base), base.size)
|
|
eq(a, base.ravel('K'))
|
|
|
|
# With typed pointer and explicit dtype
|
|
a = func(self.make_float32_pointer(base), base.shape, base.dtype)
|
|
eq(a, base)
|
|
a = func(self.make_float32_pointer(base), base.shape, np.float32)
|
|
eq(a, base)
|
|
|
|
# With voidptr and explicit dtype
|
|
a = func(self.make_voidptr(base), base.shape, base.dtype)
|
|
eq(a, base)
|
|
a = func(self.make_voidptr(base), base.shape, np.int32)
|
|
eq(a, base.view(np.int32))
|
|
|
|
# voidptr without dtype
|
|
with self.assertRaises(TypeError):
|
|
func(self.make_voidptr(base), base.shape)
|
|
# Invalid pointer type
|
|
with self.assertRaises(TypeError):
|
|
func(base.ctypes.data, base.shape)
|
|
# Mismatching dtype
|
|
with self.assertRaises(TypeError) as raises:
|
|
func(self.make_float32_pointer(base), base.shape, np.int32)
|
|
self.assertIn("mismatching dtype 'int32' for pointer",
|
|
str(raises.exception))
|
|
|
|
def test_carray(self):
|
|
"""
|
|
Test pure Python carray().
|
|
"""
|
|
self.check_carray_farray(carray, 'C')
|
|
|
|
def test_farray(self):
|
|
"""
|
|
Test pure Python farray().
|
|
"""
|
|
self.check_carray_farray(farray, 'F')
|
|
|
|
def make_carray_sigs(self, formal_sig):
|
|
"""
|
|
Generate a bunch of concrete signatures by varying the width
|
|
and signedness of size arguments (see issue #1923).
|
|
"""
|
|
for actual_size in (types.intp, types.int32, types.intc,
|
|
types.uintp, types.uint32, types.uintc):
|
|
args = tuple(actual_size if a == types.intp else a
|
|
for a in formal_sig.args)
|
|
yield formal_sig.return_type(*args)
|
|
|
|
def check_numba_carray_farray(self, usecase, dtype_usecase):
|
|
# With typed pointers and implicit dtype
|
|
pyfunc = usecase
|
|
for sig in self.make_carray_sigs(carray_float32_usecase_sig):
|
|
f = cfunc(sig)(pyfunc)
|
|
self.check_carray_usecase(self.make_float32_pointer, pyfunc, f.ctypes)
|
|
|
|
# With typed pointers and explicit (matching) dtype
|
|
pyfunc = dtype_usecase
|
|
for sig in self.make_carray_sigs(carray_float32_usecase_sig):
|
|
f = cfunc(sig)(pyfunc)
|
|
self.check_carray_usecase(self.make_float32_pointer, pyfunc, f.ctypes)
|
|
# With typed pointers and mismatching dtype
|
|
with self.assertTypingError() as raises:
|
|
f = cfunc(carray_float64_usecase_sig)(pyfunc)
|
|
self.assertIn("mismatching dtype 'float32' for pointer type 'float64*'",
|
|
str(raises.exception))
|
|
|
|
# With voidptr
|
|
pyfunc = dtype_usecase
|
|
for sig in self.make_carray_sigs(carray_voidptr_usecase_sig):
|
|
f = cfunc(sig)(pyfunc)
|
|
self.check_carray_usecase(self.make_float32_pointer, pyfunc, f.ctypes)
|
|
|
|
def test_numba_carray(self):
|
|
"""
|
|
Test Numba-compiled carray() against pure Python carray()
|
|
"""
|
|
self.check_numba_carray_farray(carray_usecase, carray_dtype_usecase)
|
|
|
|
def test_numba_farray(self):
|
|
"""
|
|
Test Numba-compiled farray() against pure Python farray()
|
|
"""
|
|
self.check_numba_carray_farray(farray_usecase, farray_dtype_usecase)
|
|
|
|
|
|
@skip_unless_cffi
|
|
class TestCffiStruct(TestCase):
|
|
c_source = """
|
|
typedef struct _big_struct {
|
|
int i1;
|
|
float f2;
|
|
double d3;
|
|
float af4[9];
|
|
} big_struct;
|
|
|
|
typedef struct _error {
|
|
int bits:4;
|
|
} error;
|
|
|
|
typedef double (*myfunc)(big_struct*, size_t);
|
|
"""
|
|
|
|
def get_ffi(self, src=c_source):
|
|
from cffi import FFI
|
|
|
|
ffi = FFI()
|
|
ffi.cdef(src)
|
|
return ffi
|
|
|
|
def test_type_parsing(self):
|
|
ffi = self.get_ffi()
|
|
# Check struct typedef
|
|
big_struct = ffi.typeof('big_struct')
|
|
nbtype = cffi_support.map_type(big_struct, use_record_dtype=True)
|
|
self.assertIsInstance(nbtype, types.Record)
|
|
self.assertEqual(len(nbtype), 4)
|
|
self.assertEqual(nbtype.typeof('i1'), types.int32)
|
|
self.assertEqual(nbtype.typeof('f2'), types.float32)
|
|
self.assertEqual(nbtype.typeof('d3'), types.float64)
|
|
self.assertEqual(
|
|
nbtype.typeof('af4'),
|
|
types.NestedArray(dtype=types.float32, shape=(9,)),
|
|
)
|
|
|
|
# Check function typedef
|
|
myfunc = ffi.typeof('myfunc')
|
|
sig = cffi_support.map_type(myfunc, use_record_dtype=True)
|
|
self.assertIsInstance(sig, typing.Signature)
|
|
self.assertEqual(sig.args[0], types.CPointer(nbtype))
|
|
self.assertEqual(sig.args[1], types.uintp)
|
|
self.assertEqual(sig.return_type, types.float64)
|
|
|
|
def test_cfunc_callback(self):
|
|
ffi = self.get_ffi()
|
|
big_struct = ffi.typeof('big_struct')
|
|
nb_big_struct = cffi_support.map_type(big_struct, use_record_dtype=True)
|
|
sig = cffi_support.map_type(ffi.typeof('myfunc'), use_record_dtype=True)
|
|
|
|
@njit
|
|
def calc(base):
|
|
tmp = 0
|
|
for i in range(base.size):
|
|
elem = base[i]
|
|
tmp += elem.i1 * elem.f2 / elem.d3
|
|
tmp += base[i].af4.sum()
|
|
return tmp
|
|
|
|
@cfunc(sig)
|
|
def foo(ptr, n):
|
|
base = carray(ptr, n)
|
|
return calc(base)
|
|
|
|
# Make data
|
|
mydata = ffi.new('big_struct[3]')
|
|
ptr = ffi.cast('big_struct*', mydata)
|
|
for i in range(3):
|
|
ptr[i].i1 = i * 123
|
|
ptr[i].f2 = i * 213
|
|
ptr[i].d3 = (1 + i) * 213
|
|
for j in range(9):
|
|
ptr[i].af4[j] = i * 10 + j
|
|
|
|
# Address of my data
|
|
addr = int(ffi.cast('size_t', ptr))
|
|
got = foo.ctypes(addr, 3)
|
|
|
|
# Make numpy array from the cffi buffer
|
|
array = np.ndarray(
|
|
buffer=ffi.buffer(mydata),
|
|
dtype=numpy_support.as_dtype(nb_big_struct),
|
|
shape=3,
|
|
)
|
|
expect = calc(array)
|
|
self.assertEqual(got, expect)
|
|
|
|
def test_unsupport_bitsize(self):
|
|
ffi = self.get_ffi()
|
|
with self.assertRaises(ValueError) as raises:
|
|
cffi_support.map_type(
|
|
ffi.typeof('error'),
|
|
use_record_dtype=True,
|
|
)
|
|
# When bitsize is provided, bitshift defaults to 0.
|
|
self.assertEqual(
|
|
"field 'bits' has bitshift, this is not supported",
|
|
str(raises.exception)
|
|
)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
unittest.main()
|