# mypy: ignore-errors import functools import itertools import math import sys from typing import Callable, Union import torch import torch._custom_op import torch._logging from torch._ops import OpOverload from torch._prims_common import ( elementwise_dtypes, ELEMENTWISE_TYPE_PROMOTION_KIND, is_boolean_dtype, is_float_dtype, is_integer_dtype, ) from torch._subclasses.fake_tensor import ( DataDependentOutputException, DynamicOutputShapeException, FakeTensor, in_kernel_invocation_manager, run_fallback_kernel, UnsupportedOperatorException, ) from torch.fx.operator_schemas import normalize_function from torch.utils._stats import count_label pytree = torch.utils._pytree __all__ = [ "op_implementations_checks", "get_fast_op_impls", "stride_incorrect_op", "has_meta", ] op_implementations_dict = {} op_implementations_checks = [] aten = torch._ops.ops.aten def ordered_set(*items): return dict.fromkeys(items, True) # This function indicates if the backend device # supports non-contiguous tensors def is_noncontiguous_supported(device): if device.type == "hpu": return False return True _like_tensor_constructors = ordered_set( aten.empty_like.default, aten.empty_like.out, aten.full_like.default, aten.full_like.out, aten.ones_like.default, aten.ones_like.out, aten.rand_like.default, aten.rand_like.out, aten.randn_like.default, aten.randn_like.out, aten.randint_like.default, aten.randint_like.out, aten.randint_like.low_dtype, aten.randint_like.low_dtype_out, aten.zeros_like.default, aten.zeros_like.out, aten.new_empty.default, aten.new_empty.out, aten.new_empty_strided.default, aten.new_empty_strided.out, aten.new_full.default, aten.new_full.out, aten.new_zeros.default, aten.new_zeros.out, aten.new_ones.default, aten.new_ones.out, ) _device_not_kwarg_ops = ordered_set( aten._resize_output_.default, aten._nested_tensor_from_tensor_list.default, aten._nested_tensor_from_tensor_list.out, aten.pin_memory.default, aten.is_pinned.default, aten.to.device, aten.to.prim_Device, aten._pin_memory.default, aten._pin_memory.out, aten._resize_output.default, aten._resize_output.out, ) # this op is never actually used _non_kwarg_device_constructors = (aten._list_to_tensor,) def contains_tensor_types(type): tensor_type = torch._C.TensorType.get() return type.isSubtypeOf(tensor_type) or any( contains_tensor_types(e) for e in type.containedTypes() ) @functools.lru_cache(None) def _is_tensor_constructor(func: OpOverload): assert isinstance(func, OpOverload) schema = func._schema if any(contains_tensor_types(arg.type) for arg in schema.arguments): return False # TODO: no real reason to restrict multiple outputs return ( len(schema.returns) == 1 and schema.returns[0].type is torch._C.TensorType.get() ) def register_op_impl(run_impl_check: Union[Callable[[OpOverload], bool], OpOverload]): def impl_decorator(op_impl): if isinstance(run_impl_check, OpOverload): assert ( run_impl_check not in op_implementations_dict ), f"duplicate registration: {run_impl_check}" op_implementations_dict[run_impl_check] = op_impl elif isinstance(run_impl_check, (list, tuple)): for op in run_impl_check: register_op_impl(op)(op_impl) else: assert callable(run_impl_check) op_implementations_checks.append((run_impl_check, op_impl)) return op_impl return impl_decorator @register_op_impl(op_implementations_dict.__contains__) def dispatch_to_op_implementations_dict(fake_mode, func, *args, **kwargs): return op_implementations_dict[func](fake_mode, func, *args, **kwargs) @register_op_impl(_is_tensor_constructor) @register_op_impl([*_like_tensor_constructors]) def constructors(fake_mode, func, *args, **kwargs): assert func not in _non_kwarg_device_constructors _, new_kwargs = normalize_function( func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True ) if "names" in kwargs: raise UnsupportedOperatorException( "torch.compile doesn't support named tensors" ) if func in _like_tensor_constructors: default_device = new_kwargs["input"].device # TODO: file issue args = (new_kwargs.pop("input"),) else: # cpu is default device if none is specified default_device = torch.device("cpu") args = () out_device = new_kwargs.pop("device", None) out_device = out_device if out_device is not None else default_device new_kwargs["device"] = torch.device("meta") # _like constructors have fake tensor inputs (maybe this causes the non-like # to fail? hmmm) with in_kernel_invocation_manager(fake_mode): r = func(*args, **new_kwargs) return FakeTensor(fake_mode, r, out_device) @register_op_impl(aten.to.prim_Device) @register_op_impl(aten.to.device) def non_kwarg_to(fake_mode, func, *args, **kwargs): _, new_kwargs = normalize_function( func, args, kwargs, normalize_to_only_use_kwargs=True ) input_device = new_kwargs["device"] out_device = input_device if input_device else new_kwargs["input"].device new_kwargs["device"] = torch.device("meta") inp = new_kwargs.pop("input") with in_kernel_invocation_manager(fake_mode): r = func(inp, **new_kwargs) # TODO: I think this does the wrong thing if r is inp return fake_mode.fake_tensor_converter.from_meta_and_device( fake_mode, r, out_device ) def stride_incorrect_op(op): if op.namespace not in ("aten", "prims"): return False if op is aten._fft_c2c.default: return False op_name = op.name() if "fft" in op_name: return True return False # These operators have meta implementations with incorrect strides @register_op_impl(stride_incorrect_op) def wordaround_stride_incorrect_op(fake_mode, func, *args, **kwargs): # This is a workaround for meta implmentations with incorrect strides def is_symbolic(x): if isinstance(x, FakeTensor): return x._has_symbolic_sizes_strides if isinstance(x, (torch.SymInt, torch.SymFloat, torch.SymBool)): return True return False # For static shapes, we can fall back to eager for the real strides if fake_mode.allow_fallback_kernels: require_dynamic = any( is_symbolic(x) for x in itertools.chain(args, kwargs.values()) ) if not require_dynamic: flat_args, args_spec = pytree.tree_flatten((args, kwargs)) return run_fallback_kernel(fake_mode, func, flat_args, args_spec, None) raise UnsupportedOperatorException(func) # Dont default to default device handling, # since the device of `the_template` is ignored @register_op_impl(aten.resize_as_.default) def resize_as_(fake_mode, func, *args, **kwargs): with in_kernel_invocation_manager(fake_mode): return func(*args, **kwargs) @register_op_impl(aten._sparse_coo_tensor_with_dims_and_tensors.default) def _sparse_coo_tensor_with_dims_and_tensors(fake_mode, func, *args, **kwargs): # TODO: remove me return constructors(fake_mode, func, *args, **kwargs) # index.Tensor data-dependent in only some conditions @register_op_impl( lambda func: torch.Tag.dynamic_output_shape in func.tags and func not in [aten.index.Tensor, aten.nonzero.default, aten.repeat_interleave.Tensor] ) def dyn_shape(fake_mode, func, *args, **kwargs): raise DynamicOutputShapeException(func) @register_op_impl(aten.repeat_interleave.Tensor) def repeat_interleave_tensor(fake_mode, func, repeats, output_size=None): if output_size is None: if ( fake_mode.shape_env is None or not fake_mode.shape_env.allow_dynamic_output_shape_ops ): raise DynamicOutputShapeException(func) output_size = fake_mode.shape_env.create_unbacked_symint() # Avoid importing sympy at a module level from torch.fx.experimental.symbolic_shapes import _constrain_range_for_size _constrain_range_for_size(output_size) # TODO: consider a memo return repeats.new_empty(output_size) @register_op_impl(torch.ops.aten._local_scalar_dense.default) def local_scalar_dense(fake_mode, func, arg): if fake_mode.shape_env is None or not fake_mode.shape_env.allow_scalar_outputs: # Without symints/symfloats, cannot handle this raise DataDependentOutputException(func) if is_float_dtype(arg.dtype): return fake_mode.shape_env.create_unbacked_symfloat() elif is_integer_dtype(arg.dtype): return fake_mode.shape_env.create_unbacked_symint() elif is_boolean_dtype(arg.dtype): return fake_mode.shape_env.create_unbacked_symbool() else: raise NotImplementedError(f"local_scalar_dense/item NYI for {arg.dtype}") @register_op_impl(torch.ops.aten.nonzero.default) def nonzero(fake_mode, func, arg): if ( fake_mode.shape_env is None or not fake_mode.shape_env.allow_dynamic_output_shape_ops ): # Without symints/symfloats, cannot handle this raise DynamicOutputShapeException(func) if arg.nonzero_memo is None: nnz = fake_mode.shape_env.create_unbacked_symint() # This is unsound, but it works well in practice # See https://docs.google.com/document/d/1lFRYAJo5nrfxRhwIzGnfi2pbLpU6T4ytSRSuLJ5qebI/edit# # TODO: Add a config knob to turn off this unsound behavior # # NB: If numel < 2, the bounds here might be COMPLETELY # disjoint with what can actually occur. But this is fine: # remember, the hypothesis is that if your later code works # with N >= 2, it will work with N = 1 and N = 0. maxval = sys.maxsize - 1 # Avoid importing sympy at a module level from torch.fx.experimental.symbolic_shapes import ( _constrain_range_for_size, has_free_symbols, ) if not has_free_symbols(arg.numel()): # Don't upgrade the range if numel is less than two, since we then # have an empty range which makes things go explodey. We also # don't allow for 2 because that would specialize the unbacked # SymInt to 2, which is also likely to be buggy. if arg.numel() > 2: maxval = int(arg.numel()) _constrain_range_for_size(nnz, max=maxval) arg._nonzero_memo = nnz arg._nonzero_memo_vc = arg._version return arg.new_empty((arg.nonzero_memo, arg.dim()), dtype=torch.int64) @register_op_impl(torch.ops.aten.masked_select.default) def masked_select(fake_mode, func, self, mask): if ( fake_mode.shape_env is None or not fake_mode.shape_env.allow_dynamic_output_shape_ops ): # Without symints/symfloats, cannot handle this raise DynamicOutputShapeException(func) nnz = fake_mode.shape_env.create_unbacked_symint() # see nonzero for commentary maxval = sys.maxsize - 1 # Avoid importing sympy at a module level from torch.fx.experimental.symbolic_shapes import ( _constrain_range_for_size, has_free_symbols, ) if not has_free_symbols(self.numel()): if self.numel() > 2: maxval = int(self.numel()) _constrain_range_for_size(nnz, max=maxval) return self.new_empty((nnz,)) # NB: this must be ordered after local_scalar_dense @register_op_impl(lambda func: torch.Tag.data_dependent_output in func.tags) def data_dep(fake_mode, func, *args, **kwargs): raise DataDependentOutputException(func) # Bool Indices get Expanded as Masks # See: IndexingUtils.h:expandTensors def check_no_bool_index_tensors(func, self, indices): for index in indices: if index is not None and index.dtype in (torch.bool, torch.uint8): raise DynamicOutputShapeException(func) def run_and_return_new_tensor_of_input_device(fake_mode, func, args, kwargs): _, new_kwargs = normalize_function( func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True ) out_device = new_kwargs["input"].device with in_kernel_invocation_manager(fake_mode): out = func(*args, **kwargs) if not is_noncontiguous_supported(out_device): out = out.new_empty(out.shape) if out is new_kwargs["input"]: return out # copy_ return FakeTensor(fake_mode, out, out_device) _is_builtin_namespaces = ordered_set("aten", "prims", "prim") def is_builtin(op): return op.namespace in _is_builtin_namespaces def has_meta(func): return torch._C._dispatch_has_computed_kernel_for_dispatch_key(func.name(), "Meta") @register_op_impl( lambda func: is_builtin(func) and "foreach" in func.name() and has_meta(func) ) def foreach_run_and_map_input_device(fake_mode, func, *args, **kwargs): tensor_lists = [] for arg in itertools.chain(args, kwargs.values()): if ( isinstance(arg, (list, tuple)) and len(arg) and isinstance(arg[0], torch.Tensor) ): tensor_lists.append(arg) try: with in_kernel_invocation_manager(fake_mode): out_meta = func(*args, **kwargs) except NotImplementedError as not_implemented_error: return NotImplemented if not out_meta: return out_meta assert tensor_lists out_fake = [] for i, meta_t in enumerate(out_meta): device, _ = FakeTensor._find_common_device(func, [tl[i] for tl in tensor_lists]) out_fake.append( fake_mode.fake_tensor_converter.from_meta_and_device( fake_mode, meta_t, device ) ) return out_fake # Dont default to default device handling, # Since op can take in non-zero sized cpu # index tensors with cuda self @register_op_impl(aten.index.Tensor) def index_tensor(fake_mode, func, *args, **kwargs): from torch._meta_registrations import meta_index_Tensor _, new_kwargs = normalize_function( func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True ) out_device = new_kwargs["input"].device # ensure nonzero call goes to fake tensor with fake_mode: out = meta_index_Tensor(*args, **kwargs) return out.to(out_device) # Can take mixed meta/non-meta arguments; the meta registration # will roughly do the right thing even when given real devices @register_op_impl(aten._embedding_bag.default) def embedding_bag(fake_mode, func, *args, **kwargs): from torch._meta_registrations import meta_embedding_bag with fake_mode: return meta_embedding_bag(*args, **kwargs) # takes in multiple-devices, dont default to default device handling @register_op_impl(aten._unsafe_index_put.default) @register_op_impl(aten.copy.default) @register_op_impl(aten.copy_.default) @register_op_impl(aten.slice_scatter.default) def multi_device_op_default(fake_mode, func, *args, **kwargs): return run_and_return_new_tensor_of_input_device(fake_mode, func, args, kwargs) # same with multi_device_op_default, but return the input @register_op_impl(aten.copy.out) @register_op_impl(aten.slice_scatter.out) def multi_device_op_out(fake_mode, func, *args, **kwargs): with in_kernel_invocation_manager(fake_mode): out = func(*args, **kwargs) _, new_kwargs = normalize_function( func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True ) return new_kwargs["input"] @register_op_impl(aten.index_put.default) @register_op_impl(aten.index_put_.default) def index_put_impl(fake_mode, func, *args, **kwargs): _, new_kwargs = normalize_function( func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True ) values = new_kwargs["values"] self_device = new_kwargs["input"].fake_device torch._check( self_device == values.fake_device or (values.ndim == 0 and values.numel() == 1), lambda: f"Mismatching {func} device between self ({self_device}) and values ({values.device})", ) out = run_and_return_new_tensor_of_input_device(fake_mode, func, args, kwargs) if func is aten.index_put_.default: return new_kwargs["input"] else: return out @register_op_impl(aten._nested_tensor_from_tensor_list.default) @register_op_impl(aten._nested_tensor_from_tensor_list.out) def nested_tensors_unsupported(fake_mode, func, *args, **kwargs): raise UnsupportedOperatorException( "torch.compile does not support strided NestedTensor" ) @register_op_impl( [ x for x in _device_not_kwarg_ops if x not in ( # these are already registered elsewhere aten.to.device, aten.to.prim_Device, aten._nested_tensor_from_tensor_list.default, aten._nested_tensor_from_tensor_list.out, ) ] ) def nyi(fake_mode, func, *args, **kwargs): assert func not in _device_not_kwarg_ops, f"NYI: {func}" @register_op_impl([aten.convolution.default, aten.convolution_backward.default]) def conv(fake_mode, func, *args, **kwargs): _, kwargs = normalize_function( func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True ) device = kwargs["input"].fake_device # need to re-enable mode so the tensors report fake device with fake_mode: # if the input is unsqueezed is done in Convolution.cpp we get segfault k = kwargs["weight"].ndim batch = kwargs["input"].shape[0] # Avoid importing sympy at a module level from torch.fx.experimental.symbolic_shapes import has_hint if not has_hint(batch): # TODO: We can make this a little more faithful with best effort # channels last detection (but only if it's statically obvious!) mem_fmt = None elif k == 3 and not kwargs["input"].is_mkldnn and not kwargs["input"].is_xpu: mem_fmt = None else: if func is aten.convolution.default: conv_backend = torch._C._select_conv_backend(**kwargs) else: conv_backend = torch._C._select_conv_backend( kwargs["input"], kwargs["weight"], bias=None, stride=kwargs["stride"], padding=kwargs["padding"], dilation=kwargs["dilation"], transposed=kwargs["transposed"], output_padding=kwargs["output_padding"], groups=kwargs["groups"], bias_sizes=kwargs["bias_sizes"], ) mem_fmt = torch._C._conv_determine_backend_memory_format( kwargs["input"], kwargs["weight"], conv_backend ) def convert(t, mem_fmt): if t is None: return t if mem_fmt is not None: t = t.to(memory_format=mem_fmt) return FakeTensor(fake_mode, t, device) with in_kernel_invocation_manager(fake_mode): out = func(**kwargs) if func is aten.convolution.default: return convert(out, mem_fmt) else: return ( convert(out[0], mem_fmt), convert(out[1], mem_fmt), convert(out[2], None), ) @register_op_impl(aten._scaled_dot_product_flash_attention.default) def meta__scaled_dot_product_flash(fake_mode, func, *args, **kwargs): _, kwargs = normalize_function( func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True ) query = kwargs["query"] key = kwargs["key"] return_debug_mask = kwargs["return_debug_mask"] # unused: value, dropout_p, is_causal, scale def convert_tensor(t, device): return FakeTensor(fake_mode, t, device) batch_size = query.size(0) num_heads = query.size(1) max_seqlen_batch_q = query.size(2) head_dim = query.size(3) max_seqlen_batch_k = key.size(2) query_t = query.transpose(1, 2) # empty_like already returns a fake tensor so we don't need to convert it attention = torch.empty_like(query_t).transpose(1, 2) logsumexp = convert_tensor( torch.empty( (batch_size, num_heads, max_seqlen_batch_q), dtype=torch.float, device="meta", ), device=query.device, ) if return_debug_mask: blocksize_c = 128 if head_dim > 64 else 256 max_seqlen_k = math.ceil(max_seqlen_batch_q / blocksize_c) if max_seqlen_batch_k <= 128: max_seqlen_k = 128 elif max_seqlen_batch_k <= 256: max_seqlen_k = 256 debug_mask = convert_tensor( torch.empty( (batch_size, num_heads, max_seqlen_batch_q, max_seqlen_k), dtype=query.dtype, device="meta", ), device=query.device, ) else: debug_mask = convert_tensor( torch.empty(0, dtype=query.dtype, device="meta"), query.device, ) # Note [Seed and Offset]: device for seed and offset below depends on whether we are # capturing or not, but at the time of tracing we don't know if we # are going to use cudagraphs or not, so we return meta tensors here # it's possible we'll need to have some special handling in inductor for sdpa return ( attention, logsumexp, None, None, max_seqlen_batch_q, max_seqlen_batch_k, convert_tensor(torch.empty((), dtype=torch.long, device="meta"), query.device), convert_tensor(torch.empty((), dtype=torch.long, device="meta"), query.device), debug_mask, ) @register_op_impl(aten._scaled_dot_product_efficient_attention.default) def meta__scaled_dot_product_efficient(fake_mode, func, *args, **kwargs): _, kwargs = normalize_function( func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True ) query = kwargs["query"] key = kwargs["key"] value = kwargs["value"] compute_log_sumexp = kwargs["compute_log_sumexp"] # unused: attn_bias, dropout_p, is_causal, scale def convert_tensor(t, device): return FakeTensor(fake_mode, t, device) query = query.transpose(1, 2) key = key.transpose(1, 2) value = value.transpose(1, 2) B = query.size(0) M = query.size(1) N = key.size(1) num_heads = query.size(-2) K = query.size(-1) Kv = value.size(-1) res = convert_tensor( torch.empty(B, M, num_heads, Kv, dtype=query.dtype, device="meta"), query.device, ) logsumexp_dim = math.ceil(M / 32) * 32 if compute_log_sumexp else 0 logsum_exp = convert_tensor( torch.empty( (B, num_heads, logsumexp_dim), dtype=torch.float, device="meta", ), query.device, ) res = res.transpose(1, 2) # See Note [Seed and Offset]: seed = convert_tensor( torch.empty((), dtype=torch.long, device="meta"), query.device ) offset = convert_tensor( torch.empty((), dtype=torch.long, device="meta"), query.device ) return res, logsum_exp, seed, offset @register_op_impl(aten._flash_attention_forward.default) def meta__flash_attention_forward(fake_mode, func, *args, **kwargs): _, kwargs = normalize_function( func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True ) query = kwargs["query"] key = kwargs["key"] cum_seq_q = kwargs["cum_seq_q"] cum_seq_k = kwargs["cum_seq_k"] max_q = kwargs["max_q"] max_k = kwargs["max_k"] return_debug_mask = kwargs["return_debug_mask"] # unused: value, dropout_p, is_causal, scale def convert_tensor(t, device): return FakeTensor(fake_mode, t, device) # NB: there are two underlying paths: # 1. normal dense path; expect 4D inputs of shape (batch_size, seqlen, num_heads, head_dim) # 2. varseqlen path; expect 3D inputs of shape (total, num_heads, head_dim) where total # includes all batch item sequences. cum_seq_q / cum_seq_k contain offsets into total batch_size = query.size(0) if cum_seq_q is None else cum_seq_q.numel() - 1 max_seqlen_batch_q = query.size(1) if cum_seq_q is None else max_q max_seqlen_batch_k = key.size(1) if cum_seq_k is None else max_k num_heads = query.size(-2) head_dim = query.size(-1) # Cuda Path # note: empty_like already returns a fake tensor, we don't need to wrap it attention = torch.empty_like(query) logsumexp = convert_tensor( torch.empty( (batch_size, num_heads, max_seqlen_batch_q), dtype=torch.float, device="meta", ), device=query.device, ) if return_debug_mask: blocksize_c = 128 if head_dim > 64 else 256 max_seqlen_k = math.ceil(max_seqlen_batch_q / blocksize_c) if max_seqlen_batch_k <= 128: max_seqlen_k = 128 elif max_seqlen_batch_k <= 256: max_seqlen_k = 256 debug_mask = convert_tensor( torch.empty( (batch_size, num_heads, max_seqlen_batch_q, max_seqlen_k), dtype=query.dtype, device="meta", ), query.device, ) else: debug_mask = convert_tensor( torch.empty(0, dtype=query.dtype, device="meta"), query.device, ) # See Note [Seed and Offset]: return ( attention, logsumexp, convert_tensor(torch.empty((), dtype=torch.long, device="meta"), query.device), convert_tensor(torch.empty((), dtype=torch.long, device="meta"), query.device), debug_mask, ) @register_op_impl(aten._efficient_attention_forward.default) def meta__efficient_attention_forward(fake_mode, func, *args, **kwargs): _, kwargs = normalize_function( func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True ) query = kwargs["query"] key = kwargs["key"] value = kwargs["value"] cu_seqlens_q = kwargs["cu_seqlens_q"] max_seqlen_q = kwargs["max_seqlen_q"] max_seqlen_k = kwargs["max_seqlen_k"] compute_log_sumexp = kwargs["compute_log_sumexp"] # unused: bias, cu_seqlens_k, dropout_p, custom_mask_type, scale, causal_diagonal, seqlen_k def convert_tensor(t, device): return FakeTensor(fake_mode, t, device) B = query.size(0) M = query.size(1) N = key.size(1) num_heads = query.size(-2) K = query.size(-1) Kv = value.size(-1) res = convert_tensor( torch.empty(B, M, num_heads, Kv, dtype=query.dtype, device="meta"), query.device, ) logsumexp_batch_dim = cu_seqlens_q.size(0) - 1 if (cu_seqlens_q is not None) else B actual_max_seqlen_q = M if cu_seqlens_q is not None: assert max_seqlen_q is not None actual_max_seqlen_q = max_seqlen_q actual_max_seqlen_k = max_seqlen_k if max_seqlen_k is not None else N logsumexp_dim = ( math.ceil(actual_max_seqlen_q / 32) * 32 if compute_log_sumexp else 0 ) logsum_exp = convert_tensor( torch.empty( (logsumexp_batch_dim, num_heads, logsumexp_dim), dtype=torch.float, device="meta", ), query.device, ) # See Note [Seed and Offset]: seed = convert_tensor( torch.empty((), dtype=torch.long, device="meta"), query.device ) offset = convert_tensor( torch.empty((), dtype=torch.long, device="meta"), query.device ) return res, logsum_exp, seed, offset, actual_max_seqlen_q, actual_max_seqlen_k FAST_OP_IMPLEMENTATIONS = {} # Unlike register_op_impl, these don't do the slow iteration for # run_impl_check, and these run BEFORE decompositions def register_fast_op_impl(func: OpOverload): def impl_decorator(op_impl): FAST_OP_IMPLEMENTATIONS[func] = op_impl return op_impl return impl_decorator # infer_size_impl in ExpandUtils def infer_size(a, b): from torch.fx.experimental.symbolic_shapes import guard_size_oblivious dimsA = len(a) dimsB = len(b) ndim = max(dimsA, dimsB) expandedSizes = [0] * ndim for i in range(ndim - 1, -1, -1): offset = ndim - 1 - i dimA = dimsA - 1 - offset dimB = dimsB - 1 - offset sizeA = a[dimA] if dimA >= 0 else 1 sizeB = b[dimB] if dimB >= 0 else 1 # NB: It is very important to test for broadcasting, before testing # sizeA == sizeB. This is because the broadcasting tests are likely # to be statically known (in particular, if sizeA/sizeB is unbacked # but size-like, we will unsoundly assume they never equal 1), but # the sizeA == sizeB test may not be statically known. However, once # we have established that no broadcasting is happening, the # sizeA == sizeB is now expect_true and we can defer it as a runtime # assert (this works because Python will return the terminal # expression of an or statement as-is, without bool()'ing it; if this # were not the case, we'd need to write this using torch.sym_or() or # something like that). torch._check( guard_size_oblivious(sizeA == 1) or guard_size_oblivious(sizeB == 1) or sizeA == sizeB, lambda: f"The size of tensor a ({sizeA}) " f"must match the size of tensor b ({sizeB}) " f"at non-singleton dimension {i})", ) expandedSizes[i] = sizeB if guard_size_oblivious(sizeA == 1) else sizeA return tuple(expandedSizes) def make_fast_binary_impl(slow_ref): def fast_binary_impl(mode, *args, **kwargs): def slow(msg): count_label(f"slow {msg}") with mode: return slow_ref(*args, **kwargs) count_label("attempt fast") # Fast path (based off of TensorIterator fast path). # Unfortunately, there is no way to easily deduplicate # this with either the TensorIterator C++ implementation # (which we don't want to SymIntify, and also the algorithm # here is slightly different from TensorIterator to allow # for broadcasting), nor the PrimTorch implementation # (which does not actually implement a fast path.) operands = args # compute_shape has_scalars = False has_tensors = False final_shape = None for op in operands: shape = op.shape if isinstance(op, torch.Tensor) else () if len(shape) == 0: has_scalars = True else: has_tensors = True if final_shape is None: final_shape = shape # TODO: Minor optimization: track if the shapes # were equal so you can skip the equality check # below if unnecessary final_shape = infer_size(final_shape, shape) assert final_shape is not None # Do some extra safety checks to see if the output # stride is obvious for op in operands: if ( isinstance(op, torch.Tensor) and len(op.shape) == len(final_shape) and op.shape == final_shape ): break else: return slow("both tensors nontrivially broadcast") # compute_types cpu = torch.device("cpu") common_device = cpu common_dtype = None output_dtype = None has_different_input_dtypes = False for op in operands: if not isinstance(op, torch.Tensor): # Use elementwise_dtypes for the tricky case has_different_input_dtypes = True continue if common_device == cpu and not op.device.type == "cpu": common_device = op.device # Slightly simplified here as target_dtype cannot vary if common_dtype is None: common_dtype = op.dtype elif common_dtype != op.dtype: has_different_input_dtypes = True if has_different_input_dtypes: # compute promotion # TODO: we don't need the compute type _, common_dtype = elementwise_dtypes( *operands, type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT ) # check all tensors on same device # cpu scalars are assumed allow current_cpu_scalars_on_non_cpu = 0 max_cpu_scalars_on_non_cpu = 1 # hard coded atm for op in operands: if not isinstance(op, torch.Tensor): continue if common_device != cpu and op.dim() == 0 and op.device == cpu: if current_cpu_scalars_on_non_cpu >= max_cpu_scalars_on_non_cpu: return slow("error") current_cpu_scalars_on_non_cpu += 1 elif op.device != common_device: return slow("error") # compute_fast_setup_type is_contiguous = True is_channels_last = True # TODO: is_non-overlapping_and_dense (not bound from Python # no inplace, no out, everything defined if is_noncontiguous_supported(common_device): for op in operands: if not isinstance(op, torch.Tensor): continue is_contiguous = is_contiguous and op.is_contiguous( memory_format=torch.contiguous_format ) is_channels_last = is_channels_last and op.is_contiguous( memory_format=torch.channels_last ) if is_contiguous: # do contiguous count_label("fast is_contiguous") return FakeTensor( mode, torch.empty( final_shape, dtype=common_dtype, device="meta", memory_format=torch.contiguous_format, ), device=common_device, ) if is_channels_last: count_label("fast channels_last") # do channels last return FakeTensor( mode, torch.empty( final_shape, dtype=common_dtype, device="meta", memory_format=torch.channels_last, ), device=common_device, ) return slow("no contiguity match") return fast_binary_impl @functools.lru_cache(None) def get_fast_op_impls(): import torch._refs register_fast_op_impl(torch.ops.aten.add.Tensor)( make_fast_binary_impl(torch._refs.add) ) register_fast_op_impl(torch.ops.aten.sub.Tensor)( make_fast_binary_impl(torch._refs.sub) ) register_fast_op_impl(torch.ops.aten.mul.Tensor)(make_fast_binary_impl(torch._refs.mul)) # type: ignore[has-type] register_fast_op_impl(torch.ops.aten.div.Tensor)( make_fast_binary_impl(torch._refs.div) ) return FAST_OP_IMPLEMENTATIONS