43 lines
1.8 KiB
Python
43 lines
1.8 KiB
Python
import torch
|
|
from torch.fx.experimental.proxy_tensor import ProxyTorchDispatchMode
|
|
from torch.overrides import TorchFunctionMode
|
|
|
|
|
|
class AutogradStateOpsFailSafeguard(TorchFunctionMode):
|
|
"""
|
|
Detect grad state ops during exporting the graph and fail the process by
|
|
raising an error, to avoid unexpected behavior. Those grad mode ops could be:
|
|
`torch.no_grad`
|
|
`torch.enable_grad`
|
|
`torch.set_grad_enabled`
|
|
|
|
Export with predispatch mode is exempted.
|
|
"""
|
|
|
|
def __torch_function__(self, func, types, args=(), kwargs=None):
|
|
kwargs = kwargs or {}
|
|
unsupported_grad_mode_ops = [
|
|
torch._C._set_grad_enabled,
|
|
]
|
|
# It's only enabled while tracing, by confirming the torch dispatch mode is
|
|
# any active PROXY. This is to allow the autograd ops out of tracing.
|
|
current_state = torch._C.is_grad_enabled()
|
|
if func in unsupported_grad_mode_ops:
|
|
assert len(args) == 1
|
|
changed_state = args[0]
|
|
mode = torch._C._get_dispatch_mode(torch._C._TorchDispatchModeKey.PROXY)
|
|
# Intend to check if it's not the pre_dispatch mode. It's allowed to use
|
|
# autograd ops in pre_dispatch mode, e.g. `torch.no_grad`
|
|
if (
|
|
mode
|
|
and isinstance(mode, ProxyTorchDispatchMode)
|
|
and not mode.pre_dispatch
|
|
and changed_state != current_state
|
|
):
|
|
raise RuntimeError(
|
|
f"Encountered autograd state manager op {func} trying to change global autograd state "
|
|
"while exporting. This is unsafe because we don't capture this op in torch.export "
|
|
"today, hence we can't reflect the user intention soundly."
|
|
)
|
|
return func(*args, **kwargs)
|