340 lines
11 KiB
Python
340 lines
11 KiB
Python
|
"""Module for handling symbolic function registration."""
|
||
|
|
||
|
import warnings
|
||
|
from typing import (
|
||
|
Callable,
|
||
|
Collection,
|
||
|
Dict,
|
||
|
Generic,
|
||
|
Optional,
|
||
|
Sequence,
|
||
|
Set,
|
||
|
TypeVar,
|
||
|
Union,
|
||
|
)
|
||
|
|
||
|
from torch.onnx import _constants, errors
|
||
|
from torch.onnx._internal import _beartype
|
||
|
|
||
|
OpsetVersion = int
|
||
|
|
||
|
|
||
|
def _dispatch_opset_version(
|
||
|
target: OpsetVersion, registered_opsets: Collection[OpsetVersion]
|
||
|
) -> Optional[OpsetVersion]:
|
||
|
"""Finds the registered opset given a target opset version and the available opsets.
|
||
|
|
||
|
Args:
|
||
|
target: The target opset version.
|
||
|
registered_opsets: The available opsets.
|
||
|
|
||
|
Returns:
|
||
|
The registered opset version.
|
||
|
"""
|
||
|
if not registered_opsets:
|
||
|
return None
|
||
|
|
||
|
descending_registered_versions = sorted(registered_opsets, reverse=True)
|
||
|
# Linear search for the opset version, which is fine since the number of opset
|
||
|
# versions is small.
|
||
|
|
||
|
if target >= _constants.ONNX_BASE_OPSET:
|
||
|
# Always look down toward opset 1 when the target is >= ONNX_BASE_OPSET (opset 9).
|
||
|
# When a custom op is register at opset 1, we want to be able to discover it as a
|
||
|
# fallback for all opsets >= ONNX_BASE_OPSET.
|
||
|
for version in descending_registered_versions:
|
||
|
if version <= target:
|
||
|
return version
|
||
|
return None
|
||
|
|
||
|
# target < opset 9. This is the legacy behavior to support opset 7 and opset 8.
|
||
|
# for caffe2 support. We search up toward opset 9.
|
||
|
for version in reversed(descending_registered_versions):
|
||
|
# Count back up until _constants.ONNX_BASE_OPSET
|
||
|
if target <= version <= _constants.ONNX_BASE_OPSET:
|
||
|
return version
|
||
|
|
||
|
return None
|
||
|
|
||
|
|
||
|
_K = TypeVar("_K")
|
||
|
_V = TypeVar("_V")
|
||
|
|
||
|
|
||
|
class OverrideDict(Generic[_K, _V], Collection[_K]):
|
||
|
"""A dictionary that merges built-in and custom symbolic functions.
|
||
|
|
||
|
It supports overriding and un-overriding built-in symbolic functions with custom
|
||
|
ones.
|
||
|
"""
|
||
|
|
||
|
def __init__(self):
|
||
|
self._base: Dict[_K, _V] = {}
|
||
|
self._overrides: Dict[_K, _V] = {}
|
||
|
self._merged: Dict[_K, _V] = {}
|
||
|
|
||
|
def set_base(self, key: _K, value: _V) -> None:
|
||
|
self._base[key] = value
|
||
|
if key not in self._overrides:
|
||
|
self._merged[key] = value
|
||
|
|
||
|
def in_base(self, key: _K) -> bool:
|
||
|
"""Checks if a key is in the base dictionary."""
|
||
|
return key in self._base
|
||
|
|
||
|
def override(self, key: _K, value: _V) -> None:
|
||
|
"""Overrides a base key-value with a new pair."""
|
||
|
self._overrides[key] = value
|
||
|
self._merged[key] = value
|
||
|
|
||
|
def remove_override(self, key: _K) -> None:
|
||
|
"""Un-overrides a key-value pair."""
|
||
|
self._overrides.pop(key, None) # type: ignore[arg-type]
|
||
|
self._merged.pop(key, None) # type: ignore[arg-type]
|
||
|
if key in self._base:
|
||
|
self._merged[key] = self._base[key]
|
||
|
|
||
|
def overridden(self, key: _K) -> bool:
|
||
|
"""Checks if a key-value pair is overridden."""
|
||
|
return key in self._overrides
|
||
|
|
||
|
def __getitem__(self, key: _K) -> _V:
|
||
|
return self._merged[key]
|
||
|
|
||
|
def get(self, key: _K, default: Optional[_V] = None):
|
||
|
return self._merged.get(key, default)
|
||
|
|
||
|
def __contains__(self, key: object) -> bool:
|
||
|
return key in self._merged
|
||
|
|
||
|
def __iter__(self):
|
||
|
return iter(self._merged)
|
||
|
|
||
|
def __len__(self) -> int:
|
||
|
return len(self._merged)
|
||
|
|
||
|
def __repr__(self) -> str:
|
||
|
return f"OverrideDict(base={self._base}, overrides={self._overrides})"
|
||
|
|
||
|
def __bool__(self) -> bool:
|
||
|
return bool(self._merged)
|
||
|
|
||
|
|
||
|
class _SymbolicFunctionGroup:
|
||
|
"""Different versions of symbolic functions registered to the same name.
|
||
|
|
||
|
O(number of registered versions of an op) search is performed to find the most
|
||
|
recent version of the op.
|
||
|
|
||
|
The registration is delayed until op is used to improve startup time.
|
||
|
|
||
|
Function overloads with different arguments are not allowed.
|
||
|
Custom op overrides are supported.
|
||
|
"""
|
||
|
|
||
|
def __init__(self, name: str) -> None:
|
||
|
self._name = name
|
||
|
# A dictionary of functions, keyed by the opset version.
|
||
|
self._functions: OverrideDict[OpsetVersion, Callable] = OverrideDict()
|
||
|
|
||
|
def __repr__(self) -> str:
|
||
|
return f"_SymbolicFunctionGroup({self._name}, registered={self._functions})"
|
||
|
|
||
|
def __getitem__(self, key: OpsetVersion) -> Callable:
|
||
|
result = self.get(key)
|
||
|
if result is None:
|
||
|
raise KeyError(key)
|
||
|
return result
|
||
|
|
||
|
# TODO(justinchuby): Add @functools.lru_cache(maxsize=None) if lookup time becomes
|
||
|
# a problem.
|
||
|
def get(self, opset: OpsetVersion) -> Optional[Callable]:
|
||
|
"""Find the most recent version of the function."""
|
||
|
version = _dispatch_opset_version(opset, self._functions)
|
||
|
if version is None:
|
||
|
return None
|
||
|
|
||
|
return self._functions[version]
|
||
|
|
||
|
def add(self, func: Callable, opset: OpsetVersion) -> None:
|
||
|
"""Adds a symbolic function.
|
||
|
|
||
|
Args:
|
||
|
func: The function to add.
|
||
|
opset: The opset version of the function to add.
|
||
|
"""
|
||
|
if self._functions.in_base(opset):
|
||
|
warnings.warn(
|
||
|
f"Symbolic function '{self._name}' already registered for opset {opset}. "
|
||
|
f"Replacing the existing function with new function. This is unexpected. "
|
||
|
f"Please report it on {_constants.PYTORCH_GITHUB_ISSUES_URL}.",
|
||
|
errors.OnnxExporterWarning,
|
||
|
)
|
||
|
self._functions.set_base(opset, func)
|
||
|
|
||
|
def add_custom(self, func: Callable, opset: OpsetVersion) -> None:
|
||
|
"""Adds a custom symbolic function.
|
||
|
|
||
|
Args:
|
||
|
func: The symbolic function to register.
|
||
|
opset: The corresponding opset version.
|
||
|
"""
|
||
|
self._functions.override(opset, func)
|
||
|
|
||
|
def remove_custom(self, opset: OpsetVersion) -> None:
|
||
|
"""Removes a custom symbolic function.
|
||
|
|
||
|
Args:
|
||
|
opset: The opset version of the custom function to remove.
|
||
|
"""
|
||
|
if not self._functions.overridden(opset):
|
||
|
warnings.warn(
|
||
|
f"No custom function registered for '{self._name}' opset {opset}"
|
||
|
)
|
||
|
return
|
||
|
self._functions.remove_override(opset)
|
||
|
|
||
|
def get_min_supported(self) -> OpsetVersion:
|
||
|
"""Returns the lowest built-in opset version supported by the function."""
|
||
|
return min(self._functions)
|
||
|
|
||
|
|
||
|
class SymbolicRegistry:
|
||
|
"""Registry for symbolic functions.
|
||
|
|
||
|
The registry maintains a mapping from qualified names to symbolic functions.
|
||
|
It is used to register new symbolic functions and to dispatch calls to
|
||
|
the appropriate function.
|
||
|
"""
|
||
|
|
||
|
def __init__(self) -> None:
|
||
|
self._registry: Dict[str, _SymbolicFunctionGroup] = {}
|
||
|
|
||
|
def register(
|
||
|
self, name: str, opset: OpsetVersion, func: Callable, custom: bool = False
|
||
|
) -> None:
|
||
|
"""Registers a symbolic function.
|
||
|
|
||
|
Args:
|
||
|
name: The qualified name of the function to register. In the form of 'domain::op'.
|
||
|
E.g. 'aten::add'.
|
||
|
opset: The opset version of the function to register.
|
||
|
func: The symbolic function to register.
|
||
|
custom: Whether the function is a custom function that overrides existing ones.
|
||
|
|
||
|
Raises:
|
||
|
ValueError: If the separator '::' is not in the name.
|
||
|
"""
|
||
|
if "::" not in name:
|
||
|
raise ValueError(
|
||
|
f"The name must be in the form of 'domain::op', not '{name}'"
|
||
|
)
|
||
|
symbolic_functions = self._registry.setdefault(
|
||
|
name, _SymbolicFunctionGroup(name)
|
||
|
)
|
||
|
if custom:
|
||
|
symbolic_functions.add_custom(func, opset)
|
||
|
else:
|
||
|
symbolic_functions.add(func, opset)
|
||
|
|
||
|
def unregister(self, name: str, opset: OpsetVersion) -> None:
|
||
|
"""Unregisters a symbolic function.
|
||
|
|
||
|
Args:
|
||
|
name: The qualified name of the function to unregister.
|
||
|
opset: The opset version of the function to unregister.
|
||
|
"""
|
||
|
if name not in self._registry:
|
||
|
return
|
||
|
self._registry[name].remove_custom(opset)
|
||
|
|
||
|
def get_function_group(self, name: str) -> Optional[_SymbolicFunctionGroup]:
|
||
|
"""Returns the function group for the given name."""
|
||
|
return self._registry.get(name)
|
||
|
|
||
|
def is_registered_op(self, name: str, version: int) -> bool:
|
||
|
"""Returns whether the given op is registered for the given opset version."""
|
||
|
functions = self.get_function_group(name)
|
||
|
if functions is None:
|
||
|
return False
|
||
|
return functions.get(version) is not None
|
||
|
|
||
|
def all_functions(self) -> Set[str]:
|
||
|
"""Returns the set of all registered function names."""
|
||
|
return set(self._registry)
|
||
|
|
||
|
|
||
|
@_beartype.beartype
|
||
|
def onnx_symbolic(
|
||
|
name: str,
|
||
|
opset: Union[OpsetVersion, Sequence[OpsetVersion]],
|
||
|
decorate: Optional[Sequence[Callable]] = None,
|
||
|
custom: bool = False,
|
||
|
) -> Callable:
|
||
|
"""Registers a symbolic function.
|
||
|
|
||
|
Usage::
|
||
|
|
||
|
```
|
||
|
@onnx_symbolic("aten::symbolic_b", opset=10, decorate=[quantized_aten_handler(scale=1/128, zero_point=0)])
|
||
|
@symbolic_helper.parse_args("v", "v", "b")
|
||
|
def symbolic_b(g: _C.Graph, x: _C.Value, y: _C.Value, arg1: bool) -> _C.Value:
|
||
|
...
|
||
|
```
|
||
|
|
||
|
Args:
|
||
|
name: The qualified name of the function in the form of 'domain::op'.
|
||
|
E.g. 'aten::add'.
|
||
|
opset: The opset versions of the function to register at.
|
||
|
decorate: A sequence of decorators to apply to the function.
|
||
|
custom: Whether the function is a custom symbolic function.
|
||
|
|
||
|
Raises:
|
||
|
ValueError: If the separator '::' is not in the name.
|
||
|
"""
|
||
|
|
||
|
def wrapper(func: Callable) -> Callable:
|
||
|
decorated = func
|
||
|
if decorate is not None:
|
||
|
for decorate_func in decorate:
|
||
|
decorated = decorate_func(decorated)
|
||
|
|
||
|
global registry
|
||
|
nonlocal opset
|
||
|
if isinstance(opset, OpsetVersion):
|
||
|
opset = (opset,)
|
||
|
for opset_version in opset:
|
||
|
registry.register(name, opset_version, decorated, custom=custom)
|
||
|
|
||
|
# Return the original function because the decorators in "decorate" are only
|
||
|
# specific to the instance being registered.
|
||
|
return func
|
||
|
|
||
|
return wrapper
|
||
|
|
||
|
|
||
|
@_beartype.beartype
|
||
|
def custom_onnx_symbolic(
|
||
|
name: str,
|
||
|
opset: Union[OpsetVersion, Sequence[OpsetVersion]],
|
||
|
decorate: Optional[Sequence[Callable]] = None,
|
||
|
) -> Callable:
|
||
|
"""Registers a custom symbolic function.
|
||
|
|
||
|
Args:
|
||
|
name: the qualified name of the function.
|
||
|
opset: the opset version of the function.
|
||
|
decorate: a sequence of decorators to apply to the function.
|
||
|
|
||
|
Returns:
|
||
|
The decorator.
|
||
|
|
||
|
Raises:
|
||
|
ValueError: If the separator '::' is not in the name.
|
||
|
"""
|
||
|
return onnx_symbolic(name, opset, decorate, custom=True)
|
||
|
|
||
|
|
||
|
# The registry for all symbolic functions.
|
||
|
registry = SymbolicRegistry()
|