ai-content-maker/.venv/Lib/site-packages/torch/_dynamo/external_utils.py

104 lines
2.6 KiB
Python

# This module contains functions that *will be allowed* by dynamo
import functools
import torch
import torch.utils._pytree as pytree
try:
import numpy as np
except ModuleNotFoundError:
np = None # type: ignore[assignment]
def is_compiling() -> bool:
"""
Indicates whether we are tracing/compiling with torch.compile() or torch.export().
If need to check specifically that TorchDynamo is used, then use
torch.compiler.is_dynamo_compiling().
TODO(khabinov): we should deprecate this function and use one of these two:
* torch.compiler.is_compiling(),
* torch.compiler.is_dynamo_compiling().
It will depend on the context where to use what.
"""
return torch.compiler.is_compiling()
def wrap_inline(fn):
"""
Create an extra frame around fn that is not in skipfiles
"""
@functools.wraps(fn)
def inner(*args, **kwargs):
return fn(*args, **kwargs)
return inner
def call_hook(hook, *args):
"""
Used by compiled autograd to handle hook returning None
"""
result = hook(*args)
if result is None:
return args[0]
return result
def wrap_numpy(f):
r"""Decorator that turns a function from ``np.ndarray``s to ``np.ndarray``s into a function
from ``torch.Tensor``s to ``torch.Tensor``s.
"""
if not np:
return f
@functools.wraps(f)
def wrap(*args, **kwargs):
args, kwargs = pytree.tree_map_only(
torch.Tensor, lambda x: x.numpy(), (args, kwargs)
)
out = f(*args, **kwargs)
return pytree.tree_map_only(np.ndarray, lambda x: torch.as_tensor(x), out)
return wrap
class FakeContext:
def __init__(self, saved_tensors):
# this will cache the results of saved_tensors
# and will no longer call into c++ binding
self.saved_tensors = saved_tensors
def call_backward(backward_fn, saved_tensors, *args):
grads = backward_fn(FakeContext(saved_tensors), *args)
# in eager, we wrap in a tuple when there's only one grad output
if type(grads) is not tuple:
grads = (grads,)
return grads
def untyped_storage_size(x: torch.Tensor):
return x.untyped_storage().size()
def call_hook_from_backward_state(*args, bw_state, hook_name: str, **kwargs):
return getattr(bw_state, hook_name)(*args, **kwargs)
def call_module_hooks_from_backward_state(
_, result, *args, bw_state, hooks_name: str, module_name: str
):
module = getattr(bw_state, module_name)
hooks = getattr(bw_state, hooks_name)
for hook in hooks:
new_result = hook(module, result, *args)
if new_result is not None:
result = new_result
return result