# mypy: ignore-errors import dataclasses import functools from importlib import import_module from typing import Any, List, Optional from functorch.compile import min_cut_rematerialization_partition import torch from torch import _guards from torch._functorch.compilers import ts_compile from .common import aot_autograd from .registry import register_debug_backend as register_backend """ This file contains TorchDynamo backends intended for debugging uses. """ @register_backend def eager(gm, fake_tensor_inputs): return gm @register_backend def pre_dispatch_eager(gm, fake_tensor_inputs): from torch.fx.experimental.proxy_tensor import make_fx def runnable_gm(*args): return torch.fx.Interpreter(gm).run(*args) pre_dispatch_gm = make_fx(runnable_gm, pre_dispatch=True)(*fake_tensor_inputs) pre_dispatch_gm.print_readable() return pre_dispatch_gm @register_backend def eager_debug(gm, fake_tensor_inputs): from torch._subclasses.schema_check_mode import SchemaCheckMode # We could add more debugging bits here. # Right now, this backend can be used to check for and error on # custom dispatcher ops that have incorrect schemas. def inner(*args): with SchemaCheckMode(): return torch.fx.Interpreter(gm).run(*args) return inner @register_backend(name="ts") def torchscript(gm, fake_tensor_inputs): return torch.jit.script(gm) # used boxed call to discard inputs when they are no longer needed def boxed_nop(fx_g, example_inputs): def run(args): return torch.fx.Interpreter(fx_g).boxed_run(args) run._boxed_call = True return run # Useful for debugging purpose # aot_eager uses AOT Autograd backend with nop compiler. It is helpful in debugging. aot_eager = aot_autograd( fw_compiler=boxed_nop, partition_fn=min_cut_rematerialization_partition ) register_backend(name="aot_eager", compiler_fn=aot_eager) aot_eager_default_partitioner = aot_autograd(fw_compiler=boxed_nop) register_backend( name="aot_eager_default_partitioner", compiler_fn=aot_eager_default_partitioner ) # Uses TorchInductor AOT Autograd decomps and partitioner to isolate aot vs # inductor problems. # aot_eager_decomp_partition just replaces the inductor compiler with nop to help # isolate inductor vs aot_eager errors aot_eager_decomp_partition = aot_autograd( # these are taken from memory_efficient_fusion() fw_compiler=boxed_nop, bw_compiler=boxed_nop, # NB: lambda here is to delay import of inductor decompositions=lambda: import_module( "torch._inductor.compile_fx" ).select_decomp_table(), partition_fn=functools.partial( min_cut_rematerialization_partition, compiler="inductor" ), ) register_backend( name="aot_eager_decomp_partition", compiler_fn=aot_eager_decomp_partition ) # AOT Autograd with torchscript backend. Default partitioner. # aot_ts uses torchscript backend. We can use this with both nnc and nvfuser # by using the relevant fuser with torch.jit.fuser(...) aot_ts = aot_autograd(fw_compiler=ts_compile) register_backend(name="aot_ts", compiler_fn=aot_ts) # These buggy backends are used for inducing bugs so that we can test # our repro extraction / minifier scripts class ReluCompileError(Exception): pass class TestingOnlyCompileError(Exception): pass @register_backend def relu_compile_error_TESTING_ONLY(gm: torch.fx.GraphModule, example_inputs): for node in gm.graph.nodes: if node.target == torch.relu: raise ReluCompileError() return gm @register_backend def relu_runtime_error_TESTING_ONLY(gm: torch.fx.GraphModule, example_inputs): for node in gm.graph.nodes: if node.target == torch.relu: node.target = torch._assert node.args = (False, "ReluRuntimeError") gm.recompile() return gm @register_backend def relu_accuracy_error_TESTING_ONLY(gm: torch.fx.GraphModule, example_inputs): for node in gm.graph.nodes: if node.target == torch.relu: node.target = torch.add node.args = (node.args[0], 1) gm.recompile() return gm @register_backend def non_leaf_compile_error_TESTING_ONLY(gm: torch.fx.GraphModule, example_inputs): # Require at least one non-trivial thing in the graph, # see https://github.com/pytorch/pytorch/issues/102898 for node in gm.graph.nodes: if node.op == "call_function": break else: return gm for t in example_inputs: if not t.is_leaf: raise TestingOnlyCompileError() return gm @dataclasses.dataclass class ExplainOutput: """ This is the output of :func:`torch._dynamo.explain()` There is no reason to create this class directly. """ graphs: List[torch.fx.GraphModule] graph_count: int graph_break_count: int break_reasons: List[ Any ] # Type is GraphCompileReason but doesn't matter for this purpose op_count: int ops_per_graph: Optional[List[torch.fx.Node]] = None out_guards: Optional[List[_guards.Guard]] = None compile_times: Optional[str] = None def __str__(self): output = f"Graph Count: {self.graph_count}\n" output += f"Graph Break Count: {self.graph_break_count}\n" output += f"Op Count: {self.op_count}\n" output += "Break Reasons:\n" for idx, break_reason in enumerate(self.break_reasons): output += f" Break Reason {idx+1}:\n" output += f" Reason: {break_reason.reason}\n" output += " User Stack:\n" for frame_summary in break_reason.user_stack: output += f" {frame_summary}\n" if self.ops_per_graph is not None: output += "Ops per Graph:\n" for idx, ops in enumerate(self.ops_per_graph): output += f" Ops {idx+1}:\n" for op in ops: output += f" {op}\n" if self.out_guards is not None: output += "Out Guards:\n" for i, guard in enumerate(self.out_guards): output += f" Guard {i+1}:\n" output += f" {str(guard)}" if self.compile_times is not None: output += f"Compile Times: {self.compile_times}\n" return output def _explain_graph_detail( gm: torch.fx.GraphModule, graphs, op_count, ops_per_graph, break_reasons ): """ This function is a utility which processes a torch.fx.GraphModule and accumulates information about its ops, graph breaks, and other details. It is intended to be used by the ExplainWithBackend class and `torch._dynamo.explain()` to provide details from Dynamo's graph capture. Parameters: gm (torch.fx.GraphModule): The GraphModule to be processed. graphs (list): A list that accumulates all the GraphModules processed. op_count (int): The total count of operations in all GraphModules processed so far. ops_per_graph (list): A list that accumulates the operations of each GraphModule. break_reasons (list): A list that accumulates the reasons for breaks in each GraphModule. Returns: tuple: A tuple containing the processed GraphModule, the updated lists of graphs, operations per graph, and break reasons, and the updated operation count. """ graphs.append(gm) ops = [node.target for node in gm.graph.nodes if node.op == "call_function"] op_count += len(ops) ops_per_graph.append(ops) if gm.compile_subgraph_reason.graph_break: break_reasons.append(gm.compile_subgraph_reason) return gm, graphs, op_count, ops_per_graph, break_reasons class ExplainWithBackend: """ This class is intended to be used as a backend for `torch.compile`. It is composable with other backends. When used in this way, it accumulates information about graph breaks, ops, and other info and provides a string representation summarizing this information. Attributes: backend (str): The name of the backend to use for optimization. graphs (list): A list of the graphs captured by TorchDynamo. op_count (int): The total number of operations in all optimized graphs. break_reasons (list): A list of graph break reasons with stack traces. Example Usage: def fn(x): x = torch.sigmoid(x) return x torch._dynamo.reset() eb = ExplainWithBackend("inductor") optimized_fn = torch.compile(fn, backend=eb) result = optimized_fn(torch.randn(5)) print(eb.output()) """ def __init__(self, backend): from .registry import lookup_backend self.backend = lookup_backend(backend) self.graphs = [] self.op_count = 0 self.break_reasons = [] def __call__(self, gm: torch.fx.GraphModule, example_inputs): gm, self.graphs, self.op_count, _, self.break_reasons = _explain_graph_detail( gm, self.graphs, self.op_count, [], self.break_reasons ) return self.backend(gm, example_inputs) def output(self) -> ExplainOutput: graph_count = len(self.graphs) output = ExplainOutput( self.graphs, graph_count, graph_count - 1, self.break_reasons, self.op_count, ) return output