797 lines
26 KiB
Python
797 lines
26 KiB
Python
import json
|
|
import logging
|
|
|
|
import math
|
|
from typing import Dict, List, Optional, Sequence, Tuple, Union
|
|
|
|
import torchgen.api.cpp as cpp
|
|
from torchgen.context import native_function_manager
|
|
from torchgen.model import (
|
|
Argument,
|
|
BackendIndex,
|
|
BaseTy,
|
|
BaseType,
|
|
FunctionSchema,
|
|
NativeFunctionsGroup,
|
|
NativeFunctionsViewGroup,
|
|
OptionalType,
|
|
SelfArgument,
|
|
TensorOptionsArguments,
|
|
Type,
|
|
)
|
|
from torchgen.static_runtime import config
|
|
|
|
logger: logging.Logger = logging.getLogger()
|
|
|
|
|
|
def has_alias(
|
|
arguments: Sequence[Union[Argument, SelfArgument, TensorOptionsArguments]]
|
|
) -> bool:
|
|
for arg in arguments:
|
|
annotation = getattr(arg, "annotation", None)
|
|
if not annotation:
|
|
continue
|
|
alias_set = getattr(annotation, "alias_set", ())
|
|
if alias_set:
|
|
return True
|
|
return False
|
|
|
|
|
|
BLOCKED_OPS = frozenset(
|
|
(
|
|
# non cpu ops
|
|
"sparse_sampled_addmm",
|
|
"hspmm",
|
|
"linalg_svdvals",
|
|
# sparse ops
|
|
"sspaddmm",
|
|
"coalesce",
|
|
"_indices",
|
|
"indices",
|
|
"_values",
|
|
"values",
|
|
"crow_indices",
|
|
"col_indices",
|
|
# deprecated ops
|
|
"floor_divide",
|
|
"ger",
|
|
# buggy ops
|
|
"conj_physical", # P495807361
|
|
"binary_cross_entropy", # P496394764
|
|
"arccosh",
|
|
# uncommon ops
|
|
"cholesky",
|
|
"lu_solve",
|
|
"linalg_cholesky",
|
|
"linalg_householder_product",
|
|
"linalg_ldl_solve",
|
|
"_compute_linear_combination",
|
|
# training related ops
|
|
"_make_dual",
|
|
# cannot call directly
|
|
"_fw_primal",
|
|
# no documentation
|
|
"_index_reduce",
|
|
# TODO: these ones got added recently and need manual inspection
|
|
"_new_zeros_with_same_feature_meta",
|
|
"_conj_physical",
|
|
"binary_cross_entropy_with_logits",
|
|
"bincount",
|
|
"conv_tbc",
|
|
"copy",
|
|
"_copy_from",
|
|
"_copy_from_and_resize",
|
|
"count_nonzero",
|
|
"cudnn_affine_grid_generator",
|
|
"cudnn_affine_grid_generator_backward",
|
|
"cudnn_grid_sampler",
|
|
"diag_embed",
|
|
"embedding",
|
|
"embedding_dense_backward",
|
|
"_embedding_bag_dense_backward",
|
|
"_embedding_bag_per_sample_weights_backward",
|
|
"grid_sampler_2d",
|
|
"_grid_sampler_2d_cpu_fallback",
|
|
"grid_sampler_3d",
|
|
"isnan",
|
|
"mkldnn_linear",
|
|
"median",
|
|
"nanmedian",
|
|
"_sparse_sparse_matmul",
|
|
"batch_norm_backward_elemt",
|
|
"_euclidean_dist",
|
|
"pixel_shuffle",
|
|
"pixel_unshuffle",
|
|
"channel_shuffle",
|
|
"_reshape_nested_backward",
|
|
"relu",
|
|
"prelu",
|
|
"celu",
|
|
"slice_scatter",
|
|
"select_scatter",
|
|
"diagonal_scatter",
|
|
"sum",
|
|
"_mkldnn_transpose",
|
|
"_nested_tensor_from_mask",
|
|
"_nested_from_padded",
|
|
"_nested_tensor_size",
|
|
"_nested_from_padded_and_nested_example",
|
|
"_standard_gamma_grad",
|
|
"_dirichlet_grad",
|
|
"native_norm",
|
|
"_sparse_softmax",
|
|
"_sparse_softmax_backward_data",
|
|
"_sparse_log_softmax",
|
|
"_sparse_log_softmax_backward_data",
|
|
"zero",
|
|
"_sparse_addmm",
|
|
"sparse_mask",
|
|
"_sparse_mask_projection",
|
|
"_to_dense",
|
|
"_coalesce",
|
|
"_coalesced",
|
|
"copy_sparse_to_sparse",
|
|
"to_sparse",
|
|
"to_sparse_csr",
|
|
"to_sparse_csc",
|
|
"to_mkldnn",
|
|
"quantize_per_tensor_dynamic",
|
|
"quantize_per_channel",
|
|
"q_per_channel_scales",
|
|
"q_per_channel_zero_points",
|
|
"int_repr",
|
|
"_make_per_channel_quantized_tensor",
|
|
"set",
|
|
"lift",
|
|
"lift_fresh",
|
|
"lift_fresh_copy",
|
|
"masked_scatter",
|
|
"_masked_softmax",
|
|
"_masked_softmax_backward",
|
|
"put",
|
|
"index_reduce",
|
|
"trace",
|
|
"_cholesky_solve_helper",
|
|
"dist",
|
|
"max",
|
|
"_torch_cuda_cu_linker_symbol_op",
|
|
"glu_jvp",
|
|
"glu_backward_jvp",
|
|
"hardswish_backward",
|
|
"rrelu_with_noise_backward",
|
|
"mkldnn_adaptive_avg_pool2d_backward",
|
|
"_adaptive_avg_pool2d_backward",
|
|
"_adaptive_avg_pool3d_backward",
|
|
"isinf",
|
|
"linalg_lu_solve",
|
|
"linalg_vecdot",
|
|
"linalg_matrix_exp",
|
|
"linalg_eigvalsh",
|
|
"_test_warn_in_autograd",
|
|
"_test_autograd_multiple_dispatch_view",
|
|
"_test_autograd_multiple_dispatch_view_copy",
|
|
"_segment_reduce",
|
|
"_segment_reduce_backward",
|
|
"_fw_primal_copy",
|
|
"_make_dual_copy",
|
|
"view_as_real_copy",
|
|
"view_as_complex_copy",
|
|
"_conj_copy",
|
|
"_neg_view_copy",
|
|
"diagonal_copy",
|
|
"detach_copy",
|
|
"squeeze_copy",
|
|
"t_copy",
|
|
"unsqueeze_copy",
|
|
"_indices_copy",
|
|
"_values_copy",
|
|
"indices_copy",
|
|
"values_copy",
|
|
"crow_indices_copy",
|
|
"col_indices_copy",
|
|
"ccol_indices",
|
|
"ccol_indices_copy",
|
|
"row_indices",
|
|
"row_indices_copy",
|
|
"unfold_copy",
|
|
"alias_copy",
|
|
"_triton_multi_head_attention",
|
|
"special_airy_ai",
|
|
"special_bessel_j0",
|
|
"special_bessel_j1",
|
|
"special_bessel_y0",
|
|
"special_bessel_y1",
|
|
"special_chebyshev_polynomial_t",
|
|
"special_chebyshev_polynomial_u",
|
|
"special_chebyshev_polynomial_v",
|
|
"special_chebyshev_polynomial_w",
|
|
"special_hermite_polynomial_h",
|
|
"special_hermite_polynomial_he",
|
|
"special_laguerre_polynomial_l",
|
|
"special_legendre_polynomial_p",
|
|
"special_modified_bessel_i0",
|
|
"special_modified_bessel_i1",
|
|
"special_modified_bessel_k0",
|
|
"special_modified_bessel_k1",
|
|
"special_scaled_modified_bessel_k0",
|
|
"special_scaled_modified_bessel_k1",
|
|
"special_shifted_chebyshev_polynomial_t",
|
|
"special_shifted_chebyshev_polynomial_u",
|
|
"special_shifted_chebyshev_polynomial_v",
|
|
"special_shifted_chebyshev_polynomial_w",
|
|
"special_spherical_bessel_j0",
|
|
"_foobar",
|
|
"_nested_tensor_strides",
|
|
)
|
|
)
|
|
|
|
|
|
def is_supported(g: Union[NativeFunctionsGroup, NativeFunctionsViewGroup]) -> bool:
|
|
base_op_name = ""
|
|
func = None
|
|
if isinstance(g, NativeFunctionsViewGroup):
|
|
base_op_name = g.view.root_name
|
|
func = g.view.func
|
|
else:
|
|
base_op_name = g.out.func.name.name.base
|
|
func = g.out.func
|
|
if config.is_hand_written(g):
|
|
logger.info("HAND WRITTEN: %s", base_op_name)
|
|
return False
|
|
if base_op_name in BLOCKED_OPS:
|
|
logger.info("BLOCKED: %s", base_op_name)
|
|
return False
|
|
for arg in func.schema_order_arguments():
|
|
maybe_method = ivalue_type_conversion_method(arg.type)
|
|
if not maybe_method:
|
|
# Type converting is unsupported yet.
|
|
logger.info("NOT SUPPORTED TYPE CONVERTING: %s", func)
|
|
return False
|
|
|
|
if isinstance(g, NativeFunctionsViewGroup):
|
|
# TODO: stop doing type tests by converting to C++ and then testing
|
|
# the string, just test the dang thing directly
|
|
if "at::Tensor" != cpp.returns_type(func.returns, symint=False).cpp_type():
|
|
# Returns a non-Tensor value.
|
|
logger.info("NON-TENSOR RET TYPE: %s", str(func))
|
|
return False
|
|
return True
|
|
|
|
# For out variant ops, we need to check the arguments of its functional func.
|
|
for arg in g.functional.func.schema_order_arguments():
|
|
maybe_method = ivalue_type_conversion_method(arg.type)
|
|
if not maybe_method:
|
|
# Type converting is unsupported yet.
|
|
logger.info("NOT SUPPORTED TYPE CONVERTING: %s", g.functional.func)
|
|
return False
|
|
|
|
if not g.structured:
|
|
# In case of unstructured op, we check if it has out variant implementation.
|
|
# The out variant implementation satisfies the minimum requirement that it has the output tensor as the last
|
|
# parameter.
|
|
if (
|
|
not hasattr(g, "out")
|
|
or not str(func).endswith("Tensor(a!) out) -> Tensor(a!)")
|
|
or not str(func.name).endswith(".out")
|
|
):
|
|
return False
|
|
# TODO: stop type testing by converting to C++
|
|
if "at::Tensor &" != cpp.returns_type(func.returns, symint=False).cpp_type():
|
|
logger.info("NON_TENSOR RET TYPE: %s", func)
|
|
return False
|
|
if has_alias(func.arguments.non_out):
|
|
# This op may create an alias of inputs.
|
|
logger.info("INPUTS ALIAS: %s", base_op_name)
|
|
return False
|
|
return True
|
|
|
|
|
|
def ivalue_type_conversion_method(
|
|
arg_type: Union[BaseType, OptionalType, Type]
|
|
) -> Optional[Tuple[bool, str]]:
|
|
"""
|
|
Return the method call expression of `c10::ivalue' to convert its contained value to
|
|
the expected value of `arg_type` type. For example, for `arg_type` == BaseTy.Tensor,
|
|
this function returns ".toTensor()", so that it can be appended to the ivalue's
|
|
variable name to get the value of the expected type.
|
|
"""
|
|
type_conversion_methods = {
|
|
BaseTy.Tensor: ((True, "toTensor()"), (False, "toOptional<at::Tensor>()")),
|
|
BaseTy.int: ((False, "toInt()"), (False, "toOptional<int64_t>()")),
|
|
BaseTy.bool: ((False, "toBool()"), (False, "toOptional<bool>()")),
|
|
BaseTy.Scalar: ((False, "toScalar()"), (False, "toOptional<at::Scalar>()")),
|
|
BaseTy.ScalarType: (
|
|
(False, "toScalarType()"),
|
|
(False, "toOptional<at::ScalarType>()"),
|
|
),
|
|
BaseTy.str: (
|
|
(False, "toStringView()"),
|
|
(False, "toOptional<c10::string_view>()"),
|
|
),
|
|
}
|
|
|
|
base_ty_object = None
|
|
if isinstance(arg_type, BaseType):
|
|
base_ty_object = arg_type.name
|
|
elif isinstance(arg_type, OptionalType):
|
|
if not isinstance(arg_type.elem, BaseType):
|
|
# ListType is currently unsupported.
|
|
return None
|
|
base_ty_object = arg_type.elem.name
|
|
else:
|
|
return None
|
|
|
|
if base_ty_object not in type_conversion_methods:
|
|
return None
|
|
methods = type_conversion_methods[base_ty_object]
|
|
if isinstance(arg_type, BaseType):
|
|
return methods[0]
|
|
return methods[1]
|
|
|
|
|
|
should_use_int_tensor_ops_ = frozenset(
|
|
(
|
|
"bitwise_not",
|
|
"bitwise_and",
|
|
"bitwise_or",
|
|
"bitwise_xor",
|
|
"bitwise_left_shift",
|
|
"bitwise_right_shift",
|
|
"gcd",
|
|
"lcm",
|
|
"scatter",
|
|
"gather",
|
|
"_convert_indices_from_coo_to_csr",
|
|
"_convert_indices_from_csr_to_coo",
|
|
)
|
|
)
|
|
should_use_complex_tensor_ops_ = frozenset(("view_as_real", "imag", "_conj"))
|
|
|
|
|
|
def should_use_int_tensor(op_name: str) -> bool:
|
|
return op_name in should_use_int_tensor_ops_
|
|
|
|
|
|
def should_use_complex_tensor(op_name: str) -> bool:
|
|
return op_name in should_use_complex_tensor_ops_
|
|
|
|
|
|
test_tensor_dim_ops_1_ = frozenset(
|
|
(
|
|
"addmv",
|
|
"index_add",
|
|
"_convert_indices_from_coo_to_csr",
|
|
"_convert_indices_from_csr_to_coo",
|
|
"nll_loss_backward",
|
|
"dot",
|
|
"vdot",
|
|
"outer",
|
|
"ger",
|
|
)
|
|
)
|
|
test_tensor_dim_ops_2_ = frozenset(
|
|
("addmm", "mm", "nuclear_norm", "diag", "_addmm_activation", "matrix_H", "t")
|
|
)
|
|
|
|
|
|
def test_tensor_dim(op_name: str) -> int:
|
|
if op_name in test_tensor_dim_ops_1_:
|
|
return 1
|
|
if op_name in test_tensor_dim_ops_2_:
|
|
return 2
|
|
return 3
|
|
|
|
|
|
test_tensor_shapes_string = '{"view_as_complex": "{2, 2}"}'
|
|
test_tensor_shape_json: Dict[str, str] = json.loads(test_tensor_shapes_string)
|
|
|
|
|
|
def test_tensor_shape(op_name: str) -> str:
|
|
if op_name in test_tensor_shape_json:
|
|
return test_tensor_shape_json[op_name]
|
|
else:
|
|
return ""
|
|
|
|
|
|
def test_value_expression(
|
|
arg_type: Union[BaseType, OptionalType, Type], index: int, op_name: str
|
|
) -> str:
|
|
tensor_size_ex = test_tensor_shape(op_name)
|
|
if tensor_size_ex == "":
|
|
num_tensors = 16 if index == 0 else 64
|
|
num_dim = test_tensor_dim(op_name)
|
|
size_per_dim = math.ceil(num_tensors / float(num_dim))
|
|
size_per_dim += size_per_dim % 2
|
|
tensor_size_ex = "{%s}" % (",".join([f"{size_per_dim}"] * num_dim))
|
|
if should_use_int_tensor(op_name):
|
|
tensor_expression = f"at::randint(1, 100, {tensor_size_ex}, at::kInt)"
|
|
elif should_use_complex_tensor(op_name):
|
|
tensor_expression = f"at::randn({tensor_size_ex}, at::kComplexFloat)"
|
|
else:
|
|
tensor_expression = f"at::rand({tensor_size_ex})"
|
|
|
|
value_expressions = {
|
|
BaseTy.Tensor: tensor_expression,
|
|
BaseTy.int: "1",
|
|
BaseTy.bool: "false",
|
|
BaseTy.Scalar: "2",
|
|
BaseTy.ScalarType: "at::ScalarType::Float",
|
|
BaseTy.str: '"floor"',
|
|
}
|
|
|
|
base_ty_object = None
|
|
if isinstance(arg_type, BaseType):
|
|
base_ty_object = arg_type.name
|
|
else:
|
|
assert isinstance(arg_type, OptionalType) and isinstance(
|
|
arg_type.elem, BaseType
|
|
)
|
|
base_ty_object = arg_type.elem.name
|
|
assert base_ty_object in value_expressions, "not expected type"
|
|
value_expression = value_expressions[base_ty_object]
|
|
return value_expression
|
|
|
|
|
|
def generate_test_value_definitions(schema: FunctionSchema, index: int) -> str:
|
|
assert not schema.is_out_fn()
|
|
schema_name = schema.name.name.base
|
|
arg_map = {}
|
|
for arg in schema.schema_order_arguments():
|
|
test_value_exp = test_value_expression(arg.type, index, schema_name)
|
|
arg_map[arg.name] = test_value_exp
|
|
config.override_test_values(arg_map, schema_name, index)
|
|
arg_populations = []
|
|
for arg_name, arg_value in arg_map.items():
|
|
arg_populations.append(f"auto {arg_name}{index} = {arg_value}")
|
|
return ";\n ".join(arg_populations) + ";"
|
|
|
|
|
|
def generate_test_value_names(schema: FunctionSchema, index: int) -> str:
|
|
assert not schema.is_out_fn()
|
|
return ",".join(f"{arg.name}{index}" for arg in schema.schema_order_arguments())
|
|
|
|
|
|
generate_test_ir_arguments_base_ty_to_type_str_ = {
|
|
BaseTy.Tensor: "Tensor",
|
|
BaseTy.int: "int",
|
|
BaseTy.float: "float",
|
|
BaseTy.str: "str",
|
|
BaseTy.Scalar: "int",
|
|
BaseTy.ScalarType: "int",
|
|
BaseTy.bool: "bool",
|
|
}
|
|
|
|
|
|
def generate_test_ir_arguments(
|
|
schema: FunctionSchema,
|
|
) -> List[Tuple[str, Optional[str]]]:
|
|
def ir_argument(arg: Argument) -> Tuple[str, Optional[str]]:
|
|
t = arg.type
|
|
add_optional = False
|
|
if isinstance(t, OptionalType):
|
|
t = t.elem
|
|
add_optional = True
|
|
assert isinstance(t, BaseType)
|
|
type_str = None
|
|
if t.name in generate_test_ir_arguments_base_ty_to_type_str_:
|
|
type_str = generate_test_ir_arguments_base_ty_to_type_str_[t.name]
|
|
if type_str and add_optional:
|
|
type_str = f"{type_str}?"
|
|
return ("%" + arg.name, type_str)
|
|
|
|
return [ir_argument(arg) for arg in schema.schema_order_arguments()]
|
|
|
|
|
|
def generate_arg_extraction(schema: FunctionSchema) -> str:
|
|
arg_populations = []
|
|
for i, arg in enumerate(schema.schema_order_arguments()):
|
|
maybe_method = ivalue_type_conversion_method(arg.type)
|
|
assert maybe_method
|
|
is_reference, type_conversion_method = maybe_method
|
|
reference = "&" if is_reference else ""
|
|
arg_populations.append(
|
|
f"const auto{reference} {arg.name} = p_node->Input({i}).{type_conversion_method}"
|
|
)
|
|
return ";\n ".join(arg_populations) + ";"
|
|
|
|
|
|
def get_kernel_name(g: NativeFunctionsGroup, backend_index: BackendIndex) -> str:
|
|
kernel = backend_index.get_kernel(g.functional)
|
|
if g.structured or kernel is None:
|
|
return cpp.name(g.functional.func)
|
|
return kernel.kernel
|
|
|
|
|
|
def get_out_kernel_name(g: NativeFunctionsGroup, backend_index: BackendIndex) -> str:
|
|
kernel = backend_index.get_kernel(g.out)
|
|
if g.structured or kernel is None:
|
|
return cpp.name(g.out.func)
|
|
return kernel.kernel
|
|
|
|
|
|
def generate_non_out_variant_call(
|
|
g: NativeFunctionsGroup, backend_index: BackendIndex
|
|
) -> str:
|
|
schema = g.functional.func
|
|
assert not schema.is_out_fn()
|
|
kernel_name = get_kernel_name(g, backend_index)
|
|
arg_names = (arg.name for arg in schema.schema_order_arguments())
|
|
namespace_name = "cpu" if g.structured else "native"
|
|
return f'at::{namespace_name}::{kernel_name}({",".join(arg_names)})'
|
|
|
|
|
|
def generate_call_to_view_ops(
|
|
g: NativeFunctionsViewGroup, backend_index: BackendIndex
|
|
) -> str:
|
|
schema = g.view.func
|
|
kernel_name = cpp.name(schema)
|
|
kernel = backend_index.get_kernel(g.view)
|
|
if kernel:
|
|
kernel_name = kernel.kernel
|
|
arg_names = (arg.name for arg in schema.schema_order_arguments())
|
|
namespace_name = "native"
|
|
return f'at::{namespace_name}::{kernel_name}({",".join(arg_names)})'
|
|
|
|
|
|
def generate_out_variant_call(
|
|
g: NativeFunctionsGroup, backend_index: BackendIndex
|
|
) -> str:
|
|
schema = g.out.func
|
|
assert schema.is_out_fn()
|
|
arg_names = []
|
|
kernel_name = get_out_kernel_name(g, backend_index)
|
|
if g.structured:
|
|
# structured op starts with the output tensor argument.
|
|
arg_names = [out_arg.name for out_arg in schema.arguments.out]
|
|
else:
|
|
arg_names = []
|
|
for arg in schema.arguments.non_out:
|
|
if isinstance(arg, SelfArgument):
|
|
arg_names.append(arg.argument.name)
|
|
else:
|
|
assert isinstance(arg, Argument)
|
|
arg_names.append(arg.name)
|
|
if not g.structured:
|
|
assert len(schema.arguments.out) == 1
|
|
arg_names.append(schema.arguments.out[0].name)
|
|
cpp_arg_names = ",".join(arg_names)
|
|
namespace_name = "cpu" if g.structured else "native"
|
|
return f"at::{namespace_name}::{kernel_name}({cpp_arg_names})"
|
|
|
|
|
|
no_memory_resize_ops = frozenset(
|
|
(
|
|
"isin.Scalar_Tensor",
|
|
"index_add",
|
|
"dot",
|
|
"vdot",
|
|
"nuclear_norm",
|
|
"histc",
|
|
"l1_loss",
|
|
"multi_margin_loss",
|
|
"multilabel_margin_loss",
|
|
"nll_loss",
|
|
"nll_loss2d",
|
|
"prod",
|
|
)
|
|
)
|
|
|
|
|
|
def should_check_resize(schema: FunctionSchema) -> bool:
|
|
schema_str = str(schema)
|
|
type_variant_op_name = schema_str[: schema_str.find("(")]
|
|
return type_variant_op_name not in no_memory_resize_ops
|
|
|
|
|
|
def op_name_from_group(g: NativeFunctionsGroup) -> str:
|
|
return g.functional.func.name.name.base
|
|
|
|
|
|
class GenOpDispatcher:
|
|
def out_variant(
|
|
self, groups: Sequence[NativeFunctionsGroup], backend_index: BackendIndex
|
|
) -> str:
|
|
if not groups:
|
|
return ""
|
|
generated_type_variants = []
|
|
for g in groups:
|
|
with native_function_manager(g):
|
|
assert is_supported(g)
|
|
assert isinstance(g, NativeFunctionsGroup)
|
|
generated_type_variant = self.out_variant_op_generator(g, backend_index)
|
|
generated_type_variants.append(generated_type_variant)
|
|
op_name = op_name_from_group(groups[0])
|
|
body = "\n".join(generated_type_variants)
|
|
generated = f"""
|
|
REGISTER_OPERATOR_FUNCTOR(
|
|
aten::{op_name},
|
|
aten_{op_name},
|
|
[](Node* n) -> SROperator {{
|
|
{body}
|
|
LogAndDumpSchema(n);
|
|
return nullptr;
|
|
}});
|
|
"""
|
|
return generated
|
|
|
|
def view(
|
|
self, groups: Sequence[NativeFunctionsViewGroup], backend_index: BackendIndex
|
|
) -> str:
|
|
if not groups:
|
|
return ""
|
|
generated_type_variants = []
|
|
for g in groups:
|
|
with native_function_manager(g):
|
|
assert is_supported(g)
|
|
assert isinstance(g, NativeFunctionsViewGroup)
|
|
generated_type_variant = self.view_op_generator(g, backend_index)
|
|
generated_type_variants.append(generated_type_variant)
|
|
op_name = config.func_name_base_str(groups[0])
|
|
body = "\n".join(generated_type_variants)
|
|
generated = f"""
|
|
REGISTER_NATIVE_OPERATOR_FUNCTOR(
|
|
aten::{op_name},
|
|
aten_{op_name},
|
|
[](Node* n) -> SROperator {{
|
|
{body}
|
|
LogAndDumpSchema(n);
|
|
return nullptr;
|
|
}});
|
|
"""
|
|
return generated
|
|
|
|
def out_variant_op_generator(
|
|
self, g: NativeFunctionsGroup, backend_index: BackendIndex
|
|
) -> str:
|
|
functional = g.functional
|
|
schema = str(functional.func)
|
|
populated_argument = generate_arg_extraction(g.functional.func)
|
|
functional_variant_call = generate_non_out_variant_call(g, backend_index)
|
|
assert len(g.out.func.arguments.out) == 1
|
|
out_variable_name = str(g.out.func.arguments.out[0].name)
|
|
out_variant_call = generate_out_variant_call(g, backend_index)
|
|
generated = f"""
|
|
if (n->matches(torch::schema("aten::{schema}"))) {{
|
|
return [](ProcessedNode* p_node) {{
|
|
{populated_argument}
|
|
if (p_node->Output(0).isNone()) {{
|
|
p_node->Output(0) = {functional_variant_call};
|
|
return;
|
|
}}
|
|
auto& {out_variable_name} = p_node->Output(0).toTensor();
|
|
fastResizeToZero({out_variable_name});
|
|
{out_variant_call};
|
|
}};
|
|
}}"""
|
|
return generated
|
|
|
|
def view_op_generator(
|
|
self, g: NativeFunctionsViewGroup, backend_index: BackendIndex
|
|
) -> str:
|
|
schema = str(g.view.func)
|
|
populated_argument = generate_arg_extraction(g.view.func)
|
|
functional_variant_call = generate_call_to_view_ops(g, backend_index)
|
|
generated = f"""
|
|
if (n->matches(torch::schema("aten::{schema}"))) {{
|
|
return [](ProcessedNode* p_node) {{
|
|
{populated_argument}
|
|
p_node->Output(0) = {functional_variant_call};
|
|
}};
|
|
}}"""
|
|
return generated
|
|
|
|
|
|
class GenOpTestCase:
|
|
def out_variant(self, groups: Sequence[NativeFunctionsGroup]) -> str:
|
|
if not groups:
|
|
return ""
|
|
generated_type_variants = []
|
|
for g in groups:
|
|
with native_function_manager(g):
|
|
assert is_supported(g)
|
|
assert isinstance(g, NativeFunctionsGroup)
|
|
generated_type_variant = self.out_variant_op_test_case_generator(g)
|
|
generated_type_variants.append(generated_type_variant)
|
|
return "\n".join(generated_type_variants)
|
|
|
|
def view(self, groups: Sequence[NativeFunctionsViewGroup]) -> str:
|
|
if not groups:
|
|
return ""
|
|
generated_type_variants = []
|
|
for g in groups:
|
|
with native_function_manager(g):
|
|
assert is_supported(g)
|
|
assert isinstance(g, NativeFunctionsViewGroup)
|
|
generated_type_variant = self.view_op_test_case_generator(g)
|
|
generated_type_variants.append(generated_type_variant)
|
|
return "\n".join(generated_type_variants)
|
|
|
|
def out_variant_op_test_case_generator(self, g: NativeFunctionsGroup) -> str:
|
|
schema = g.functional.func
|
|
schema_str = str(schema)
|
|
assert schema_str.find("(") > 0
|
|
type_variant_op_name = schema_str[: schema_str.find("(")].replace(".", "_")
|
|
op_name = op_name_from_group(g)
|
|
assert type_variant_op_name.startswith(op_name)
|
|
|
|
arg_types = generate_test_ir_arguments(schema)
|
|
arg_declarations = ", ".join(
|
|
(
|
|
arg_name if arg_type is None else f"{arg_name}: {arg_type}"
|
|
for arg_name, arg_type in arg_types
|
|
)
|
|
)
|
|
arg_names = ", ".join((arg_name for arg_name, _ in arg_types))
|
|
assert (
|
|
len(schema.returns) == 1
|
|
and isinstance(schema.returns[0].type, BaseType)
|
|
and schema.returns[0].type.name is BaseTy.Tensor
|
|
)
|
|
test_value_definitions = generate_test_value_definitions(schema, 0)
|
|
test_value_names = generate_test_value_names(schema, 0)
|
|
test_value_definitions2 = generate_test_value_definitions(schema, 1)
|
|
test_value_names2 = generate_test_value_names(schema, 1)
|
|
check_resize = "true" if should_check_resize(schema) else "false"
|
|
generated = f"""
|
|
TEST(StaticRuntime, autogen_{type_variant_op_name}) {{
|
|
const std::string script = R"IR(
|
|
graph({arg_declarations}):
|
|
%bias: None = prim::Constant()
|
|
%ret = aten::{op_name}({arg_names})
|
|
%cloned = aten::clone(%ret, %bias)
|
|
return (%cloned)
|
|
)IR";
|
|
|
|
{test_value_definitions}
|
|
std::vector<IValue> args{{{test_value_names}}};
|
|
testStaticRuntime(script, args, {{}}, /*use_allclose=*/false, /*use_equalnan=*/false, /*check_resize=*/{check_resize});
|
|
|
|
{test_value_definitions2}
|
|
std::vector<IValue> args2{{{test_value_names2}}};
|
|
testStaticRuntime(script, args, args2, /*use_allclose=*/false, /*use_equalnan=*/false, /*check_resize=*/{check_resize});
|
|
|
|
}}
|
|
"""
|
|
return generated
|
|
|
|
def view_op_test_case_generator(self, g: NativeFunctionsViewGroup) -> str:
|
|
schema = g.view.func
|
|
schema_str = str(schema)
|
|
assert schema_str.find("(") > 0
|
|
type_variant_op_name = schema_str[: schema_str.find("(")].replace(".", "_")
|
|
op_name = g.view.root_name
|
|
assert type_variant_op_name.startswith(op_name)
|
|
|
|
arg_types = generate_test_ir_arguments(schema)
|
|
arg_declarations = ", ".join(
|
|
(
|
|
arg_name if arg_type is None else f"{arg_name}: {arg_type}"
|
|
for arg_name, arg_type in arg_types
|
|
)
|
|
)
|
|
arg_names = ", ".join((arg_name for arg_name, _ in arg_types))
|
|
assert (
|
|
len(schema.returns) == 1
|
|
and isinstance(schema.returns[0].type, BaseType)
|
|
and schema.returns[0].type.name is BaseTy.Tensor
|
|
)
|
|
test_value_definitions = generate_test_value_definitions(schema, 0)
|
|
test_value_names = generate_test_value_names(schema, 0)
|
|
generated = f"""
|
|
TEST(StaticRuntime, autogen_{type_variant_op_name}) {{
|
|
const std::string script = R"IR(
|
|
graph({arg_declarations}):
|
|
%bias: None = prim::Constant()
|
|
%ret = aten::{op_name}({arg_names})
|
|
%cloned = aten::clone(%ret, %bias)
|
|
return (%cloned)
|
|
)IR";
|
|
|
|
{test_value_definitions}
|
|
std::vector<IValue> args{{{test_value_names}}};
|
|
testStaticRuntime(script, args);
|
|
}}
|
|
"""
|
|
|
|
return generated
|