from dataclasses import dataclass from typing import Callable, List, Sequence, Tuple from torchgen.api.types import Binding, CType, NamedCType from torchgen.model import ( Argument, BaseTy, BaseType, ListType, NativeFunction, OptionalType, Type, ) connector = "\n\t" # Return unboxing function name for a NativeFunction def name(f: NativeFunction) -> str: return f.func.name.unambiguous_name() @dataclass(frozen=True) class Unboxing: """ Takes a sequence of Bindings and unbox EValues to these Bindings. Return generated code that performs correct unboxing. A sample generated code: // aten::mul.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) void mul_out(EValue** stack) { EValue& self = *stack[0]; EValue& other = *stack[1]; EValue& out = *stack[2]; const torch::executor::Tensor & self_base = self.to(); const torch::executor::Tensor & other_base = other.to(); torch::executor::Tensor & out_base = out.to(); EXECUTORCH_SCOPE_PROF("native_call_mul.out"); torch::executor::mul_outf(self_base, other_base, out_base); } """ # this is a callable that converts a JIT argument, into its C++ type. # Translates (type, mutability, binds) to NamedCType. E.g., torchgen.api.cpp.argumenttype_type. argument_type_gen: Callable[ ..., NamedCType, ] # Convert all the arguments in a NativeFunction to C++ code def convert_arguments( self, args: Sequence[Binding] ) -> Tuple[List[Binding], List[str]]: code_list = [f"EValue& {args[i].name} = *stack[{i}];" for i in range(len(args))] binding_list = [] for arg in args: # expecting only Argument if not isinstance(arg.argument, Argument): raise Exception( f"Unexpected argument type, expecting `Argument` but got {arg}" ) argument: Argument = arg.argument unboxed_name, _, code, decl = self.argumenttype_evalue_convert( argument.type, argument.name, mutable=argument.is_write ) code_list.extend(decl) code_list.extend(code) binding_list.append(arg.with_name(unboxed_name)) return binding_list, code_list def argumenttype_evalue_convert( self, t: Type, arg_name: str, *, mutable: bool = False ) -> Tuple[str, CType, List[str], List[str]]: """ Takes in the type, name and mutability corresponding to an argument, and generates a tuple of: (1) the C++ code necessary to unbox the argument (2) A Binding corresponding to the newly created unboxed variable, including variable name and its CType :param t: a `Type` of an argument :param arg_name: argument name :param mutable: boolean for whether this argument type is mutable :return: unboxed result """ ctype = self.argument_type_gen(t, mutable=mutable, binds=arg_name).type if isinstance(t, BaseType): out_name = f"{arg_name}_base" code, decl = self._gen_code_base_type( arg_name=arg_name, out_name=out_name, ctype=ctype ) elif isinstance(t, OptionalType): out_name = f"{arg_name}_opt_out" code, decl = self._gen_code_optional_type( arg_name=arg_name, out_name=out_name, t=t, ctype=ctype ) elif isinstance(t, ListType): out_name = f"{arg_name}_list_out" code, decl = self._gen_code_list_type( arg_name=arg_name, out_name=out_name, t=t, ctype=ctype ) else: raise Exception(f"Cannot handle type {t}. arg_name: {arg_name}") return out_name, ctype, code, decl def _gen_code_base_type( self, arg_name: str, out_name: str, ctype: CType ) -> Tuple[List[str], List[str]]: return [ f"{ctype.cpp_type()} {out_name} = {arg_name}.to<{ctype.cpp_type(strip_ref=True)}>();" ], [] def _gen_code_optional_type( self, arg_name: str, out_name: str, t: OptionalType, ctype: CType ) -> Tuple[List[str], List[str]]: in_name = f"{arg_name}_opt_in" res_name, base_type, res_code, decl = self.argumenttype_evalue_convert( t.elem, in_name ) return ( f""" {ctype.cpp_type(strip_ref=True)} {out_name} = {arg_name}.toOptional<{base_type.cpp_type(strip_ref=True)}>(); """.split( "\n" ), decl, ) def _gen_code_list_type( self, arg_name: str, out_name: str, t: ListType, ctype: CType ) -> Tuple[List[str], List[str]]: in_name = f"{arg_name}_list_in" elem_name = f"{arg_name}_elem" code = [] res_name, res_ctype, res_code, decl = self.argumenttype_evalue_convert( t.elem, elem_name ) if isinstance(t.elem, BaseType) and t.elem.name == BaseTy.Tensor: code.extend( f""" {ctype.cpp_type(strip_ref=True)} {out_name} = {arg_name}.toTensorList(); """.split( "\n" ) ) elif isinstance(t.elem, BaseType) and ( t.elem.name == BaseTy.int or t.elem.name == BaseTy.SymInt ): code.extend( f""" {ctype.cpp_type(strip_ref=True)} {out_name} = {arg_name}.toIntList(); """.split( "\n" ) ) elif isinstance(t.elem, BaseType) and t.elem.name == BaseTy.float: code.extend( f""" {ctype.cpp_type(strip_ref=True)} {out_name} = {arg_name}.toDoubleList(); """.split( "\n" ) ) elif isinstance(t.elem, BaseType) and t.elem.name == BaseTy.bool: # handle list type with size, e.g., bool[4] code.extend( f""" {ctype.cpp_type(strip_ref=True)} {out_name} = {arg_name}.toBoolList(); """.split( "\n" ) ) # pytorch codegen: # we have to use c10::List for optional element. e.g., Tensor?[] -> c10::List> elif ( isinstance(t.elem, OptionalType) and isinstance(t.elem.elem, BaseType) and t.elem.elem.name == BaseTy.Tensor ): code.extend( f""" #ifdef USE_ATEN_LIB at::ArrayRef> {in_name} = {arg_name}.toListOptionalTensor(); c10::List> {out_name}; for (auto {elem_name}: {in_name}) {{ {out_name}.push_back({elem_name}); }} #else torch::executor::ArrayRef> {out_name} = {arg_name}.toListOptionalTensor(); #endif """.split( "\n" ) ) else: # use ArrayRef as default. vec_name = arg_name + "_vec" # need to bring vector instantiation out of scope so that ArrayRef has valid data decl.append( f"std::vector<{res_ctype.cpp_type(strip_ref=True)}> {vec_name};" ) code.extend( f""" for (EValue {elem_name}: {in_name}) {{ {connector.join(res_code)} {vec_name}.push_back({res_name}); }} {ctype.cpp_type(strip_ref=True)} {out_name}({vec_name}); """.split( "\n" ) ) return code, decl