from collections import defaultdict from collections.abc import Iterable from dataclasses import dataclass from typing import Dict, List, Optional, Set, Tuple import yaml from torchgen.model import NativeFunction from torchgen.selective_build.operator import ( merge_debug_info, merge_operator_dicts, SelectiveBuildOperator, strip_operator_overload_name, ) # A SelectiveBuilder holds information extracted from the selective build # YAML specification. # # It includes information about the build's selectivity, the debug_info # associated with this selective build (opaque string), and the set of # operators that should be included in the build. # @dataclass(frozen=True) class SelectiveBuilder: # If true, then the build is not selective, and includes all # operators. include_all_operators: bool # Debug Information at the selective/custom build level. _debug_info: Optional[Tuple[str, ...]] # A dictionary of operator -> operator metadata. operators: Dict[str, SelectiveBuildOperator] # A dictionary of selected kernel tags and dtypes. Typically a # PyTorch Operator Kernel (function) may have many code paths # that are specialized for many many Tensor dtypes, so it's not # one per kernel function, but there could be many per kernel # function. The tag isn't a kernel function name, but some fragment # of the kernel function implementation itself. kernel_metadata: Dict[str, List[str]] # ExecuTorch only. A dictionary of kernel tag -> list of (list of input # dtypes for tensor-like input args). # This is from selective.yaml et_kernel_metadata: Dict[str, List[str]] # A set of all the custom torch bind classes used by the selected models # Stored as a set internally to remove duplicates proactively, but written # as a list to yamls custom_classes: Set[str] # A set of all the build features used by the selected models # Stored as a set internally to remove duplicates proactively, but written # as a list to yamls build_features: Set[str] # If true, then fragments for all dtypes for all kernel functions # are included as well as all custom classes. This is typically set when any one of the # operator lists is generated from a mechanism other than # tracing based selective build. include_all_non_op_selectives: bool @staticmethod def get_nop_selector() -> "SelectiveBuilder": return SelectiveBuilder.from_yaml_dict({"include_all_operators": True}) @staticmethod def from_yaml_dict(data: Dict[str, object]) -> "SelectiveBuilder": valid_top_level_keys = { "include_all_non_op_selectives", "include_all_operators", "debug_info", "operators", "kernel_metadata", "et_kernel_metadata", "custom_classes", "build_features", } top_level_keys = set(data.keys()) if len(top_level_keys - valid_top_level_keys) > 0: raise Exception( "Got unexpected top level keys: {}".format( ",".join(top_level_keys - valid_top_level_keys), ) ) include_all_operators = data.get("include_all_operators", False) assert isinstance(include_all_operators, bool) debug_info = None if "debug_info" in data: di_list = data["debug_info"] assert isinstance(di_list, list) debug_info = tuple(str(x) for x in di_list) operators = {} operators_dict = data.get("operators", {}) assert isinstance(operators_dict, dict) for k, v in operators_dict.items(): operators[k] = SelectiveBuildOperator.from_yaml_dict(k, v) kernel_metadata = {} kernel_metadata_dict = data.get("kernel_metadata", {}) assert isinstance(kernel_metadata_dict, dict) for k, v in kernel_metadata_dict.items(): kernel_metadata[str(k)] = [str(dtype) for dtype in v] et_kernel_metadata = data.get("et_kernel_metadata", {}) assert isinstance(et_kernel_metadata, dict) custom_classes = data.get("custom_classes", []) assert isinstance(custom_classes, Iterable) custom_classes = set(custom_classes) build_features = data.get("build_features", []) assert isinstance(build_features, Iterable) build_features = set(build_features) include_all_non_op_selectives = data.get("include_all_non_op_selectives", False) assert isinstance(include_all_non_op_selectives, bool) return SelectiveBuilder( include_all_operators, debug_info, operators, kernel_metadata, et_kernel_metadata, custom_classes, # type: ignore[arg-type] build_features, # type: ignore[arg-type] include_all_non_op_selectives, ) @staticmethod def from_yaml_str(config_contents: str) -> "SelectiveBuilder": contents = yaml.safe_load(config_contents) return SelectiveBuilder.from_yaml_dict(contents) @staticmethod def from_yaml_path(config_path: str) -> "SelectiveBuilder": with open(config_path) as f: contents = yaml.safe_load(f) return SelectiveBuilder.from_yaml_dict(contents) @staticmethod def from_legacy_op_registration_allow_list( allow_list: Set[str], is_root_operator: bool, is_used_for_training: bool ) -> "SelectiveBuilder": operators = {} for op in allow_list: operators[op] = { "name": op, "is_root_operator": is_root_operator, "is_used_for_training": is_used_for_training, "include_all_overloads": True, } return SelectiveBuilder.from_yaml_dict( { "operators": operators, "include_all_non_op_selectives": True, } ) def is_operator_selected(self, name: str) -> bool: if self.include_all_operators: return True if name in self.operators: return True name = strip_operator_overload_name(name) return name in self.operators and self.operators[name].include_all_overloads def is_native_function_selected(self, func: NativeFunction) -> bool: op_name = op_name_from_native_function(func) return self.is_operator_selected(op_name) def is_operator_selected_for_training(self, name: str) -> bool: if not self.is_operator_selected(name): return False if self.include_all_operators: return True not_training_op = SelectiveBuildOperator( name="", is_root_operator=False, is_used_for_training=False, include_all_overloads=False, _debug_info=None, ) op = not_training_op if name in self.operators: op = self.operators[name] name = strip_operator_overload_name(name) base_op = not_training_op if name in self.operators: base_op = self.operators[name] return op.is_used_for_training or ( base_op.include_all_overloads and base_op.is_used_for_training ) def is_native_function_selected_for_training(self, func: NativeFunction) -> bool: op_name = op_name_from_native_function(func) return self.is_operator_selected_for_training(op_name) def is_root_operator(self, name: str) -> bool: if not self.is_operator_selected(name): return False if self.include_all_operators: return True if name in self.operators: op: SelectiveBuildOperator = self.operators[name] return op.is_root_operator name = strip_operator_overload_name(name) if name not in self.operators: return False base_op: SelectiveBuildOperator = self.operators[name] return base_op.include_all_overloads and base_op.is_root_operator def is_kernel_dtype_selected(self, kernel_tag: str, dtype: str) -> bool: if self.include_all_operators or self.include_all_non_op_selectives: return True return ( kernel_tag in self.kernel_metadata and dtype in self.kernel_metadata[kernel_tag] ) def et_get_selected_kernels(self, op_name: str, kernel_key: List[str]) -> List[str]: """ Return a list of kernel keys that cover the used ops """ # If no kernel metadata, either it's implied by include_all_operators=True or the op is not used. if op_name not in self.et_kernel_metadata: return kernel_key if self.include_all_operators else [] # Otherwise, only return the specific kernel keys. result_set = set() for model_kernel_keys in self.et_kernel_metadata[op_name]: key_found = False for key in kernel_key: # Don't compare the version for now if ( key != "default" and key.split("/")[1] == model_kernel_keys.split("/")[1] ): result_set.add(key) key_found = True break if not key_found: if "default" not in kernel_key: raise Exception("Missing kernel for the model") else: result_set.add("default") return list(result_set) def to_dict(self) -> Dict[str, object]: ret: Dict[str, object] = { "include_all_non_op_selectives": self.include_all_non_op_selectives, "include_all_operators": self.include_all_operators, } operators = {} for op_name, op in self.operators.items(): operators[op_name] = op.to_dict() ret["operators"] = operators if self._debug_info is not None: ret["debug_info"] = sorted(self._debug_info) ret["kernel_metadata"] = { k: sorted(v) for (k, v) in self.kernel_metadata.items() } ret["et_kernel_metadata"] = self.et_kernel_metadata ret["custom_classes"] = sorted(self.custom_classes) ret["build_features"] = sorted(self.build_features) return ret def merge_kernel_metadata( lhs: Dict[str, List[str]], rhs: Dict[str, List[str]], ) -> Dict[str, List[str]]: kernel_metadata: Dict[str, List[str]] = {} for tag_name, dtypes in list(lhs.items()) + list(rhs.items()): dtypes_copy = set(dtypes) if tag_name in kernel_metadata: dtypes_copy |= set(kernel_metadata[tag_name]) kernel_metadata[tag_name] = list(dtypes_copy) return kernel_metadata def merge_et_kernel_metadata( lhs: Dict[str, List[str]], rhs: Dict[str, List[str]], ) -> Dict[str, List[str]]: merge_et_kernel_metadata: Dict[str, Set[str]] = defaultdict(set) for op in list(lhs.keys()) + list(rhs.keys()): merge_et_kernel_metadata[op].update(lhs.get(op, [])) merge_et_kernel_metadata[op].update(rhs.get(op, [])) return {op: sorted(val) for op, val in merge_et_kernel_metadata.items()} def combine_selective_builders( lhs: SelectiveBuilder, rhs: SelectiveBuilder ) -> SelectiveBuilder: include_all_operators = lhs.include_all_operators or rhs.include_all_operators debug_info = merge_debug_info(lhs._debug_info, rhs._debug_info) operators = merge_operator_dicts(lhs.operators, rhs.operators) kernel_metadata = merge_kernel_metadata(lhs.kernel_metadata, rhs.kernel_metadata) et_kernel_metadata = merge_et_kernel_metadata( lhs.et_kernel_metadata, rhs.et_kernel_metadata ) include_all_non_op_selectives = ( lhs.include_all_non_op_selectives or rhs.include_all_non_op_selectives ) custom_classes = lhs.custom_classes.union(rhs.custom_classes) build_features = lhs.build_features.union(rhs.build_features) return SelectiveBuilder( include_all_operators, debug_info, operators, kernel_metadata, et_kernel_metadata, custom_classes, build_features, include_all_non_op_selectives, ) def op_name_from_native_function(f: NativeFunction) -> str: # This was originally read from the 'operator_name_with_overload' field in the # declaration dict, which was the part before the first '(' in 'schema_string'. return f"{f.namespace}::{f.func.name}"