# Represents all kernels used by an Executorch model. # It maintains a Dict[OperatorName, Dict[ETKernelKey, BackendMetadata]] structure. import itertools from collections import defaultdict, namedtuple from dataclasses import dataclass from enum import IntEnum from typing import Dict, List, Tuple, Union from torchgen.model import ( BackendIndex, BackendMetadata, DispatchKey, NativeFunction, NativeFunctionsGroup, OperatorName, ) from torchgen.utils import assert_never KERNEL_KEY_VERSION = 1 # TODO: Duplicated Subset from codegen.tool.gen_oplist, remove declaration in codegen class ScalarType(IntEnum): Byte = 0 Char = 1 Short = 2 Int = 3 Long = 4 Float = 6 Double = 7 Bool = 11 ETParsedYaml = namedtuple("ETParsedYaml", ["native_functions", "kernel_index"]) @dataclass(frozen=True) class ETKernelKeyOpArgMeta: arg_name: str dtype: str # The order of the dimensions if entry is a Tensor dim_order: Tuple[int, ...] def to_native_string(self) -> str: dtype_str = ScalarType[self.dtype].value dim_str = str(self.dim_order)[1:-1].replace(" ", "") return f"{dtype_str};{dim_str}" @dataclass(frozen=True) class ETKernelKey: # Field undefined is default = True arg_meta: Tuple[ETKernelKeyOpArgMeta, ...] = () # Indicator for this kernel being used as a catch all default: bool = False version: int = KERNEL_KEY_VERSION @staticmethod def gen_from_yaml( args: Dict[str, Tuple[str, str]], type_alias_map: Dict[str, List[str]], # TODO: Support unwrapped str val dim_order_alias_map: Dict[str, List[int]], ) -> List["ETKernelKey"]: """Generate ETKernelKeys from arg kernel specs Multiple ETKernelKeys are returned due to dtype permutations from utilizing type_alias_map (actualizing each potential type permutation as a KernelKey) Args: args: Mapping from argument name to kernel specs Kernel specs are a tuple of (dtype, dim_order). Currently tuple entries must be aliased via the alias map arguments type_alias_map: Mapping from type alias to potential type enums i.e { T0 : [Double, Int] } means T0 can be either Double or Int Used for lookup by args dim_order_alias_map: Mapping from alias to a list of dimension orders Used for lookup by args """ # Cast to dim order to int dim_order_alias_map = { k: [int(alias) for alias in v] for k, v in dim_order_alias_map.items() } kernel_keys = [] # Get all used Dtype Alias dtype_alias_used = set() for type_alias, dim_order in args.values(): # Enforce usage of alias initially # TODO: Support inlined arguments assert type_alias in type_alias_map, "Undefined type alias: " + str( type_alias ) assert ( dim_order in dim_order_alias_map ), "Undefined dim_order alias: " + str(dim_order) dtype_alias_used.add(type_alias) # Generate all permutations of dtype alias values alias_dtypes = [ [(alias, dtype) for dtype in type_alias_map[alias]] for alias in dtype_alias_used ] alias_permutations = [ dict(permutation) for permutation in list(itertools.product(*alias_dtypes)) ] # Using each alias value permutation, generate kernel keys op_arg_cache = {} for permutation in alias_permutations: arg_list = [] for arg_name, arg_spec in args.items(): dtype = permutation[arg_spec[0]] dim_order = dim_order_alias_map[arg_spec[1]] # type: ignore[assignment] if ( cache_key := (arg_name, dtype, tuple(dim_order)) ) not in op_arg_cache: op_arg_cache[cache_key] = ETKernelKeyOpArgMeta(*cache_key) # type: ignore[arg-type] arg_list.append(op_arg_cache[cache_key]) kernel_keys.append(ETKernelKey(tuple(arg_list))) return kernel_keys def to_native_string(self) -> str: if self.default: return "default" return ( "v" + str(KERNEL_KEY_VERSION) + "/" + "|".join([arg.to_native_string() for arg in self.arg_meta]) ) @dataclass(frozen=True) class ETKernelIndex: index: Dict[OperatorName, Dict[ETKernelKey, BackendMetadata]] def has_kernels(self, g: Union[NativeFunction, NativeFunctionsGroup]) -> bool: m = self.get_kernels(g) return m is not None def get_kernels( self, g: Union[NativeFunction, NativeFunctionsGroup] ) -> Dict[ETKernelKey, BackendMetadata]: if isinstance(g, NativeFunction): f = g elif isinstance(g, NativeFunctionsGroup): f = g.functional else: assert_never(g) if f.func.name not in self.index: return {} return self.index[f.func.name] @staticmethod def grow_from_backend_indices( kernel_index: Dict[OperatorName, Dict[ETKernelKey, BackendMetadata]], backend_indices: Dict[DispatchKey, Dict[OperatorName, BackendMetadata]], ) -> None: for dk in backend_indices: index = backend_indices[dk] for op, backend_metadata in index.items(): if op in kernel_index: kernel_index[op][ETKernelKey(default=True)] = backend_metadata else: kernel_index[op] = {ETKernelKey(default=True): backend_metadata} @staticmethod def from_backend_indices( backend_indices: Dict[DispatchKey, Dict[OperatorName, BackendMetadata]] ) -> "ETKernelIndex": kernel_index: Dict[ OperatorName, Dict[ETKernelKey, BackendMetadata] ] = defaultdict(dict) ETKernelIndex.grow_from_backend_indices(kernel_index, backend_indices) return ETKernelIndex(kernel_index) def grow( self, backend_indices: Dict[DispatchKey, Dict[OperatorName, BackendMetadata]] ) -> "ETKernelIndex": ETKernelIndex.grow_from_backend_indices(self.index, backend_indices) return self def _to_backend_index(self) -> BackendIndex: """ WARNING: this will be deprecated once all the codegen places know how to handle ETKernelIndex. """ index: Dict[OperatorName, BackendMetadata] = {} for op in self.index: kernel_dict = self.index[op] assert ( len(kernel_dict.values()) == 1 ), f"Can't convert ETKernelIndex to BackendIndex because {op} has more than one kernels. Got {kernel_dict}" index[op] = kernel_dict.get( ETKernelKey(default=True), BackendMetadata(kernel="", structured=False, cpp_namespace=""), ) return BackendIndex( dispatch_key=DispatchKey.CPU, use_out_as_primary=False, device_guard=False, external=False, index=index, ) # Note duplicate ETKernelKey from index_b will clobber the metadata from index_a @staticmethod def merge_indices( index_a: "ETKernelIndex", index_b: "ETKernelIndex" ) -> "ETKernelIndex": combined = defaultdict(dict, index_a.index.copy()) for op, entry in index_b.index.items(): for key, metadata in entry.items(): combined[op][key] = metadata return ETKernelIndex(combined)