import re from numba import njit from numba.core.extending import overload from numba.core.targetconfig import ConfigStack from numba.core.compiler import Flags, DEFAULT_FLAGS from numba.core import types from numba.core.funcdesc import default_mangler from numba.tests.support import TestCase, unittest class TestCompilerFlagCachedOverload(TestCase): def test_fastmath_in_overload(self): def fastmath_status(): pass @overload(fastmath_status) def ov_fastmath_status(): flags = ConfigStack().top() val = "Has fastmath" if flags.fastmath else "No fastmath" def codegen(): return val return codegen @njit(fastmath=True) def set_fastmath(): return fastmath_status() @njit() def foo(): a = fastmath_status() b = set_fastmath() return (a, b) a, b = foo() self.assertEqual(a, "No fastmath") self.assertEqual(b, "Has fastmath") class TestFlagMangling(TestCase): def test_demangle(self): def check(flags): mangled = flags.get_mangle_string() out = flags.demangle(mangled) # Demangle result MUST match summary() self.assertEqual(out, flags.summary()) # test empty flags flags = Flags() check(flags) # test default check(DEFAULT_FLAGS) # test other flags = Flags() flags.no_cpython_wrapper = True flags.nrt = True flags.fastmath = True check(flags) def test_mangled_flags_is_shorter(self): # at least for these control cases flags = Flags() flags.nrt = True flags.auto_parallel = True self.assertLess(len(flags.get_mangle_string()), len(flags.summary())) def test_mangled_flags_with_fastmath_parfors_inline(self): # at least for these control cases flags = Flags() flags.nrt = True flags.auto_parallel = True flags.fastmath = True flags.inline = "always" self.assertLess(len(flags.get_mangle_string()), len(flags.summary())) demangled = flags.demangle(flags.get_mangle_string()) # There should be no pointer value in the demangled string. self.assertNotIn("0x", demangled) def test_demangling_from_mangled_symbols(self): """Test demangling of flags from mangled symbol""" # Use default mangler to mangle the string fname = 'foo' argtypes = types.int32, flags = Flags() flags.nrt = True flags.target_backend = "myhardware" name = default_mangler( fname, argtypes, abi_tags=[flags.get_mangle_string()], ) # Find the ABI-tag. Starts with "B" prefix = "_Z3fooB" # Find the length of the ABI-tag m = re.match("[0-9]+", name[len(prefix):]) size = m.group(0) # Extract the ABI tag base = len(prefix) + len(size) abi_mangled = name[base:base + int(size)] # Demangle and check demangled = Flags.demangle(abi_mangled) self.assertEqual(demangled, flags.summary()) if __name__ == "__main__": unittest.main()