319 lines
14 KiB
Python
319 lines
14 KiB
Python
|
from __future__ import annotations
|
||
|
|
||
|
from typing import cast, Callable, Generic, List, Optional, Type, TypeVar, Union
|
||
|
|
||
|
import torch
|
||
|
|
||
|
__all__ = ['Future', 'collect_all', 'wait_all']
|
||
|
|
||
|
T = TypeVar("T")
|
||
|
S = TypeVar("S")
|
||
|
|
||
|
|
||
|
class _PyFutureMeta(type(torch._C.Future), type(Generic)): # type: ignore[misc, no-redef]
|
||
|
pass
|
||
|
|
||
|
|
||
|
class Future(torch._C.Future, Generic[T], metaclass=_PyFutureMeta):
|
||
|
r"""
|
||
|
Wrapper around a ``torch._C.Future`` which encapsulates an asynchronous
|
||
|
execution of a callable, e.g. :meth:`~torch.distributed.rpc.rpc_async`. It
|
||
|
also exposes a set of APIs to add callback functions and set results.
|
||
|
|
||
|
.. warning:: GPU support is a beta feature, subject to changes.
|
||
|
"""
|
||
|
|
||
|
def __init__(self, *, devices: Optional[List[Union[int, str, torch.device]]] = None):
|
||
|
r"""
|
||
|
Create an empty unset ``Future``. If the future is intended to hold
|
||
|
values containing CUDA tensors, (a superset of) their CUDA devices must
|
||
|
be specified at construction. (This is only supported if
|
||
|
``torch.cuda.is_available()`` returns ``True``). This is needed to
|
||
|
ensure proper CUDA stream synchronization. The child futures, returned
|
||
|
by the ``then`` method, will inherit these devices.
|
||
|
|
||
|
Args:
|
||
|
devices(``List[Union[int, str, torch.device]]``, optional): the set
|
||
|
of devices on which tensors contained in this future's value are
|
||
|
allowed to reside and on which callbacks are allowed to operate.
|
||
|
"""
|
||
|
if devices is None:
|
||
|
devices = []
|
||
|
super().__init__([torch.device(d) for d in devices])
|
||
|
|
||
|
def done(self) -> bool:
|
||
|
r"""
|
||
|
Return ``True`` if this ``Future`` is done. A ``Future`` is done if it
|
||
|
has a result or an exception.
|
||
|
|
||
|
If the value contains tensors that reside on GPUs, ``Future.done()``
|
||
|
will return ``True`` even if the asynchronous kernels that are
|
||
|
populating those tensors haven't yet completed running on the device,
|
||
|
because at such stage the result is already usable, provided one
|
||
|
performs the appropriate synchronizations (see :meth:`wait`).
|
||
|
"""
|
||
|
return super().done()
|
||
|
|
||
|
def wait(self) -> T:
|
||
|
r"""
|
||
|
Block until the value of this ``Future`` is ready.
|
||
|
|
||
|
If the value contains tensors that reside on GPUs, then an additional
|
||
|
synchronization is performed with the kernels (executing on the device)
|
||
|
which may be asynchronously populating those tensors. Such sync is
|
||
|
non-blocking, which means that ``wait()`` will insert the necessary
|
||
|
instructions in the current streams to ensure that further operations
|
||
|
enqueued on those streams will be properly scheduled after the async
|
||
|
kernels but, once that is done, ``wait()`` will return, even if those
|
||
|
kernels are still running. No further synchronization is required when
|
||
|
accessing and using the values, as long as one doesn't change streams.
|
||
|
|
||
|
Returns:
|
||
|
The value held by this ``Future``. If the function (callback or RPC)
|
||
|
creating the value has thrown an error, this ``wait`` method will
|
||
|
also throw an error.
|
||
|
"""
|
||
|
return super().wait()
|
||
|
|
||
|
def value(self) -> T:
|
||
|
r"""
|
||
|
Obtain the value of an already-completed future.
|
||
|
|
||
|
This method should only be called after a call to :meth:`wait` has
|
||
|
completed, or inside a callback function passed to :meth:`then`. In
|
||
|
other cases this ``Future`` may not yet hold a value and calling
|
||
|
``value()`` could fail.
|
||
|
|
||
|
If the value contains tensors that reside on GPUs, then this method will
|
||
|
*not* perform any additional synchronization. This should be done
|
||
|
beforehand, separately, through a call to :meth:`wait` (except within
|
||
|
callbacks, for which it's already being taken care of by :meth:`then`).
|
||
|
|
||
|
Returns:
|
||
|
The value held by this ``Future``. If the function (callback or RPC)
|
||
|
creating the value has thrown an error, this ``value()`` method will
|
||
|
also throw an error.
|
||
|
"""
|
||
|
return super().value()
|
||
|
|
||
|
def then(self, callback: Callable[[Future[T]], S]) -> Future[S]:
|
||
|
r"""
|
||
|
Append the given callback function to this ``Future``, which will be run
|
||
|
when the ``Future`` is completed. Multiple callbacks can be added to
|
||
|
the same ``Future``, but the order in which they will be executed cannot
|
||
|
be guaranteed (to enforce a certain order consider chaining:
|
||
|
``fut.then(cb1).then(cb2)``). The callback must take one argument, which
|
||
|
is the reference to this ``Future``. The callback function can use the
|
||
|
:meth:`value` method to get the value. Note that if this ``Future`` is
|
||
|
already completed, the given callback will be run immediately inline.
|
||
|
|
||
|
If the ``Future``'s value contains tensors that reside on GPUs, the
|
||
|
callback might be invoked while the async kernels that are populating
|
||
|
those tensors haven't yet finished executing on the device. However, the
|
||
|
callback will be invoked with some dedicated streams set as current
|
||
|
(fetched from a global pool) which will be synchronized with those
|
||
|
kernels. Hence any operation performed by the callback on these tensors
|
||
|
will be scheduled on the device after the kernels complete. In other
|
||
|
words, as long as the callback doesn't switch streams, it can safely
|
||
|
manipulate the result without any additional synchronization. This is
|
||
|
similar to the non-blocking behavior of :meth:`wait`.
|
||
|
|
||
|
Similarly, if the callback returns a value that contains tensors that
|
||
|
reside on a GPU, it can do so even if the kernels that are producing
|
||
|
these tensors are still running on the device, as long as the callback
|
||
|
didn't change streams during its execution. If one wants to change
|
||
|
streams, one must be careful to re-synchronize them with the original
|
||
|
streams, that is, those that were current when the callback was invoked.
|
||
|
|
||
|
Args:
|
||
|
callback(``Callable``): a ``Callable`` that takes this ``Future`` as
|
||
|
the only argument.
|
||
|
|
||
|
Returns:
|
||
|
A new ``Future`` object that holds the return value of the
|
||
|
``callback`` and will be marked as completed when the given
|
||
|
``callback`` finishes.
|
||
|
|
||
|
.. note:: Note that if the callback function throws, either
|
||
|
through the original future being completed with an exception and
|
||
|
calling ``fut.wait()``, or through other code in the callback, the
|
||
|
future returned by ``then`` will be marked appropriately with the
|
||
|
encountered error. However, if this callback later completes
|
||
|
additional futures, those futures are not marked as completed with
|
||
|
an error and the user is responsible for handling completion/waiting
|
||
|
on those futures independently.
|
||
|
|
||
|
Example::
|
||
|
>>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_FUTURES)
|
||
|
>>> def callback(fut):
|
||
|
... print(f"RPC return value is {fut.wait()}.")
|
||
|
>>> fut = torch.futures.Future()
|
||
|
>>> # The inserted callback will print the return value when
|
||
|
>>> # receiving the response from "worker1"
|
||
|
>>> cb_fut = fut.then(callback)
|
||
|
>>> chain_cb_fut = cb_fut.then(
|
||
|
... lambda x : print(f"Chained cb done. {x.wait()}")
|
||
|
... )
|
||
|
>>> fut.set_result(5)
|
||
|
RPC return value is 5.
|
||
|
Chained cb done. None
|
||
|
"""
|
||
|
return cast(Future[S], super().then(callback))
|
||
|
|
||
|
def add_done_callback(self, callback: Callable[[Future[T]], None]) -> None:
|
||
|
r"""
|
||
|
Append the given callback function to this ``Future``, which will be run
|
||
|
when the ``Future`` is completed. Multiple callbacks can be added to
|
||
|
the same ``Future``, but the order in which they will be executed cannot
|
||
|
be guaranteed. The callback must take one argument, which is the
|
||
|
reference to this ``Future``. The callback function can use the
|
||
|
:meth:`value` method to get the value. Note that if this ``Future`` is
|
||
|
already completed, the given callback will be run inline.
|
||
|
|
||
|
We recommend that you use the :meth:`then` method as it provides a way
|
||
|
to synchronize after your callback has completed. ``add_done_callback``
|
||
|
can be cheaper if your callback does not return anything. But both
|
||
|
:meth:`then` and ``add_done_callback`` use the same callback
|
||
|
registration API under the hood.
|
||
|
|
||
|
With respect to GPU tensors, this method behaves in the same way as
|
||
|
:meth:`then`.
|
||
|
|
||
|
Args:
|
||
|
callback(``Future``): a ``Callable`` that takes in one argument,
|
||
|
which is the reference to this ``Future``.
|
||
|
|
||
|
.. note:: Note that if the callback function throws, either
|
||
|
through the original future being completed with an exception and
|
||
|
calling ``fut.wait()``, or through other code in the callback,
|
||
|
error handling must be carefully taken care of. For example, if
|
||
|
this callback later completes additional futures, those futures are
|
||
|
not marked as completed with an error and the user is responsible
|
||
|
for handling completion/waiting on those futures independently.
|
||
|
|
||
|
Example::
|
||
|
>>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_FUTURES)
|
||
|
>>> def callback(fut):
|
||
|
... print("This will run after the future has finished.")
|
||
|
... print(fut.wait())
|
||
|
>>> fut = torch.futures.Future()
|
||
|
>>> fut.add_done_callback(callback)
|
||
|
>>> fut.set_result(5)
|
||
|
This will run after the future has finished.
|
||
|
5
|
||
|
"""
|
||
|
super().add_done_callback(callback)
|
||
|
|
||
|
def set_result(self, result: T) -> None:
|
||
|
r"""
|
||
|
Set the result for this ``Future``, which will mark this ``Future`` as
|
||
|
completed and trigger all attached callbacks. Note that a ``Future``
|
||
|
cannot be marked completed twice.
|
||
|
|
||
|
If the result contains tensors that reside on GPUs, this method can be
|
||
|
called even if the asynchronous kernels that are populating those
|
||
|
tensors haven't yet completed running on the device, provided that the
|
||
|
streams on which those kernels were enqueued are set as the current ones
|
||
|
when this method is called. Put simply, it's safe to call this method
|
||
|
immediately after launching those kernels, without any additional
|
||
|
synchronization, as long as one doesn't change streams in between. This
|
||
|
method will record events on all the relevant current streams and will
|
||
|
use them to ensure proper scheduling for all the consumers of this
|
||
|
``Future``.
|
||
|
|
||
|
Args:
|
||
|
result (object): the result object of this ``Future``.
|
||
|
|
||
|
Example::
|
||
|
>>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_FUTURES)
|
||
|
>>> import threading
|
||
|
>>> import time
|
||
|
>>> def slow_set_future(fut, value):
|
||
|
... time.sleep(0.5)
|
||
|
... fut.set_result(value)
|
||
|
>>> fut = torch.futures.Future()
|
||
|
>>> t = threading.Thread(
|
||
|
... target=slow_set_future,
|
||
|
... args=(fut, torch.ones(2) * 3)
|
||
|
... )
|
||
|
>>> t.start()
|
||
|
>>> print(fut.wait())
|
||
|
tensor([3., 3.])
|
||
|
>>> t.join()
|
||
|
"""
|
||
|
super().set_result(result)
|
||
|
|
||
|
def set_exception(self, result: T) -> None:
|
||
|
r"""
|
||
|
Set an exception for this ``Future``, which will mark this ``Future`` as
|
||
|
completed with an error and trigger all attached callbacks. Note that
|
||
|
when calling wait()/value() on this ``Future``, the exception set here
|
||
|
will be raised inline.
|
||
|
|
||
|
Args:
|
||
|
result (BaseException): the exception for this ``Future``.
|
||
|
|
||
|
Example::
|
||
|
>>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_FUTURES)
|
||
|
>>> fut = torch.futures.Future()
|
||
|
>>> fut.set_exception(ValueError("foo"))
|
||
|
>>> fut.wait()
|
||
|
Traceback (most recent call last):
|
||
|
...
|
||
|
ValueError: foo
|
||
|
"""
|
||
|
assert isinstance(result, Exception), f"{result} is of type {type(result)}, not an Exception."
|
||
|
|
||
|
def raise_error(fut_result):
|
||
|
raise fut_result
|
||
|
|
||
|
super()._set_unwrap_func(raise_error)
|
||
|
self.set_result(result) # type: ignore[arg-type]
|
||
|
|
||
|
|
||
|
def collect_all(futures: List[Future]) -> Future[List[Future]]:
|
||
|
r"""
|
||
|
Collects the provided :class:`~torch.futures.Future` objects into a single
|
||
|
combined :class:`~torch.futures.Future` that is completed when all of the
|
||
|
sub-futures are completed.
|
||
|
|
||
|
Args:
|
||
|
futures (list): a list of :class:`~torch.futures.Future` objects.
|
||
|
|
||
|
Returns:
|
||
|
Returns a :class:`~torch.futures.Future` object to a list of the passed
|
||
|
in Futures.
|
||
|
|
||
|
Example::
|
||
|
>>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_FUTURES)
|
||
|
>>> fut0 = torch.futures.Future()
|
||
|
>>> fut1 = torch.futures.Future()
|
||
|
>>> fut = torch.futures.collect_all([fut0, fut1])
|
||
|
>>> fut0.set_result(0)
|
||
|
>>> fut1.set_result(1)
|
||
|
>>> fut_list = fut.wait()
|
||
|
>>> print(f"fut0 result = {fut_list[0].wait()}")
|
||
|
fut0 result = 0
|
||
|
>>> print(f"fut1 result = {fut_list[1].wait()}")
|
||
|
fut1 result = 1
|
||
|
"""
|
||
|
return cast(Future[List[Future]], torch._C._collect_all(cast(List[torch._C.Future], futures)))
|
||
|
|
||
|
|
||
|
def wait_all(futures: List[Future]) -> List:
|
||
|
r"""
|
||
|
Waits for all provided futures to be complete, and returns
|
||
|
the list of completed values. If any of the futures encounters an error,
|
||
|
the method will exit early and report the error not waiting for other
|
||
|
futures to complete.
|
||
|
|
||
|
Args:
|
||
|
futures (list): a list of :class:`~torch.futures.Future` object.
|
||
|
|
||
|
Returns:
|
||
|
A list of the completed :class:`~torch.futures.Future` results. This
|
||
|
method will throw an error if ``wait`` on any
|
||
|
:class:`~torch.futures.Future` throws.
|
||
|
"""
|
||
|
return [fut.wait() for fut in torch._C._collect_all(cast(List[torch._C.Future], futures)).wait()]
|