from functools import partial from . import functions from . import rpc_async import torch from .constants import UNSET_RPC_TIMEOUT from torch.futures import Future def _local_invoke(rref, func_name, args, kwargs): return getattr(rref.local_value(), func_name)(*args, **kwargs) @functions.async_execution def _local_invoke_async_execution(rref, func_name, args, kwargs): return getattr(rref.local_value(), func_name)(*args, **kwargs) def _invoke_rpc(rref, rpc_api, func_name, timeout, *args, **kwargs): def _rref_type_cont(rref_fut): rref_type = rref_fut.value() _invoke_func = _local_invoke # Bypass ScriptModules when checking for async function attribute. bypass_type = issubclass(rref_type, torch.jit.ScriptModule) or issubclass( rref_type, torch._C.ScriptModule ) if not bypass_type: func = getattr(rref_type, func_name) if hasattr(func, "_wrapped_async_rpc_function"): _invoke_func = _local_invoke_async_execution return rpc_api( rref.owner(), _invoke_func, args=(rref, func_name, args, kwargs), timeout=timeout ) rref_fut = rref._get_type(timeout=timeout, blocking=False) if rpc_api != rpc_async: rref_fut.wait() return _rref_type_cont(rref_fut) else: # A little explanation on this. # rpc_async returns a Future pointing to the return value of `func_name`, it returns a `Future[T]` # Calling _rref_type_cont from the `then` lambda causes Future wrapping. IOW, `then` returns a `Future[Future[T]]` # To address that, we return a Future that is completed with the result of the async call. result: Future = Future() def _wrap_rref_type_cont(fut): try: _rref_type_cont(fut).then(_complete_op) except BaseException as ex: result.set_exception(ex) def _complete_op(fut): try: result.set_result(fut.value()) except BaseException as ex: result.set_exception(ex) rref_fut.then(_wrap_rref_type_cont) return result # This class manages proxied RPC API calls for RRefs. It is entirely used from # C++ (see python_rpc_handler.cpp). class RRefProxy: def __init__(self, rref, rpc_api, timeout=UNSET_RPC_TIMEOUT): self.rref = rref self.rpc_api = rpc_api self.rpc_timeout = timeout def __getattr__(self, func_name): return partial(_invoke_rpc, self.rref, self.rpc_api, func_name, self.rpc_timeout)