""" Tests for numba.core.codegen. """ import warnings import base64 import ctypes import pickle import re import subprocess import sys import weakref import llvmlite.binding as ll import unittest from numba import njit from numba.core.codegen import JITCPUCodegen from numba.core.compiler_lock import global_compiler_lock from numba.tests.support import TestCase asm_sum = r""" define i32 @sum(i32 %.1, i32 %.2) { %.3 = add i32 %.1, %.2 ret i32 %.3 } """ # Note we're using a rather mangled function name to check that it # is compatible with object serialization. asm_sum_inner = """ define i32 @"__main__.ising_element_update$1.array(int8,_2d,_C).int64.int64"(i32 %.1, i32 %.2) { %.3 = add i32 %.1, %.2 ret i32 %.3 } """ # noqa: E501 asm_sum_outer = """ declare i32 @"__main__.ising_element_update$1.array(int8,_2d,_C).int64.int64"(i32 %.1, i32 %.2) define i32 @sum(i32 %.1, i32 %.2) { %.3 = call i32 @"__main__.ising_element_update$1.array(int8,_2d,_C).int64.int64"(i32 %.1, i32 %.2) ret i32 %.3 } """ # noqa: E501 ctypes_sum_ty = ctypes.CFUNCTYPE(ctypes.c_int, ctypes.c_int, ctypes.c_int) class JITCPUCodegenTestCase(TestCase): """ Test the JIT code generation. """ def setUp(self): global_compiler_lock.acquire() self.codegen = JITCPUCodegen('test_codegen') def tearDown(self): del self.codegen global_compiler_lock.release() def compile_module(self, asm, linking_asm=None): library = self.codegen.create_library('compiled_module') ll_module = ll.parse_assembly(asm) ll_module.verify() library.add_llvm_module(ll_module) if linking_asm: linking_library = self.codegen.create_library('linking_module') ll_module = ll.parse_assembly(linking_asm) ll_module.verify() linking_library.add_llvm_module(ll_module) library.add_linking_library(linking_library) return library @classmethod def _check_unserialize_sum(cls, state): codegen = JITCPUCodegen('other_codegen') library = codegen.unserialize_library(state) ptr = library.get_pointer_to_function("sum") assert ptr, ptr cfunc = ctypes_sum_ty(ptr) res = cfunc(2, 3) assert res == 5, res def test_get_pointer_to_function(self): library = self.compile_module(asm_sum) ptr = library.get_pointer_to_function("sum") self.assertIsInstance(ptr, int) cfunc = ctypes_sum_ty(ptr) self.assertEqual(cfunc(2, 3), 5) # Note: With llvm3.9.1, deleting `library` will cause memory error in # the following code during running of optimization passes in # LLVM. The reason of the error is unclear. The error is known to # replicate on osx64 and linux64. # Same, but with dependency on another library library2 = self.compile_module(asm_sum_outer, asm_sum_inner) ptr = library2.get_pointer_to_function("sum") self.assertIsInstance(ptr, int) cfunc = ctypes_sum_ty(ptr) self.assertEqual(cfunc(2, 3), 5) def test_magic_tuple(self): tup = self.codegen.magic_tuple() pickle.dumps(tup) cg2 = JITCPUCodegen('xxx') self.assertEqual(cg2.magic_tuple(), tup) # Serialization tests. def _check_serialize_unserialize(self, state): self._check_unserialize_sum(state) def _check_unserialize_other_process(self, state): arg = base64.b64encode(pickle.dumps(state, -1)) code = """if 1: import base64 import pickle import sys from numba.tests.test_codegen import %(test_class)s state = pickle.loads(base64.b64decode(sys.argv[1])) %(test_class)s._check_unserialize_sum(state) """ % dict(test_class=self.__class__.__name__) subprocess.check_call([sys.executable, '-c', code, arg.decode()]) def test_serialize_unserialize_bitcode(self): library = self.compile_module(asm_sum_outer, asm_sum_inner) state = library.serialize_using_bitcode() self._check_serialize_unserialize(state) def test_unserialize_other_process_bitcode(self): library = self.compile_module(asm_sum_outer, asm_sum_inner) state = library.serialize_using_bitcode() self._check_unserialize_other_process(state) def test_serialize_unserialize_object_code(self): library = self.compile_module(asm_sum_outer, asm_sum_inner) library.enable_object_caching() state = library.serialize_using_object_code() self._check_serialize_unserialize(state) def test_unserialize_other_process_object_code(self): library = self.compile_module(asm_sum_outer, asm_sum_inner) library.enable_object_caching() state = library.serialize_using_object_code() self._check_unserialize_other_process(state) def test_cache_disabled_inspection(self): """ """ library = self.compile_module(asm_sum_outer, asm_sum_inner) library.enable_object_caching() state = library.serialize_using_object_code() # exercise the valid behavior with warnings.catch_warnings(record=True) as w: old_llvm = library.get_llvm_str() old_asm = library.get_asm_str() library.get_function_cfg('sum') self.assertEqual(len(w), 0) # unserialize codegen = JITCPUCodegen('other_codegen') library = codegen.unserialize_library(state) # the inspection methods would warn and give incorrect result with warnings.catch_warnings(record=True) as w: warnings.simplefilter("always") self.assertNotEqual(old_llvm, library.get_llvm_str()) self.assertEqual(len(w), 1) self.assertIn("Inspection disabled", str(w[0].message)) with warnings.catch_warnings(record=True) as w: warnings.simplefilter("always") self.assertNotEqual(library.get_asm_str(), old_asm) self.assertEqual(len(w), 1) self.assertIn("Inspection disabled", str(w[0].message)) with warnings.catch_warnings(record=True) as w: warnings.simplefilter("always") with self.assertRaises(NameError) as raises: library.get_function_cfg('sum') self.assertEqual(len(w), 1) self.assertIn("Inspection disabled", str(w[0].message)) self.assertIn("sum", str(raises.exception)) # Lifetime tests @unittest.expectedFailure # MCJIT removeModule leaks and it is disabled def test_library_lifetime(self): library = self.compile_module(asm_sum_outer, asm_sum_inner) # Exercise code generation library.enable_object_caching() library.serialize_using_bitcode() library.serialize_using_object_code() u = weakref.ref(library) v = weakref.ref(library._final_module) del library # Both the library and its backing LLVM module are collected self.assertIs(u(), None) self.assertIs(v(), None) class TestWrappers(TestCase): def test_noinline_on_main_call(self): # Checks that the cpython and cfunc wrapper produces a call with the # "noinline" attr present for the decorated function. @njit def foo(): pass foo() sig = foo.signatures[0] ol = foo.overloads[sig] name = ol.fndesc.mangled_name.replace("$", r"\$") p1 = r".*call.*{}".format(name) p2 = r".*(#[0-9]+).*" call_site = re.compile(p1 + p2) lines = foo.inspect_llvm(sig).splitlines() meta_data_idx = [] for l in lines: matched = call_site.match(l) if matched: meta_data_idx.append(matched.groups()[0]) # should be 2 calls, one from cpython wrapper one from cfunc wrapper self.assertEqual(len(meta_data_idx), 2) # both calls should refer to the same metadata item self.assertEqual(meta_data_idx[0], meta_data_idx[1]) p1 = r"^attributes\s+{}".format(meta_data_idx[0]) p2 = r"\s+=\s+{(.*)}.*$" attr_site = re.compile(p1 + p2) for l in reversed(lines): matched = attr_site.match(l) if matched: meta_data = matched.groups()[0] lmeta = meta_data.strip().split(' ') for x in lmeta: if 'noinline' in x: break else: continue break else: return self.fail("Metadata did not match 'noinline'") if __name__ == '__main__': unittest.main()