1200 lines
50 KiB
Python
1200 lines
50 KiB
Python
import dataclasses
|
|
import importlib
|
|
import logging
|
|
import os
|
|
|
|
from typing import (
|
|
Any,
|
|
Callable,
|
|
Dict,
|
|
Final,
|
|
List,
|
|
Mapping,
|
|
Optional,
|
|
Sequence,
|
|
Set,
|
|
Tuple,
|
|
Union,
|
|
)
|
|
|
|
from typing_extensions import TypeAlias
|
|
|
|
import torch
|
|
import torch._C
|
|
import torch._ops
|
|
import torch._prims.executor
|
|
import torch.fx
|
|
from torch._subclasses.fake_tensor import FakeTensor
|
|
from torch.fx._compatibility import compatibility
|
|
from torch.fx.passes.fake_tensor_prop import FakeTensorProp
|
|
from torch.fx.passes.operator_support import OperatorSupport
|
|
from torch.fx.passes.tools_common import CALLABLE_NODE_OPS
|
|
from torch.utils import _pytree
|
|
|
|
try:
|
|
# Use try-except to initialize package-dependent global variables.
|
|
import onnx
|
|
import onnxruntime # type: ignore[import]
|
|
from onnxruntime.capi import _pybind_state as ORTC # type: ignore[import]
|
|
|
|
# This is not use directly in DORT but needed by underlying exporter,
|
|
# so we still need to check if it exists.
|
|
importlib.import_module("onnxscript")
|
|
|
|
import torch.onnx
|
|
import torch.onnx._internal
|
|
import torch.onnx._internal.diagnostics
|
|
import torch.onnx._internal.exporter
|
|
import torch.onnx._internal.fx.decomposition_table
|
|
import torch.onnx._internal.fx.passes
|
|
from torch.onnx._internal.fx import fx_onnx_interpreter
|
|
from torch.onnx._internal.fx.type_utils import (
|
|
_TORCH_DTYPE_TO_NUMPY_DTYPE,
|
|
_TORCH_DTYPE_TO_ONNX_TENSOR_ELEMENT_TYPE,
|
|
from_python_type_to_onnx_tensor_element_type,
|
|
)
|
|
|
|
_SUPPORT_ONNXRT = True
|
|
except ImportError:
|
|
_SUPPORT_ONNXRT = False
|
|
|
|
__all__ = [
|
|
"is_onnxrt_backend_supported",
|
|
"torch_compile_backend",
|
|
"OrtExecutionProvider",
|
|
"OrtBackendOptions",
|
|
"OrtBackend",
|
|
]
|
|
|
|
|
|
def is_onnxrt_backend_supported() -> bool:
|
|
"""Returns ``True`` if ONNX Runtime dependencies are installed and usable
|
|
to support TorchDynamo backend integration; ``False`` otherwise.
|
|
|
|
Example::
|
|
|
|
# xdoctest: +REQUIRES(env:TORCH_DOCTEST_ONNX)
|
|
>>> import torch
|
|
>>> if torch.onnx.is_onnxrt_backend_supported():
|
|
... @torch.compile(backend="onnxrt")
|
|
... def f(x):
|
|
... return x * x
|
|
... print(f(torch.randn(10)))
|
|
... else:
|
|
... print("pip install onnx onnxscript onnxruntime")
|
|
...
|
|
"""
|
|
return _SUPPORT_ONNXRT
|
|
|
|
|
|
_dumped_onnx_model: Dict[str, int] = {}
|
|
|
|
|
|
def _dump_onnx_model(
|
|
model_string: bytes, graph_module: Optional[torch.fx.GraphModule] = None
|
|
) -> str:
|
|
"""Stores the onnx model into a file.
|
|
The name is "{ONNXRT_DUMP_PATH}{N}.onnx"
|
|
where *N* is the number of files already stored with
|
|
this prefix.
|
|
If graph_module is not None, the graph is stored as a string with
|
|
the same filename except the extension (.txt).
|
|
"""
|
|
prefix = os.environ.get("ONNXRT_DUMP_PATH", None)
|
|
if not prefix:
|
|
return ""
|
|
n = _dumped_onnx_model.get(prefix, -1) + 1
|
|
filename = f"{prefix}{n}.onnx"
|
|
with open(filename, "wb") as f:
|
|
f.write(model_string)
|
|
_dumped_onnx_model[prefix] = n
|
|
if graph_module is not None:
|
|
filename_txt = f"{prefix}{n}.txt"
|
|
with open(filename_txt, "w", encoding="utf-8") as f:
|
|
f.write(str(graph_module.graph))
|
|
return filename
|
|
|
|
|
|
def _infer_default_eps() -> Sequence[str]:
|
|
# TODO: select a good default based on the capabilities of the host
|
|
# e.g. DML on Windows, etc.
|
|
return ["CPUExecutionProvider"]
|
|
|
|
|
|
def _nvtx_range_push(name: str):
|
|
"""If PyTorch is installed with CUDA support, this starts NVTX range.
|
|
|
|
Check torch.cuda.nvtx.range_push's document for more details.
|
|
"""
|
|
if torch.cuda.is_available():
|
|
torch.cuda.nvtx.range_push(name)
|
|
|
|
|
|
def _nvtx_range_pop():
|
|
"""If PyTorch is installed with CUDA support, this terminates NVTX range.
|
|
|
|
Check torch.cuda.nvtx.range_pop's document for more details.
|
|
"""
|
|
if torch.cuda.is_available():
|
|
torch.cuda.nvtx.range_pop()
|
|
|
|
|
|
def _get_ort_device_type(device_type: str):
|
|
if device_type == "cuda":
|
|
return ORTC.OrtDevice.cuda()
|
|
if device_type == "cpu":
|
|
return ORTC.OrtDevice.cpu()
|
|
# ort pytorch device is mapped to NPU OrtDevice type
|
|
if device_type == "ort":
|
|
return ORTC.OrtDevice.npu()
|
|
raise ValueError("Unsupported device type: " + device_type)
|
|
|
|
|
|
logger = logging.getLogger(__name__)
|
|
# Uncomment the following lines to print out development info.
|
|
# logging.basicConfig(level=logging.WARNING)
|
|
# logger.setLevel(logging.WARNING)
|
|
|
|
|
|
class OrtOperatorSupport(OperatorSupport):
|
|
"""Operator support for ONNXRuntime backend.
|
|
|
|
It has two-level of support decision. One is via support_dict and the other one
|
|
is via extra_support_dict. The logic of using support_dict is implemented in
|
|
OrtOperatorSupport and extra_support_dict is used by OperatorSupport.is_node_supported.
|
|
"""
|
|
|
|
def __init__(self, support_dict: Set[Any], extra_support_dict: Dict[str, Any]):
|
|
# Use extra_support_dict[op_name] = None to indicate
|
|
# we support op_name with all input types. Otherwise,
|
|
# see support_dict (type: SupportDict) in operator_support.py
|
|
# for specifying supported types.
|
|
super().__init__(extra_support_dict)
|
|
self._onnx_support_dict = support_dict
|
|
|
|
def is_node_supported(
|
|
self, submodules: Mapping[str, torch.nn.Module], node: torch.fx.Node
|
|
) -> bool:
|
|
# OperatorSupport.is_node_supported returns True for non-callable nodes.
|
|
# Since ORT can't execute them, we return False here to override the base
|
|
# behavior.
|
|
if node.op not in CALLABLE_NODE_OPS:
|
|
return False
|
|
# This is the and the only place to decide if aten op is supported.
|
|
if node.op == "call_function" and node.target in self._onnx_support_dict:
|
|
logger.warning(
|
|
"support_dict supports node.target: %s (type: %s)",
|
|
node.target,
|
|
type(node.target),
|
|
)
|
|
return True
|
|
# If node.target is not in support_dict, we still want to check if torch.jit.script
|
|
# can convert it to ONNX equivalence. Let's use base mechanism to do this.
|
|
# See extra_support_dict for supported ops.
|
|
if super().is_node_supported(submodules, node):
|
|
logger.warning(
|
|
"extra_support_dict supports node.target: %s (type: %s)",
|
|
node.target,
|
|
type(node.target),
|
|
)
|
|
return True
|
|
logger.warning(
|
|
"support_dict and extra_support_dict don't support node.target: %s (type: %s)",
|
|
node.target,
|
|
type(node.target),
|
|
)
|
|
return False
|
|
|
|
|
|
def _move_placeholder_to_front(graph_module: torch.fx.GraphModule) -> None:
|
|
"""
|
|
In torch.fx.Graph, placeholder is a special assignment node. If it's not
|
|
executed in the beginning, it could overwrite values computed by upstream
|
|
nodes.
|
|
"""
|
|
|
|
graph = graph_module.graph
|
|
placeholders = []
|
|
first_not_placeholder = None
|
|
for node in graph.nodes:
|
|
if node.op == "placeholder":
|
|
placeholders.append(node)
|
|
if first_not_placeholder is None and node.op != "placeholder":
|
|
first_not_placeholder = node
|
|
if first_not_placeholder is None:
|
|
return
|
|
for placeholder in placeholders:
|
|
first_not_placeholder.prepend(placeholder)
|
|
|
|
|
|
def _infer_ep_from_device(*args) -> Tuple[str, ...]:
|
|
"""Return the first valid device (i.e., GPU or CPU) in argument list."""
|
|
eps = []
|
|
for arg in args:
|
|
if hasattr(arg, "device"):
|
|
device = arg.device
|
|
if device.type == "cuda":
|
|
eps.append("CUDAExecutionProvider")
|
|
elif device.type == "cpu":
|
|
eps.append("CPUExecutionProvider")
|
|
return tuple(eps)
|
|
|
|
|
|
def _extract_graph_module_inputs(graph_module: torch.fx.GraphModule) -> Tuple[Any, ...]:
|
|
placeholders = []
|
|
for node in graph_module.graph.nodes:
|
|
if node.op == "placeholder":
|
|
if hasattr(node, "meta") and "val" in node.meta:
|
|
assert isinstance(node.meta["val"], torch.Tensor)
|
|
placeholders.append(node)
|
|
return tuple(placeholders)
|
|
|
|
|
|
def _extract_graph_module_outputs(graph_module: torch.fx.GraphModule) -> Any:
|
|
"""Collect "val" fields from outputs metadata in this torch.fx.GraphModule."""
|
|
for node in graph_module.graph.nodes:
|
|
if node.op == "output":
|
|
# Output node is unique. Let's retrieve output values from
|
|
# this node's input list. And then just return.
|
|
return node.args[0]
|
|
raise ValueError("No output node found in this torch.fx.GraphModule.")
|
|
|
|
|
|
def _infer_ep_from_graph_module(graph_module: torch.fx.GraphModule) -> Tuple[str, ...]:
|
|
"""Return the all valid devices (i.e., GPU or CPU) among outputs of this torch.fx.GraphModule."""
|
|
flattened_output_args, _ = _pytree.tree_flatten(
|
|
_extract_graph_module_outputs(graph_module)
|
|
)
|
|
# Output arguments with example value (type: torch.Tensor) in the `graph_module`.
|
|
selected_output_args = [
|
|
output_arg.meta["val"]
|
|
for output_arg in flattened_output_args
|
|
# output_arg must have tensor for its device information.
|
|
# Otherwise, skip it.
|
|
if (hasattr(output_arg, "meta") and "val" in output_arg.meta)
|
|
]
|
|
return _infer_ep_from_device(*selected_output_args)
|
|
|
|
|
|
def _sort_eps(eps: Tuple[str, ...]) -> Tuple[str, ...]:
|
|
"""Sort execution providers in eps based on pre-set priority."""
|
|
|
|
def get_execution_provider_priority(ep: str) -> int:
|
|
if ep == "CPUExecutionProvider":
|
|
# Lowest priority.
|
|
return 2
|
|
if ep == "CUDAExecutionProvider":
|
|
# Higher priority than CPU but lower than
|
|
# other specialized EPs.
|
|
return 1
|
|
# Highest priority.
|
|
return 0
|
|
|
|
unique_eps = set(eps)
|
|
return tuple(sorted(unique_eps, key=get_execution_provider_priority, reverse=True))
|
|
|
|
|
|
def _get_onnx_devices(
|
|
values: Tuple[
|
|
Union[
|
|
torch.Tensor, torch.SymInt, int, torch.SymFloat, float, torch.SymBool, bool
|
|
],
|
|
...,
|
|
]
|
|
) -> Tuple["ORTC.OrtDevice", ...]:
|
|
def _device_id_or_zero(device_id: int) -> int:
|
|
return device_id or 0
|
|
|
|
def _map_tensor_or_sym_to_device(
|
|
value: Union[
|
|
torch.Tensor, torch.SymInt, int, torch.SymFloat, float, torch.SymBool, bool
|
|
],
|
|
) -> int:
|
|
if isinstance(value, torch.Tensor):
|
|
return ORTC.OrtDevice(
|
|
_get_ort_device_type(value.device.type),
|
|
ORTC.OrtDevice.default_memory(),
|
|
_device_id_or_zero(value.device.index),
|
|
)
|
|
elif isinstance(
|
|
value, (torch.SymInt, int, torch.SymFloat, float, torch.SymBool, bool)
|
|
):
|
|
return ORTC.OrtDevice(
|
|
_get_ort_device_type("cpu"), ORTC.OrtDevice.default_memory(), 0
|
|
)
|
|
else:
|
|
raise ValueError("Unsupported value type: " + str(type(value)))
|
|
|
|
if len(values) > 0:
|
|
ort_devices = tuple(_map_tensor_or_sym_to_device(value) for value in values)
|
|
return ort_devices
|
|
else:
|
|
return (_map_tensor_or_sym_to_device(1),)
|
|
|
|
|
|
def _get_ortvalues_from_torch_tensors(
|
|
tensors: Tuple[torch.Tensor, ...], devices: Tuple["ORTC.OrtDevice", ...]
|
|
) -> Tuple[torch.Tensor, ...]:
|
|
ortvalues = ORTC.OrtValueVector()
|
|
ortvalues.reserve(len(tensors))
|
|
dtypes = []
|
|
shapes = []
|
|
data_ptrs = []
|
|
|
|
for tensor in tensors:
|
|
dtypes.append(_TORCH_DTYPE_TO_NUMPY_DTYPE[tensor.dtype])
|
|
shapes.append(tensor.size())
|
|
data_ptrs.append(tensor.data_ptr())
|
|
ortvalues.push_back_batch(tensors, data_ptrs, dtypes, shapes, devices)
|
|
return ortvalues
|
|
|
|
|
|
def _to_real_tensor(tensor: FakeTensor) -> torch.Tensor:
|
|
if tensor.is_sparse:
|
|
raise ValueError("sparse tensor is not yet supported.")
|
|
out = torch.empty(tensor.size(), dtype=tensor.dtype, device=tensor.device)
|
|
return out
|
|
|
|
|
|
def _adjust_scalar_from_fx_to_onnx(
|
|
dynamo_value: Union[
|
|
torch.Tensor,
|
|
int,
|
|
float,
|
|
bool,
|
|
],
|
|
value_info: "onnx.ValueInfoProto", # type: ignore[name-defined]
|
|
) -> torch.Tensor:
|
|
"""Helper function to wrap PyTorch variables as torch.Tensor"""
|
|
if (
|
|
isinstance(dynamo_value, torch.Tensor)
|
|
and len(value_info.type.tensor_type.shape.dim) == 0
|
|
and dynamo_value.shape == (1,)
|
|
):
|
|
# ONNX expect a scalar with empty shape.
|
|
# In contrast, PyTorch usually allows implicit
|
|
# conversion between shape=() and shape=(1,).
|
|
#
|
|
# Below, PyTorch's shape (1,) is reshaped to ().
|
|
return torch.squeeze(dynamo_value)
|
|
elif isinstance(dynamo_value, int):
|
|
return torch.tensor(dynamo_value, dtype=torch.int64)
|
|
elif isinstance(dynamo_value, float):
|
|
return torch.tensor(dynamo_value, dtype=torch.float32)
|
|
elif isinstance(dynamo_value, bool):
|
|
return torch.tensor(dynamo_value, dtype=torch.bool)
|
|
else:
|
|
assert isinstance(dynamo_value, torch.Tensor)
|
|
return dynamo_value.contiguous()
|
|
|
|
|
|
def _adjust_scalar_from_onnx_to_fx(
|
|
tensor: torch.Tensor,
|
|
prim_value: Union[
|
|
torch.Tensor,
|
|
torch.SymInt,
|
|
int,
|
|
torch.SymFloat,
|
|
float,
|
|
torch.SymBool,
|
|
bool,
|
|
],
|
|
) -> Union[torch.Tensor, int, float, bool,]:
|
|
"""Helper function to wrap ORT-produced torch.Tensor as PyTorch variables"""
|
|
assert isinstance(tensor, torch.Tensor), "ORT's output must be tensor."
|
|
if isinstance(
|
|
prim_value,
|
|
(torch.SymInt, int, torch.SymFloat, float, torch.SymBool, bool),
|
|
):
|
|
# Convert tensor back to scalar to match Dynamo's expectation.
|
|
return tensor.item()
|
|
return tensor
|
|
|
|
|
|
def _run_onnx_session_with_ortvaluevector(
|
|
sess: "onnxruntime.InferenceSession",
|
|
input_names: Tuple[str, ...],
|
|
inputs: Tuple[torch.Tensor, ...],
|
|
input_devices: Tuple["ORTC.OrtDevice", ...],
|
|
output_names: Tuple[str, ...],
|
|
outputs: Tuple[torch.Tensor, ...],
|
|
output_devices: Tuple["ORTC.OrtDevice", ...],
|
|
preallocate_output: bool,
|
|
input_value_infos: Tuple["onnx.ValueInfoProto", ...], # type: ignore[name-defined]
|
|
normalized_prim_outputs: Tuple[
|
|
Union[
|
|
torch.Tensor, torch.SymInt, int, torch.SymFloat, float, torch.SymBool, bool
|
|
],
|
|
...,
|
|
],
|
|
) -> Tuple[Union[torch.Tensor, int, float, bool], ...]:
|
|
_nvtx_range_push("contiguous")
|
|
inputs = tuple(
|
|
_adjust_scalar_from_fx_to_onnx(arg, value_info)
|
|
for arg, value_info in zip(inputs, input_value_infos)
|
|
)
|
|
_nvtx_range_pop()
|
|
|
|
_nvtx_range_push("push_back_batch")
|
|
ort_inputs = _get_ortvalues_from_torch_tensors(inputs, input_devices)
|
|
|
|
# preallocate output pytorch Tensors and use the buffers affined to the torch device for the output ortvalue.
|
|
# Because the output ortvalue is not allocated and owned by ort, it does not need to convert the output ortvalue
|
|
# to torch Tensor transferring the ownership.
|
|
if preallocate_output:
|
|
pth_outputs = tuple(
|
|
_to_real_tensor(t) if isinstance(t, FakeTensor) else t for t in outputs
|
|
)
|
|
ort_outputs = _get_ortvalues_from_torch_tensors(pth_outputs, output_devices)
|
|
else:
|
|
ort_outputs = ORTC.OrtValueVector()
|
|
_nvtx_range_pop()
|
|
|
|
_nvtx_range_push("run_with_ortvaluevector")
|
|
run_options = onnxruntime.RunOptions()
|
|
run_options.add_run_config_entry("disable_synchronize_execution_providers", "1")
|
|
sess.run_with_ortvaluevector(
|
|
run_options, input_names, ort_inputs, output_names, ort_outputs, output_devices
|
|
)
|
|
_nvtx_range_pop()
|
|
|
|
# Post-processing step:
|
|
# wrap ORT's outputs to the schema represented by
|
|
# `prim_output` (obtained by running the original
|
|
# torch.fx.GraphModule).
|
|
if preallocate_output:
|
|
# Profile the ORT-to-PyTorch type cast below
|
|
_nvtx_range_push("after run_with_ortvaluevector")
|
|
# Outputs are stored on pre-allocated torch.Tensors' memory,
|
|
# so this case doesn't need to convert ORTValue to torch.Tensor.
|
|
pth_outputs = tuple(
|
|
_adjust_scalar_from_onnx_to_fx(onnx_output, prim_output) # type: ignore[misc]
|
|
for onnx_output, prim_output in zip(pth_outputs, normalized_prim_outputs)
|
|
)
|
|
_nvtx_range_pop()
|
|
return pth_outputs
|
|
else:
|
|
# Profile the two ORT-to-PyTorch type casts below
|
|
_nvtx_range_push("after run_with_ortvaluevector")
|
|
# Map ORTValue to torch.Tensor.
|
|
pth_outputs = onnxruntime.training.ortmodule._utils._ortvalues_to_torch_tensor(
|
|
ort_outputs
|
|
)
|
|
# Change some torch.Tensor to int, float, bool.
|
|
pth_outputs = tuple(
|
|
_adjust_scalar_from_onnx_to_fx(onnx_output, prim_output) # type: ignore[misc]
|
|
for onnx_output, prim_output in zip(pth_outputs, normalized_prim_outputs)
|
|
)
|
|
_nvtx_range_pop()
|
|
return pth_outputs
|
|
|
|
|
|
def _run_onnx_session_with_fetch(
|
|
sess: "onnxruntime.InferenceSession",
|
|
input_names: Tuple[str, ...],
|
|
inputs: Tuple[torch.Tensor, ...],
|
|
input_devices: Tuple["ORTC.OrtDevice", ...],
|
|
output_names: Tuple[str, ...],
|
|
outputs: Tuple[torch.Tensor, ...],
|
|
output_devices: Tuple["ORTC.OrtDevice", ...],
|
|
preallocate_output: bool,
|
|
input_value_infos: Tuple["onnx.ValueInfoProto", ...], # type: ignore[name-defined]
|
|
normalized_prim_outputs: Tuple[
|
|
Union[
|
|
torch.Tensor, torch.SymInt, int, torch.SymFloat, float, torch.SymBool, bool
|
|
],
|
|
...,
|
|
],
|
|
) -> Tuple[Union[torch.Tensor, int, float, bool], ...]:
|
|
inputs = tuple(
|
|
_adjust_scalar_from_fx_to_onnx(arg, value_info)
|
|
for arg, value_info in zip(inputs, input_value_infos)
|
|
)
|
|
feed = {
|
|
name: onnxruntime.OrtValue.ortvalue_from_numpy(tensor.cpu().numpy())
|
|
for name, tensor in zip(input_names, inputs)
|
|
}
|
|
ort_outputs = sess.run(output_names, feed)
|
|
pth_outputs = tuple(
|
|
_adjust_scalar_from_onnx_to_fx(
|
|
torch.from_numpy(value),
|
|
prim_output,
|
|
)
|
|
for value, prim_output in zip(ort_outputs, normalized_prim_outputs)
|
|
)
|
|
return pth_outputs
|
|
|
|
|
|
class OrtExecutionInfoPerSession:
|
|
"""Information required to execute torch.fx.GraphModule using onnxruntime.InferenceSession"""
|
|
|
|
def __init__(
|
|
self,
|
|
session: "onnxruntime.InferenceSession",
|
|
input_names: Tuple[str, ...],
|
|
input_value_infos: Tuple["onnx.ValueInfoProto", ...], # type: ignore[name-defined]
|
|
output_names: Tuple[str, ...],
|
|
output_value_infos: Tuple["onnx.ValueInfoProto", ...], # type: ignore[name-defined]
|
|
input_devices: Tuple["ORTC.OrtDevice", ...],
|
|
output_devices: Tuple["ORTC.OrtDevice", ...],
|
|
example_outputs: Union[Tuple[torch.Tensor, ...], torch.Tensor],
|
|
):
|
|
# Carrier of ONNX model and its executor.
|
|
self.session: onnxruntime.InferenceSession = session
|
|
# For the ONNX model stored in self.session, self.input_names[i] is the
|
|
# name of the i-th positional input.
|
|
self.input_names: Tuple[str, ...] = input_names
|
|
# self.input_name[i]'s type information is stored in self.input_value_infos[i].
|
|
self.input_value_infos: Tuple[onnx.ValueInfoProto, ...] = input_value_infos # type: ignore[name-defined]
|
|
# Similar to self.input_names, but for outputs.
|
|
self.output_names: Tuple[str, ...] = output_names
|
|
# Similar to self.input_value_infos but for outputs.
|
|
self.output_value_infos: Tuple[onnx.ValueInfoProto, ...] = output_value_infos # type: ignore[name-defined]
|
|
# For the ONNX model stored in self.session, self.input_devices[i] is the
|
|
# i-th positional input's device.
|
|
self.input_devices: Tuple["ORTC.OrtDevice", ...] = input_devices
|
|
# Similar to self.input_devices, but for outputs.
|
|
self.output_devices: Tuple["ORTC.OrtDevice", ...] = output_devices
|
|
# This is the outputs of executing the original torch.fx.GraphModule with example inputs
|
|
# (i.e., args passed into OrtBackend._ort_acclerated_call).
|
|
self.example_outputs: Union[
|
|
Tuple[torch.Tensor, ...], torch.Tensor
|
|
] = example_outputs
|
|
|
|
def is_supported(self, *args):
|
|
# Compare the args and the input schema in ONNX model and
|
|
# return the first match.
|
|
if len(args) != len(self.input_value_infos):
|
|
return False
|
|
for arg, value_info in zip(args, self.input_value_infos):
|
|
if not isinstance(arg, (torch.Tensor, float, int)):
|
|
return False
|
|
|
|
# Check Python scalars such as int, float, and bool.
|
|
if isinstance(arg, (int, float, bool)):
|
|
# Map, e.g., float to onnx.TensorProto.FLOAT.
|
|
onnx_dtype = from_python_type_to_onnx_tensor_element_type(type(arg))
|
|
if onnx_dtype != value_info.type.tensor_type.elem_type:
|
|
return False
|
|
if len(value_info.type.tensor_type.shape.dim) != 0:
|
|
return False
|
|
continue
|
|
|
|
# Check tensor.
|
|
onnx_dtype = _TORCH_DTYPE_TO_ONNX_TENSOR_ELEMENT_TYPE[arg.dtype]
|
|
if onnx_dtype != value_info.type.tensor_type.elem_type:
|
|
return False
|
|
for dim, onnx_dim in zip(arg.shape, value_info.type.tensor_type.shape.dim):
|
|
if isinstance(dim, int) and (
|
|
onnx_dim.dim_value == dim or onnx_dim.dim_param
|
|
):
|
|
continue
|
|
elif isinstance(dim, torch.SymInt) and onnx_dim.dim_param:
|
|
continue
|
|
else:
|
|
return False
|
|
return True
|
|
|
|
|
|
@dataclasses.dataclass
|
|
class OrtExecutionInfoForAllGraphModules:
|
|
def __init__(self):
|
|
# All sessions (and their related information) created by exporting the same GraphModule
|
|
# with different inputs.
|
|
self.execution_info_per_graph_module: Dict[
|
|
torch.fx.GraphModule, List[OrtExecutionInfoPerSession]
|
|
] = {}
|
|
|
|
def search_reusable_session_execution_info(
|
|
self, graph_module: torch.fx.GraphModule, *args
|
|
):
|
|
if graph_module not in self.execution_info_per_graph_module:
|
|
return None
|
|
# All execution information for ONNX models exported from the same `graph_module`
|
|
# with different inputs.
|
|
candidates = self.execution_info_per_graph_module[graph_module]
|
|
|
|
for candidate in candidates:
|
|
if candidate.is_supported(*args):
|
|
# Returns the first session that accepts this input schema.
|
|
return candidate
|
|
# No reusable session found.
|
|
return None
|
|
|
|
def cache_session_execution_info(
|
|
self, graph_module: torch.fx.GraphModule, info: OrtExecutionInfoPerSession
|
|
):
|
|
if graph_module not in self.execution_info_per_graph_module:
|
|
self.execution_info_per_graph_module[graph_module] = [info]
|
|
else:
|
|
self.execution_info_per_graph_module[graph_module].append(info)
|
|
|
|
|
|
OrtExecutionProvider: TypeAlias = Union[str, Tuple[str, Mapping[str, Any]]]
|
|
"""Either the name of an ONNX Runtime execution provider as a string or
|
|
a 2-tuple of the name and a dictionary of execution provider options.
|
|
|
|
Examples::
|
|
|
|
>>> "CPUExecutionProvider"
|
|
|
|
>>> ("CUDAExecutionProvider", {"device_id": 3})
|
|
|
|
"""
|
|
|
|
|
|
@dataclasses.dataclass(frozen=True)
|
|
@compatibility(is_backward_compatible=False)
|
|
class OrtBackendOptions:
|
|
"""Options for constructing an ``OrtBackend``, the ONNX Runtime
|
|
backend (``"onnxrt"``) for ``torch.compile``.
|
|
|
|
Example::
|
|
|
|
>>> @torch.compile(
|
|
... backend="onnxrt",
|
|
... options=torch.onnx._OrtBackendOptions(...),
|
|
... )
|
|
... def ort_function(x):
|
|
... return x ** x
|
|
"""
|
|
|
|
preferred_execution_providers: Optional[Sequence[OrtExecutionProvider]] = None
|
|
"""An optional sequence of execution providers to be prioritized ahead of any
|
|
execution providers that may be inferred (see ``infer_execution_providers``).
|
|
"""
|
|
|
|
infer_execution_providers: bool = True
|
|
"""Whether to infer an execution provider from ``torch.device`` bound to inputs or found in the graph."""
|
|
|
|
default_execution_providers: Optional[Sequence[OrtExecutionProvider]] = None
|
|
"""The default fallback execution providers. If not specified, one will be
|
|
be selected based on the host environment (most likely ``"CPUExecutionProvider"``).
|
|
"""
|
|
|
|
# preallocate_output allows for allocating output torch Tensor buffers and feeding them to InferenceSession
|
|
# in order to avoid internal allocation of output buffers in InferenceSession.
|
|
# If output ortvalue returned from InferenceSession is allocated internally,
|
|
# it needs to be converted to torch Tensor for return, and the torch Tensor should hold the ownership.
|
|
# When a custom torch device is used with a custom aten allocator, the conversion from ortvalue to torch Tensor
|
|
# should be supported, which is currently done through dlpack. Note that dlpack might not support a custom torch device.
|
|
# It can be avoided by allowing for preallocation for output buffers allocated by a custom aten allocator,
|
|
# and use the preallocated output buffers for InferenceSession not holding any ownership for them.
|
|
# TODO(wschin): Make it to inference session level flag.
|
|
# See https://github.com/pytorch/pytorch/issues/106869.
|
|
preallocate_output: bool = False
|
|
"""If ``True``, allocate memory for ONNX Runtime's outputs on the PyTorch side."""
|
|
|
|
use_aot_autograd: bool = True
|
|
"""Whether to wrap the ``OrtBackend`` with TorchDynamo's aot_autograd backend
|
|
to support training (i.e., backward graphs are also sent to ``OrtBackend``).
|
|
|
|
Symbolic execution is used to capture the forward pass and backward passes as a single graph.
|
|
Then, a selected graph partition algorithm (``min_cut_rematerialization_partition``) is used
|
|
to split the entire graph into forward sub-graph and backward sub-graph. Finally, both
|
|
sub-graphs are compiled by ``OrtBackend``.
|
|
"""
|
|
|
|
export_options: Optional["torch.onnx.ExportOptions"] = None
|
|
"""Options for the TorchDynamo-based ONNX exporter used by the ``OrtBackend``."""
|
|
|
|
ort_session_options: Optional["onnxruntime.SessionOptions"] = None
|
|
"""Options for the ``onnxruntime.InferenceSession`` used by the ``OrtBackend``."""
|
|
|
|
pre_ort_model_transforms: Optional[ # type: ignore[name-defined]
|
|
Sequence[Callable[["onnx.ModelProto"], None]]
|
|
] = None
|
|
"""A list of graph transforms to be applied to the ONNX model before it
|
|
is fed to ONNXRuntime's InferenceSession."""
|
|
|
|
|
|
@compatibility(is_backward_compatible=False)
|
|
class OrtBackend:
|
|
"""A backend compiles (sub-)graphs in torch.fx.GraphModule to onnxruntime.InferenceSession calls.
|
|
|
|
The compiler entry point is OrtBackend.compile, which
|
|
1. partitions the original graph into supported sub-graphs (type: torch.fx.GraphModule) and unsupported
|
|
sub-graphs.
|
|
2. For each supported sub-graph, it replaces its _wrapped_call function with _ort_accelerated_call.
|
|
3. Inside _ort_accelerated_call, it creates onnxruntime.InferenceSession and calls it to execute the sub-graph.
|
|
"""
|
|
|
|
def __init__(self, options: Optional[OrtBackendOptions] = None):
|
|
self._options: Final = OrtBackendOptions() if options is None else options
|
|
|
|
# options.export_options contains information shared between exporter and DORT.
|
|
# For example, they should use the same decomposition table when
|
|
# 1. capturing FX graph in torch.compile (see how we create aot_ort in register_backend.py)
|
|
# 2. call exporter's API to convert `torch.fx.GraphModule` to ONNX model
|
|
# (see onnxfunction_dispatcher passed to FxOnnxInterpreter.run below).
|
|
#
|
|
# Convert user-facing option to internal option used by ONNX exporter
|
|
# to access required information.
|
|
# Some useful fields:
|
|
# - Decomposition table for decomposing FX operators in exporter is
|
|
# self._resolved_onnx_exporter_options.decomposition_table.
|
|
# - self._resolved_onnx_exporter_options.onnx_registry records what
|
|
# aten/prim ops are supported by exporter and their exporters (type: callable).
|
|
self._resolved_onnx_exporter_options = (
|
|
torch.onnx._internal.exporter.ResolvedExportOptions(
|
|
torch.onnx.ExportOptions()
|
|
if self._options.export_options is None
|
|
else self._options.export_options
|
|
)
|
|
)
|
|
|
|
# Given DORT's computation flow:
|
|
# 1. OrtOperatorSupport uses support_dict and extra_support_dict to select operators
|
|
# and send them to DORT.
|
|
# 2. Then, DORT exports the selected sub-graphs into ONNX.
|
|
# 3. Finally DORT calls ORT to do the computation.
|
|
# OrtOperatorSupport and create_onnx_friendly_decomposition_table(...)
|
|
# must use the same support_dict. If the support_dict here contains something not
|
|
# supported by exporter, exporter will fails in step 2 since the selected graphs may
|
|
# contains unsupported operators such as aten::_who_you_are.
|
|
# This restriction is automatically done since DORT and exporter shares the same
|
|
# self._resolved_onnx_exporter_options.
|
|
support_dict = torch.onnx._internal.fx.decomposition_table._create_onnx_supports_op_overload_table(
|
|
self._resolved_onnx_exporter_options.onnx_registry
|
|
)
|
|
|
|
extra_support_dict: Dict[str, Any] = {
|
|
"getattr": None,
|
|
# To send operator.getitem to ORT, add the corresponding string
|
|
# recognized by PyTorch's OperatorSupport class.
|
|
"_operator.getitem": None,
|
|
# To send operator.mul to ORT, add the corresponding string
|
|
# recognized by PyTorch's OperatorSupport class.
|
|
"_operator.mul": None,
|
|
"_operator.add": None,
|
|
"_operator.sub": None,
|
|
}
|
|
|
|
self._supported_ops = OrtOperatorSupport(support_dict, extra_support_dict)
|
|
# TODO(wschin): this is a naive implementation of cache without proper guard
|
|
# See https://github.com/pytorch/pytorch/issues/106868.
|
|
self._partitioner_cache: Dict[torch.fx.GraphModule, torch.fx.GraphModule] = {}
|
|
# Conceptually, this filed is a 2-layer dictionary
|
|
# GraphModule 0
|
|
# ONNX Model 0 (with ORT InferenceSession and related information. type: OrtExecutionInfoPerSession)
|
|
# ONNX Model 1
|
|
# ...
|
|
# GraphModule 1
|
|
# ONNX Model 2 (with ORT InferenceSession and related information. type: OrtExecutionInfoPerSession)
|
|
# ONNX Model 3
|
|
# ...
|
|
# ...
|
|
# , which caches all previous compilation result so that we can reuse them.
|
|
# ONNX Model 0 and 1 are exported from the same GraphModule 0 but with different inputs
|
|
# (e.g., tensors with different ranks). GraphModule 0 and GraphModule 1 are different
|
|
# graphs captured by Dynamo and sent to OrtBackend.compile.
|
|
self._all_ort_execution_info = OrtExecutionInfoForAllGraphModules()
|
|
|
|
self._assert_allclose_to_baseline = False
|
|
|
|
self.execution_count = 0
|
|
|
|
# Function which invokes ORT do to the real computation.
|
|
self.run = (
|
|
_run_onnx_session_with_ortvaluevector
|
|
if hasattr(ORTC.OrtValueVector, "push_back_batch")
|
|
else _run_onnx_session_with_fetch
|
|
)
|
|
|
|
def _select_eps(
|
|
self, graph_module: torch.fx.GraphModule, *args
|
|
) -> Sequence[Tuple[str, Mapping[str, Any]]]:
|
|
inferred_eps: Tuple[str, ...] = tuple()
|
|
if self._options.infer_execution_providers:
|
|
if eps_from_args := _infer_ep_from_device(*args):
|
|
# If user feeds CUDA tensor as input argument,
|
|
# we want to use CUDA EP.
|
|
# Thus, `eps_from_args` (deduced from input arguments)
|
|
# has highest priority.
|
|
inferred_eps = eps_from_args
|
|
elif eps_from_graph_module := _infer_ep_from_graph_module(graph_module):
|
|
# If there is no EP in input arguments, we deduce EP from
|
|
# graph_module's outputs. Those outputs may come from
|
|
# FakeTensorProp or Dynamo's built-in symbolic shape inference.
|
|
inferred_eps = eps_from_graph_module
|
|
|
|
selected_eps = []
|
|
|
|
for ep in (
|
|
*(self._options.preferred_execution_providers or []),
|
|
*_sort_eps(inferred_eps),
|
|
*(self._options.default_execution_providers or _infer_default_eps()),
|
|
):
|
|
if isinstance(ep, str):
|
|
ep = (ep, {})
|
|
elif isinstance(ep, tuple) and ep[1] is None:
|
|
ep = (ep[0], {})
|
|
if ep is not None and ep not in selected_eps:
|
|
selected_eps.append(ep)
|
|
|
|
return selected_eps
|
|
|
|
def _ort_acclerated_call(self, graph_module: torch.fx.GraphModule, *args, **kwargs):
|
|
"""This function replaces GraphModule._wrapped_call in compiled model.
|
|
|
|
The _wrapped_call is the underlying implementation of forward method. Replacing
|
|
it means we delegate the computation to _ort_acclerated_call and therefore
|
|
onnxruntime.InferenceSession.
|
|
"""
|
|
cached_execution_info_per_session = (
|
|
self._all_ort_execution_info.search_reusable_session_execution_info(
|
|
graph_module, *args
|
|
)
|
|
)
|
|
if cached_execution_info_per_session:
|
|
onnx_session = cached_execution_info_per_session.session
|
|
input_names = cached_execution_info_per_session.input_names
|
|
output_names = cached_execution_info_per_session.output_names
|
|
input_value_infos = cached_execution_info_per_session.input_value_infos
|
|
output_value_infos = cached_execution_info_per_session.output_value_infos
|
|
input_devices = cached_execution_info_per_session.input_devices
|
|
output_devices = cached_execution_info_per_session.output_devices
|
|
prim_outputs = cached_execution_info_per_session.example_outputs
|
|
else:
|
|
# It's first time seeing such as graph. Let's make a new session
|
|
# (type: onnxruntime.InferenceSession) for it.
|
|
|
|
graph_module = torch.onnx._internal.fx.passes.MovePlaceholderToFront(
|
|
self._resolved_onnx_exporter_options.diagnostic_context,
|
|
graph_module,
|
|
).run()
|
|
# Generate reference outputs. They are used to indicate output
|
|
# tensors' types and devices when calling ORT.
|
|
#
|
|
# WARNING: The downstream code should not change prim_outputs and
|
|
# this backend should always produces output with schema identical to prim_outputs'.
|
|
|
|
if self._resolved_onnx_exporter_options.dynamic_shapes:
|
|
# No pre-allocation when dynamic shape is enabled.
|
|
self.preallocate_output = False
|
|
extracted_outputs = _extract_graph_module_outputs(graph_module)
|
|
|
|
def maybe_map_to_meta_val(value):
|
|
if hasattr(value, "meta") and "val" in value.meta:
|
|
# Select outputs with "val" information. Without "val",
|
|
# it's not possible access output_arg.meta["val"].device.
|
|
return value.meta["val"]
|
|
else:
|
|
return value
|
|
|
|
prim_outputs = _pytree.tree_map(
|
|
maybe_map_to_meta_val, extracted_outputs
|
|
)
|
|
else:
|
|
try:
|
|
prim_outputs = FakeTensorProp(graph_module).propagate(
|
|
*args, **kwargs
|
|
)
|
|
except Exception:
|
|
logger.warning("FakeTensorProb failed for %s", graph_module)
|
|
# When FakeTensorProp fails, it is not possible to preallocate output buffers
|
|
# because the output shapes are not inferred.
|
|
self.preallocate_output = False
|
|
|
|
# rethrow FakeTensorProb failure because it is not yet currently handled.
|
|
raise
|
|
|
|
# Create the object to iterate through the nodes in graph one-by-one
|
|
# and calls the corresponding ONNX exporter for each node.
|
|
fx_interpreter = fx_onnx_interpreter.FxOnnxInterpreter(
|
|
diagnostic_context=self._resolved_onnx_exporter_options.diagnostic_context
|
|
)
|
|
# Cast FX variables if they will result schema-mismatch when searching
|
|
# for ONNX operator. E.g., add(double_tensor, int_tensor) is fine in PyTorch,
|
|
# but ONNX expects add(double_tensor, double_tensor).
|
|
graph_module = torch.onnx._internal.fx.passes.InsertTypePromotion(
|
|
self._resolved_onnx_exporter_options.diagnostic_context, graph_module
|
|
).run()
|
|
# Start the per-node exporting process. It's conceptually a for loop
|
|
# scanning through the nodes in the graph.
|
|
exported = fx_interpreter.run(
|
|
fx_graph_module=graph_module,
|
|
onnxfunction_dispatcher=self._resolved_onnx_exporter_options.onnxfunction_dispatcher,
|
|
op_level_debug=self._resolved_onnx_exporter_options.op_level_debug,
|
|
)
|
|
# Convert the exported result to ONNX ModelProto.
|
|
onnx_model = exported.to_model_proto(
|
|
opset_version=self._resolved_onnx_exporter_options.onnx_registry.opset_version,
|
|
)
|
|
|
|
# Modify ONNX model using pre-registered graph transforms.
|
|
# They are in-place modifications for avoiding unnecessary
|
|
# copy of ONNX initializers.
|
|
if self._options.pre_ort_model_transforms:
|
|
for transform in self._options.pre_ort_model_transforms:
|
|
transform(onnx_model)
|
|
|
|
onnx_model_bytes = onnx_model.SerializeToString()
|
|
if os.environ.get("ONNXRT_DUMP_PATH", None):
|
|
# If not empty, environment variable ONNXRT_DUMP_PATH defined the path
|
|
# where generated onnx files should be stored.
|
|
# This module keeps a global variables keeping track of the
|
|
# stored models.
|
|
# If ONNXRT_DUMP_PATH="dumped/dumped_model_"
|
|
# The first file name will be 'dumped/dumped_model_0.onnx'.
|
|
# For every dumped model, a text file 'dumped/dumped_model_0.txt'
|
|
# is created as well to contain the string representing the graph_module.
|
|
_dump_onnx_model(onnx_model_bytes, graph_module=graph_module)
|
|
|
|
# Initialize a ORT session to execute this ONNX model.
|
|
# Note that TorchDynamo assumes all inputs/outputs are on the
|
|
# same device, but it's subject to change (very likely with
|
|
# dynamic shape support), so we add execution providers
|
|
# based on the logic in _select_eps: (explicitly preferred EPs,
|
|
# EPs inferred from inputs or graph, and the fallback default EP)/
|
|
#
|
|
# TODO(wschin): enable external allocators.
|
|
# See https://github.com/pytorch/pytorch/issues/106867
|
|
onnx_session = onnxruntime.InferenceSession(
|
|
path_or_bytes=onnx_model_bytes,
|
|
sess_options=self._options.ort_session_options,
|
|
providers=self._select_eps(graph_module, *args),
|
|
)
|
|
|
|
# Cache ORT session. It's reused for the same "graph_module".
|
|
# Generate ONNX model and extract its input and output names.
|
|
input_names = tuple(input.name for input in onnx_model.graph.input)
|
|
output_names = tuple(output.name for output in onnx_model.graph.output)
|
|
input_devices = _get_onnx_devices(args)
|
|
# Cache devices for inputs and outputs. They are used to invoke
|
|
# ORT session. Output devices indicate where (e.g., GPU or CPU)
|
|
# to store outputs
|
|
if isinstance(prim_outputs, tuple):
|
|
output_devices = _get_onnx_devices(prim_outputs)
|
|
else:
|
|
output_devices = _get_onnx_devices((prim_outputs,))
|
|
|
|
input_value_infos = tuple(input for input in onnx_model.graph.input)
|
|
output_value_infos = tuple(output for output in onnx_model.graph.output)
|
|
|
|
execution_info_per_session = OrtExecutionInfoPerSession(
|
|
session=onnx_session,
|
|
input_names=input_names,
|
|
input_value_infos=input_value_infos,
|
|
output_names=output_names,
|
|
output_value_infos=output_value_infos,
|
|
input_devices=input_devices,
|
|
output_devices=output_devices,
|
|
example_outputs=prim_outputs,
|
|
)
|
|
|
|
self._all_ort_execution_info.cache_session_execution_info(
|
|
graph_module, execution_info_per_session
|
|
)
|
|
|
|
self.execution_count += 1
|
|
|
|
# ORT always returns a tuple of outputs. If the original output is a tensor,
|
|
# ORT output's first element must be extracted and returned. Otherwise, type
|
|
# mismatch may happen in downstream computation.
|
|
is_single_tensor_output = isinstance(prim_outputs, torch.Tensor)
|
|
normalized_prim_outputs = (
|
|
(prim_outputs,) if is_single_tensor_output else prim_outputs
|
|
)
|
|
assert isinstance(normalized_prim_outputs, tuple)
|
|
assert all(
|
|
isinstance(elem, (torch.Tensor, torch.SymInt, int))
|
|
for elem in normalized_prim_outputs
|
|
)
|
|
|
|
_nvtx_range_push("run_onnx_session_with_ortvaluevector")
|
|
onnx_outputs = self.run(
|
|
onnx_session,
|
|
input_names,
|
|
args,
|
|
input_devices,
|
|
output_names,
|
|
normalized_prim_outputs,
|
|
output_devices,
|
|
self._options.preallocate_output,
|
|
input_value_infos,
|
|
normalized_prim_outputs,
|
|
)
|
|
_nvtx_range_pop()
|
|
|
|
if self._assert_allclose_to_baseline:
|
|
# Compute baseline.
|
|
baseline_outputs = torch._prims.executor.execute(
|
|
graph_module, *args, executor="aten"
|
|
)
|
|
normalized_baseline_ouptuts = (
|
|
(baseline_outputs,) if is_single_tensor_output else baseline_outputs
|
|
)
|
|
# Ensure every output tensor is close to the corresponding baseline.
|
|
for onnx_output, baseline_output in zip(
|
|
onnx_outputs, normalized_baseline_ouptuts
|
|
):
|
|
torch.testing.assert_close(onnx_output, baseline_output)
|
|
return onnx_outputs[0] if is_single_tensor_output else onnx_outputs
|
|
|
|
def compile(self, graph_module: torch.fx.GraphModule, args) -> torch.fx.GraphModule:
|
|
# Deferred import since CapabilityBasedPartitioner is not decorated with
|
|
# @compatibility; importing it at the module level will result in the test
|
|
# failing: pytest test/test_fx.py -k test_public_api_surface
|
|
# because this module is imported into torch.onnx.
|
|
from torch.fx.passes.infra.partitioner import CapabilityBasedPartitioner
|
|
|
|
# FX graph based partitioning based on ONNX supported ops.
|
|
# Given a graph module
|
|
# GraphModule0
|
|
# node_0
|
|
# node_1
|
|
# node_2
|
|
# node_3
|
|
# node_4
|
|
# If only node_2 is not supported by ONNX, this graph module will be partitioned into
|
|
# GraphModule0
|
|
# GraphModule1
|
|
# node_0
|
|
# node_1
|
|
# node_2
|
|
# GraphModule2
|
|
# node_3
|
|
# node_4
|
|
# by calling CapabilityBasedPartitioner.partition_and_fuse.
|
|
# Then, GraphModule1's and GraphModule2's forward method (GraphModule._wrapped_call)
|
|
# will be replaced by OrtBackend._ort_accelerated_call to delegate computation to ORT.
|
|
if graph_module in self._partitioner_cache:
|
|
partitioned_prim_graph_module = self._partitioner_cache[graph_module]
|
|
else:
|
|
prim_graph_module = graph_module
|
|
partitioner = CapabilityBasedPartitioner(
|
|
prim_graph_module,
|
|
self._supported_ops,
|
|
allows_single_node_partition=True,
|
|
)
|
|
partitioned_prim_graph_module = partitioner.partition_and_fuse()
|
|
self._partitioner_cache[graph_module] = partitioned_prim_graph_module
|
|
|
|
# Overriding fused_module's __call__() function with ort_acclerated_call()
|
|
# This loop goes through all graph partitions (each of them is an ONNX-representable graph)
|
|
# and override their _wrapped_call function with _ort_accelerated_call.
|
|
# Inside _ort_accelerated_call, the partition's graph is exported into ONNX and executed by ORT.
|
|
for node in partitioned_prim_graph_module.graph.nodes:
|
|
# TODO(wschin): use a better way to identify fused submodule
|
|
# See https://github.com/pytorch/pytorch/issues/106872.
|
|
if node.op == "call_module" and "fused_" in node.name:
|
|
fused_module = getattr(partitioned_prim_graph_module, node.name)
|
|
# self.ort_acclerated_call is responsible for exporting graph to ONNX,
|
|
# creating ORT session, and running ORT session.
|
|
fused_module._wrapped_call = self._ort_acclerated_call
|
|
|
|
return partitioned_prim_graph_module
|
|
|
|
def __call__(
|
|
self, graph_module: torch.fx.GraphModule, args
|
|
) -> torch.fx.GraphModule:
|
|
"""If ``OrtBackendOptions.use_aot_autograd`` is ``True``, the `auto_autograd` compiler
|
|
will be invoked, wrapping this ``OrtBackend`` instance's ``compile`` method. Otherwise,
|
|
the ``compile`` method is invoked directly."""
|
|
if self._options.use_aot_autograd:
|
|
from functorch.compile import min_cut_rematerialization_partition
|
|
|
|
from torch._dynamo.backends.common import aot_autograd
|
|
|
|
return aot_autograd(
|
|
fw_compiler=self.compile,
|
|
partition_fn=min_cut_rematerialization_partition,
|
|
decompositions=self._resolved_onnx_exporter_options.decomposition_table,
|
|
)(graph_module, args)
|
|
|
|
return self.compile(graph_module, args)
|
|
|
|
__instance_cache_max_count: Final = 8
|
|
__instance_cache: Final[List["OrtBackend"]] = []
|
|
|
|
@staticmethod
|
|
def get_cached_instance_for_options(
|
|
options: Optional[Union[OrtBackendOptions, Mapping[str, Any]]] = None,
|
|
) -> "OrtBackend":
|
|
"""Returns a possibly cached instance of an ``OrtBackend``. If an existing
|
|
backend was created previously through this function with the same options,
|
|
it will be returned. Otherwise a new backend will be created, cached, and
|
|
returned.
|
|
|
|
Note: if ``options`` sets ``ort_session_options``, a new ``OrtBackend``
|
|
will always be returned, since ``onnxruntime.SessionOptions`` cannot
|
|
participate in caching."""
|
|
|
|
def reusable(a: OrtBackendOptions, b: OrtBackendOptions):
|
|
if (
|
|
a.preferred_execution_providers != b.preferred_execution_providers
|
|
or a.infer_execution_providers != b.infer_execution_providers
|
|
or a.default_execution_providers != b.default_execution_providers
|
|
or a.preallocate_output != b.preallocate_output
|
|
or a.use_aot_autograd != b.use_aot_autograd
|
|
or a.pre_ort_model_transforms != b.pre_ort_model_transforms
|
|
):
|
|
return False
|
|
|
|
# onnxruntime.SessionOptions is a pybind11 object, cannot be pickled,
|
|
# and holds too much potential state to reasonably check manually;
|
|
# ort_session_options is provided at all, the backend does not participate
|
|
# in caching.
|
|
if a.ort_session_options is not None or b.ort_session_options is not None:
|
|
return False
|
|
|
|
if a.export_options is b.export_options:
|
|
return True
|
|
|
|
# Similarly, some objects in ExportOptions are too stateful to use for
|
|
# caching. We should revisit this.
|
|
if a.export_options is not None and b.export_options is not None:
|
|
return (
|
|
a.export_options.dynamic_shapes == b.export_options.dynamic_shapes
|
|
and a.export_options.op_level_debug
|
|
== b.export_options.op_level_debug
|
|
and a.export_options.diagnostic_options
|
|
== b.export_options.diagnostic_options
|
|
and a.export_options.onnx_registry is b.export_options.onnx_registry
|
|
and a.export_options.fake_context is b.export_options.fake_context
|
|
)
|
|
|
|
# We can't account for how the two option sets may differ, so it's not safe to reuse.
|
|
return False
|
|
|
|
if not isinstance(options, OrtBackendOptions):
|
|
options = OrtBackendOptions(**(options or {}))
|
|
|
|
backend = next(
|
|
(b for b in OrtBackend.__instance_cache if reusable(b._options, options)),
|
|
None,
|
|
)
|
|
|
|
if backend is None:
|
|
assert (
|
|
len(OrtBackend.__instance_cache) < OrtBackend.__instance_cache_max_count
|
|
), (
|
|
f"No more than {OrtBackend.__instance_cache_max_count} instances of "
|
|
f"{OrtBackend} allowed. Please instantiate `{OrtBackend}` explicitly "
|
|
"to pass to `torch.compile`. "
|
|
"See https://github.com/pytorch/pytorch/pull/107973#discussion_r1306144795 "
|
|
"for discussion."
|
|
)
|
|
OrtBackend.__instance_cache.append(backend := OrtBackend(options))
|
|
|
|
return backend
|
|
|
|
@staticmethod
|
|
def clear_cached_instances():
|
|
OrtBackend.__instance_cache.clear()
|
|
|
|
@staticmethod
|
|
def get_cached_instances():
|
|
return tuple(OrtBackend.__instance_cache)
|
|
|
|
|
|
@compatibility(is_backward_compatible=False)
|
|
def torch_compile_backend(
|
|
graph_module: torch.fx.GraphModule,
|
|
args,
|
|
*,
|
|
options: Optional[Union[OrtBackendOptions, Mapping[str, Any]]] = None,
|
|
):
|
|
return OrtBackend.get_cached_instance_for_options(options)(graph_module, args)
|