ai-content-maker/.venv/Lib/site-packages/torchgen/selective_build/operator.py

171 lines
6.4 KiB
Python
Raw Normal View History

2024-05-03 04:18:51 +03:00
from dataclasses import dataclass
from typing import Dict, Optional, Tuple
# This class holds information about a single operator used to determine
# the outcome of a selective/custom PyTorch build that doesn't include
# registration code for all the supported operators. This is done to
# reduce the size of the generated binary so that it can be deployed in
# situations where binary size comes at a premium.
#
@dataclass(frozen=True)
class SelectiveBuildOperator:
# The name of the operator. This includes the aten::, etc... prefix
# The operator name may or may not have the overload name. If this
# operator name does not specify an overload name, the way to determine
# if this entry refers to the family of operators with this base name
# or just the operator with this name is to look at the value of the
# 'include_all_overloads' flag in this class.
name: str
# True if this is a root operator (i.e. called directly from a
# TorchScript model, etc...). An operator is considered to be a
# root operator if it is called directly from any one of the models
# that this instance of the pytorch library was built for. Hence, it
# may not be a root operator in all of the models that are used in
# this instance of the pytorch library.
is_root_operator: bool
# Is this operator used for on-device training? If True, then we need to
# use the information to generate code in VariableType_N.cpp for registration
# of training related operators. Again, this is True if this operator
# is used for training in one or more models used by this instance of the
# pytorch library.
is_used_for_training: bool
# If True, it indicates that this operator instance (object) refers to an
# operator without the overload name and should apply to all overloads
# which have this operator name as the base name. This flag is applicable
# only for objects that have operator names without a DOT (period) character
# in them.
#
# Note: This flag is a temporary workaround to grandfather in the current
# static selective (custom) build mechanism, which largely ignores overload
# names when determining whether to select operators for registration
# purposes.
include_all_overloads: bool
# Debug Information at the operator level
_debug_info: Optional[Tuple[str, ...]]
@staticmethod
def from_yaml_dict(
op_name: str, op_info: Dict[str, object]
) -> "SelectiveBuildOperator":
allowed_keys = {
"name",
"is_root_operator",
"is_used_for_training",
"include_all_overloads",
"debug_info",
}
if len(set(op_info.keys()) - allowed_keys) > 0:
raise Exception(
"Got unexpected top level keys: {}".format(
",".join(set(op_info.keys()) - allowed_keys),
)
)
if "name" in op_info:
assert op_name == op_info["name"]
is_root_operator = op_info.get("is_root_operator", True)
assert isinstance(is_root_operator, bool)
is_used_for_training = op_info.get("is_used_for_training", True)
assert isinstance(is_used_for_training, bool)
include_all_overloads = op_info.get("include_all_overloads", True)
assert isinstance(include_all_overloads, bool)
debug_info: Optional[Tuple[str, ...]] = None
if "debug_info" in op_info:
di_list = op_info["debug_info"]
assert isinstance(di_list, list)
debug_info = tuple(str(x) for x in di_list)
return SelectiveBuildOperator(
name=op_name,
is_root_operator=is_root_operator,
is_used_for_training=is_used_for_training,
include_all_overloads=include_all_overloads,
_debug_info=debug_info,
)
@staticmethod
def from_legacy_operator_name_without_overload(
name: str,
) -> "SelectiveBuildOperator":
return SelectiveBuildOperator(
name=name,
is_root_operator=True,
is_used_for_training=True,
include_all_overloads=True,
_debug_info=None,
)
def to_dict(self) -> Dict[str, object]:
ret: Dict[str, object] = {
"is_root_operator": self.is_root_operator,
"is_used_for_training": self.is_used_for_training,
"include_all_overloads": self.include_all_overloads,
}
if self._debug_info is not None:
ret["debug_info"] = self._debug_info
return ret
def merge_debug_info(
lhs: Optional[Tuple[str, ...]],
rhs: Optional[Tuple[str, ...]],
) -> Optional[Tuple[str, ...]]:
# Ensure that when merging, each entry shows up just once.
if lhs is None and rhs is None:
return None
return tuple(set((lhs or ()) + (rhs or ())))
def combine_operators(
lhs: "SelectiveBuildOperator", rhs: "SelectiveBuildOperator"
) -> "SelectiveBuildOperator":
if str(lhs.name) != str(rhs.name):
raise Exception(
f"Expected both arguments to have the same name, but got '{str(lhs.name)}' and '{str(rhs.name)}' instead"
)
return SelectiveBuildOperator(
name=lhs.name,
# Consider this operator to be a root operator if it is a
# root operator in any of the models used in this instance of
# the pytorch library.
is_root_operator=lhs.is_root_operator or rhs.is_root_operator,
# Consider this operator to be a training operator if it is
# an operator used for training in any of the models used
# in this instance of the pytorch library.
is_used_for_training=lhs.is_used_for_training or rhs.is_used_for_training,
include_all_overloads=lhs.include_all_overloads or rhs.include_all_overloads,
_debug_info=merge_debug_info(lhs._debug_info, rhs._debug_info),
)
def merge_operator_dicts(
lhs: Dict[str, SelectiveBuildOperator],
rhs: Dict[str, SelectiveBuildOperator],
) -> Dict[str, SelectiveBuildOperator]:
operators: Dict[str, SelectiveBuildOperator] = {}
for op_name, op in list(lhs.items()) + list(rhs.items()):
new_op = op
if op_name in operators:
new_op = combine_operators(operators[op_name], op)
operators[op_name] = new_op
return operators
def strip_operator_overload_name(op_name: str) -> str:
return op_name.split(".")[0]