1076 lines
37 KiB
Python
1076 lines
37 KiB
Python
import inspect
|
|
import llvmlite.binding as ll
|
|
import multiprocessing
|
|
import numpy as np
|
|
import os
|
|
import stat
|
|
import shutil
|
|
import subprocess
|
|
import sys
|
|
import traceback
|
|
import unittest
|
|
import warnings
|
|
from numba import njit
|
|
from numba.core import codegen
|
|
from numba.core.caching import _UserWideCacheLocator
|
|
from numba.core.errors import NumbaWarning
|
|
from numba.parfors import parfor
|
|
from numba.tests.support import (
|
|
TestCase,
|
|
SerialMixin,
|
|
capture_cache_log,
|
|
import_dynamic,
|
|
override_config,
|
|
run_in_new_process_caching,
|
|
skip_if_typeguard,
|
|
skip_parfors_unsupported,
|
|
temp_directory,
|
|
)
|
|
|
|
try:
|
|
import ipykernel
|
|
except ImportError:
|
|
ipykernel = None
|
|
|
|
|
|
def check_access_is_preventable():
|
|
# This exists to check whether it is possible to prevent access to
|
|
# a file/directory through the use of `chmod 500`. If a user has
|
|
# elevated rights (e.g. root) then writes are likely to be possible
|
|
# anyway. Tests that require functioning access prevention are
|
|
# therefore skipped based on the result of this check.
|
|
tempdir = temp_directory('test_cache')
|
|
test_dir = (os.path.join(tempdir, 'writable_test'))
|
|
os.mkdir(test_dir)
|
|
# check a write is possible
|
|
with open(os.path.join(test_dir, 'write_ok'), 'wt') as f:
|
|
f.write('check1')
|
|
# now forbid access
|
|
os.chmod(test_dir, 0o500)
|
|
try:
|
|
with open(os.path.join(test_dir, 'write_forbidden'), 'wt') as f:
|
|
f.write('check2')
|
|
# access prevention is not possible
|
|
return False
|
|
except PermissionError:
|
|
# Check that the cause of the exception is due to access/permission
|
|
# as per
|
|
# https://github.com/conda/conda/blob/4.5.0/conda/gateways/disk/permissions.py#L35-L37 # noqa: E501
|
|
# errno reports access/perm fail so access prevention via
|
|
# `chmod 500` works for this user.
|
|
return True
|
|
finally:
|
|
os.chmod(test_dir, 0o775)
|
|
shutil.rmtree(test_dir)
|
|
|
|
|
|
_access_preventable = check_access_is_preventable()
|
|
_access_msg = "Cannot create a directory to which writes are preventable"
|
|
skip_bad_access = unittest.skipUnless(_access_preventable, _access_msg)
|
|
|
|
|
|
def constant_unicode_cache():
|
|
c = "abcd"
|
|
return hash(c), c
|
|
|
|
|
|
def check_constant_unicode_cache():
|
|
pyfunc = constant_unicode_cache
|
|
cfunc = njit(cache=True)(pyfunc)
|
|
exp_hv, exp_str = pyfunc()
|
|
got_hv, got_str = cfunc()
|
|
assert exp_hv == got_hv
|
|
assert exp_str == got_str
|
|
|
|
|
|
def dict_cache():
|
|
return {'a': 1, 'b': 2}
|
|
|
|
|
|
def check_dict_cache():
|
|
pyfunc = dict_cache
|
|
cfunc = njit(cache=True)(pyfunc)
|
|
exp = pyfunc()
|
|
got = cfunc()
|
|
assert exp == got
|
|
|
|
|
|
def generator_cache():
|
|
for v in (1, 2, 3):
|
|
yield v
|
|
|
|
|
|
def check_generator_cache():
|
|
pyfunc = generator_cache
|
|
cfunc = njit(cache=True)(pyfunc)
|
|
exp = list(pyfunc())
|
|
got = list(cfunc())
|
|
assert exp == got
|
|
|
|
|
|
class TestCaching(SerialMixin, TestCase):
|
|
def run_test(self, func):
|
|
func()
|
|
res = run_in_new_process_caching(func)
|
|
self.assertEqual(res['exitcode'], 0)
|
|
|
|
def test_constant_unicode_cache(self):
|
|
self.run_test(check_constant_unicode_cache)
|
|
|
|
def test_dict_cache(self):
|
|
self.run_test(check_dict_cache)
|
|
|
|
def test_generator_cache(self):
|
|
self.run_test(check_generator_cache)
|
|
|
|
def test_omitted(self):
|
|
|
|
# Test in a new directory
|
|
cache_dir = temp_directory(self.__class__.__name__)
|
|
ctx = multiprocessing.get_context()
|
|
result_queue = ctx.Queue()
|
|
proc = ctx.Process(
|
|
target=omitted_child_test_wrapper,
|
|
args=(result_queue, cache_dir, False),
|
|
)
|
|
proc.start()
|
|
proc.join()
|
|
success, output = result_queue.get()
|
|
|
|
# Ensure the child process is completed before checking its output
|
|
if not success:
|
|
self.fail(output)
|
|
|
|
self.assertEqual(
|
|
output,
|
|
1000,
|
|
"Omitted function returned an incorrect output"
|
|
)
|
|
|
|
proc = ctx.Process(
|
|
target=omitted_child_test_wrapper,
|
|
args=(result_queue, cache_dir, True)
|
|
)
|
|
proc.start()
|
|
proc.join()
|
|
success, output = result_queue.get()
|
|
|
|
# Ensure the child process is completed before checking its output
|
|
if not success:
|
|
self.fail(output)
|
|
|
|
self.assertEqual(
|
|
output,
|
|
1000,
|
|
"Omitted function returned an incorrect output"
|
|
)
|
|
|
|
|
|
def omitted_child_test_wrapper(result_queue, cache_dir, second_call):
|
|
with override_config("CACHE_DIR", cache_dir):
|
|
@njit(cache=True)
|
|
def test(num=1000):
|
|
return num
|
|
|
|
try:
|
|
output = test()
|
|
# If we have a second call, we should have a cache hit.
|
|
# Otherwise, we expect a cache miss.
|
|
if second_call:
|
|
assert test._cache_hits[test.signatures[0]] == 1, \
|
|
"Cache did not hit as expected"
|
|
assert test._cache_misses[test.signatures[0]] == 0, \
|
|
"Cache has an unexpected miss"
|
|
else:
|
|
assert test._cache_misses[test.signatures[0]] == 1, \
|
|
"Cache did not miss as expected"
|
|
assert test._cache_hits[test.signatures[0]] == 0, \
|
|
"Cache has an unexpected hit"
|
|
success = True
|
|
# Catch anything raised so it can be propagated
|
|
except: # noqa: E722
|
|
output = traceback.format_exc()
|
|
success = False
|
|
result_queue.put((success, output))
|
|
|
|
|
|
class BaseCacheTest(TestCase):
|
|
# The source file that will be copied
|
|
usecases_file = None
|
|
# Make sure this doesn't conflict with another module
|
|
modname = None
|
|
|
|
def setUp(self):
|
|
self.tempdir = temp_directory('test_cache')
|
|
sys.path.insert(0, self.tempdir)
|
|
self.modfile = os.path.join(self.tempdir, self.modname + ".py")
|
|
self.cache_dir = os.path.join(self.tempdir, "__pycache__")
|
|
shutil.copy(self.usecases_file, self.modfile)
|
|
os.chmod(self.modfile, stat.S_IREAD | stat.S_IWRITE)
|
|
self.maxDiff = None
|
|
|
|
def tearDown(self):
|
|
sys.modules.pop(self.modname, None)
|
|
sys.path.remove(self.tempdir)
|
|
|
|
def import_module(self):
|
|
# Import a fresh version of the test module. All jitted functions
|
|
# in the test module will start anew and load overloads from
|
|
# the on-disk cache if possible.
|
|
old = sys.modules.pop(self.modname, None)
|
|
if old is not None:
|
|
# Make sure cached bytecode is removed
|
|
cached = [old.__cached__]
|
|
for fn in cached:
|
|
try:
|
|
os.unlink(fn)
|
|
except FileNotFoundError:
|
|
pass
|
|
mod = import_dynamic(self.modname)
|
|
self.assertEqual(mod.__file__.rstrip('co'), self.modfile)
|
|
return mod
|
|
|
|
def cache_contents(self):
|
|
try:
|
|
return [fn for fn in os.listdir(self.cache_dir)
|
|
if not fn.endswith(('.pyc', ".pyo"))]
|
|
except FileNotFoundError:
|
|
return []
|
|
|
|
def get_cache_mtimes(self):
|
|
return dict((fn, os.path.getmtime(os.path.join(self.cache_dir, fn)))
|
|
for fn in sorted(self.cache_contents()))
|
|
|
|
def check_pycache(self, n):
|
|
c = self.cache_contents()
|
|
self.assertEqual(len(c), n, c)
|
|
|
|
def dummy_test(self):
|
|
pass
|
|
|
|
|
|
class DispatcherCacheUsecasesTest(BaseCacheTest):
|
|
here = os.path.dirname(__file__)
|
|
usecases_file = os.path.join(here, "cache_usecases.py")
|
|
modname = "dispatcher_caching_test_fodder"
|
|
|
|
def run_in_separate_process(self, *, envvars={}):
|
|
# Cached functions can be run from a distinct process.
|
|
# Also stresses issue #1603: uncached function calling cached function
|
|
# shouldn't fail compiling.
|
|
code = """if 1:
|
|
import sys
|
|
|
|
sys.path.insert(0, %(tempdir)r)
|
|
mod = __import__(%(modname)r)
|
|
mod.self_test()
|
|
""" % dict(tempdir=self.tempdir, modname=self.modname)
|
|
|
|
subp_env = os.environ.copy()
|
|
subp_env.update(envvars)
|
|
popen = subprocess.Popen([sys.executable, "-c", code],
|
|
stdout=subprocess.PIPE, stderr=subprocess.PIPE,
|
|
env=subp_env)
|
|
out, err = popen.communicate()
|
|
if popen.returncode != 0:
|
|
raise AssertionError(
|
|
"process failed with code %s: \n"
|
|
"stdout follows\n%s\n"
|
|
"stderr follows\n%s\n"
|
|
% (popen.returncode, out.decode(), err.decode()),
|
|
)
|
|
|
|
def check_hits(self, func, hits, misses=None):
|
|
st = func.stats
|
|
self.assertEqual(sum(st.cache_hits.values()), hits, st.cache_hits)
|
|
if misses is not None:
|
|
self.assertEqual(sum(st.cache_misses.values()), misses,
|
|
st.cache_misses)
|
|
|
|
|
|
class TestCache(DispatcherCacheUsecasesTest):
|
|
|
|
def test_caching(self):
|
|
self.check_pycache(0)
|
|
mod = self.import_module()
|
|
self.check_pycache(0)
|
|
|
|
f = mod.add_usecase
|
|
self.assertPreciseEqual(f(2, 3), 6)
|
|
self.check_pycache(2) # 1 index, 1 data
|
|
self.assertPreciseEqual(f(2.5, 3), 6.5)
|
|
self.check_pycache(3) # 1 index, 2 data
|
|
self.check_hits(f, 0, 2)
|
|
|
|
f = mod.add_objmode_usecase
|
|
self.assertPreciseEqual(f(2, 3), 6)
|
|
self.check_pycache(5) # 2 index, 3 data
|
|
self.assertPreciseEqual(f(2.5, 3), 6.5)
|
|
self.check_pycache(6) # 2 index, 4 data
|
|
self.check_hits(f, 0, 2)
|
|
|
|
f = mod.record_return
|
|
rec = f(mod.aligned_arr, 1)
|
|
self.assertPreciseEqual(tuple(rec), (2, 43.5))
|
|
rec = f(mod.packed_arr, 1)
|
|
self.assertPreciseEqual(tuple(rec), (2, 43.5))
|
|
self.check_pycache(9) # 3 index, 6 data
|
|
self.check_hits(f, 0, 2)
|
|
|
|
# Check the code runs ok from another process
|
|
self.run_in_separate_process()
|
|
|
|
def test_caching_nrt_pruned(self):
|
|
self.check_pycache(0)
|
|
mod = self.import_module()
|
|
self.check_pycache(0)
|
|
|
|
f = mod.add_usecase
|
|
self.assertPreciseEqual(f(2, 3), 6)
|
|
self.check_pycache(2) # 1 index, 1 data
|
|
# NRT pruning may affect cache
|
|
self.assertPreciseEqual(f(2, np.arange(3)), 2 + np.arange(3) + 1)
|
|
self.check_pycache(3) # 1 index, 2 data
|
|
self.check_hits(f, 0, 2)
|
|
|
|
def test_inner_then_outer(self):
|
|
# Caching inner then outer function is ok
|
|
mod = self.import_module()
|
|
self.assertPreciseEqual(mod.inner(3, 2), 6)
|
|
self.check_pycache(2) # 1 index, 1 data
|
|
# Uncached outer function shouldn't fail (issue #1603)
|
|
f = mod.outer_uncached
|
|
self.assertPreciseEqual(f(3, 2), 2)
|
|
self.check_pycache(2) # 1 index, 1 data
|
|
mod = self.import_module()
|
|
f = mod.outer_uncached
|
|
self.assertPreciseEqual(f(3, 2), 2)
|
|
self.check_pycache(2) # 1 index, 1 data
|
|
# Cached outer will create new cache entries
|
|
f = mod.outer
|
|
self.assertPreciseEqual(f(3, 2), 2)
|
|
self.check_pycache(4) # 2 index, 2 data
|
|
self.assertPreciseEqual(f(3.5, 2), 2.5)
|
|
self.check_pycache(6) # 2 index, 4 data
|
|
|
|
def test_outer_then_inner(self):
|
|
# Caching outer then inner function is ok
|
|
mod = self.import_module()
|
|
self.assertPreciseEqual(mod.outer(3, 2), 2)
|
|
self.check_pycache(4) # 2 index, 2 data
|
|
self.assertPreciseEqual(mod.outer_uncached(3, 2), 2)
|
|
self.check_pycache(4) # same
|
|
mod = self.import_module()
|
|
f = mod.inner
|
|
self.assertPreciseEqual(f(3, 2), 6)
|
|
self.check_pycache(4) # same
|
|
self.assertPreciseEqual(f(3.5, 2), 6.5)
|
|
self.check_pycache(5) # 2 index, 3 data
|
|
|
|
def test_no_caching(self):
|
|
mod = self.import_module()
|
|
|
|
f = mod.add_nocache_usecase
|
|
self.assertPreciseEqual(f(2, 3), 6)
|
|
self.check_pycache(0)
|
|
|
|
def test_looplifted(self):
|
|
# Loop-lifted functions can't be cached and raise a warning
|
|
mod = self.import_module()
|
|
|
|
with warnings.catch_warnings(record=True) as w:
|
|
warnings.simplefilter('always', NumbaWarning)
|
|
|
|
f = mod.looplifted
|
|
self.assertPreciseEqual(f(4), 6)
|
|
self.check_pycache(0)
|
|
|
|
self.assertEqual(len(w), 1)
|
|
self.assertIn('Cannot cache compiled function "looplifted" '
|
|
'as it uses lifted code', str(w[0].message))
|
|
|
|
def test_big_array(self):
|
|
# Code references big array globals cannot be cached
|
|
mod = self.import_module()
|
|
with warnings.catch_warnings(record=True) as w:
|
|
warnings.simplefilter('always', NumbaWarning)
|
|
|
|
f = mod.use_big_array
|
|
np.testing.assert_equal(f(), mod.biggie)
|
|
self.check_pycache(0)
|
|
|
|
self.assertEqual(len(w), 1)
|
|
self.assertIn('Cannot cache compiled function "use_big_array" '
|
|
'as it uses dynamic globals', str(w[0].message))
|
|
|
|
def test_ctypes(self):
|
|
# Functions using a ctypes pointer can't be cached and raise
|
|
# a warning.
|
|
mod = self.import_module()
|
|
|
|
for f in [mod.use_c_sin, mod.use_c_sin_nest1, mod.use_c_sin_nest2]:
|
|
with warnings.catch_warnings(record=True) as w:
|
|
warnings.simplefilter('always', NumbaWarning)
|
|
|
|
self.assertPreciseEqual(f(0.0), 0.0)
|
|
self.check_pycache(0)
|
|
|
|
self.assertEqual(len(w), 1)
|
|
self.assertIn(
|
|
'Cannot cache compiled function "{}"'.format(f.__name__),
|
|
str(w[0].message),
|
|
)
|
|
|
|
def test_closure(self):
|
|
mod = self.import_module()
|
|
|
|
with warnings.catch_warnings():
|
|
warnings.simplefilter('error', NumbaWarning)
|
|
|
|
f = mod.closure1
|
|
self.assertPreciseEqual(f(3), 6) # 3 + 3 = 6
|
|
f = mod.closure2
|
|
self.assertPreciseEqual(f(3), 8) # 3 + 5 = 8
|
|
f = mod.closure3
|
|
self.assertPreciseEqual(f(3), 10) # 3 + 7 = 10
|
|
f = mod.closure4
|
|
self.assertPreciseEqual(f(3), 12) # 3 + 9 = 12
|
|
self.check_pycache(5) # 1 nbi, 4 nbc
|
|
|
|
def test_first_class_function(self):
|
|
mod = self.import_module()
|
|
f = mod.first_class_function_usecase
|
|
self.assertEqual(f(mod.first_class_function_mul, 1), 1)
|
|
self.assertEqual(f(mod.first_class_function_mul, 10), 100)
|
|
self.assertEqual(f(mod.first_class_function_add, 1), 2)
|
|
self.assertEqual(f(mod.first_class_function_add, 10), 20)
|
|
# 1 + 1 + 1 nbi, 1 + 1 + 2 nbc - a separate cache for each call to `f`
|
|
# with a different callback.
|
|
self.check_pycache(7)
|
|
|
|
def test_cache_reuse(self):
|
|
mod = self.import_module()
|
|
mod.add_usecase(2, 3)
|
|
mod.add_usecase(2.5, 3.5)
|
|
mod.add_objmode_usecase(2, 3)
|
|
mod.outer_uncached(2, 3)
|
|
mod.outer(2, 3)
|
|
mod.record_return(mod.packed_arr, 0)
|
|
mod.record_return(mod.aligned_arr, 1)
|
|
mtimes = self.get_cache_mtimes()
|
|
# Two signatures compiled
|
|
self.check_hits(mod.add_usecase, 0, 2)
|
|
|
|
mod2 = self.import_module()
|
|
self.assertIsNot(mod, mod2)
|
|
f = mod2.add_usecase
|
|
f(2, 3)
|
|
self.check_hits(f, 1, 0)
|
|
f(2.5, 3.5)
|
|
self.check_hits(f, 2, 0)
|
|
f = mod2.add_objmode_usecase
|
|
f(2, 3)
|
|
self.check_hits(f, 1, 0)
|
|
|
|
# The files haven't changed
|
|
self.assertEqual(self.get_cache_mtimes(), mtimes)
|
|
|
|
self.run_in_separate_process()
|
|
self.assertEqual(self.get_cache_mtimes(), mtimes)
|
|
|
|
def test_cache_invalidate(self):
|
|
mod = self.import_module()
|
|
f = mod.add_usecase
|
|
self.assertPreciseEqual(f(2, 3), 6)
|
|
|
|
# This should change the functions' results
|
|
with open(self.modfile, "a") as f:
|
|
f.write("\nZ = 10\n")
|
|
|
|
mod = self.import_module()
|
|
f = mod.add_usecase
|
|
self.assertPreciseEqual(f(2, 3), 15)
|
|
f = mod.add_objmode_usecase
|
|
self.assertPreciseEqual(f(2, 3), 15)
|
|
|
|
def test_recompile(self):
|
|
# Explicit call to recompile() should overwrite the cache
|
|
mod = self.import_module()
|
|
f = mod.add_usecase
|
|
self.assertPreciseEqual(f(2, 3), 6)
|
|
|
|
mod = self.import_module()
|
|
f = mod.add_usecase
|
|
mod.Z = 10
|
|
self.assertPreciseEqual(f(2, 3), 6)
|
|
f.recompile()
|
|
self.assertPreciseEqual(f(2, 3), 15)
|
|
|
|
# Freshly recompiled version is re-used from other imports
|
|
mod = self.import_module()
|
|
f = mod.add_usecase
|
|
self.assertPreciseEqual(f(2, 3), 15)
|
|
|
|
def test_same_names(self):
|
|
# Function with the same names should still disambiguate
|
|
mod = self.import_module()
|
|
f = mod.renamed_function1
|
|
self.assertPreciseEqual(f(2), 4)
|
|
f = mod.renamed_function2
|
|
self.assertPreciseEqual(f(2), 8)
|
|
|
|
def test_frozen(self):
|
|
from .dummy_module import function
|
|
old_code = function.__code__
|
|
code_obj = compile('pass', 'tests/dummy_module.py', 'exec')
|
|
try:
|
|
function.__code__ = code_obj
|
|
|
|
source = inspect.getfile(function)
|
|
# doesn't return anything, since it cannot find the module
|
|
# fails unless the executable is frozen
|
|
locator = _UserWideCacheLocator.from_function(function, source)
|
|
self.assertIsNone(locator)
|
|
|
|
sys.frozen = True
|
|
# returns a cache locator object, only works when the executable
|
|
# is frozen
|
|
locator = _UserWideCacheLocator.from_function(function, source)
|
|
self.assertIsInstance(locator, _UserWideCacheLocator)
|
|
|
|
finally:
|
|
function.__code__ = old_code
|
|
del sys.frozen
|
|
|
|
def _test_pycache_fallback(self):
|
|
"""
|
|
With a disabled __pycache__, test there is a working fallback
|
|
(e.g. on the user-wide cache dir)
|
|
"""
|
|
mod = self.import_module()
|
|
f = mod.add_usecase
|
|
# Remove this function's cache files at the end, to avoid accumulation
|
|
# across test calls.
|
|
self.addCleanup(shutil.rmtree, f.stats.cache_path, ignore_errors=True)
|
|
|
|
self.assertPreciseEqual(f(2, 3), 6)
|
|
# It's a cache miss since the file was copied to a new temp location
|
|
self.check_hits(f, 0, 1)
|
|
|
|
# Test re-use
|
|
mod2 = self.import_module()
|
|
f = mod2.add_usecase
|
|
self.assertPreciseEqual(f(2, 3), 6)
|
|
self.check_hits(f, 1, 0)
|
|
|
|
# The __pycache__ is empty (otherwise the test's preconditions
|
|
# wouldn't be met)
|
|
self.check_pycache(0)
|
|
|
|
@skip_bad_access
|
|
@unittest.skipIf(os.name == "nt",
|
|
"cannot easily make a directory read-only on Windows")
|
|
def test_non_creatable_pycache(self):
|
|
# Make it impossible to create the __pycache__ directory
|
|
old_perms = os.stat(self.tempdir).st_mode
|
|
os.chmod(self.tempdir, 0o500)
|
|
self.addCleanup(os.chmod, self.tempdir, old_perms)
|
|
|
|
self._test_pycache_fallback()
|
|
|
|
@skip_bad_access
|
|
@unittest.skipIf(os.name == "nt",
|
|
"cannot easily make a directory read-only on Windows")
|
|
def test_non_writable_pycache(self):
|
|
# Make it impossible to write to the __pycache__ directory
|
|
pycache = os.path.join(self.tempdir, '__pycache__')
|
|
os.mkdir(pycache)
|
|
old_perms = os.stat(pycache).st_mode
|
|
os.chmod(pycache, 0o500)
|
|
self.addCleanup(os.chmod, pycache, old_perms)
|
|
|
|
self._test_pycache_fallback()
|
|
|
|
def test_ipython(self):
|
|
# Test caching in an IPython session
|
|
base_cmd = [sys.executable, '-m', 'IPython']
|
|
base_cmd += ['--quiet', '--quick', '--no-banner', '--colors=NoColor']
|
|
try:
|
|
ver = subprocess.check_output(base_cmd + ['--version'])
|
|
except subprocess.CalledProcessError as e:
|
|
self.skipTest("ipython not available: return code %d"
|
|
% e.returncode)
|
|
ver = ver.strip().decode()
|
|
# Create test input
|
|
inputfn = os.path.join(self.tempdir, "ipython_cache_usecase.txt")
|
|
with open(inputfn, "w") as f:
|
|
f.write(r"""
|
|
import os
|
|
import sys
|
|
|
|
from numba import jit
|
|
|
|
# IPython 5 does not support multiline input if stdin isn't
|
|
# a tty (https://github.com/ipython/ipython/issues/9752)
|
|
f = jit(cache=True)(lambda: 42)
|
|
|
|
res = f()
|
|
# IPython writes on stdout, so use stderr instead
|
|
sys.stderr.write(u"cache hits = %d\n" % f.stats.cache_hits[()])
|
|
|
|
# IPython hijacks sys.exit(), bypass it
|
|
sys.stdout.flush()
|
|
sys.stderr.flush()
|
|
os._exit(res)
|
|
""")
|
|
|
|
def execute_with_input():
|
|
# Feed the test input as stdin, to execute it in REPL context
|
|
with open(inputfn, "rb") as stdin:
|
|
p = subprocess.Popen(base_cmd, stdin=stdin,
|
|
stdout=subprocess.PIPE,
|
|
stderr=subprocess.PIPE,
|
|
universal_newlines=True)
|
|
out, err = p.communicate()
|
|
if p.returncode != 42:
|
|
self.fail("unexpected return code %d\n"
|
|
"-- stdout:\n%s\n"
|
|
"-- stderr:\n%s\n"
|
|
% (p.returncode, out, err))
|
|
return err
|
|
|
|
execute_with_input()
|
|
# Run a second time and check caching
|
|
err = execute_with_input()
|
|
self.assertIn("cache hits = 1", err.strip())
|
|
|
|
@unittest.skipIf((ipykernel is None) or (ipykernel.version_info[0] < 6),
|
|
"requires ipykernel >= 6")
|
|
def test_ipykernel(self):
|
|
# Test caching in an IPython session using ipykernel
|
|
|
|
base_cmd = [sys.executable, '-m', 'IPython']
|
|
base_cmd += ['--quiet', '--quick', '--no-banner', '--colors=NoColor']
|
|
try:
|
|
ver = subprocess.check_output(base_cmd + ['--version'])
|
|
except subprocess.CalledProcessError as e:
|
|
self.skipTest("ipython not available: return code %d"
|
|
% e.returncode)
|
|
ver = ver.strip().decode()
|
|
# Create test input
|
|
from ipykernel import compiler
|
|
inputfn = compiler.get_tmp_directory()
|
|
with open(inputfn, "w") as f:
|
|
f.write(r"""
|
|
import os
|
|
import sys
|
|
|
|
from numba import jit
|
|
|
|
# IPython 5 does not support multiline input if stdin isn't
|
|
# a tty (https://github.com/ipython/ipython/issues/9752)
|
|
f = jit(cache=True)(lambda: 42)
|
|
|
|
res = f()
|
|
# IPython writes on stdout, so use stderr instead
|
|
sys.stderr.write(u"cache hits = %d\n" % f.stats.cache_hits[()])
|
|
|
|
# IPython hijacks sys.exit(), bypass it
|
|
sys.stdout.flush()
|
|
sys.stderr.flush()
|
|
os._exit(res)
|
|
""")
|
|
|
|
def execute_with_input():
|
|
# Feed the test input as stdin, to execute it in REPL context
|
|
with open(inputfn, "rb") as stdin:
|
|
p = subprocess.Popen(base_cmd, stdin=stdin,
|
|
stdout=subprocess.PIPE,
|
|
stderr=subprocess.PIPE,
|
|
universal_newlines=True)
|
|
out, err = p.communicate()
|
|
if p.returncode != 42:
|
|
self.fail("unexpected return code %d\n"
|
|
"-- stdout:\n%s\n"
|
|
"-- stderr:\n%s\n"
|
|
% (p.returncode, out, err))
|
|
return err
|
|
|
|
execute_with_input()
|
|
# Run a second time and check caching
|
|
err = execute_with_input()
|
|
self.assertIn("cache hits = 1", err.strip())
|
|
|
|
|
|
@skip_parfors_unsupported
|
|
class TestSequentialParForsCache(DispatcherCacheUsecasesTest):
|
|
def setUp(self):
|
|
super(TestSequentialParForsCache, self).setUp()
|
|
# Turn on sequential parfor lowering
|
|
parfor.sequential_parfor_lowering = True
|
|
|
|
def tearDown(self):
|
|
super(TestSequentialParForsCache, self).tearDown()
|
|
# Turn off sequential parfor lowering
|
|
parfor.sequential_parfor_lowering = False
|
|
|
|
def test_caching(self):
|
|
mod = self.import_module()
|
|
self.check_pycache(0)
|
|
f = mod.parfor_usecase
|
|
ary = np.ones(10)
|
|
self.assertPreciseEqual(f(ary), ary * ary + ary)
|
|
dynamic_globals = [cres.library.has_dynamic_globals
|
|
for cres in f.overloads.values()]
|
|
self.assertEqual(dynamic_globals, [False])
|
|
self.check_pycache(2) # 1 index, 1 data
|
|
|
|
|
|
class TestCacheWithCpuSetting(DispatcherCacheUsecasesTest):
|
|
# Disable parallel testing due to envvars modification
|
|
_numba_parallel_test_ = False
|
|
|
|
def check_later_mtimes(self, mtimes_old):
|
|
match_count = 0
|
|
for k, v in self.get_cache_mtimes().items():
|
|
if k in mtimes_old:
|
|
self.assertGreaterEqual(v, mtimes_old[k])
|
|
match_count += 1
|
|
self.assertGreater(match_count, 0,
|
|
msg='nothing to compare')
|
|
|
|
def test_user_set_cpu_name(self):
|
|
self.check_pycache(0)
|
|
mod = self.import_module()
|
|
mod.self_test()
|
|
cache_size = len(self.cache_contents())
|
|
|
|
mtimes = self.get_cache_mtimes()
|
|
# Change CPU name to generic
|
|
self.run_in_separate_process(envvars={'NUMBA_CPU_NAME': 'generic'})
|
|
|
|
self.check_later_mtimes(mtimes)
|
|
self.assertGreater(len(self.cache_contents()), cache_size)
|
|
# Check cache index
|
|
cache = mod.add_usecase._cache
|
|
cache_file = cache._cache_file
|
|
cache_index = cache_file._load_index()
|
|
self.assertEqual(len(cache_index), 2)
|
|
[key_a, key_b] = cache_index.keys()
|
|
if key_a[1][1] == ll.get_host_cpu_name():
|
|
key_host, key_generic = key_a, key_b
|
|
else:
|
|
key_host, key_generic = key_b, key_a
|
|
self.assertEqual(key_host[1][1], ll.get_host_cpu_name())
|
|
self.assertEqual(key_host[1][2], codegen.get_host_cpu_features())
|
|
self.assertEqual(key_generic[1][1], 'generic')
|
|
self.assertEqual(key_generic[1][2], '')
|
|
|
|
def test_user_set_cpu_features(self):
|
|
self.check_pycache(0)
|
|
mod = self.import_module()
|
|
mod.self_test()
|
|
cache_size = len(self.cache_contents())
|
|
|
|
mtimes = self.get_cache_mtimes()
|
|
# Change CPU feature
|
|
my_cpu_features = '-sse;-avx'
|
|
|
|
system_features = codegen.get_host_cpu_features()
|
|
|
|
self.assertNotEqual(system_features, my_cpu_features)
|
|
self.run_in_separate_process(
|
|
envvars={'NUMBA_CPU_FEATURES': my_cpu_features},
|
|
)
|
|
self.check_later_mtimes(mtimes)
|
|
self.assertGreater(len(self.cache_contents()), cache_size)
|
|
# Check cache index
|
|
cache = mod.add_usecase._cache
|
|
cache_file = cache._cache_file
|
|
cache_index = cache_file._load_index()
|
|
self.assertEqual(len(cache_index), 2)
|
|
[key_a, key_b] = cache_index.keys()
|
|
|
|
if key_a[1][2] == system_features:
|
|
key_host, key_generic = key_a, key_b
|
|
else:
|
|
key_host, key_generic = key_b, key_a
|
|
|
|
self.assertEqual(key_host[1][1], ll.get_host_cpu_name())
|
|
self.assertEqual(key_host[1][2], system_features)
|
|
self.assertEqual(key_generic[1][1], ll.get_host_cpu_name())
|
|
self.assertEqual(key_generic[1][2], my_cpu_features)
|
|
|
|
|
|
class TestMultiprocessCache(BaseCacheTest):
|
|
|
|
# Nested multiprocessing.Pool raises AssertionError:
|
|
# "daemonic processes are not allowed to have children"
|
|
_numba_parallel_test_ = False
|
|
|
|
here = os.path.dirname(__file__)
|
|
usecases_file = os.path.join(here, "cache_usecases.py")
|
|
modname = "dispatcher_caching_test_fodder"
|
|
|
|
def test_multiprocessing(self):
|
|
# Check caching works from multiple processes at once (#2028)
|
|
mod = self.import_module()
|
|
# Calling a pure Python caller of the JIT-compiled function is
|
|
# necessary to reproduce the issue.
|
|
f = mod.simple_usecase_caller
|
|
n = 3
|
|
try:
|
|
ctx = multiprocessing.get_context('spawn')
|
|
except AttributeError:
|
|
ctx = multiprocessing
|
|
pool = ctx.Pool(n)
|
|
try:
|
|
res = sum(pool.imap(f, range(n)))
|
|
finally:
|
|
pool.close()
|
|
self.assertEqual(res, n * (n - 1) // 2)
|
|
|
|
|
|
@skip_if_typeguard
|
|
class TestCacheFileCollision(unittest.TestCase):
|
|
_numba_parallel_test_ = False
|
|
|
|
here = os.path.dirname(__file__)
|
|
usecases_file = os.path.join(here, "cache_usecases.py")
|
|
modname = "caching_file_loc_fodder"
|
|
source_text_1 = """
|
|
from numba import njit
|
|
@njit(cache=True)
|
|
def bar():
|
|
return 123
|
|
"""
|
|
source_text_2 = """
|
|
from numba import njit
|
|
@njit(cache=True)
|
|
def bar():
|
|
return 321
|
|
"""
|
|
|
|
def setUp(self):
|
|
self.tempdir = temp_directory('test_cache_file_loc')
|
|
sys.path.insert(0, self.tempdir)
|
|
self.modname = 'module_name_that_is_unlikely'
|
|
self.assertNotIn(self.modname, sys.modules)
|
|
self.modname_bar1 = self.modname
|
|
self.modname_bar2 = '.'.join([self.modname, 'foo'])
|
|
foomod = os.path.join(self.tempdir, self.modname)
|
|
os.mkdir(foomod)
|
|
with open(os.path.join(foomod, '__init__.py'), 'w') as fout:
|
|
print(self.source_text_1, file=fout)
|
|
with open(os.path.join(foomod, 'foo.py'), 'w') as fout:
|
|
print(self.source_text_2, file=fout)
|
|
|
|
def tearDown(self):
|
|
sys.modules.pop(self.modname_bar1, None)
|
|
sys.modules.pop(self.modname_bar2, None)
|
|
sys.path.remove(self.tempdir)
|
|
|
|
def import_bar1(self):
|
|
return import_dynamic(self.modname_bar1).bar
|
|
|
|
def import_bar2(self):
|
|
return import_dynamic(self.modname_bar2).bar
|
|
|
|
def test_file_location(self):
|
|
bar1 = self.import_bar1()
|
|
bar2 = self.import_bar2()
|
|
# Check that the cache file is named correctly
|
|
idxname1 = bar1._cache._cache_file._index_name
|
|
idxname2 = bar2._cache._cache_file._index_name
|
|
self.assertNotEqual(idxname1, idxname2)
|
|
self.assertTrue(idxname1.startswith("__init__.bar-3.py"))
|
|
self.assertTrue(idxname2.startswith("foo.bar-3.py"))
|
|
|
|
@unittest.skipUnless(hasattr(multiprocessing, 'get_context'),
|
|
'Test requires multiprocessing.get_context')
|
|
def test_no_collision(self):
|
|
bar1 = self.import_bar1()
|
|
bar2 = self.import_bar2()
|
|
with capture_cache_log() as buf:
|
|
res1 = bar1()
|
|
cachelog = buf.getvalue()
|
|
# bar1 should save new index and data
|
|
self.assertEqual(cachelog.count('index saved'), 1)
|
|
self.assertEqual(cachelog.count('data saved'), 1)
|
|
self.assertEqual(cachelog.count('index loaded'), 0)
|
|
self.assertEqual(cachelog.count('data loaded'), 0)
|
|
with capture_cache_log() as buf:
|
|
res2 = bar2()
|
|
cachelog = buf.getvalue()
|
|
# bar2 should save new index and data
|
|
self.assertEqual(cachelog.count('index saved'), 1)
|
|
self.assertEqual(cachelog.count('data saved'), 1)
|
|
self.assertEqual(cachelog.count('index loaded'), 0)
|
|
self.assertEqual(cachelog.count('data loaded'), 0)
|
|
self.assertNotEqual(res1, res2)
|
|
|
|
try:
|
|
# Make sure we can spawn new process without inheriting
|
|
# the parent context.
|
|
mp = multiprocessing.get_context('spawn')
|
|
except ValueError:
|
|
print("missing spawn context")
|
|
|
|
q = mp.Queue()
|
|
# Start new process that calls `cache_file_collision_tester`
|
|
proc = mp.Process(target=cache_file_collision_tester,
|
|
args=(q, self.tempdir,
|
|
self.modname_bar1,
|
|
self.modname_bar2))
|
|
proc.start()
|
|
# Get results from the process
|
|
log1 = q.get()
|
|
got1 = q.get()
|
|
log2 = q.get()
|
|
got2 = q.get()
|
|
proc.join()
|
|
|
|
# The remote execution result of bar1() and bar2() should match
|
|
# the one executed locally.
|
|
self.assertEqual(got1, res1)
|
|
self.assertEqual(got2, res2)
|
|
|
|
# The remote should have loaded bar1 from cache
|
|
self.assertEqual(log1.count('index saved'), 0)
|
|
self.assertEqual(log1.count('data saved'), 0)
|
|
self.assertEqual(log1.count('index loaded'), 1)
|
|
self.assertEqual(log1.count('data loaded'), 1)
|
|
|
|
# The remote should have loaded bar2 from cache
|
|
self.assertEqual(log2.count('index saved'), 0)
|
|
self.assertEqual(log2.count('data saved'), 0)
|
|
self.assertEqual(log2.count('index loaded'), 1)
|
|
self.assertEqual(log2.count('data loaded'), 1)
|
|
|
|
|
|
def cache_file_collision_tester(q, tempdir, modname_bar1, modname_bar2):
|
|
sys.path.insert(0, tempdir)
|
|
bar1 = import_dynamic(modname_bar1).bar
|
|
bar2 = import_dynamic(modname_bar2).bar
|
|
with capture_cache_log() as buf:
|
|
r1 = bar1()
|
|
q.put(buf.getvalue())
|
|
q.put(r1)
|
|
with capture_cache_log() as buf:
|
|
r2 = bar2()
|
|
q.put(buf.getvalue())
|
|
q.put(r2)
|
|
|
|
|
|
class TestCacheMultipleFilesWithSignature(unittest.TestCase):
|
|
# Regression test for https://github.com/numba/numba/issues/3658
|
|
|
|
_numba_parallel_test_ = False
|
|
|
|
source_text_file1 = """
|
|
from file2 import function2
|
|
"""
|
|
source_text_file2 = """
|
|
from numba import njit
|
|
|
|
@njit('float64(float64)', cache=True)
|
|
def function1(x):
|
|
return x
|
|
|
|
@njit('float64(float64)', cache=True)
|
|
def function2(x):
|
|
return x
|
|
"""
|
|
|
|
def setUp(self):
|
|
self.tempdir = temp_directory('test_cache_file_loc')
|
|
|
|
self.file1 = os.path.join(self.tempdir, 'file1.py')
|
|
with open(self.file1, 'w') as fout:
|
|
print(self.source_text_file1, file=fout)
|
|
|
|
self.file2 = os.path.join(self.tempdir, 'file2.py')
|
|
with open(self.file2, 'w') as fout:
|
|
print(self.source_text_file2, file=fout)
|
|
|
|
def tearDown(self):
|
|
shutil.rmtree(self.tempdir)
|
|
|
|
def test_caching_mutliple_files_with_signature(self):
|
|
# Execute file1.py
|
|
popen = subprocess.Popen([sys.executable, self.file1],
|
|
stdout=subprocess.PIPE,
|
|
stderr=subprocess.PIPE)
|
|
out, err = popen.communicate()
|
|
msg = f"stdout:\n{out.decode()}\n\nstderr:\n{err.decode()}"
|
|
self.assertEqual(popen.returncode, 0, msg=msg)
|
|
|
|
# Execute file2.py
|
|
popen = subprocess.Popen([sys.executable, self.file2],
|
|
stdout=subprocess.PIPE,
|
|
stderr=subprocess.PIPE)
|
|
out, err = popen.communicate()
|
|
msg = f"stdout:\n{out.decode()}\n\nstderr:\n{err.decode()}"
|
|
self.assertEqual(popen.returncode, 0, msg)
|
|
|
|
|
|
class TestCFuncCache(BaseCacheTest):
|
|
|
|
here = os.path.dirname(__file__)
|
|
usecases_file = os.path.join(here, "cfunc_cache_usecases.py")
|
|
modname = "cfunc_caching_test_fodder"
|
|
|
|
def run_in_separate_process(self):
|
|
# Cached functions can be run from a distinct process.
|
|
code = """if 1:
|
|
import sys
|
|
|
|
sys.path.insert(0, %(tempdir)r)
|
|
mod = __import__(%(modname)r)
|
|
mod.self_test()
|
|
|
|
f = mod.add_usecase
|
|
assert f.cache_hits == 1
|
|
f = mod.outer
|
|
assert f.cache_hits == 1
|
|
f = mod.div_usecase
|
|
assert f.cache_hits == 1
|
|
""" % dict(tempdir=self.tempdir, modname=self.modname)
|
|
|
|
popen = subprocess.Popen([sys.executable, "-c", code],
|
|
stdout=subprocess.PIPE, stderr=subprocess.PIPE)
|
|
out, err = popen.communicate()
|
|
if popen.returncode != 0:
|
|
raise AssertionError(f"process failed with code {popen.returncode}:"
|
|
f"stderr follows\n{err.decode()}\n")
|
|
|
|
def check_module(self, mod):
|
|
mod.self_test()
|
|
|
|
def test_caching(self):
|
|
self.check_pycache(0)
|
|
mod = self.import_module()
|
|
self.check_pycache(6) # 3 index, 3 data
|
|
|
|
self.assertEqual(mod.add_usecase.cache_hits, 0)
|
|
self.assertEqual(mod.outer.cache_hits, 0)
|
|
self.assertEqual(mod.add_nocache_usecase.cache_hits, 0)
|
|
self.assertEqual(mod.div_usecase.cache_hits, 0)
|
|
self.check_module(mod)
|
|
|
|
# Reload module to hit the cache
|
|
mod = self.import_module()
|
|
self.check_pycache(6) # 3 index, 3 data
|
|
|
|
self.assertEqual(mod.add_usecase.cache_hits, 1)
|
|
self.assertEqual(mod.outer.cache_hits, 1)
|
|
self.assertEqual(mod.add_nocache_usecase.cache_hits, 0)
|
|
self.assertEqual(mod.div_usecase.cache_hits, 1)
|
|
self.check_module(mod)
|
|
|
|
self.run_in_separate_process()
|
|
|
|
|
|
if __name__ == '__main__':
|
|
unittest.main()
|