import collections import contextlib import copy import functools import itertools import logging import operator import re import sys import traceback import weakref from dataclasses import dataclass from typing import Any, Callable, Dict, List, NamedTuple, Optional, Set, Tuple, Union import sympy import torch._guards import torch._logging import torch.nn import torch.utils._pytree as pytree from torch import fx from torch._guards import ( Checkpointable, GlobalContextCheckpointState, GuardsCheckpointState, Source, TracingContext, ) from torch._utils_internal import signpost_event from torch.fx._lazy_graph_module import _make_graph_module # type: ignore[attr-defined] from torch.fx.experimental._backward_state import BackwardState from torch.fx.experimental.sym_node import SymNode from torch.fx.experimental.symbolic_shapes import free_symbols, is_symbolic, ShapeEnv from torch.utils._python_dispatch import is_traceable_wrapper_subclass from torch.utils._sympy.interp import sympy_interp from torch.utils._sympy.reference import PythonReferenceAnalysis from torch.utils.weak import WeakTensorKeyDictionary from . import config, logging as torchdynamo_logging, variables from .backends.registry import CompiledFn, CompilerFn from .bytecode_transformation import ( create_call_function, create_instruction, Instruction, unique_id, ) from .code_context import code_context from .codegen import PyCodegen from .current_scope_id import enter_new_scope from .exc import ( BackendCompilerFailed, exceptions_allowed_to_be_fallback, SkipFrame, unimplemented, unimplemented_with_warning, ) from .guards import GuardBuilder, install_guard from .mutation_guard import is_dynamic_nn_module from .side_effects import SideEffects from .source import ( AttrSource, BackwardStateSource, ConstantSource, GlobalStateSource, is_constant_source, is_from_local_source, LocalSource, ParamBufferSource, ShapeEnvSource, TensorProperty, TensorPropertySource, ) from .utils import ( checkpoint_params, CleanupHook, clone_inputs, count_calls, counters, dynamo_timed, get_instruction_source_311, get_static_address_type, graph_break_reasons, increment_op_count, lazy_format_graph_code, lazy_format_graph_tabular, LazyString, same, ) from .variables.base import VariableTracker from .variables.builder import ( BackwardStateGraphArg, GraphArg, TrackedFake, VariableBuilder, wrap_fx_proxy, ) from .variables.nn_module import NNModuleVariable from .variables.tensor import ( NumpyNdarrayVariable, SymNodeVariable, TensorVariable, UnspecializedPythonVariable, ) from .variables.torch_function import TensorWithTFOverrideVariable log = logging.getLogger(__name__) graph_tabular_log = torch._logging.getArtifactLogger(__name__, "graph") graph_code_log = torch._logging.getArtifactLogger(__name__, "graph_code") graph_sizes_log = torch._logging.getArtifactLogger(__name__, "graph_sizes") trace_call_log = torch._logging.getArtifactLogger(__name__, "trace_call") class OutputGraphState(NamedTuple): input_source_to_var: Dict[Source, VariableTracker] tracked_fakes: List[TrackedFake] guard_state: GuardsCheckpointState nn_modules: Optional[Dict[str, torch.nn.Module]] register_finalizer_fns: List[Callable[[fx.GraphModule], None]] global_state: Optional[Dict[str, bool]] param_name_to_source: Optional[Dict[str, Source]] side_effects: SideEffects timestamp: int non_compliant_ops: Set[torch._ops.OpOverload] compliant_custom_ops: Set[torch._ops.OpOverload] def diff(self, other: "OutputGraphState", *, prefix: str = "") -> Optional[str]: for k in self._fields: if k == "guard_state": r = self.guard_state.diff(other.guard_state) if r is not None: return r continue elif k == "side_effects": r = self.side_effects.diff(other.side_effects) if r is not None: return r continue sv = getattr(self, k) ov = getattr(other, k) if sv != ov: return f"{prefix}{k} mismatch: {sv} != {ov}" return None # Back compat .guards api @property def guards(self): return self.guard_state.dynamo_guards @functools.lru_cache(None) def _step_logger(): return torchdynamo_logging.get_step_logger(log) @dataclass class GraphCompileReason: """Stores why a given output graph was compiled; i.e. what caused the graph break.""" reason: str user_stack: List[traceback.FrameSummary] # Indicates if this was a graph compile reason due to graph break. graph_break: bool = True def __post_init__(self): if self.graph_break: graph_break_reasons.append(self) def _get_gen_rand_values_fn(random_calls): def _gen_rand_values(): return [fn(*args, **kwargs) for fn, args, kwargs in random_calls] return _gen_rand_values class FakeRootModule(torch.nn.Module): """Trick the constructor of fx.GraphModule""" def __init__(self, nn_modules: Dict[str, torch.nn.Module]): super().__init__() for k, v in nn_modules.items(): setattr(self, k, v) def __repr__(self): return "FakeRootModule(...)" class WrapperBackend: def __init__(self, backend: CompilerFn): self.backend: CompilerFn = backend def __call__(self, gm: torch.fx.GraphModule, example_inputs: List[torch.Tensor]): self.restore = checkpoint_params(gm) self.gm = gm copy_gm = copy.deepcopy(self.gm) self.candidate = self.backend(copy_gm, example_inputs) if self.candidate is None or self.candidate is self.gm.forward: return self.gm.forward if not config.verify_correctness: return self.candidate # if verify_correctness=True try: correct = self.gm.forward(*clone_inputs(example_inputs)) result = self.candidate(*clone_inputs(example_inputs)) # TODO: replace `same` function with the one in testing if same(correct, result): return self.candidate raise RuntimeError(f"incorrect results of backend {self}") return self.gm.forward except Exception: log.exception("error in verify_correctness") raise finally: self.restore() Scope = Dict[str, object] class OutputGraph(Checkpointable[OutputGraphState]): """ Wrapper class to hold outputs of InstructionTranslator. Mainly the generated fx.Graph. OutputGraph is 1:1 with a frame being processed. Each frame is associated with some root InstructionTranslator. When user code calls a function, we construct a InliningInstructionTranslator that continues to write into the root InstructionTranslator's OutputGraph. """ def __init__( self, code_options: Dict[str, Any], compiler_fn: Optional[CompilerFn], root_tx, export: bool, export_constraints, frame_state, local_scope: Scope, global_scope: Scope, f_code, ): super().__init__() self.tracers = [SubgraphTracer(self, export_root=export)] # Map from graph input's `Source` to its `VariableTracker` to # de-duplicate graph inputs by source and reuse the tracker self.input_source_to_var: Dict[Source, VariableTracker] = {} self.export = export self.export_constraints = export_constraints self.frame_state = frame_state self.tensor_weakref_to_sizes_strides = WeakTensorKeyDictionary() self.cleanup_hooks: List[Callable[[], Any]] = [] # compile_id is an id number for the current torch.compile self.compile_id: int = next(_compile_id_counter) # Set of globals installed via install_global* APIs self.installed_globals: Set[str] = set() # TODO: maybe should just pass the entire f_code in here? Not # sure... self.co_fields = { "co_name": f_code.co_name, "co_filename": f_code.co_filename, "co_firstlineno": f_code.co_firstlineno, } # tracked_fakes says where any tensor that was wrapped to fake came # from. It is similar to GraphArg, in that all GraphArgs will get # will get added to TrackedFakes, but TrackedFakes also contains # GraphArgs that got pruned, and things like Tensor attributes which # aren't explicit graph inputs. Used by shape guard self.tracked_fakes: List[TrackedFake] = [] # List of symbols for which we have exact bindings in the arguments # already self.bound_symbols: Set[sympy.Symbol] = set() shape_env = ShapeEnv( # Reference Cycle! # Share a reference to the list of TrackedFake. # # ShapeEnv needs this in order to be able to reproduce the call # to produce_guards at an arbitrary time point. That is because # TrackedFake instances may have its metadata changed throughout # the program execution. tracked_fakes=self.tracked_fakes, allow_scalar_outputs=config.capture_scalar_outputs, allow_dynamic_output_shape_ops=config.capture_dynamic_output_shape_ops, co_fields=self.co_fields, ) # In export mode, we force the shape_env to strictly disallow any constraining # of the user marked dynamic dims fake_mode = torch._subclasses.FakeTensorMode( shape_env=shape_env, # TODO (tmanlaibaatar) Remove this once we always lift params and buffers allow_non_fake_inputs=True if self.export else False, ) self.tracing_context: TracingContext = TracingContext(fake_mode) self.init_ambient_guards() # Map each tensor id to a list of sources. This is necessary because # tensor ids cannot be recovered from tracked fakes (in general). # We use this map to interpret (i.e., check for violations of) constraints, # specifically equality constraints, which have shared tensor ids in them. # This map should also be generally useful, e.g., for (de)serialization. self.tracked_fakes_id_to_source: Dict[ int, List[Source] ] = collections.defaultdict(list) # Stores the full fqn of a param or buffer to the relevant source. self.param_name_to_source: Optional[Dict[str, Source]] = dict() self.side_effects = SideEffects() self.code_options = dict(code_options) self.output_instructions: List[Instruction] = [] # used to track nodes that are added between calls of copy_graphstate # and restore_graphstate self.timestamp = 0 # A list of register_finalizer_fns to apply to the output graph module self.register_finalizer_fns: List[Callable[[fx.GraphModule], None]] = [] # Not checkpointed self.compiler_fn: Optional[CompilerFn] = compiler_fn self.global_scope = global_scope self.local_scope = local_scope self.root_tx = root_tx from torch._dynamo.symbolic_convert import InstructionTranslatorBase # Given a source, what are the user stacks of all locations that # accessed it? # # For efficiency, we only populate this: # - During export, and # - If the source could potentially lead to a spurious export input # # Feel free to populate this more frequently if other use-cases arise, # but be aware that we have to generate full stacks for each # recording! self.source_to_user_stacks: Dict[Source, List[traceback.StackSummary]] = {} self._current_tx: List[InstructionTranslatorBase] = [] self.cleanups: List[CleanupHook] = [] self.should_exit = False self.unspec_variable_map: Dict[str, UnspecializedPythonVariable] = {} self.torch_function_enabled = torch._C._is_torch_function_enabled() # Tracks if the output graph has a user defined allowed function in the # graph. This is used later to determine if we should fallback to eager # for certain exceptions. THe idea is that if the user has applied # allow_in_graph, they would like to see the error instead of falling # back for backend errors. self.has_user_defined_allowed_in_graph = False # Tracks a list of called ops that were not tagged with "pt2_compliant_tag". # This information is useful for logging. self.non_compliant_ops: Set[torch._ops.OpOverload] = set({}) # Tracks a list of called custom ops that were tagged with "pt2_compliant_tag". # This information is useful for logging. self.compliant_custom_ops: Set[torch._ops.OpOverload] = set({}) # We save the global torch state here to be restored in case of graph # breaks. The relevant issue is seen here # https://github.com/pytorch/pytorch/pull/100570#issuecomment-1543427086 # where inlining of a function changes the global state (because of the # presence of torch.no_grad) and there is a graph break. self.save_global_state() # Tracks the original FQNs of the constant tensors from the original graph, # i.e. buffers and parameters. self.dynamo_flat_name_to_original_fqn: Dict[str, str] = {} # All calls to random() are replaced with a single call to __gen_rand_values # functions that returns a tuple of random values for each original call. # random_calls tracks calls to random() and random_values_var stores the name of # the variable that stores __gen_rand_values results. self.random_calls: List[ Tuple[Callable[..., object], Tuple[object, ...], Dict[str, object]] ] = [] self.random_values_var = None # Bytecode to insert right before we call the graph self.pregraph_bytecode: List[Instruction] = [] # Use to pass values to backward hooks when using compiled autograd self.backward_state: Dict[str, VariableTracker] = {} self.backward_state_proxy: Optional[torch.fx.Proxy] = None self.backward_state_var: Optional[str] = None def add_backward_state_hook(self, hook: VariableTracker): name = f"hook{len(self.backward_state)}" assert name not in self.backward_state self.backward_state[name] = hook return name, self.get_backward_state_proxy() def get_backward_state_proxy(self): if self.backward_state_proxy is None: if self.export: unimplemented("backward_state does not support export") self.backward_state_proxy = self.root_tracer.create_graph_input( "dynamo_backward_state", BackwardState, source=BackwardStateSource() ) self.backward_state_proxy.node.meta["grapharg"] = BackwardStateGraphArg() self.backward_state_proxy.node.meta["example_value"] = BackwardState() self.backward_state_var = self.new_var() return self.backward_state_proxy # This gets its own helper function so guards DEBUG logs are more informative def init_ambient_guards(self): # Register a SHAPE_ENV guard to make sure we setup shape guards # that show up in ShapeEnv self.guards.add(ShapeEnvSource().make_guard(GuardBuilder.SHAPE_ENV)) self.guards.add( GlobalStateSource().make_guard(GuardBuilder.DETERMINISTIC_ALGORITHMS) ) self.guards.add(GlobalStateSource().make_guard(GuardBuilder.GRAD_MODE)) self.guards.add(GlobalStateSource().make_guard(GuardBuilder.DEFAULT_DEVICE)) self.guards.add( GlobalStateSource().make_guard(GuardBuilder.TORCH_FUNCTION_STATE) ) self.guards.add(GlobalStateSource().make_guard(GuardBuilder.BACKEND_MATCH)) def add_cleanup_hook(self, fn: Callable[[], Any]): self.cleanup_hooks.append(fn) def call_cleanup_hooks(self): for hook in reversed(self.cleanup_hooks): hook() self.cleanup_hooks.clear() @property def root_tracer(self): return self.tracers[0] @property def current_tracer(self): return self.tracers[-1] def is_root_tracer(self): # Helper to tell if we are inside the higher order operator tracing. return len(self.tracers) == 1 @property def graph(self): return self.current_tracer.graph # TODO(rzou): can delete after we refactor speculate_subgraph to use nested GraphTracer. @graph.setter def graph(self, value): self.current_tracer.graph = value @property def input_name_to_proxy(self): return self.current_tracer.input_name_to_proxy @property def real_value_cache(self): return self.current_tracer.real_value_cache # If you are here, and you're looking for create_graph_input, # to avoid ambiguity, please call one of the following: # - self.current_tracer.create_graph_input # - self.root_tracer.create_graph_input # See NOTE [HigherOrderOperator tracing design] for more context. def create_proxy(self, *args, **kwargs): return self.current_tracer.create_proxy(*args, **kwargs) def create_node(self, *args, **kwargs): return self.current_tracer.create_node(*args, **kwargs) def remove_node(self, *args, **kwargs): return self.current_tracer.remove_node(*args, **kwargs) @contextlib.contextmanager def subtracer(self, source_target, prior_tracer): new_scope_ctx = enter_new_scope() try: if prior_tracer: # Lineage MUST stay preserved assert prior_tracer.parent is self.current_tracer new_scope_ctx.__enter__() tracer = ( prior_tracer if prior_tracer else SubgraphTracer( self, parent=self.current_tracer, source_target=source_target ) ) self.tracers.append(tracer) yield tracer finally: new_scope_ctx.__exit__(None, None, None) self.tracers.pop() @property def output(self): return self @property def fake_mode(self): return self.tracing_context.fake_mode @property def shape_env(self): return self.tracing_context.fake_mode.shape_env @property def guards(self) -> torch._guards.GuardsSet: return self.tracing_context.guards_context.dynamo_guards @property def nn_modules(self) -> Dict[str, Any]: return self.tracing_context.module_context.nn_modules def save_global_state(self, out=None): """ Saves to out if it is provided. Else saves to the tracing context's global_state. """ global_state = ( out if out is not None else self.tracing_context.global_context.global_state ) # TODO - Consider having a torch level API for torch_function_state. As # of now, we create a ref cycle by passing the # output.set_torch_function_state to # output.tracing_context.global_context.global_state. In the interim, # the problem can be solved by manually set # output.tracing_context.global_context.global_state to None at cleanup. global_state["torch_function_enabled"] = ( self.set_torch_function_state, self.torch_function_enabled, ) global_state["grad_enabled"] = (torch.set_grad_enabled, torch.is_grad_enabled()) global_state["autocast_enabled"] = ( torch.set_autocast_enabled, torch.is_autocast_enabled(), ) global_state["autocast_cpu_enabled"] = ( torch.set_autocast_cpu_enabled, torch.is_autocast_cpu_enabled(), ) global_state["autocast_gpu_dtype"] = ( torch.set_autocast_gpu_dtype, torch.get_autocast_gpu_dtype(), ) global_state["autocast_cpu_dtype"] = ( torch.set_autocast_cpu_dtype, torch.get_autocast_cpu_dtype(), ) global_state["autocast_cache_enabled"] = ( torch.set_autocast_cache_enabled, torch.is_autocast_cache_enabled(), ) def push_tx(self, tx): self._current_tx.append(tx) def pop_tx(self): return self._current_tx.pop() @property def current_tx(self): return self.root_tx if not self._current_tx else self._current_tx[-1] def copy_graphstate(self) -> OutputGraphState: """Create a checkpoint of the current state by copying everything""" assert self.param_name_to_source is not None guards_graph_state = self.tracing_context.guards_context.copy_graphstate() module_state = self.tracing_context.module_context.copy_graphstate() global_state = self.tracing_context.global_context.copy_graphstate() state = OutputGraphState( dict(self.input_source_to_var), list(self.tracked_fakes), guards_graph_state, module_state, list(self.register_finalizer_fns), global_state, dict(self.param_name_to_source), self.side_effects.clone(), self.timestamp, set(self.non_compliant_ops), set(self.compliant_custom_ops), ) self.timestamp += 1 return state def restore_graphstate(self, state: OutputGraphState): """Restore a checkpoint created by self.copy_graphstate()""" ( self.input_source_to_var, self.tracked_fakes, guards_state, module_state, self.register_finalizer_fns, global_state, self.param_name_to_source, self.side_effects, self.timestamp, self.non_compliant_ops, self.compliant_custom_ops, ) = state self.tracing_context.guards_context.restore_graphstate(guards_state) self.tracing_context.module_context.restore_graphstate(module_state) self.tracing_context.global_context.restore_graphstate(global_state) # FX deepcopy doesn't work for a partially created graph, so just remove new nodes removed_nodes = 0 for node in reversed(list(self.graph.nodes)): if ( node.meta["creation_timestamp"] > self.timestamp # placeholders here may have been lazily added by existing objects and node.op != "placeholder" ): # Erasing node alone does not remove the meta information # So, remove the help tensor explicitly if "example_value" in node.meta: del node.meta["example_value"] self.remove_node(node) self.real_value_cache.pop(node, None) removed_nodes += 1 log.debug("restore_graphstate: removed %s nodes", removed_nodes) def add_symbol_bindings(self, arg: GraphArg): # Insert implicit size vars as necessary. With dynamic shapes, we # maintain the invariant that every sizevar gets a direct SymInt input # into the graph. This means downstream graph transforms can assume # every size variable is explicitly bound and accessible, instead of # having to pull it out implicitly from tensors. if self.export: return assert arg.fake_tensor is not None def bind_symint(s, prop): if not (is_symbolic(s) and isinstance(s.node.expr, sympy.Symbol)): return s0 = s.node.expr if s0 in self.bound_symbols: return self.bound_symbols.add(s0) log.debug("bind_symint %s %s", s, prop.name()) # TODO: don't readd symint if we already have it in graph # (this is harmless because we do remove the unused ones later) proxy = self.root_tracer.create_graph_input( str(s0), torch.SymInt, before=True, source=prop, ) proxy.node.meta["example_value"] = s proxy.node.meta["grapharg"] = GraphArg( prop, s, is_unspecialized=False, fake_tensor=None, is_tensor=False, ) def handle_tensor(t, src): for i, s in enumerate(t.size()): bind_symint(s, TensorPropertySource(src, TensorProperty.SIZE, i)) for i, s in enumerate(t.stride()): bind_symint(s, TensorPropertySource(src, TensorProperty.STRIDE, i)) bind_symint( t.storage_offset(), TensorPropertySource(src, TensorProperty.STORAGE_OFFSET), ) if is_traceable_wrapper_subclass(t): attrs, ctx = t.__tensor_flatten__() for attr in attrs: inner_t = getattr(t, attr) handle_tensor(inner_t, AttrSource(src, attr)) handle_tensor(arg.fake_tensor, arg.source) def count_calls(self): return count_calls(self.graph) def is_empty_graph(self): return len(list(self.graph.nodes)) == 0 def get_submodule(self, keys): assert keys obj: Union[torch.nn.Module, Dict[str, torch.nn.Module]] = self.nn_modules for k in keys.split("."): if isinstance(obj, dict): obj = obj[k] else: obj = getattr(obj, k) return obj def new_var(self, name="tmp"): existing = set(self.code_options["co_varnames"]) for i in itertools.count(): var = f"{name}_{i}" if var not in existing: self.code_options["co_varnames"] += (var,) return var def update_co_names(self, name): """Ensure self.code_options.co_names contains name""" if name not in self.code_options["co_names"]: self.code_options["co_names"] += (name,) @staticmethod def module_key_name(*names): # create a new unique name name = "_".join(map(str, names)) # Strip the guard lookup L/G access name = re.sub(r"^[GL]\['?(.*?)'?\]$", r"\1", name) # e.g. replace abc.xyz[123].qkv with abc.xyz_123.qkv name = re.sub(r"\[(\d+)\]", r"_\g<1>", name) # e.g. replace abc.xyz_123.qkv with abc_xyz_123_qkv name = re.sub(r"[^a-zA-Z0-9]", "_", name) if not name or not name[0].isalpha(): name = "sub" + name return name def register_attr_or_module( self, target: Union[torch.nn.Module, torch.Tensor, Any], *names, **options, ): if is_dynamic_nn_module(target): return variables.UnspecializedNNModuleVariable(target, **options) options = dict(options) assert "source" in options source = options["source"] assert not isinstance(source, ParamBufferSource) if isinstance(target, torch.Tensor): tracer = self.current_tracer if not self.is_root_tracer(): # For higher order ops, we don't want to insert the get_attr in # innermost graph. Instead, we want to raise the params/buffers # as inputs to the higher-order graph, and register them as # get_attrs in the root tracer. # Note that Dynamo will still call lift_tracked_freevar_to_input # when these inputs are encountered for the inner graph. The # only difference is what happens at the root tracer for # nn.Parameters vs free inputs. The free inputs are registered # as placeholders in the root graph, whereas the nn.Parameters # are registered as get_attr nodes in the root graph. tracer = self.root_tracer if not is_constant_source(source): install_guard(source.make_guard(GuardBuilder.TENSOR_MATCH)) if get_static_address_type(target) == "guarded": install_guard(source.make_guard(GuardBuilder.DATA_PTR_MATCH)) def wrap_name(module_key): assert self.param_name_to_source is not None self.param_name_to_source[module_key] = source return wrap_fx_proxy( self.root_tx, tracer.create_proxy("get_attr", module_key, tuple(), {}), example_value=target, **options, ) elif isinstance(target, torch.nn.Module): assert isinstance(target, torch.nn.Module) install_guard(source.make_guard(GuardBuilder.NN_MODULE)) def wrap_name(module_key): return NNModuleVariable(type(target), module_key, target, **options) elif isinstance(target, (torch.SymInt, torch.SymFloat)): # HACKY CODE REGION BEGIN # WE ARE PIGGYBACKING ON EXISTING INFRA TO REGISTER ATTRS # This ultimately gets written to self.nn_modules, which is unfortunate # Attrs that are tenors and symints and such need to be migrated to have their # own storage # alas, this is like this for now def wrap_name(module_key): return SymNodeVariable.create( self, self.create_proxy("get_attr", module_key, tuple(), {}), sym_num=target, **options, ) # HACKY CODE REGION END else: def wrap_name(module_key): self.output.update_co_names(module_key) self.global_scope[module_key] = target return VariableBuilder(self, ConstantSource(source_name=module_key))( target ) for k, v in self.nn_modules.items(): if v is target: # it already exists return wrap_name(k) name = OutputGraph.module_key_name(*names) base = name for i in itertools.count(): if name not in self.nn_modules: self.nn_modules[name] = target if isinstance(target, torch.nn.Module): def register_leaf_name(leaf_name): assert self.param_name_to_source is not None new_source = ParamBufferSource(source, leaf_name) new_name = f"{name}.{leaf_name}" self.param_name_to_source[new_name] = new_source if isinstance(source, LocalSource): self.dynamo_flat_name_to_original_fqn[ OutputGraph.module_key_name(new_source.name()) ] = leaf_name # annoying, but there are cases when we do not have parameters # see test_nn_moduledict_contains if hasattr(target, "_parameters"): for leaf_name, _ in target.named_parameters(): register_leaf_name(leaf_name) if hasattr(target, "_buffers"): for leaf_name, _ in target.named_buffers(): register_leaf_name(leaf_name) return wrap_name(name) name = f"{base}_{i}" raise AssertionError("unreachable") def compile_subgraph( self, tx, partial_convert=False, reason: Optional[GraphCompileReason] = None ): """ Generate a subgraph to continue execution on user code. Automatically restore live variables. """ assert reason is not None from .decorators import disable self.partial_convert = partial_convert self.compile_subgraph_reason = reason self.should_exit = True log.debug("COMPILING GRAPH due to %s", reason) if not all(block.can_restore() for block in tx.block_stack): unimplemented("compile_subgraph with block_depth != 0") prefix_insts: List[Instruction] = [] if sys.version_info >= (3, 11): # prefix instructions (Python 3.11+) for inst in tx.prefix_insts: if inst.opname == "MAKE_CELL": prefix_insts.append( create_instruction("MAKE_CELL", argval=inst.argval) ) elif inst.opname == "COPY_FREE_VARS": prefix_insts.append( create_instruction( "COPY_FREE_VARS", arg=len(tx.code_options["co_freevars"]) ) ) else: prefix_insts.append(copy.copy(inst)) assert not ( self.pregraph_bytecode and self.export ), "export does not support pregraph_bytecode" prefix_insts.extend(self.pregraph_bytecode) def append_prefix_insts(): self.add_output_instructions(prefix_insts) prefix_insts.clear() for block in reversed(tx.block_stack): block.exit(tx) self.cleanup_graph() tx.prune_dead_locals() stack_values = list(tx.stack) root = FakeRootModule(self.nn_modules) # Add all the local vars to the "stack" so restore at the end restore_vars = [] val_to_names: Dict[VariableTracker, List[str]] = {} if stack_values: val_to_names[stack_values[-1]] = list() # NB: Typically (i.e., for graph compile from RETURN_VALUE), # symbolic_locals will be empty at this point, as prune_dead_locals # will clear out all of symbolic_locals because RETURN_VALUE is the # last instruction and no more locals are used. The fanciness here # is only needed for partial graphs. for k, v in tx.symbolic_locals.items(): # Note! this explicitly uses .local_name for matching # Failure to do so will cause spurious registrations in val_to_names. # This will in turn result in spurious variables showing up in the graph. # This was very tricky to debug. For an example, dump the graph at call_user_compiler # while running test_subgraphs.py if isinstance(v.source, LocalSource) and v.source.local_name == k: continue # no need to restore initial state if v not in val_to_names: val_to_names[v] = list() val_to_names[v].append(k) for v in val_to_names.keys(): restore_vars.extend(val_to_names[v]) stack_values.extend([v] * len(val_to_names[v])) # to handle random calls if len(self.random_calls) > 0: append_prefix_insts() random_calls_instructions = [] self.random_values_var = self.new_var("random_values") rand_fn = disable(_get_gen_rand_values_fn(self.random_calls)) rand_fn_name = self.install_global("__gen_rand_values", rand_fn) codegen = PyCodegen(tx, root) random_calls_instructions.extend( codegen.load_function_name(rand_fn_name, True) ) random_calls_instructions.extend(create_call_function(0, False)) random_calls_instructions.append( codegen.create_store(tx.output.random_values_var), ) self.add_output_instructions(random_calls_instructions) if ( stack_values and all( not isinstance( v, ( UnspecializedPythonVariable, NumpyNdarrayVariable, TensorWithTFOverrideVariable, ), ) for v in stack_values ) and all(isinstance(x, TensorVariable) for x in stack_values) and len(set(stack_values)) == len(stack_values) and self.side_effects.is_empty() and not len(tx.debug_locals) != 0 and not self.backward_state ): append_prefix_insts() # optimization to generate better code in a common case self.add_output_instructions( self.compile_and_call_fx_graph(tx, list(reversed(stack_values)), root) + [create_instruction("UNPACK_SEQUENCE", arg=len(stack_values))] ) else: graph_output_var = self.new_var("graph_out") pass1 = PyCodegen(tx, root, graph_output_var) self.codegen_suffix(tx, stack_values, pass1) # one more time now that we have established tempvars pass2 = PyCodegen( tx, root, graph_output_var, tempvars={val: None for val, count in pass1.uses.items() if count > 1}, ) self.codegen_suffix(tx, stack_values, pass2) output = [] if count_calls(self.graph) != 0 or len(pass2.graph_outputs) != 0: output.extend( self.compile_and_call_fx_graph(tx, pass2.graph_output_vars(), root) ) if len(pass2.graph_outputs) != 0: output.append(pass2.create_store(graph_output_var)) else: output.append(create_instruction("POP_TOP")) append_prefix_insts() self.add_output_instructions(output + pass2.get_instructions()) # restore all the live local vars self.add_output_instructions( [PyCodegen(tx).create_store(var) for var in reversed(restore_vars)] ) def codegen_suffix(self, tx, stack_values, cg): if self.backward_state: assert not self.export for name, val in self.backward_state.items(): cg(val) cg.append_output(cg.create_load(self.backward_state_var)) cg.store_attr(name) self.side_effects.codegen_hooks(cg) self.side_effects.codegen_save_tempvars(cg) # Return variables used for logging at the end for debug_var, args in tx.debug_locals: cg(debug_var) for arg in args: cg(arg) cg.extend_output(create_call_function(len(args), True)) cg.restore_stack(stack_values, value_from_source=not tx.export) self.side_effects.codegen_update_mutated(cg) def cleanup_graph(self): """ Remove "creation_timestamp" from node meta Remove this pattern from the graph: torch._C._set_grad_enabled(False) torch._C._set_grad_enabled(True) """ assert self.should_exit nodes = list(self.graph.nodes) for node in nodes: node.meta.pop("creation_timestamp", None) grad_enabled = torch.is_grad_enabled() for node1, node2 in zip(nodes, nodes[1:]): if ( node1.target is torch._C._set_grad_enabled and tuple(node1.args) == (not grad_enabled,) and not node1._erased ): grad_enabled = node1.args[0] if ( node2.target is torch._C._set_grad_enabled and tuple(node2.args) == (not grad_enabled,) and not node2._erased ): grad_enabled = node2.args[0] self.graph.erase_node(node1) self.graph.erase_node(node2) def get_graph_sizes_structured(self): ret = {} for node in self.graph.nodes: example_value = node.meta.get("example_value", None) if isinstance(example_value, torch._subclasses.FakeTensor): size = example_value.size() ret[node.name] = [s if isinstance(s, int) else repr(s) for s in size] return ret def get_graph_sizes(self, name: str): graph_sizes_str = "TRACED GRAPH TENSOR SIZES\n" graph_sizes_str += f"===== {name} =====\n" for node in self.graph.nodes: example_value = node.meta.get("example_value", None) if isinstance(example_value, torch._subclasses.FakeTensor): size = example_value.size() graph_sizes_str += f"{node.name}: {tuple(size)}\n" concrete_size = [] has_symint = False for sz in size: if isinstance(sz, int): concrete_size.append(sz) elif isinstance(sz, torch.SymInt): has_symint = True concrete_size.append(sz.node.hint) else: break else: if has_symint: graph_sizes_str += ( f"{node.name} (concrete): {tuple(concrete_size)}\n" ) return graph_sizes_str @contextlib.contextmanager def restore_global_state(self): """ Momentarily restores the global state to what it was prior to tracing the current output """ prior_global_state = self.tracing_context.global_context.copy_graphstate() current_global_state: Dict[str, Tuple[Any, bool]] = {} self.save_global_state(out=current_global_state) try: # Set to state prior to tracing the graph self.tracing_context.global_context.restore_graphstate(prior_global_state) yield finally: # Reset to state at the current time (e.g. before calling the user compiler) self.tracing_context.global_context.restore_graphstate( GlobalContextCheckpointState(current_global_state) ) @torch._guards.TracingContext.clear_frame() def compile_and_call_fx_graph(self, tx, rv, root): """ Generate code from self.graph and return the Instruction()s to call that generated code. """ from .decorators import disable assert self.should_exit name = unique_id("__compiled_fn") assert isinstance(rv, list) assert isinstance(root, FakeRootModule) self.create_node( "output", "output", (self.current_tracer.create_arg(tuple(x.as_proxy() for x in rv)),), {}, ) self.insert_deferred_runtime_asserts(root, name) # NB: deferred runtime asserts can keep graphargs live, so make sure # those are inserted before pruning self.remove_unused_graphargs() ncalls = count_calls(self.graph) counters["stats"]["calls_captured"] += ncalls # free a bit of memory self.real_value_cache.clear() gm = _make_graph_module(root, self.graph) for register_finalizer in self.register_finalizer_fns: register_finalizer(gm) gm.compile_subgraph_reason = self.compile_subgraph_reason gm.meta[ "dynamo_flat_name_to_original_fqn" ] = self.dynamo_flat_name_to_original_fqn.copy() graph_code_log.debug("%s", lazy_format_graph_code(name, gm)) torch._logging.trace_structured( "dynamo_output_graph", lambda: {"sizes": self.get_graph_sizes_structured()}, payload_fn=lambda: gm.print_readable(print_output=False), ) graph_tabular_log.debug("%s", lazy_format_graph_tabular(name, gm)) graph_sizes_log.debug("%s", LazyString(lambda: self.get_graph_sizes(name))) self.call_cleanup_hooks() old_fake_mode = self.tracing_context.fake_mode if not self.export: # TODO(voz): The way export uses gm, and fake tensors, is not supported with us resetting backend_fake_mode = torch._subclasses.FakeTensorMode( shape_env=old_fake_mode.shape_env, ) # TODO(voz): Ostensibily, this should be scoped and # restore back to old_fake_mode, but doing so currently violates # a lot of fake_tensor ownership assumptions and runs afoul of detect_fake_mode self.tracing_context.fake_mode = backend_fake_mode with self.restore_global_state(): compiled_fn = self.call_user_compiler(gm) compiled_fn = disable(compiled_fn) counters["stats"]["unique_graphs"] += 1 # This is safe because we pre-process name to be unique self.install_global_unsafe(name, compiled_fn) cg = PyCodegen(tx) cg.make_call_generated_code(name) return cg.get_instructions() @property def placeholders(self) -> List[fx.Node]: r = [] for node in self.graph.nodes: if node.op == "placeholder": r.append(node) continue break return r @property def graphargs(self) -> List[GraphArg]: return [node.meta["grapharg"] for node in self.placeholders] @dynamo_timed(phase_name="backend_compile") def call_user_compiler(self, gm: fx.GraphModule) -> CompiledFn: assert self.compiler_fn is not None tot = 0 placeholders = [] for node in gm.graph.nodes: if node.op in ("call_function", "call_method", "call_module"): tot += 1 if node.op == "placeholder": placeholders.append(node) increment_op_count(tot) for pl in placeholders: arg = pl.meta["grapharg"] # TODO: Why isn't this stored in meta :think: pl._dynamo_source = arg.source gm._param_name_to_source = self.param_name_to_source # type: ignore[assignment] gm._source_to_user_stacks = self.source_to_user_stacks # type: ignore[assignment] try: name = ( self.compiler_fn.__name__ if hasattr(self.compiler_fn, "__name__") else "" ) _step_logger()(logging.INFO, f"calling compiler function {name}") compiler_fn = self.compiler_fn if config.verify_correctness: compiler_fn = WrapperBackend(compiler_fn) compiled_fn = compiler_fn(gm, self.example_inputs()) _step_logger()(logging.INFO, f"done compiler function {name}") assert callable(compiled_fn), "compiler_fn did not return callable" except exceptions_allowed_to_be_fallback as e: if self.has_user_defined_allowed_in_graph: raise BackendCompilerFailed(self.compiler_fn, e).with_traceback( e.__traceback__ ) from None msg = ( "Backend compiler failed with a fake tensor exception at \n" f"{self.root_tx.format_frame_summary()}" "Adding a graph break." ) unimplemented_with_warning(e, self.root_tx.f_code, msg) except SkipFrame as e: # The backend compiler has requested that we skip the frame, instead of # aborting execution. raise e except Exception as e: raise BackendCompilerFailed(self.compiler_fn, e).with_traceback( e.__traceback__ ) from None signpost_event( "dynamo", "OutputGraph.call_user_compiler", { **self.co_fields, "op_count": tot, "node_count": len(gm.graph.nodes), "input_count": len(placeholders), }, ) return compiled_fn def example_inputs(self) -> List[torch.Tensor]: result = [] for arg in self.graphargs: result.append(arg.example) return result def remove_unused_graphargs(self) -> None: assert self.should_exit # Miniature DCE pass, but only for obviously trivial operations for node in reversed(list(self.graph.nodes)): if len(list(node.users)) == 0: if node.op == "get_attr": self.remove_node(node) elif node.op == "call_function" and node.target is operator.getitem: self.remove_node(node) def placeholder_binds_symbol(node): arg = node.meta["grapharg"] example = arg.example if isinstance(example, torch.SymInt) and isinstance( example.node.expr, sympy.Symbol ): return example.node.expr return None def remove_unused(node): log.debug("REMOVE UNUSED GRAPHARG %s", node.meta["grapharg"].source.name()) # I'm not really sure why you need to delete these from the # node since the node is going to get removed del node.meta["grapharg"] self.remove_node(node) self.real_value_cache.pop(node, None) used_symbols = set() recheck_placeholders = [] for node in self.placeholders: binds_symbol = placeholder_binds_symbol(node) is not None # Don't delete symbol bindings yet if binds_symbol: if not node.users: recheck_placeholders.append(node) else: if not node.users and not isinstance( node.meta["grapharg"], BackwardStateGraphArg ): remove_unused(node) else: # Register the free symbols as uses arg = node.meta["grapharg"] if isinstance(arg, BackwardStateGraphArg): continue fake = ( arg.fake_tensor if arg.fake_tensor is not None else arg.example ) used_symbols |= free_symbols(fake) # After removing unused graphargs, prune unused binds_symbol for node in recheck_placeholders: symbol = placeholder_binds_symbol(node) if symbol is not None: if symbol not in used_symbols: remove_unused(node) else: # Make sure we delete later occurrences of the same symbol used_symbols.remove(symbol) # TODO: this is a generic pass that should live outside of Dynamo def insert_deferred_runtime_asserts(self, root, name) -> None: """ During tracing, we may have discovered that some data-dependent values had runtime assert on them; e.g., torch.empty(x.item()) induces a runtime that x.item() >= 0. This asserts can happen unpredictably during fake tensor propagation, so we cannot conveniently insert them into the FX graph when they occur. Instead, we accumulate them in the ShapeEnv, and in this pass insert them into the graph as proper tests. """ # TODO: Request simplification on runtime asserts before emitting them ras_by_symbol = self.shape_env.deferred_runtime_asserts.copy() if not any(ras for ras in ras_by_symbol.values()): return gm = fx.GraphModule(root, self.graph) graph_code_log.debug( "%s", lazy_format_graph_code(f"pre insert_deferred_runtime_asserts {name}", gm), ) # We are going to mutate the dict symbol_to_proxy = {} placeholders = set() last_placeholder = None for node in self.graph.nodes: if node.op != "placeholder": last_placeholder = node break placeholders.add(node) assert last_placeholder is not None # Identify what symbols we need to reify. This isn't strictly needed # but helps reduce churn on the graph needed_symbols: Set[sympy.Symbol] = set() for ras in ras_by_symbol.values(): for ra in ras: needed_symbols.update(free_symbols(ra.expr)) log.debug("needed_symbols = %s", needed_symbols) for node in self.graph.nodes: # Placeholders can match symbols, but when we destructure them # with size we have to make sure we insert the nodes after all # the placeholders with self.graph.inserting_before( node.next if node not in placeholders else last_placeholder.next ): if "example_value" not in node.meta: continue defs = [] # For every new unbacked symbol, we need an fx.Node representing # precisely this value. There are a few places where the unbacked # symbol could have come from, and we will check them to setup # these nodes. # # For a case like item(), this is trivial (no new node is added.) # # For nonzero(), we need to add something like i0 = out.size(0) # # We could end up with duplicate nodes this way but it is not a # big deal. # # We also do this to setup backed SymInts, but those are all going # to be matched from placeholders def match_symbol(symint, cb): if ( isinstance(symint, torch.SymInt) and isinstance(symint.node, SymNode) and isinstance(s := symint.node.expr, sympy.Symbol) and s not in symbol_to_proxy and s in needed_symbols ): symbol_to_proxy[s] = fx.Proxy(cb()) log.debug("symbol_to_proxy[%s] = %s", s, symbol_to_proxy[s]) defs.append(s) match_symbol(node.meta["example_value"], lambda: node) if isinstance(t := node.meta["example_value"], torch.Tensor): for i, s in enumerate(t.size()): match_symbol( s, lambda: self.graph.call_method("size", (node, i)) ) for i, s in enumerate(t.stride()): match_symbol( s, lambda: self.graph.call_method("stride", (node, i)) ) match_symbol( t.storage_offset(), lambda: self.graph.call_method("storage_offset", (node,)), ) for i0 in defs: ras = ras_by_symbol.pop(i0, []) # Before we perform any asserts, first apply range # refinement. This is important, because if we are going # to retrace the graph (and we typically are if we send # the graph to AOTAutograd), we need to make sure we apply # range refinement (ala _check_is_size) first, BEFORE we # run any of the asserts. Otherwise, we may decide to # perform substitutions based on the asserts which we then # can't back out, because value ranges can only be applied # to asserts.) # # A perhaps better long term plan is to avoid this order # dependence by making it possible to refine ranges on # arbitrary expressions, not just symbols. But it is not # so easy to make use of this information, see # https://twitter.com/ezyang/status/1745801370299482492 # We actually made an attempt at this in # https://github.com/pytorch/pytorch/pull/119043 # which didn't work. # # Another ideas for how to do this: # - Have bound_sympy be the source of truth of the ranges of any expression # - Cache intermediate results for every subexpression of bound_sympy # - This cache should be possible to edit to refine ranges # # One issue with this proposal is that if # we have a bound on 2x, we are not going to be able to # apply it for 4x. Similarly, we may have bounds for an # equivalent expression that we are not applying because # it's not a perfect match (e.g. x < y vs y > x)". # # The first issue we already have it and it's impossible # to solve in general, so any implementation on a best # effort basis should do. # # The second issue is a preexisting one. It can be mitigated # with a normalisation algorithm. In general, it may also # be on a best effort basis, but since our grammar is not # terribly difficult, chances are we could even fully # normalise SymPy expressions... who knows. if i0 in self.shape_env.size_like: self.graph.call_function( torch._check_is_size, (symbol_to_proxy[i0].node,) ) vr = self.shape_env.var_to_range[i0] if not self.shape_env._default_unspecified_value_range().issubset( vr ): # The runtime range is constrained, so add a runtime # assert and also explicitly refine the range # (refinement should not be necessary once runtime # asserts cause refinement, but that's NYI) def convert(s): try: return int(s) except TypeError: return None self.graph.call_function( torch._constrain_as_value, ( symbol_to_proxy[i0].node, convert(vr.lower), convert(vr.upper), ), ) for ra in ras: log.debug("inserting runtime assert %s", ra.expr) # Need to process ALL free symbols, not just unbacked ones fvs = free_symbols(ra.expr) missing = fvs - symbol_to_proxy.keys() if missing: i1 = sorted(missing)[0] # TODO: Remove relaxing assert on unbacked_symint https://github.com/pytorch/pytorch/issues/119689 # assert self.shape_env.is_unbacked_symint(i1), i1 ras_by_symbol.setdefault(i1, []).append(ra) else: # Convert the sympy expression into a sequence of FX # nodes res = sympy_interp( PythonReferenceAnalysis, symbol_to_proxy, ra.expr ).node self.graph.call_function( torch.ops.aten._assert_scalar.default, # TODO: use ra.msg here, but it's pretty # useless right now ( res, f"Deferred runtime assertion failed {ra.expr}", ), ) def add_output_instructions(self, prefix: List[Instruction]) -> None: """ We call this on the creation of a new compiled subgraph that is inserted before user code. """ self.output_instructions.extend(prefix) self.should_exit = True def install_global_unsafe(self, name, value) -> None: """ WARNING: prefer the safer `install_global_by_id/install_global`. torch.compile instances should be independent of each other; one footgun is to have one instance depend on the existence of a global installed by another instance. This can happen if we mangle a global the same way across both instances. """ assert name not in self.installed_globals self.installed_globals.add(name) self.cleanups.append(CleanupHook.create(self.global_scope, name, value)) def install_global_by_id(self, prefix, value) -> str: """ Installs a global if it hasn't been installed already. This is determined by (prefix, id(value)) pair. Returns the name of the newly installed global. """ # NB: need self.compile_id to distinguish this global # from another global created in a different torch.compile instance name = f"{prefix}_{id(value)}_c{self.compile_id}" if name in self.installed_globals: return name self.install_global_unsafe(name, value) return name def install_global(self, prefix, value) -> str: """ Installs a global, generating a unique name for it. Returns the name of the newly installed global. """ # NB: unique_id is unique, even across torch.compile instances name = unique_id(prefix) self.install_global_unsafe(name, value) return name def cleanup(self) -> None: # There is a reference cycle between tracer and OutputGraph, causing # some of the tensor objects to be held alive for longer than necessary. self.root_tx = None self.nn_modules.clear() self.param_name_to_source = None for node in self.graph.nodes: if "grapharg" in node.meta: del node.meta["grapharg"] self.real_value_cache.clear() self.input_name_to_proxy.clear() self.side_effects.clear() self.register_finalizer_fns.clear() self.dynamo_flat_name_to_original_fqn.clear() self.tracing_context.clear() def set_torch_function_state(self, enabled: bool) -> None: self.torch_function_enabled = enabled def add_graph_finalizer( self, register_finalizer: Callable[[fx.GraphModule], None] ) -> None: self.register_finalizer_fns.append(register_finalizer) def example_value_from_input_node(self, node: torch.fx.Node): """Extract the non-fake example tensor""" if node.op == "placeholder": return node.meta["grapharg"].example assert node.op == "get_attr" return self.nn_modules[node.target] # type: ignore[index] err_epilogue = ( "With the current config, we will graph break " "(and fall back to eager-mode PyTorch) on all ops " "that have do not have the 'pt2_compliant_tag'. " "Please see the following doc for how to mark this op as PT2 compliant " "https://docs.google.com/document/d/1W--T6wz8IY8fOI0Vm8BF44PdBgs283QvpelJZWieQWQ" ) def check_pt2_compliant_op(output_graph, kind, target, args, kwargs): if kind != "call_function": return def encountered_compliant_op(target): if target.namespace in {"prim", "prims", "aten"}: return output_graph.compliant_custom_ops.add(target) def encountered_non_compliant_op(target, msg): output_graph.non_compliant_ops.add(target) if config.only_allow_pt2_compliant_ops: unimplemented(msg + " " + err_epilogue) if isinstance(target, torch._ops.OpOverload): if torch.Tag.pt2_compliant_tag in target.tags: encountered_compliant_op(target) return encountered_non_compliant_op( target, f"Encountered the torch.ops.OpOverload {target} " f"that is not PT2 compliant.", ) return if isinstance(target, torch._ops.OpOverloadPacket): overloads = tuple(target.overloads()) # Optimization: Overload resolution is expensive. # If there's only one overload, we know what it will resolve to. if len(overloads) == 1: op = getattr(target, overloads[0]) if torch.Tag.pt2_compliant_tag in op.tags: encountered_compliant_op(op) return encountered_non_compliant_op( op, f"Encountered the non-overloaded " f"torch.ops.OpOverloadPacket {target} " f"that is not PT2 compliant. ", ) return args, kwargs = torch._dynamo.utils.get_fake_values_from_nodes( output_graph.current_tx, (args, kwargs), False ) try: overload = torch._C._jit_resolve_packet( target._qualified_op_name, *args, **kwargs ) except RuntimeError as e: unimplemented(str(e)) op = getattr(target, overload) if torch.Tag.pt2_compliant_tag in op.tags: encountered_compliant_op(op) else: encountered_non_compliant_op( op, f"Encountered the torch.ops.OpOverloadPacket {target} " f"which resolves to the overload ({overload}) that is " f"not PT2 compliant.", ) _compile_id_counter = itertools.count() class SubgraphTracer(fx.Tracer): """ Holds an FX graph that is being traced. OutputGraph owns a SubgraphTracer and the separation of responsibilities is that SubgraphTracer is responsible for building the graph while OutputGraph is responsible for compiling and executing the graph. """ def __init__( self, output_graph, parent=None, export_root=False, source_target=None ): super().__init__() self.output_graph = weakref.proxy(output_graph) self.graph = torch.fx.Graph() # The export is only ever set for the ROOT tracer. It controls # whether or not certain inputs are allowed to be added or not. # Look at call sites of create_graph_input to see how it is used. if export_root: assert parent is None self.export_root = export_root # Map from graph input name to its placeholder proxy object, where the # map's keys give all current placeholder node names and can be used to # create unique node names self.input_name_to_proxy: Dict[str, fx.Proxy] = {} # Node => computed real value (see utils.get_real_value) self.real_value_cache: Dict[fx.Node, torch.Tensor] = {} # SubgraphTracers can be nested. See NOTE [HigherOrderOperator tracing design] self.parent = parent # A dict mapping previously free variables (Proxy objects) # to new Proxy objects that wrap inputs to this subgraph. # # This dict serves two purposes: # - Proxies are associated with VariableTrackers. If we see # the same VariableTracker twice (and it is a free variable), # then we want to use the same Proxy in the current subgraph to # record the tracing. # - If we are tracing a HigherOrderOperator's body_fn, then we # need to keep track of what free variables were lifted so we can # rewrite the HigherOrderOperator call using the traced body_fn. # Dicts maintain the order of args for the HigherOrderOperator call. self.lifted_freevars = {} self.prev_inst = None self._cur_code = None self._orig_gm_meta = None self._orig_gm_lineno_map = None self._orig_gm_firstlineno = None # Each SubgraphTracer is associated with a source target, which indicates # which operator this subgraph is attached to. We compute a source_fn_stack # based on the source target. For the root tracer, it's set to []. # This is useful for debugging and transforming the exported graph. if self.parent is None: self.source_fn_stack = [] else: self.source_fn_stack = self.parent.source_fn_stack + [ (self.graph._target_to_str(source_target), source_target) ] def create_proxy( self, kind, target, args, kwargs, name=None, type_expr=None, proxy_factory_fn=None, ): # NOTE: [Nested SubgraphTracer and free_variable handling] # -------------------------------------------------------- # Read NOTE [HigherOrderOperator tracing design] first. # # Let's say we're in the middle of introspecting the body of a possibly # nested HigherOrderOperator, and we see a free variable. # # There are two cases: # 1. We see a free variable that is already tracked by Dynamo. # 2. We see a free variable that has not been tracked by Dynamo # # In case 1, we call `maybe_lift_tracked_freevar_to_input` (below) # which will lift the freevar to be an input of this subgraph # and also recursively lift it to be an input on the parent(s). # # In case 2, before the call to `create_proxy`, the InstructionTranslator # will see the freevar when it gets loaded by Python bytecode. # E.g. for Python 3.11 the bytecodes that may do this are LOAD_DEREF or # LOAD_GLOBAL. # There, the InstructionTranslator asks Dynamo to begin tracking the # freevar by building a new Variable. # Building a new Variable automatically lifts the freevar to be an # input of the root SubgraphTracer. # # The implications for the code below are: # - We will always be in Case 1 when we get to this code. # - Any "free variable" we encounter here is guaranteed to already be # bound, that is, it is either a graph input of the root graph, or # some local variable of the root graph or a subgraph. # - The additional work we need to do here is *only* that we need to # lift this free variable into inputs (recursively) of each nested # higher-order-op subgraph until we hit the subgraph where the free # variable is bound if self.parent is not None: flat_args, tree_spec = pytree.tree_flatten((args, kwargs)) new_flat_args = [] for arg in flat_args: maybe_new_arg = self.maybe_lift_tracked_freevar_to_input(arg) new_flat_args.append(maybe_new_arg) args, kwargs = pytree.tree_unflatten(new_flat_args, tree_spec) rv = super().create_proxy( kind, target, args, kwargs, name, type_expr, proxy_factory_fn ) # append stack trace to fx node tx = self.output_graph.current_tx # log detailed location of line of code in 3.11 if sys.version_info >= (3, 11) and kind in ( "call_function", "call_method", "call_module", ): cur_inst = tx.current_instruction if ( cur_inst is not self.prev_inst and cur_inst.positions is not None and cur_inst.positions.lineno is not None ): tx_code = tx.f_code header = tx.get_line_of_code_header(lineno=cur_inst.positions.lineno) def get_trace_call_log_str(): line = get_instruction_source_311(tx_code, cur_inst).rstrip() return f"TRACE FX call {rv.node.name} from {header}\n{line}" trace_call_log.debug("%s", LazyString(get_trace_call_log_str)) self.prev_inst = cur_inst # update reference to original meta if we're tracing a new code object is_retracing = False if tx.f_code is not self._cur_code: orig_graphmodule_maybe = code_context.get_context(tx.f_code).get( "orig_graphmodule", lambda: None )() if isinstance(orig_graphmodule_maybe, torch.fx.GraphModule): is_retracing = True self._orig_gm_meta = [ nd.meta for nd in orig_graphmodule_maybe.graph.nodes ] self._orig_gm_lineno_map = orig_graphmodule_maybe._lineno_map self._orig_gm_firstlineno = ( orig_graphmodule_maybe.forward.__code__.co_firstlineno ) else: self._orig_gm_meta = None self._orig_gm_lineno_map = None self._orig_gm_firstlineno = None nn_module_stack = tx.nn_module_stack if nn_module_stack: rv.node.meta["nn_module_stack"] = nn_module_stack.copy() if kind in {"call_function", "call_method"}: rv.node.meta["source_fn_stack"] = self.source_fn_stack + [ (rv.node.name, target) ] elif kind == "call_module": if self.parent is not None: unimplemented("Invoking an nn.Module inside HigherOrderOperator") # For modules we store the class rv.node.meta["source_fn_stack"] = self.source_fn_stack + [ ( rv.node.name, rv.node.meta["nn_module_stack"][target][1], ) ] # preserve original meta if it is available if ( self._orig_gm_meta and self._orig_gm_lineno_map and self._orig_gm_firstlineno ): lineno = tx.current_instruction.starts_line node_idx = None if lineno is not None: node_idx = self._orig_gm_lineno_map.get( lineno - self._orig_gm_firstlineno, None ) if node_idx is not None: meta = self._orig_gm_meta[node_idx] for field in fx.proxy._COPY_META_FIELDS: if field in meta: rv.node.meta[field] = meta[field] if "stack_trace" in meta: rv.node.meta["stack_trace"] = meta["stack_trace"] if not is_retracing: if "nn_module_stack" not in rv.node.meta: nn_module_stack = tx.nn_module_stack if nn_module_stack: rv.node.meta["nn_module_stack"] = nn_module_stack.copy() if "source_fn_stack" not in rv.node.meta: if kind in {"call_function", "call_method"}: rv.node.meta["source_fn_stack"] = self.source_fn_stack + [ (rv.node.name, target) ] elif kind == "call_module": if self.parent is not None: unimplemented( "Invoking an nn.Module inside HigherOrderOperator" ) # For modules we store the class rv.node.meta["source_fn_stack"] = self.source_fn_stack + [ ( rv.node.name, rv.node.meta["nn_module_stack"][target][1], ) ] if "stack_trace" not in rv.node.meta: frame_summaries: List[traceback.FrameSummary] = [] while tx: frame_summaries.append(tx.frame_summary()) tx = getattr(tx, "parent", None) # Reverse the frame_summaries, such that the innermost frame is at the last frame_summaries.reverse() # official from_list stub doesn't have new-style type msgs = traceback.StackSummary.from_list(frame_summaries).format() rv.node.stack_trace = "".join(msgs) return rv def create_node( self, op, target, args=None, kwargs=None, name=None, type_expr=None ): check_pt2_compliant_op(self.output_graph, op, target, args, kwargs) if self.parent is not None: flat_args = pytree.arg_tree_leaves(*args, **kwargs) for arg in flat_args: if not isinstance(arg, torch.fx.Node): continue assert ( arg.graph == self.graph ), "create_node using arg not from this SubgraphTracer" node = super().create_node(op, target, args, kwargs, name, type_expr) node.meta["creation_timestamp"] = self.output_graph.timestamp return node # Note: we did not override erase_node since # we call self.graph.erase_node elsewhere def remove_node(self, node): if len(node.users) > 0: user_graph_nodes: List[torch.fx.Node] = [] for user in node.users.keys(): # For the case where user.graph == self.graph, that is a real bug and will raise # properly. if user.graph != self.graph: # This is a nested graph, which needs to be deleted. # If we do not do this, we will raise on attempting to remove this. # As we only get here during restoration cleanup, this is sound. user_graph_nodes.extend(reversed(list(user.graph.nodes))) for other_graph_node in user_graph_nodes: other_graph_node.graph.erase_node(other_graph_node) self.graph.erase_node(node) self.input_name_to_proxy.pop(node.name, None) # when before=True, we will insert this input before the most recent # inserted proxy. This is a hack to get around an ordering problem, # where we first insert a tensor argument, and then insert bindings # for SymInts that may occur in the tensor argument. # Remove this if https://github.com/pytorch/pytorch/issues/99007 gets # fixed. def create_graph_input(self, name, type_expr=None, before=False, source=None): log.debug( "create_graph_input %s %s", name, source.name() if source is not None else "(none)", ) if source is None: assert ( self.parent is not None ), "you are required to provide a source for inputs on the root tracer" # In eager, we are generally OK with adding graph inputs whenever we # want, because we take care of writing the bytecode that knows how # to source all the inputs. # # In export, this is bad, because you want a self-contained export # object which only depends on the inputs you explicitly passed to it. # So we are a bit more strict about what sources can become inputs # in export if self.export_root: if not is_from_local_source(source, allow_cell_or_freevar=False): self.output_graph.source_to_user_stacks.setdefault(source, []).append( TracingContext.extract_stack() ) # unique if name in self.input_name_to_proxy: for i in itertools.count(): candidate_name = f"{name}_{i}" if candidate_name not in self.input_name_to_proxy: name = candidate_name break if self.input_name_to_proxy: prev_name = next(reversed(self.input_name_to_proxy)) node = self.input_name_to_proxy[prev_name].node if before: ctx = self.graph.inserting_before(node) else: ctx = self.graph.inserting_after(node) else: ctx = self.graph.inserting_before(None) with ctx: proxy = self.create_proxy("placeholder", name, (), {}, type_expr=type_expr) if self.input_name_to_proxy and before: k, v = self.input_name_to_proxy.popitem() self.input_name_to_proxy[name] = proxy self.input_name_to_proxy[k] = v else: self.input_name_to_proxy[name] = proxy return proxy # See NOTE: [Nested SubgraphTracer and free_variable handling] for more details def lift_tracked_freevar_to_input(self, proxy): # You're doing something wrong if we are the root SubgraphTracer because # Dynamo adds tensors to graph inputs before creating a proxy for them. assert ( self.parent is not None ), "lift_tracked_freevar_to_input should not be called on root SubgraphTracer" # Proxys are associated with VariableTracker. # It is possible that we've already lifted the Proxy to be an input. # If that is the case, just return the already lifted Proxy. if proxy in self.lifted_freevars: return self.lifted_freevars[proxy] new_proxy = self.create_graph_input(proxy.node.name) new_proxy.node.meta["example_value"] = proxy.node.meta["example_value"] self.lifted_freevars[proxy] = new_proxy if self.parent is not None and proxy.tracer != self.parent: self.parent.lift_tracked_freevar_to_input(proxy) return new_proxy def maybe_lift_tracked_freevar_to_input(self, arg): """ If arg is a free variable, then lift it to be an input. Returns the new lifted arg (if arg was a freevar), else the original arg. """ if not isinstance(arg, torch.fx.Proxy): return arg elif arg.tracer == self: return arg return self.lift_tracked_freevar_to_input(arg) # NOTE: [HigherOrderOperator tracing design] # Ignoring HigherOrderOperators for a moment, # OutputGraph represents the graph being built by Dynamo that may be compiled # and executed. It holds a root SubgraphTracer where the FX graph is built. # # HigherOrderOperators are operators that take functions as their arguments. # When Dynamo encounters a HigherOrderOperator, then it attempts to introspect # the function passed to it (call this the "body function"), capture it into a # GraphModule, and rewrite the call to the HigherOrderOperator to use the # GraphModule. # # The way we handle the capture of body functions is through having # (possibly nested) SubgraphTracers, one per body function. # # Mechanically, we do the introspection by: # - Creating a new SubgraphTracer via OutputGraph.subtracer # - Executing the body function. # This constructs the graph of the body function in the new SubgraphTracer # while modifying the state of the OutputGraph. For example: # - the OutputGraph can receive new GraphArgs (if we discover any new # untracked Tensors) # - side effects from the body function get accumulated into # OutputGraph.side_effects # - guards produced by the body function get accumulated into OutputGraph.guards # # The traced function has some special properties that make it easier for us # to transform later down the line: # - we lift all free variables to being inputs. # # If the introspection fails (due to the existence of graph breaks), then # we roll back the current OutputGraph state and graph break on the # HigherOrderOperator.