281 lines
9.5 KiB
Python
281 lines
9.5 KiB
Python
|
import contextlib
|
||
|
import functools
|
||
|
from typing import List, Optional
|
||
|
|
||
|
import torch
|
||
|
from torch._dynamo.external_utils import call_backward, call_hook
|
||
|
from torch._dynamo.source import GetItemSource, LocalSource
|
||
|
from torch._dynamo.utils import counters, lazy_format_graph_code
|
||
|
from torch._logging import getArtifactLogger, trace_structured
|
||
|
from torch._prims_common import clone_preserve_strides
|
||
|
from torch._subclasses import FakeTensorMode
|
||
|
from torch.fx import GraphModule
|
||
|
from torch.fx.experimental._backward_state import BackwardState
|
||
|
from torch.fx.experimental.proxy_tensor import (
|
||
|
decompose,
|
||
|
disable_autocast_cache,
|
||
|
disable_proxy_modes_tracing,
|
||
|
fetch_object_proxy,
|
||
|
ProxyTorchDispatchMode,
|
||
|
PythonKeyTracer,
|
||
|
track_tensor_tree,
|
||
|
)
|
||
|
from torch.fx.experimental.symbolic_shapes import DimDynamic, ShapeEnv
|
||
|
from torch.fx.proxy import Proxy
|
||
|
|
||
|
compiled_autograd_log = getArtifactLogger(__name__, "compiled_autograd")
|
||
|
|
||
|
|
||
|
def maybe_clone(x):
|
||
|
if x is not None:
|
||
|
return clone_preserve_strides(x)
|
||
|
return x
|
||
|
|
||
|
|
||
|
class AutogradCompilerInstance:
|
||
|
def __init__(self, compiler_fn) -> None:
|
||
|
self.compiler_fn = compiler_fn
|
||
|
self.stack = contextlib.ExitStack()
|
||
|
self.close = self.stack.close
|
||
|
self.shape_env = ShapeEnv()
|
||
|
self.fake_tensor_mode = FakeTensorMode(
|
||
|
allow_fallback_kernels=True,
|
||
|
allow_non_fake_inputs=True,
|
||
|
shape_env=self.shape_env,
|
||
|
)
|
||
|
self.fx_tracer = PythonKeyTracer()
|
||
|
self.proxy_mode = ProxyTorchDispatchMode(self.fx_tracer, "symbolic")
|
||
|
self.hooks_proxy: Optional[Proxy] = None
|
||
|
|
||
|
def wrap_fake(self, x, source):
|
||
|
assert isinstance(x, torch.Tensor)
|
||
|
return self.fake_tensor_mode.from_tensor(x, source=source)
|
||
|
|
||
|
@staticmethod
|
||
|
def source(name, idx) -> GetItemSource:
|
||
|
return GetItemSource(LocalSource(name), idx)
|
||
|
|
||
|
def begin_capture(self, inputs: List[torch.Tensor], sizes: List[int]):
|
||
|
counters["compiled_autograd"]["captures"] += 1
|
||
|
self.fx_tracer.root = torch.nn.Module()
|
||
|
self.fx_tracer.graph = torch.fx.Graph(tracer_cls=PythonKeyTracer)
|
||
|
self.fx_tracer.tensor_attrs = {}
|
||
|
args_proxy = self.fx_tracer.create_proxy("placeholder", "inputs", (), {})
|
||
|
sizes_proxy = self.fx_tracer.create_proxy("placeholder", "sizes", (), {})
|
||
|
self.hooks_proxy = self.fx_tracer.create_proxy("placeholder", "hooks", (), {})
|
||
|
|
||
|
# tensor inputs to fake tensors
|
||
|
inputs = [
|
||
|
self.wrap_fake(x, self.source("inputs", idx))
|
||
|
for idx, x in enumerate(inputs)
|
||
|
]
|
||
|
proxies = [args_proxy[i] for i in range(len(inputs))]
|
||
|
self.bind_tensors_to_proxies(inputs, proxies)
|
||
|
|
||
|
# size inputs to symints
|
||
|
sizes = [
|
||
|
self.shape_env.create_unspecified_symint_and_symbol(
|
||
|
val,
|
||
|
self.source("sizes", idx),
|
||
|
DimDynamic.DYNAMIC,
|
||
|
)
|
||
|
for idx, val in enumerate(sizes)
|
||
|
]
|
||
|
self.bind_tensors_to_proxies(sizes, sizes_proxy)
|
||
|
|
||
|
# TODO(jansel): are all these modes needed?
|
||
|
self.stack.enter_context(decompose({}))
|
||
|
self.stack.enter_context(self.fake_tensor_mode)
|
||
|
self.stack.enter_context(self.proxy_mode.sym_mode)
|
||
|
self.stack.enter_context(self.proxy_mode)
|
||
|
self.stack.enter_context(disable_autocast_cache())
|
||
|
return inputs, sizes
|
||
|
|
||
|
def proxy_call_backward(
|
||
|
self,
|
||
|
inputs,
|
||
|
output_metadatas,
|
||
|
saved_tensors,
|
||
|
backward_idx: int,
|
||
|
):
|
||
|
assert self.hooks_proxy is not None
|
||
|
backward_fn = self.hooks_proxy[backward_idx] # type: ignore[index]
|
||
|
proxies = self.fx_tracer.create_proxy(
|
||
|
kind="call_function",
|
||
|
target=call_backward,
|
||
|
args=(
|
||
|
backward_fn,
|
||
|
self.to_proxy(saved_tensors),
|
||
|
*self.to_proxy(inputs),
|
||
|
),
|
||
|
kwargs={},
|
||
|
)
|
||
|
|
||
|
with disable_proxy_modes_tracing():
|
||
|
# create fake Tensors
|
||
|
grad_ins: List[Optional[torch.Tensor]] = []
|
||
|
for output_metadata in output_metadatas:
|
||
|
if output_metadata is None:
|
||
|
grad_ins.append(None)
|
||
|
continue
|
||
|
|
||
|
layout, device, dtype, size = output_metadata
|
||
|
grad_ins.append(
|
||
|
torch.empty(size=size, dtype=dtype, layout=layout, device=device)
|
||
|
)
|
||
|
self.bind_tensors_to_proxies(grad_ins, proxies)
|
||
|
return tuple(grad_ins)
|
||
|
|
||
|
def proxy_call_hook(self, hook, *args):
|
||
|
return self.fx_tracer.create_proxy(
|
||
|
"call_function",
|
||
|
call_hook,
|
||
|
(
|
||
|
hook,
|
||
|
*[self.to_proxy(x) for x in args],
|
||
|
),
|
||
|
{},
|
||
|
)
|
||
|
|
||
|
def tensor_pre_hook(self, inputs, hook_id, i: int):
|
||
|
assert self.hooks_proxy is not None
|
||
|
hook = self.hooks_proxy[hook_id] # type: ignore[index]
|
||
|
proxy = self.proxy_call_hook(
|
||
|
hook,
|
||
|
inputs[i],
|
||
|
)
|
||
|
with disable_proxy_modes_tracing():
|
||
|
inputs[i] = maybe_clone(inputs[i])
|
||
|
self.bind_tensors_to_proxies([inputs[i]], [proxy])
|
||
|
return inputs
|
||
|
|
||
|
def pre_hook(self, inputs, hook_id):
|
||
|
assert self.hooks_proxy is not None
|
||
|
hook = self.hooks_proxy[hook_id] # type: ignore[index]
|
||
|
proxies = self.proxy_call_hook(
|
||
|
hook,
|
||
|
inputs,
|
||
|
)
|
||
|
with disable_proxy_modes_tracing():
|
||
|
inputs = [maybe_clone(x) for x in inputs]
|
||
|
self.bind_tensors_to_proxies(inputs, proxies)
|
||
|
return inputs
|
||
|
|
||
|
def post_hook(self, outputs, inputs, hook_id):
|
||
|
assert self.hooks_proxy is not None
|
||
|
hook = self.hooks_proxy[hook_id] # type: ignore[index]
|
||
|
proxies = self.proxy_call_hook(
|
||
|
hook,
|
||
|
outputs,
|
||
|
inputs,
|
||
|
)
|
||
|
with disable_proxy_modes_tracing():
|
||
|
outputs = [maybe_clone(x) for x in outputs]
|
||
|
self.bind_tensors_to_proxies(outputs, proxies)
|
||
|
return outputs
|
||
|
|
||
|
def post_acc_grad_hook(self, input, hook_id):
|
||
|
assert isinstance(input, torch.Tensor)
|
||
|
assert self.hooks_proxy is not None
|
||
|
hook = self.hooks_proxy[hook_id] # type: ignore[index]
|
||
|
proxies = self.proxy_call_hook(
|
||
|
hook,
|
||
|
input,
|
||
|
)
|
||
|
with disable_proxy_modes_tracing():
|
||
|
input = [maybe_clone(input)]
|
||
|
self.bind_tensors_to_proxies(input, proxies)
|
||
|
return input
|
||
|
|
||
|
def end_capture(self, outputs):
|
||
|
self.stack.close()
|
||
|
self.fx_tracer.create_node(
|
||
|
"output",
|
||
|
"output",
|
||
|
(self.fx_tracer.create_arg(self.to_proxy(outputs)),),
|
||
|
{},
|
||
|
)
|
||
|
graph = GraphModule(
|
||
|
self.fx_tracer.root, self.fx_tracer.graph, "CompiledAutograd"
|
||
|
)
|
||
|
compiled_autograd_log.info(
|
||
|
"%s", lazy_format_graph_code("Compiled autograd graph", graph)
|
||
|
)
|
||
|
trace_structured(
|
||
|
"compiled_autograd_graph",
|
||
|
payload_fn=lambda: graph.print_readable(print_output=False),
|
||
|
)
|
||
|
return self.compiler_fn(graph)
|
||
|
|
||
|
def to_proxy(self, t):
|
||
|
if t is None:
|
||
|
return None
|
||
|
if isinstance(t, list):
|
||
|
return [self.to_proxy(x) for x in t]
|
||
|
if isinstance(t, tuple):
|
||
|
return tuple(self.to_proxy(x) for x in t)
|
||
|
assert isinstance(t, (torch.Tensor, torch.SymInt))
|
||
|
return fetch_object_proxy(self.fx_tracer)(t).proxy
|
||
|
|
||
|
def bind_tensors_to_proxies(self, tensors, proxies):
|
||
|
if isinstance(proxies, torch.fx.Proxy):
|
||
|
proxies = [proxies[i] for i in range(len(tensors))]
|
||
|
assert len(tensors) == len(proxies)
|
||
|
track_tensor_tree(tensors, proxies, constant=None, tracer=self.fx_tracer)
|
||
|
|
||
|
def bind_backward_state(self, index: int):
|
||
|
assert self.hooks_proxy is not None
|
||
|
proxy = self.hooks_proxy[index] # type: ignore[index]
|
||
|
bw_state = BackwardState()
|
||
|
track_tensor_tree(bw_state, proxy, constant=None, tracer=self.fx_tracer)
|
||
|
return bw_state
|
||
|
|
||
|
|
||
|
compiled_autograd_enabled = False
|
||
|
|
||
|
# We may have code like:
|
||
|
# with enable(compiler_fn):
|
||
|
# ...
|
||
|
# with disable():
|
||
|
# ...
|
||
|
# ...
|
||
|
# The disable() call just want to disable compiled autograd temporarily.
|
||
|
# But overall the feature is enabled.
|
||
|
#
|
||
|
# The code covered by the disable context manager has no way to know if
|
||
|
# compiled autograd is overall eanbled. Use another variable
|
||
|
# compiled_autograd_enabled_count to indicate how many times compiled
|
||
|
# autograd has been enabled in the call stack for this purpose.
|
||
|
compiled_autograd_enabled_count = 0
|
||
|
|
||
|
|
||
|
@contextlib.contextmanager
|
||
|
def enable(compiler_fn):
|
||
|
prior = torch._C._dynamo.compiled_autograd.set_autograd_compiler(
|
||
|
functools.partial(AutogradCompilerInstance, compiler_fn)
|
||
|
)
|
||
|
global compiled_autograd_enabled, compiled_autograd_enabled_count
|
||
|
compiled_autograd_enabled = True
|
||
|
compiled_autograd_enabled_count += 1
|
||
|
try:
|
||
|
with torch.autograd.set_multithreading_enabled(False):
|
||
|
yield
|
||
|
finally:
|
||
|
compiled_autograd_enabled_count -= 1
|
||
|
if not prior:
|
||
|
compiled_autograd_enabled = False
|
||
|
torch._C._dynamo.compiled_autograd.set_autograd_compiler(prior)
|
||
|
|
||
|
|
||
|
@contextlib.contextmanager
|
||
|
def disable():
|
||
|
prior = torch._C._dynamo.compiled_autograd.set_autograd_compiler(None)
|
||
|
global compiled_autograd_enabled
|
||
|
compiled_autograd_enabled = False
|
||
|
try:
|
||
|
yield
|
||
|
finally:
|
||
|
if prior:
|
||
|
compiled_autograd_enabled = True
|
||
|
torch._C._dynamo.compiled_autograd.set_autograd_compiler(prior)
|