import json import logging import math from typing import Dict, List, Optional, Sequence, Tuple, Union import torchgen.api.cpp as cpp from torchgen.context import native_function_manager from torchgen.model import ( Argument, BackendIndex, BaseTy, BaseType, FunctionSchema, NativeFunctionsGroup, NativeFunctionsViewGroup, OptionalType, SelfArgument, TensorOptionsArguments, Type, ) from torchgen.static_runtime import config logger: logging.Logger = logging.getLogger() def has_alias( arguments: Sequence[Union[Argument, SelfArgument, TensorOptionsArguments]] ) -> bool: for arg in arguments: annotation = getattr(arg, "annotation", None) if not annotation: continue alias_set = getattr(annotation, "alias_set", ()) if alias_set: return True return False BLOCKED_OPS = frozenset( ( # non cpu ops "sparse_sampled_addmm", "hspmm", "linalg_svdvals", # sparse ops "sspaddmm", "coalesce", "_indices", "indices", "_values", "values", "crow_indices", "col_indices", # deprecated ops "floor_divide", "ger", # buggy ops "conj_physical", # P495807361 "binary_cross_entropy", # P496394764 "arccosh", # uncommon ops "cholesky", "lu_solve", "linalg_cholesky", "linalg_householder_product", "linalg_ldl_solve", "_compute_linear_combination", # training related ops "_make_dual", # cannot call directly "_fw_primal", # no documentation "_index_reduce", # TODO: these ones got added recently and need manual inspection "_new_zeros_with_same_feature_meta", "_conj_physical", "binary_cross_entropy_with_logits", "bincount", "conv_tbc", "copy", "_copy_from", "_copy_from_and_resize", "count_nonzero", "cudnn_affine_grid_generator", "cudnn_affine_grid_generator_backward", "cudnn_grid_sampler", "diag_embed", "embedding", "embedding_dense_backward", "_embedding_bag_dense_backward", "_embedding_bag_per_sample_weights_backward", "grid_sampler_2d", "_grid_sampler_2d_cpu_fallback", "grid_sampler_3d", "isnan", "mkldnn_linear", "median", "nanmedian", "_sparse_sparse_matmul", "batch_norm_backward_elemt", "_euclidean_dist", "pixel_shuffle", "pixel_unshuffle", "channel_shuffle", "_reshape_nested_backward", "relu", "prelu", "celu", "slice_scatter", "select_scatter", "diagonal_scatter", "sum", "_mkldnn_transpose", "_nested_tensor_from_mask", "_nested_from_padded", "_nested_tensor_size", "_nested_from_padded_and_nested_example", "_standard_gamma_grad", "_dirichlet_grad", "native_norm", "_sparse_softmax", "_sparse_softmax_backward_data", "_sparse_log_softmax", "_sparse_log_softmax_backward_data", "zero", "_sparse_addmm", "sparse_mask", "_sparse_mask_projection", "_to_dense", "_coalesce", "_coalesced", "copy_sparse_to_sparse", "to_sparse", "to_sparse_csr", "to_sparse_csc", "to_mkldnn", "quantize_per_tensor_dynamic", "quantize_per_channel", "q_per_channel_scales", "q_per_channel_zero_points", "int_repr", "_make_per_channel_quantized_tensor", "set", "lift", "lift_fresh", "lift_fresh_copy", "masked_scatter", "_masked_softmax", "_masked_softmax_backward", "put", "index_reduce", "trace", "_cholesky_solve_helper", "dist", "max", "_torch_cuda_cu_linker_symbol_op", "glu_jvp", "glu_backward_jvp", "hardswish_backward", "rrelu_with_noise_backward", "mkldnn_adaptive_avg_pool2d_backward", "_adaptive_avg_pool2d_backward", "_adaptive_avg_pool3d_backward", "isinf", "linalg_lu_solve", "linalg_vecdot", "linalg_matrix_exp", "linalg_eigvalsh", "_test_warn_in_autograd", "_test_autograd_multiple_dispatch_view", "_test_autograd_multiple_dispatch_view_copy", "_segment_reduce", "_segment_reduce_backward", "_fw_primal_copy", "_make_dual_copy", "view_as_real_copy", "view_as_complex_copy", "_conj_copy", "_neg_view_copy", "diagonal_copy", "detach_copy", "squeeze_copy", "t_copy", "unsqueeze_copy", "_indices_copy", "_values_copy", "indices_copy", "values_copy", "crow_indices_copy", "col_indices_copy", "ccol_indices", "ccol_indices_copy", "row_indices", "row_indices_copy", "unfold_copy", "alias_copy", "_triton_multi_head_attention", "special_airy_ai", "special_bessel_j0", "special_bessel_j1", "special_bessel_y0", "special_bessel_y1", "special_chebyshev_polynomial_t", "special_chebyshev_polynomial_u", "special_chebyshev_polynomial_v", "special_chebyshev_polynomial_w", "special_hermite_polynomial_h", "special_hermite_polynomial_he", "special_laguerre_polynomial_l", "special_legendre_polynomial_p", "special_modified_bessel_i0", "special_modified_bessel_i1", "special_modified_bessel_k0", "special_modified_bessel_k1", "special_scaled_modified_bessel_k0", "special_scaled_modified_bessel_k1", "special_shifted_chebyshev_polynomial_t", "special_shifted_chebyshev_polynomial_u", "special_shifted_chebyshev_polynomial_v", "special_shifted_chebyshev_polynomial_w", "special_spherical_bessel_j0", "_foobar", "_nested_tensor_strides", ) ) def is_supported(g: Union[NativeFunctionsGroup, NativeFunctionsViewGroup]) -> bool: base_op_name = "" func = None if isinstance(g, NativeFunctionsViewGroup): base_op_name = g.view.root_name func = g.view.func else: base_op_name = g.out.func.name.name.base func = g.out.func if config.is_hand_written(g): logger.info("HAND WRITTEN: %s", base_op_name) return False if base_op_name in BLOCKED_OPS: logger.info("BLOCKED: %s", base_op_name) return False for arg in func.schema_order_arguments(): maybe_method = ivalue_type_conversion_method(arg.type) if not maybe_method: # Type converting is unsupported yet. logger.info("NOT SUPPORTED TYPE CONVERTING: %s", func) return False if isinstance(g, NativeFunctionsViewGroup): # TODO: stop doing type tests by converting to C++ and then testing # the string, just test the dang thing directly if "at::Tensor" != cpp.returns_type(func.returns, symint=False).cpp_type(): # Returns a non-Tensor value. logger.info("NON-TENSOR RET TYPE: %s", str(func)) return False return True # For out variant ops, we need to check the arguments of its functional func. for arg in g.functional.func.schema_order_arguments(): maybe_method = ivalue_type_conversion_method(arg.type) if not maybe_method: # Type converting is unsupported yet. logger.info("NOT SUPPORTED TYPE CONVERTING: %s", g.functional.func) return False if not g.structured: # In case of unstructured op, we check if it has out variant implementation. # The out variant implementation satisfies the minimum requirement that it has the output tensor as the last # parameter. if ( not hasattr(g, "out") or not str(func).endswith("Tensor(a!) out) -> Tensor(a!)") or not str(func.name).endswith(".out") ): return False # TODO: stop type testing by converting to C++ if "at::Tensor &" != cpp.returns_type(func.returns, symint=False).cpp_type(): logger.info("NON_TENSOR RET TYPE: %s", func) return False if has_alias(func.arguments.non_out): # This op may create an alias of inputs. logger.info("INPUTS ALIAS: %s", base_op_name) return False return True def ivalue_type_conversion_method( arg_type: Union[BaseType, OptionalType, Type] ) -> Optional[Tuple[bool, str]]: """ Return the method call expression of `c10::ivalue' to convert its contained value to the expected value of `arg_type` type. For example, for `arg_type` == BaseTy.Tensor, this function returns ".toTensor()", so that it can be appended to the ivalue's variable name to get the value of the expected type. """ type_conversion_methods = { BaseTy.Tensor: ((True, "toTensor()"), (False, "toOptional()")), BaseTy.int: ((False, "toInt()"), (False, "toOptional()")), BaseTy.bool: ((False, "toBool()"), (False, "toOptional()")), BaseTy.Scalar: ((False, "toScalar()"), (False, "toOptional()")), BaseTy.ScalarType: ( (False, "toScalarType()"), (False, "toOptional()"), ), BaseTy.str: ( (False, "toStringView()"), (False, "toOptional()"), ), } base_ty_object = None if isinstance(arg_type, BaseType): base_ty_object = arg_type.name elif isinstance(arg_type, OptionalType): if not isinstance(arg_type.elem, BaseType): # ListType is currently unsupported. return None base_ty_object = arg_type.elem.name else: return None if base_ty_object not in type_conversion_methods: return None methods = type_conversion_methods[base_ty_object] if isinstance(arg_type, BaseType): return methods[0] return methods[1] should_use_int_tensor_ops_ = frozenset( ( "bitwise_not", "bitwise_and", "bitwise_or", "bitwise_xor", "bitwise_left_shift", "bitwise_right_shift", "gcd", "lcm", "scatter", "gather", "_convert_indices_from_coo_to_csr", "_convert_indices_from_csr_to_coo", ) ) should_use_complex_tensor_ops_ = frozenset(("view_as_real", "imag", "_conj")) def should_use_int_tensor(op_name: str) -> bool: return op_name in should_use_int_tensor_ops_ def should_use_complex_tensor(op_name: str) -> bool: return op_name in should_use_complex_tensor_ops_ test_tensor_dim_ops_1_ = frozenset( ( "addmv", "index_add", "_convert_indices_from_coo_to_csr", "_convert_indices_from_csr_to_coo", "nll_loss_backward", "dot", "vdot", "outer", "ger", ) ) test_tensor_dim_ops_2_ = frozenset( ("addmm", "mm", "nuclear_norm", "diag", "_addmm_activation", "matrix_H", "t") ) def test_tensor_dim(op_name: str) -> int: if op_name in test_tensor_dim_ops_1_: return 1 if op_name in test_tensor_dim_ops_2_: return 2 return 3 test_tensor_shapes_string = '{"view_as_complex": "{2, 2}"}' test_tensor_shape_json: Dict[str, str] = json.loads(test_tensor_shapes_string) def test_tensor_shape(op_name: str) -> str: if op_name in test_tensor_shape_json: return test_tensor_shape_json[op_name] else: return "" def test_value_expression( arg_type: Union[BaseType, OptionalType, Type], index: int, op_name: str ) -> str: tensor_size_ex = test_tensor_shape(op_name) if tensor_size_ex == "": num_tensors = 16 if index == 0 else 64 num_dim = test_tensor_dim(op_name) size_per_dim = math.ceil(num_tensors / float(num_dim)) size_per_dim += size_per_dim % 2 tensor_size_ex = "{%s}" % (",".join([f"{size_per_dim}"] * num_dim)) if should_use_int_tensor(op_name): tensor_expression = f"at::randint(1, 100, {tensor_size_ex}, at::kInt)" elif should_use_complex_tensor(op_name): tensor_expression = f"at::randn({tensor_size_ex}, at::kComplexFloat)" else: tensor_expression = f"at::rand({tensor_size_ex})" value_expressions = { BaseTy.Tensor: tensor_expression, BaseTy.int: "1", BaseTy.bool: "false", BaseTy.Scalar: "2", BaseTy.ScalarType: "at::ScalarType::Float", BaseTy.str: '"floor"', } base_ty_object = None if isinstance(arg_type, BaseType): base_ty_object = arg_type.name else: assert isinstance(arg_type, OptionalType) and isinstance( arg_type.elem, BaseType ) base_ty_object = arg_type.elem.name assert base_ty_object in value_expressions, "not expected type" value_expression = value_expressions[base_ty_object] return value_expression def generate_test_value_definitions(schema: FunctionSchema, index: int) -> str: assert not schema.is_out_fn() schema_name = schema.name.name.base arg_map = {} for arg in schema.schema_order_arguments(): test_value_exp = test_value_expression(arg.type, index, schema_name) arg_map[arg.name] = test_value_exp config.override_test_values(arg_map, schema_name, index) arg_populations = [] for arg_name, arg_value in arg_map.items(): arg_populations.append(f"auto {arg_name}{index} = {arg_value}") return ";\n ".join(arg_populations) + ";" def generate_test_value_names(schema: FunctionSchema, index: int) -> str: assert not schema.is_out_fn() return ",".join(f"{arg.name}{index}" for arg in schema.schema_order_arguments()) generate_test_ir_arguments_base_ty_to_type_str_ = { BaseTy.Tensor: "Tensor", BaseTy.int: "int", BaseTy.float: "float", BaseTy.str: "str", BaseTy.Scalar: "int", BaseTy.ScalarType: "int", BaseTy.bool: "bool", } def generate_test_ir_arguments( schema: FunctionSchema, ) -> List[Tuple[str, Optional[str]]]: def ir_argument(arg: Argument) -> Tuple[str, Optional[str]]: t = arg.type add_optional = False if isinstance(t, OptionalType): t = t.elem add_optional = True assert isinstance(t, BaseType) type_str = None if t.name in generate_test_ir_arguments_base_ty_to_type_str_: type_str = generate_test_ir_arguments_base_ty_to_type_str_[t.name] if type_str and add_optional: type_str = f"{type_str}?" return ("%" + arg.name, type_str) return [ir_argument(arg) for arg in schema.schema_order_arguments()] def generate_arg_extraction(schema: FunctionSchema) -> str: arg_populations = [] for i, arg in enumerate(schema.schema_order_arguments()): maybe_method = ivalue_type_conversion_method(arg.type) assert maybe_method is_reference, type_conversion_method = maybe_method reference = "&" if is_reference else "" arg_populations.append( f"const auto{reference} {arg.name} = p_node->Input({i}).{type_conversion_method}" ) return ";\n ".join(arg_populations) + ";" def get_kernel_name(g: NativeFunctionsGroup, backend_index: BackendIndex) -> str: kernel = backend_index.get_kernel(g.functional) if g.structured or kernel is None: return cpp.name(g.functional.func) return kernel.kernel def get_out_kernel_name(g: NativeFunctionsGroup, backend_index: BackendIndex) -> str: kernel = backend_index.get_kernel(g.out) if g.structured or kernel is None: return cpp.name(g.out.func) return kernel.kernel def generate_non_out_variant_call( g: NativeFunctionsGroup, backend_index: BackendIndex ) -> str: schema = g.functional.func assert not schema.is_out_fn() kernel_name = get_kernel_name(g, backend_index) arg_names = (arg.name for arg in schema.schema_order_arguments()) namespace_name = "cpu" if g.structured else "native" return f'at::{namespace_name}::{kernel_name}({",".join(arg_names)})' def generate_call_to_view_ops( g: NativeFunctionsViewGroup, backend_index: BackendIndex ) -> str: schema = g.view.func kernel_name = cpp.name(schema) kernel = backend_index.get_kernel(g.view) if kernel: kernel_name = kernel.kernel arg_names = (arg.name for arg in schema.schema_order_arguments()) namespace_name = "native" return f'at::{namespace_name}::{kernel_name}({",".join(arg_names)})' def generate_out_variant_call( g: NativeFunctionsGroup, backend_index: BackendIndex ) -> str: schema = g.out.func assert schema.is_out_fn() arg_names = [] kernel_name = get_out_kernel_name(g, backend_index) if g.structured: # structured op starts with the output tensor argument. arg_names = [out_arg.name for out_arg in schema.arguments.out] else: arg_names = [] for arg in schema.arguments.non_out: if isinstance(arg, SelfArgument): arg_names.append(arg.argument.name) else: assert isinstance(arg, Argument) arg_names.append(arg.name) if not g.structured: assert len(schema.arguments.out) == 1 arg_names.append(schema.arguments.out[0].name) cpp_arg_names = ",".join(arg_names) namespace_name = "cpu" if g.structured else "native" return f"at::{namespace_name}::{kernel_name}({cpp_arg_names})" no_memory_resize_ops = frozenset( ( "isin.Scalar_Tensor", "index_add", "dot", "vdot", "nuclear_norm", "histc", "l1_loss", "multi_margin_loss", "multilabel_margin_loss", "nll_loss", "nll_loss2d", "prod", ) ) def should_check_resize(schema: FunctionSchema) -> bool: schema_str = str(schema) type_variant_op_name = schema_str[: schema_str.find("(")] return type_variant_op_name not in no_memory_resize_ops def op_name_from_group(g: NativeFunctionsGroup) -> str: return g.functional.func.name.name.base class GenOpDispatcher: def out_variant( self, groups: Sequence[NativeFunctionsGroup], backend_index: BackendIndex ) -> str: if not groups: return "" generated_type_variants = [] for g in groups: with native_function_manager(g): assert is_supported(g) assert isinstance(g, NativeFunctionsGroup) generated_type_variant = self.out_variant_op_generator(g, backend_index) generated_type_variants.append(generated_type_variant) op_name = op_name_from_group(groups[0]) body = "\n".join(generated_type_variants) generated = f""" REGISTER_OPERATOR_FUNCTOR( aten::{op_name}, aten_{op_name}, [](Node* n) -> SROperator {{ {body} LogAndDumpSchema(n); return nullptr; }}); """ return generated def view( self, groups: Sequence[NativeFunctionsViewGroup], backend_index: BackendIndex ) -> str: if not groups: return "" generated_type_variants = [] for g in groups: with native_function_manager(g): assert is_supported(g) assert isinstance(g, NativeFunctionsViewGroup) generated_type_variant = self.view_op_generator(g, backend_index) generated_type_variants.append(generated_type_variant) op_name = config.func_name_base_str(groups[0]) body = "\n".join(generated_type_variants) generated = f""" REGISTER_NATIVE_OPERATOR_FUNCTOR( aten::{op_name}, aten_{op_name}, [](Node* n) -> SROperator {{ {body} LogAndDumpSchema(n); return nullptr; }}); """ return generated def out_variant_op_generator( self, g: NativeFunctionsGroup, backend_index: BackendIndex ) -> str: functional = g.functional schema = str(functional.func) populated_argument = generate_arg_extraction(g.functional.func) functional_variant_call = generate_non_out_variant_call(g, backend_index) assert len(g.out.func.arguments.out) == 1 out_variable_name = str(g.out.func.arguments.out[0].name) out_variant_call = generate_out_variant_call(g, backend_index) generated = f""" if (n->matches(torch::schema("aten::{schema}"))) {{ return [](ProcessedNode* p_node) {{ {populated_argument} if (p_node->Output(0).isNone()) {{ p_node->Output(0) = {functional_variant_call}; return; }} auto& {out_variable_name} = p_node->Output(0).toTensor(); fastResizeToZero({out_variable_name}); {out_variant_call}; }}; }}""" return generated def view_op_generator( self, g: NativeFunctionsViewGroup, backend_index: BackendIndex ) -> str: schema = str(g.view.func) populated_argument = generate_arg_extraction(g.view.func) functional_variant_call = generate_call_to_view_ops(g, backend_index) generated = f""" if (n->matches(torch::schema("aten::{schema}"))) {{ return [](ProcessedNode* p_node) {{ {populated_argument} p_node->Output(0) = {functional_variant_call}; }}; }}""" return generated class GenOpTestCase: def out_variant(self, groups: Sequence[NativeFunctionsGroup]) -> str: if not groups: return "" generated_type_variants = [] for g in groups: with native_function_manager(g): assert is_supported(g) assert isinstance(g, NativeFunctionsGroup) generated_type_variant = self.out_variant_op_test_case_generator(g) generated_type_variants.append(generated_type_variant) return "\n".join(generated_type_variants) def view(self, groups: Sequence[NativeFunctionsViewGroup]) -> str: if not groups: return "" generated_type_variants = [] for g in groups: with native_function_manager(g): assert is_supported(g) assert isinstance(g, NativeFunctionsViewGroup) generated_type_variant = self.view_op_test_case_generator(g) generated_type_variants.append(generated_type_variant) return "\n".join(generated_type_variants) def out_variant_op_test_case_generator(self, g: NativeFunctionsGroup) -> str: schema = g.functional.func schema_str = str(schema) assert schema_str.find("(") > 0 type_variant_op_name = schema_str[: schema_str.find("(")].replace(".", "_") op_name = op_name_from_group(g) assert type_variant_op_name.startswith(op_name) arg_types = generate_test_ir_arguments(schema) arg_declarations = ", ".join( ( arg_name if arg_type is None else f"{arg_name}: {arg_type}" for arg_name, arg_type in arg_types ) ) arg_names = ", ".join((arg_name for arg_name, _ in arg_types)) assert ( len(schema.returns) == 1 and isinstance(schema.returns[0].type, BaseType) and schema.returns[0].type.name is BaseTy.Tensor ) test_value_definitions = generate_test_value_definitions(schema, 0) test_value_names = generate_test_value_names(schema, 0) test_value_definitions2 = generate_test_value_definitions(schema, 1) test_value_names2 = generate_test_value_names(schema, 1) check_resize = "true" if should_check_resize(schema) else "false" generated = f""" TEST(StaticRuntime, autogen_{type_variant_op_name}) {{ const std::string script = R"IR( graph({arg_declarations}): %bias: None = prim::Constant() %ret = aten::{op_name}({arg_names}) %cloned = aten::clone(%ret, %bias) return (%cloned) )IR"; {test_value_definitions} std::vector args{{{test_value_names}}}; testStaticRuntime(script, args, {{}}, /*use_allclose=*/false, /*use_equalnan=*/false, /*check_resize=*/{check_resize}); {test_value_definitions2} std::vector args2{{{test_value_names2}}}; testStaticRuntime(script, args, args2, /*use_allclose=*/false, /*use_equalnan=*/false, /*check_resize=*/{check_resize}); }} """ return generated def view_op_test_case_generator(self, g: NativeFunctionsViewGroup) -> str: schema = g.view.func schema_str = str(schema) assert schema_str.find("(") > 0 type_variant_op_name = schema_str[: schema_str.find("(")].replace(".", "_") op_name = g.view.root_name assert type_variant_op_name.startswith(op_name) arg_types = generate_test_ir_arguments(schema) arg_declarations = ", ".join( ( arg_name if arg_type is None else f"{arg_name}: {arg_type}" for arg_name, arg_type in arg_types ) ) arg_names = ", ".join((arg_name for arg_name, _ in arg_types)) assert ( len(schema.returns) == 1 and isinstance(schema.returns[0].type, BaseType) and schema.returns[0].type.name is BaseTy.Tensor ) test_value_definitions = generate_test_value_definitions(schema, 0) test_value_names = generate_test_value_names(schema, 0) generated = f""" TEST(StaticRuntime, autogen_{type_variant_op_name}) {{ const std::string script = R"IR( graph({arg_declarations}): %bias: None = prim::Constant() %ret = aten::{op_name}({arg_names}) %cloned = aten::clone(%ret, %bias) return (%cloned) )IR"; {test_value_definitions} std::vector args{{{test_value_names}}}; testStaticRuntime(script, args); }} """ return generated