112 lines
3.2 KiB
Python
112 lines
3.2 KiB
Python
|
import re
|
||
|
|
||
|
from numba import njit
|
||
|
from numba.core.extending import overload
|
||
|
from numba.core.targetconfig import ConfigStack
|
||
|
from numba.core.compiler import Flags, DEFAULT_FLAGS
|
||
|
from numba.core import types
|
||
|
from numba.core.funcdesc import default_mangler
|
||
|
|
||
|
from numba.tests.support import TestCase, unittest
|
||
|
|
||
|
|
||
|
class TestCompilerFlagCachedOverload(TestCase):
|
||
|
def test_fastmath_in_overload(self):
|
||
|
def fastmath_status():
|
||
|
pass
|
||
|
|
||
|
@overload(fastmath_status)
|
||
|
def ov_fastmath_status():
|
||
|
flags = ConfigStack().top()
|
||
|
val = "Has fastmath" if flags.fastmath else "No fastmath"
|
||
|
|
||
|
def codegen():
|
||
|
return val
|
||
|
|
||
|
return codegen
|
||
|
|
||
|
@njit(fastmath=True)
|
||
|
def set_fastmath():
|
||
|
return fastmath_status()
|
||
|
|
||
|
@njit()
|
||
|
def foo():
|
||
|
a = fastmath_status()
|
||
|
b = set_fastmath()
|
||
|
return (a, b)
|
||
|
|
||
|
a, b = foo()
|
||
|
self.assertEqual(a, "No fastmath")
|
||
|
self.assertEqual(b, "Has fastmath")
|
||
|
|
||
|
|
||
|
class TestFlagMangling(TestCase):
|
||
|
|
||
|
def test_demangle(self):
|
||
|
|
||
|
def check(flags):
|
||
|
mangled = flags.get_mangle_string()
|
||
|
out = flags.demangle(mangled)
|
||
|
# Demangle result MUST match summary()
|
||
|
self.assertEqual(out, flags.summary())
|
||
|
|
||
|
# test empty flags
|
||
|
flags = Flags()
|
||
|
check(flags)
|
||
|
|
||
|
# test default
|
||
|
check(DEFAULT_FLAGS)
|
||
|
|
||
|
# test other
|
||
|
flags = Flags()
|
||
|
flags.no_cpython_wrapper = True
|
||
|
flags.nrt = True
|
||
|
flags.fastmath = True
|
||
|
check(flags)
|
||
|
|
||
|
def test_mangled_flags_is_shorter(self):
|
||
|
# at least for these control cases
|
||
|
flags = Flags()
|
||
|
flags.nrt = True
|
||
|
flags.auto_parallel = True
|
||
|
self.assertLess(len(flags.get_mangle_string()), len(flags.summary()))
|
||
|
|
||
|
def test_mangled_flags_with_fastmath_parfors_inline(self):
|
||
|
# at least for these control cases
|
||
|
flags = Flags()
|
||
|
flags.nrt = True
|
||
|
flags.auto_parallel = True
|
||
|
flags.fastmath = True
|
||
|
flags.inline = "always"
|
||
|
self.assertLess(len(flags.get_mangle_string()), len(flags.summary()))
|
||
|
demangled = flags.demangle(flags.get_mangle_string())
|
||
|
# There should be no pointer value in the demangled string.
|
||
|
self.assertNotIn("0x", demangled)
|
||
|
|
||
|
def test_demangling_from_mangled_symbols(self):
|
||
|
"""Test demangling of flags from mangled symbol"""
|
||
|
# Use default mangler to mangle the string
|
||
|
fname = 'foo'
|
||
|
argtypes = types.int32,
|
||
|
flags = Flags()
|
||
|
flags.nrt = True
|
||
|
flags.target_backend = "myhardware"
|
||
|
name = default_mangler(
|
||
|
fname, argtypes, abi_tags=[flags.get_mangle_string()],
|
||
|
)
|
||
|
# Find the ABI-tag. Starts with "B"
|
||
|
prefix = "_Z3fooB"
|
||
|
# Find the length of the ABI-tag
|
||
|
m = re.match("[0-9]+", name[len(prefix):])
|
||
|
size = m.group(0)
|
||
|
# Extract the ABI tag
|
||
|
base = len(prefix) + len(size)
|
||
|
abi_mangled = name[base:base + int(size)]
|
||
|
# Demangle and check
|
||
|
demangled = Flags.demangle(abi_mangled)
|
||
|
self.assertEqual(demangled, flags.summary())
|
||
|
|
||
|
|
||
|
if __name__ == "__main__":
|
||
|
unittest.main()
|