194 lines
6.3 KiB
Python
194 lines
6.3 KiB
Python
|
import torch
|
||
|
from typing import List
|
||
|
|
||
|
__all__ = [
|
||
|
"compile",
|
||
|
"assume_constant_result",
|
||
|
"reset",
|
||
|
"allow_in_graph",
|
||
|
"list_backends",
|
||
|
"disable",
|
||
|
"cudagraph_mark_step_begin",
|
||
|
"wrap_numpy",
|
||
|
"is_compiling",
|
||
|
"is_dynamo_compiling",
|
||
|
]
|
||
|
|
||
|
def compile(*args, **kwargs):
|
||
|
"""
|
||
|
See :func:`torch.compile` for details on the arguments for this function.
|
||
|
"""
|
||
|
return torch.compile(*args, **kwargs)
|
||
|
|
||
|
def reset() -> None:
|
||
|
"""
|
||
|
This function clears all compilation caches and restores the system to its initial state.
|
||
|
It is recommended to call this function, especially after using operations like `torch.compile(...)`
|
||
|
to ensure a clean state before another unrelated compilation
|
||
|
"""
|
||
|
import torch._dynamo
|
||
|
|
||
|
torch._dynamo.reset()
|
||
|
|
||
|
def allow_in_graph(fn):
|
||
|
"""
|
||
|
Customize which functions compilation will include in the generated graph.
|
||
|
It bypasses all introspection of the symbolic python code in favor of
|
||
|
directly writing it to the graph.
|
||
|
If fn is a list or tuple of callables it recursively applies :func:`allow_in_graph()`
|
||
|
to each function and returns a new list or tuple containing the modified functions
|
||
|
|
||
|
Args:
|
||
|
fn: A callable representing the function to be included in the graph.
|
||
|
|
||
|
.. warning::
|
||
|
|
||
|
:func:`allow_in_graph` skips TorchDynamo completely on the decorated function
|
||
|
skipping all TorchDynamo safety checks (graph breaks, handling closures, etc).
|
||
|
Therefore, one has to be very careful with :func:`allow_in_graph` since subsystems
|
||
|
like AOT Autograd rely on torchdynamo
|
||
|
If not careful, this could lead to soundness and really hard-to-debug issues.
|
||
|
|
||
|
"""
|
||
|
import torch._dynamo
|
||
|
|
||
|
return torch._dynamo.allow_in_graph(fn)
|
||
|
|
||
|
|
||
|
def list_backends(exclude_tags=("debug", "experimental")) -> List[str]:
|
||
|
"""
|
||
|
Return valid strings that can be passed to `torch.compile(..., backend="name")`.
|
||
|
|
||
|
Args:
|
||
|
exclude_tags(optional): A tuple of strings representing tags to exclude.
|
||
|
"""
|
||
|
import torch._dynamo
|
||
|
|
||
|
return torch._dynamo.list_backends(exclude_tags)
|
||
|
|
||
|
def assume_constant_result(fn):
|
||
|
"""
|
||
|
This function is used to mark a function `fn` as having a constant result.
|
||
|
This allows the compiler to optimize away your function
|
||
|
Returns The same function `fn`
|
||
|
|
||
|
Args:
|
||
|
fn: The function to be marked as having a constant result.
|
||
|
|
||
|
.. warning::
|
||
|
`assume_constant_result` can if invalid cause safety and soundness issues, :func:`torch.compile`
|
||
|
will not attempt to validate whether the constant assumption is true or not
|
||
|
|
||
|
"""
|
||
|
import torch._dynamo
|
||
|
|
||
|
return torch._dynamo.assume_constant_result(fn)
|
||
|
|
||
|
def disable(fn=None, recursive=True):
|
||
|
"""
|
||
|
This function provides both a decorator and a context manager to disable compilation on a function
|
||
|
It also provides the option of recursively disabling called functions
|
||
|
|
||
|
Args:
|
||
|
fn (optional): The function to disable
|
||
|
recursive (optional): A boolean value indicating whether the disabling should be recursive.
|
||
|
"""
|
||
|
import torch._dynamo
|
||
|
|
||
|
return torch._dynamo.disable(fn, recursive)
|
||
|
|
||
|
def cudagraph_mark_step_begin():
|
||
|
"""
|
||
|
Indicates that a new iteration of inference or training is about to begin.
|
||
|
|
||
|
CUDA Graphs will free tensors of a prior iteration. A new iteration is started on each invocation of
|
||
|
torch.compile, so long as there is not a pending backward that has not been called.
|
||
|
|
||
|
If that heuristic is wrong, such as in the following example, manually mark it with this api.
|
||
|
|
||
|
.. code-block:: python
|
||
|
|
||
|
@torch.compile(mode="reduce-overhead")
|
||
|
def rand_foo():
|
||
|
return torch.rand([4], device="cuda")
|
||
|
|
||
|
for _ in range(5):
|
||
|
torch.compiler.cudagraph_mark_step_begin()
|
||
|
rand_foo() + rand_foo()
|
||
|
|
||
|
For more details, see `torch.compiler_cudagraph_trees <https://pytorch.org/docs/main/torch.compiler_cudagraph_trees.html>`__
|
||
|
"""
|
||
|
from torch._inductor import cudagraph_trees
|
||
|
|
||
|
cudagraph_trees.mark_step_begin()
|
||
|
|
||
|
def wrap_numpy(fn):
|
||
|
r"""Decorator that turns a function from ``np.ndarray``s to ``np.ndarray``s into a function
|
||
|
from ``torch.Tensor``s to ``torch.Tensor``s.
|
||
|
|
||
|
It is designed to be used with :func:`torch.compile` with ``fullgraph=True``. It allows to
|
||
|
compile a NumPy function as if it were a PyTorch function. This allows you to run NumPy code
|
||
|
on CUDA or compute its gradients.
|
||
|
|
||
|
.. note::
|
||
|
|
||
|
This decorator does not work without :func:`torch.compile`.
|
||
|
|
||
|
Example::
|
||
|
|
||
|
>>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_CUDA)
|
||
|
>>> # Compile a NumPy function as a Tensor -> Tensor function
|
||
|
>>> @torch.compile(fullgraph=True)
|
||
|
>>> @torch.compiler.wrap_numpy
|
||
|
>>> def fn(a: np.ndarray):
|
||
|
>>> return np.sum(a * a)
|
||
|
>>> # Execute the NumPy function using Tensors on CUDA and compute the gradients
|
||
|
>>> x = torch.arange(6, dtype=torch.float32, device="cuda", requires_grad=True)
|
||
|
>>> out = fn(x)
|
||
|
>>> out.backward()
|
||
|
>>> print(x.grad)
|
||
|
tensor([ 0., 2., 4., 6., 8., 10.], device='cuda:0')
|
||
|
"""
|
||
|
from torch._dynamo.external_utils import wrap_numpy as wrap
|
||
|
return wrap(fn)
|
||
|
|
||
|
_is_compiling_flag: bool = False
|
||
|
|
||
|
def is_compiling() -> bool:
|
||
|
"""
|
||
|
Indicates whether a graph is executed/traced as part of torch.compile() or torch.export().
|
||
|
|
||
|
Note that there are 2 other related flags that should deprecated eventually:
|
||
|
* torch._dynamo.external_utils.is_compiling()
|
||
|
* torch._utils.is_compiling()
|
||
|
|
||
|
Example::
|
||
|
|
||
|
>>> def forward(self, x):
|
||
|
>>> if not torch.compiler.is_compiling():
|
||
|
>>> ...logic that is not needed in a compiled/traced graph...
|
||
|
>>>
|
||
|
>>> ...rest of the function...
|
||
|
"""
|
||
|
if torch.jit.is_scripting():
|
||
|
return False
|
||
|
else:
|
||
|
return _is_compiling_flag
|
||
|
|
||
|
def is_dynamo_compiling() -> bool:
|
||
|
"""
|
||
|
Indicates whether a graph is traced via TorchDynamo.
|
||
|
|
||
|
It's stricter than is_compiling() flag, as it would only be set to True when
|
||
|
TorchDynamo is used.
|
||
|
|
||
|
Example::
|
||
|
|
||
|
>>> def forward(self, x):
|
||
|
>>> if not torch.compiler.is_dynamo_compiling():
|
||
|
>>> ...logic that is not needed in a TorchDynamo-traced graph...
|
||
|
>>>
|
||
|
>>> ...rest of the function...
|
||
|
"""
|
||
|
return False
|