import contextlib import functools from typing import List, Optional import torch from torch._dynamo.external_utils import call_backward, call_hook from torch._dynamo.source import GetItemSource, LocalSource from torch._dynamo.utils import counters, lazy_format_graph_code from torch._logging import getArtifactLogger, trace_structured from torch._prims_common import clone_preserve_strides from torch._subclasses import FakeTensorMode from torch.fx import GraphModule from torch.fx.experimental._backward_state import BackwardState from torch.fx.experimental.proxy_tensor import ( decompose, disable_autocast_cache, disable_proxy_modes_tracing, fetch_object_proxy, ProxyTorchDispatchMode, PythonKeyTracer, track_tensor_tree, ) from torch.fx.experimental.symbolic_shapes import DimDynamic, ShapeEnv from torch.fx.proxy import Proxy compiled_autograd_log = getArtifactLogger(__name__, "compiled_autograd") def maybe_clone(x): if x is not None: return clone_preserve_strides(x) return x class AutogradCompilerInstance: def __init__(self, compiler_fn) -> None: self.compiler_fn = compiler_fn self.stack = contextlib.ExitStack() self.close = self.stack.close self.shape_env = ShapeEnv() self.fake_tensor_mode = FakeTensorMode( allow_fallback_kernels=True, allow_non_fake_inputs=True, shape_env=self.shape_env, ) self.fx_tracer = PythonKeyTracer() self.proxy_mode = ProxyTorchDispatchMode(self.fx_tracer, "symbolic") self.hooks_proxy: Optional[Proxy] = None def wrap_fake(self, x, source): assert isinstance(x, torch.Tensor) return self.fake_tensor_mode.from_tensor(x, source=source) @staticmethod def source(name, idx) -> GetItemSource: return GetItemSource(LocalSource(name), idx) def begin_capture(self, inputs: List[torch.Tensor], sizes: List[int]): counters["compiled_autograd"]["captures"] += 1 self.fx_tracer.root = torch.nn.Module() self.fx_tracer.graph = torch.fx.Graph(tracer_cls=PythonKeyTracer) self.fx_tracer.tensor_attrs = {} args_proxy = self.fx_tracer.create_proxy("placeholder", "inputs", (), {}) sizes_proxy = self.fx_tracer.create_proxy("placeholder", "sizes", (), {}) self.hooks_proxy = self.fx_tracer.create_proxy("placeholder", "hooks", (), {}) # tensor inputs to fake tensors inputs = [ self.wrap_fake(x, self.source("inputs", idx)) for idx, x in enumerate(inputs) ] proxies = [args_proxy[i] for i in range(len(inputs))] self.bind_tensors_to_proxies(inputs, proxies) # size inputs to symints sizes = [ self.shape_env.create_unspecified_symint_and_symbol( val, self.source("sizes", idx), DimDynamic.DYNAMIC, ) for idx, val in enumerate(sizes) ] self.bind_tensors_to_proxies(sizes, sizes_proxy) # TODO(jansel): are all these modes needed? self.stack.enter_context(decompose({})) self.stack.enter_context(self.fake_tensor_mode) self.stack.enter_context(self.proxy_mode.sym_mode) self.stack.enter_context(self.proxy_mode) self.stack.enter_context(disable_autocast_cache()) return inputs, sizes def proxy_call_backward( self, inputs, output_metadatas, saved_tensors, backward_idx: int, ): assert self.hooks_proxy is not None backward_fn = self.hooks_proxy[backward_idx] # type: ignore[index] proxies = self.fx_tracer.create_proxy( kind="call_function", target=call_backward, args=( backward_fn, self.to_proxy(saved_tensors), *self.to_proxy(inputs), ), kwargs={}, ) with disable_proxy_modes_tracing(): # create fake Tensors grad_ins: List[Optional[torch.Tensor]] = [] for output_metadata in output_metadatas: if output_metadata is None: grad_ins.append(None) continue layout, device, dtype, size = output_metadata grad_ins.append( torch.empty(size=size, dtype=dtype, layout=layout, device=device) ) self.bind_tensors_to_proxies(grad_ins, proxies) return tuple(grad_ins) def proxy_call_hook(self, hook, *args): return self.fx_tracer.create_proxy( "call_function", call_hook, ( hook, *[self.to_proxy(x) for x in args], ), {}, ) def tensor_pre_hook(self, inputs, hook_id, i: int): assert self.hooks_proxy is not None hook = self.hooks_proxy[hook_id] # type: ignore[index] proxy = self.proxy_call_hook( hook, inputs[i], ) with disable_proxy_modes_tracing(): inputs[i] = maybe_clone(inputs[i]) self.bind_tensors_to_proxies([inputs[i]], [proxy]) return inputs def pre_hook(self, inputs, hook_id): assert self.hooks_proxy is not None hook = self.hooks_proxy[hook_id] # type: ignore[index] proxies = self.proxy_call_hook( hook, inputs, ) with disable_proxy_modes_tracing(): inputs = [maybe_clone(x) for x in inputs] self.bind_tensors_to_proxies(inputs, proxies) return inputs def post_hook(self, outputs, inputs, hook_id): assert self.hooks_proxy is not None hook = self.hooks_proxy[hook_id] # type: ignore[index] proxies = self.proxy_call_hook( hook, outputs, inputs, ) with disable_proxy_modes_tracing(): outputs = [maybe_clone(x) for x in outputs] self.bind_tensors_to_proxies(outputs, proxies) return outputs def post_acc_grad_hook(self, input, hook_id): assert isinstance(input, torch.Tensor) assert self.hooks_proxy is not None hook = self.hooks_proxy[hook_id] # type: ignore[index] proxies = self.proxy_call_hook( hook, input, ) with disable_proxy_modes_tracing(): input = [maybe_clone(input)] self.bind_tensors_to_proxies(input, proxies) return input def end_capture(self, outputs): self.stack.close() self.fx_tracer.create_node( "output", "output", (self.fx_tracer.create_arg(self.to_proxy(outputs)),), {}, ) graph = GraphModule( self.fx_tracer.root, self.fx_tracer.graph, "CompiledAutograd" ) compiled_autograd_log.info( "%s", lazy_format_graph_code("Compiled autograd graph", graph) ) trace_structured( "compiled_autograd_graph", payload_fn=lambda: graph.print_readable(print_output=False), ) return self.compiler_fn(graph) def to_proxy(self, t): if t is None: return None if isinstance(t, list): return [self.to_proxy(x) for x in t] if isinstance(t, tuple): return tuple(self.to_proxy(x) for x in t) assert isinstance(t, (torch.Tensor, torch.SymInt)) return fetch_object_proxy(self.fx_tracer)(t).proxy def bind_tensors_to_proxies(self, tensors, proxies): if isinstance(proxies, torch.fx.Proxy): proxies = [proxies[i] for i in range(len(tensors))] assert len(tensors) == len(proxies) track_tensor_tree(tensors, proxies, constant=None, tracer=self.fx_tracer) def bind_backward_state(self, index: int): assert self.hooks_proxy is not None proxy = self.hooks_proxy[index] # type: ignore[index] bw_state = BackwardState() track_tensor_tree(bw_state, proxy, constant=None, tracer=self.fx_tracer) return bw_state compiled_autograd_enabled = False # We may have code like: # with enable(compiler_fn): # ... # with disable(): # ... # ... # The disable() call just want to disable compiled autograd temporarily. # But overall the feature is enabled. # # The code covered by the disable context manager has no way to know if # compiled autograd is overall eanbled. Use another variable # compiled_autograd_enabled_count to indicate how many times compiled # autograd has been enabled in the call stack for this purpose. compiled_autograd_enabled_count = 0 @contextlib.contextmanager def enable(compiler_fn): prior = torch._C._dynamo.compiled_autograd.set_autograd_compiler( functools.partial(AutogradCompilerInstance, compiler_fn) ) global compiled_autograd_enabled, compiled_autograd_enabled_count compiled_autograd_enabled = True compiled_autograd_enabled_count += 1 try: with torch.autograd.set_multithreading_enabled(False): yield finally: compiled_autograd_enabled_count -= 1 if not prior: compiled_autograd_enabled = False torch._C._dynamo.compiled_autograd.set_autograd_compiler(prior) @contextlib.contextmanager def disable(): prior = torch._C._dynamo.compiled_autograd.set_autograd_compiler(None) global compiled_autograd_enabled compiled_autograd_enabled = False try: yield finally: if prior: compiled_autograd_enabled = True torch._C._dynamo.compiled_autograd.set_autograd_compiler(prior)