ai-content-maker/.venv/Lib/site-packages/numba/tests/test_serialize.py

332 lines
11 KiB
Python
Raw Normal View History

2024-05-03 04:18:51 +03:00
import contextlib
import gc
import pickle
import runpy
import subprocess
import sys
import unittest
from multiprocessing import get_context
import numba
from numba.core.errors import TypingError
from numba.tests.support import TestCase
from numba.core.target_extension import resolve_dispatcher_from_str
from numba.cloudpickle import dumps, loads
class TestDispatcherPickling(TestCase):
def run_with_protocols(self, meth, *args, **kwargs):
for proto in range(pickle.HIGHEST_PROTOCOL + 1):
meth(proto, *args, **kwargs)
@contextlib.contextmanager
def simulate_fresh_target(self):
hwstr = 'cpu'
dispatcher_cls = resolve_dispatcher_from_str(hwstr)
old_descr = dispatcher_cls.targetdescr
# Simulate fresh targetdescr
dispatcher_cls.targetdescr = type(dispatcher_cls.targetdescr)(hwstr)
try:
yield
finally:
# Be sure to reinstantiate old descriptor, otherwise other
# objects may be out of sync.
dispatcher_cls.targetdescr = old_descr
def check_call(self, proto, func, expected_result, args):
def check_result(func):
if (isinstance(expected_result, type)
and issubclass(expected_result, Exception)):
self.assertRaises(expected_result, func, *args)
else:
self.assertPreciseEqual(func(*args), expected_result)
# Control
check_result(func)
pickled = pickle.dumps(func, proto)
with self.simulate_fresh_target():
new_func = pickle.loads(pickled)
check_result(new_func)
def test_call_with_sig(self):
from .serialize_usecases import add_with_sig
self.run_with_protocols(self.check_call, add_with_sig, 5, (1, 4))
# Compilation has been disabled => float inputs will be coerced to int
self.run_with_protocols(self.check_call, add_with_sig, 5, (1.2, 4.2))
def test_call_without_sig(self):
from .serialize_usecases import add_without_sig
self.run_with_protocols(self.check_call, add_without_sig, 5, (1, 4))
self.run_with_protocols(self.check_call, add_without_sig, 5.5, (1.2, 4.3))
# Object mode is enabled
self.run_with_protocols(self.check_call, add_without_sig, "abc", ("a", "bc"))
def test_call_nopython(self):
from .serialize_usecases import add_nopython
self.run_with_protocols(self.check_call, add_nopython, 5.5, (1.2, 4.3))
# Object mode is disabled
self.run_with_protocols(self.check_call, add_nopython, TypingError, (object(), object()))
def test_call_nopython_fail(self):
from .serialize_usecases import add_nopython_fail
# Compilation fails
self.run_with_protocols(self.check_call, add_nopython_fail, TypingError, (1, 2))
def test_call_objmode_with_global(self):
from .serialize_usecases import get_global_objmode
self.run_with_protocols(self.check_call, get_global_objmode, 7.5, (2.5,))
def test_call_closure(self):
from .serialize_usecases import closure
inner = closure(1)
self.run_with_protocols(self.check_call, inner, 6, (2, 3))
def check_call_closure_with_globals(self, **jit_args):
from .serialize_usecases import closure_with_globals
inner = closure_with_globals(3.0, **jit_args)
self.run_with_protocols(self.check_call, inner, 7.0, (4.0,))
def test_call_closure_with_globals_nopython(self):
self.check_call_closure_with_globals(nopython=True)
def test_call_closure_with_globals_objmode(self):
self.check_call_closure_with_globals(forceobj=True)
def test_call_closure_calling_other_function(self):
from .serialize_usecases import closure_calling_other_function
inner = closure_calling_other_function(3.0)
self.run_with_protocols(self.check_call, inner, 11.0, (4.0, 6.0))
def test_call_closure_calling_other_closure(self):
from .serialize_usecases import closure_calling_other_closure
inner = closure_calling_other_closure(3.0)
self.run_with_protocols(self.check_call, inner, 8.0, (4.0,))
def test_call_dyn_func(self):
from .serialize_usecases import dyn_func
# Check serializing a dynamically-created function
self.run_with_protocols(self.check_call, dyn_func, 36, (6,))
def test_call_dyn_func_objmode(self):
from .serialize_usecases import dyn_func_objmode
# Same with an object mode function
self.run_with_protocols(self.check_call, dyn_func_objmode, 36, (6,))
def test_renamed_module(self):
from .serialize_usecases import get_renamed_module
# Issue #1559: using a renamed module (e.g. `import numpy as np`)
# should not fail serializing
expected = get_renamed_module(0.0)
self.run_with_protocols(self.check_call, get_renamed_module,
expected, (0.0,))
def test_other_process(self):
"""
Check that reconstructing doesn't depend on resources already
instantiated in the original process.
"""
from .serialize_usecases import closure_calling_other_closure
func = closure_calling_other_closure(3.0)
pickled = pickle.dumps(func)
code = """if 1:
import pickle
data = {pickled!r}
func = pickle.loads(data)
res = func(4.0)
assert res == 8.0, res
""".format(**locals())
subprocess.check_call([sys.executable, "-c", code])
def test_reuse(self):
"""
Check that deserializing the same function multiple times re-uses
the same dispatcher object.
Note that "same function" is intentionally under-specified.
"""
from .serialize_usecases import closure
func = closure(5)
pickled = pickle.dumps(func)
func2 = closure(6)
pickled2 = pickle.dumps(func2)
f = pickle.loads(pickled)
g = pickle.loads(pickled)
h = pickle.loads(pickled2)
self.assertIs(f, g)
self.assertEqual(f(2, 3), 10)
g.disable_compile()
self.assertEqual(g(2, 4), 11)
self.assertIsNot(f, h)
self.assertEqual(h(2, 3), 11)
# Now make sure the original object doesn't exist when deserializing
func = closure(7)
func(42, 43)
pickled = pickle.dumps(func)
del func
gc.collect()
f = pickle.loads(pickled)
g = pickle.loads(pickled)
self.assertIs(f, g)
self.assertEqual(f(2, 3), 12)
g.disable_compile()
self.assertEqual(g(2, 4), 13)
def test_imp_deprecation(self):
"""
The imp module was deprecated in v3.4 in favour of importlib
"""
code = """if 1:
import pickle
import warnings
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter('always', DeprecationWarning)
from numba import njit
@njit
def foo(x):
return x + 1
foo(1)
serialized_foo = pickle.dumps(foo)
for x in w:
if 'serialize.py' in x.filename:
assert "the imp module is deprecated" not in x.msg
"""
subprocess.check_call([sys.executable, "-c", code])
class TestSerializationMisc(TestCase):
def test_numba_unpickle(self):
# Test that _numba_unpickle is memorizing its output
from numba.core.serialize import _numba_unpickle
random_obj = object()
bytebuf = pickle.dumps(random_obj)
hashed = hash(random_obj)
got1 = _numba_unpickle(id(random_obj), bytebuf, hashed)
# not the original object
self.assertIsNot(got1, random_obj)
got2 = _numba_unpickle(id(random_obj), bytebuf, hashed)
# unpickled results are the same objects
self.assertIs(got1, got2)
class TestCloudPickleIssues(TestCase):
"""This test case includes issues specific to the cloudpickle implementation.
"""
_numba_parallel_test_ = False
def test_dynamic_class_reset_on_unpickle(self):
# a dynamic class
class Klass:
classvar = None
def mutator():
Klass.classvar = 100
def check():
self.assertEqual(Klass.classvar, 100)
saved = dumps(Klass)
mutator()
check()
loads(saved)
# Without the patch, each `loads(saved)` will reset `Klass.classvar`
check()
loads(saved)
check()
@unittest.skipIf(__name__ == "__main__",
"Test cannot run as when module is __main__")
def test_main_class_reset_on_unpickle(self):
mp = get_context('spawn')
proc = mp.Process(target=check_main_class_reset_on_unpickle)
proc.start()
proc.join(timeout=60)
self.assertEqual(proc.exitcode, 0)
def test_dynamic_class_reset_on_unpickle_new_proc(self):
# a dynamic class
class Klass:
classvar = None
# serialize Klass in this process
saved = dumps(Klass)
# Check the reset problem in a new process
mp = get_context('spawn')
proc = mp.Process(target=check_unpickle_dyn_class_new_proc, args=(saved,))
proc.start()
proc.join(timeout=60)
self.assertEqual(proc.exitcode, 0)
def test_dynamic_class_issue_7356(self):
cfunc = numba.njit(issue_7356)
self.assertEqual(cfunc(), (100, 100))
class DynClass(object):
# For testing issue #7356
a = None
def issue_7356():
with numba.objmode(before="intp"):
DynClass.a = 100
before = DynClass.a
with numba.objmode(after="intp"):
after = DynClass.a
return before, after
def check_main_class_reset_on_unpickle():
# Load module and get its global dictionary
glbs = runpy.run_module(
"numba.tests.cloudpickle_main_class",
run_name="__main__",
)
# Get the Klass and check it is from __main__
Klass = glbs['Klass']
assert Klass.__module__ == "__main__"
assert Klass.classvar != 100
saved = dumps(Klass)
# mutate
Klass.classvar = 100
# check
_check_dyn_class(Klass, saved)
def check_unpickle_dyn_class_new_proc(saved):
Klass = loads(saved)
assert Klass.classvar != 100
# mutate
Klass.classvar = 100
# check
_check_dyn_class(Klass, saved)
def _check_dyn_class(Klass, saved):
def check():
if Klass.classvar != 100:
raise AssertionError("Check failed. Klass reset.")
check()
loaded = loads(saved)
if loaded is not Klass:
raise AssertionError("Expected reuse")
# Without the patch, each `loads(saved)` will reset `Klass.classvar`
check()
loaded = loads(saved)
if loaded is not Klass:
raise AssertionError("Expected reuse")
check()
if __name__ == '__main__':
unittest.main()