import torch import torch._prims_common as utils # Utilities should come BEFORE this import from torch._decomp import register_decomposition from torch._prims_common import TensorLikeType from torch._prims_common.wrappers import out_wrapper from torch._refs import _broadcast_shapes # Data conversion references. # # Note: this module breaks the usual _refs to torch naming scheme where # _refs.foo.bar is a ref for torch.foo.bar. The following definitions are not # part of _refs/__init__.py to avoid name clashes with Python builtin types # (like int). __all__ = [ # dtypes "bfloat16", "bool", "byte", "cdouble", "cfloat", "chalf", "char", "double", "float", "half", "int", "long", "short", # misc "complex", "polar", ] def _make_conversion_method(name: str, dtype: torch.dtype): def fn( self: TensorLikeType, memory_format: torch.memory_format = torch.preserve_format ) -> TensorLikeType: return self.to(dtype, memory_format=memory_format) # type: ignore[call-overload] fn.__name__ = name return fn bfloat16 = _make_conversion_method("bfloat16", torch.bfloat16) bool = _make_conversion_method("bool", torch.bool) byte = _make_conversion_method("byte", torch.uint8) cdouble = _make_conversion_method("cdouble", torch.cdouble) cfloat = _make_conversion_method("cfloat", torch.cfloat) chalf = _make_conversion_method("chalf", torch.complex32) char = _make_conversion_method("char", torch.int8) double = _make_conversion_method("double", torch.double) float = _make_conversion_method("float", torch.float) half = _make_conversion_method("half", torch.half) int = _make_conversion_method("int", torch.int) long = _make_conversion_method("long", torch.long) short = _make_conversion_method("short", torch.short) @register_decomposition(torch._ops.ops.aten.complex) # Note: complex has type promotion tests disabled due to different semantics. # exact_dtype is for compat with complex_check_dtype from core. @out_wrapper(exact_dtype=True) def complex(real: TensorLikeType, imag: TensorLikeType) -> TensorLikeType: allowed_dtypes = (torch.float32, torch.float64, torch.float16) torch._check( real.dtype in allowed_dtypes and imag.dtype in allowed_dtypes, lambda: ( f"Expected both inputs to be Half, Float or Double tensors but got " f"{real.dtype} and {imag.dtype}" ), ) torch._check( real.dtype == imag.dtype, lambda: ( f"Expected object of scalar type {real.dtype} but got " f"scalar type {imag.dtype} for second argument" ), ) result_dtype = utils.corresponding_complex_dtype(real.dtype) # type: ignore[arg-type] common_shape = _broadcast_shapes(real.shape, imag.shape) result = real.new_empty( common_shape, dtype=result_dtype, layout=real.layout, device=real.device, # pin_memory=real.is_pinned(), # NYI ) result.real = real result.imag = imag return result @register_decomposition(torch._ops.ops.aten.polar) # Note: polar has type promotion tests disabled due to different semantics. # exact_dtype is for compat with complex_check_dtype from core. @out_wrapper(exact_dtype=True) def polar(abs: TensorLikeType, angle: TensorLikeType) -> TensorLikeType: result = torch.complex(abs, angle) result.real = abs * torch.cos(angle) result.imag = abs * torch.sin(angle) return result