218 lines
7.3 KiB
Python
218 lines
7.3 KiB
Python
import sys
|
|
import os
|
|
import os.path as op
|
|
import tempfile
|
|
from subprocess import Popen, check_output, PIPE, STDOUT, CalledProcessError
|
|
from srsly.cloudpickle.compat import pickle
|
|
from contextlib import contextmanager
|
|
from concurrent.futures import ProcessPoolExecutor
|
|
|
|
import psutil
|
|
from srsly.cloudpickle import dumps
|
|
from subprocess import TimeoutExpired
|
|
|
|
loads = pickle.loads
|
|
TIMEOUT = 60
|
|
TEST_GLOBALS = "a test value"
|
|
|
|
|
|
def make_local_function():
|
|
def g(x):
|
|
# this function checks that the globals are correctly handled and that
|
|
# the builtins are available
|
|
assert TEST_GLOBALS == "a test value"
|
|
return sum(range(10))
|
|
|
|
return g
|
|
|
|
|
|
def _make_cwd_env():
|
|
"""Helper to prepare environment for the child processes"""
|
|
cloudpickle_repo_folder = op.normpath(
|
|
op.join(op.dirname(__file__), '..'))
|
|
env = os.environ.copy()
|
|
pythonpath = "{src}{sep}tests{pathsep}{src}".format(
|
|
src=cloudpickle_repo_folder, sep=os.sep, pathsep=os.pathsep)
|
|
env['PYTHONPATH'] = pythonpath
|
|
return cloudpickle_repo_folder, env
|
|
|
|
|
|
def subprocess_pickle_string(input_data, protocol=None, timeout=TIMEOUT,
|
|
add_env=None):
|
|
"""Retrieve pickle string of an object generated by a child Python process
|
|
|
|
Pickle the input data into a buffer, send it to a subprocess via
|
|
stdin, expect the subprocess to unpickle, re-pickle that data back
|
|
and send it back to the parent process via stdout for final unpickling.
|
|
|
|
>>> testutils.subprocess_pickle_string([1, 'a', None], protocol=2)
|
|
b'\x80\x02]q\x00(K\x01X\x01\x00\x00\x00aq\x01Ne.'
|
|
|
|
"""
|
|
# run then pickle_echo(protocol=protocol) in __main__:
|
|
|
|
# Protect stderr from any warning, as we will assume an error will happen
|
|
# if it is not empty. A concrete example is pytest using the imp module,
|
|
# which is deprecated in python 3.8
|
|
cmd = [sys.executable, '-W ignore', __file__, "--protocol", str(protocol)]
|
|
cwd, env = _make_cwd_env()
|
|
if add_env:
|
|
env.update(add_env)
|
|
proc = Popen(cmd, stdin=PIPE, stdout=PIPE, stderr=PIPE, cwd=cwd, env=env,
|
|
bufsize=4096)
|
|
pickle_string = dumps(input_data, protocol=protocol)
|
|
try:
|
|
comm_kwargs = {}
|
|
comm_kwargs['timeout'] = timeout
|
|
out, err = proc.communicate(pickle_string, **comm_kwargs)
|
|
if proc.returncode != 0 or len(err):
|
|
message = "Subprocess returned %d: " % proc.returncode
|
|
message += err.decode('utf-8')
|
|
raise RuntimeError(message)
|
|
return out
|
|
except TimeoutExpired as e:
|
|
proc.kill()
|
|
out, err = proc.communicate()
|
|
message = "\n".join([out.decode('utf-8'), err.decode('utf-8')])
|
|
raise RuntimeError(message) from e
|
|
|
|
|
|
def subprocess_pickle_echo(input_data, protocol=None, timeout=TIMEOUT,
|
|
add_env=None):
|
|
"""Echo function with a child Python process
|
|
Pickle the input data into a buffer, send it to a subprocess via
|
|
stdin, expect the subprocess to unpickle, re-pickle that data back
|
|
and send it back to the parent process via stdout for final unpickling.
|
|
>>> subprocess_pickle_echo([1, 'a', None])
|
|
[1, 'a', None]
|
|
"""
|
|
out = subprocess_pickle_string(input_data,
|
|
protocol=protocol,
|
|
timeout=timeout,
|
|
add_env=add_env)
|
|
return loads(out)
|
|
|
|
|
|
def _read_all_bytes(stream_in, chunk_size=4096):
|
|
all_data = b""
|
|
while True:
|
|
data = stream_in.read(chunk_size)
|
|
all_data += data
|
|
if len(data) < chunk_size:
|
|
break
|
|
return all_data
|
|
|
|
|
|
def pickle_echo(stream_in=None, stream_out=None, protocol=None):
|
|
"""Read a pickle from stdin and pickle it back to stdout"""
|
|
if stream_in is None:
|
|
stream_in = sys.stdin
|
|
if stream_out is None:
|
|
stream_out = sys.stdout
|
|
|
|
# Force the use of bytes streams under Python 3
|
|
if hasattr(stream_in, 'buffer'):
|
|
stream_in = stream_in.buffer
|
|
if hasattr(stream_out, 'buffer'):
|
|
stream_out = stream_out.buffer
|
|
|
|
input_bytes = _read_all_bytes(stream_in)
|
|
stream_in.close()
|
|
obj = loads(input_bytes)
|
|
repickled_bytes = dumps(obj, protocol=protocol)
|
|
stream_out.write(repickled_bytes)
|
|
stream_out.close()
|
|
|
|
|
|
def call_func(payload, protocol):
|
|
"""Remote function call that uses cloudpickle to transport everthing"""
|
|
func, args, kwargs = loads(payload)
|
|
try:
|
|
result = func(*args, **kwargs)
|
|
except BaseException as e:
|
|
result = e
|
|
return dumps(result, protocol=protocol)
|
|
|
|
|
|
class _Worker:
|
|
def __init__(self, protocol=None):
|
|
self.protocol = protocol
|
|
self.pool = ProcessPoolExecutor(max_workers=1)
|
|
self.pool.submit(id, 42).result() # start the worker process
|
|
|
|
def run(self, func, *args, **kwargs):
|
|
"""Synchronous remote function call"""
|
|
|
|
input_payload = dumps((func, args, kwargs), protocol=self.protocol)
|
|
result_payload = self.pool.submit(
|
|
call_func, input_payload, self.protocol).result()
|
|
result = loads(result_payload)
|
|
|
|
if isinstance(result, BaseException):
|
|
raise result
|
|
return result
|
|
|
|
def memsize(self):
|
|
workers_pids = [p.pid if hasattr(p, "pid") else p
|
|
for p in list(self.pool._processes)]
|
|
num_workers = len(workers_pids)
|
|
if num_workers == 0:
|
|
return 0
|
|
elif num_workers > 1:
|
|
raise RuntimeError("Unexpected number of workers: %d"
|
|
% num_workers)
|
|
return psutil.Process(workers_pids[0]).memory_info().rss
|
|
|
|
def close(self):
|
|
self.pool.shutdown(wait=True)
|
|
|
|
|
|
@contextmanager
|
|
def subprocess_worker(protocol=None):
|
|
worker = _Worker(protocol=protocol)
|
|
yield worker
|
|
worker.close()
|
|
|
|
|
|
def assert_run_python_script(source_code, timeout=TIMEOUT):
|
|
"""Utility to help check pickleability of objects defined in __main__
|
|
|
|
The script provided in the source code should return 0 and not print
|
|
anything on stderr or stdout.
|
|
"""
|
|
fd, source_file = tempfile.mkstemp(suffix='_src_test_cloudpickle.py')
|
|
os.close(fd)
|
|
try:
|
|
with open(source_file, 'wb') as f:
|
|
f.write(source_code.encode('utf-8'))
|
|
cmd = [sys.executable, '-W ignore', source_file]
|
|
cwd, env = _make_cwd_env()
|
|
kwargs = {
|
|
'cwd': cwd,
|
|
'stderr': STDOUT,
|
|
'env': env,
|
|
}
|
|
# If coverage is running, pass the config file to the subprocess
|
|
coverage_rc = os.environ.get("COVERAGE_PROCESS_START")
|
|
if coverage_rc:
|
|
kwargs['env']['COVERAGE_PROCESS_START'] = coverage_rc
|
|
kwargs['timeout'] = timeout
|
|
try:
|
|
try:
|
|
out = check_output(cmd, **kwargs)
|
|
except CalledProcessError as e:
|
|
raise RuntimeError("script errored with output:\n%s"
|
|
% e.output.decode('utf-8')) from e
|
|
if out != b"":
|
|
raise AssertionError(out.decode('utf-8'))
|
|
except TimeoutExpired as e:
|
|
raise RuntimeError("script timeout, output so far:\n%s"
|
|
% e.output.decode('utf-8')) from e
|
|
finally:
|
|
os.unlink(source_file)
|
|
|
|
|
|
if __name__ == '__main__':
|
|
protocol = int(sys.argv[sys.argv.index('--protocol') + 1])
|
|
pickle_echo(protocol=protocol)
|