import inspect import llvmlite.binding as ll import multiprocessing import numpy as np import os import stat import shutil import subprocess import sys import traceback import unittest import warnings from numba import njit from numba.core import codegen from numba.core.caching import _UserWideCacheLocator from numba.core.errors import NumbaWarning from numba.parfors import parfor from numba.tests.support import ( TestCase, SerialMixin, capture_cache_log, import_dynamic, override_config, run_in_new_process_caching, skip_if_typeguard, skip_parfors_unsupported, temp_directory, ) try: import ipykernel except ImportError: ipykernel = None def check_access_is_preventable(): # This exists to check whether it is possible to prevent access to # a file/directory through the use of `chmod 500`. If a user has # elevated rights (e.g. root) then writes are likely to be possible # anyway. Tests that require functioning access prevention are # therefore skipped based on the result of this check. tempdir = temp_directory('test_cache') test_dir = (os.path.join(tempdir, 'writable_test')) os.mkdir(test_dir) # check a write is possible with open(os.path.join(test_dir, 'write_ok'), 'wt') as f: f.write('check1') # now forbid access os.chmod(test_dir, 0o500) try: with open(os.path.join(test_dir, 'write_forbidden'), 'wt') as f: f.write('check2') # access prevention is not possible return False except PermissionError: # Check that the cause of the exception is due to access/permission # as per # https://github.com/conda/conda/blob/4.5.0/conda/gateways/disk/permissions.py#L35-L37 # noqa: E501 # errno reports access/perm fail so access prevention via # `chmod 500` works for this user. return True finally: os.chmod(test_dir, 0o775) shutil.rmtree(test_dir) _access_preventable = check_access_is_preventable() _access_msg = "Cannot create a directory to which writes are preventable" skip_bad_access = unittest.skipUnless(_access_preventable, _access_msg) def constant_unicode_cache(): c = "abcd" return hash(c), c def check_constant_unicode_cache(): pyfunc = constant_unicode_cache cfunc = njit(cache=True)(pyfunc) exp_hv, exp_str = pyfunc() got_hv, got_str = cfunc() assert exp_hv == got_hv assert exp_str == got_str def dict_cache(): return {'a': 1, 'b': 2} def check_dict_cache(): pyfunc = dict_cache cfunc = njit(cache=True)(pyfunc) exp = pyfunc() got = cfunc() assert exp == got def generator_cache(): for v in (1, 2, 3): yield v def check_generator_cache(): pyfunc = generator_cache cfunc = njit(cache=True)(pyfunc) exp = list(pyfunc()) got = list(cfunc()) assert exp == got class TestCaching(SerialMixin, TestCase): def run_test(self, func): func() res = run_in_new_process_caching(func) self.assertEqual(res['exitcode'], 0) def test_constant_unicode_cache(self): self.run_test(check_constant_unicode_cache) def test_dict_cache(self): self.run_test(check_dict_cache) def test_generator_cache(self): self.run_test(check_generator_cache) def test_omitted(self): # Test in a new directory cache_dir = temp_directory(self.__class__.__name__) ctx = multiprocessing.get_context() result_queue = ctx.Queue() proc = ctx.Process( target=omitted_child_test_wrapper, args=(result_queue, cache_dir, False), ) proc.start() proc.join() success, output = result_queue.get() # Ensure the child process is completed before checking its output if not success: self.fail(output) self.assertEqual( output, 1000, "Omitted function returned an incorrect output" ) proc = ctx.Process( target=omitted_child_test_wrapper, args=(result_queue, cache_dir, True) ) proc.start() proc.join() success, output = result_queue.get() # Ensure the child process is completed before checking its output if not success: self.fail(output) self.assertEqual( output, 1000, "Omitted function returned an incorrect output" ) def omitted_child_test_wrapper(result_queue, cache_dir, second_call): with override_config("CACHE_DIR", cache_dir): @njit(cache=True) def test(num=1000): return num try: output = test() # If we have a second call, we should have a cache hit. # Otherwise, we expect a cache miss. if second_call: assert test._cache_hits[test.signatures[0]] == 1, \ "Cache did not hit as expected" assert test._cache_misses[test.signatures[0]] == 0, \ "Cache has an unexpected miss" else: assert test._cache_misses[test.signatures[0]] == 1, \ "Cache did not miss as expected" assert test._cache_hits[test.signatures[0]] == 0, \ "Cache has an unexpected hit" success = True # Catch anything raised so it can be propagated except: # noqa: E722 output = traceback.format_exc() success = False result_queue.put((success, output)) class BaseCacheTest(TestCase): # The source file that will be copied usecases_file = None # Make sure this doesn't conflict with another module modname = None def setUp(self): self.tempdir = temp_directory('test_cache') sys.path.insert(0, self.tempdir) self.modfile = os.path.join(self.tempdir, self.modname + ".py") self.cache_dir = os.path.join(self.tempdir, "__pycache__") shutil.copy(self.usecases_file, self.modfile) os.chmod(self.modfile, stat.S_IREAD | stat.S_IWRITE) self.maxDiff = None def tearDown(self): sys.modules.pop(self.modname, None) sys.path.remove(self.tempdir) def import_module(self): # Import a fresh version of the test module. All jitted functions # in the test module will start anew and load overloads from # the on-disk cache if possible. old = sys.modules.pop(self.modname, None) if old is not None: # Make sure cached bytecode is removed cached = [old.__cached__] for fn in cached: try: os.unlink(fn) except FileNotFoundError: pass mod = import_dynamic(self.modname) self.assertEqual(mod.__file__.rstrip('co'), self.modfile) return mod def cache_contents(self): try: return [fn for fn in os.listdir(self.cache_dir) if not fn.endswith(('.pyc', ".pyo"))] except FileNotFoundError: return [] def get_cache_mtimes(self): return dict((fn, os.path.getmtime(os.path.join(self.cache_dir, fn))) for fn in sorted(self.cache_contents())) def check_pycache(self, n): c = self.cache_contents() self.assertEqual(len(c), n, c) def dummy_test(self): pass class DispatcherCacheUsecasesTest(BaseCacheTest): here = os.path.dirname(__file__) usecases_file = os.path.join(here, "cache_usecases.py") modname = "dispatcher_caching_test_fodder" def run_in_separate_process(self, *, envvars={}): # Cached functions can be run from a distinct process. # Also stresses issue #1603: uncached function calling cached function # shouldn't fail compiling. code = """if 1: import sys sys.path.insert(0, %(tempdir)r) mod = __import__(%(modname)r) mod.self_test() """ % dict(tempdir=self.tempdir, modname=self.modname) subp_env = os.environ.copy() subp_env.update(envvars) popen = subprocess.Popen([sys.executable, "-c", code], stdout=subprocess.PIPE, stderr=subprocess.PIPE, env=subp_env) out, err = popen.communicate() if popen.returncode != 0: raise AssertionError( "process failed with code %s: \n" "stdout follows\n%s\n" "stderr follows\n%s\n" % (popen.returncode, out.decode(), err.decode()), ) def check_hits(self, func, hits, misses=None): st = func.stats self.assertEqual(sum(st.cache_hits.values()), hits, st.cache_hits) if misses is not None: self.assertEqual(sum(st.cache_misses.values()), misses, st.cache_misses) class TestCache(DispatcherCacheUsecasesTest): def test_caching(self): self.check_pycache(0) mod = self.import_module() self.check_pycache(0) f = mod.add_usecase self.assertPreciseEqual(f(2, 3), 6) self.check_pycache(2) # 1 index, 1 data self.assertPreciseEqual(f(2.5, 3), 6.5) self.check_pycache(3) # 1 index, 2 data self.check_hits(f, 0, 2) f = mod.add_objmode_usecase self.assertPreciseEqual(f(2, 3), 6) self.check_pycache(5) # 2 index, 3 data self.assertPreciseEqual(f(2.5, 3), 6.5) self.check_pycache(6) # 2 index, 4 data self.check_hits(f, 0, 2) f = mod.record_return rec = f(mod.aligned_arr, 1) self.assertPreciseEqual(tuple(rec), (2, 43.5)) rec = f(mod.packed_arr, 1) self.assertPreciseEqual(tuple(rec), (2, 43.5)) self.check_pycache(9) # 3 index, 6 data self.check_hits(f, 0, 2) # Check the code runs ok from another process self.run_in_separate_process() def test_caching_nrt_pruned(self): self.check_pycache(0) mod = self.import_module() self.check_pycache(0) f = mod.add_usecase self.assertPreciseEqual(f(2, 3), 6) self.check_pycache(2) # 1 index, 1 data # NRT pruning may affect cache self.assertPreciseEqual(f(2, np.arange(3)), 2 + np.arange(3) + 1) self.check_pycache(3) # 1 index, 2 data self.check_hits(f, 0, 2) def test_inner_then_outer(self): # Caching inner then outer function is ok mod = self.import_module() self.assertPreciseEqual(mod.inner(3, 2), 6) self.check_pycache(2) # 1 index, 1 data # Uncached outer function shouldn't fail (issue #1603) f = mod.outer_uncached self.assertPreciseEqual(f(3, 2), 2) self.check_pycache(2) # 1 index, 1 data mod = self.import_module() f = mod.outer_uncached self.assertPreciseEqual(f(3, 2), 2) self.check_pycache(2) # 1 index, 1 data # Cached outer will create new cache entries f = mod.outer self.assertPreciseEqual(f(3, 2), 2) self.check_pycache(4) # 2 index, 2 data self.assertPreciseEqual(f(3.5, 2), 2.5) self.check_pycache(6) # 2 index, 4 data def test_outer_then_inner(self): # Caching outer then inner function is ok mod = self.import_module() self.assertPreciseEqual(mod.outer(3, 2), 2) self.check_pycache(4) # 2 index, 2 data self.assertPreciseEqual(mod.outer_uncached(3, 2), 2) self.check_pycache(4) # same mod = self.import_module() f = mod.inner self.assertPreciseEqual(f(3, 2), 6) self.check_pycache(4) # same self.assertPreciseEqual(f(3.5, 2), 6.5) self.check_pycache(5) # 2 index, 3 data def test_no_caching(self): mod = self.import_module() f = mod.add_nocache_usecase self.assertPreciseEqual(f(2, 3), 6) self.check_pycache(0) def test_looplifted(self): # Loop-lifted functions can't be cached and raise a warning mod = self.import_module() with warnings.catch_warnings(record=True) as w: warnings.simplefilter('always', NumbaWarning) f = mod.looplifted self.assertPreciseEqual(f(4), 6) self.check_pycache(0) self.assertEqual(len(w), 1) self.assertIn('Cannot cache compiled function "looplifted" ' 'as it uses lifted code', str(w[0].message)) def test_big_array(self): # Code references big array globals cannot be cached mod = self.import_module() with warnings.catch_warnings(record=True) as w: warnings.simplefilter('always', NumbaWarning) f = mod.use_big_array np.testing.assert_equal(f(), mod.biggie) self.check_pycache(0) self.assertEqual(len(w), 1) self.assertIn('Cannot cache compiled function "use_big_array" ' 'as it uses dynamic globals', str(w[0].message)) def test_ctypes(self): # Functions using a ctypes pointer can't be cached and raise # a warning. mod = self.import_module() for f in [mod.use_c_sin, mod.use_c_sin_nest1, mod.use_c_sin_nest2]: with warnings.catch_warnings(record=True) as w: warnings.simplefilter('always', NumbaWarning) self.assertPreciseEqual(f(0.0), 0.0) self.check_pycache(0) self.assertEqual(len(w), 1) self.assertIn( 'Cannot cache compiled function "{}"'.format(f.__name__), str(w[0].message), ) def test_closure(self): mod = self.import_module() with warnings.catch_warnings(): warnings.simplefilter('error', NumbaWarning) f = mod.closure1 self.assertPreciseEqual(f(3), 6) # 3 + 3 = 6 f = mod.closure2 self.assertPreciseEqual(f(3), 8) # 3 + 5 = 8 f = mod.closure3 self.assertPreciseEqual(f(3), 10) # 3 + 7 = 10 f = mod.closure4 self.assertPreciseEqual(f(3), 12) # 3 + 9 = 12 self.check_pycache(5) # 1 nbi, 4 nbc def test_first_class_function(self): mod = self.import_module() f = mod.first_class_function_usecase self.assertEqual(f(mod.first_class_function_mul, 1), 1) self.assertEqual(f(mod.first_class_function_mul, 10), 100) self.assertEqual(f(mod.first_class_function_add, 1), 2) self.assertEqual(f(mod.first_class_function_add, 10), 20) # 1 + 1 + 1 nbi, 1 + 1 + 2 nbc - a separate cache for each call to `f` # with a different callback. self.check_pycache(7) def test_cache_reuse(self): mod = self.import_module() mod.add_usecase(2, 3) mod.add_usecase(2.5, 3.5) mod.add_objmode_usecase(2, 3) mod.outer_uncached(2, 3) mod.outer(2, 3) mod.record_return(mod.packed_arr, 0) mod.record_return(mod.aligned_arr, 1) mtimes = self.get_cache_mtimes() # Two signatures compiled self.check_hits(mod.add_usecase, 0, 2) mod2 = self.import_module() self.assertIsNot(mod, mod2) f = mod2.add_usecase f(2, 3) self.check_hits(f, 1, 0) f(2.5, 3.5) self.check_hits(f, 2, 0) f = mod2.add_objmode_usecase f(2, 3) self.check_hits(f, 1, 0) # The files haven't changed self.assertEqual(self.get_cache_mtimes(), mtimes) self.run_in_separate_process() self.assertEqual(self.get_cache_mtimes(), mtimes) def test_cache_invalidate(self): mod = self.import_module() f = mod.add_usecase self.assertPreciseEqual(f(2, 3), 6) # This should change the functions' results with open(self.modfile, "a") as f: f.write("\nZ = 10\n") mod = self.import_module() f = mod.add_usecase self.assertPreciseEqual(f(2, 3), 15) f = mod.add_objmode_usecase self.assertPreciseEqual(f(2, 3), 15) def test_recompile(self): # Explicit call to recompile() should overwrite the cache mod = self.import_module() f = mod.add_usecase self.assertPreciseEqual(f(2, 3), 6) mod = self.import_module() f = mod.add_usecase mod.Z = 10 self.assertPreciseEqual(f(2, 3), 6) f.recompile() self.assertPreciseEqual(f(2, 3), 15) # Freshly recompiled version is re-used from other imports mod = self.import_module() f = mod.add_usecase self.assertPreciseEqual(f(2, 3), 15) def test_same_names(self): # Function with the same names should still disambiguate mod = self.import_module() f = mod.renamed_function1 self.assertPreciseEqual(f(2), 4) f = mod.renamed_function2 self.assertPreciseEqual(f(2), 8) def test_frozen(self): from .dummy_module import function old_code = function.__code__ code_obj = compile('pass', 'tests/dummy_module.py', 'exec') try: function.__code__ = code_obj source = inspect.getfile(function) # doesn't return anything, since it cannot find the module # fails unless the executable is frozen locator = _UserWideCacheLocator.from_function(function, source) self.assertIsNone(locator) sys.frozen = True # returns a cache locator object, only works when the executable # is frozen locator = _UserWideCacheLocator.from_function(function, source) self.assertIsInstance(locator, _UserWideCacheLocator) finally: function.__code__ = old_code del sys.frozen def _test_pycache_fallback(self): """ With a disabled __pycache__, test there is a working fallback (e.g. on the user-wide cache dir) """ mod = self.import_module() f = mod.add_usecase # Remove this function's cache files at the end, to avoid accumulation # across test calls. self.addCleanup(shutil.rmtree, f.stats.cache_path, ignore_errors=True) self.assertPreciseEqual(f(2, 3), 6) # It's a cache miss since the file was copied to a new temp location self.check_hits(f, 0, 1) # Test re-use mod2 = self.import_module() f = mod2.add_usecase self.assertPreciseEqual(f(2, 3), 6) self.check_hits(f, 1, 0) # The __pycache__ is empty (otherwise the test's preconditions # wouldn't be met) self.check_pycache(0) @skip_bad_access @unittest.skipIf(os.name == "nt", "cannot easily make a directory read-only on Windows") def test_non_creatable_pycache(self): # Make it impossible to create the __pycache__ directory old_perms = os.stat(self.tempdir).st_mode os.chmod(self.tempdir, 0o500) self.addCleanup(os.chmod, self.tempdir, old_perms) self._test_pycache_fallback() @skip_bad_access @unittest.skipIf(os.name == "nt", "cannot easily make a directory read-only on Windows") def test_non_writable_pycache(self): # Make it impossible to write to the __pycache__ directory pycache = os.path.join(self.tempdir, '__pycache__') os.mkdir(pycache) old_perms = os.stat(pycache).st_mode os.chmod(pycache, 0o500) self.addCleanup(os.chmod, pycache, old_perms) self._test_pycache_fallback() def test_ipython(self): # Test caching in an IPython session base_cmd = [sys.executable, '-m', 'IPython'] base_cmd += ['--quiet', '--quick', '--no-banner', '--colors=NoColor'] try: ver = subprocess.check_output(base_cmd + ['--version']) except subprocess.CalledProcessError as e: self.skipTest("ipython not available: return code %d" % e.returncode) ver = ver.strip().decode() # Create test input inputfn = os.path.join(self.tempdir, "ipython_cache_usecase.txt") with open(inputfn, "w") as f: f.write(r""" import os import sys from numba import jit # IPython 5 does not support multiline input if stdin isn't # a tty (https://github.com/ipython/ipython/issues/9752) f = jit(cache=True)(lambda: 42) res = f() # IPython writes on stdout, so use stderr instead sys.stderr.write(u"cache hits = %d\n" % f.stats.cache_hits[()]) # IPython hijacks sys.exit(), bypass it sys.stdout.flush() sys.stderr.flush() os._exit(res) """) def execute_with_input(): # Feed the test input as stdin, to execute it in REPL context with open(inputfn, "rb") as stdin: p = subprocess.Popen(base_cmd, stdin=stdin, stdout=subprocess.PIPE, stderr=subprocess.PIPE, universal_newlines=True) out, err = p.communicate() if p.returncode != 42: self.fail("unexpected return code %d\n" "-- stdout:\n%s\n" "-- stderr:\n%s\n" % (p.returncode, out, err)) return err execute_with_input() # Run a second time and check caching err = execute_with_input() self.assertIn("cache hits = 1", err.strip()) @unittest.skipIf((ipykernel is None) or (ipykernel.version_info[0] < 6), "requires ipykernel >= 6") def test_ipykernel(self): # Test caching in an IPython session using ipykernel base_cmd = [sys.executable, '-m', 'IPython'] base_cmd += ['--quiet', '--quick', '--no-banner', '--colors=NoColor'] try: ver = subprocess.check_output(base_cmd + ['--version']) except subprocess.CalledProcessError as e: self.skipTest("ipython not available: return code %d" % e.returncode) ver = ver.strip().decode() # Create test input from ipykernel import compiler inputfn = compiler.get_tmp_directory() with open(inputfn, "w") as f: f.write(r""" import os import sys from numba import jit # IPython 5 does not support multiline input if stdin isn't # a tty (https://github.com/ipython/ipython/issues/9752) f = jit(cache=True)(lambda: 42) res = f() # IPython writes on stdout, so use stderr instead sys.stderr.write(u"cache hits = %d\n" % f.stats.cache_hits[()]) # IPython hijacks sys.exit(), bypass it sys.stdout.flush() sys.stderr.flush() os._exit(res) """) def execute_with_input(): # Feed the test input as stdin, to execute it in REPL context with open(inputfn, "rb") as stdin: p = subprocess.Popen(base_cmd, stdin=stdin, stdout=subprocess.PIPE, stderr=subprocess.PIPE, universal_newlines=True) out, err = p.communicate() if p.returncode != 42: self.fail("unexpected return code %d\n" "-- stdout:\n%s\n" "-- stderr:\n%s\n" % (p.returncode, out, err)) return err execute_with_input() # Run a second time and check caching err = execute_with_input() self.assertIn("cache hits = 1", err.strip()) @skip_parfors_unsupported class TestSequentialParForsCache(DispatcherCacheUsecasesTest): def setUp(self): super(TestSequentialParForsCache, self).setUp() # Turn on sequential parfor lowering parfor.sequential_parfor_lowering = True def tearDown(self): super(TestSequentialParForsCache, self).tearDown() # Turn off sequential parfor lowering parfor.sequential_parfor_lowering = False def test_caching(self): mod = self.import_module() self.check_pycache(0) f = mod.parfor_usecase ary = np.ones(10) self.assertPreciseEqual(f(ary), ary * ary + ary) dynamic_globals = [cres.library.has_dynamic_globals for cres in f.overloads.values()] self.assertEqual(dynamic_globals, [False]) self.check_pycache(2) # 1 index, 1 data class TestCacheWithCpuSetting(DispatcherCacheUsecasesTest): # Disable parallel testing due to envvars modification _numba_parallel_test_ = False def check_later_mtimes(self, mtimes_old): match_count = 0 for k, v in self.get_cache_mtimes().items(): if k in mtimes_old: self.assertGreaterEqual(v, mtimes_old[k]) match_count += 1 self.assertGreater(match_count, 0, msg='nothing to compare') def test_user_set_cpu_name(self): self.check_pycache(0) mod = self.import_module() mod.self_test() cache_size = len(self.cache_contents()) mtimes = self.get_cache_mtimes() # Change CPU name to generic self.run_in_separate_process(envvars={'NUMBA_CPU_NAME': 'generic'}) self.check_later_mtimes(mtimes) self.assertGreater(len(self.cache_contents()), cache_size) # Check cache index cache = mod.add_usecase._cache cache_file = cache._cache_file cache_index = cache_file._load_index() self.assertEqual(len(cache_index), 2) [key_a, key_b] = cache_index.keys() if key_a[1][1] == ll.get_host_cpu_name(): key_host, key_generic = key_a, key_b else: key_host, key_generic = key_b, key_a self.assertEqual(key_host[1][1], ll.get_host_cpu_name()) self.assertEqual(key_host[1][2], codegen.get_host_cpu_features()) self.assertEqual(key_generic[1][1], 'generic') self.assertEqual(key_generic[1][2], '') def test_user_set_cpu_features(self): self.check_pycache(0) mod = self.import_module() mod.self_test() cache_size = len(self.cache_contents()) mtimes = self.get_cache_mtimes() # Change CPU feature my_cpu_features = '-sse;-avx' system_features = codegen.get_host_cpu_features() self.assertNotEqual(system_features, my_cpu_features) self.run_in_separate_process( envvars={'NUMBA_CPU_FEATURES': my_cpu_features}, ) self.check_later_mtimes(mtimes) self.assertGreater(len(self.cache_contents()), cache_size) # Check cache index cache = mod.add_usecase._cache cache_file = cache._cache_file cache_index = cache_file._load_index() self.assertEqual(len(cache_index), 2) [key_a, key_b] = cache_index.keys() if key_a[1][2] == system_features: key_host, key_generic = key_a, key_b else: key_host, key_generic = key_b, key_a self.assertEqual(key_host[1][1], ll.get_host_cpu_name()) self.assertEqual(key_host[1][2], system_features) self.assertEqual(key_generic[1][1], ll.get_host_cpu_name()) self.assertEqual(key_generic[1][2], my_cpu_features) class TestMultiprocessCache(BaseCacheTest): # Nested multiprocessing.Pool raises AssertionError: # "daemonic processes are not allowed to have children" _numba_parallel_test_ = False here = os.path.dirname(__file__) usecases_file = os.path.join(here, "cache_usecases.py") modname = "dispatcher_caching_test_fodder" def test_multiprocessing(self): # Check caching works from multiple processes at once (#2028) mod = self.import_module() # Calling a pure Python caller of the JIT-compiled function is # necessary to reproduce the issue. f = mod.simple_usecase_caller n = 3 try: ctx = multiprocessing.get_context('spawn') except AttributeError: ctx = multiprocessing pool = ctx.Pool(n) try: res = sum(pool.imap(f, range(n))) finally: pool.close() self.assertEqual(res, n * (n - 1) // 2) @skip_if_typeguard class TestCacheFileCollision(unittest.TestCase): _numba_parallel_test_ = False here = os.path.dirname(__file__) usecases_file = os.path.join(here, "cache_usecases.py") modname = "caching_file_loc_fodder" source_text_1 = """ from numba import njit @njit(cache=True) def bar(): return 123 """ source_text_2 = """ from numba import njit @njit(cache=True) def bar(): return 321 """ def setUp(self): self.tempdir = temp_directory('test_cache_file_loc') sys.path.insert(0, self.tempdir) self.modname = 'module_name_that_is_unlikely' self.assertNotIn(self.modname, sys.modules) self.modname_bar1 = self.modname self.modname_bar2 = '.'.join([self.modname, 'foo']) foomod = os.path.join(self.tempdir, self.modname) os.mkdir(foomod) with open(os.path.join(foomod, '__init__.py'), 'w') as fout: print(self.source_text_1, file=fout) with open(os.path.join(foomod, 'foo.py'), 'w') as fout: print(self.source_text_2, file=fout) def tearDown(self): sys.modules.pop(self.modname_bar1, None) sys.modules.pop(self.modname_bar2, None) sys.path.remove(self.tempdir) def import_bar1(self): return import_dynamic(self.modname_bar1).bar def import_bar2(self): return import_dynamic(self.modname_bar2).bar def test_file_location(self): bar1 = self.import_bar1() bar2 = self.import_bar2() # Check that the cache file is named correctly idxname1 = bar1._cache._cache_file._index_name idxname2 = bar2._cache._cache_file._index_name self.assertNotEqual(idxname1, idxname2) self.assertTrue(idxname1.startswith("__init__.bar-3.py")) self.assertTrue(idxname2.startswith("foo.bar-3.py")) @unittest.skipUnless(hasattr(multiprocessing, 'get_context'), 'Test requires multiprocessing.get_context') def test_no_collision(self): bar1 = self.import_bar1() bar2 = self.import_bar2() with capture_cache_log() as buf: res1 = bar1() cachelog = buf.getvalue() # bar1 should save new index and data self.assertEqual(cachelog.count('index saved'), 1) self.assertEqual(cachelog.count('data saved'), 1) self.assertEqual(cachelog.count('index loaded'), 0) self.assertEqual(cachelog.count('data loaded'), 0) with capture_cache_log() as buf: res2 = bar2() cachelog = buf.getvalue() # bar2 should save new index and data self.assertEqual(cachelog.count('index saved'), 1) self.assertEqual(cachelog.count('data saved'), 1) self.assertEqual(cachelog.count('index loaded'), 0) self.assertEqual(cachelog.count('data loaded'), 0) self.assertNotEqual(res1, res2) try: # Make sure we can spawn new process without inheriting # the parent context. mp = multiprocessing.get_context('spawn') except ValueError: print("missing spawn context") q = mp.Queue() # Start new process that calls `cache_file_collision_tester` proc = mp.Process(target=cache_file_collision_tester, args=(q, self.tempdir, self.modname_bar1, self.modname_bar2)) proc.start() # Get results from the process log1 = q.get() got1 = q.get() log2 = q.get() got2 = q.get() proc.join() # The remote execution result of bar1() and bar2() should match # the one executed locally. self.assertEqual(got1, res1) self.assertEqual(got2, res2) # The remote should have loaded bar1 from cache self.assertEqual(log1.count('index saved'), 0) self.assertEqual(log1.count('data saved'), 0) self.assertEqual(log1.count('index loaded'), 1) self.assertEqual(log1.count('data loaded'), 1) # The remote should have loaded bar2 from cache self.assertEqual(log2.count('index saved'), 0) self.assertEqual(log2.count('data saved'), 0) self.assertEqual(log2.count('index loaded'), 1) self.assertEqual(log2.count('data loaded'), 1) def cache_file_collision_tester(q, tempdir, modname_bar1, modname_bar2): sys.path.insert(0, tempdir) bar1 = import_dynamic(modname_bar1).bar bar2 = import_dynamic(modname_bar2).bar with capture_cache_log() as buf: r1 = bar1() q.put(buf.getvalue()) q.put(r1) with capture_cache_log() as buf: r2 = bar2() q.put(buf.getvalue()) q.put(r2) class TestCacheMultipleFilesWithSignature(unittest.TestCase): # Regression test for https://github.com/numba/numba/issues/3658 _numba_parallel_test_ = False source_text_file1 = """ from file2 import function2 """ source_text_file2 = """ from numba import njit @njit('float64(float64)', cache=True) def function1(x): return x @njit('float64(float64)', cache=True) def function2(x): return x """ def setUp(self): self.tempdir = temp_directory('test_cache_file_loc') self.file1 = os.path.join(self.tempdir, 'file1.py') with open(self.file1, 'w') as fout: print(self.source_text_file1, file=fout) self.file2 = os.path.join(self.tempdir, 'file2.py') with open(self.file2, 'w') as fout: print(self.source_text_file2, file=fout) def tearDown(self): shutil.rmtree(self.tempdir) def test_caching_mutliple_files_with_signature(self): # Execute file1.py popen = subprocess.Popen([sys.executable, self.file1], stdout=subprocess.PIPE, stderr=subprocess.PIPE) out, err = popen.communicate() msg = f"stdout:\n{out.decode()}\n\nstderr:\n{err.decode()}" self.assertEqual(popen.returncode, 0, msg=msg) # Execute file2.py popen = subprocess.Popen([sys.executable, self.file2], stdout=subprocess.PIPE, stderr=subprocess.PIPE) out, err = popen.communicate() msg = f"stdout:\n{out.decode()}\n\nstderr:\n{err.decode()}" self.assertEqual(popen.returncode, 0, msg) class TestCFuncCache(BaseCacheTest): here = os.path.dirname(__file__) usecases_file = os.path.join(here, "cfunc_cache_usecases.py") modname = "cfunc_caching_test_fodder" def run_in_separate_process(self): # Cached functions can be run from a distinct process. code = """if 1: import sys sys.path.insert(0, %(tempdir)r) mod = __import__(%(modname)r) mod.self_test() f = mod.add_usecase assert f.cache_hits == 1 f = mod.outer assert f.cache_hits == 1 f = mod.div_usecase assert f.cache_hits == 1 """ % dict(tempdir=self.tempdir, modname=self.modname) popen = subprocess.Popen([sys.executable, "-c", code], stdout=subprocess.PIPE, stderr=subprocess.PIPE) out, err = popen.communicate() if popen.returncode != 0: raise AssertionError(f"process failed with code {popen.returncode}:" f"stderr follows\n{err.decode()}\n") def check_module(self, mod): mod.self_test() def test_caching(self): self.check_pycache(0) mod = self.import_module() self.check_pycache(6) # 3 index, 3 data self.assertEqual(mod.add_usecase.cache_hits, 0) self.assertEqual(mod.outer.cache_hits, 0) self.assertEqual(mod.add_nocache_usecase.cache_hits, 0) self.assertEqual(mod.div_usecase.cache_hits, 0) self.check_module(mod) # Reload module to hit the cache mod = self.import_module() self.check_pycache(6) # 3 index, 3 data self.assertEqual(mod.add_usecase.cache_hits, 1) self.assertEqual(mod.outer.cache_hits, 1) self.assertEqual(mod.add_nocache_usecase.cache_hits, 0) self.assertEqual(mod.div_usecase.cache_hits, 1) self.check_module(mod) self.run_in_separate_process() if __name__ == '__main__': unittest.main()