112 lines
3.2 KiB
Python
112 lines
3.2 KiB
Python
import re
|
|
|
|
from numba import njit
|
|
from numba.core.extending import overload
|
|
from numba.core.targetconfig import ConfigStack
|
|
from numba.core.compiler import Flags, DEFAULT_FLAGS
|
|
from numba.core import types
|
|
from numba.core.funcdesc import default_mangler
|
|
|
|
from numba.tests.support import TestCase, unittest
|
|
|
|
|
|
class TestCompilerFlagCachedOverload(TestCase):
|
|
def test_fastmath_in_overload(self):
|
|
def fastmath_status():
|
|
pass
|
|
|
|
@overload(fastmath_status)
|
|
def ov_fastmath_status():
|
|
flags = ConfigStack().top()
|
|
val = "Has fastmath" if flags.fastmath else "No fastmath"
|
|
|
|
def codegen():
|
|
return val
|
|
|
|
return codegen
|
|
|
|
@njit(fastmath=True)
|
|
def set_fastmath():
|
|
return fastmath_status()
|
|
|
|
@njit()
|
|
def foo():
|
|
a = fastmath_status()
|
|
b = set_fastmath()
|
|
return (a, b)
|
|
|
|
a, b = foo()
|
|
self.assertEqual(a, "No fastmath")
|
|
self.assertEqual(b, "Has fastmath")
|
|
|
|
|
|
class TestFlagMangling(TestCase):
|
|
|
|
def test_demangle(self):
|
|
|
|
def check(flags):
|
|
mangled = flags.get_mangle_string()
|
|
out = flags.demangle(mangled)
|
|
# Demangle result MUST match summary()
|
|
self.assertEqual(out, flags.summary())
|
|
|
|
# test empty flags
|
|
flags = Flags()
|
|
check(flags)
|
|
|
|
# test default
|
|
check(DEFAULT_FLAGS)
|
|
|
|
# test other
|
|
flags = Flags()
|
|
flags.no_cpython_wrapper = True
|
|
flags.nrt = True
|
|
flags.fastmath = True
|
|
check(flags)
|
|
|
|
def test_mangled_flags_is_shorter(self):
|
|
# at least for these control cases
|
|
flags = Flags()
|
|
flags.nrt = True
|
|
flags.auto_parallel = True
|
|
self.assertLess(len(flags.get_mangle_string()), len(flags.summary()))
|
|
|
|
def test_mangled_flags_with_fastmath_parfors_inline(self):
|
|
# at least for these control cases
|
|
flags = Flags()
|
|
flags.nrt = True
|
|
flags.auto_parallel = True
|
|
flags.fastmath = True
|
|
flags.inline = "always"
|
|
self.assertLess(len(flags.get_mangle_string()), len(flags.summary()))
|
|
demangled = flags.demangle(flags.get_mangle_string())
|
|
# There should be no pointer value in the demangled string.
|
|
self.assertNotIn("0x", demangled)
|
|
|
|
def test_demangling_from_mangled_symbols(self):
|
|
"""Test demangling of flags from mangled symbol"""
|
|
# Use default mangler to mangle the string
|
|
fname = 'foo'
|
|
argtypes = types.int32,
|
|
flags = Flags()
|
|
flags.nrt = True
|
|
flags.target_backend = "myhardware"
|
|
name = default_mangler(
|
|
fname, argtypes, abi_tags=[flags.get_mangle_string()],
|
|
)
|
|
# Find the ABI-tag. Starts with "B"
|
|
prefix = "_Z3fooB"
|
|
# Find the length of the ABI-tag
|
|
m = re.match("[0-9]+", name[len(prefix):])
|
|
size = m.group(0)
|
|
# Extract the ABI tag
|
|
base = len(prefix) + len(size)
|
|
abi_mangled = name[base:base + int(size)]
|
|
# Demangle and check
|
|
demangled = Flags.demangle(abi_mangled)
|
|
self.assertEqual(demangled, flags.summary())
|
|
|
|
|
|
if __name__ == "__main__":
|
|
unittest.main()
|