214 lines
7.6 KiB
Python
214 lines
7.6 KiB
Python
|
from dataclasses import dataclass
|
||
|
from typing import Callable, List, Sequence, Tuple
|
||
|
|
||
|
from torchgen.api.types import Binding, CType, NamedCType
|
||
|
from torchgen.model import (
|
||
|
Argument,
|
||
|
BaseTy,
|
||
|
BaseType,
|
||
|
ListType,
|
||
|
NativeFunction,
|
||
|
OptionalType,
|
||
|
Type,
|
||
|
)
|
||
|
|
||
|
connector = "\n\t"
|
||
|
|
||
|
|
||
|
# Return unboxing function name for a NativeFunction
|
||
|
def name(f: NativeFunction) -> str:
|
||
|
return f.func.name.unambiguous_name()
|
||
|
|
||
|
|
||
|
@dataclass(frozen=True)
|
||
|
class Unboxing:
|
||
|
"""
|
||
|
Takes a sequence of Bindings and unbox EValues to these Bindings. Return generated code that performs correct unboxing.
|
||
|
A sample generated code:
|
||
|
// aten::mul.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)
|
||
|
void mul_out(EValue** stack) {
|
||
|
EValue& self = *stack[0];
|
||
|
EValue& other = *stack[1];
|
||
|
EValue& out = *stack[2];
|
||
|
const torch::executor::Tensor & self_base = self.to<torch::executor::Tensor>();
|
||
|
const torch::executor::Tensor & other_base = other.to<torch::executor::Tensor>();
|
||
|
torch::executor::Tensor & out_base = out.to<torch::executor::Tensor>();
|
||
|
|
||
|
EXECUTORCH_SCOPE_PROF("native_call_mul.out");
|
||
|
torch::executor::mul_outf(self_base, other_base, out_base);
|
||
|
|
||
|
|
||
|
}
|
||
|
"""
|
||
|
|
||
|
# this is a callable that converts a JIT argument, into its C++ type.
|
||
|
# Translates (type, mutability, binds) to NamedCType. E.g., torchgen.api.cpp.argumenttype_type.
|
||
|
argument_type_gen: Callable[
|
||
|
...,
|
||
|
NamedCType,
|
||
|
]
|
||
|
|
||
|
# Convert all the arguments in a NativeFunction to C++ code
|
||
|
def convert_arguments(
|
||
|
self, args: Sequence[Binding]
|
||
|
) -> Tuple[List[Binding], List[str]]:
|
||
|
code_list = [f"EValue& {args[i].name} = *stack[{i}];" for i in range(len(args))]
|
||
|
binding_list = []
|
||
|
for arg in args:
|
||
|
# expecting only Argument
|
||
|
if not isinstance(arg.argument, Argument):
|
||
|
raise Exception(
|
||
|
f"Unexpected argument type, expecting `Argument` but got {arg}"
|
||
|
)
|
||
|
argument: Argument = arg.argument
|
||
|
unboxed_name, _, code, decl = self.argumenttype_evalue_convert(
|
||
|
argument.type, argument.name, mutable=argument.is_write
|
||
|
)
|
||
|
code_list.extend(decl)
|
||
|
code_list.extend(code)
|
||
|
binding_list.append(arg.with_name(unboxed_name))
|
||
|
return binding_list, code_list
|
||
|
|
||
|
def argumenttype_evalue_convert(
|
||
|
self, t: Type, arg_name: str, *, mutable: bool = False
|
||
|
) -> Tuple[str, CType, List[str], List[str]]:
|
||
|
"""
|
||
|
Takes in the type, name and mutability corresponding to an argument, and generates a tuple of:
|
||
|
(1) the C++ code necessary to unbox the argument
|
||
|
(2) A Binding corresponding to the newly created unboxed variable, including variable name and its CType
|
||
|
:param t: a `Type` of an argument
|
||
|
:param arg_name: argument name
|
||
|
:param mutable: boolean for whether this argument type is mutable
|
||
|
:return: unboxed result
|
||
|
"""
|
||
|
ctype = self.argument_type_gen(t, mutable=mutable, binds=arg_name).type
|
||
|
|
||
|
if isinstance(t, BaseType):
|
||
|
out_name = f"{arg_name}_base"
|
||
|
code, decl = self._gen_code_base_type(
|
||
|
arg_name=arg_name, out_name=out_name, ctype=ctype
|
||
|
)
|
||
|
elif isinstance(t, OptionalType):
|
||
|
out_name = f"{arg_name}_opt_out"
|
||
|
code, decl = self._gen_code_optional_type(
|
||
|
arg_name=arg_name, out_name=out_name, t=t, ctype=ctype
|
||
|
)
|
||
|
elif isinstance(t, ListType):
|
||
|
out_name = f"{arg_name}_list_out"
|
||
|
code, decl = self._gen_code_list_type(
|
||
|
arg_name=arg_name, out_name=out_name, t=t, ctype=ctype
|
||
|
)
|
||
|
else:
|
||
|
raise Exception(f"Cannot handle type {t}. arg_name: {arg_name}")
|
||
|
return out_name, ctype, code, decl
|
||
|
|
||
|
def _gen_code_base_type(
|
||
|
self, arg_name: str, out_name: str, ctype: CType
|
||
|
) -> Tuple[List[str], List[str]]:
|
||
|
return [
|
||
|
f"{ctype.cpp_type()} {out_name} = {arg_name}.to<{ctype.cpp_type(strip_ref=True)}>();"
|
||
|
], []
|
||
|
|
||
|
def _gen_code_optional_type(
|
||
|
self, arg_name: str, out_name: str, t: OptionalType, ctype: CType
|
||
|
) -> Tuple[List[str], List[str]]:
|
||
|
in_name = f"{arg_name}_opt_in"
|
||
|
res_name, base_type, res_code, decl = self.argumenttype_evalue_convert(
|
||
|
t.elem, in_name
|
||
|
)
|
||
|
return (
|
||
|
f"""
|
||
|
{ctype.cpp_type(strip_ref=True)} {out_name} = {arg_name}.toOptional<{base_type.cpp_type(strip_ref=True)}>();
|
||
|
""".split(
|
||
|
"\n"
|
||
|
),
|
||
|
decl,
|
||
|
)
|
||
|
|
||
|
def _gen_code_list_type(
|
||
|
self, arg_name: str, out_name: str, t: ListType, ctype: CType
|
||
|
) -> Tuple[List[str], List[str]]:
|
||
|
in_name = f"{arg_name}_list_in"
|
||
|
elem_name = f"{arg_name}_elem"
|
||
|
code = []
|
||
|
res_name, res_ctype, res_code, decl = self.argumenttype_evalue_convert(
|
||
|
t.elem, elem_name
|
||
|
)
|
||
|
|
||
|
if isinstance(t.elem, BaseType) and t.elem.name == BaseTy.Tensor:
|
||
|
code.extend(
|
||
|
f"""
|
||
|
{ctype.cpp_type(strip_ref=True)} {out_name} = {arg_name}.toTensorList();
|
||
|
""".split(
|
||
|
"\n"
|
||
|
)
|
||
|
)
|
||
|
elif isinstance(t.elem, BaseType) and (
|
||
|
t.elem.name == BaseTy.int or t.elem.name == BaseTy.SymInt
|
||
|
):
|
||
|
code.extend(
|
||
|
f"""
|
||
|
{ctype.cpp_type(strip_ref=True)} {out_name} = {arg_name}.toIntList();
|
||
|
""".split(
|
||
|
"\n"
|
||
|
)
|
||
|
)
|
||
|
elif isinstance(t.elem, BaseType) and t.elem.name == BaseTy.float:
|
||
|
code.extend(
|
||
|
f"""
|
||
|
{ctype.cpp_type(strip_ref=True)} {out_name} = {arg_name}.toDoubleList();
|
||
|
""".split(
|
||
|
"\n"
|
||
|
)
|
||
|
)
|
||
|
elif isinstance(t.elem, BaseType) and t.elem.name == BaseTy.bool:
|
||
|
# handle list type with size, e.g., bool[4]
|
||
|
code.extend(
|
||
|
f"""
|
||
|
{ctype.cpp_type(strip_ref=True)} {out_name} = {arg_name}.toBoolList();
|
||
|
""".split(
|
||
|
"\n"
|
||
|
)
|
||
|
)
|
||
|
# pytorch codegen:
|
||
|
# we have to use c10::List for optional element. e.g., Tensor?[] -> c10::List<c10::optional<at::Tensor>>
|
||
|
elif (
|
||
|
isinstance(t.elem, OptionalType)
|
||
|
and isinstance(t.elem.elem, BaseType)
|
||
|
and t.elem.elem.name == BaseTy.Tensor
|
||
|
):
|
||
|
code.extend(
|
||
|
f"""
|
||
|
#ifdef USE_ATEN_LIB
|
||
|
at::ArrayRef<c10::optional<at::Tensor>> {in_name} = {arg_name}.toListOptionalTensor();
|
||
|
c10::List<c10::optional<at::Tensor>> {out_name};
|
||
|
for (auto {elem_name}: {in_name}) {{
|
||
|
{out_name}.push_back({elem_name});
|
||
|
}}
|
||
|
#else
|
||
|
torch::executor::ArrayRef<torch::executor::optional<torch::executor::Tensor>> {out_name} = {arg_name}.toListOptionalTensor();
|
||
|
#endif
|
||
|
""".split(
|
||
|
"\n"
|
||
|
)
|
||
|
)
|
||
|
else:
|
||
|
# use ArrayRef as default.
|
||
|
vec_name = arg_name + "_vec"
|
||
|
# need to bring vector instantiation out of scope so that ArrayRef has valid data
|
||
|
decl.append(
|
||
|
f"std::vector<{res_ctype.cpp_type(strip_ref=True)}> {vec_name};"
|
||
|
)
|
||
|
code.extend(
|
||
|
f"""
|
||
|
for (EValue {elem_name}: {in_name}) {{
|
||
|
{connector.join(res_code)}
|
||
|
{vec_name}.push_back({res_name});
|
||
|
}}
|
||
|
{ctype.cpp_type(strip_ref=True)} {out_name}({vec_name});
|
||
|
""".split(
|
||
|
"\n"
|
||
|
)
|
||
|
)
|
||
|
return code, decl
|