import dataclasses import importlib import logging import os from typing import ( Any, Callable, Dict, Final, List, Mapping, Optional, Sequence, Set, Tuple, Union, ) from typing_extensions import TypeAlias import torch import torch._C import torch._ops import torch._prims.executor import torch.fx from torch._subclasses.fake_tensor import FakeTensor from torch.fx._compatibility import compatibility from torch.fx.passes.fake_tensor_prop import FakeTensorProp from torch.fx.passes.operator_support import OperatorSupport from torch.fx.passes.tools_common import CALLABLE_NODE_OPS from torch.utils import _pytree try: # Use try-except to initialize package-dependent global variables. import onnx import onnxruntime # type: ignore[import] from onnxruntime.capi import _pybind_state as ORTC # type: ignore[import] # This is not use directly in DORT but needed by underlying exporter, # so we still need to check if it exists. importlib.import_module("onnxscript") import torch.onnx import torch.onnx._internal import torch.onnx._internal.diagnostics import torch.onnx._internal.exporter import torch.onnx._internal.fx.decomposition_table import torch.onnx._internal.fx.passes from torch.onnx._internal.fx import fx_onnx_interpreter from torch.onnx._internal.fx.type_utils import ( _TORCH_DTYPE_TO_NUMPY_DTYPE, _TORCH_DTYPE_TO_ONNX_TENSOR_ELEMENT_TYPE, from_python_type_to_onnx_tensor_element_type, ) _SUPPORT_ONNXRT = True except ImportError: _SUPPORT_ONNXRT = False __all__ = [ "is_onnxrt_backend_supported", "torch_compile_backend", "OrtExecutionProvider", "OrtBackendOptions", "OrtBackend", ] def is_onnxrt_backend_supported() -> bool: """Returns ``True`` if ONNX Runtime dependencies are installed and usable to support TorchDynamo backend integration; ``False`` otherwise. Example:: # xdoctest: +REQUIRES(env:TORCH_DOCTEST_ONNX) >>> import torch >>> if torch.onnx.is_onnxrt_backend_supported(): ... @torch.compile(backend="onnxrt") ... def f(x): ... return x * x ... print(f(torch.randn(10))) ... else: ... print("pip install onnx onnxscript onnxruntime") ... """ return _SUPPORT_ONNXRT _dumped_onnx_model: Dict[str, int] = {} def _dump_onnx_model( model_string: bytes, graph_module: Optional[torch.fx.GraphModule] = None ) -> str: """Stores the onnx model into a file. The name is "{ONNXRT_DUMP_PATH}{N}.onnx" where *N* is the number of files already stored with this prefix. If graph_module is not None, the graph is stored as a string with the same filename except the extension (.txt). """ prefix = os.environ.get("ONNXRT_DUMP_PATH", None) if not prefix: return "" n = _dumped_onnx_model.get(prefix, -1) + 1 filename = f"{prefix}{n}.onnx" with open(filename, "wb") as f: f.write(model_string) _dumped_onnx_model[prefix] = n if graph_module is not None: filename_txt = f"{prefix}{n}.txt" with open(filename_txt, "w", encoding="utf-8") as f: f.write(str(graph_module.graph)) return filename def _infer_default_eps() -> Sequence[str]: # TODO: select a good default based on the capabilities of the host # e.g. DML on Windows, etc. return ["CPUExecutionProvider"] def _nvtx_range_push(name: str): """If PyTorch is installed with CUDA support, this starts NVTX range. Check torch.cuda.nvtx.range_push's document for more details. """ if torch.cuda.is_available(): torch.cuda.nvtx.range_push(name) def _nvtx_range_pop(): """If PyTorch is installed with CUDA support, this terminates NVTX range. Check torch.cuda.nvtx.range_pop's document for more details. """ if torch.cuda.is_available(): torch.cuda.nvtx.range_pop() def _get_ort_device_type(device_type: str): if device_type == "cuda": return ORTC.OrtDevice.cuda() if device_type == "cpu": return ORTC.OrtDevice.cpu() # ort pytorch device is mapped to NPU OrtDevice type if device_type == "ort": return ORTC.OrtDevice.npu() raise ValueError("Unsupported device type: " + device_type) logger = logging.getLogger(__name__) # Uncomment the following lines to print out development info. # logging.basicConfig(level=logging.WARNING) # logger.setLevel(logging.WARNING) class OrtOperatorSupport(OperatorSupport): """Operator support for ONNXRuntime backend. It has two-level of support decision. One is via support_dict and the other one is via extra_support_dict. The logic of using support_dict is implemented in OrtOperatorSupport and extra_support_dict is used by OperatorSupport.is_node_supported. """ def __init__(self, support_dict: Set[Any], extra_support_dict: Dict[str, Any]): # Use extra_support_dict[op_name] = None to indicate # we support op_name with all input types. Otherwise, # see support_dict (type: SupportDict) in operator_support.py # for specifying supported types. super().__init__(extra_support_dict) self._onnx_support_dict = support_dict def is_node_supported( self, submodules: Mapping[str, torch.nn.Module], node: torch.fx.Node ) -> bool: # OperatorSupport.is_node_supported returns True for non-callable nodes. # Since ORT can't execute them, we return False here to override the base # behavior. if node.op not in CALLABLE_NODE_OPS: return False # This is the and the only place to decide if aten op is supported. if node.op == "call_function" and node.target in self._onnx_support_dict: logger.warning( "support_dict supports node.target: %s (type: %s)", node.target, type(node.target), ) return True # If node.target is not in support_dict, we still want to check if torch.jit.script # can convert it to ONNX equivalence. Let's use base mechanism to do this. # See extra_support_dict for supported ops. if super().is_node_supported(submodules, node): logger.warning( "extra_support_dict supports node.target: %s (type: %s)", node.target, type(node.target), ) return True logger.warning( "support_dict and extra_support_dict don't support node.target: %s (type: %s)", node.target, type(node.target), ) return False def _move_placeholder_to_front(graph_module: torch.fx.GraphModule) -> None: """ In torch.fx.Graph, placeholder is a special assignment node. If it's not executed in the beginning, it could overwrite values computed by upstream nodes. """ graph = graph_module.graph placeholders = [] first_not_placeholder = None for node in graph.nodes: if node.op == "placeholder": placeholders.append(node) if first_not_placeholder is None and node.op != "placeholder": first_not_placeholder = node if first_not_placeholder is None: return for placeholder in placeholders: first_not_placeholder.prepend(placeholder) def _infer_ep_from_device(*args) -> Tuple[str, ...]: """Return the first valid device (i.e., GPU or CPU) in argument list.""" eps = [] for arg in args: if hasattr(arg, "device"): device = arg.device if device.type == "cuda": eps.append("CUDAExecutionProvider") elif device.type == "cpu": eps.append("CPUExecutionProvider") return tuple(eps) def _extract_graph_module_inputs(graph_module: torch.fx.GraphModule) -> Tuple[Any, ...]: placeholders = [] for node in graph_module.graph.nodes: if node.op == "placeholder": if hasattr(node, "meta") and "val" in node.meta: assert isinstance(node.meta["val"], torch.Tensor) placeholders.append(node) return tuple(placeholders) def _extract_graph_module_outputs(graph_module: torch.fx.GraphModule) -> Any: """Collect "val" fields from outputs metadata in this torch.fx.GraphModule.""" for node in graph_module.graph.nodes: if node.op == "output": # Output node is unique. Let's retrieve output values from # this node's input list. And then just return. return node.args[0] raise ValueError("No output node found in this torch.fx.GraphModule.") def _infer_ep_from_graph_module(graph_module: torch.fx.GraphModule) -> Tuple[str, ...]: """Return the all valid devices (i.e., GPU or CPU) among outputs of this torch.fx.GraphModule.""" flattened_output_args, _ = _pytree.tree_flatten( _extract_graph_module_outputs(graph_module) ) # Output arguments with example value (type: torch.Tensor) in the `graph_module`. selected_output_args = [ output_arg.meta["val"] for output_arg in flattened_output_args # output_arg must have tensor for its device information. # Otherwise, skip it. if (hasattr(output_arg, "meta") and "val" in output_arg.meta) ] return _infer_ep_from_device(*selected_output_args) def _sort_eps(eps: Tuple[str, ...]) -> Tuple[str, ...]: """Sort execution providers in eps based on pre-set priority.""" def get_execution_provider_priority(ep: str) -> int: if ep == "CPUExecutionProvider": # Lowest priority. return 2 if ep == "CUDAExecutionProvider": # Higher priority than CPU but lower than # other specialized EPs. return 1 # Highest priority. return 0 unique_eps = set(eps) return tuple(sorted(unique_eps, key=get_execution_provider_priority, reverse=True)) def _get_onnx_devices( values: Tuple[ Union[ torch.Tensor, torch.SymInt, int, torch.SymFloat, float, torch.SymBool, bool ], ..., ] ) -> Tuple["ORTC.OrtDevice", ...]: def _device_id_or_zero(device_id: int) -> int: return device_id or 0 def _map_tensor_or_sym_to_device( value: Union[ torch.Tensor, torch.SymInt, int, torch.SymFloat, float, torch.SymBool, bool ], ) -> int: if isinstance(value, torch.Tensor): return ORTC.OrtDevice( _get_ort_device_type(value.device.type), ORTC.OrtDevice.default_memory(), _device_id_or_zero(value.device.index), ) elif isinstance( value, (torch.SymInt, int, torch.SymFloat, float, torch.SymBool, bool) ): return ORTC.OrtDevice( _get_ort_device_type("cpu"), ORTC.OrtDevice.default_memory(), 0 ) else: raise ValueError("Unsupported value type: " + str(type(value))) if len(values) > 0: ort_devices = tuple(_map_tensor_or_sym_to_device(value) for value in values) return ort_devices else: return (_map_tensor_or_sym_to_device(1),) def _get_ortvalues_from_torch_tensors( tensors: Tuple[torch.Tensor, ...], devices: Tuple["ORTC.OrtDevice", ...] ) -> Tuple[torch.Tensor, ...]: ortvalues = ORTC.OrtValueVector() ortvalues.reserve(len(tensors)) dtypes = [] shapes = [] data_ptrs = [] for tensor in tensors: dtypes.append(_TORCH_DTYPE_TO_NUMPY_DTYPE[tensor.dtype]) shapes.append(tensor.size()) data_ptrs.append(tensor.data_ptr()) ortvalues.push_back_batch(tensors, data_ptrs, dtypes, shapes, devices) return ortvalues def _to_real_tensor(tensor: FakeTensor) -> torch.Tensor: if tensor.is_sparse: raise ValueError("sparse tensor is not yet supported.") out = torch.empty(tensor.size(), dtype=tensor.dtype, device=tensor.device) return out def _adjust_scalar_from_fx_to_onnx( dynamo_value: Union[ torch.Tensor, int, float, bool, ], value_info: "onnx.ValueInfoProto", # type: ignore[name-defined] ) -> torch.Tensor: """Helper function to wrap PyTorch variables as torch.Tensor""" if ( isinstance(dynamo_value, torch.Tensor) and len(value_info.type.tensor_type.shape.dim) == 0 and dynamo_value.shape == (1,) ): # ONNX expect a scalar with empty shape. # In contrast, PyTorch usually allows implicit # conversion between shape=() and shape=(1,). # # Below, PyTorch's shape (1,) is reshaped to (). return torch.squeeze(dynamo_value) elif isinstance(dynamo_value, int): return torch.tensor(dynamo_value, dtype=torch.int64) elif isinstance(dynamo_value, float): return torch.tensor(dynamo_value, dtype=torch.float32) elif isinstance(dynamo_value, bool): return torch.tensor(dynamo_value, dtype=torch.bool) else: assert isinstance(dynamo_value, torch.Tensor) return dynamo_value.contiguous() def _adjust_scalar_from_onnx_to_fx( tensor: torch.Tensor, prim_value: Union[ torch.Tensor, torch.SymInt, int, torch.SymFloat, float, torch.SymBool, bool, ], ) -> Union[torch.Tensor, int, float, bool,]: """Helper function to wrap ORT-produced torch.Tensor as PyTorch variables""" assert isinstance(tensor, torch.Tensor), "ORT's output must be tensor." if isinstance( prim_value, (torch.SymInt, int, torch.SymFloat, float, torch.SymBool, bool), ): # Convert tensor back to scalar to match Dynamo's expectation. return tensor.item() return tensor def _run_onnx_session_with_ortvaluevector( sess: "onnxruntime.InferenceSession", input_names: Tuple[str, ...], inputs: Tuple[torch.Tensor, ...], input_devices: Tuple["ORTC.OrtDevice", ...], output_names: Tuple[str, ...], outputs: Tuple[torch.Tensor, ...], output_devices: Tuple["ORTC.OrtDevice", ...], preallocate_output: bool, input_value_infos: Tuple["onnx.ValueInfoProto", ...], # type: ignore[name-defined] normalized_prim_outputs: Tuple[ Union[ torch.Tensor, torch.SymInt, int, torch.SymFloat, float, torch.SymBool, bool ], ..., ], ) -> Tuple[Union[torch.Tensor, int, float, bool], ...]: _nvtx_range_push("contiguous") inputs = tuple( _adjust_scalar_from_fx_to_onnx(arg, value_info) for arg, value_info in zip(inputs, input_value_infos) ) _nvtx_range_pop() _nvtx_range_push("push_back_batch") ort_inputs = _get_ortvalues_from_torch_tensors(inputs, input_devices) # preallocate output pytorch Tensors and use the buffers affined to the torch device for the output ortvalue. # Because the output ortvalue is not allocated and owned by ort, it does not need to convert the output ortvalue # to torch Tensor transferring the ownership. if preallocate_output: pth_outputs = tuple( _to_real_tensor(t) if isinstance(t, FakeTensor) else t for t in outputs ) ort_outputs = _get_ortvalues_from_torch_tensors(pth_outputs, output_devices) else: ort_outputs = ORTC.OrtValueVector() _nvtx_range_pop() _nvtx_range_push("run_with_ortvaluevector") run_options = onnxruntime.RunOptions() run_options.add_run_config_entry("disable_synchronize_execution_providers", "1") sess.run_with_ortvaluevector( run_options, input_names, ort_inputs, output_names, ort_outputs, output_devices ) _nvtx_range_pop() # Post-processing step: # wrap ORT's outputs to the schema represented by # `prim_output` (obtained by running the original # torch.fx.GraphModule). if preallocate_output: # Profile the ORT-to-PyTorch type cast below _nvtx_range_push("after run_with_ortvaluevector") # Outputs are stored on pre-allocated torch.Tensors' memory, # so this case doesn't need to convert ORTValue to torch.Tensor. pth_outputs = tuple( _adjust_scalar_from_onnx_to_fx(onnx_output, prim_output) # type: ignore[misc] for onnx_output, prim_output in zip(pth_outputs, normalized_prim_outputs) ) _nvtx_range_pop() return pth_outputs else: # Profile the two ORT-to-PyTorch type casts below _nvtx_range_push("after run_with_ortvaluevector") # Map ORTValue to torch.Tensor. pth_outputs = onnxruntime.training.ortmodule._utils._ortvalues_to_torch_tensor( ort_outputs ) # Change some torch.Tensor to int, float, bool. pth_outputs = tuple( _adjust_scalar_from_onnx_to_fx(onnx_output, prim_output) # type: ignore[misc] for onnx_output, prim_output in zip(pth_outputs, normalized_prim_outputs) ) _nvtx_range_pop() return pth_outputs def _run_onnx_session_with_fetch( sess: "onnxruntime.InferenceSession", input_names: Tuple[str, ...], inputs: Tuple[torch.Tensor, ...], input_devices: Tuple["ORTC.OrtDevice", ...], output_names: Tuple[str, ...], outputs: Tuple[torch.Tensor, ...], output_devices: Tuple["ORTC.OrtDevice", ...], preallocate_output: bool, input_value_infos: Tuple["onnx.ValueInfoProto", ...], # type: ignore[name-defined] normalized_prim_outputs: Tuple[ Union[ torch.Tensor, torch.SymInt, int, torch.SymFloat, float, torch.SymBool, bool ], ..., ], ) -> Tuple[Union[torch.Tensor, int, float, bool], ...]: inputs = tuple( _adjust_scalar_from_fx_to_onnx(arg, value_info) for arg, value_info in zip(inputs, input_value_infos) ) feed = { name: onnxruntime.OrtValue.ortvalue_from_numpy(tensor.cpu().numpy()) for name, tensor in zip(input_names, inputs) } ort_outputs = sess.run(output_names, feed) pth_outputs = tuple( _adjust_scalar_from_onnx_to_fx( torch.from_numpy(value), prim_output, ) for value, prim_output in zip(ort_outputs, normalized_prim_outputs) ) return pth_outputs class OrtExecutionInfoPerSession: """Information required to execute torch.fx.GraphModule using onnxruntime.InferenceSession""" def __init__( self, session: "onnxruntime.InferenceSession", input_names: Tuple[str, ...], input_value_infos: Tuple["onnx.ValueInfoProto", ...], # type: ignore[name-defined] output_names: Tuple[str, ...], output_value_infos: Tuple["onnx.ValueInfoProto", ...], # type: ignore[name-defined] input_devices: Tuple["ORTC.OrtDevice", ...], output_devices: Tuple["ORTC.OrtDevice", ...], example_outputs: Union[Tuple[torch.Tensor, ...], torch.Tensor], ): # Carrier of ONNX model and its executor. self.session: onnxruntime.InferenceSession = session # For the ONNX model stored in self.session, self.input_names[i] is the # name of the i-th positional input. self.input_names: Tuple[str, ...] = input_names # self.input_name[i]'s type information is stored in self.input_value_infos[i]. self.input_value_infos: Tuple[onnx.ValueInfoProto, ...] = input_value_infos # type: ignore[name-defined] # Similar to self.input_names, but for outputs. self.output_names: Tuple[str, ...] = output_names # Similar to self.input_value_infos but for outputs. self.output_value_infos: Tuple[onnx.ValueInfoProto, ...] = output_value_infos # type: ignore[name-defined] # For the ONNX model stored in self.session, self.input_devices[i] is the # i-th positional input's device. self.input_devices: Tuple["ORTC.OrtDevice", ...] = input_devices # Similar to self.input_devices, but for outputs. self.output_devices: Tuple["ORTC.OrtDevice", ...] = output_devices # This is the outputs of executing the original torch.fx.GraphModule with example inputs # (i.e., args passed into OrtBackend._ort_acclerated_call). self.example_outputs: Union[ Tuple[torch.Tensor, ...], torch.Tensor ] = example_outputs def is_supported(self, *args): # Compare the args and the input schema in ONNX model and # return the first match. if len(args) != len(self.input_value_infos): return False for arg, value_info in zip(args, self.input_value_infos): if not isinstance(arg, (torch.Tensor, float, int)): return False # Check Python scalars such as int, float, and bool. if isinstance(arg, (int, float, bool)): # Map, e.g., float to onnx.TensorProto.FLOAT. onnx_dtype = from_python_type_to_onnx_tensor_element_type(type(arg)) if onnx_dtype != value_info.type.tensor_type.elem_type: return False if len(value_info.type.tensor_type.shape.dim) != 0: return False continue # Check tensor. onnx_dtype = _TORCH_DTYPE_TO_ONNX_TENSOR_ELEMENT_TYPE[arg.dtype] if onnx_dtype != value_info.type.tensor_type.elem_type: return False for dim, onnx_dim in zip(arg.shape, value_info.type.tensor_type.shape.dim): if isinstance(dim, int) and ( onnx_dim.dim_value == dim or onnx_dim.dim_param ): continue elif isinstance(dim, torch.SymInt) and onnx_dim.dim_param: continue else: return False return True @dataclasses.dataclass class OrtExecutionInfoForAllGraphModules: def __init__(self): # All sessions (and their related information) created by exporting the same GraphModule # with different inputs. self.execution_info_per_graph_module: Dict[ torch.fx.GraphModule, List[OrtExecutionInfoPerSession] ] = {} def search_reusable_session_execution_info( self, graph_module: torch.fx.GraphModule, *args ): if graph_module not in self.execution_info_per_graph_module: return None # All execution information for ONNX models exported from the same `graph_module` # with different inputs. candidates = self.execution_info_per_graph_module[graph_module] for candidate in candidates: if candidate.is_supported(*args): # Returns the first session that accepts this input schema. return candidate # No reusable session found. return None def cache_session_execution_info( self, graph_module: torch.fx.GraphModule, info: OrtExecutionInfoPerSession ): if graph_module not in self.execution_info_per_graph_module: self.execution_info_per_graph_module[graph_module] = [info] else: self.execution_info_per_graph_module[graph_module].append(info) OrtExecutionProvider: TypeAlias = Union[str, Tuple[str, Mapping[str, Any]]] """Either the name of an ONNX Runtime execution provider as a string or a 2-tuple of the name and a dictionary of execution provider options. Examples:: >>> "CPUExecutionProvider" >>> ("CUDAExecutionProvider", {"device_id": 3}) """ @dataclasses.dataclass(frozen=True) @compatibility(is_backward_compatible=False) class OrtBackendOptions: """Options for constructing an ``OrtBackend``, the ONNX Runtime backend (``"onnxrt"``) for ``torch.compile``. Example:: >>> @torch.compile( ... backend="onnxrt", ... options=torch.onnx._OrtBackendOptions(...), ... ) ... def ort_function(x): ... return x ** x """ preferred_execution_providers: Optional[Sequence[OrtExecutionProvider]] = None """An optional sequence of execution providers to be prioritized ahead of any execution providers that may be inferred (see ``infer_execution_providers``). """ infer_execution_providers: bool = True """Whether to infer an execution provider from ``torch.device`` bound to inputs or found in the graph.""" default_execution_providers: Optional[Sequence[OrtExecutionProvider]] = None """The default fallback execution providers. If not specified, one will be be selected based on the host environment (most likely ``"CPUExecutionProvider"``). """ # preallocate_output allows for allocating output torch Tensor buffers and feeding them to InferenceSession # in order to avoid internal allocation of output buffers in InferenceSession. # If output ortvalue returned from InferenceSession is allocated internally, # it needs to be converted to torch Tensor for return, and the torch Tensor should hold the ownership. # When a custom torch device is used with a custom aten allocator, the conversion from ortvalue to torch Tensor # should be supported, which is currently done through dlpack. Note that dlpack might not support a custom torch device. # It can be avoided by allowing for preallocation for output buffers allocated by a custom aten allocator, # and use the preallocated output buffers for InferenceSession not holding any ownership for them. # TODO(wschin): Make it to inference session level flag. # See https://github.com/pytorch/pytorch/issues/106869. preallocate_output: bool = False """If ``True``, allocate memory for ONNX Runtime's outputs on the PyTorch side.""" use_aot_autograd: bool = True """Whether to wrap the ``OrtBackend`` with TorchDynamo's aot_autograd backend to support training (i.e., backward graphs are also sent to ``OrtBackend``). Symbolic execution is used to capture the forward pass and backward passes as a single graph. Then, a selected graph partition algorithm (``min_cut_rematerialization_partition``) is used to split the entire graph into forward sub-graph and backward sub-graph. Finally, both sub-graphs are compiled by ``OrtBackend``. """ export_options: Optional["torch.onnx.ExportOptions"] = None """Options for the TorchDynamo-based ONNX exporter used by the ``OrtBackend``.""" ort_session_options: Optional["onnxruntime.SessionOptions"] = None """Options for the ``onnxruntime.InferenceSession`` used by the ``OrtBackend``.""" pre_ort_model_transforms: Optional[ # type: ignore[name-defined] Sequence[Callable[["onnx.ModelProto"], None]] ] = None """A list of graph transforms to be applied to the ONNX model before it is fed to ONNXRuntime's InferenceSession.""" @compatibility(is_backward_compatible=False) class OrtBackend: """A backend compiles (sub-)graphs in torch.fx.GraphModule to onnxruntime.InferenceSession calls. The compiler entry point is OrtBackend.compile, which 1. partitions the original graph into supported sub-graphs (type: torch.fx.GraphModule) and unsupported sub-graphs. 2. For each supported sub-graph, it replaces its _wrapped_call function with _ort_accelerated_call. 3. Inside _ort_accelerated_call, it creates onnxruntime.InferenceSession and calls it to execute the sub-graph. """ def __init__(self, options: Optional[OrtBackendOptions] = None): self._options: Final = OrtBackendOptions() if options is None else options # options.export_options contains information shared between exporter and DORT. # For example, they should use the same decomposition table when # 1. capturing FX graph in torch.compile (see how we create aot_ort in register_backend.py) # 2. call exporter's API to convert `torch.fx.GraphModule` to ONNX model # (see onnxfunction_dispatcher passed to FxOnnxInterpreter.run below). # # Convert user-facing option to internal option used by ONNX exporter # to access required information. # Some useful fields: # - Decomposition table for decomposing FX operators in exporter is # self._resolved_onnx_exporter_options.decomposition_table. # - self._resolved_onnx_exporter_options.onnx_registry records what # aten/prim ops are supported by exporter and their exporters (type: callable). self._resolved_onnx_exporter_options = ( torch.onnx._internal.exporter.ResolvedExportOptions( torch.onnx.ExportOptions() if self._options.export_options is None else self._options.export_options ) ) # Given DORT's computation flow: # 1. OrtOperatorSupport uses support_dict and extra_support_dict to select operators # and send them to DORT. # 2. Then, DORT exports the selected sub-graphs into ONNX. # 3. Finally DORT calls ORT to do the computation. # OrtOperatorSupport and create_onnx_friendly_decomposition_table(...) # must use the same support_dict. If the support_dict here contains something not # supported by exporter, exporter will fails in step 2 since the selected graphs may # contains unsupported operators such as aten::_who_you_are. # This restriction is automatically done since DORT and exporter shares the same # self._resolved_onnx_exporter_options. support_dict = torch.onnx._internal.fx.decomposition_table._create_onnx_supports_op_overload_table( self._resolved_onnx_exporter_options.onnx_registry ) extra_support_dict: Dict[str, Any] = { "getattr": None, # To send operator.getitem to ORT, add the corresponding string # recognized by PyTorch's OperatorSupport class. "_operator.getitem": None, # To send operator.mul to ORT, add the corresponding string # recognized by PyTorch's OperatorSupport class. "_operator.mul": None, "_operator.add": None, "_operator.sub": None, } self._supported_ops = OrtOperatorSupport(support_dict, extra_support_dict) # TODO(wschin): this is a naive implementation of cache without proper guard # See https://github.com/pytorch/pytorch/issues/106868. self._partitioner_cache: Dict[torch.fx.GraphModule, torch.fx.GraphModule] = {} # Conceptually, this filed is a 2-layer dictionary # GraphModule 0 # ONNX Model 0 (with ORT InferenceSession and related information. type: OrtExecutionInfoPerSession) # ONNX Model 1 # ... # GraphModule 1 # ONNX Model 2 (with ORT InferenceSession and related information. type: OrtExecutionInfoPerSession) # ONNX Model 3 # ... # ... # , which caches all previous compilation result so that we can reuse them. # ONNX Model 0 and 1 are exported from the same GraphModule 0 but with different inputs # (e.g., tensors with different ranks). GraphModule 0 and GraphModule 1 are different # graphs captured by Dynamo and sent to OrtBackend.compile. self._all_ort_execution_info = OrtExecutionInfoForAllGraphModules() self._assert_allclose_to_baseline = False self.execution_count = 0 # Function which invokes ORT do to the real computation. self.run = ( _run_onnx_session_with_ortvaluevector if hasattr(ORTC.OrtValueVector, "push_back_batch") else _run_onnx_session_with_fetch ) def _select_eps( self, graph_module: torch.fx.GraphModule, *args ) -> Sequence[Tuple[str, Mapping[str, Any]]]: inferred_eps: Tuple[str, ...] = tuple() if self._options.infer_execution_providers: if eps_from_args := _infer_ep_from_device(*args): # If user feeds CUDA tensor as input argument, # we want to use CUDA EP. # Thus, `eps_from_args` (deduced from input arguments) # has highest priority. inferred_eps = eps_from_args elif eps_from_graph_module := _infer_ep_from_graph_module(graph_module): # If there is no EP in input arguments, we deduce EP from # graph_module's outputs. Those outputs may come from # FakeTensorProp or Dynamo's built-in symbolic shape inference. inferred_eps = eps_from_graph_module selected_eps = [] for ep in ( *(self._options.preferred_execution_providers or []), *_sort_eps(inferred_eps), *(self._options.default_execution_providers or _infer_default_eps()), ): if isinstance(ep, str): ep = (ep, {}) elif isinstance(ep, tuple) and ep[1] is None: ep = (ep[0], {}) if ep is not None and ep not in selected_eps: selected_eps.append(ep) return selected_eps def _ort_acclerated_call(self, graph_module: torch.fx.GraphModule, *args, **kwargs): """This function replaces GraphModule._wrapped_call in compiled model. The _wrapped_call is the underlying implementation of forward method. Replacing it means we delegate the computation to _ort_acclerated_call and therefore onnxruntime.InferenceSession. """ cached_execution_info_per_session = ( self._all_ort_execution_info.search_reusable_session_execution_info( graph_module, *args ) ) if cached_execution_info_per_session: onnx_session = cached_execution_info_per_session.session input_names = cached_execution_info_per_session.input_names output_names = cached_execution_info_per_session.output_names input_value_infos = cached_execution_info_per_session.input_value_infos output_value_infos = cached_execution_info_per_session.output_value_infos input_devices = cached_execution_info_per_session.input_devices output_devices = cached_execution_info_per_session.output_devices prim_outputs = cached_execution_info_per_session.example_outputs else: # It's first time seeing such as graph. Let's make a new session # (type: onnxruntime.InferenceSession) for it. graph_module = torch.onnx._internal.fx.passes.MovePlaceholderToFront( self._resolved_onnx_exporter_options.diagnostic_context, graph_module, ).run() # Generate reference outputs. They are used to indicate output # tensors' types and devices when calling ORT. # # WARNING: The downstream code should not change prim_outputs and # this backend should always produces output with schema identical to prim_outputs'. if self._resolved_onnx_exporter_options.dynamic_shapes: # No pre-allocation when dynamic shape is enabled. self.preallocate_output = False extracted_outputs = _extract_graph_module_outputs(graph_module) def maybe_map_to_meta_val(value): if hasattr(value, "meta") and "val" in value.meta: # Select outputs with "val" information. Without "val", # it's not possible access output_arg.meta["val"].device. return value.meta["val"] else: return value prim_outputs = _pytree.tree_map( maybe_map_to_meta_val, extracted_outputs ) else: try: prim_outputs = FakeTensorProp(graph_module).propagate( *args, **kwargs ) except Exception: logger.warning("FakeTensorProb failed for %s", graph_module) # When FakeTensorProp fails, it is not possible to preallocate output buffers # because the output shapes are not inferred. self.preallocate_output = False # rethrow FakeTensorProb failure because it is not yet currently handled. raise # Create the object to iterate through the nodes in graph one-by-one # and calls the corresponding ONNX exporter for each node. fx_interpreter = fx_onnx_interpreter.FxOnnxInterpreter( diagnostic_context=self._resolved_onnx_exporter_options.diagnostic_context ) # Cast FX variables if they will result schema-mismatch when searching # for ONNX operator. E.g., add(double_tensor, int_tensor) is fine in PyTorch, # but ONNX expects add(double_tensor, double_tensor). graph_module = torch.onnx._internal.fx.passes.InsertTypePromotion( self._resolved_onnx_exporter_options.diagnostic_context, graph_module ).run() # Start the per-node exporting process. It's conceptually a for loop # scanning through the nodes in the graph. exported = fx_interpreter.run( fx_graph_module=graph_module, onnxfunction_dispatcher=self._resolved_onnx_exporter_options.onnxfunction_dispatcher, op_level_debug=self._resolved_onnx_exporter_options.op_level_debug, ) # Convert the exported result to ONNX ModelProto. onnx_model = exported.to_model_proto( opset_version=self._resolved_onnx_exporter_options.onnx_registry.opset_version, ) # Modify ONNX model using pre-registered graph transforms. # They are in-place modifications for avoiding unnecessary # copy of ONNX initializers. if self._options.pre_ort_model_transforms: for transform in self._options.pre_ort_model_transforms: transform(onnx_model) onnx_model_bytes = onnx_model.SerializeToString() if os.environ.get("ONNXRT_DUMP_PATH", None): # If not empty, environment variable ONNXRT_DUMP_PATH defined the path # where generated onnx files should be stored. # This module keeps a global variables keeping track of the # stored models. # If ONNXRT_DUMP_PATH="dumped/dumped_model_" # The first file name will be 'dumped/dumped_model_0.onnx'. # For every dumped model, a text file 'dumped/dumped_model_0.txt' # is created as well to contain the string representing the graph_module. _dump_onnx_model(onnx_model_bytes, graph_module=graph_module) # Initialize a ORT session to execute this ONNX model. # Note that TorchDynamo assumes all inputs/outputs are on the # same device, but it's subject to change (very likely with # dynamic shape support), so we add execution providers # based on the logic in _select_eps: (explicitly preferred EPs, # EPs inferred from inputs or graph, and the fallback default EP)/ # # TODO(wschin): enable external allocators. # See https://github.com/pytorch/pytorch/issues/106867 onnx_session = onnxruntime.InferenceSession( path_or_bytes=onnx_model_bytes, sess_options=self._options.ort_session_options, providers=self._select_eps(graph_module, *args), ) # Cache ORT session. It's reused for the same "graph_module". # Generate ONNX model and extract its input and output names. input_names = tuple(input.name for input in onnx_model.graph.input) output_names = tuple(output.name for output in onnx_model.graph.output) input_devices = _get_onnx_devices(args) # Cache devices for inputs and outputs. They are used to invoke # ORT session. Output devices indicate where (e.g., GPU or CPU) # to store outputs if isinstance(prim_outputs, tuple): output_devices = _get_onnx_devices(prim_outputs) else: output_devices = _get_onnx_devices((prim_outputs,)) input_value_infos = tuple(input for input in onnx_model.graph.input) output_value_infos = tuple(output for output in onnx_model.graph.output) execution_info_per_session = OrtExecutionInfoPerSession( session=onnx_session, input_names=input_names, input_value_infos=input_value_infos, output_names=output_names, output_value_infos=output_value_infos, input_devices=input_devices, output_devices=output_devices, example_outputs=prim_outputs, ) self._all_ort_execution_info.cache_session_execution_info( graph_module, execution_info_per_session ) self.execution_count += 1 # ORT always returns a tuple of outputs. If the original output is a tensor, # ORT output's first element must be extracted and returned. Otherwise, type # mismatch may happen in downstream computation. is_single_tensor_output = isinstance(prim_outputs, torch.Tensor) normalized_prim_outputs = ( (prim_outputs,) if is_single_tensor_output else prim_outputs ) assert isinstance(normalized_prim_outputs, tuple) assert all( isinstance(elem, (torch.Tensor, torch.SymInt, int)) for elem in normalized_prim_outputs ) _nvtx_range_push("run_onnx_session_with_ortvaluevector") onnx_outputs = self.run( onnx_session, input_names, args, input_devices, output_names, normalized_prim_outputs, output_devices, self._options.preallocate_output, input_value_infos, normalized_prim_outputs, ) _nvtx_range_pop() if self._assert_allclose_to_baseline: # Compute baseline. baseline_outputs = torch._prims.executor.execute( graph_module, *args, executor="aten" ) normalized_baseline_ouptuts = ( (baseline_outputs,) if is_single_tensor_output else baseline_outputs ) # Ensure every output tensor is close to the corresponding baseline. for onnx_output, baseline_output in zip( onnx_outputs, normalized_baseline_ouptuts ): torch.testing.assert_close(onnx_output, baseline_output) return onnx_outputs[0] if is_single_tensor_output else onnx_outputs def compile(self, graph_module: torch.fx.GraphModule, args) -> torch.fx.GraphModule: # Deferred import since CapabilityBasedPartitioner is not decorated with # @compatibility; importing it at the module level will result in the test # failing: pytest test/test_fx.py -k test_public_api_surface # because this module is imported into torch.onnx. from torch.fx.passes.infra.partitioner import CapabilityBasedPartitioner # FX graph based partitioning based on ONNX supported ops. # Given a graph module # GraphModule0 # node_0 # node_1 # node_2 # node_3 # node_4 # If only node_2 is not supported by ONNX, this graph module will be partitioned into # GraphModule0 # GraphModule1 # node_0 # node_1 # node_2 # GraphModule2 # node_3 # node_4 # by calling CapabilityBasedPartitioner.partition_and_fuse. # Then, GraphModule1's and GraphModule2's forward method (GraphModule._wrapped_call) # will be replaced by OrtBackend._ort_accelerated_call to delegate computation to ORT. if graph_module in self._partitioner_cache: partitioned_prim_graph_module = self._partitioner_cache[graph_module] else: prim_graph_module = graph_module partitioner = CapabilityBasedPartitioner( prim_graph_module, self._supported_ops, allows_single_node_partition=True, ) partitioned_prim_graph_module = partitioner.partition_and_fuse() self._partitioner_cache[graph_module] = partitioned_prim_graph_module # Overriding fused_module's __call__() function with ort_acclerated_call() # This loop goes through all graph partitions (each of them is an ONNX-representable graph) # and override their _wrapped_call function with _ort_accelerated_call. # Inside _ort_accelerated_call, the partition's graph is exported into ONNX and executed by ORT. for node in partitioned_prim_graph_module.graph.nodes: # TODO(wschin): use a better way to identify fused submodule # See https://github.com/pytorch/pytorch/issues/106872. if node.op == "call_module" and "fused_" in node.name: fused_module = getattr(partitioned_prim_graph_module, node.name) # self.ort_acclerated_call is responsible for exporting graph to ONNX, # creating ORT session, and running ORT session. fused_module._wrapped_call = self._ort_acclerated_call return partitioned_prim_graph_module def __call__( self, graph_module: torch.fx.GraphModule, args ) -> torch.fx.GraphModule: """If ``OrtBackendOptions.use_aot_autograd`` is ``True``, the `auto_autograd` compiler will be invoked, wrapping this ``OrtBackend`` instance's ``compile`` method. Otherwise, the ``compile`` method is invoked directly.""" if self._options.use_aot_autograd: from functorch.compile import min_cut_rematerialization_partition from torch._dynamo.backends.common import aot_autograd return aot_autograd( fw_compiler=self.compile, partition_fn=min_cut_rematerialization_partition, decompositions=self._resolved_onnx_exporter_options.decomposition_table, )(graph_module, args) return self.compile(graph_module, args) __instance_cache_max_count: Final = 8 __instance_cache: Final[List["OrtBackend"]] = [] @staticmethod def get_cached_instance_for_options( options: Optional[Union[OrtBackendOptions, Mapping[str, Any]]] = None, ) -> "OrtBackend": """Returns a possibly cached instance of an ``OrtBackend``. If an existing backend was created previously through this function with the same options, it will be returned. Otherwise a new backend will be created, cached, and returned. Note: if ``options`` sets ``ort_session_options``, a new ``OrtBackend`` will always be returned, since ``onnxruntime.SessionOptions`` cannot participate in caching.""" def reusable(a: OrtBackendOptions, b: OrtBackendOptions): if ( a.preferred_execution_providers != b.preferred_execution_providers or a.infer_execution_providers != b.infer_execution_providers or a.default_execution_providers != b.default_execution_providers or a.preallocate_output != b.preallocate_output or a.use_aot_autograd != b.use_aot_autograd or a.pre_ort_model_transforms != b.pre_ort_model_transforms ): return False # onnxruntime.SessionOptions is a pybind11 object, cannot be pickled, # and holds too much potential state to reasonably check manually; # ort_session_options is provided at all, the backend does not participate # in caching. if a.ort_session_options is not None or b.ort_session_options is not None: return False if a.export_options is b.export_options: return True # Similarly, some objects in ExportOptions are too stateful to use for # caching. We should revisit this. if a.export_options is not None and b.export_options is not None: return ( a.export_options.dynamic_shapes == b.export_options.dynamic_shapes and a.export_options.op_level_debug == b.export_options.op_level_debug and a.export_options.diagnostic_options == b.export_options.diagnostic_options and a.export_options.onnx_registry is b.export_options.onnx_registry and a.export_options.fake_context is b.export_options.fake_context ) # We can't account for how the two option sets may differ, so it's not safe to reuse. return False if not isinstance(options, OrtBackendOptions): options = OrtBackendOptions(**(options or {})) backend = next( (b for b in OrtBackend.__instance_cache if reusable(b._options, options)), None, ) if backend is None: assert ( len(OrtBackend.__instance_cache) < OrtBackend.__instance_cache_max_count ), ( f"No more than {OrtBackend.__instance_cache_max_count} instances of " f"{OrtBackend} allowed. Please instantiate `{OrtBackend}` explicitly " "to pass to `torch.compile`. " "See https://github.com/pytorch/pytorch/pull/107973#discussion_r1306144795 " "for discussion." ) OrtBackend.__instance_cache.append(backend := OrtBackend(options)) return backend @staticmethod def clear_cached_instances(): OrtBackend.__instance_cache.clear() @staticmethod def get_cached_instances(): return tuple(OrtBackend.__instance_cache) @compatibility(is_backward_compatible=False) def torch_compile_backend( graph_module: torch.fx.GraphModule, args, *, options: Optional[Union[OrtBackendOptions, Mapping[str, Any]]] = None, ): return OrtBackend.get_cached_instance_for_options(options)(graph_module, args)