from dataclasses import dataclass from typing import Dict, Optional, Tuple # This class holds information about a single operator used to determine # the outcome of a selective/custom PyTorch build that doesn't include # registration code for all the supported operators. This is done to # reduce the size of the generated binary so that it can be deployed in # situations where binary size comes at a premium. # @dataclass(frozen=True) class SelectiveBuildOperator: # The name of the operator. This includes the aten::, etc... prefix # The operator name may or may not have the overload name. If this # operator name does not specify an overload name, the way to determine # if this entry refers to the family of operators with this base name # or just the operator with this name is to look at the value of the # 'include_all_overloads' flag in this class. name: str # True if this is a root operator (i.e. called directly from a # TorchScript model, etc...). An operator is considered to be a # root operator if it is called directly from any one of the models # that this instance of the pytorch library was built for. Hence, it # may not be a root operator in all of the models that are used in # this instance of the pytorch library. is_root_operator: bool # Is this operator used for on-device training? If True, then we need to # use the information to generate code in VariableType_N.cpp for registration # of training related operators. Again, this is True if this operator # is used for training in one or more models used by this instance of the # pytorch library. is_used_for_training: bool # If True, it indicates that this operator instance (object) refers to an # operator without the overload name and should apply to all overloads # which have this operator name as the base name. This flag is applicable # only for objects that have operator names without a DOT (period) character # in them. # # Note: This flag is a temporary workaround to grandfather in the current # static selective (custom) build mechanism, which largely ignores overload # names when determining whether to select operators for registration # purposes. include_all_overloads: bool # Debug Information at the operator level _debug_info: Optional[Tuple[str, ...]] @staticmethod def from_yaml_dict( op_name: str, op_info: Dict[str, object] ) -> "SelectiveBuildOperator": allowed_keys = { "name", "is_root_operator", "is_used_for_training", "include_all_overloads", "debug_info", } if len(set(op_info.keys()) - allowed_keys) > 0: raise Exception( "Got unexpected top level keys: {}".format( ",".join(set(op_info.keys()) - allowed_keys), ) ) if "name" in op_info: assert op_name == op_info["name"] is_root_operator = op_info.get("is_root_operator", True) assert isinstance(is_root_operator, bool) is_used_for_training = op_info.get("is_used_for_training", True) assert isinstance(is_used_for_training, bool) include_all_overloads = op_info.get("include_all_overloads", True) assert isinstance(include_all_overloads, bool) debug_info: Optional[Tuple[str, ...]] = None if "debug_info" in op_info: di_list = op_info["debug_info"] assert isinstance(di_list, list) debug_info = tuple(str(x) for x in di_list) return SelectiveBuildOperator( name=op_name, is_root_operator=is_root_operator, is_used_for_training=is_used_for_training, include_all_overloads=include_all_overloads, _debug_info=debug_info, ) @staticmethod def from_legacy_operator_name_without_overload( name: str, ) -> "SelectiveBuildOperator": return SelectiveBuildOperator( name=name, is_root_operator=True, is_used_for_training=True, include_all_overloads=True, _debug_info=None, ) def to_dict(self) -> Dict[str, object]: ret: Dict[str, object] = { "is_root_operator": self.is_root_operator, "is_used_for_training": self.is_used_for_training, "include_all_overloads": self.include_all_overloads, } if self._debug_info is not None: ret["debug_info"] = self._debug_info return ret def merge_debug_info( lhs: Optional[Tuple[str, ...]], rhs: Optional[Tuple[str, ...]], ) -> Optional[Tuple[str, ...]]: # Ensure that when merging, each entry shows up just once. if lhs is None and rhs is None: return None return tuple(set((lhs or ()) + (rhs or ()))) def combine_operators( lhs: "SelectiveBuildOperator", rhs: "SelectiveBuildOperator" ) -> "SelectiveBuildOperator": if str(lhs.name) != str(rhs.name): raise Exception( f"Expected both arguments to have the same name, but got '{str(lhs.name)}' and '{str(rhs.name)}' instead" ) return SelectiveBuildOperator( name=lhs.name, # Consider this operator to be a root operator if it is a # root operator in any of the models used in this instance of # the pytorch library. is_root_operator=lhs.is_root_operator or rhs.is_root_operator, # Consider this operator to be a training operator if it is # an operator used for training in any of the models used # in this instance of the pytorch library. is_used_for_training=lhs.is_used_for_training or rhs.is_used_for_training, include_all_overloads=lhs.include_all_overloads or rhs.include_all_overloads, _debug_info=merge_debug_info(lhs._debug_info, rhs._debug_info), ) def merge_operator_dicts( lhs: Dict[str, SelectiveBuildOperator], rhs: Dict[str, SelectiveBuildOperator], ) -> Dict[str, SelectiveBuildOperator]: operators: Dict[str, SelectiveBuildOperator] = {} for op_name, op in list(lhs.items()) + list(rhs.items()): new_op = op if op_name in operators: new_op = combine_operators(operators[op_name], op) operators[op_name] = new_op return operators def strip_operator_overload_name(op_name: str) -> str: return op_name.split(".")[0]