import functools import logging import math import sys import typing from typing import Optional import torch import torch._decomp as decomp import torch._prims_common as utils import torch.ao.quantization.fx._decomposed from torch._decomp import ( core_aten_decompositions, get_decompositions, remove_decompositions, ) from torch._decomp.decompositions import ( _grid_sampler_2d as decomp_grid_sampler_2d, pw_cast_for_opmath, ) from torch._decomp.decompositions_for_rng import extra_random_decomps from torch._higher_order_ops.out_dtype import out_dtype from torch._prims_common import ( elementwise_dtypes, ELEMENTWISE_TYPE_PROMOTION_KIND, type_to_dtype, ) from . import config, inductor_prims log = logging.getLogger(__name__) aten = torch.ops.aten prims = torch.ops.prims quantized_decomposed = torch.ops.quantized_decomposed inductor_decompositions = get_decompositions( [ aten._adaptive_avg_pool2d_backward, aten.arange, aten.bitwise_and_, aten.bitwise_or_, aten.clamp_min_, aten.dist, aten.empty_like, aten.flip, aten.gelu, aten.hardtanh, aten.index_select, aten.lcm, aten.leaky_relu, aten.linalg_vector_norm, aten._log_softmax, aten.max_pool2d_with_indices_backward, aten._native_batch_norm_legit, aten._native_batch_norm_legit_functional, aten._native_batch_norm_legit_no_training, aten.native_batch_norm, aten.native_group_norm, aten.native_layer_norm, aten.nll_loss2d_backward, aten._softmax, aten.sin_, aten.sqrt_, out_dtype, aten._to_copy, aten.tril_indices, aten.triu_indices, aten.upsample_bilinear2d.vec, ] ) decompositions = {**core_aten_decompositions(), **inductor_decompositions} # Remove unwanted decompositions included via the core ATen decompositions from # the Inductor decomp table. decomps_to_exclude = [ aten._unsafe_index, aten._scaled_dot_product_flash_attention_for_cpu.default, # See comments in torch/_decomp/decompositions.py aten.clamp_max, aten.clamp_min, aten.glu, # inductor lowers this directly aten.split.Tensor, # inductor lowers this directly aten.squeeze, # inductor lowers this directly aten.sum, # inductor lowers this directly aten.unbind, # inductor lowers this directly ] remove_decompositions(decompositions, decomps_to_exclude) def register_decomposition(ops): for op in [ops] if callable(ops) else ops: if op in decompositions: log.warning("duplicate decomp: %s", ops) return decomp.register_decomposition(ops, decompositions) # TODO: for now, inductor doesn't handle asserts # because the condition is symbool -> tensor in the graph. @register_decomposition([aten._assert_async.msg]) def assert_async_msg_decomp(tensor, msg): return # Following `assert_async_msg_decomp` and implement as non-op. @register_decomposition([aten._functional_assert_async.msg]) def functional_assert_async_msg_decomp(tensor, msg): return @register_decomposition([aten.sym_constrain_range_for_size.default]) def sym_constrain_range_for_size(symbol, *, min=None, max=None): return @register_decomposition([aten.clamp]) @pw_cast_for_opmath def clamp(x, min=None, max=None): if min is not None: x = x.clamp_min(min) if max is not None: x = x.clamp_max(max) return x @register_decomposition([aten.full]) def full(size, fill_value, **kwargs): dtype = kwargs.get("dtype") if dtype is None: kwargs["dtype"] = type_to_dtype(type(fill_value)) return aten.full(size, fill_value, **kwargs) return NotImplemented # Not really sure how to put this into the main library. PrimTorch wants # empty_permuted to go to the prim, and typically users don't really want # to decompose to empty_strided (but inductor is OK with it, because we are # cool with strides and everything goes to empty_strided) @register_decomposition([aten.empty_permuted.default]) def empty_permuted(size, physical_layout, **kwargs): perm = [0] * len(size) for p, l in enumerate(physical_layout): perm[l] = p return torch.empty([size[l] for l in physical_layout], **kwargs).permute(perm) @register_decomposition([aten.convolution_backward]) def convolution_backward( grad_output, input, weight, bias_sizes, stride, padding, dilation, transposed, output_padding, groups, output_mask, ): if not output_mask[2] or grad_output.device.type != "cuda": return NotImplemented grad_bias = aten.sum(grad_output, [0] + list(range(2, grad_output.dim()))) grad_inp, grad_weight, _ = aten.convolution_backward( grad_output, input, weight, bias_sizes, stride, padding, dilation, transposed, output_padding, groups, [output_mask[0], output_mask[1], False], ) return (grad_inp, grad_weight, grad_bias) @register_decomposition([aten.log2]) def log2(x): return torch.log(x) * (1.0 / math.log(2.0)) @register_decomposition([aten.round.decimals]) def round_dec(x, decimals=0): ten_pow_decimals = 10.0**decimals return aten.round(x * ten_pow_decimals) * (1.0 / ten_pow_decimals) @register_decomposition([aten.bmm]) @pw_cast_for_opmath def bmm(self, batch2): if config.coordinate_descent_tuning: if self.shape[1] == 1 or batch2.shape[2] == 1: out = (self.unsqueeze(-1) * batch2.unsqueeze(1)).sum(dim=2) return out if self.device.type == "cpu": if self.size(1) == 1 and batch2.size(-1) == 1: return torch.sum( self.squeeze(1) * batch2.squeeze(-1), dim=1, keepdim=True ).unsqueeze(1) return NotImplemented @register_decomposition([aten.addmm]) @pw_cast_for_opmath def addmm(self, mat1, mat2, beta=1, alpha=1): if self.device.type == "cpu": if mat1.size(0) == 1 and mat2.size(-1) == 1: out = torch.sum( mat1.squeeze(0) * mat2.squeeze(-1), dim=0, keepdim=True ).unsqueeze(0) return alpha * out + beta * self if mat1.size(0) == 1 and mat2.size(0) <= 16 and mat2.size(1) <= 16: out = (mat1.T * mat2).sum(dim=0, keepdim=True) return alpha * out + beta * self return NotImplemented @register_decomposition([aten.mm]) @pw_cast_for_opmath def mm(self, input2): from torch.fx.experimental.symbolic_shapes import ( definitely_true, guard_size_oblivious, ) # Our matrix vector multiplies only achieve peak bandwidth with coordinate descent tuning. # todo: Look into why and fix it (hopefully) if config.coordinate_descent_tuning: if self.shape[0] == 1 or input2.shape[1] == 1: return (self.unsqueeze(2) * input2.unsqueeze(0)).sum(dim=1) if self.device.type == "cpu": if ( guard_size_oblivious(self.size(-1) == 1) and guard_size_oblivious(self.size(0) > 0) and guard_size_oblivious(input2.size(0) == 1) and (self.dtype == input2.dtype) and definitely_true((torch.numel(self) + torch.numel(input2)) <= 32) ): return torch.cat([self[i, :] * input2 for i in range(self.size(0))]) if guard_size_oblivious(self.size(0) == 1) and guard_size_oblivious( input2.size(-1) == 1 ): return torch.sum( self.squeeze(0) * input2.squeeze(-1), dim=0, keepdim=True ).unsqueeze(0) return NotImplemented # This pass does two things: # - Eliminate cat when there is only one tensor input # - Normalize cat calls, so that legacy empty 1-D tensors are removed (NB: we # don't remove ALL empty tensors, only the naughty ones) @register_decomposition([aten.cat.default]) def cat(tensors, dim=0): from torch.fx.experimental.symbolic_shapes import guard_size_oblivious def non_empty_tensor(x): # For better or worse, this is a valid cat: # # torch.cat([torch.randn(2, 2, 4), torch.randn(0), torch.randn(3, 2, 4)]) # # We'd like to eliminate naughtiness like this for downstream passes # like split_cat. The easiest way is to just drop such inputs # (guarding that they are non-zero). # # Is it permissible for this filtering to be size-oblivious? A case # where this could matter is cat([(2, 2), (u0,)], dim=0); if u0 # happened to be zero, we would have liked to have filtered it out. # But actually, the ONLY way this could have passed is if u0 == 0, # so by the time we get here we have already installed a deferred # runtime assert forcing u0 to be zero. So if this hasn't happened, # we know that the unbacked SymInt has appropriate size and there are # no problems. return len(x.shape) != 1 or guard_size_oblivious(x.shape[0] > 0) filtered_tensors = list(filter(non_empty_tensor, tensors)) if len(filtered_tensors) == 1: return filtered_tensors[0].clone() elif 1 < len(filtered_tensors) < len(tensors): # on the first call, when we remove empty tensors, we redispatch recursively return aten.cat.default(filtered_tensors, dim) # when no 'filtering' has occurred, we raise to prevent infinite recursion (no more decomposition needed) return NotImplemented @register_decomposition([aten.angle]) def angle(x): if x.is_complex(): return torch.where( torch.isnan(x.real), float("nan"), torch.atan2(x.imag, x.real) ) # when x is real number # if x >= 0, return 0 # if x < 0, return pi # if x is nan, return nan _, dtype = elementwise_dtypes( x, type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, ) pi = torch.scalar_tensor(math.pi, dtype=dtype, device=x.device) ret = torch.where(x < 0, pi, 0.0) return torch.where(torch.isnan(x), float("nan"), ret) @register_decomposition([aten.add]) def add(x, y, *, alpha=None): x_is_complex_tensor = torch.is_tensor(x) and x.is_complex() y_is_complex_tensor = torch.is_tensor(y) and y.is_complex() if not x_is_complex_tensor or not y_is_complex_tensor: return NotImplemented z = y if alpha is not None: z = alpha * y complex_type = torch.promote_types(x.dtype, y.dtype) return (x.view(x.real.dtype) + z.view(y.real.dtype)).view(complex_type) @register_decomposition([aten.conj_physical]) def conj_physical(self): assert not self.is_complex(), "TODO: implement this" return self @register_decomposition([aten.lift, aten.detach_]) def lift(self): return self @register_decomposition([aten.bernoulli.default]) def bernoulli(self, *, generator=None): assert generator is None return (torch.rand_like(self, dtype=torch.float32) < self).to(self.dtype) @register_decomposition([aten.fmin, prims.fmin]) def fmin(self, other): return torch.where(torch.isnan(other) | (other > self), self, other) @register_decomposition([aten.fmax, prims.fmax]) def fmax(self, other): return torch.where(torch.isnan(other) | (other < self), self, other) @register_decomposition(aten.amax) def amax(self, dim=None, keepdim=False): if self.dtype == torch.bool: return torch.any(self, dim=dim, keepdim=keepdim) return NotImplemented @register_decomposition(aten.amin) def amin(self, dim=None, keepdim=False): if self.dtype == torch.bool: return torch.all(self, dim=dim, keepdim=keepdim) return NotImplemented @register_decomposition([aten.narrow_copy]) def narrow_copy(self, dim, start, length): return torch.narrow(self, dim, start, length).clone() @register_decomposition([aten.expand_copy]) def expand_copy(self, size, *, implicit=False): return aten.expand(self, size, implicit=implicit).clone() @register_decomposition([aten.view_copy.default]) def view_copy_default(self, size): return aten.view(self, size).clone() @register_decomposition([aten.view_copy.dtype]) def view_copy_dtype(self, dtype): return self.to(dtype).clone() def get_like_layout( tensor: torch.Tensor, memory_format: Optional[torch.memory_format] ) -> torch.memory_format: # TODO: _to_copy tensor to stride permutation if memory_format is torch.preserve_format or memory_format is None: return utils.suggest_memory_format(tensor) else: return memory_format @register_decomposition(aten.rand_like) def rand_like(self, *, dtype=None, device=None, memory_format=None, **kwargs): return torch.rand( [*self.size()], dtype=dtype or self.dtype, device=device or self.device, **kwargs, ).to(memory_format=get_like_layout(self, memory_format)) @register_decomposition(aten.randn_like) def randn_like(self, *, dtype=None, device=None, memory_format=None, **kwargs): return torch.randn( [*self.size()], dtype=dtype or self.dtype, device=device or self.device, **kwargs, ).to(memory_format=get_like_layout(self, memory_format)) @register_decomposition(aten.full_like) def full_like( self, fill_value, *, dtype=None, layout=None, device=None, pin_memory=False, requires_grad=False, memory_format=torch.preserve_format, ): return torch.full( [*self.size()], fill_value, dtype=dtype or self.dtype, layout=layout or self.layout, device=device or self.device, requires_grad=requires_grad, ).to(memory_format=get_like_layout(self, memory_format)) @register_decomposition(aten.randint_like.default) def randint_like(self, high, *, dtype=None, device=None, memory_format=None, **kwargs): return aten.randint.low( 0, high, [*self.size()], dtype=dtype or self.dtype, device=device or self.device, **kwargs, ).to(memory_format=get_like_layout(self, memory_format)) @register_decomposition(aten.randint_like.low_dtype) def randint_like_low( self, low, high, *, dtype=None, device=None, memory_format=None, **kwargs ): return aten.randint.low( low, high, [*self.size()], dtype=dtype or self.dtype, device=device or self.device, **kwargs, ).to(memory_format=get_like_layout(self, memory_format)) @register_decomposition(aten.randint.default) def randint(high, size, **kwargs): return aten.randint.low(0, high, size, **kwargs) # The difference between quantize_per_tensor.default and quantize_per_tensor.tensor is # scale and zero_point is scalar or scalar tensor @register_decomposition(quantized_decomposed.quantize_per_tensor.default) def quantize_per_tensor_default_decomp_impl( input: torch.Tensor, scale: float, zero_point: int, quant_min: int, quant_max: int, dtype: torch.dtype, ) -> torch.Tensor: if input.dtype == torch.bfloat16: input = input.to(torch.float32) inv_scale = 1.0 / scale return torch.clamp( torch.round(input * inv_scale) + zero_point, quant_min, quant_max ).to(dtype) # The difference between dequantize_per_tensor.default and dequantize_per_tensor.tensor is # scale and zero_point is scalar or scalar tensor @register_decomposition(quantized_decomposed.dequantize_per_tensor.default) def dequantize_per_tensor_default_decomp_impl( input: torch.Tensor, scale: float, zero_point: int, quant_min: int, quant_max: int, dtype: torch.dtype, ) -> torch.Tensor: return (input.to(torch.float32) - zero_point) * scale @register_decomposition(quantized_decomposed.quantize_per_tensor.tensor) def quantize_per_tensor_tensor_decomp_impl( input: torch.Tensor, scale: torch.Tensor, zero_point: torch.Tensor, quant_min: int, quant_max: int, dtype: torch.dtype, ) -> torch.Tensor: if input.dtype == torch.bfloat16: input = input.to(torch.float32) inv_scale = 1.0 / scale return torch.clamp( torch.round(input * inv_scale) + zero_point, quant_min, quant_max ).to(dtype) @register_decomposition(quantized_decomposed.dequantize_per_tensor.tensor) def dequantize_per_tensor_tensor_decomp_impl( input: torch.Tensor, scale: torch.Tensor, zero_point: torch.Tensor, quant_min: int, quant_max: int, dtype: torch.dtype, ) -> torch.Tensor: return (input.to(torch.float32) - zero_point.to(torch.int32)) * scale.to( torch.float32 ) @register_decomposition(torch.ops.quantized.embedding_bag_byte_unpack) def q_embedding_bag_byte_unpack_decomp(packed): def bitcast_u8_to_f32(u8): x, y, z, w = (u8[..., n].to(torch.int32) for n in (0, 1, 2, 3)) if sys.byteorder == "little": return (x + (y << 8) + (z << 16) + (w << 24)).view(torch.float32)[..., None] else: return ((x << 24) + (y << 16) + (z << 8) + w).view(torch.float32)[..., None] scales = bitcast_u8_to_f32(packed[..., -8:-4]) offsets = bitcast_u8_to_f32(packed[..., -4:]) return packed[..., :-8].to(torch.float32) * scales + offsets @register_decomposition([aten.grid_sampler_2d]) @pw_cast_for_opmath def grid_sampler_2d( a: torch.Tensor, grid: torch.Tensor, interpolation_mode: int = 0, padding_mode: int = 0, align_corners: bool = False, ) -> torch.Tensor: # We do not expand the grid (_expand_grid=False) on cpu for performance reasons # Experimenting locally it was found that compiled CUDA code is accelerated by ~5x # and CPU code by ~2x on bicubic mode, if we expand the grid from (N, H, W, 2) into (N, C, H, W, 2) # However, this leads to a slowdown around ~0.8x on CPU bilinear mode, channels first. # Thus we apply this hack to not expand the grid for this case. _expand_grid = not ( a.device == torch.device("cpu") and interpolation_mode == 0 and a.is_contiguous(memory_format=torch.contiguous_format) ) output = decomp_grid_sampler_2d( a, grid=grid, interpolation_mode=interpolation_mode, padding_mode=padding_mode, align_corners=align_corners, _expand_grid=_expand_grid, ) return output @register_decomposition(aten._foreach_addcmul.Scalar) def _foreach_addcmul_scalar(self, left_tensors, right_tensors, scalar=1): return aten._foreach_add.List( self, aten._foreach_mul.List(left_tensors, right_tensors), alpha=scalar ) @register_decomposition(aten._foreach_addcdiv.Scalar) def _foreach_addcdiv_scalar(self, left_tensors, right_tensors, scalar=1): return aten._foreach_add.List( self, aten._foreach_div.List(left_tensors, right_tensors), alpha=scalar ) @register_decomposition(aten._foreach_lerp.Scalar) def _foreach_lerp_scalar(start_tensors, end_tensors, weight): return aten._foreach_add.List( start_tensors, aten._foreach_mul.Scalar( aten._foreach_sub.List(end_tensors, start_tensors), weight ), ) @aten.miopen_batch_norm.default.py_impl(torch._C.DispatchKey.Autograd) @register_decomposition(aten.miopen_batch_norm) def miopen_batch_norm( input: torch.Tensor, weight: torch.Tensor, bias: typing.Optional[torch.Tensor], running_mean: typing.Optional[torch.Tensor], running_var: typing.Optional[torch.Tensor], training: bool, exponential_average_factor: float, epsilon: float, ): a, b, c = aten.native_batch_norm( input, weight, bias, running_mean, running_var, training, exponential_average_factor, epsilon, ) if training: return (a, b, c) return ( a, weight.new_zeros((0,)), weight.new_zeros((0,)), ) @functools.lru_cache(None) def fast_random_decomps(): return {**decompositions, **extra_random_decomps} def select_decomp_table(): """decomps can change based on config""" if config.fallback_random: return decompositions return fast_random_decomps() @register_decomposition(aten.masked_scatter) def masked_scatter(self, mask, source): if self.device.type == "cuda": # This two-step algorithm is the same as eager CUDA, for eager CPU we # use a 1-shot serial iteration. self, mask = aten.broadcast_tensors([self, mask]) source_idx = mask.reshape(-1).cumsum(0) - 1 return inductor_prims.masked_scatter_with_index(self, mask, source_idx, source) return NotImplemented @register_decomposition(quantized_decomposed.choose_qparams.tensor) def choose_qparams_tensor( input: torch.Tensor, quant_min: int, quant_max: int, eps: float, dtype: torch.dtype ): min_val, max_val = torch.aminmax(input) scale = (max_val - min_val) / float(quant_max - quant_min) scale = torch.max(scale, torch.Tensor([eps])) zero_point = quant_min - torch.round(min_val / scale).to(torch.int) zero_point = torch.clamp(zero_point, quant_min, quant_max) return scale.to(torch.float64), zero_point.to(torch.int64) @register_decomposition(aten.put) def put(self, index, source, accumulate=False): flattened = self.flatten() flattened = torch.index_put( flattened, [index], source.reshape(index.shape), accumulate ) return flattened.reshape(self.shape) @register_decomposition(aten.put_) def put_(self, index, source, accumulate=False): out = aten.put(self, index, source, accumulate=accumulate) return self.copy_(out)