import collections import contextlib import dataclasses import functools import itertools import logging import re import textwrap import traceback from contextlib import nullcontext from enum import Enum from functools import partial from typing import ( Any, Callable, ClassVar, Dict, Iterable, List, Optional, Sequence, Set, Tuple, TYPE_CHECKING, Union, ) from unittest.mock import patch import sympy from sympy import Expr, Integer import torch._export.serde.schema as export_schema import torch._logging import torch.fx import torch.utils._pytree as pytree from torch._dynamo.device_interface import get_interface_for_device from torch._dynamo.utils import identity from torch._export.serde.serialize import GraphModuleSerializer from torch._higher_order_ops.auto_functionalize import can_auto_functionalize from torch._prims_common import ( compute_required_storage_length, is_boolean_dtype, is_float_dtype, make_channels_last_strides_for, make_contiguous_strides_for, StrideType, ) from torch._subclasses.fake_tensor import get_schema_info from torch.fx.experimental.symbolic_shapes import free_unbacked_symbols, SymTypes from torch.utils._sympy.functions import CleanDiv, FloorDiv, ModularIndexing from . import config, dependencies from .codegen.common import index_prevent_reordering from .dependencies import ( extract_free_unbacked_symbols, extract_input_node_reduction_ranges, extract_read_writes, var_builder, ) from .ops_handler import OpCounterCSE from .utils import ( argsort, cache_on_self, convert_shape_to_inductor, convert_shape_to_symint, developer_warning, get_kernel_metadata, is_dynamic, pad_listlike, sympy_dot, sympy_index_symbol, sympy_product, sympy_subs, ) from .virtualized import ops, V if TYPE_CHECKING: from .graph import GraphLowering log = logging.getLogger(__name__) indent = functools.partial(textwrap.indent, prefix=" ") aten = torch.ops.aten """ [Note: Inductor IR] Inductor's IR is produced by executing 'lowering' code (see Each lowering is registered to a particular aten operator, and expects inputs that correspond to the aten schema. However, in place of torch Tensor inputs, lowerings expect Inductor TensorBox inputs. TensorBox IR represents torch tensors. Tensors are sometimes single objects owning storage, and sometimes views of another Tensor's storage. Mutating tensor operations (such as add_()) affect the underlying storage and any associated views. Other operations (such as .t_()) update metadata about the current view but don't modify the underlying storage. To model this in Inductor, the IR distinguishes between TensorBox, View, StorageBox and Buffer. TensorBox is the top level IR construct that any lowering should produce and maps to a torch.Tensor output from an operation. But just as torch.Tensors take different forms, TensorBox IR can reference View IR or directly reference StorageBox IRs. Some Inductor lowerings produce new sets of 'Box'es, while others (such as .t() or other view ops) may take an existing TensorBox and point it to a new underlying View IR. Tensors that directly own storage are represented as a chain of: TensorBox -> StorageBox -> Buffer where Buffer is a simple (1D) allocation, and StorageBox introduces the concept of a Layout. If you mutate the data of such a tensor, we swing the StorageBox pointer to point to a new buffer (leaving the old buffer unmodified and functionalizing the operation). Tensors backed by views add one more indirection to the IR. TensorBox -> View -> StorageBox -> Buffer In these cases, the underlying StorageBox/Buffer will be shared with the pre-view TensorBox. """ def validate_ir(node_or_nodes): def _check_tensorbox(nodes): # Could expand this to check deeper properties # (e.g. TensorBox points to View or StorageBox) if isinstance(nodes, (list, tuple)): for node in nodes: _check_tensorbox(node) elif isinstance(nodes, dict): for node in nodes.values(): _check_tensorbox(node) else: assert isinstance( nodes, (, DynamicScalar, AssertScalar, TensorBox, sympy.logic.boolalg.Boolean, Expr, ), ), f"Found {type(nodes)}, which is not a supported top level IR node. See [Note: Inductor IR]" # Be picky about the accepted data structure (don't use pytree here) _check_tensorbox(node_or_nodes) def ops_wrapper(name): assert isinstance(name, str) def fn(*args, **kwargs): return getattr(ops, name)(*args, **kwargs) return fn def inverse_reorder(order): inv_order = dict(zip(order, range(len(order)))) def reindex(index): assert len(index) == len(inv_order) return [index[inv_order[i]] for i in range(len(index))] return reindex def same_reorder(order): def reindex(index): assert len(index) == len(order) return [index[order[i]] for i in range(len(index))] return reindex def fuse_reindexing(reindex1, reindex2): def reindex(index): return reindex1(reindex2(index)) return reindex NHWC_STRIDE_ORDER = [3, 0, 2, 1] def stride_order2fill_order(order): """ Convert stride order to fill order For channel last format, stride order = [3, 0, 2, 1] and fill order = [1, 3, 2, 0] """ lookup = {pos: idx for idx, pos in enumerate(order)} fill_order = [lookup[i] for i in range(len(order))] return fill_order def get_stride_order(seq: Sequence[int]) -> List[int]: """ Convert strides to stride order """ sorted_idx: List[int] = argsort(seq) out = [0 for _ in range(len(seq))] for i, elem in enumerate(sorted_idx): out[elem] = i return out def ir_node_to_tensor(x, guard_shape=True): if x is None: return None shape_fn: Callable[[Expr], Union[int, Expr]] if not guard_shape: shape_fn = V.graph.sizevars.size_hint else: shape_fn = identity size = [shape_fn(s) for s in x.get_size()] stride: StrideType if is_storage_and_layout(x): stride = [shape_fn(s) for s in x.get_layout().stride] # type: ignore[misc] else: stride = make_contiguous_strides_for(size) # type: ignore[arg-type] dtype = x.get_dtype() device = x.get_device() size = convert_shape_to_symint(size) stride = convert_shape_to_symint(stride) t = torch.empty_strided( size=size, stride=stride, dtype=dtype, device=device ).zero_() return t def may_convert_to_optional(value): if isinstance(value, list) and not value: # [None] makes sure the cpp wrapper codegen will generate something like # {c10::nullopt} instead of {} return [None] return value def get_device_type(x): if getattr(x, "get_device", None): return get_device_type(x.get_device()) if isinstance(x, torch.device): return x.type return None def is_triton(x): return get_device_type(x) == "cuda" def is_cpu(x): return get_device_type(x) == "cpu" class IRNode: _current_origins: ClassVar[Set[Any]] = set() @staticmethod @contextlib.contextmanager def current_origins(origins: Set[torch.fx.Node]): old = IRNode._current_origins IRNode._current_origins = old | origins try: yield finally: IRNode._current_origins = old def __post_init__(self): = set(self._current_origins) self.traceback = traceback.format_stack() if config.debug_ir_traceback else None def get_traceback(self): return self.traceback def common_repr(self): origins = f"origins={getattr(self, 'origins', '')}" if len(origins) > 64: # this can get *very* long origins = f"{origins[:61]}..." return [origins] def str_helper(self, lines): lines = lines + self.common_repr() lines = indent(",\n".join(map(str, lines))) return f"{type(self).__name__}(\n{lines}\n)" def is_user_of(self, name): return name in self.get_read_names() @cache_on_self def get_read_names(self): return { for dep in self.get_reads()} def get_dtype(self): return self.dtype def get_layout(self): raise NotImplementedError(f"get_layout() is not implemented by {type(self)}!") def get_size(self): raise NotImplementedError(f"get_size() is not implemented by {type(self)}!") def get_numel(self): return sympy_product(self.get_size()) def is_zero_elements(self): return V.graph.sizevars.is_expr_static_and_true(sympy.Eq(self.get_numel(), 0)) # type: ignore[arg-type] def realize(self): """ If the IRNode refers to data which has not been materialized (e.g., it is a Pointwise/Reduction that could potentially have more compute fused into it), realize the IRNode into physical memory, ending the possibility of fusing into it, but allowing, e.g., multiple users to access the data without having to recompute. Check StorageBox.realize for a particularly notable implementation. TODO(ezyang): I think, in principle, every IRNode should have an implementation of this, and most of the time no-op is OK, but you really do have to audit each IRNode for this, so for now, raise an error if it's not implemented. Note that some code in will catch this thrown error and suppress it with a warning. """ raise NotImplementedError(f"realize NYI on {type(self)}") def codegen_reference(self, writer=None): raise NotImplementedError(f"codegen_reference NYI on {type(self)}") # The abstract method declarations below serve to convince mypy that all IRNode instances have these functions # defined, while having no effect at runtime. We cannot create stub implementations here because other parts of # the code dynamically check for defined attributes. get_device: Callable[[], torch.device] dtype: torch.dtype get_name: Callable[[], str] get_reads: Callable[[], Any] get_stride: Callable[[], Any] get_storage_numel: Callable[[], Any] has_exceeded_max_reads: Callable[[], bool] make_loader: Callable[[], Callable[[Any], Any]] make_indexer: Callable[[], Callable[[Any], Any]] mark_reuse: Callable[[int], None] realize_hint: Callable[[], None] get_unbacked_symbol_uses: Callable[[], Set[sympy.Symbol]] @dataclasses.dataclass class Loops(IRNode): device: torch.device dtype: torch.dtype inner_fn: Callable[..., Any] ranges: List[Expr] def get_unbacked_symbol_uses(self) -> Set[sympy.Symbol]: return set().union( *(free_unbacked_symbols(e) for e in self.ranges), self.inner_fn_free_unbacked_symbols(), ) def __str__(self, names=("ranges",)): return self.str_helper( [ f"'{self.device.type}'", str(self.dtype), self.inner_fn_str(), ] + [f"{name}={getattr(self, name)}" for name in names] + [f"origin_node={self.origin_node!r}"] ) def __post_init__(self): super().__post_init__() self.origin_node = None __repr__ = __str__ def get_device(self): return self.device def get_origin_node(self): return self.origin_node def get_size(self): return self.ranges def get_pointwise_size(self): return self.ranges def is_extern(self): return False @classmethod def create(cls, *args, **kwargs): origin_node = kwargs.pop("origin_node", None) tb = kwargs.pop("traceback", None) r = cls(*args, **kwargs) r.origin_node = origin_node r.traceback = ( tb or traceback.format_stack() if config.debug_ir_traceback else None ) return TensorBox.create(r) @staticmethod def _index(ranges, prefix="i"): return [ sympy.Integer(0) if s == 1 else sympy_index_symbol(f"{prefix}{n}") for n, s in enumerate(ranges) ] @cache_on_self def inner_fn_opcount(self): from .ir import FlexibleLayout opcounter = OpCounterCSE(V.MockHandler()) with V.set_ops_handler(opcounter), patch.object( FlexibleLayout, "allow_indexing", True ): result = self.inner_fn(*self.inner_fn_args()) return opcounter.op_count def inner_fn_args(self): return (self._index(self.ranges),) def inner_fn_str(self): return V.KernelFormatterHandler.ir_to_string( self.inner_fn, *self.inner_fn_args() ) def has_large_inner_fn(self): return self.inner_fn_opcount() > config.realize_opcount_threshold def inner_fn_free_unbacked_symbols(self): index = self._index(self.ranges) return extract_free_unbacked_symbols(self.inner_fn, index) def get_reads(self): with patch.object(FlexibleLayout, "allow_indexing", True): if self.get_reduction_type(): return extract_read_writes( self.make_loader(), self.get_size(), self.get_reduction_size(), ).reads else: return extract_read_writes( self.make_loader(), self.get_size(), ).reads def get_reduction_size(self): raise NotImplementedError( f"get_reduction_size() is not implemented by {type(self)}!" ) def get_reduction_type(self): raise NotImplementedError( f"get_reduction_type() is not implemented by {type(self)}!" ) def constant_to_device(self, device): raise NotImplementedError( f"constant_to_device() is not implemented by {type(self)}!" ) def nop_loader_fn(idx, *, dtype): if dtype.is_floating_point: return ops.constant(float("nan"), dtype) else: return ops.constant(0, dtype) class Pointwise(Loops): def make_loader(self): # Make zero-element loops into a no-op if self.is_zero_elements(): return partial(nop_loader_fn, dtype=self.dtype) return self.inner_fn def get_reduction_size(self): return [] def get_reduction_type(self): return None def store_output(self, output_name, indexer, vars): loader = self.make_loader() return, indexer(vars), loader(vars)) def constant_to_device(self, device): """Move this to a given device. Requires that all reads are to constants.""" loader = self.make_loader() loader = patch.object(ConstantBuffer, "override_device", device)(loader) return Pointwise(device, self.dtype, loader, self.ranges) @dataclasses.dataclass class Scatter(Pointwise): output_indexer: Callable[[List[Expr]], Expr] scatter_mode: Optional[str] = None def constant_to_device(self, device): """Move this to a given device. Requires that all reads are to constants.""" loader = self.make_loader() loader = patch.object(ConstantBuffer, "override_device", device)(loader) return Scatter( device, self.dtype, loader, self.ranges, self.output_indexer, self.scatter_mode, ) def store_output(self, output_name, indexer, vars): loader = self.make_loader() return output_name, indexer(self.output_indexer(vars)), loader(vars), mode=self.scatter_mode, ) class ReductionHint(Enum): INNER = 0 OUTER = 1 OUTER_TINY = 2 DEFAULT = 3 class TileHint(Enum): SQUARE = 0 DEFAULT = 1 REDUCTION_COMBINE_FN = { "any": ops_wrapper("logical_or"), "max": ops_wrapper("maximum"), "min": ops_wrapper("minimum"), "prod": ops_wrapper("mul"), "sum": ops_wrapper("add"), "xor_sum": ops_wrapper("bitwise_xor"), } def get_reduction_combine_fn(reduction_type, dtype): if reduction_type in REDUCTION_COMBINE_FN: combine_fn = REDUCTION_COMBINE_FN[reduction_type] elif reduction_type in {"argmax", "argmin"}: def combine_fn(a, b): a_value, a_index = a b_value, b_index = b if reduction_type == "argmin": mask =, b_value) else: mask =, b_value) equal = ops.eq(a_value, b_value) if is_float_dtype(dtype): a_isnan =, a_value) b_isnan =, b_value) mask = ops.logical_or(mask,, b_isnan)) equal = ops.logical_or(equal, ops.logical_and(a_isnan, b_isnan)) mask = ops.logical_or( mask, ops.logical_and(equal,, b_index)) ) return ( ops.where(mask, a_value, b_value), ops.where(mask, a_index, b_index), ) elif reduction_type == "welford_combine": def combine_fn(a, b): a_mean, a_m2, a_weight = a b_mean, b_m2, b_weight = b delta = b_mean - a_mean new_weight = a_weight + b_weight w2_over_w = b_weight / new_weight return ( a_mean + delta * w2_over_w, a_m2 + b_m2 + delta * delta * a_weight * w2_over_w, new_weight, ) else: raise NotImplementedError(f"unknown reduction_type={reduction_type}") return combine_fn @dataclasses.dataclass class Reduction(Loops): reduction_ranges: List[Expr] reduction_type: str # self.dtype represents the dst dtype src_dtype: torch.dtype reduction_hint: ReductionHint def __str__(self): return Loops.__str__( # type: ignore[call-arg] self, names=("ranges", "reduction_ranges", "reduction_type") ) def __repr__(self): return self.__str__() def get_unbacked_symbol_uses(self) -> Set[sympy.Symbol]: return super().get_unbacked_symbol_uses() | set().union( *(free_unbacked_symbols(e) for e in self.reduction_ranges) ) def get_reduction_size(self): return self.reduction_ranges def get_reduction_type(self): return self.reduction_type def store_reduction(self, output_name, indexer, vars, reduction_vars): value = ops.reduction( self.dtype, self.src_dtype, self.reduction_type, self.inner_fn(vars, reduction_vars), ) return ops.store_reduction(output_name, indexer(vars), value) def index_length(self): return len(self.ranges) + len(self.reduction_ranges) def inner_fn_args(self): index = self._index(self.ranges) rindex = self._index(self.reduction_ranges, "r") return (index, rindex) def inner_fn_free_unbacked_symbols(self): index = self._index(self.ranges) rindex = self._index(self.reduction_ranges, "r") return extract_free_unbacked_symbols(self.inner_fn, index, rindex) def constant_to_device(self, device): """Move this to a given device. Requires that all reads are to constants.""" loader = self.make_loader() loader = patch.object(ConstantBuffer, "override_device", device)(loader) return Reduction( device, self.dtype, loader, self.ranges, self.reduction_ranges, self.reduction_type, self.src_dtype, ReductionHint.DEFAULT, ) @staticmethod def num_splits( device, dst_dtype, src_dtype, inner_fn, ranges, reduction_ranges, reduction_type, reduction_numel, input_node: Optional[IRNode] = None, ): def _is_static(x): return isinstance(x, (int, sympy.Integer)) reduction_numel_hint = V.graph.sizevars.symbolic_hint(reduction_numel) numel_hint = V.graph.sizevars.symbolic_hint(sympy_product(ranges)) should_split = ( is_triton(device) and reduction_type not in { "argmax", "argmin", } and config.split_reductions # We don't support unbacked symints and _is_static(reduction_numel_hint) and _is_static(numel_hint) ) if not should_split: return ReductionHint.DEFAULT, 1 device_interface = get_interface_for_device(get_device_type(device)) num_sm = device_interface.Worker.get_device_properties( device ).multi_processor_count min_elements_per_thread = 32 max_elements_per_thread = 512 threads_per_sm = 2048 min_elements_per_device = min_elements_per_thread * num_sm * threads_per_sm max_elements_per_device = max_elements_per_thread * num_sm * threads_per_sm def inner_reduction_splits(reduction_numel_hint, numel_hint): # do heuristics that's close to eager mode for split inner reduction # we leak reduction autotune configs here, and will need to refactor to avoid this later num_warps = 8 num_threads = 32 * num_warps if numel_hint >= 2 * num_sm: # don't split if there are enough outputs return 1 if reduction_numel_hint <= 8192: return 1 if reduction_numel_hint * numel_hint <= min_elements_per_device: split_size = min_elements_per_thread elif reduction_numel_hint * numel_hint < max_elements_per_device: target_blocks = num_sm * threads_per_sm // (2 * num_threads) blocks_per_output = (target_blocks + numel_hint - 1) // numel_hint tmp_split_size = ( reduction_numel_hint + num_threads * blocks_per_output - 1 ) // (num_threads * blocks_per_output) divisors = sympy.divisors(reduction_numel_hint) closest = min(divisors, key=lambda x: abs(x - tmp_split_size)) if abs(closest - tmp_split_size) < 30: # prefer even splits, but never smalle than min_elements_per_thread split_size = max(closest, min_elements_per_thread) else: split_size = tmp_split_size else: divisors = sympy.divisors(reduction_numel_hint) closest = min(divisors, key=lambda x: abs(x - max_elements_per_thread)) if abs(closest - max_elements_per_thread) < 50: # prefer even splits split_size = closest else: split_size = max_elements_per_thread return (reduction_numel_hint + split_size * num_threads - 1) // ( split_size * num_threads ) def outer_reduction_splits(reduction_numel_hint, numel_hint): # TODO the best heuristic currently has XBLOCK (corresponding to numel_hint) 128 # extend to even smaller number of outputs num_warps = 8 num_threads = num_warps * 32 rvals_per_thread = 4 # comes from heuristics, refactor to not leak here xvals_per_block = 128 xblocks = (numel_hint + xvals_per_block - 1) // xvals_per_block if reduction_numel_hint * numel_hint < min_elements_per_device: split_size = min_elements_per_thread elif reduction_numel_hint * numel_hint < max_elements_per_device: target_blocks = num_sm * threads_per_sm // (num_threads) target_blocks = (target_blocks + xblocks - 1) // xblocks tmp_split_size = ( reduction_numel_hint + rvals_per_thread * target_blocks - 1 ) // (rvals_per_thread * target_blocks) divisors = sympy.divisors(reduction_numel_hint) closest = min(divisors, key=lambda x: abs(x - tmp_split_size)) if abs(tmp_split_size - closest) < 20: split_size = max(closest, min_elements_per_thread) else: split_size = tmp_split_size else: divisors = sympy.divisors(reduction_numel_hint) closest = min(divisors, key=lambda x: abs(x - max_elements_per_thread)) if abs(closest - max_elements_per_thread) < 50: # prefer even splits split_size = closest else: split_size = max_elements_per_thread return (reduction_numel_hint + rvals_per_thread * split_size - 1) // ( rvals_per_thread * split_size ) # easy cases if numel_hint == 1: split = inner_reduction_splits(reduction_numel_hint, numel_hint) if split == 1: # No need to split. return ReductionHint.INNER, split if ( len(ranges) == 0 and input_node is not None and isinstance(input_node, TensorBox) ): # Only handles the case where keep_dim = False. # Otherwise, we need to propagate reduction dim info to the stage where # the intermediate loader of the first Reduction is generated. new_ranges, new_reduction_ranges = extract_input_node_reduction_ranges( input_node ) if new_ranges is not None and new_reduction_ranges is not None: extracted_numel_hint = V.graph.sizevars.symbolic_hint( sympy_product(new_ranges + new_reduction_ranges) ) if reduction_numel_hint == extracted_numel_hint: log.debug( "Use previous IRNode's range and reduction_ranges instead of split. " "current ranges: %s, current reduction ranges: %s, current split: %d, " "new ranges: %s, new reduction ranges: %s", ranges, reduction_ranges, split, new_ranges, new_reduction_ranges, ) # If the input_node or its dependent nodes are also Reduction nodes, # use reduction_sizes of this node or its dependent nodes directly. return ReductionHint.INNER, -1 return ReductionHint.INNER, split if ( reduction_numel_hint <= min_elements_per_thread or numel_hint >= num_sm * 2 * 32 ): return ReductionHint.DEFAULT, 1 r = Reduction( device, dst_dtype, inner_fn, ranges, reduction_ranges, reduction_type, src_dtype, ReductionHint.DEFAULT, ) def get_read_indices(r): cb = ComputedBuffer( name=None, layout=FlexibleLayout( device=r.get_device(), dtype=r.get_dtype(), size=r.get_size(), ), data=r, ) read_writes = cb.get_read_writes() # try finding the full size producer # TODO this will fail for something like ((1, N) * (N, 1)).sum() # this would also possibly be wrong for producers with the different contiguity but we hope those cases are rare range_vars = [ r for r in read_writes.range_vars if isinstance(r, sympy.Expr) and not isinstance(r, sympy.Number) ] indices = [] changed = False for md in sorted(read_writes.reads, key=lambda x: if all(r in md.index.free_symbols for r in range_vars): indices.append(md.index) if in V.graph.name_to_buffer: buf = V.graph.name_to_buffer[] original_stride = buf.layout.stride buf.decide_layout() if buf.layout.stride != original_stride: changed = True return indices, changed indices, changed = get_read_indices(r) if changed: indices, _ = get_read_indices(r) if len(indices) == 0: # TODO determine splits when all inputs are broadcast return ReductionHint.DEFAULT, 1 (_, reduction_vars), ranges = dependencies.index_vars_squeeze( r.get_size(), r.get_reduction_size() ) num_outer = 0 num_inner = 0 for i in indices: i = V.graph.sizevars.simplify_with_ranges(i, ranges) strides = V.graph.sizevars.stride_hints(i, reduction_vars, ranges.keys()) outer = all(s > 1 for s in strides) if outer: num_outer += 1 else: num_inner += 1 if num_inner > num_outer: return ReductionHint.INNER, inner_reduction_splits( reduction_numel_hint, numel_hint ) else: return ReductionHint.OUTER, outer_reduction_splits( reduction_numel_hint, numel_hint ) @staticmethod def _unroll_reduction_fn(inner_fn, reduction_ranges, reduction_type, src_dtype): """Convert inner_fn from a reduction to an pointwise""" reduction_ranges = [ V.graph.sizevars.evaluate_static_shape(x) for x in reduction_ranges ] combine_fn = get_reduction_combine_fn(reduction_type, src_dtype) def fn(index): return functools.reduce( combine_fn, ( value_fn(index, rindex) for rindex in itertools.product( *[range(x) for x in reduction_ranges] ) ), ) if reduction_type in ("argmin", "argmax"): flatten_index = FixedLayout( None, # type: ignore[arg-type] None, # type: ignore[arg-type] reduction_ranges, FlexibleLayout.contiguous_strides(reduction_ranges), ).make_indexer() def value_fn(index, rindex): rindex = [sympy.expand(i) for i in rindex] return ( inner_fn(index, rindex), ops.index_expr(flatten_index(rindex), torch.int64), ) return lambda index: fn(index)[1] else: value_fn = inner_fn return fn @classmethod def create( # type: ignore[override] cls, device: torch.device, dst_dtype: torch.dtype, src_dtype: torch.dtype, inner_fn: Callable[..., Any], ranges: List[Expr], reduction_ranges: List[Expr], reduction_type: str, reduction_hint: ReductionHint = ReductionHint.DEFAULT, input_node: Optional[IRNode] = None, ): reduction_numel = V.graph.sizevars.simplify(sympy_product(reduction_ranges)) if reduction_numel == 0: # N.B. This is a hack to generate the literal of the given type # Ideally, we should be fixing `def constant` in # but it breaks due to hardcoded dtypes in other places def py_cnst(val): return ( bool(val) if dst_dtype == torch.bool else float(val) if dst_dtype.is_floating_point else int(val) ) rtypes_to_inits = { "sum": py_cnst(0), "xor_sum": py_cnst(0), "prod": py_cnst(1), "any": py_cnst(0), # "all" is desugared to `!any(!val)` } assert ( reduction_type in rtypes_to_inits.keys() ), f"{reduction_type} not supported for zero-dimension tensors!" def const_fn(index): return ops.constant(rtypes_to_inits[reduction_type], dst_dtype) return Pointwise.create( device=device, dtype=src_dtype, inner_fn=const_fn, ranges=list(ranges), ) if reduction_numel == 1: # this reduction is actually a pointwise op if reduction_type in ("argmin", "argmax"): def fn(index): return ops.constant(0, dst_dtype) else: def fn(index): reduction_index = [sympy.Integer(0) for _ in reduction_ranges] return inner_fn(index, reduction_index) return Pointwise.create(device, dst_dtype, fn, ranges) if ( isinstance(reduction_numel, sympy.Integer) and V.graph.sizevars.size_hint(reduction_numel) < config.unroll_reductions_threshold and sympy_product(ranges) != 1 ): return Pointwise.create( device, dst_dtype, cls._unroll_reduction_fn( inner_fn, reduction_ranges, reduction_type, src_dtype ), ranges, ) # triton doesn't support reduce to single element well, so break it up hint, split = cls.num_splits( device, dst_dtype, src_dtype, inner_fn, ranges, reduction_ranges, reduction_type, reduction_numel, input_node, ) # intermediate reduction in split can contain complex indexing, # and num_splits will fail to correctly set the hint # reuse the passed hint if available if reduction_hint == ReductionHint.DEFAULT: reduction_hint = hint if split == -1: assert input_node is not None new_ranges, new_reduction_ranges = extract_input_node_reduction_ranges( input_node # type: ignore[arg-type] ) assert new_ranges is not None assert new_reduction_ranges is not None return cls.create_multilayer_existing_ranges( device, dst_dtype, src_dtype, inner_fn, ranges, reduction_ranges, new_ranges, new_reduction_ranges, reduction_type, reduction_hint, ) elif split > 1: # triton doesn't support reduce to single element well, so break it up return cls.create_multilayer( device, dst_dtype, src_dtype, inner_fn, ranges, reduction_ranges, reduction_type, split, reduction_hint, ) return TensorBox.create( Reduction( device, dst_dtype, inner_fn, ranges, reduction_ranges, reduction_type, src_dtype, reduction_hint, ) ) @staticmethod def default_accumulator(reduction_type, dtype): if reduction_type in {"max", "argmax"}: if is_float_dtype(dtype): return float("-inf") elif is_boolean_dtype(dtype): return 0 else: return torch.iinfo(dtype).min if reduction_type in {"min", "argmin"}: if is_float_dtype(dtype): return float("inf") elif is_boolean_dtype(dtype): return 1 else: return torch.iinfo(dtype).max return { "sum": 0, "prod": 1, "xor_sum": 0, "any": 0, "welford_reduce": (0, 0, 0), "welford_combine": (0, 0, 0), }[reduction_type] @staticmethod def default_value(reduction_type, dtype): if reduction_type == "welford_reduce": return 0 return Reduction.default_accumulator(reduction_type, dtype) @staticmethod def _multilayer_second_step_hint( split: int, numel_hint: int, reduction_hint: ReductionHint ) -> ReductionHint: if split == -1: return reduction_hint if split <= 512 and numel_hint <= 512 and reduction_hint == ReductionHint.OUTER: return ReductionHint.OUTER_TINY if ( split <= 1024 and numel_hint <= 256 and reduction_hint == ReductionHint.OUTER ): return ReductionHint.OUTER_TINY return reduction_hint @classmethod def _multilayer_wrap_loader( cls, loader, reduction_ranges, reduction_numel, split, block_size, default, ): reindex = View.dynamic_reshape_indexer(reduction_ranges, [reduction_numel]) need_mask = not V.graph.sizevars.is_expr_static_and_true( sympy.Eq(reduction_numel % split, 0) # type: ignore[arg-type] ) def wrapper_fn(index, reduction_index): (reduction_index,) = reduction_index *new_index, reduction_block = index indices = block_size * reduction_block + reduction_index def body(): return loader(new_index, reindex([indices])) if need_mask: mask = ops.index_expr(indices, torch.int32), ops.index_expr(reduction_numel, torch.int32), ) return ops.masked(mask, body, default) else: return body() return wrapper_fn @classmethod def _multilayer_wrap_loader_existing_ranges( cls, loader, original_ranges, original_reduction_ranges, new_ranges, new_reduction_ranges, default, ): assert len(original_ranges) == 0, f"{original_ranges}= is not equal to []" reindex = View.dynamic_reshape_indexer( original_reduction_ranges, tuple(new_ranges) + tuple(new_reduction_ranges) ) def wrapper_fn(index, reduction_index): return loader([], reindex(tuple(index) + tuple(reduction_index))) return wrapper_fn @classmethod def create_multilayer_helper( cls, device: torch.device, dst_dtype: torch.dtype, src_dtype: torch.dtype, wrapper_fn: Callable[..., Any], original_ranges: List[Expr], original_reduction_ranges: List[Expr], new_ranges: List[Expr], new_reduction_ranges: List[Expr], reduction_type: str, split: int, reduction_hint: ReductionHint, ): """ Break a large reduction up into multiple smaller reductions recursively """ # triton will automatically compute reductions in fp32 if reducing over fp16/bf16 # within the kernel. keep the intermediate in fp32 so as to keep the whole reduction # in fp32 and not reduce precision by breaking up the kernel into multiple layers intermediate_dtype = ( dst_dtype if dst_dtype not in (torch.float16, torch.bfloat16) else torch.float ) intermediate = Reduction.create( device, intermediate_dtype, src_dtype, wrapper_fn, new_ranges, new_reduction_ranges, reduction_type, reduction_hint, ) intermediate.realize() intermediate_loader = intermediate.make_loader() def intermediate_fn(index, reduction_index): return intermediate_loader([*index, *reduction_index]) numel_hint = V.graph.sizevars.size_hint(sympy_product(original_ranges)) reduction_hint = cls._multilayer_second_step_hint( split, numel_hint, reduction_hint ) assert original_ranges == new_ranges[: len(original_ranges)] return TensorBox.create( Reduction( device, dst_dtype, intermediate_fn, original_ranges, new_ranges[len(original_ranges) :], reduction_type, src_dtype, reduction_hint, ) ) @classmethod def create_multilayer( cls, device: torch.device, dst_dtype: torch.dtype, src_dtype: torch.dtype, inner_fn: Callable[..., Any], ranges: List[Expr], reduction_ranges: List[Expr], reduction_type: str, split: int, reduction_hint: ReductionHint, ): """ Break a large reduction up into multiple smaller reductions recursively """ # TODO(jansel): realize the reduction so we can do dynamic indexing reduction_numel = sympy_product(reduction_ranges) block_size = FloorDiv(reduction_numel + (split - 1), split) default = cls.default_value(reduction_type, dst_dtype) wrapper_fn = cls._multilayer_wrap_loader( inner_fn, reduction_ranges, reduction_numel, split, block_size, default ) return cls.create_multilayer_helper( device, dst_dtype, src_dtype, wrapper_fn, ranges, reduction_ranges, [*ranges, split], # type: ignore[list-item] [block_size], reduction_type, split, reduction_hint, ) @classmethod def create_multilayer_existing_ranges( cls, device: torch.device, dst_dtype: torch.dtype, src_dtype: torch.dtype, inner_fn: Callable[..., Any], original_ranges: List[Expr], original_reduction_ranges: List[Expr], new_ranges: List[Expr], new_reduction_ranges: List[Expr], reduction_type: str, reduction_hint: ReductionHint, ): """ Break a large reduction up into multiple smaller reductions recursively """ default = cls.default_value(reduction_type, dst_dtype) wrapper_fn = cls._multilayer_wrap_loader_existing_ranges( inner_fn, original_ranges, original_reduction_ranges, new_ranges, new_reduction_ranges, default, ) return cls.create_multilayer_helper( device, dst_dtype, src_dtype, wrapper_fn, original_ranges, original_reduction_ranges, new_ranges, new_reduction_ranges, reduction_type, -1, reduction_hint, ) def num_reduction_outputs(reduction_type): return 3 if "welford" in reduction_type else 1 class WelfordReduction(Reduction): output_index: int def __init__( self, device, dtype, inner_fns, ranges, reduction_ranges, reduction_type, reduction_hint, output_index, ): if len(inner_fns) == 1: loader = inner_fns[0] else: def loader(idx, reduction_idx): return tuple(fn(idx, reduction_idx) for fn in inner_fns) super().__init__( device, dtype, loader, ranges, reduction_ranges, reduction_type, dtype, reduction_hint, ) self.output_index = output_index def store_reduction(self, output_name, indexer, vars, reduction_vars): values = ops.reduction( self.dtype, self.src_dtype, self.reduction_type, self.inner_fn(vars, reduction_vars), ) value = values[self.output_index] return ops.store_reduction(output_name, indexer(vars), value) @classmethod def create( # type: ignore[override] cls, device: torch.device, dtype: torch.dtype, inner_fns: Sequence[Callable[..., Any]], ranges: List[Expr], reduction_ranges: List[Expr], reduction_type: str, reduction_hint: ReductionHint = ReductionHint.DEFAULT, ): assert reduction_type in {"welford_reduce", "welford_combine"} reduction_numel = V.graph.sizevars.simplify(sympy_product(reduction_ranges)) def const(val): def inner_fn(idx): return ops.constant( val, dtype, ) return Pointwise.create( device=device, dtype=dtype, inner_fn=inner_fn, ranges=list(ranges), ) if reduction_numel == 0: mean = const(0) m2 = const(0) weight = const(0) return mean, m2, weight if reduction_numel == 1: def copy(loader): def inner_fn(idx): reduction_index = [sympy.Integer(0) for _ in reduction_ranges] return loader(idx, reduction_index) return Pointwise.create( device=device, dtype=dtype, inner_fn=inner_fn, ranges=list(ranges), ) if reduction_type == "welford_reduce": return copy(inner_fns[0]), const(0), const(1) else: return tuple(copy(fn) for fn in inner_fns) # TODO: Unrolled reduction # if ( # isinstance(reduction_numel, sympy.Integer) # and V.graph.sizevars.size_hint(reduction_numel) # < config.unroll_reductions_threshold # and sympy_product(ranges) != 1 # ): # return Pointwise.create( # device, # dst_dtype, # cls._unroll_reduction_fn( # inner_fn, reduction_ranges, reduction_type, src_dtype # ), # ranges, # ) # triton doesn't support reduce to single element well, so break it up hint, split = Reduction.num_splits( device, dtype, dtype, inner_fns[0], ranges, reduction_ranges, reduction_type=reduction_type, reduction_numel=reduction_numel, ) # intermediate reduction in split can contain complex indexing, # and num_splits will fail to correctly set the hint # reuse the passed hint if available if reduction_hint == ReductionHint.DEFAULT: reduction_hint = hint if split > 1: # triton doesn't support reduce to single element well, so break it up return cls.create_multilayer( device, dtype, inner_fns, ranges, reduction_ranges, reduction_type, split, reduction_hint, ) results = [ TensorBox.create( WelfordReduction( device, dtype, inner_fns, ranges, reduction_ranges, reduction_type, reduction_hint, output_idx, ) ) for output_idx in range(3) ] for t in results: t.realize() return results @staticmethod def default_value(reduction_type, dtype): return (0, 0, 0) @classmethod def create_multilayer( # type: ignore[override] cls, device: torch.device, dtype: torch.dtype, inner_fns: Sequence[Callable[..., Any]], ranges: List[Expr], reduction_ranges: List[Expr], reduction_type: str, split: int, reduction_hint: ReductionHint, ): """ Break a large reduction up into multiple smaller reductions recursively """ reduction_numel = sympy_product(reduction_ranges) need_mask = not V.graph.sizevars.is_expr_static_and_true( sympy.Eq(reduction_numel % split, 0) # type: ignore[arg-type] ) if need_mask and reduction_type != "welford_combine": # If we need mask, then "welford_reduce" doesn't work because # masked inputs shouldn't count towards the welford weight def constant(idx, reduction_idx, value): return ops.constant(value, dtype) return cls.create_multilayer( device=device, dtype=dtype, inner_fns=( inner_fns[0], partial(constant, value=0), partial(constant, value=1), ), ranges=ranges, reduction_ranges=reduction_ranges, reduction_type="welford_combine", split=split, reduction_hint=reduction_hint, ) block_size = FloorDiv(reduction_numel + (split - 1), split) intermediates = WelfordReduction.create( device, dtype, tuple( cls._multilayer_wrap_loader( loader, reduction_ranges, reduction_numel, split, block_size, default=0, ) for loader in inner_fns ), [*ranges, split], # type: ignore[list-item] [block_size], reduction_type, reduction_hint, ) for i in intermediates: i.realize() i_loaders = [i.make_loader() for i in intermediates] def intermediate_loader_fn(index, reduction_index, loader): return loader([*index, *reduction_index]) numel_hint = V.graph.sizevars.size_hint(sympy_product(ranges)) reduction_hint = cls._multilayer_second_step_hint( split, numel_hint, reduction_hint ) return WelfordReduction.create( device, dtype, tuple( partial(intermediate_loader_fn, loader=i.make_loader()) for i in intermediates ), ranges, [split], # type: ignore[list-item] # welford_reduce turns one input into three outputs, which are combined with welford_combine "welford_combine", reduction_hint, ) @dataclasses.dataclass class Scan(Loops): scan_ranges: List[Expr] size: List[Expr] combine_fn: Callable[..., Any] reindex: Callable[[List[Expr], List[Expr]], List[Expr]] reduction_hint: ReductionHint init: int # HACK we mimick reduction def get_unbacked_symbol_uses(self) -> Set[sympy.Symbol]: # TODO: Can combine_fn/reindex close over unbacked symbols? If so, we # need to explicitly represent the closure so we can pull out unbacked # symbols here return ( super().get_unbacked_symbol_uses() | set().union(*(free_unbacked_symbols(e) for e in self.scan_ranges)) | set().union(*(free_unbacked_symbols(e) for e in self.size)) ) def __post_init__(self): assert len(self.ranges) + len(self.scan_ranges) == len(self.size) super().__post_init__() def store_reduction(self, output_name, indexer, vars, scan_vars): idx = self.reindex(vars, scan_vars) value = self.inner_fn(idx) result = ops.scan(self.dtype, self.combine_fn, value, self.init) return, indexer(idx), result) def get_reduction_type(self): # return self.scan_op return "custom" def get_reduction_size(self): return self.scan_ranges def get_size(self): return self.size def get_pointwise_size(self): return self.ranges def index_length(self): return len(self.ranges) + len(self.scan_ranges) def inner_fn_args(self): index = self._index(self.ranges) rindex = self._index(self.scan_ranges, "r") idx = self.reindex(index, rindex) return (idx,) def inner_fn_free_unbacked_symbols(self): index = self._index(self.ranges) rindex = self._index(self.scan_ranges, "r") idx = self.reindex(index, rindex) return extract_free_unbacked_symbols(self.inner_fn, idx) @classmethod def create( cls, device: torch.device, dtype: torch.dtype, inner_fn: Callable[[List[Expr]], Any], size: List[Expr], axis: int, combine_fn: Callable[..., Any], init: Any, reduction_hint: ReductionHint = ReductionHint.DEFAULT, ) -> Optional["TensorBox"]: pointwise_ranges = [*size[:axis], *size[axis + 1 :]] scan_ranges = [size[axis]] if device.type != "cuda": # TODO: CPU support return None sizevars = V.graph.sizevars scan_numel = sizevars.simplify(sympy_product(scan_ranges)) # Scan with a single element is just a copy if sizevars.is_expr_static_and_true(sympy.Le(scan_numel, 1)): # type: ignore[arg-type] return Pointwise.create( device=device, dtype=dtype, inner_fn=inner_fn, ranges=size, ) reduction_hint, num_splits = cls.num_splits( device=device, dtype=dtype, inner_fn=inner_fn, axis=axis, pointwise_ranges=pointwise_ranges, scan_ranges=scan_ranges, combine_fn=combine_fn, scan_numel=scan_numel, ) scan_type = Scan if num_splits <= 1 else SplitScan if num_splits > 1 and torch.version.hip is not None: # Fallback for split-scan on ROCm return None def reindex(index, scan_index): assert len(scan_index) == len(scan_ranges) assert len(index) == len(pointwise_ranges) return [*index[:axis], *scan_index, *index[axis:]] result = TensorBox.create( scan_type( device=device, dtype=dtype, inner_fn=inner_fn, size=size, ranges=pointwise_ranges, scan_ranges=scan_ranges, combine_fn=combine_fn, reindex=reindex, init=init, reduction_hint=reduction_hint, ) ) result.realize() return result @classmethod def num_splits( cls, device: torch.device, dtype: torch.dtype, inner_fn: Callable[[List[Expr]], Any], axis: int, pointwise_ranges: List[Expr], scan_ranges: List[Expr], combine_fn: Callable[..., Any], scan_numel: Expr, ): # TODO: custom splitting heuristic for scan def wrapper_fn(idx, reduction_idx): return inner_fn([*idx[:axis], *reduction_idx, *idx[axis:]]) return Reduction.num_splits( device=device, dst_dtype=dtype, src_dtype=dtype, inner_fn=wrapper_fn, ranges=pointwise_ranges, reduction_ranges=scan_ranges, reduction_type="sum", reduction_numel=scan_numel, ) # This signifies a scan op that should go through TritonSplitScanKernel codgen on CUDA. @dataclasses.dataclass class SplitScan(Scan): pass def is_storage_and_layout(x): try: as_storage_and_layout(x, freeze=False) return True except NotImplementedError: return False def is_contiguous_storage_and_layout(x): try: buffer, layout = as_storage_and_layout(x, freeze=False) return layout.is_contiguous() except NotImplementedError: return False def as_storage_and_layout(x, freeze=True, want_contiguous=False, stride_order=None): """Try to simplify x into a StorageBox and a Layout""" if isinstance(x, TensorBox): return as_storage_and_layout(, freeze=freeze, want_contiguous=want_contiguous, stride_order=stride_order, ) if isinstance(x, StorageBox) and isinstance(, Buffer): if freeze: if want_contiguous: assert elif stride_order is not None: else: return x, if isinstance(x, ReinterpretView): # making the base of x contiguous or stride_ordered will not necessarily make # the ReinterpretView either, so don't pass along those arguments buffer, _ = as_storage_and_layout(, freeze=freeze, ) return buffer, x.layout raise NotImplementedError as_contiguous_storage_and_layout = functools.partial( as_storage_and_layout, want_contiguous=True ) def is_stride_order_storage_and_layout(x, stride_order): try: buffer, layout = as_storage_and_layout(x, freeze=False) return layout.is_stride_ordered(stride_order) except NotImplementedError: return False @dataclasses.dataclass class BaseView(IRNode): data: IRNode def get_unbacked_symbol_uses(self): return def make_reindexer(self): raise NotImplementedError(f"make_reindexer NYI on {self}") def make_indexer(self): inner = reindex = self.make_reindexer() def indexer(idx): return inner(reindex(idx)) return indexer def make_loader(self): inner = reindex = self.make_reindexer() def loader(idx): return inner(reindex(idx)) return loader @property def dtype(self): return def get_layout(self): return def get_device(self): return def get_origin_node(self): return None def get_name(self): return def get_pointwise_size(self): return self.get_size() def mark_reuse(self, users): return def has_exceeded_max_reads(self): return def realize(self): return def realize_hint(self): return def get_storage_numel(self): return def is_extern(self): return # type: ignore[attr-defined] def get_reads(self): with patch.object(FlexibleLayout, "allow_indexing", True): return extract_read_writes( self.make_loader(), self.get_size(), ).reads def unwrap_view(self): x: IRNode = self while isinstance(x, BaseView): x = return x def constant_to_device(self, device): """Move this to a given device. Requires that all reads are to constants.""" loader = self.make_loader() loader = patch.object(ConstantBuffer, "override_device", device)(loader) return Pointwise(device, self.get_dtype(), loader, self.get_size()) @dataclasses.dataclass class ExpandView(BaseView): size: List[Expr] @staticmethod def _normalize_size(x, new_size): """Replace `-1` with correct sizes""" new_size = list(map(sympy.expand, new_size)) old_size = x.get_size() old_size = [None] * (len(new_size) - len(old_size)) + list(old_size) assert len(new_size) == len(old_size) for i in range(len(new_size)): if new_size[i] == -1: assert old_size[i] is not None new_size[i] = old_size[i] elif old_size[i] is None or old_size[i] == 1: pass else: # Expect broadcast compatibility new_size[i] = V.graph.sizevars.expect_equals( new_size[i], old_size[i], msg=f"Broadcast failed in ExpandView({x.get_size()}, {new_size}) on dimension {i}", ) return new_size @classmethod def create(cls, x, new_size): new_size = cls._normalize_size(x, new_size) if is_storage_and_layout(x): storage, old_layout = as_storage_and_layout(x) skip = len(new_size) - len(old_layout.size) assert skip >= 0 new_stride = [sympy.Integer(0)] * skip for stride, size in zip(old_layout.stride, old_layout.size): new_stride.append(stride if size != 1 else sympy.Integer(0)) new_layout = FixedLayout( old_layout.device, old_layout.dtype, list(new_size), new_stride, old_layout.offset, ) return ReinterpretView(storage, new_layout) return ExpandView(x, new_size) def get_size(self): return self.size def make_reindexer(self): target = self.get_size() actual = skip = len(target) - len(actual) def reindex(index): index = list(index[skip:]) assert len(index) == len(actual) for i in range(len(actual)): if actual[i] == 1: # zero out broadcast dimension index[i] = sympy.Integer(0) return index return reindex @dataclasses.dataclass class PermuteView(BaseView): dims: List[Expr] @classmethod def create(cls, x, dims): dims = cls._map_neg_dims(dims) assert set(dims) == set(range(len(dims))) if is_storage_and_layout(x): storage, old_layout = as_storage_and_layout(x) new_layout = FixedLayout( old_layout.device, old_layout.dtype, [old_layout.size[i] for i in dims], [old_layout.stride[i] for i in dims], old_layout.offset, ) return ReinterpretView(storage, new_layout) return PermuteView(x, dims) @classmethod def _map_neg_dims(cls, dims): return [dim if dim >= 0 else len(dims) + dim for dim in dims] def get_size(self): assert set(self._map_neg_dims(self.dims)) == set(range(len(self.dims))) size = return [size[i] for i in self.dims] def make_reindexer(self): inv = {j: i for i, j in enumerate(self.dims)} inv = [inv[i] for i in range(len(self.dims))] # type: ignore[index] assert set(inv) == set(range(len(self.dims))) def reindex(index): return [index[i] for i in inv] return reindex class SqueezeView(BaseView): @classmethod def create(cls, x, *, dim=None): if is_storage_and_layout(x): storage, old_layout = as_storage_and_layout(x) new_size = [] new_stride = [] if dim is not None: assert isinstance(dim, int), "expected integer dim argument" assert 0 <= dim and dim < len(old_layout.size) for i, (size, stride) in enumerate(zip(old_layout.size, old_layout.stride)): if dim is None: if size != 1: new_size.append(size) new_stride.append(stride) else: if i != dim: new_size.append(size) new_stride.append(stride) else: assert size == 1, "expected squeezed size to be 1" new_layout = FixedLayout( old_layout.device, old_layout.dtype, new_size, new_stride, old_layout.offset, ) return ReinterpretView(storage, new_layout) if dim is None: # redirect to a generic view return View.create(x, [s for s in x.get_size() if s != 1]) else: assert x.get_size()[dim] == 1 return View.create(x, [s for i, s in enumerate(x.get_size()) if i != dim]) @staticmethod def squeezer(size: Tuple[sympy.Expr, ...]): new_size = [s for s in size if s != 1] not_one = [i for i, s in enumerate(size) if s != 1] length = len(size) def reindex(index: List[sympy.Expr]) -> Tuple[sympy.Expr, ...]: assert len(index) == len(not_one), f"{index} {not_one}" new_index = [sympy.Integer(0)] * length for idx, s in zip(not_one, index): new_index[idx] = s return tuple(new_index) return new_size, reindex def __init__(self, data): raise AssertionError("use SqueezeView.create()") @dataclasses.dataclass class GenericView(BaseView): size: List[Expr] reindex: Callable[..., Any] def make_reindexer(self): return self.reindex def reindex_str(self): index_old = [sympy_index_symbol(f"i{n}") for n in range(len(self.size))] index_new = list(self.reindex(index_old)) return f"lambda {', '.join(map(str, index_old))}: {index_new}" def __str__(self): return self.str_helper( [, f"size={self.size}", f"reindex={self.reindex_str()}"] ) __repr__ = __str__ @classmethod def create(cls, x, new_size, reindex): return cls(x, list(new_size), reindex) def get_size(self): return self.size @dataclasses.dataclass class View(GenericView): @staticmethod def handle_negative_index(idx, size): idx = sympy.expand(idx) size = sympy.expand(size) evaluate_expr = V.graph.sizevars.shape_env.evaluate_expr if evaluate_expr(sympy.Lt(idx, 0)): idx = idx + size return idx @classmethod def create(cls, x, new_size): assert isinstance(new_size, (tuple, list)) old_size, new_size = cls.resolve_negative_size(x.get_size(), new_size) # Skip pointless views if V.graph.sizevars.statically_known_list_equals(old_size, new_size): return x unbacked_symbols_in_sizes = False if ( len(free_unbacked_symbols(old_size)) > 0 or len(free_unbacked_symbols(new_size)) > 0 ): unbacked_symbols_in_sizes = True if 0 in new_size: def fake_reindex(index): return tuple([0] * len(old_size)) return cls(x, list(new_size), fake_reindex) # TODO: a new class for FixedTransferLayout that output layout is constrained by input layout elif is_contiguous_storage_and_layout(x) or unbacked_symbols_in_sizes: if unbacked_symbols_in_sizes and (not is_contiguous_storage_and_layout(x)): # realize x; otherwise, the dynamic_reshape_indexer below will fail # due to the size_hint's inability to process unbacked SymInts x = ExternKernel.realize_input(x) storage, old_layout = as_contiguous_storage_and_layout(x) new_layout = FixedLayout( old_layout.device, old_layout.dtype, new_size, FlexibleLayout.contiguous_strides(new_size), old_layout.offset, ) return ReinterpretView(storage, new_layout) reindex = cls.dynamic_reshape_indexer(old_size, new_size) return cls(x, list(new_size), reindex) @staticmethod def resolve_negative_size(old_size, new_size): new_size = [V.graph.sizevars.simplify(x) for x in new_size] old_size = [V.graph.sizevars.simplify(x) for x in old_size] new_size = list(new_size) for i in range(len(new_size)): if new_size[i] == -1: new_size[i] = sympy.Integer(1) new_size[i] = CleanDiv(sympy_product(old_size), sympy_product(new_size)) break V.graph.sizevars.guard_equals(sympy_product(old_size), sympy_product(new_size)) return old_size, new_size @classmethod def dynamic_reshape_indexer(cls, old_size, new_size): try: reindex = cls._dynamic_reshape_indexer(old_size, new_size) except (AssertionError, IndexError): # optimistic algorithm failed, lets do a fallback flat = [sympy_product(old_size)] reindex1 = cls._dynamic_reshape_indexer(old_size, flat) reindex2 = cls._dynamic_reshape_indexer(flat, new_size) reindex = fuse_reindexing(reindex1, reindex2) return reindex @staticmethod def _dynamic_reshape_indexer(old_size, new_size): """ Perform a reshape entirely by modifying indexing math """ size_hint = V.graph.sizevars.size_hint vars = [sympy_index_symbol(f"view{i}") for i in range(len(new_size))] stack_new = list(zip(vars, new_size)) stack_old = list(old_size) view_expr = [] while stack_new and stack_old: size_old = stack_old.pop() var, size_new = stack_new.pop() if size_old == 1: view_expr.append(sympy.Integer(0)) stack_new.append((var, size_new)) # re-add elif size_new == 1: stack_old.append(size_old) # re-add elif size_hint(size_new) == size_hint(size_old): view_expr.append(var) V.graph.sizevars.guard_equals(size_new, size_old) elif size_hint(size_new) < size_hint(size_old): while size_hint(size_new) < size_hint(size_old): var2, size_new2 = stack_new.pop() var = var2 * size_new + var size_new = size_new * size_new2 view_expr.append(var) V.graph.sizevars.guard_equals(size_new, size_old) elif size_hint(size_new) > size_hint(size_old): divisor = sympy.Integer(1) modulus = size_old view_expr.append(ModularIndexing(var, divisor, modulus)) divisor = divisor * modulus while size_hint(size_new) > size_hint(size_old): modulus = stack_old.pop() view_expr.append(ModularIndexing(var, divisor, modulus)) divisor = divisor * modulus size_old = size_old * modulus V.graph.sizevars.guard_equals(size_new, size_old) else: raise AssertionError() while stack_old: size_old = stack_old.pop() V.graph.sizevars.guard_equals(size_old, 1) # type: ignore[arg-type] view_expr.append(sympy.Integer(0)) while stack_new: var, size_new = stack_new.pop() V.graph.sizevars.guard_equals(size_new, 1) # type: ignore[arg-type] view_expr.reverse() assert len(view_expr) == len(old_size) def reindex(index): assert len(index) == len(vars), (len(index), len(vars)) replacements = dict(zip(vars, index)) return tuple(sympy_subs(x, replacements) for x in view_expr) # type: ignore[arg-type] return reindex @dataclasses.dataclass class ReinterpretView(BaseView): """Pretend our storage has a different layout""" layout: "Layout" def __post_init__(self): super().__post_init__() if isinstance(, BaseView): = def __str__(self): return self.str_helper( [, self.layout, ] ) __repr__ = __str__ def get_name(self): return def get_device(self): return self.layout.device def get_origin_node(self): return None @property def dtype(self): return self.layout.dtype def get_size(self): return list(self.layout.size) def get_stride(self): return list(self.layout.stride) def make_loader(self): def loader(index): indexer = self.layout.make_indexer() return ops.load(self.get_name(), indexer(index)) return loader def make_indexer(self): return self.layout.make_indexer() def get_layout(self): return self.layout def freeze_layout(self): pass def get_unbacked_symbol_uses(self) -> Set[sympy.Symbol]: return ( free_unbacked_symbols(self.layout.size) | free_unbacked_symbols(self.layout.stride) | free_unbacked_symbols(self.layout.offset) ) def codegen_reference(self, writer=None): # reinterpret_tensor is similar to as_strided except: # - offset is added to the existing offset (rather than replacing it) # - view tracking is disabled similar to unsafe_view return V.graph.wrapper_code.codegen_reinterpret_view(, self.layout.size, self.layout.stride, self.layout.offset, writer, ) class SliceView(View): @classmethod def normalize_start_end(cls, x, dim, start, end): """ Normalize start and end such that both are in the range [0, x.get_size()[dim]] and start <= end. """ sizevars = V.graph.sizevars dim_size = x.get_size()[dim] if any(free_unbacked_symbols(x) for x in (start, end, dim_size)): def clamp(x, lower, upper): return sympy.Min(sympy.Max(x, lower), upper) else: def clamp(x, lower, upper): return sizevars.evaluate_min(sizevars.evaluate_max(x, lower), upper) def clamp_wrap(val, lower, upper, default): if val is None: return default val = cls.handle_negative_index(val, dim_size) return clamp(val, lower, upper) start = clamp_wrap(start, 0, dim_size, 0) end = clamp_wrap(end, start, dim_size, dim_size) return start, end @classmethod def create(cls, x, dim, start, end, step=1): step = sympy.expand(step) assert step > 0 try: if start == 0 and end >= 2**63 - 1 and step == 1: return x except TypeError: pass sizevars = V.graph.sizevars new_size = list(x.get_size()) start, end = cls.normalize_start_end(x, dim, start, end) new_size[dim] = FloorDiv(end - start + (step - 1), step) if is_storage_and_layout(x): # Fast path storage, old_layout = as_storage_and_layout(x) new_stride = list(old_layout.stride) new_stride[dim] = new_stride[dim] * step new_layout = FixedLayout( old_layout.device, old_layout.dtype, new_size, new_stride, old_layout.offset + old_layout.stride[dim] * start, ) return ReinterpretView(storage, new_layout) def reindex(index): assert len(index) == len(new_size), f"wrong ndim {index} {new_size}" index = list(index) index[dim] = index[dim] * step + start return index # redirect to a generic view return SliceView(x, size=new_size, reindex=reindex) class BaseConstant(IRNode): dtype: torch.dtype device: torch.device def get_size(self): return () def get_device(self): return self.device def get_origin_node(self): return None def mark_reuse(self, users): pass def has_exceeded_max_reads(self): return False def get_reads(self): return () def is_extern(self): return False @dataclasses.dataclass class Constant(BaseConstant): value: Any dtype: torch.dtype device: torch.device def make_loader(self): def loader(index): return ops.constant(self.value, self.dtype) return loader def realize(self): pass def constant_to_device(self, device): return Constant(self.value, self.dtype, device) @dataclasses.dataclass class IndexingConstant(BaseConstant): index: Any dtype: torch.dtype device: torch.device def make_loader(self): def loader(index): return ops.index_expr(self.index, self.dtype) return loader def constant_to_device(self, device): return IndexingConstant(self.index, self.dtype, device) def is_contiguous_strides_for_shape(stride, shape): return all( size == 1 or left == right for left, right, size in zip( stride, FlexibleLayout.contiguous_strides(shape), shape ) ) @dataclasses.dataclass class Layout(IRNode): def __init__( self, device: torch.device, dtype: torch.dtype, size: List[Expr], stride: Optional[Sequence[Union[Expr, int]]], offset: Expr = Integer(0), ): assert stride is None or len(size) == len( stride ), f"size={size}, stride={stride}" self.device = device self.dtype = dtype assert all(isinstance(s, (Expr, int)) for s in size) self.size = size self._stride = stride self.offset = offset @property def stride(self): return self._stride def __str__(self): offset = "" if self.offset != 0: offset = f", offset={self.offset}" return ( f"{type(self).__name__}('{self.device.type}', {self.dtype}, " f"size={self.size}, stride={self.stride}{offset})" ) __repr__ = __str__ def is_contiguous(self): return is_contiguous_strides_for_shape(self.stride, self.size) def is_channels_last_contiguous(self): ndim = len(self.size) if ndim not in [4, 5]: return False for left, right, size in zip( self.stride, make_channels_last_strides_for(self.size), self.size # type: ignore[arg-type] ): if size != 1 and left != right: return False return True def is_transposed(self): for left, right, size in zip( self.stride, reversed(FlexibleLayout.contiguous_strides(self.size)), self.size, ): if size != 1 and left != right: return False return True def is_stride_ordered(self, order): assert len(self.stride) == len(order) # ignore dimensions of size 1, they dont affect layout non_1_indices = [ i for i, dim in enumerate(self.size) if V.graph.sizevars.size_hint(dim, fallback=2) != 1 ] stride = [self.stride[i] for i in non_1_indices] order = [order[i] for i in non_1_indices] def sorted_indices(arr): sorted_arr = sorted(arr) return [sorted_arr.index(element) for element in arr] # since we may have removed dimensions, need to re-sort & re-index order order = sorted_indices(order) # reorder the stride given order stride_ordered = [-1] * len(order) for i in range(len(order)): stride_ordered[order[i]] = V.graph.sizevars.size_hint(stride[i]) # check if it is in ascending order for i in range(len(order) - 1): if stride_ordered[i] > stride_ordered[i + 1]: return False return True def is_channels_last_stride_ordered(self): # create channels_last order(NCHW, NCDHW, the C is the first order). order = [0] + list(reversed(range(1, len(self.stride) - 1))) order = [len(order)] + order return self.is_stride_ordered(order) def as_fixed(self): return FixedLayout( self.device, self.dtype, self.size, self.stride, self.offset, ) def make_indexer(self): assert ( FlexibleLayout.allow_indexing ), f"convert {type(self).__name__} to FixedLayout first" return self.as_fixed().make_indexer() def __eq__(self, other) -> bool: return ( self.device == other.device and self.dtype == other.dtype and self.size == other.size and self.stride == other.stride and self.offset == other.offset ) def storage_size(self) -> sympy.Expr: return compute_required_storage_length(self.size, self.stride, self.offset) # type: ignore[arg-type, return-value] class FixedLayout(Layout): """A Tensor layout we cannot change""" def __init__( self, device: torch.device, dtype: torch.dtype, size: Union[List[Expr], List[int]], stride: Optional[Sequence[Union[Expr, int]]] = None, offset: Union[Expr, int] = Integer(0), ): if stride is None: stride = FlexibleLayout.contiguous_strides(size) super().__init__( device, dtype, size, # type: ignore[arg-type] stride, offset, # type: ignore[arg-type] ) def make_indexer(self): """A closure containing math to read a given element""" def indexer(index): assert len(index) == len(self.stride) == len(self.size) result = self.offset for idx, stride, sz in zip(index, self.stride, self.size): if sz != 1: result = result + idx * stride return result return indexer class FlexibleLayout(Layout): """A Tensor layout we are allowed to change""" allow_indexing = False @staticmethod def contiguous_strides(sizes): if len(sizes) == 0: return [] reversed_strides = [sympy.Integer(1)] for size in reversed(sizes[1:]): reversed_strides.append(size * reversed_strides[-1]) return list(reversed(reversed_strides)) @staticmethod def fill_ordered(sizes, order): """ Create a stride based on the order the dimensions should be filled in. In this format, channels last would be: [1, 3, 2, 0] """ assert set(range(len(sizes))) == set(order) next_stride = sympy.Integer(1) strides = [None] * len(order) for i in order: strides[i] = next_stride next_stride = next_stride * sizes[i] return strides @staticmethod def stride_ordered(sizes, order): """ Create a stride based on the sorted order of a permuted range. In this format, channels last would be: [3, 0, 2, 1] """ assert set(range(len(sizes))) == set(order) fill_order = stride_order2fill_order(order) return FlexibleLayout.fill_ordered(sizes, fill_order) @staticmethod def same_ordered(sizes, stride): """ Create a stride that has the same stride order as given stride For example, if given stride is [1000, 1, 100, 10], the fill order should be [1, 3, 2, 0] """ assert len(sizes) == len(stride) stride = [V.graph.sizevars.size_hint(x) for x in stride] fill_order = sorted(range(len(stride)), key=stride.__getitem__) return FlexibleLayout.fill_ordered(sizes, fill_order) def as_stride_order(self, order): return FixedLayout( self.device, self.dtype, self.size, self.stride_ordered(self.size, order), self.offset, ) def as_fill_order(self, order): return FixedLayout( self.device, self.dtype, self.size, self.fill_ordered(self.size, order), self.offset, ) def as_same_order(self, stride): return FixedLayout( self.device, self.dtype, self.size, self.same_ordered(self.size, stride), self.offset, ) def __init__(self, device, dtype, size, stride_order=None): if stride_order: strides = FlexibleLayout.fill_ordered(size, stride_order) else: strides = FlexibleLayout.contiguous_strides(size) super().__init__(device, dtype, size, strides) class AliasedLayout(Layout): """Shares the same storage as another tensor""" def __init__(self, view: Union[BaseView, "TensorBox"]): layout = view.get_layout() super().__init__( layout.device, layout.dtype, layout.size, layout.stride, ) self.view = view def make_indexer(self): return self.as_fixed().make_indexer() def maybe_guard_aligned(self): offset = self.view.get_layout().offset if offset == 0: return True from .compile_fx import ALIGNMENT return V.graph.sizevars.statically_known_multiple_of(offset, ALIGNMENT) # type: ignore[arg-type] class NoneLayout(IRNode): # This is janky, I figured out what fields to populate by just running # the model I was interested in and adding properties/methods as needed. # This doesn't inherit from Layout because Layout assumes you have stuff # like sizes, but I don't really have anything here. # # If you have an ir.Node with NoneLayout, you probably need to setup # dependencies manually in scheduler def __init__(self, device): self.device = device self.size = [0] self.stride = [0] def storage_size(self): return 0 def as_fixed(self): return self class MutationLayout(Layout): def __init__(self, target: IRNode): super().__init__( target.get_device(), target.get_dtype(), target.get_size(), None, ) = target name = self.get_buffer().get_name() V.graph.mark_buffer_mutated(name) @Layout.stride.getter # type: ignore[attr-defined] def stride(self): return self.real_layout().stride def storage_size(self) -> sympy.Expr: return self.real_layout().storage_size() def get_buffer(self) -> "Buffer": def unwrap_views(target): if isinstance(target, MutationLayout): return unwrap_views( if isinstance(target, BaseView): return unwrap_views(target.unwrap_view()) if isinstance(target, MutableBox): return unwrap_views( return target result = unwrap_views( assert isinstance(result, Buffer), "MutationLayout must refer to a buffer" return result def real_layout(self): return self.get_buffer().layout @classmethod def realize_into(cls, src, dst, unsafe_alias=False): dst.realize() # NOTE: We must realize users of `dst` before we realize `src`, since # realization order determines scheduling order. Otherwise, src's # mutation would be scheduled before the existing users of dst! V.graph.mark_buffer_mutated(dst.get_name()) if isinstance(src, TensorBox): src = # We copy the contents of src into dst. In most cases this should # be fused into a single kernel by the scheduler. # NOTE: We cannot change src's layout to mutate dst directly as this # would alias src to dst, which is not correct as further mutations to # dst would effect users of src. However if there are no more users of # dst, we can alias src to dst. src.realize_hint() if not unsafe_alias: src = Pointwise.create( device=src.get_device(), dtype=src.get_dtype(), inner_fn=src.make_loader(), ranges=[ V.graph.sizevars.guard_equals(a, b) for a, b in zip(src.get_size(), dst.get_size()) ], ).data src.realize() assert isinstance(, FlexibleLayout) = MutationLayout(dst) return def as_fixed(self): return self def make_indexer(self): return @dataclasses.dataclass class Buffer(IRNode): # Name is sometimes None; e.g., ForceInPlace, where there isn't # a meaningful name name: Optional[str] layout: Layout # Multi-output buffers will define 'outputs: List[Buffer]'. Confusingly, # MultiOutput does NOT define this! def __post_init__(self): super().__post_init__() self.origin_node = None def make_indexer(self): return self.layout.make_indexer() def get_name(self) -> str: assert return def get_device(self): return self.layout.device def get_origin_node(self): return self.origin_node @property def dtype(self): return getattr(self.layout, "dtype", None) def get_size(self): return list(self.layout.size) def get_stride(self): return list(self.layout.stride) def get_offset(self): return self.layout.offset def get_layout(self): return self.layout def get_storage_numel(self): return self.get_numel() def is_extern(self): return False def freeze_layout(self): if not isinstance(self.layout, (MultiOutputLayout, AliasedLayout)): self.layout = self.layout.as_fixed() def freeze_layout_with_stride_order(self, order): assert isinstance(self.layout, FlexibleLayout) self.layout = self.layout.as_stride_order(order) def freeze_layout_with_fill_order(self, order): assert isinstance(self.layout, FlexibleLayout) self.layout = self.layout.as_fill_order(order) def freeze_layout_with_same_order(self, stride): assert isinstance(self.layout, FlexibleLayout) self.layout = self.layout.as_same_order(stride) def is_zero_elements(self): return V.graph.sizevars.is_expr_static_and_true(sympy.Eq(self.get_numel(), 0)) # type: ignore[arg-type] def make_loader(self): # Loading from a zero-element buffer is a no-op if self.is_zero_elements(): return partial(nop_loader_fn, dtype=self.get_dtype()) def loader(index): indexer = self.layout.make_indexer() return ops.load(, indexer(index)) return loader def is_no_op(self): return False def codegen_reference(self, writer=None): return self.get_name() def decide_layout(self): pass def get_alias_names(self): if isinstance(self.layout, AliasedLayout): return [self.layout.view.get_name()] return () def get_mutation_names(self): if isinstance(self.layout, MutationLayout): return [] return () def get_read_writes(self): with patch.object(FlexibleLayout, "allow_indexing", True): return extract_read_writes( self.make_loader(), self.get_size(), ) def get_reads(self): return self.get_read_writes().reads def get_unbacked_symbol_defs(self) -> Set[sympy.Symbol]: """ Returns the unbacked symbols which are defined by this IR node, because this is a data-dependent IR node, or item() """ # So this is a little unusual. In principle, you could imagine # defining a MultiOutputLayout buffer so that it DOES define # unbacked symints. However, we can't easily tell what symints # such a buffer defines, because MultiOutputLayout doesn't actually # define any useful information about what it returns. # # An easier and better approach is to delay the symint allocation # to the MultiOutput IR nodes, which are when we actually extract # out the buffers and know what their sizes are. # # There are two subleties here: # # 1. Suppose you have a kernel that produces out1: (i0,), out2: (i0,) # Both of these actually count as defs! The scheduler will just # arbitrarily pick one of these as the canonical definer and # ensure it stays live. It's not a big deal if we pick the # wrong one because tuple accesses are cheap, and all this means # is we accidentally keep a MultiOutput node live when it wasn't # strictly necessary. # # 2. Suppose you have a MultiOutput buffer whose size is (i0,), but # the MultiOutputLayout buffer it is projecting from isn't actually # dynamic; it has i0 as one of the arguments. We cannot tell this # directly from MultiOutput, we have to look at the input buffer's # uses to work this out. No big deal. if isinstance(self.layout, (NoneLayout, MultiOutputLayout)): return set() # This kernel defines all unbacked symbols... that it didn't get in as # arguments! defs = ( free_unbacked_symbols(self.get_size()) | free_unbacked_symbols(self.get_stride()) | free_unbacked_symbols(self.get_offset()) ) return defs - self.get_unbacked_symbol_uses() def get_unbacked_symbol_uses(self) -> Set[sympy.Symbol]: """ Returns the unbacked symbols which are required to be in scope in order to successfully perform codegen for this buffer. For example, a buffer that corresponds to an extern kernel call that takes i0 as an argument would return {i0} here. This is used to generate necessary dependencies that ensure we actually bind i0 in codegen before you try to use it. Note that this is NOT transitive; in particular, if this buffer takes in as input another buffer with dynamic shape (e.g., (i0,)), we will not report it here, because you will already have a dependency on that buffer, which will eventually have a dependency on i0 if necessary. """ return set() def codegen_unbacked_symbol_defs(self, wrapper): # NB: If it is possible for other ir node types to return unbacked # symints, you need to make sure their codegen calls this method. # Don't forget to update get_unbacked_symbol_defs too. symbols_to_define = self.get_unbacked_symbol_defs() for i, s in enumerate(self.get_size()): if s in symbols_to_define: wrapper.writeline( f"{wrapper.codegen_unbacked_symbol_decl(s)} = {self.get_name()}.size({i}){wrapper.ending}" ) symbols_to_define.remove(s) for i, s in enumerate(self.get_stride()): if s in symbols_to_define: wrapper.writeline( f"{wrapper.codegen_unbacked_symbol_decl(s)} = {self.get_name()}.stride({i}){wrapper.ending}" ) symbols_to_define.remove(s) if (s := self.get_offset()) in symbols_to_define: wrapper.writeline( f"{wrapper.codegen_unbacked_symbol_decl(s)} = {self.get_name()}.storage_offset(){wrapper.ending}" ) symbols_to_define.remove(s) assert ( not symbols_to_define ), f"unbacked symint {s} not written out, check comment above" def realize(self): pass def get_workspace_size(self): """ Gets extra global memory size needed by this buffer. Some algorithms (e.g. group gemm) may require extra global memory in the generated code. """ return 0 def should_allocate(self): # Returns False by default. return False class InputBuffer(Buffer): pass class ConstantBuffer(InputBuffer): override_device: Optional[torch.device] = None def make_loader(self): def loader(index): indexer = self.layout.make_indexer() return ops.load( V.graph.constant_name(self.get_name(), self.override_device), indexer(index), ) return loader def constant_to_device(self, device): return ConstantBuffer( V.graph.constant_name(self.get_name(), device), self.layout ) class NoneAsConstantBuffer(IRNode): def get_unbacked_symbol_uses(self) -> Set[sympy.Symbol]: return set() def codegen_reference(self, writer=None): return V.graph.wrapper_code.none_str class ShapeAsConstantBuffer(IRNode): def __init__(self, shape): super().__init__() self.shape = shape def get_unbacked_symbol_uses(self) -> Set[sympy.Symbol]: return free_unbacked_symbols(self.shape) def codegen_reference(self, writer=None): return V.graph.wrapper_code.expr_printer(V.graph.sizevars.simplify(self.shape)) @dataclasses.dataclass class ComputedBuffer(Buffer): data: Loops def get_computed_buffer_name(self): """ Returns if it exists, otherwise returns the name of the data node if that exists. If neither exist, returns None. """ if is not None: return if hasattr(, "name"): return return None @cache_on_self def num_reads(self): return len(self.get_read_writes().reads) def get_read_writes(self): with patch.object(FlexibleLayout, "allow_indexing", True): if return extract_read_writes( self.get_store_function(),,, ) else: return extract_read_writes( self.get_store_function(),, ) def get_unbacked_symbol_uses(self) -> Set[sympy.Symbol]: # Ordinarily, we'd like to just peek at the arguments list, # but ComputedBuffers have no argument list. # # Morally, this logic needs to be synchronized with the # KernelArgs.size calls, which are responsible for making symbols make # there way as kernel arguments (and it is precisely passing in one of # those symbols that establishes a dependency). However, we haven't # started codegen yet so we can't directly reuse that logic. # # For now, I'm just yoloing with the size of the buffer. Not sure if # it is enough. # # One thing you might wonder is if this is enough for a ComputedBuffer # denoting a reduction over i0. Empirically, it is enough, but for an # unusual reason: we only need accurate dependencies for item() call, # but it's impossible to end up with a reduction over i0 from an # item() call without a regular non-reduction buffer first. return ( free_unbacked_symbols(self.get_size()) | free_unbacked_symbols(self.get_stride()) | free_unbacked_symbols(self.get_offset()) | ) def make_loader(self): # Inline constants and index_expressions if ( hasattr(, "make_loader") and not in V.graph.mutated_buffers and self.num_reads() == 0 ): # can be inlined return return super().make_loader() def get_store_function(self): indexer = self.layout.as_fixed().make_indexer() if isinstance(, (Reduction, Scan)): return partial(,, indexer) else: assert isinstance(, Pointwise) return partial(,, indexer) def get_fill_order(self): """ If our layout is still flexible, try to determine the stride order based on stride orders of reads. TODO(jansel): A better algorithm here would look at downstream consumers of this value and try to do global graph-level layout optimization. This is also something just begging to be autotuned. """ if isinstance(self.layout, FlexibleLayout): (index_vars, reduction_vars), _ = dependencies.index_vars_squeeze(, ) reads = self.get_read_writes().reads reads_bufs = [ V.graph.name_to_buffer[] if in V.graph.name_to_buffer.keys() else None for r in reads ] # only consider reads to buffer of same size # ignore StarDeps because they don't contribute stride information assert all( isinstance(r, (dependencies.StarDep, dependencies.MemoryDep)) for r in reads ) reads = [ sympy_subs( r.index, {v: sympy.Integer(0) for v in reduction_vars if v != 0} ) for r in reads if isinstance(r, dependencies.MemoryDep) ] if reads: if isinstance(, Scan): indices =, reduction_vars) else: indices = index_vars stride_lengths = [ V.graph.sizevars.stride_hints(expr, indices) for expr in reads # type: ignore[arg-type] ] from .scheduler import pick_loop_order return pick_loop_order(stride_lengths, self.get_size()) return None def decide_layout(self): if isinstance(self.layout, FlexibleLayout): order = self.get_fill_order() if order: self.freeze_layout_with_fill_order(order) else: self.freeze_layout() def get_default_sizes_body(self): args, var_ranges = dependencies.index_vars_squeeze(,, prefix="q" ) with patch.object(ConstantBuffer, "override_device", self.get_device()): body = LoopBody( self.get_store_function(), (args if self.get_reduction_type() else args[:1]), var_ranges, ) index_vars = [] reduce_vars: List[Any] = [] index_size = [] reduce_size = [] for v, s in var_ranges.items(): if v in args[0]: assert not reduce_vars index_vars.append(v) index_size.append(s) else: assert v in args[1] reduce_vars.append(v) reduce_size.append(s) return (index_size, reduce_size), body, (index_vars, reduce_vars) def simplify_and_reorder( self, extra_indexing_constraints: Optional[Tuple[Dict[Any, Any], List[Any]]] = None, ): """ This is a main place where we do loop transformations in a backend-agnostic way. Here we: 1) Remove any 1 dimensions 2) Fuse contiguous dimensions together 3) Reorder dimensions based on stride orders Optional argument extra_indexing_constraints can be used to append additional indexing expressions to existing ones derived from buffer's body. This can be useful to fuse scheduler nodes with compatible ranges, e.g. (s0*s1*...,) and (s0, s1, s2, ...) on CPU by preventing indexing simplifications and obtaining index/reduce ranges for the scheduler node compatible with other nodes. """ ( (index_size, reduce_size), body, (index_vars, reduce_vars), ) = self.get_default_sizes_body() index_formulas = [*body.indexing_exprs.values()] if extra_indexing_constraints is not None: assert ( isinstance(extra_indexing_constraints, tuple) and len(extra_indexing_constraints) == 2 ) extra_indexing_ranges, extra_indexing_expr = extra_indexing_constraints assert isinstance(extra_indexing_ranges, dict) assert isinstance(extra_indexing_expr, list) assert all(isinstance(f, Expr) for f in extra_indexing_expr) expected_var_ranges = body.var_ranges assert expected_var_ranges == extra_indexing_ranges, ( expected_var_ranges, extra_indexing_ranges, ) # remove already existing expressions extra_indexing_expr = [ e for e in extra_indexing_expr if e not in index_formulas ] index_formulas += extra_indexing_expr reads_bufs = [ V.graph.name_to_buffer[reads_name] if reads_name in V.graph.name_to_buffer.keys() else None for reads_name in body.reads_name2expr.keys() ] memory_addrs = [ *body.reads_name2expr.values(), *body.writes_name2expr.values(), ] # the reordering_reindex in reads' simplify_reorder_and_tile reordering_reindex = [same_reorder(range(len(index_vars)))] * len(memory_addrs) for i, reads_buf in enumerate(reads_bufs): if isinstance(reads_buf, ComputedBuffer) and hasattr( reads_buf, "iter_reordering_reindex" ): reordering_reindex[i] = reads_buf.iter_reordering_reindex # type: ignore[has-type] def simplify_and_reorder(x_vars, support_vars, sizes, reordering_reindex=None): sizes, reindex0, reindex1 = self._apply_loop_reordering( x_vars, support_vars, sizes, memory_addrs, reordering_reindex ) # for NHWC: reindex0([0,1,2,3]) = [0,2,3,1], reindex1([0,1,2,3]) = [0,3,2,1] x_vars = reindex0(x_vars) sizes, reindex2, prune = V.graph.sizevars._simplify_loops( x_vars, sizes, index_prevent_reordering(index_formulas, x_vars, sizes), ) x_vars = prune(x_vars) # sizes, reindex1, prune = _simplify_loops(x_vars, sizes, index_formulas) # x_vars = prune(x_vars) # sizes, reindex2 = self._apply_loop_reordering(x_vars, sizes, memory_addrs) reindex = fuse_reindexing(reindex1, reindex2) return sizes, reindex, reindex1 support_vars = index_vars + reduce_vars iter_ranges, iter_reindex, iter_reordering_reindex = simplify_and_reorder( index_vars, support_vars, index_size, reordering_reindex ) reduce_ranges, reduce_reindex, _ = simplify_and_reorder( reduce_vars, support_vars, reduce_size ) # remember the reordering if not have loop collapse. if len(iter_ranges) == len(index_vars): self.iter_reordering_reindex = iter_reordering_reindex # retrace the loop body with simplification and reordering applied (iter_vars, reduce_vars), var_ranges = dependencies.index_vars_no_squeeze( iter_ranges, reduce_ranges, prefix="z" ) body = LoopBody( body, [iter_reindex(iter_vars), reduce_reindex(reduce_vars)], var_ranges ) return (iter_ranges, reduce_ranges), body @staticmethod def _apply_loop_reordering( index_vars, support_vars, sizes, memory_addrs, reordering_reindex=None, priority_idx=None, ): """ Shuffle the order of loops around to hopefully improve performance. """ from .scheduler import pick_loop_order if priority_idx is None: priority_idx = [] try: strides = [ V.graph.sizevars.stride_hints(expr, index_vars, support_vars) for expr in memory_addrs ] assert len(strides) == len(memory_addrs) and len(strides[0]) == len( index_vars ) # consider both layout(strides) and reordering(reordering_reindex) if reordering_reindex is not None: for i in range(len(memory_addrs)): try: strides[i] = reordering_reindex[i](strides[i]) # if len(order) != len(strides), do not reorder except AssertionError: pass order = list(reversed(pick_loop_order(strides, sizes, priority_idx))) except Exception: if config.debug: log.warning( "Did not simplify complex index:\n%s\n%s", dict(zip(index_vars, sizes)), memory_addrs, ) order = list(range(len(sizes))) sizes = [sizes[i] for i in order] return sizes, same_reorder(order), inverse_reorder(order) def get_reduction_size(self): return def get_reduction_type(self): return def is_no_op(self): return def should_allocate(self): return True def constant_to_device(self, device): """Move this to a given device. Requires that all reads are to constants.""" return class TemplateBuffer(Buffer): """ Represents a Triton (in the future other type) of template operator that we can fuse an epilogue onto. """ def __init__(self, layout, inputs, make_kernel_render): super().__init__(name=None, layout=layout) self.inputs = InputsKernel.unwrap_storage(inputs) self.make_kernel_render = make_kernel_render = V.graph.register_buffer(self) def get_read_writes(self): return self.normalized_read_writes() def normalized_read_writes(self): name = self.get_name() indexer = self.layout.make_indexer() def dummy(index, rindex): assert len(rindex) == 0 return, indexer(index), "fake") deps = dependencies.extract_read_writes( dummy, self.get_size(), (), normalize=True ) deps.reads = {dependencies.StarDep(x.get_name()) for x in self.inputs} return deps def get_reduction_size(self): return 1 def get_reduction_type(self): return None def is_no_op(self): return False def should_allocate(self): return True def simplify_and_reorder( self, extra_indexing_constraints: Optional[Tuple[Dict[Any, Any], List[Any]]] = None, ): return ( ( self.get_size(), (), ), None, ) class TritonTemplateBuffer(TemplateBuffer): pass class CUDATemplateBuffer(TemplateBuffer): def __init__( self, layout, inputs, make_kernel_render, workspace_size: int, template: "CUDATemplate", # type: ignore[name-defined] # noqa: F821 ): super().__init__(layout, inputs, make_kernel_render) # Global memory (in bytes) needed for this template. self.workspace_size = workspace_size self.template = template def get_workspace_size(self): return self.workspace_size if self.workspace_size is not None else 0 @dataclasses.dataclass class InputsKernel(Buffer): inputs: List[Buffer] def get_read_writes_input(self, x): return dependencies.StarDep(x.get_name()) def get_read_writes(self): star_dep = [] for input in self.inputs: if isinstance(input, list): star_dep.extend([self.get_read_writes_input(x) for x in input]) else: star_dep.append(self.get_read_writes_input(input)) return dependencies.ReadWrites( set(star_dep), {dependencies.StarDep(self.get_name())}, set(), [], None, op_counts=collections.Counter(), ) @classmethod def unwrap_storage_for_input(cls, x): if isinstance(x, TensorBox): x = if isinstance(x, StorageBox): x = if isinstance(x, BaseView) and not isinstance(x, ReinterpretView): x = ExternKernel.realize_input(x) if isinstance(x, TensorBox): # when converting to ReinterpretView fails in the # realize_input call above, the result will be wrapped # into TensorBox / StorageBox pair as a result of the # cls.copy_input call; so we should unwrap recursively return cls.unwrap_storage_for_input(x) assert isinstance(x, (Buffer, ReinterpretView)), x return x @staticmethod def unwrap_storage(inputs): inputs_new = [] for x in inputs: if isinstance(x, list): x = [InputsKernel.unwrap_storage_for_input(i) for i in x] else: x = InputsKernel.unwrap_storage_for_input(x) inputs_new.append(x) return inputs_new def is_extern(self): return True class NopKernel(InputsKernel): def is_no_op(self): return True class ConcatKernel(NopKernel): """ There isn't actually a real kernel for concat, we just change the storage for the upstream data. """ @classmethod def create(cls, inputs, dim): device = inputs[0].get_device() dtype = inputs[0].get_dtype() new_size = list(inputs[0].get_size()) offsets_start = [0] offsets_end = [new_size[dim]] assert 0 <= dim < len(new_size) for i in range(1, len(inputs)): input_size = inputs[i].get_size() offsets_start.append(new_size[dim]) assert len(input_size) == len(new_size) assert inputs[i].get_dtype() == dtype assert inputs[i].get_device() == device for j in range(len(new_size)): if j == dim: new_size[j] = new_size[j] + input_size[j] else: new_size[j] = V.graph.sizevars.guard_equals( new_size[j], input_size[j] ) offsets_end.append(new_size[dim]) output_stride = FlexibleLayout.contiguous_strides(new_size) # If any of the inputs is in CL format, use CL format for the output for i in range(len(inputs)): x = inputs[i] if is_storage_and_layout(x): layout = x.get_layout() if ( isinstance(layout, FixedLayout) and layout.is_channels_last_contiguous() ): # use CL stride for the output output_stride = make_channels_last_strides_for(new_size) break concat_kernel = ConcatKernel( name=None, layout=FixedLayout( device=device, dtype=dtype, size=new_size, stride=output_stride, ), inputs=[], ) kernel = StorageBox(concat_kernel) buffer_names = [] for i in range(len(inputs)): input_buffer = cls.realize_into( inputs[i], SliceView.create(kernel, dim, offsets_start[i], offsets_end[i]), ) concat_kernel.inputs.append(input_buffer) if isinstance(inputs[i].data, BaseView): input_unwrapped = inputs[i].data.unwrap_view() else: input_unwrapped = inputs[i].data if ( input_unwrapped.is_input_buffer() and inputs[i].get_device().type == "cuda" and not is_dynamic(input_buffer) ): buffer_names.append(input_buffer.get_name()) if len(buffer_names) > 1: V.graph.register_list(buffer_names) = V.graph.register_buffer(concat_kernel) concat_kernel.inputs = cls.unwrap_storage(concat_kernel.inputs) return kernel @classmethod def can_realize_into_without_copy(cls, src): if isinstance(src, TensorBox): # unwrap a TensorBox return cls.can_realize_into_without_copy( return isinstance(, FlexibleLayout) and not isinstance(, ExternKernelAlloc ) @classmethod def realize_into(cls, src, dst): # Attempt to turn this into a ReinterpretView rather than assert. # This has concessions around layout, as as_storage_and_layout # can cause us to go from flexible to fixed layout. if not isinstance(dst, ReinterpretView): if is_storage_and_layout(dst): storage, layout = as_storage_and_layout(dst) dst = ReinterpretView(storage, layout) assert isinstance(dst, ReinterpretView), dst if isinstance(src, TensorBox): # unwrap a TensorBox return cls.realize_into(, dst) if isinstance(src, StorageBox): src.realize() # ExternKernelAlloc has specific requirements for output layout, should create a copy assert hasattr(, "layout") if cls.can_realize_into_without_copy(src): = AliasedLayout(dst) return # introduce a copy pw = Pointwise.create( device=src.get_device(), dtype=src.get_dtype(), inner_fn=src.make_loader(), ranges=[ V.graph.sizevars.guard_equals(a, b) for a, b in zip(src.get_size(), dst.get_size()) ], ) return cls.realize_into(pw, dst) def should_allocate(self): return True @dataclasses.dataclass class ExternKernel(InputsKernel): constant_args: Tuple[Any, ...] = () kwargs: Dict[str, Any] = dataclasses.field(default_factory=dict) output_view: Optional[ReinterpretView] = None python_kernel_name: Optional[str] = None cpp_kernel_name: Optional[str] = None # FIXME: in some cases we sill need to explicitly pass in ordered_kwargs_for_cpp_kernel # We shouldn't need to do this since the information can be retrieved from op_overload._schema. ordered_kwargs_for_cpp_kernel: Iterable[str] = dataclasses.field( default_factory=list ) op_overload: Optional[ Union[torch._ops.OpOverload, torch._ops.HigherOrderOperator] ] = None arg_properties: Optional[List[Dict[str, Any]]] = None kwarg_properties: Optional[Dict[str, Dict[str, Any]]] = None def __init__( self, name, layout, inputs, constant_args=(), kwargs=None, output_view=None, python_kernel_name=None, cpp_kernel_name=None, ordered_kwargs_for_cpp_kernel=(), op_overload=None, ): super().__init__( name, layout, inputs, ) self.constant_args = constant_args self.kwargs = kwargs if kwargs else {} self.output_view = output_view self.python_kernel_name = python_kernel_name self.cpp_kernel_name = cpp_kernel_name self.ordered_kwargs_for_cpp_kernel = ordered_kwargs_for_cpp_kernel self.op_overload = op_overload self.collect_arg_kwarg_properties() def collect_arg_kwarg_properties(self): # if self.op_overload is torch._ops.OpOverload, we can use its schema to collect additional # information for args and kwargs, e.g. type and default value, to help with the cpp wrapper codegen if ( isinstance(self.op_overload, torch._ops.OpOverload) and not self.ordered_kwargs_for_cpp_kernel ): self.ordered_kwargs_for_cpp_kernel = [ for x in self.op_overload._schema.arguments if x.kwarg_only ] self.arg_properties = ( [ { "name":, "type": x.real_type, "default_value": x.default_value, } for x in self.op_overload._schema.arguments if not x.kwarg_only ] if isinstance(self.op_overload, torch._ops.OpOverload) else [{} for i in range(len(self.inputs))] ) self.kwarg_properties = ( { {"type": x.real_type, "default_value": x.default_value} for x in self.op_overload._schema.arguments if x.kwarg_only } if isinstance(self.op_overload, torch._ops.OpOverload) else {} ) def decide_layout(self): if isinstance(self.layout, FlexibleLayout): self.apply_constraint() self.freeze_layout() def codegen_comment(self, wrapper): origin_str, detailed_origin_str = get_kernel_metadata(self, wrapper) if origin_str: wrapper.writeline(origin_str) def codegen(self, wrapper): raise NotImplementedError() def get_kernel_name(self): return self.cpp_kernel_name if V.graph.cpp_wrapper else self.python_kernel_name @staticmethod def copy_input(x): pw = Pointwise.create( device=x.get_device(), dtype=x.get_dtype(), inner_fn=x.make_loader(), ranges=x.get_size(), origin_node=x.get_origin_node(), traceback=x.get_traceback(), ) pw.realize() return pw @classmethod def process_kernel(cls, kernel, *args, **kwargs): binded_args = {"args": args, "kwargs": kwargs} args_flat, args_spec = pytree.tree_flatten(binded_args) is_arg_tensor = [] tensor_args = [] non_tensor_args: List[Any] = [] for arg in args_flat: is_arg_tensor.append(isinstance(arg, IRNode)) if is_arg_tensor[-1]: tensor_args.append(arg) else: if isinstance(arg, sympy.Expr): arg = V.graph.sizevars.shape_env.create_symintnode(arg, hint=None) non_tensor_args.append(arg) def unflatten_args(new_tensor_args, new_non_tensor_args): result = [] it_tensors = iter(new_tensor_args) it_non_tensors = iter(new_non_tensor_args) for is_tensor in is_arg_tensor: if is_tensor: result.append(next(it_tensors)) else: result.append(next(it_non_tensors)) r = pytree.tree_unflatten(result, args_spec) return r.get("args", []), r.get("kwargs", {}) tensor_args = [cls.realize_input(x) for x in tensor_args] # freeze layout otherwise our output stride calculation might # become incorrect for x in tensor_args: if is_storage_and_layout(x): as_storage_and_layout(x, freeze=True) # We don't have generic shape formulas, so just burn in the # shapes and run an example input. # TODO(jansel): replace this with dynamic shape formulas example_args = [] # We need to retain the constant values of fake tensors that we originally # propagated the graph with, because for some operators running without a # constant would trigger an error / DataDependentException for x in tensor_args: if x.get_name() in V.graph.constants: example_args.append(V.graph.constants[x.get_name()]) else: example_args.append(ir_node_to_tensor(x, guard_shape=True)) new_args, new_kwargs = unflatten_args(example_args, non_tensor_args) example_output = kernel(*new_args, **new_kwargs) example_out_li = ( [example_output] if not isinstance(example_output, (list, tuple)) else example_output ) for t in example_out_li: if isinstance(t, torch.Tensor) and t.is_sparse: msg = "sparsity not handled. Please file issue for sparse inference weights." if stack_trace := V.graph.current_node.meta.get("stack_trace", None): msg = f"{msg} Found from : \n {stack_trace}" V.graph.disable_cudagraphs_reason = msg # TODO: Unconditionally do this, not just when example_output has # unbacked symbols if maybe_free_unbacked_symbols(example_output): example_output = V.graph.current_node.meta["val"] return example_output, tensor_args, non_tensor_args, unflatten_args @classmethod def convert_to_reinterpret_view(cls, x): """ In order to pass this to an extern kernel we need a ReinterpretView not a View. This allows us to avoid some unneeded copies. """ assert isinstance(x, BaseView) if isinstance(x, ReinterpretView): return x # NOTE: Don't use extract_read_writes here as it fails when # make_loader() inlines the computation x.unwrap_view().freeze_layout() index_args, var_ranges = dependencies.index_vars_squeeze( x.get_size(), prefix="r" ) range_vars = index_args[0] index = x.make_indexer()(range_vars) index = V.graph.sizevars.simplify_with_ranges(index, var_ranges) strides = V.graph.sizevars.stride_vars(index, range_vars) offset = V.graph.sizevars.offset_var(index, range_vars) expected = sympy_dot(range_vars, strides) + offset if index != expected: log.debug( "convert_to_reinterpret_view failed: stride=%s offset=%s index=%s", strides, offset, index, ) raise NotImplementedError() return ReinterpretView(, layout=FixedLayout( device=x.get_device(), dtype=x.get_dtype(), size=x.get_size(), stride=strides, offset=offset, ), ) @classmethod def realize_input(cls, x): if x is None: return NoneAsConstantBuffer() if isinstance(x, (sympy.Expr, sympy.logic.boolalg.Boolean, int)): return ShapeAsConstantBuffer(x) if isinstance(x, Constant): return V.graph.add_tensor_constant( torch.tensor(x.value, dtype=x.get_dtype(), device=x.get_device()) ) if isinstance(x, ConstantBuffer): return x if isinstance(x, TensorBox): return cls.realize_input( if isinstance(x, ReinterpretView): return ReinterpretView(cls.realize_input(, x.get_layout()) if isinstance(x, BaseView): x.realize() if is_storage_and_layout(x.unwrap_view()): try: return cls.convert_to_reinterpret_view(x) except NotImplementedError: pass if isinstance(x, StorageBox): # TODO(jansel): impose layout preference on realized buffer x.realize() return x return cls.copy_input(x) @classmethod def require_stride1(cls, x): if is_storage_and_layout(x): if len(x.get_stride()) == 0: return x for stride in x.get_stride(): if stride == 1: return x return cls.copy_input(x) @classmethod def require_stride_order(cls, x, order): if x.get_numel() == 0: # Layout doesn't matter return x # require x to have the layout as strided_ordered as order if is_storage_and_layout(x): while isinstance(x.get_layout(), AliasedLayout): x = x.get_layout().view if isinstance(x.get_layout(), FlexibleLayout): # fix flexiblelayout to be FixedLayout with stride_order as_storage_and_layout( x, freeze=True, want_contiguous=False, stride_order=order ) return x elif isinstance( x.get_layout(), FixedLayout ) and x.get_layout().is_stride_ordered(order): return x elif isinstance(x.get_layout(), MutationLayout): if isinstance(x.get_layout().real_layout(), FlexibleLayout): raise AssertionError( "the MutationLayout's real layout shouldn't be FlexibleLayout" ) elif isinstance( x.get_layout().real_layout(), FixedLayout ) and x.get_layout().real_layout().is_stride_ordered(order): return x # TODO - Storage to InputBuffer if isinstance(x, InputBuffer) and x.get_layout().is_stride_ordered(order): return x if ( isinstance(x, TensorBox) and isinstance(, BaseView) and not isinstance(, ReinterpretView) and is_storage_and_layout(x.unwrap_view()) and not isinstance(x.unwrap_view().data, ExternKernelAlloc) ): try: = cls.convert_to_reinterpret_view( return cls.require_stride_order(x, order) except NotImplementedError: pass x = cls.copy_input(x) as_storage_and_layout(x, freeze=True, want_contiguous=False, stride_order=order) assert is_stride_order_storage_and_layout(x, order) return x @classmethod def require_channels_last(cls, x): return cls.require_stride_order(x, NHWC_STRIDE_ORDER) @classmethod def require_contiguous(cls, x): return cls.require_stride_order(x, list(reversed(range(len(x.get_size()))))) def apply_constraint(self): pass def codegen_const_args(self): return map(V.graph.wrapper_code.val_to_arg_str, self.constant_args) def codegen_args(self): args = [] for i, x in enumerate(self.inputs): if isinstance(x, list): names = [i.codegen_reference() for i in x] codegen_reference = f'[{", ".join(names)}]' args.append(codegen_reference) else: if V.graph.cpp_wrapper: assert self.arg_properties and i < len( self.arg_properties ), "Invalid arg_properties accessing" type_ = self.arg_properties[i].get("type") args.append( V.graph.wrapper_code.val_to_cpp_arg_str( # type: ignore[arg-type] type_, x, self.is_legacy_abi_kernel() ) ) else: args.append(x.codegen_reference()) args.extend(self.codegen_const_args()) return args def get_kwargs_value(self, arg_name): if arg_name in self.kwargs: return self.kwargs.get(arg_name) if self.kwarg_properties and self.kwarg_properties.get(arg_name): return self.kwarg_properties.get(arg_name).get("default_value") # type: ignore[union-attr] else: raise AssertionError(f"{arg_name} not in self.kwarg_properties") def is_legacy_abi_kernel(self): return False def codegen_kwargs(self): if V.graph.cpp_wrapper: kwargs = [] for arg_name in self.ordered_kwargs_for_cpp_kernel: v = self.get_kwargs_value(arg_name) if isinstance(v, sympy.Expr): kwargs.append(v) else: type_ = ( self.kwarg_properties.get(arg_name).get("type") # type: ignore[union-attr] if self.kwarg_properties and arg_name in self.kwarg_properties else None ) kwargs.append( V.graph.wrapper_code.val_to_cpp_arg_str( # type: ignore[arg-type] type_, v, self.is_legacy_abi_kernel() ) ) else: kwargs = [ f"{k}={V.graph.wrapper_code.val_to_arg_str(v)}" # type: ignore[misc] for k, v in self.kwargs.items() ] return kwargs def codegen_size_asserts(self, wrapper): if config.size_asserts and not V.graph.cpp_wrapper: size = V.graph.wrapper_code.codegen_shape_tuple(self.get_size()) stride = V.graph.wrapper_code.codegen_shape_tuple(self.get_stride()) wrapper.writeline( f"assert_size_stride({self.get_name()}, {size}, {stride})" ) def get_group_stride(self): """ get output sizes and strides, for template_codegen """ _size = self.get_size() _stride = self.get_stride() # iter_ranges = _size of output tensor, reduce_range = [] because no reduction return [_size, []], _stride def canonicalize(self): """ Manually get canonicalization of the output index """ # manually generate index formula for conv sizevars = V.graph.sizevars sizes = self.get_size() strides = self.get_stride() strides = [sizevars.size_hint(x) for x in strides] index_vars = [sympy_index_symbol(f"d{i}") for i in range(len(sizes))] # reorder index vars according to stride index_order = sorted(range(len(strides)), key=strides.__getitem__, reverse=True) lookup = {pos: idx for idx, pos in enumerate(index_order)} order = [lookup[i] for i in range(len(lookup))] index_vars = [index_vars[i] for i in order] indexer = self.make_indexer() index = indexer(index_vars) new_sizes, reindex, prune = V.graph.sizevars._simplify_loops( index_vars, sizes, [index] ) # assign new variables each dimension to deal with numbering mismatches # d0, d1, d2 could become d0, d2 -- which won't match d0, d1 _, add_var = var_builder("c") replacement = dict(zip(index_vars, reindex([add_var(x) for x in new_sizes]))) index = sympy_subs(sympy.expand(index), replacement) # type: ignore[arg-type] return index, tuple(new_sizes) def get_unbacked_symbol_uses(self) -> Set[sympy.Symbol]: # NB: It's not necessary to check regular inputs as we automatically # have dependencies on them r = set() for arg in self.constant_args: r |= maybe_free_unbacked_symbols(arg) for arg in self.kwargs.values(): r |= maybe_free_unbacked_symbols(arg) return r def __str__(self): kernel_name = getattr(self, "python_kernel_name", None) lines = [ f"python_kernel_name={kernel_name!r}", ] lines += [ f"{}={getattr(self,}" for field in dataclasses.fields(self) ] lines.append(f"origin_node={self.origin_node!r}") return self.str_helper(lines) __repr__ = __str__ @dataclasses.dataclass class ExternKernelOut(ExternKernel): def codegen(self, wrapper): self.codegen_comment(wrapper) args = [*self.codegen_args(), *self.codegen_kwargs()] wrapper.generate_extern_kernel_out( self.output_view, self.codegen_reference(), args, self.get_kernel_name(), ) def __init__( self, layout, inputs, constant_args=(), kwargs=None, output_view=None, python_kernel_name=None, cpp_kernel_name=None, ordered_kwargs_for_cpp_kernel=(), op_overload=None, ): super().__init__( None, layout, self.unwrap_storage(inputs), constant_args, kwargs or {}, None, python_kernel_name, cpp_kernel_name, ordered_kwargs_for_cpp_kernel, op_overload, ) = V.graph.register_buffer(self) def should_allocate(self): return True class RandomSeeds(ExternKernelOut): def __init__(self, count: int, device: torch.device): limits = torch.iinfo(torch.int64) super().__init__( layout=FixedLayout( device=device, dtype=torch.int64, size=[count], ), inputs=[], constant_args=[limits.min, limits.max, [count]], python_kernel_name="aten.randint.low_out", cpp_kernel_name="at::randint_out", ) class ExternKernelAlloc(ExternKernel): def codegen(self, wrapper): self.codegen_comment(wrapper) args = [*self.codegen_args(), *self.codegen_kwargs()] V.graph.wrapper_code.generate_extern_kernel_alloc(self, args) if isinstance(self.layout, Layout): self.codegen_size_asserts(wrapper) def __init__( self, layout, inputs, constant_args=(), kwargs=None, python_kernel_name=None, cpp_kernel_name=None, ordered_kwargs_for_cpp_kernel=(), op_overload=None, ): super().__init__( None, layout, self.unwrap_storage(inputs), constant_args, kwargs or {}, None, python_kernel_name, cpp_kernel_name, ordered_kwargs_for_cpp_kernel, op_overload, ) = V.graph.register_buffer(self) def should_allocate(self): return False def apply_constraint(self): raise NotImplementedError class UserDefinedTritonKernel(ExternKernel): def get_kernel_and_configs(self): from triton.runtime.autotuner import Autotuner from torch._higher_order_ops.triton_kernel_wrap import kernel_side_table kernel = kernel_side_table.get_kernel(self.kernel_idx) configs = [] if isinstance(kernel, Autotuner): configs = kernel.configs kernel = kernel.fn return kernel, configs def codegen(self, wrapper): kernel, configs = self.get_kernel_and_configs() # Definition of kernel new_name, triton_meta = wrapper.define_user_defined_triton_kernel( kernel, configs, self.kwargs ) args = self.codegen_kwargs() if V.graph.cpp_wrapper: # in C++ wrapper, we don't pass constexpr args, as they don't # get added as parameters to the PTX code compiled from the # user-defined Triton kernel (only non-constexpr args do) args = [arg for i, arg in enumerate(args) if i not in kernel.constexprs] # Call to kernel self.codegen_comment(wrapper) wrapper.generate_user_defined_triton_kernel( new_name, self.grid, configs, args, triton_meta, ) def should_allocate(self): return False def has_side_effects(self): # UserDefinedTritonKernel does not return anything, but rather # modifies input in place, do not let it get DCEd return True def get_unbacked_symbol_defs(self) -> Set[sympy.Symbol]: return set() def get_mutation_names(self): return [] def __init__(self, *, kernel_idx, grid, kernel_args): inputs = [] kwargs = dict() constant_args = [] for k, v in kernel_args.items(): if isinstance(v, TensorBox): t = InputsKernel.unwrap_storage_for_input(self.realize_input(v)) inputs.append(t) kwargs[k] = t else: constant_args.append(v) kwargs[k] = v assert len(inputs) != 0 device = inputs[0].get_device() super().__init__( None, NoneLayout(device), # type: ignore[arg-type] inputs, tuple(constant_args), kwargs, ) = V.graph.register_buffer(self) self.kernel_idx = kernel_idx self.grid = grid kernel, _ = self.get_kernel_and_configs() # If we are autotuning, not all arguments will be passed self.ordered_kwargs_for_cpp_kernel = [ arg for arg in kernel.arg_names if arg in kernel_args ] mark_node_as_mutating( self, *[a for a in kernel_args.values() if isinstance(a, TensorBox)] ) def get_alias_names(self): return [i.get_name() for i in self.inputs] def mark_node_as_mutating(cur_buffer, *mutated_ops): """ Allows ops in mutated_ops to be marked as being mutated as well as indicates to the scheduler that these ops depend on cur_buffer. """ for op in mutated_ops: assert isinstance(op, IRNode), op V.graph.mark_buffer_mutated(op.get_name()) assert hasattr(op, "layout") MutationOutput(op.layout, op, cur_buffer) class MutationOutput(ExternKernel): def get_mutation_names(self): return [self.inputs[0].get_name()] def __init__(self, layout, input, parent): super().__init__(None, layout, [input, parent], ()) = V.graph.register_buffer(self) def should_allocate(self): return False def is_no_op(self): return True def has_side_effects(self): return True def get_alias_names(self): return [self.inputs[0].get_name()] class InplaceBernoulliFallback(ExternKernel): """ This needs to be a custom class to handle mutation properly """ def codegen(self, wrapper): (x,) = (t.codegen_reference() for t in self.inputs) wrapper.writeline( f"{self.get_kernel_name()}({x}, {', '.join(map(repr, self.constant_args))}){wrapper.ending}" ) def should_allocate(self): return False def get_mutation_names(self): return [self.inputs[0].get_name()] def get_unbacked_symbol_defs(self) -> Set[sympy.Symbol]: return set() def __init__(self, x, *constant_args): super().__init__( None, NoneLayout(x.get_device()), # type: ignore[arg-type] self.unwrap_storage([x]), constant_args, ) = V.graph.register_buffer(self) self.python_kernel_name = "aten.bernoulli_" self.cpp_kernel_name = ( "aoti_torch_bernoulli_" if config.abi_compatible else "at::native::bernoulli_" ) mark_node_as_mutating(self, x) # Used to deal with torch.complex types class InplaceCopyFallback(ExternKernel): """ This needs to be a custom class to handle mutation properly """ def codegen(self, wrapper): (dst, src, non_blocking) = self.codegen_args() wrapper.writeline( f"{self.get_kernel_name()}({dst}, {src}, {non_blocking}){wrapper.ending}" ) def should_allocate(self): return False def get_mutation_names(self): return [self.inputs[0].get_name()] def get_unbacked_symbol_defs(self) -> Set[sympy.Symbol]: return set() def __init__( self, layout, inputs, constant_args, ): super().__init__( None, layout, inputs, constant_args, python_kernel_name="aten.copy_", cpp_kernel_name=( "aoti_torch_copy_" if config.abi_compatible else "at::_ops::copy_::call" ), ) = V.graph.register_buffer(self) @classmethod def create(cls, dst, src, non_blocking: bool = False): inputs = [cls.realize_input(t) for t in [dst, src]] constant_args = (non_blocking,) result = InplaceCopyFallback( NoneLayout(dst.get_device()), # type: ignore[arg-type] inputs, constant_args, ) mark_node_as_mutating(result, dst) return result class MutatingFirstArgExternKernel(ExternKernel): """ This needs to be a custom class to handle mutation properly """ def codegen(self, wrapper): argrefs = [ *(t.codegen_reference() for t in self.inputs), *map(repr, self.constant_args), ] wrapper.writeline( f"{self.get_kernel_name()}({', '.join(argrefs)}){wrapper.ending}" ) def should_allocate(self): return False def get_mutation_names(self): return [self.inputs[0].get_name()] def get_unbacked_symbol_defs(self) -> Set[sympy.Symbol]: return set() def has_side_effects(self): return True class ResizeStorageBytes(MutatingFirstArgExternKernel): def __init__(self, variable, new_size): assert isinstance(new_size, int), "TODO: dynamic shapes" super().__init__( None, NoneLayout(variable.get_device()), # type: ignore[arg-type] self.unwrap_storage([variable]), constant_args=(new_size,), ) V.graph.mark_buffer_mutated(variable.get_name()) = V.graph.register_buffer(self) self.python_kernel_name = "inductor_ops.resize_storage_bytes_" self.cpp_kernel_name = "torch::inductor::resize_storage_bytes_" V.graph.never_reuse_buffers.add( mark_node_as_mutating(self, variable) class ScatterFallback(ExternKernel): """ This needs to be a custom class to handle mutation properly. This class handles both aten.scatter_ and aten.scatter_reduce_. It also handle the case `src` being a scalar properly. """ def codegen(self, wrapper): reduce = self.kwargs["reduce"] if V.graph.cpp_wrapper: # Follow aten/src/ATen/native/ReductionType.h:get_operator_enum get_operator_enum = {"add": "sum", "multiply": "prod"} if reduce in get_operator_enum: reduce = get_operator_enum[reduce] if self.src_is_tensor: (x, index, src) = (t.codegen_reference() for t in self.inputs) else: (x, index) = (t.codegen_reference() for t in self.inputs) src = self.constant_args[1] wrapper.generate_scatter_fallback( x, [x, self.constant_args[0], index, src], self.get_kernel_name(), self.python_kernel_name, self.src_is_tensor, reduce, self.codegen_kwargs(), ) def should_allocate(self): return False def get_cpp_kernel(self): reduce = self.kwargs["reduce"] if self.python_kernel_name == "aten.scatter_": if self.src_is_tensor: kernel = ( "at::scatter_out" if reduce is None else "at::scatter_reduce_out" ) else: assert ( reduce is None ), "Expect reduce to be None for aten.scatter_ with scalar src" kernel = "at::scatter_out" else: assert ( reduce is not None ), "Expect reduce to be not None for aten.scatter_reduce_" kernel = "at::scatter_reduce_out" return kernel def get_mutation_names(self): return [self.inputs[0].get_name()] def get_unbacked_symbol_defs(self) -> Set[sympy.Symbol]: return set() def __init__( self, op_overload, python_kernel_name, x, dim: int, index, src, *, reduce: Optional[str] = None, include_self: bool = True, ): assert python_kernel_name in {"aten.scatter_", "aten.scatter_reduce_"} self.src_is_tensor = isinstance(src, TensorBox) constant_args: Tuple[Any, ...] if self.src_is_tensor: tensors = [self.realize_input(t) for t in [x, index, src]] constant_args = (dim,) else: tensors = [self.realize_input(t) for t in [x, index]] constant_args = (dim, src) super().__init__( None, NoneLayout(x.get_device()), # type: ignore[arg-type] self.unwrap_storage(tensors), constant_args, {"reduce": reduce, "include_self": include_self}, python_kernel_name=python_kernel_name, ordered_kwargs_for_cpp_kernel=["reduce", "include_self"], op_overload=op_overload, ) self.cpp_kernel_name = self.get_cpp_kernel() = V.graph.register_buffer(self) mark_node_as_mutating(self, x) class IndexPutFallback(ExternKernel): """ This needs to be a custom class to handle mutation and indices properly """ def codegen(self, wrapper): (x, values, *valid_indices) = (t.codegen_reference() for t in self.inputs) indices = [] iter_valid_indices = iter(valid_indices) for i, _ in enumerate(self.indices): if self.indices[i] is not None: indices.append(next(iter_valid_indices)) else: indices.append(V.graph.wrapper_code.none_str) wrapper.generate_index_put_fallback( self.get_kernel_name(), x, indices, values, *self.codegen_const_args() ) def should_allocate(self): return False def get_mutation_names(self): return [self.inputs[0].get_name()] def get_unbacked_symbol_defs(self) -> Set[sympy.Symbol]: return set() def __init__(self, op_overload, x, indices, values, accumulate): self.indices = indices valid_indices = [i for i in indices if i is not None] tensors = [self.realize_input(x) for x in [x, values, *valid_indices]] cpp_kernel_name = ( "aoti_torch_index_put_out" if config.abi_compatible else "at::index_put_out" ) super().__init__( None, NoneLayout(x.get_device()), # type: ignore[arg-type] self.unwrap_storage(tensors), (accumulate,), python_kernel_name="aten.index_put_", cpp_kernel_name=cpp_kernel_name, op_overload=op_overload, ) = V.graph.register_buffer(self) mark_node_as_mutating(self, x) class DeviceCopy(ExternKernelOut): @classmethod def create(cls, x, device): if ( not x.is_extern() and all( ( in V.graph.constants and isinstance(r, dependencies.MemoryDep)) for r in x.get_reads() ) and not config.aot_inductor.use_runtime_constant_folding ): return x.constant_to_device(device) V.graph.add_device_info(device) V.graph.add_device_info(x.get_device()) developer_warning("DeviceCopy in input program") return DeviceCopy( FlexibleLayout( device=device, dtype=x.get_dtype(), size=x.get_size(), ), [cls.realize_input(x)], ) def codegen(self, wrapper): args = self.codegen_args() assert len(args) == 1 if self.output_view: wrapper.codegen_device_copy(args[0], self.output_view.codegen_reference()) else: wrapper.codegen_device_copy(args[0], self.codegen_reference()) class DynamicScalar(ExternKernel): """ The result of a call to aten._local_scalar_dense. """ def get_reads(self): return () def should_allocate(self): return False # TODO: handle bools carefully def __init__(self, sym, data): data.realize() super().__init__(None, NoneLayout(torch.device("cpu")), self.unwrap_storage([data])) # type: ignore[arg-type] if isinstance(sym, sympy.Symbol): self.sym = sym self.is_bool = False else: # Special case for boolean. For Reasons(TM), we don't represent # boolean variables directly in sympy; instead, we generate an # indicator integer variable which we then convert to a boolean by # testing i0 == 1. We have to identify the underlying indicator # variable, and then bind i0 to the appropriate integer value # based on the runtime boolean. assert isinstance(sym, sympy.Eq), sym assert isinstance(sym.args[0], sympy.Symbol), sym assert sym.args[1] == 1, sym self.sym = sym.args[0] self.is_bool = True def get_unbacked_symbol_defs(self) -> Set[sympy.Symbol]: return {self.sym} def codegen(self, wrapper): wrapper.codegen_dynamic_scalar(self) class AssertScalar(ExternKernel): """ The result of a call to aten._assert_scalar """ def get_reads(self): return () def should_allocate(self): return False def __init__(self, scalar, msg): super().__init__( # Buffer(name, layotu) None, NoneLayout(torch.device("cpu")), # type: ignore[arg-type] # InputsKernel(inputs) [], ) # type: ignore[arg-type] self.scalar = scalar self.msg = msg def has_side_effects(self): return True def get_unbacked_symbol_uses(self): return free_unbacked_symbols(self.scalar) def codegen(self, wrapper): if V.graph.cpp_wrapper: pass else: wrapper.writeline( f"if not {V.graph.wrapper_code.codegen_python_sizevar(self.scalar)}:" ) wrapper.writeline(f" raise RuntimeError({repr(self.msg)})") # No one should ever use this buffer, but for uniformity # define the variable and assign it None wrapper.writeline(f"{self.get_name()} = None") @dataclasses.dataclass class ExternKernelNode: name: str node: export_schema.Node has_c_shim = { aten._embedding_bag.default, aten._fft_c2c.default, aten._scaled_dot_product_efficient_attention.default, aten._scaled_dot_product_flash_attention.default, aten._scaled_mm.default, aten.addmm.out, aten.bmm.out, aten.copy_.default,, aten.repeat_interleave.Tensor, aten.nonzero.default, aten.view.dtype, aten.view_as_real.default, } def get_aten_cpp_kernel_name(kernel): # Calling with the default kernel name can lead to ambiguous behavior like the following example. # repeat_interleave(const at::Tensor & repeats, c10::optional output_size=c10::nullopt) # repeat_interleave(const at::Tensor & self, int64_t repeats, # c10::optional dim=c10::nullopt, c10::optional output_size=c10::nullopt) assert ( isinstance(kernel, torch._ops.OpOverload) and kernel.namespace == "aten" ), "Invalid aten kernel" opname = ( kernel.__name__.split(".")[0] if kernel._overloadname == "default" else kernel.__name__.replace(".", "_") ) return f"at::_ops::{opname}::call" class FallbackKernel(ExternKernelAlloc): args_default_value: List[Dict[str, Any]] def __init__( self, layout, kernel, tensor_args, nontensor_args, unflatten_args, kwargs=None, ): super().__init__( layout, tuple(tensor_args), tuple(nontensor_args), op_overload=kernel, ) # We need output buffers for generating kernel arguments in the # abi-compatible mode, where we retrieve outputs by pass each individual # output through the abi-compatible interface. self.outputs: Sequence[Any] = [] self.use_runtime_dispatch = False self.abi_compatible_kernel = None assert isinstance( kernel, ( torch._ops.OpOverload, torch._ops.HigherOrderOperator, ), ), f"Fails to create FallbackKernel for {kernel}: {type(kernel)} not supported" self.op_overload = kernel self.unflatten_args = unflatten_args self.kwargs = {} if kwargs is None else kwargs V.graph.warn_fallback(self.python_kernel_name) # args that are aliased self.alias_names: List[str] = [] # args that are mutated AND returned from the op self.mutation_names: List[str] = [] if isinstance(self.op_overload, torch._ops.HigherOrderOperator): # We assume here that HOPs with FallbackKernel are functional. # This may not always be true! HOPs must individually opt-in to # FallbackKernel, so please check this if you opt-in. return if "_c10d_functional" in # _c10d_functional kernels are lowered into _CollectiveKernel which # derives from FallbackKernel for the cpp codegen. The kernels # don't pass the can_auto_functionalize check, but their mutation # is handled properly by _CollectiveKernel. return schema = self.op_overload._schema # NOTE: [FallbackKernel supported operators] # We only support three types of operators: # - functional ops # - view ops # - inplace aten ops # - mutating ops that are auto-functionalizable. That is, # the operator may mutate any number of inputs, but its outputs # may not alias any of the inputs. # # The unsupported cases usually do not show up here (because # AOTAutograd functionalized them away); the only way for an in-place # op to show up here is if a lowering or pass introduced it. if torch._library.utils.mutates_and_returns_first_arg(self.op_overload): self.mutation_names.append(tensor_args[0].get_name()) return if schema.is_mutable and not can_auto_functionalize(kernel): raise NotImplementedError( f"NYI: Can't generate FallbackKernel for {kernel}" ) schema_args = schema.arguments args, kwargs = self.unflatten_args(self.inputs, self.constant_args) def handle_aliasing_and_mutation(info, arg): # Assertions to make sure we didn't mismatch args if isinstance(info.type, torch.ListType): assert isinstance(arg, (list, tuple)) is_optional_tensor = isinstance( info.type, torch.OptionalType ) and isinstance(info.type.getElementType(), torch.TensorType) if is_optional_tensor or isinstance(info.type, torch.TensorType): # PyTorch also accepts None and scalar types for args marked as "Tensor". # We're not going to check all of them here. assert not isinstance(arg, (tuple, list)) if arg is None: return if info.alias_info is None: return # can_auto_functionalize already filters out mutable List[Tensor]. # We can support this in the future, but this is very uncommon. assert isinstance(info.type, torch.TensorType) or is_optional_tensor self.alias_names.append(arg.get_name()) if info.alias_info.is_write: mark_node_as_mutating(self, arg) for info, arg in torch._library.utils.zip_schema(schema, args, kwargs): handle_aliasing_and_mutation(info, arg) def set_cpp_kernel(self, kernel): from .codegen.wrapper import get_cpp_op_schema assert ( not kernel._schema.is_mutable ), f"mutable {kernel.__name__} is not supported with cpp_wrapper" # These checks are here because ops that return aliasing tensors will # return type Tensor& instead of Tensor, but codegen will always write # type Tensor on the LHS. def is_not_write(arg): return arg.alias_info is None or not arg.alias_info.is_write assert all( is_not_write(x) for x in kernel._schema.arguments ), f"{kernel.__name__} with alias_info arguments is not supported with cpp_wrapper" assert all( is_not_write(x) for x in kernel._schema.returns ), f"{kernel.__name__} with alias_info returns is not supported with cpp_wrapper" self.cpp_kernel_name = self.cpp_kernel_overload_name = kernel._schema.overload_name self.cpp_kernel_key = f"{self.cpp_kernel_name.replace('::', '_')}_{self.cpp_kernel_overload_name}" # type: ignore[union-attr] self.cpp_op_schema = get_cpp_op_schema(kernel) self.init_args_default_value(kernel._schema) def is_legacy_abi_kernel(self): return ( config.c_shim_version == "1" and "_scaled_dot_product_flash_attention" in str(self.python_kernel_name) ) def init_args_default_value(self, schema): self.args_default_value = [ { "name":, "type": x.real_type, "value": x.default_value, } for x in schema.arguments if not x.kwarg_only ] def get_pos_arg_value(self, pos, kwargs): # positional args may be provided in kwargs pos_arg_name = self.args_default_value[pos]["name"] if pos_arg_name in kwargs: log.debug( "Found argument %s with value %s from kwargs", pos_arg_name, kwargs[pos_arg_name], ) return kwargs[pos_arg_name] assert hasattr( self, "args_default_value" ), "self.args_default_value has to be provided" assert pos < len( self.args_default_value ), f"expected the index {pos} to be smaller than len(self.args_default_value): {len(self.args_default_value)}" arg_default_value = self.args_default_value[pos]["value"] log.debug( "Use default value %s for argument %s", arg_default_value, pos_arg_name ) return arg_default_value def codegen_args(self): @dataclasses.dataclass class Shim: ref: Any def __repr__(self): return self.ref tensor_args = [Shim(x.codegen_reference()) for x in self.inputs] args, kwargs = self.unflatten_args(tensor_args, self.constant_args) # Now we setup abi_compatible_kernel after self.python_kernel_name # and kwargs are adjusted appropriately. # For sdpa, we need the v2 version since v1 didn't consider optional arg # FIXME: no need to do this after we switch to the torchgen-ed C shim self.abi_compatible_kernel = ( f"{self.cpp_kernel_name}_v2" if self.cpp_kernel_name in {"at::_scaled_dot_product_flash_attention"} and config.c_shim_version == "1" else self.cpp_kernel_name ) if V.graph.cpp_wrapper and isinstance(self.op_overload, torch._ops.OpOverload): args = [ V.graph.wrapper_code.val_to_cpp_arg_str( param.real_type, x, self.is_legacy_abi_kernel() ) for param, x in zip(self.op_overload._schema.arguments, args) ] else: args = [V.graph.wrapper_code.val_to_arg_str(x) for x in args] # Previously, we want to maintain forward-compatibility by skipping # default args in the serialized artifacts in fbcode. However, # some of our shim interfaces require default values being set. # Discussed with Sherlock offline and we decided to allow serializing # default args into the C++ wrapper code for now. We will refine this # part if we see real FC requirement. More details related to FC # can be found at: # if V.graph.cpp_wrapper and hasattr(self, "args_default_value"): self.fill_non_provided_args(args, kwargs, convert_val_to_str=True) # let self.codegen_kwargs handle kwargs self.kwargs.update(kwargs) return args @staticmethod def find_device(tensor_args, example_output): if tensor_args: return tensor_args[0].get_device() if isinstance(example_output, torch.Tensor): return example_output.device if isinstance(example_output, (list, tuple)): devices = {FallbackKernel.find_device(None, x) for x in example_output} # Remove None devices = [device for device in devices if device] if len(devices) == 1: return devices[0] for device in devices: if device.type == "cuda": return device return devices[0] return None def has_side_effects(self): if isinstance(self.op_overload, torch._ops.HigherOrderOperator): return False return get_schema_info(self.op_overload).is_mutable() def get_alias_names(self): return self.alias_names def get_mutation_names(self): assert len(self.mutation_names) <= 1 return self.mutation_names def fill_non_provided_args(self, args, kwargs, convert_val_to_str=False): assert isinstance(args, (list, tuple)) if isinstance(args, tuple): args = list(args) assert hasattr(self, "args_default_value") n_args = len(args) n_pos_args = len(self.args_default_value) # For cpp wrapper, if some positional args are not provided, we need to check # if they're in the kwargs or use their default value if n_args < n_pos_args: log.debug( "%s has %d unprovided positional arguments. " "Will check if they are in the keyword arguments or will use default values.", self.op_overload, n_pos_args - n_args, ) pos_args = [ self.get_pos_arg_value(i, kwargs) for i in range(n_args, n_pos_args) ] if convert_val_to_str: pos_args = [V.graph.wrapper_code.val_to_arg_str(x) for x in pos_args] args.extend(pos_args) return args # ProxyExecutor Design Note # We export the ExternFallbackNodes (for custom ops) into a serialized file # and run it with a host side proxy executor to address the ABI problem # This is currently only implemented for fbcode. Eventually, we will also make this work for OSS. # Detailed design doc can be found at # def export_extern_kernel_node(self): assert isinstance(self, FallbackKernel) args, kwargs = self.unflatten_args(self.inputs, self.constant_args) args = self.fill_non_provided_args(args, kwargs) ordered_kwargs = [ kwargs.get(key, None) for key in self.ordered_kwargs_for_cpp_kernel ] serializer = GraphModuleSerializer(None, None) # type: ignore[arg-type] named_arguments = serializer.serialize_inputs(self.op_overload, args, kwargs) # type: ignore[arg-type] # serialize_outputs def handle_single_output(return_type, output): if isinstance(return_type, torch.TensorType): # For single Tensor out = output if isinstance(output, (list, tuple)): assert len(output) == 1 out = output[0] return export_schema.Argument.create( as_tensor=export_schema.TensorArgument(name=out.get_name()) ) elif isinstance(return_type, torch.ListType) and isinstance( return_type.getElementType(), torch.TensorType ): # For single TensorList return export_schema.Argument.create( as_tensors=[ export_schema.TensorArgument(name=out.get_name()) for out in output ] ) else: raise RuntimeError(f"Unsupported return type {type(return_type)}") target = self.op_overload returns = target._schema.returns # type: ignore[union-attr] if len(returns) == 1: return_type = returns[0].real_type output_arguments = [handle_single_output(return_type, self.outputs)] else: # For tuple returns, e.g "-> (Tensor, Tensor)" or "-> (Tesnor, Tensor[])" assert isinstance(self.outputs, tuple) assert len(returns) == len(self.outputs) output_arguments = [ handle_single_output(return_schema.real_type, output) for return_schema, output in zip(returns, self.outputs) ] node = ExternKernelNode( name=self.get_name(), node=export_schema.Node(, # type: ignore[union-attr] inputs=named_arguments, outputs=output_arguments, metadata={}, ), ) V.graph.extern_kernel_nodes.append(node) return [*args, *ordered_kwargs] def codegen(self, wrapper): kernel = self.op_overload if kernel.namespace == "aten": # type: ignore[union-attr] # Aten Fallback Ops assert isinstance(kernel, torch._ops.OpOverload) if V.graph.cpp_wrapper: if ( config.is_fbcode() and kernel not in has_c_shim # C shim v2 is torchgen-ed, which should cover all aten ops. # If you do hit a missed op, please update and config.c_shim_version == "1" ): log.warning( "%s is missing a c-shim implementation, using proxy executor as fallback", kernel, ) self.use_runtime_dispatch = True self.set_cpp_kernel(kernel) else: self.cpp_kernel_name = get_aten_cpp_kernel_name(kernel) schema = kernel._schema self.init_args_default_value(schema) else: self.python_kernel_name = str(kernel) elif isinstance(kernel, torch._ops.HigherOrderOperator): self.python_kernel_name = f"torch.ops.higher_order.{kernel.__name__}" else: # For non-aten OpOverload, i.e. custom ops if V.graph.cpp_wrapper: self.use_runtime_dispatch = True self.set_cpp_kernel(kernel) else: self.python_kernel_name = f"{kernel.__module__.replace('._ops.', '.ops.')}.{kernel.__name__}" # type: ignore[union-attr] if self.use_runtime_dispatch: self.codegen_comment(wrapper) exported_args = None args = None if config.is_fbcode() and V.graph.cpp_wrapper: exported_args = self.export_extern_kernel_node() else: args = [*self.codegen_args(), *self.codegen_kwargs()] wrapper.generate_extern_kernel_alloc_and_find_schema_if_needed( self.get_name(), self.get_kernel_name(), args, self.cpp_op_schema, self.cpp_kernel_key, self.cpp_kernel_overload_name, self.op_overload, exported_args, self.outputs, ) else: self.codegen_comment(wrapper) args = [*self.codegen_args(), *self.codegen_kwargs()] V.graph.wrapper_code.generate_fallback_kernel(self, args) if isinstance(self.layout, Layout): self.codegen_size_asserts(wrapper) @staticmethod def tensor_to_layout(output: torch.Tensor): return FixedLayout( output.device, output.dtype, convert_shape_to_inductor(output.size()), convert_shape_to_inductor(output.stride()), ) @classmethod def create(cls, kernel, *args, **kwargs): fake_incorrect_kernels = (aten._fused_moving_avg_obs_fq_helper_functional,) context = ( V.graph.fake_mode if kernel not in fake_incorrect_kernels else nullcontext() ) with context: ( example_output, tensor_args, non_tensor_args, unflatten_args, ) = cls.process_kernel(kernel, *args, **kwargs) device = cls.find_device(tensor_args, example_output) assert device, "Not sure where to find device info" packed = cls( MultiOutputLayout(device), kernel, tensor_args, non_tensor_args, unflatten_args, ) def generate_output(output, indices): if isinstance(output, (list, tuple)): return type(output)( generate_output(output[i], indices + [(type(output), i)]) for i in range(len(output)) ) elif isinstance(output, dict): return { key: generate_output(val, indices + [(type(output), key)]) for key, val in output.items() } elif isinstance(output, torch.Tensor): return MultiOutput( cls.tensor_to_layout(output), packed, indices, ) elif isinstance(output, int): return output elif isinstance(output, torch.SymInt): return output.node.expr else: assert ( output is None ), f"FallbackKernel output type {type(output)} is not supported" return None outputs = generate_output(example_output, []) if isinstance(outputs, (list, tuple, dict)): packed.outputs = outputs # type: ignore[assignment] else: packed.outputs = [outputs] return outputs def apply_constraint(self): return super().apply_constraint() @dataclasses.dataclass class ComplexView(FallbackKernel): """View a complex number as two dtyped numbers or vice versa""" def should_allocate(self): return False def get_alias_names(self): # Signal to codegen that our output buffer isn't safe to reuse return [self.inputs[0].get_name()] def __init__( self, layout, kernel, tensor_args, nontensor_args, unflatten_args, ): super().__init__( layout, kernel, tensor_args, nontensor_args, unflatten_args, ) @dataclasses.dataclass class MultiOutputLayout(IRNode): device: torch.device class MultiOutput(ExternKernel): # Given an input MultiOutputLayout buffer, indexes out an actual buffer # from that result. This doesn't actually produce multiple outputs, # that's MultiOutputLayout! def codegen_list_tuple_access(self, basename, indices): if len(indices) > 0: itype, i = indices[0] if itype == list: return self.codegen_list_tuple_access(f"{basename}[{i}]", indices[1:]) elif itype == tuple: # cpp wrapper code needs to use std::get<> to access a tuple tuple_access = V.graph.wrapper_code.codegen_tuple_access( basename, self.get_name(), str(i) ) return self.codegen_list_tuple_access(tuple_access, indices[1:]) elif itype == dict: return self.codegen_list_tuple_access(f"{basename}['{i}']", indices[1:]) else: raise AssertionError("non supported index type") else: return basename def codegen(self, wrapper): wrapper.codegen_multi_output( self.get_name(), self.codegen_list_tuple_access(self.inputs[0].get_name(), self.indices), ) self.codegen_unbacked_symbol_defs(wrapper) def __init__(self, layout, input, indices: List[Tuple[Any, ...]]): super().__init__(None, layout, [input], ()) = V.graph.register_buffer(self) self.indices = indices def get_unbacked_symbol_uses(self) -> Set[sympy.Symbol]: return self.inputs[0].get_unbacked_symbol_uses() def should_allocate(self): return False def get_alias_names(self): return [ inp.get_name() for inp in self.inputs if isinstance(inp, FallbackKernel) and len(inp.get_alias_names()) > 0 ] def _prepare_convolution_fusion_create( cls, x: "TensorBox", weight: "TensorBox", bias: "TensorBox", padding: List[int], stride: List[int], dilation: List[int], groups: int, transposed: bool = False, output_padding: Optional[List[int]] = None, ): """ This function is a helper function to prepare inputs, layout and constant args for convolution post-op fusion's create function, including deciding the output layout (channels first or channels last), realizing inputs and make them etc. The function only supports the CPU device since conv post-op fusion kernel is only supported on CPU right now. """ # Port from aten/src/ATen/native/ConvUtils.h: _conv_input_size def _conv_input_size( output_size, weight_size, padding, output_padding, stride, dilation, groups ): assert len(output_size) == len(weight_size), "Expect input dim == weight dim" dim = len(output_size) assert dim > 2, "Expect input dim > 2" BATCH_DIM = 0 WEIGHT_INPUT_CHANNELS_DIM = 1 input_size = [] input_size.append(output_size[BATCH_DIM]) input_size.append(weight_size[WEIGHT_INPUT_CHANNELS_DIM] * groups) for d in range(2, dim): kernel = (weight_size[d] - 1) * dilation[d - 2] + 1 input_size_d = ( (output_size[d] - 1) * stride[d - 2] - (padding[d - 2] * 2) + kernel + output_padding[d - 2] ) input_size.append(input_size_d) return list(map(int, input_size)) # The size of prepacked_weight is the prepacked weight size of deconv: # Groups > 1: [g*o, i/g, ...] # Groups == 1: [o, i, ...] # Returns original weight size in [i, o, ...] def _original_deconv_weight_size( prepacked_weight, groups, ): prepacked_weight_size = prepacked_weight.size() dim = len(prepacked_weight_size) assert dim > 2, "Expect weight dim > 2" if groups > 1: weight_size = [] weight_size.append(prepacked_weight_size[1] * groups) weight_size.append(prepacked_weight_size[0] / groups) for d in range(2, dim): weight_size.append(prepacked_weight_size[d]) else: weight_size = prepacked_weight.transpose(0, 1).size() return weight_size x.realize() weight.realize() if bias is not None: bias.realize() with V.graph.fake_mode: # TODO cleaned up the fake_tensor trace as Linear implementation x_fake = ir_node_to_tensor(x, guard_shape=True) weight_fake = ir_node_to_tensor(weight, guard_shape=True) dims = len(x_fake.size()) - 2 assert 0 < len(padding) <= dims assert 0 < len(dilation) <= dims assert 0 < len(stride) <= dims padding = pad_listlike(padding, dims) dilation = pad_listlike(dilation, dims) stride = pad_listlike(stride, dims) if output_padding is None: output_padding = pad_listlike([0], dims) else: assert 0 < len(output_padding) <= dims output_padding = pad_listlike(output_padding, dims) assert isinstance(groups, int) if transposed: # When transposed, the size of the prepacked oneDNN weight is different # from the PyTorch weight. We're not able to run aten conv with such # size. We infer the output size from the input params here: weight_size = _original_deconv_weight_size(weight_fake, groups) input_size = x_fake.size() output_size = _conv_input_size( input_size, weight_size, padding, output_padding, stride, dilation, groups, ) else: bias_fake = ( ir_node_to_tensor(bias, guard_shape=True) if bias is not None else bias ) output = torch.ops.aten.convolution( x_fake, weight_fake, bias_fake, stride, padding, dilation, transposed, output_padding, groups, ) output_size = output.size() req_stride_order = [0] + list(reversed(range(1, len(stride) + 1))) req_stride_order = [len(req_stride_order)] + req_stride_order output_stride = make_channels_last_strides_for(output_size) x = cls.require_stride_order(x, req_stride_order) assert x.get_device().type == "cpu" and weight.get_device().type == "cpu" inputs = [x, weight] kernel_layout = FixedLayout( x.get_device(), x.get_dtype(), convert_shape_to_inductor(output_size), convert_shape_to_inductor(output_stride), ) constant_args = [padding, stride, dilation, groups] if transposed: constant_args.insert(1, output_padding) if bias is not None: inputs.append(bias) else: constant_args.insert(0, bias) return inputs, constant_args, kernel_layout, req_stride_order def _prepare_linear_fusion_create( cls, x: "TensorBox", weight: "TensorBox", bias: "TensorBox", ): """ This function is a helper function to prepare inputs, layout and constant args for linear post-op fusion's create function. The function only supports the CPU device since linear post-op fusion kernel is only supported on CPU right now. """ x.realize() weight.realize() if bias is not None: bias.realize() *m, _ = x.get_size() # The weight has been transposed during the qlinear weight prepack process. # # aten/src/ATen/native/quantized/cpu/qlinear_prepack.cpp#L291 _, oc = weight.get_size() output_size = list(m) + [oc] req_stride_order = list(reversed(range(len(x.get_size())))) x = cls.require_stride_order(x, req_stride_order) assert x.get_device().type == "cpu" and weight.get_device().type == "cpu" inputs = [x, weight] output_stride = make_contiguous_strides_for(output_size) kernel_layout = FixedLayout( x.get_device(), x.get_dtype(), output_size, output_stride, ) constant_args: List[Any] = [] if bias is not None: inputs.append(bias) else: constant_args.insert(0, bias) return inputs, constant_args, kernel_layout, req_stride_order class ConvolutionUnary(ExternKernelAlloc): def __init__( self, layout, inputs, constant_args=(), ): super().__init__( layout, inputs, constant_args, None, python_kernel_name="torch.ops.mkldnn._convolution_pointwise", cpp_kernel_name="mkldnn::_convolution_pointwise", ) self.cpp_kernel_key = "convolution_pointwise" self.cpp_op_schema = """ at::Tensor( const at::Tensor& input_t, const at::Tensor& weight_t, const c10::optional& bias_opt, at::IntArrayRef padding, at::IntArrayRef stride, at::IntArrayRef dilation, int64_t groups, c10::string_view attr, torch::List> scalars, c10::optional algorithm)""" def codegen(self, wrapper): wrapper.generate_extern_kernel_alloc_and_find_schema_if_needed( self.get_name(), self.get_kernel_name(), self.codegen_args(), self.cpp_op_schema, self.cpp_kernel_key, ) if isinstance(self.layout, Layout): self.codegen_size_asserts(wrapper) @classmethod def create( cls, x: "TensorBox", weight: "TensorBox", bias: "TensorBox", padding_: List[int], stride_: List[int], dilation_: List[int], groups: int, attr, scalars: Optional[List[Any]], algorithm, ): (inputs, constant_args, kernel_layout, _) = _prepare_convolution_fusion_create( cls, x, weight, bias, padding_, stride_, dilation_, groups ) constant_args = constant_args + [ attr, may_convert_to_optional(scalars), algorithm, ] return ConvolutionUnary( layout=kernel_layout, inputs=inputs, constant_args=constant_args, ) class ConvolutionBinary(ExternKernelAlloc): def __init__( self, layout, inputs, constant_args=(), cpp_constant_args=(), ): super().__init__( layout, inputs, constant_args, None, python_kernel_name="torch.ops.mkldnn._convolution_pointwise.binary", cpp_kernel_name="mkldnn::_convolution_pointwise", ) self.cpp_kernel_overload_name = "binary" self.cpp_kernel_key = "convolution_pointwise_binary" self.cpp_op_schema = """ at::Tensor( const at::Tensor& input_t, const at::Tensor& other_t, const at::Tensor& weight_t, const c10::optional& bias_opt, at::IntArrayRef padding, at::IntArrayRef stride, at::IntArrayRef dilation, int64_t groups, c10::string_view binary_attr, c10::optional alpha, c10::optional unary_attr, torch::List> unary_scalars, c10::optional unary_algorithm)""" self.cpp_constant_args = cpp_constant_args def codegen(self, wrapper): wrapper.generate_extern_kernel_alloc_and_find_schema_if_needed( self.get_name(), self.get_kernel_name(), self.codegen_args(), self.cpp_op_schema, self.cpp_kernel_key, self.cpp_kernel_overload_name, ) if isinstance(self.layout, Layout): self.codegen_size_asserts(wrapper) @classmethod def create( cls, x: "TensorBox", other: "TensorBox", weight: "TensorBox", bias: "TensorBox", padding_: List[int], stride_: List[int], dilation_: List[int], groups: int, binary_attr: str, binary_alpha: Optional[float], unary_attr: Optional[str], unary_scalars: Optional[List[Any]], unary_algorithm: Optional[str], ): ( inputs, constant_args, kernel_layout, req_stride_order, ) = _prepare_convolution_fusion_create( cls, x, weight, bias, padding_, stride_, dilation_, groups ) other = cls.require_stride_order(other, req_stride_order) inputs.insert(1, other) constant_args = constant_args + [ binary_attr, binary_alpha, unary_attr, may_convert_to_optional(unary_scalars), unary_algorithm, ] return ConvolutionBinary( layout=kernel_layout, inputs=inputs, constant_args=constant_args, ) class ConvolutionBinaryInplace(ExternKernelAlloc): def __init__( self, kernel_layout, inputs, constant_args=(), ): # Due to constrain of, other (Tensor&) should be at input[0] reordered_inputs = [inputs[1], inputs[0]] + inputs[2:] super().__init__( kernel_layout, reordered_inputs, constant_args, None, python_kernel_name="torch.ops.mkldnn._convolution_pointwise_.binary", cpp_kernel_name="mkldnn::_convolution_pointwise_", ) self.cpp_kernel_overload_name = "binary" self.cpp_kernel_key = "convolution_pointwise_binary_" # TODO: input[0] should be at::Tensor& self.cpp_op_schema = """ at::Tensor&( at::Tensor& other_t, const at::Tensor& input_t, const at::Tensor& weight_t, const c10::optional& bias_opt, at::IntArrayRef padding, at::IntArrayRef stride, at::IntArrayRef dilation, int64_t groups, c10::string_view binary_attr, c10::optional alpha, c10::optional unary_attr, torch::List> unary_scalars, c10::optional unary_algorithm)""" def codegen(self, wrapper): wrapper.generate_extern_kernel_alloc_and_find_schema_if_needed( self.get_name(), self.get_kernel_name(), self.codegen_args(), self.cpp_op_schema, self.cpp_kernel_key, self.cpp_kernel_overload_name, ) def get_mutation_names(self): return [self.inputs[0].get_name()] def get_unbacked_symbol_defs(self) -> Set[sympy.Symbol]: return set() @classmethod def create( cls, x: "TensorBox", other: "TensorBox", weight: "TensorBox", bias: "TensorBox", padding_: List[int], stride_: List[int], dilation_: List[int], groups: int, binary_attr: str, binary_alpha: Optional[float], unary_attr: Optional[str], unary_scalars: Optional[List[Any]], unary_algorithm: Optional[str], ): ( inputs, constant_args, _, req_stride_order, ) = _prepare_convolution_fusion_create( cls, x, weight, bias, padding_, stride_, dilation_, groups ) other = cls.require_stride_order(other, req_stride_order) inputs.insert(1, other) constant_args = constant_args + [ binary_attr, binary_alpha, unary_attr, may_convert_to_optional(unary_scalars), unary_algorithm, ] packed = ConvolutionBinaryInplace( kernel_layout=NoneLayout(inputs[1].get_device()), # type: ignore[arg-type] inputs=inputs, constant_args=constant_args, ) mark_node_as_mutating(packed, inputs[1]) # This op mutates in place which means that the result is not the # target but rather the input that is being mutated # init reorders the inputs, so inputs[1] becomes packed.inputs[0] return packed.inputs[0] class MKLPackedLinear(ExternKernelAlloc): def __init__( self, layout, inputs, constant_args=(), ): super().__init__( layout, inputs, constant_args, None, python_kernel_name="torch.ops.mkl._mkl_linear", cpp_kernel_name="mkl::_mkl_linear", ) self.cpp_kernel_key = "mkl_linear" self.cpp_op_schema = """ at::Tensor( const at::Tensor& self, const at::Tensor& mkl_weight_t, const at::Tensor& origin_weight_t, const c10::optional& bias_opt, const int64_t prepack_batch_size)""" def codegen(self, wrapper): wrapper.generate_extern_kernel_alloc_and_find_schema_if_needed( self.get_name(), self.get_kernel_name(), self.codegen_args(), self.cpp_op_schema, self.cpp_kernel_key, ) @classmethod def create(cls, x, packed_w, orig_w, batch_size): x = cls.require_stride1(cls.realize_input(x)) orig_w = cls.require_stride1(cls.realize_input(orig_w)) *m, _ = x.get_size() oc, _ = orig_w.get_size() output_size = list(m) + [oc] output_stride = make_contiguous_strides_for(output_size) inputs = [x, packed_w, orig_w] constant_args = [None, batch_size] return MKLPackedLinear( layout=FixedLayout( x.get_device(), x.get_dtype(), output_size, output_stride ), inputs=inputs, constant_args=constant_args, ) class LinearUnary(ExternKernelAlloc): def __init__( self, layout, inputs, constant_args=(), ): super().__init__( layout, inputs, constant_args, None, python_kernel_name="torch.ops.mkldnn._linear_pointwise", cpp_kernel_name="mkldnn::_linear_pointwise", ) self.cpp_kernel_key = "linear_pointwise" self.cpp_op_schema = """ at::Tensor( const at::Tensor& input_t, const at::Tensor& weight_t, const c10::optional& bias_opt, c10::string_view attr, torch::List> scalars, c10::optional algorithm)""" def codegen(self, wrapper): wrapper.generate_extern_kernel_alloc_and_find_schema_if_needed( self.get_name(), self.get_kernel_name(), self.codegen_args(), self.cpp_op_schema, self.cpp_kernel_key, ) @classmethod def create(cls, x, w, b, attr, scalars, algorithm): x = cls.require_contiguous(cls.realize_input(x)) w = cls.require_contiguous(cls.realize_input(w)) *m, ic = x.get_size() oc, ic = w.get_size() inputs = [x, w] constant_args = [attr, scalars if scalars else [-1], algorithm] if b is not None: b = cls.require_contiguous(cls.realize_input(b)) inputs.append(b) else: constant_args.insert(0, None) return LinearUnary( layout=FlexibleLayout( device=x.get_device(), dtype=x.get_dtype(), size=list(m) + [oc], ), inputs=inputs, constant_args=constant_args, ) def apply_constraint(self): pass class LinearBinary(ExternKernelAlloc): kernel = "torch.ops.mkldnn._linear_pointwise.binary" def __init__( self, layout, inputs, constant_args=(), ): super().__init__( layout, inputs, constant_args, None, python_kernel_name="torch.ops.mkldnn._linear_pointwise.binary", cpp_kernel_name="mkldnn::_linear_pointwise", ) self.cpp_kernel_overload_name = "binary" self.cpp_kernel_key = "linear_pointwise_binary" self.cpp_op_schema = """ at::Tensor( const at::Tensor& input_t, const at::Tensor& other_t, const at::Tensor& weight_t, const c10::optional& bias_opt, c10::string_view attr) """ def codegen(self, wrapper): wrapper.generate_extern_kernel_alloc_and_find_schema_if_needed( self.get_name(), self.get_kernel_name(), self.codegen_args(), self.cpp_op_schema, self.cpp_kernel_key, self.cpp_kernel_overload_name, ) @classmethod def create(cls, x, y, w, b, attr): x = cls.require_contiguous(cls.realize_input(x)) y = cls.require_contiguous(cls.realize_input(y)) w = cls.require_contiguous(cls.realize_input(w)) *m, ic = x.get_size() oc, ic = w.get_size() inputs = [x, y, w] constant_args = [attr] if b is not None: b = cls.require_contiguous(cls.realize_input(b)) inputs.append(b) else: constant_args.insert(0, b) return LinearBinary( layout=FlexibleLayout( device=x.get_device(), dtype=x.get_dtype(), size=list(m) + [oc], ), inputs=inputs, constant_args=constant_args, ) def apply_constraint(self): pass class ConvolutionTransposeUnary(ExternKernelAlloc): def __init__( self, layout, inputs, constant_args=(), ): super().__init__( layout, inputs, constant_args, None, python_kernel_name="torch.ops.mkldnn._convolution_transpose_pointwise", cpp_kernel_name="mkldnn::_convolution_transpose_pointwise", ) self.cpp_kernel_key = "convolution_transpose_pointwise" self.cpp_op_schema = """ at::Tensor( const at::Tensor& input_t, const at::Tensor& weight_t, const c10::optional& bias_opt, at::IntArrayRef padding, at::IntArrayRef output_padding, at::IntArrayRef stride, at::IntArrayRef dilation, int64_t groups, c10::string_view attr, torch::List> scalars, c10::optional algorithm)""" def codegen(self, wrapper): wrapper.generate_extern_kernel_alloc_and_find_schema_if_needed( self.get_name(), self.get_kernel_name(), self.codegen_args(), self.cpp_op_schema, self.cpp_kernel_key, ) @classmethod def create( cls, x: "TensorBox", weight: "TensorBox", bias: "TensorBox", padding_: List[int], output_padding_: List[int], stride_: List[int], dilation_: List[int], groups_: int, attr, scalars: Optional[List[Any]], algorithm, ): transposed = True ( inputs, constant_args, kernel_layout, _, ) = _prepare_convolution_fusion_create( cls, x, weight, bias, padding_, stride_, dilation_, groups_, transposed, output_padding_, ) constant_args = constant_args + [ attr, may_convert_to_optional(scalars), algorithm, ] return ConvolutionTransposeUnary( layout=kernel_layout, inputs=inputs, constant_args=constant_args, ) class MkldnnRnnLayer(ExternKernelAlloc): def __init__( self, layout, inputs, constant_args=(), ): super().__init__( layout, inputs, constant_args, None, python_kernel_name="aten.mkldnn_rnn_layer", cpp_kernel_name="at::mkldnn_rnn_layer", ) @classmethod def create( cls, x: "TensorBox", w0: "TensorBox", w1: "TensorBox", w2: "TensorBox", w3: "TensorBox", hx: "TensorBox", cx: "TensorBox", reverse: bool, batch_sizes: List[int], mode: int, hidden_size: int, num_layers: int, has_biases: bool, bidirectional: bool, batch_first: bool, train: bool, ): x = cls.require_stride1(cls.realize_input(x)) # If batch_first, x has been permuted in lstm before entering the mkldnn_rnn_layer. # Make sure x is contiguous in batch_first case. x.freeze_layout() w0 = cls.require_stride1(cls.realize_input(w0)) w1 = cls.require_stride1(cls.realize_input(w1)) w2 = cls.require_stride1(cls.realize_input(w2)) w3 = cls.require_stride1(cls.realize_input(w3)) hx = cls.require_stride1(cls.realize_input(hx)) hx.freeze_layout() cx = cls.require_stride1(cls.realize_input(cx)) cx.freeze_layout() input_size = x.get_size() assert len(input_size) == 3, "Expect lstm input to be 3D" # batch_first is handled in the lstm OP. When entering # rnn_layer here, we'll always have batch_first = False seq_length, mini_batch, input_size = input_size output_shape = [seq_length, mini_batch, hidden_size] hy_shape = hx.get_size() cy_shape = cx.get_size() res: List[IRNode] = [] inputs = [x, w0, w1, w2, w3, hx, cx] constant_args = [ reverse, batch_sizes, mode, hidden_size, num_layers, has_biases, bidirectional, batch_first, train, ] packed = MkldnnRnnLayer( MultiOutputLayout(x.get_device()), inputs=inputs, constant_args=constant_args, ) def get_strides_of_lstm_output(output_shape, batch_first): assert len(output_shape) == 3, "Expect output_shape to be 3D" return make_contiguous_strides_for(output_shape) output_sizes = [output_shape, hy_shape, cy_shape] output_strides = [ get_strides_of_lstm_output(output_shape, batch_first), make_contiguous_strides_for(hy_shape), make_contiguous_strides_for(cy_shape), ] output_ir = [ MultiOutput( FixedLayout( x.get_device(), x.get_dtype(), output_size, output_stride, ), packed, [(tuple, i)], ) for i, (output_size, output_stride) in enumerate( zip(output_sizes, output_strides) ) ] return output_ir class QConvPointWisePT2E(ExternKernelAlloc): def __init__( self, layout, inputs, constant_args=(), ): """ if bias is not None - inputs = [x, w, b, weight_scale, weight_zp] - const_args is: [stride, padding, dilation, groups, x_scale, x_zp, o_inv_scale, o_zp, fp32_output, unary_attr, unary_scalars, unary_algorithm] else - inputs = [x, w, weight_scale, weight_zp] - const_args is: [bias, stride, padding, dilation, groups, x_scale, x_zp, o_inv_scale, o_zp, fp32_output, unary_attr, unary_scalars, unary_algorithm] """ self.has_bias = len(inputs) == 5 super().__init__( layout, inputs, constant_args, None, python_kernel_name="torch.ops.onednn.qconv2d_pointwise", cpp_kernel_name="onednn::qconv2d_pointwise", ) self.cpp_kernel_key = "qconv2d_pointwise" self.cpp_op_schema = """ at::Tensor( at::Tensor act, double act_scale, int64_t act_zero_point, at::Tensor weight, at::Tensor weight_scales, at::Tensor weight_zero_points, c10::optional bias, torch::List stride, torch::List padding, torch::List dilation, int64_t groups, double inv_output_scale, int64_t output_zero_point, c10::optional output_dtype, c10::string_view attr, torch::List> scalars, c10::optional algorithm)""" def codegen(self, wrapper): # Parser the inputs and constant args = [x.codegen_reference() for x in self.inputs] const_args = [] const_args.extend(self.codegen_const_args()) x = args[0] packed_weight = args[1] bias = args[2] if self.has_bias else const_args[0] w_scale, w_zp = args[-2], args[-1] ( stride, padding, dilation, groups, x_scale, x_zp, o_inv_scale, o_zp, output_dtype, unary_attr, unary_scalars, unary_algorithm, ) = const_args[-12:] codegen_args = ( x, x_scale, x_zp, packed_weight, w_scale, w_zp, bias, stride, padding, dilation, groups, o_inv_scale, o_zp, output_dtype, unary_attr, unary_scalars, unary_algorithm, ) wrapper.generate_extern_kernel_alloc_and_find_schema_if_needed( self.get_name(), self.get_kernel_name(), codegen_args, self.cpp_op_schema, self.cpp_kernel_key, ) if isinstance(self.layout, Layout): self.codegen_size_asserts(wrapper) @classmethod def create( cls, x: "TensorBox", x_scale: float, x_zp: int, weight: "TensorBox", # packed_weight w_scale: "TensorBox", w_zp: "TensorBox", bias: "TensorBox", stride_: List[int], padding_: List[int], dilation_: List[int], groups: int, o_inv_scale: float, output_zero_point: int, output_dtype, unary_attr, unary_scalars, unary_algorithm, ): transposed = False output_padding = None (inputs, constant_args, kernel_layout, _) = _prepare_convolution_fusion_create( cls, x, weight, bias, padding_, stride_, dilation_, groups, transposed, output_padding, ) # swap padding and stride to align with functional conv arg order if bias is None: constant_args[1], constant_args[2] = constant_args[2], constant_args[1] else: constant_args[0], constant_args[1] = constant_args[1], constant_args[0] w_scale.realize() w_zp.realize() inputs = inputs + [w_scale, w_zp] constant_args = constant_args + [ x_scale, x_zp, o_inv_scale, output_zero_point, output_dtype, unary_attr, may_convert_to_optional(unary_scalars), unary_algorithm, ] if output_dtype is not None: assert output_dtype in [torch.float32, torch.bfloat16] # in _prepare_convolution_fusion_create, we use x.dtype (uint8) to create kernel_layout # if we set output_dtype is not None, the output buf should be output_dtype instead of uint8. kernel_layout.dtype = output_dtype return QConvPointWisePT2E( layout=kernel_layout, inputs=inputs, constant_args=constant_args, ) class QConvPointWiseBinaryPT2E(ExternKernelAlloc): def __init__( self, layout, inputs, constant_args=(), ): """ Needs input/weight/output qparams if bias is not None - inputs = [x, w, b, accum, w_scale, w_zp] - const_args = [stride, padding, dilation, groups, x_scale, x_zp, accum_scale, accum_zp, o_inv_scale, o_zp, fp32_output, binary_attr, aplha, unary_attr, unary_scalars, unary_algorithm] else - inputs = [x, w, accum, w_scale, w_zp] - const_args = const_args is: [bias, stride, padding, dilation, groups, x_scale, x_zp, accum_scale, accum_zp, o_inv_scale, o_zp, fp32_output, binary_attr, aplha, unary_attr, unary_scalars, unary_algorithm] """ self.has_bias = len(inputs) == 6 self.idx_for_inplace_sum = 3 if self.has_bias else 2 super().__init__( layout, inputs, constant_args, None, python_kernel_name="torch.ops.onednn.qconv2d_pointwise.binary", cpp_kernel_name="onednn::qconv2d_pointwise", ) self.cpp_kernel_overload_name = "binary" self.cpp_kernel_key = "qconv2d_pointwise_binary" self.cpp_op_schema = """ at::Tensor( at::Tensor act, double act_scale, int64_t act_zero_point, at::Tensor accum, double accum_scale, int64_t accum_zero_point, at::Tensor weight, at::Tensor weight_scales, at::Tensor weight_zero_points, c10::optional bias, torch::List stride, torch::List padding, torch::List dilation, int64_t groups, double inv_output_scale, int64_t output_zero_point, c10::optional output_dtype, c10::string_view binary_attr, c10::optional alpha, c10::optional attr, torch::List> scalars, c10::optional algorithm)""" def codegen(self, wrapper): # Parser the inputs and constant args = [x.codegen_reference() for x in self.inputs] const_args = [] const_args.extend(self.codegen_const_args()) x = args[0] packed_weight = args[1] bias = args[2] if self.has_bias else const_args[0] accum, w_scale, w_zp = args[-3], args[-2], args[-1] ( stride, padding, dilation, groups, x_scale, x_zp, accum_scale, accum_zp, o_inv_scale, o_zp, output_dtype, binary_attr, alpha, unary_attr, unary_scalars, unary_algorithm, ) = const_args[-16:] conv_args = ( x, x_scale, x_zp, accum, accum_scale, accum_zp, packed_weight, w_scale, w_zp, bias, stride, padding, dilation, groups, o_inv_scale, o_zp, output_dtype, binary_attr, alpha, unary_attr, unary_scalars, unary_algorithm, ) wrapper.generate_extern_kernel_alloc_and_find_schema_if_needed( self.get_name(), self.get_kernel_name(), conv_args, self.cpp_op_schema, self.cpp_kernel_key, self.cpp_kernel_overload_name, ) if isinstance(self.layout, Layout): self.codegen_size_asserts(wrapper) def get_mutation_names(self): return [self.inputs[self.idx_for_inplace_sum].get_name()] def get_unbacked_symbol_defs(self) -> Set[sympy.Symbol]: return set() @classmethod def create( cls, x: "TensorBox", x_scale, x_zp, accum: "TensorBox", accum_scale, accum_zp, weight: "TensorBox", # packed_weight w_scale, w_zp, bias: "TensorBox", stride_: List[int], padding_: List[int], dilation_: List[int], groups: int, o_inv_scale: "TensorBox", output_zero_point: "TensorBox", output_dtype, binary_attr, alpha, unary_attr, unary_scalars, unary_algorithm, ): transposed = False output_padding = None ( inputs, constant_args, kernel_layout, req_stride_order, ) = _prepare_convolution_fusion_create( cls, x, weight, bias, padding_, stride_, dilation_, groups, transposed, output_padding, ) accum = cls.require_stride_order(accum, req_stride_order) inputs.append(accum) # swap padding and stride to align with functional conv arg order if bias is None: constant_args[1], constant_args[2] = constant_args[2], constant_args[1] else: constant_args[0], constant_args[1] = constant_args[1], constant_args[0] w_scale.realize() w_zp.realize() inputs = inputs + [w_scale, w_zp] constant_args = constant_args + [ x_scale, x_zp, accum_scale, accum_zp, o_inv_scale, output_zero_point, output_dtype, binary_attr, alpha, unary_attr, may_convert_to_optional(unary_scalars), unary_algorithm, ] assert ( binary_attr == "sum" ), "For now, only post op sum is supported in QConvPointWiseBinaryPT2E." packed = QConvPointWiseBinaryPT2E( layout=NoneLayout(accum.get_device()), inputs=inputs, constant_args=constant_args, ) mark_node_as_mutating(packed, accum) # Return accum since it has been inplace changed. return packed.inputs[packed.idx_for_inplace_sum] class QLinearPointwisePT2E(ExternKernelAlloc): def __init__( self, layout, inputs, constant_args=(), has_bias=True, x_scale_zp_are_tensors=False, ): """ if bias is not None - inputs = [x, w, b, weight_scale, weight_zp] - const_args is: [x_scale, x_zp, o_inv_scale, o_zp, fp32_output, unary_attr, unary_scalars, unary_algorithm] else - inputs = [x, w, weight_scale, weight_zp] - const_args is: [bias, x_scale, x_zp, o_inv_scale, o_zp, fp32_output, unary_attr, unary_scalars, unary_algorithm] """ self.has_bias = has_bias self.x_scale_zp_are_tensors = x_scale_zp_are_tensors super().__init__( layout, inputs, constant_args, None, python_kernel_name=( "torch.ops.onednn.qlinear_pointwise.tensor" if x_scale_zp_are_tensors else "torch.ops.onednn.qlinear_pointwise.default" ), cpp_kernel_name="onednn::qlinear_pointwise", ) self.cpp_kernel_overload_name = "tensor" if x_scale_zp_are_tensors else "" self.cpp_kernel_key = "qlinear_pointwise" x_scale_type_str, x_zp_type_str = ( ("at::Tensor", "at::Tensor") if x_scale_zp_are_tensors else ("double", "int64_t") ) self.cpp_op_schema = f""" at::Tensor( at::Tensor act, {x_scale_type_str} act_scale, {x_zp_type_str} act_zero_point, at::Tensor weight, at::Tensor weight_scales, at::Tensor weight_zero_points, c10::optional bias, double inv_output_scale, int64_t output_zero_point, c10::optional output_dtype, std::string post_op_name, torch::List> post_op_args, std::string post_op_algorithm)""" def codegen(self, wrapper): # Parser the inputs and constant args = [x.codegen_reference() for x in self.inputs] const_args = [] const_args.extend(self.codegen_const_args()) x = args[0] packed_weight = args[1] bias = args[2] if self.has_bias else const_args[0] w_scale, w_zp = args[-2], args[-1] if self.x_scale_zp_are_tensors: assert len(args) >= 4 x_scale, x_zp = args[-4], args[-3] ( o_inv_scale, o_zp, output_dtype, unary_attr, unary_scalars, unary_algorithm, ) = const_args[-6:] else: assert len(const_args) >= 8 ( x_scale, x_zp, o_inv_scale, o_zp, output_dtype, unary_attr, unary_scalars, unary_algorithm, ) = const_args[-8:] codegen_args = ( x, x_scale, x_zp, packed_weight, w_scale, w_zp, bias, o_inv_scale, o_zp, output_dtype, unary_attr, unary_scalars, unary_algorithm, ) wrapper.generate_extern_kernel_alloc_and_find_schema_if_needed( self.get_name(), self.get_kernel_name(), codegen_args, self.cpp_op_schema, self.cpp_kernel_key, self.cpp_kernel_overload_name, ) if isinstance(self.layout, Layout): self.codegen_size_asserts(wrapper) @classmethod def create( cls, x: "TensorBox", x_scale: float, x_zp: int, weight: "TensorBox", # packed_weight w_scale: "TensorBox", w_zp: "TensorBox", bias: "TensorBox", o_inv_scale: float, output_zero_point: int, output_dtype, unary_attr, unary_scalars, unary_algorithm, ): (inputs, constant_args, kernel_layout, _) = _prepare_linear_fusion_create( cls, x, weight, bias, ) if isinstance(x_scale, TensorBox) and isinstance(x_zp, TensorBox): x_scale.realize() x_zp.realize() inputs = inputs + [x_scale, x_zp] x_scale_zp_are_tensors = True else: assert isinstance(x_scale, float) and isinstance(x_zp, int) constant_args = constant_args + [x_scale, x_zp] x_scale_zp_are_tensors = False w_scale.realize() w_zp.realize() inputs = inputs + [w_scale, w_zp] constant_args = constant_args + [ o_inv_scale, output_zero_point, output_dtype, unary_attr, may_convert_to_optional(unary_scalars), unary_algorithm, ] if output_dtype is not None: assert output_dtype in [torch.float32, torch.bfloat16] # in _prepare_linear_fusion_create, we use x.dtype (uint8) to create kernel_layout # if we set fp32_output, the output buf should be dtype float32 instead of uint8. kernel_layout.dtype = output_dtype return QLinearPointwisePT2E( layout=kernel_layout, inputs=inputs, constant_args=constant_args, has_bias=(bias is not None), x_scale_zp_are_tensors=x_scale_zp_are_tensors, ) @dataclasses.dataclass class MutableBox(IRNode): """ TensorBox / StorageBox allow in-place mutation of Tensors """ data: IRNode def __getattr__(self, name): fn = getattr(, name) if callable(fn): return fn raise AttributeError(f"{type(}.{name} not callable") def realize(self): return def get_unbacked_symbol_uses(self) -> Set[sympy.Symbol]: return def codegen_reference(self, writer=None): return @property def layout(self): return # type: ignore[attr-defined] def get_layout(self): return self.layout def get_size(self): return @property def dtype(self): return def __str__(self): if isinstance(, MutableBox): line0 = f"{type(self).__name__}({type(}(" endl = "))" inner = else: line0 = f"{type(self).__name__}(" inner = endl = ")" lines = [ line0, indent(str(inner)), endl, ] return "\n".join(lines) __repr__ = __str__ class TensorBox(MutableBox): @staticmethod def create(data): return TensorBox(StorageBox(data)) class StorageBox(MutableBox): def is_input_buffer(self): if isinstance(, (InputBuffer, ReinterpretView)): return in V.graph.graph_inputs return False def realize(self): if isinstance(, ( ComputedBuffer, InputsKernel, InputBuffer, ReinterpretView, TemplateBuffer, ), ): return assert isinstance(, (Pointwise, Reduction, Scan)), type( origin_node = traceback = = ComputedBuffer( name=None, layout=FlexibleLayout(,,, ),, ) = V.graph.register_buffer( = = origin_node = traceback return def realize_hint(self): """ Called on buffers we expect to be forced to realize later. """ if ( isinstance(, (Pointwise, Reduction)) and self.num_reads() > 1 and self.is_pointwise_non_scalar_tensor_num_reads_larger_than_one() ): self.realize() def has_exceeded_max_reads(self): return isinstance(, Pointwise) and ( self.num_reads() > config.realize_acc_reads_threshold or self.has_large_inner_fn() ) def mark_reuse(self, users): """ A heuristic to decide if we should realize a tensor that is used multiple times. """ def should_realize_on_cpu(loops: Union[Pointwise, Reduction]): """ The heuristic for realizing reused result of heavy ops on cpu """ heavy_ops = ["exp"] # a list of heavy ops fn_str = loops.inner_fn_str() return any((op + "(") in fn_str for op in heavy_ops) if ( users > 1 and isinstance(, (Pointwise, Reduction)) and ( self.num_reads() > config.realize_reads_threshold or self.has_large_inner_fn() or (is_cpu( and should_realize_on_cpu( ) ): self.realize() @cache_on_self def num_reads(self): data = if isinstance(data, (InputsKernel, InputBuffer, ReinterpretView)): return 1 if isinstance(data, ComputedBuffer): read_writes = data.get_read_writes() else: assert isinstance(data, (Pointwise, Reduction)), type(data) read_writes = ComputedBuffer( name=None, layout=FlexibleLayout( device=data.get_device(), dtype=data.get_dtype(), size=data.get_size(), ), data=data, ).get_read_writes() return len(read_writes.reads) @cache_on_self def is_pointwise_non_scalar_tensor_num_reads_larger_than_one(self): # Skip the check for non Pointwise instances return ( (sum(read.index != 0 for read in > 1) if isinstance(, Pointwise) and all( not isinstance(read, dependencies.StarDep) for read in ) else True ) @dataclasses.dataclass class Subgraph(IRNode): name: str graph_module: torch.fx.GraphModule graph: Optional["GraphLowering"] = None @dataclasses.dataclass class Conditional(ExternKernel): predicate: Optional[DynamicScalar] = None operands: Optional[List[TensorBox]] = None true_subgraph: Optional[Subgraph] = None false_subgraph: Optional[Subgraph] = None outputs: Optional[List[MultiOutput]] = None def __init__( self, predicate: DynamicScalar, operands: List[TensorBox], true_subgraph: Subgraph, false_subgraph: Subgraph, layout: MultiOutputLayout, ): self.predicate = predicate self.operands = operands self.true_subgraph = true_subgraph self.false_subgraph = false_subgraph super().__init__( name=None, layout=layout, # type: ignore[arg-type] inputs=[predicate, *operands], # type: ignore[list-item] ) = V.graph.register_buffer(self) @classmethod def create( cls, predicate: TensorBox, true_fn: Subgraph, false_fn: Subgraph, operands: List[TensorBox], ): predicate = cls.realize_input(predicate) operands = [cls.realize_input(x) for x in operands] fx_operands = V.graph.current_node.args[-1] fake_operands = [x.meta["val"] for x in fx_operands] # type: ignore[union-attr] for subgraph in (true_fn, false_fn): if subgraph.graph is None: # create and lower subgraphs subgraph.graph = V.graph.make_subgraph( gm=subgraph.graph_module, example_inputs=fake_operands,, ) with V.set_graph_handler(subgraph.graph):*fake_operands) true_outputs = true_fn.graph.graph_outputs # type: ignore[union-attr] false_outputs = true_fn.graph.graph_outputs # type: ignore[union-attr] def _aliased_buffers(outputs): buffers = [ output.unwrap_view() if isinstance(output, ReinterpretView) else output for output in outputs ] # assuming the same buffer is represented by the same IRNode object return len({id(buffer) for buffer in buffers}) < len(outputs) for name, outputs in (("true_fn", true_outputs), ("false_fn", false_outputs)): if _aliased_buffers(true_outputs): raise AssertionError( "Output aliasing is currently not supported in compiled torch.cond. " f"The outputs of the {name} subgraph of torch.cond are aliased: {outputs}" ) # make sure true and false outputs are structurally equivalent assert len(true_outputs) == len(false_outputs), (true_outputs, false_outputs) for i, (to, fo) in enumerate(zip(true_outputs, false_outputs)): assert to.get_size() == fo.get_size(), (i, to, fo) assert to.get_stride() == fo.get_stride(), (i, to, fo) assert to.get_device() == fo.get_device(), (i, to, fo) assert to.get_dtype() == fo.get_dtype(), (i, to, fo) assert to.get_layout().offset == fo.get_layout().offset, (i, to, fo) conditional = Conditional( predicate=predicate, operands=operands, true_subgraph=true_fn, false_subgraph=false_fn, # use predicate device for consistent codegen-ing layout=MultiOutputLayout(predicate.get_device()), ) outputs = [ MultiOutput( FixedLayout( device=output.get_device(), dtype=output.get_dtype(), size=output.get_size(), stride=output.get_stride(), offset=output.get_layout().offset, ), conditional, [(list, i)], ) # as the true and false outputs are equivalent, # we can use either of them here as a "template" for i, output in enumerate(true_outputs) ] conditional.outputs = outputs return outputs def codegen(self, wrapper): wrapper.codegen_conditional(self) class InterpreterShim(torch.fx.Interpreter): @staticmethod @functools.lru_cache(None) def _dummy_gm(): return torch.fx.symbolic_trace(identity) def __init__(self, graph, submodules): # call super() with a placeholder to avoid constructing a # GraphModule which is very expensive (it does codegen). super().__init__(self._dummy_gm(), garbage_collect_values=False) self.module = self # type: ignore[assignment] self.graph = graph self.submodules = submodules self.extra_traceback = False self.fetch_attr = submodules.__getitem__ self.current_node = None def run_node(self, n: torch.fx.Node) -> Any: self.current_node = n return super().run_node(n) def run(self, *args, **kwargs): with V.set_interpreter_handler(self): return super().run(*args, **kwargs) class LoopBody: """ Captures the body of a Loops subclass into an FX graph. Persists any indexing simplifications and makes it easier to analyze loop bodies. """ def __init__(self, fn, args, var_ranges): super().__init__() self.var_ranges = var_ranges self.indexing_exprs = {} self.indexing_exprs_name = {} self.reads = [] self.writes = [] self.reads_name2expr = {} self.writes_name2expr = {} self.other = [] self.submodules = {"get_index": self.get_index} self.subblocks = {} self.indirect_vars = [] self.root_block = LoopBodyBlock(self, fn, args) self.indexing = None @cache_on_self def get_nodes(self): all_graphs = itertools.chain( (self.root_block.graph,), (block.graph for block in self.subblocks.values()), ) return [node for graph in all_graphs for node in graph.nodes] @cache_on_self def bounds(self): # Doing a local import to avoid dumping all the code here from .bounds import BoundVars return BoundVars(self) def debug_str(self): lines = [f"var_ranges = {dict(self.var_ranges)}"] lines.extend([f"{name} = {val}" for name, val in self.indexing_exprs.items()]) lines.extend( [ block.debug_str(name) for name, block in itertools.chain( [("body", self.root_block)], self.subblocks.items() ) ] ) return "\n".join(lines) def add_index_expr(self, expr: sympy.Expr, category, buf_name): getattr(self, category).append(expr) if buf_name is not None: getattr(self, f"{category}_name2expr")[buf_name] = expr if expr not in self.indexing_exprs_name: name = f"index{len(self.indexing_exprs)}" self.indexing_exprs_name[expr] = name self.indexing_exprs[name] = expr return self.indexing_exprs_name[expr] def add_submodule(self, block, prefix): """Not actually for nn.Modules, but subblocks in generated code are mapped to FX call_module opcodes""" if prefix[-1].isnumeric() and prefix not in self.submodules: name = prefix else: name = f"{prefix}{len(self.submodules)}" self.submodules[name] = block return name def add_indirect(self, size): name = f"indirect{len(self.indirect_vars)}" var = sympy_index_symbol(name) self.indirect_vars.append(var) return var def replace_indirect(self, old, new): """Swap in a variable used in indirect indexing""" if str(old) == str(new): return assert self.indexing is not None self.indexing = {k: sympy_subs(v, {old: new}) for k, v in self.indexing.items()} def get_index(self, name): assert self.indexing is not None return self.indexing[name] def __call__(self, *indices): index = list(itertools.chain.from_iterable(indices)) assert len(index) == len(self.var_ranges), (index, self.var_ranges) assert all(v not in self.var_ranges for v in index) replacements = dict(zip(self.var_ranges.keys(), index)) self.indexing = { name: sympy_subs(expr, replacements) for name, expr in self.indexing_exprs.items() } result = self.root_block() self.indexing = None return result class LoopBodyBlock: """ Captures the body of a Loops subclass into an FX graph. In normal cases there will be a 1:1 mapping between LoopBody and LoopBodyBlock, hower in the case of ops.masked() the masked out operations will manifest as an extra LoopBodyBlock. """ def __init__(self, body: LoopBody, fn: Callable[..., Any], args: List[Any]): self.body = body def add_index(expr, category, buf_name=None): return tracer.create_proxy( "call_module", "get_index", (self.body.add_index_expr(expr, category, buf_name),), {}, ) class CaptureIndexing(V.WrapperHandler): # type: ignore[name-defined] = "CaptureIndexing" def load(self, name: str, index: sympy.Expr): index = add_index(index, "reads", name) return self._inner.load(name, index) def store(self, name, index, value, mode=None): index = add_index(index, "writes", name) return, index, value, mode) def store_reduction(self, name, index, value): index = add_index(index, "writes", name) return self._inner.store_reduction(name, index, value) def reduction(self, dtype, src_dtype, reduction_type, value): result = self._inner.reduction(dtype, src_dtype, reduction_type, value) if "welford" in reduction_type: return tuple(result[i] for i in range(3)) return result def index_expr(self, index, dtype): if isinstance(index, (int, sympy.Integer)): return self._inner.constant(int(index), dtype) index = add_index(index, "other") return self._inner.index_expr(index, dtype) def bucketize( self, values, offsets_name: str, offsets_size: sympy.Expr, indexing_dtype: torch.dtype, right: bool, ): offsets_size = add_index(offsets_size, "other") return self._inner.bucketize( values, offsets_name, offsets_size, indexing_dtype, right ) @staticmethod def masked(mask_proxy, masked_body: Callable[..., Any], other_proxy): """ Recursively capture the masked out body in another LoopBodyBlock """ subblock: LoopBodyBlock def shim(mask, other): return V.ops.masked(mask, subblock, other) name = self.body.add_submodule(shim, "masked_subblock") subblock = LoopBodyBlock(self.body, masked_body, []) self.body.subblocks[name] = subblock return tracer.create_proxy( "call_module", name, (mask_proxy, other_proxy), {} ) @staticmethod def scan( dtype_proxy, combine_fn: Callable[..., Any], value_proxy, init_proxy ): def shim(dtype, value, init): return V.ops.scan(dtype, combine_fn, value, init) name = self.body.add_submodule(shim, "scan") return tracer.create_proxy( "call_module", name, (dtype_proxy, value_proxy, init_proxy), {} ) def frexp(self, value_proxy): result = self._inner.frexp(value_proxy) # Proxies are iterable, but some methods expect tuples/lists return (result[0], result[1]) @staticmethod def indirect_indexing(index_proxy, size, check=True): """ Flow data from tensors into indexing formulas. Introduce a call_module to update the indexing. """ var = self.body.add_indirect(size) def set_indirect(new_var): self.body.replace_indirect( var, V.ops.indirect_indexing(new_var, size, check) ) tracer.create_proxy( "call_module", self.body.add_submodule(set_indirect, f"set_{var}"), (index_proxy,), {}, ) return var @staticmethod def output(result): tracer.create_proxy("output", "output", (result,), {}) tracer = torch.fx.Tracer() tracer.graph = torch.fx.Graph(tracer_cls=tracer.__class__) proxy_ops = tracer.create_proxy("placeholder", "ops", (), {}) from .index_propagation import IndexPropagation from .sizevars import SimplifyIndexing handler: Any = SimplifyIndexing( CaptureIndexing(proxy_ops), self.body.var_ranges ) if config.constant_and_index_propagation: handler = IndexPropagation(handler) with V.set_ops_handler(handler): # This indirection is just a cute way to get IndexPropagation to # unwrap the return value. ops.output(fn(*args)) self.graph = tracer.graph def __call__(self): graph = self.graph submodules = self.body.submodules return InterpreterShim(graph, submodules).run(V.get_ops_handler()) def debug_str(self, name="block"): code = torch.fx.GraphModule(self.body.submodules, self.graph).code return re.sub( # strip `; del var0` suffixes to make output prettier r";[^\n]*", "", code.strip().replace("def forward(", f"def {name}("), ) class Wait(ExternKernelAlloc): """ Wait should not be used by itself. It should always be constructed in tandem with a collective op that produces a work to wait on. """ def __init__( self, layout, inputs, constant_args=(), ): super().__init__(layout, inputs, constant_args) def should_allocate(self): return False def codegen(self, wrapper): from .codegen.wrapper import ReuseLine wrapper.add_import_once( "from torch.distributed._functional_collectives_impl import _wait_tensor" ) (input_collective,) = (t.codegen_reference() for t in self.inputs) wrapper.writeline(f"{input_collective} = _wait_tensor({input_collective})") # wait op still needs to produce a 'buffer' that represents the tensor output. # this is a symbolic gesture, and it gets handled by WrapperCodegen. # codegen outputs a '# reuse' line that assigns the input buffer here ('input_collective') # to a new name (`self.get_name()`) and `del`s the old name. wrapper.writeline(ReuseLine(wrapper, self.inputs[0], self, delete_old=False)) @classmethod def create(cls, collective_op: "TensorBox"): # TODO(whc) i'm not sure what's going on here, this probably means I missed something upstream collective_op.decide_layout() return Wait( layout=AliasedLayout(collective_op), inputs=[collective_op], ) def get_alias_names(self): # Signal to codegen that our output buffer isn't safe to reuse return [self.inputs[0].codegen_reference()] def get_mutation_names(self): # The generated `_wait_tensor` op mutates the input tensor return [self.inputs[0].codegen_reference()] class CollectiveKernel(ExternKernel): """ Each collective should follow the pattern: - extend InPlaceCollectiveKernel or OutOfPlaceCollectiveKernel. - the kernel delegates into c10d processgroup, which returns a 'work' obj - the work obj is registered via _register_tensor_work so it can be waited on later """ def __init__(self, layout, inputs, constant_args): super().__init__(None, layout, inputs, constant_args) = V.graph.register_buffer(self) def should_emit_register_tensor_work(self): return True def should_emit_find_or_create_pg(self): return True def codegen_collective(self, wrapper, output_name, input_names): # factor so the boilerplate can be handled in CollectiveKernel.codegen raise NotImplementedError("Must implement") def codegen_output(self, wrapper, output_name, input_names): # factor so the boilerplate can be handled in CollectiveKernel.codegen raise NotImplementedError("Must implement") @classmethod def wrap_inputs_as_inplace(cls, inputs): def wrap_input(var): op = InPlaceHint( FlexibleLayout(var.get_device(), var.get_dtype(), var.get_size()), var ) return TensorBox.create(op) return list(map(wrap_input, inputs)) def codegen(self, wrapper): wrapper.add_import_once("import torch.distributed as dist") wrapper.add_import_once("import torch.distributed.distributed_c10d as c10d") wrapper.add_import_once( "import torch.distributed._functional_collectives_impl as fun_col_impl" ) # extract references to our args in string form for codegen output input_names = [t.codegen_reference() for t in self.inputs] output_name = self.get_name() tag, ranks, group_size = self.constant_args if self.should_emit_find_or_create_pg(): # TODO: avoid more than one ref of the same pg (even though they are cached inside the api) wrapper.writeline( f"{output_name}_pg = c10d._find_or_create_pg_by_ranks_and_tag('{tag}', {ranks}, {group_size})" ) self.codegen_output(wrapper, output_name, input_names) self.codegen_collective(wrapper, output_name, input_names) if self.should_emit_register_tensor_work(): wrapper.writeline( f"fun_col_impl._register_tensor_work({output_name}, {output_name}_work)" ) class InPlaceCollectiveKernel(CollectiveKernel): """ InPlaceCollectiveKernel are those with in-out arguments such as all_reduce. Extend this kernel if your collective needs to modify its inputs in-place. """ def __init__(self, layout, inputs, constant_args): super().__init__(layout, inputs, constant_args) def should_allocate(self): return False def has_side_effects(self): return True def codegen_output(self, wrapper, output_name, input_names): if len(input_names) > 1: wrapper.writeline(f"{output_name} = [{','.join(input_names)}] ") else: wrapper.writeline(f"{output_name} = {input_names[0]}") class OutOfPlaceCollectiveKernel(CollectiveKernel): """ OutOfPlaceCollectiveKernel are those that allocate their outputs and leave their inputs inplace, such as all_gather. """ def __init__(self, layout, inputs, outputs, constant_args): super().__init__(layout, inputs + outputs, constant_args) self.outputs = outputs self.original_inputs = inputs # NOTE: As seen in issue #108780, output buffers of out-of-place collectives # could be incorrectly reused. As a safety measure, here we just ban the reuse of them. # TODO: A better fix is to figure out how to propagate the aliases properly, # so that the buffer is only reused after all its users have consumed it. for x in self.outputs: V.graph.never_reuse_buffers.add( def should_allocate(self): return False def has_side_effects(self): return True def codegen_output(self, wrapper, output_name, input_names): input_names = [t.codegen_reference() for t in self.original_inputs] wrapper.writeline(f"{output_name}_inputs = [{','.join(input_names)}]") wrapper.writeline(f"{output_name} = [{','.join( for x in self.outputs)}]") @classmethod def create_output_buffers(cls, inputs, size_cb=None): outputs = [] for input in inputs: new_size = input.get_size() if size_cb is not None: size_cb(new_size) # new_size[0] *= group_size buff = OutputBuffer( layout=FlexibleLayout( device=input.get_device(), dtype=input.get_dtype(), size=new_size, ), ) outputs.append(buff) return outputs @classmethod def create_output_nodes(cls, coll, output_buffers): return [ MultiOutputNoSizeAssert( out_t.layout, coll, f"[{i}]", ) for i, out_t in enumerate(output_buffers) ] class InPlaceHint(ExternKernel): """ Helper OP to encode an in/out argument that tries to make it inplace whenever possible. Wrap the input of your inplace op to enable this behavior. The design is based on two key decisions: - this node is responsible for allocating the in/out buffer used by the collective. This is controlled by the ``should_allocate`` method that returns True here and False for the collective node - The scheduler special-case this node and enable it to reuse its input. """ def codegen(self, wrapper): input_name = self.inputs[0].codegen_reference() output_name = self.get_name() if not wrapper.did_reuse(self, self.inputs[0]): wrapper.writeline(f"{output_name}.copy_({input_name}) #no reuse") def __init__(self, layout, input): input = self.realize_input(input) super().__init__(None, layout, self.unwrap_storage([input]), ()) = V.graph.register_buffer(self) def should_allocate(self): return True class OutputBuffer(ExternKernel): """ Represent the output buffer used by ops that require multiple of them """ def __init__(self, layout): super().__init__(name=None, layout=layout, inputs=[]) = V.graph.register_buffer(self) def should_allocate(self): return True def codegen(self, wrapper): wrapper.writeline(f"# collective out buffer {}") class MultiOutputNoSizeAssert(MultiOutput): """ Extract partial output from a multi-output OP. Works like MultiOutput but doesn't assert size. This must be a property guaranteed by the op emitting this. """ def __init__(self, layout, input, index): super().__init__(layout, input, []) self.index = index def codegen(self, wrapper): wrapper.writeline( f"{self.get_name()} = {self.inputs[0].get_name()}{self.index}" ) class Broadcast(InPlaceCollectiveKernel): def __init__(self, layout, inputs, constant_args, src): super().__init__(layout, inputs, constant_args) self.src = src def get_mutation_names(self): return [self.inputs[0].get_name()] def get_unbacked_symbol_defs(self) -> Set[sympy.Symbol]: return set() @classmethod def create( cls, x: "TensorBox", src: int, tag: str, ranks: List[int], group_size: int ): inplace_inputs = cls.wrap_inputs_as_inplace([x]) packed = Broadcast( layout=NoneLayout(inplace_inputs[0].get_device()), # type: ignore[arg-type] inputs=inplace_inputs, constant_args=[tag, ranks, group_size], src=src, ) mark_node_as_mutating(packed, inplace_inputs[0]) return inplace_inputs[0] def codegen_collective(self, wrapper, output_name, input_names): wrapper.writeline( f"{output_name}_work = dist.broadcast(" f"{output_name}, async_op=True, group={output_name}_pg, src={self.src})" ) class AllReduceCoalesced(InPlaceCollectiveKernel): def __init__(self, layout, inputs, constant_args, reduce_op): super().__init__(layout, inputs, constant_args) self.reduce_op = reduce_op def should_allocate(self): return False def get_mutation_names(self): return [self.inputs[0].get_name()] def get_unbacked_symbol_defs(self) -> Set[sympy.Symbol]: return set() @classmethod def create( cls, inputs: List["TensorBox"], reduce_op: str, tag: str, ranks: List[int], group_size: int, ): inplace_inputs = cls.wrap_inputs_as_inplace(inputs) packed = AllReduceCoalesced( layout=NoneLayout(inplace_inputs[0].get_device()), # type: ignore[arg-type] inputs=inplace_inputs, constant_args=[tag, ranks, group_size], reduce_op=reduce_op, ) mark_node_as_mutating(packed, inplace_inputs[0]) return inplace_inputs def codegen_collective(self, wrapper, output_name, input_names): wrapper.writeline( f"{output_name}_work = dist.all_reduce_coalesced(" f"{output_name}, " f"op=fun_col_impl._str_to_reduce_op('{str(self.reduce_op)}'), " f"group={output_name}_pg, " "async_op=True)" ) class AllReduce(InPlaceCollectiveKernel): def __init__(self, layout, inputs, constant_args, reduce_op): super().__init__(layout, inputs, constant_args) self.reduce_op = reduce_op def get_mutation_names(self): return [self.inputs[0].get_name()] def get_unbacked_symbol_defs(self) -> Set[sympy.Symbol]: return set() @classmethod def create( cls, x: "TensorBox", reduce_op: str, tag: str, ranks: List[int], group_size: int ): inplace_inputs = cls.wrap_inputs_as_inplace([x]) packed = AllReduce( layout=NoneLayout(inplace_inputs[0].get_device()), # type: ignore[arg-type] inputs=inplace_inputs, constant_args=[tag, ranks, group_size], reduce_op=reduce_op, ) mark_node_as_mutating(packed, inplace_inputs[0]) return inplace_inputs[0] def codegen_collective(self, wrapper, output_name, input_names): wrapper.writeline( f"{output_name}_work = dist.all_reduce(" f"{output_name}, async_op=True, group={output_name}_pg, op=fun_col_impl._str_to_reduce_op('{str(self.reduce_op)}'))" ) class AllGatherIntoTensor(OutOfPlaceCollectiveKernel): def __init__(self, layout, inputs, outputs, constant_args): super().__init__(layout, inputs, outputs, constant_args) @classmethod def create(cls, x: "TensorBox", tag: str, ranks: List[int], group_size: int): inputs = [cls.realize_input(x)] def compute_size(new_size): new_size[0] *= group_size outputs = cls.create_output_buffers(inputs, compute_size) layout = MultiOutputLayout(inputs[0].get_device()) packed = AllGatherIntoTensor( layout=layout, inputs=inputs, outputs=outputs, constant_args=[tag, ranks, group_size], ) return cls.create_output_nodes(packed, outputs)[0] def codegen_collective(self, wrapper, output_name, input_names): wrapper.writeline( f"{output_name}_work = dist.all_gather_into_tensor(" f"{output_name}[0], {output_name}_inputs[0], async_op=True, group={output_name}_pg)" ) class ReduceScatterTensor(OutOfPlaceCollectiveKernel): def __init__(self, layout, inputs, outputs, constant_args, reduce_op): super().__init__(layout, inputs, outputs, constant_args) self.reduce_op = reduce_op @classmethod def create( cls, x: "TensorBox", reduce_op: str, tag: str, ranks: List[int], group_size: int, ): inputs = [cls.realize_input(x)] def compute_size(new_size): new_size[0] //= group_size outputs = cls.create_output_buffers(inputs, compute_size) layout = MultiOutputLayout(inputs[0].get_device()) packed = ReduceScatterTensor( layout=layout, inputs=inputs, outputs=outputs, constant_args=[tag, ranks, group_size], reduce_op=reduce_op, ) return cls.create_output_nodes(packed, outputs)[0] def codegen_collective(self, wrapper, output_name, input_names): wrapper.writeline( f"{output_name}_work = dist.reduce_scatter_tensor(" f"{output_name}[0], {output_name}_inputs[0], " f"async_op=True, group={output_name}_pg, op=fun_col_impl._str_to_reduce_op('{str(self.reduce_op)}'))" ) class AllGatherIntoTensorCoalesced(OutOfPlaceCollectiveKernel): def __init__(self, layout, inputs, outputs, constant_args): super().__init__(layout, inputs, outputs, constant_args) @classmethod def create( cls, inputs: List["TensorBox"], tag: str, ranks: List[int], group_size: int, ): inputs = [cls.realize_input(x) for x in inputs] def compute_size(new_size): new_size[0] *= group_size outputs = cls.create_output_buffers(inputs, compute_size) layout = MultiOutputLayout(inputs[0].get_device()) packed = AllGatherIntoTensorCoalesced( layout=layout, inputs=inputs, outputs=outputs, constant_args=[tag, ranks, group_size], ) return outputs # return cls.create_output_nodes(packed, outputs) def codegen_collective(self, wrapper, output_name, input_names): wrapper.writeline( f"{output_name}_work = fun_col_impl._all_gather_into_tensor_coalesced_fallback(" f"output_tensors={output_name}, " f"input_tensors={output_name}_inputs, " f"group={output_name}_pg, " "async_op=True)" ) class ReduceScatterTensorCoalesced(OutOfPlaceCollectiveKernel): def __init__(self, layout, inputs, outputs, constant_args, reduce_op): super().__init__(layout, inputs, outputs, constant_args) self.reduce_op = reduce_op @classmethod def create( cls, inputs: List["TensorBox"], reduce_op: str, tag: str, ranks: List[int], group_size: int, ): inputs = [cls.realize_input(x) for x in inputs] def compute_size(new_size): new_size[0] //= group_size outputs = cls.create_output_buffers(inputs, compute_size) layout = MultiOutputLayout(inputs[0].get_device()) _ = ReduceScatterTensorCoalesced( layout=layout, inputs=inputs, outputs=outputs, constant_args=[tag, ranks, group_size], reduce_op=reduce_op, ) return outputs def codegen_collective(self, wrapper, output_name, input_names): wrapper.writeline( f"{output_name}_work = fun_col_impl._reduce_scatter_tensor_coalesced_fallback(" f"output_tensors={output_name}, " f"input_tensors={output_name}_inputs, " f"op=fun_col_impl._str_to_reduce_op('{str(self.reduce_op)}'), " f"group={output_name}_pg, " "async_op=True)" ) # TODO(yifu): replace the CollectiveKernel IR hierarchy with _CollectiveKernel. class _CollectiveKernel(FallbackKernel): def should_allocate(self): return False def has_side_effects(self): return True # This is identical to FallbackKernel.set_cpp_kernel(), minus the # part that checks against input aliasing and mutation. def set_cpp_kernel(self, kernel): from .codegen.wrapper import get_cpp_op_schema self.cpp_kernel_name = self.cpp_kernel_overload_name = kernel._schema.overload_name self.cpp_kernel_key = f"{self.cpp_kernel_name.replace('::', '_')}_{self.cpp_kernel_overload_name}" # type: ignore[union-attr] self.cpp_op_schema = get_cpp_op_schema(kernel) self.ordered_kwargs_for_cpp_kernel = [ for x in kernel._schema.arguments if x.kwarg_only ] # NOTE: [In-Place Collective Safety] # Between the initiation and completion of an in-place collective, the # input buffers are subject to both volatile reads and volatile writes. # They must not be read, written to or reused by another kernel. To ensure # the constraints, we model collective -> wait_tensor as as two-step # mutation of the input buffers. @classmethod def create_inplace( cls, kernel, inputs: Union[TensorBox, List[TensorBox]], *args, **kwargs ) -> None: cpp_kernel_name = kernel._name python_kernel_name = cpp_kernel_name.replace("::", ".") with V.graph.fake_mode: ( example_output, tensor_args, non_tensor_args, unflatten_args, ) = cls.process_kernel(kernel, inputs, *args, **kwargs) for tensor_arg in tensor_args: tensor_arg.realize() packed = cls( NoneLayout(tensor_args[0].get_device()), kernel, tensor_args, non_tensor_args, unflatten_args, ) packed.cpp_kernel_name = cpp_kernel_name packed.python_kernel_name = python_kernel_name def mark_mutation(x): if isinstance(, BaseView): x = MutationOutput(x.layout, x, packed) pytree.tree_map(lambda inp: mark_mutation(inp), inputs) # NOTE: [Out-of-Place Collective Safety] # Between the initiation and completion of an out-of-place collective: # # Input buffers: # - Are subject to volatile reads # - Can be read by another kernel # - Must not be written to or reused by another kernel # # Output buffers: # - Are subject to volatile writes # - Must not be read, written to or reused by another kernel # # To ensure the safety of input buffers without sacrificing read # availability, we add input buffers as read deps of wait_tensor kernels. # # To ensure the safety of output buffers, we model wait_tensor as a # mutation to the output buffer. Note we also assumes the user program being # correct and the output buffer is not consumed by kernels other than # wait_tensor. # # TODO(yifu): add a pre-grad pass to validate the correctness of collective # usage in the user program. @classmethod def create_out_of_place( cls, kernel, inputs: Union[TensorBox, List[TensorBox]], *args, **kwargs ): cpp_kernel_name = kernel._name python_kernel_name = cpp_kernel_name.replace("::", ".") with V.graph.fake_mode: ( example_output, tensor_args, non_tensor_args, unflatten_args, ) = cls.process_kernel(kernel, inputs, *args, **kwargs) for tensor_arg in tensor_args: tensor_arg.realize() if isinstance(example_output, list): device = cls.find_device(tensor_args, example_output) packed = cls( MultiOutputLayout(device), kernel, tensor_args, non_tensor_args, unflatten_args, ) packed.cpp_kernel_name = cpp_kernel_name packed.python_kernel_name = python_kernel_name packed.outputs = [ MultiOutput( cls.tensor_to_layout(tensor), packed, [(list, i)], ) for i, tensor in enumerate(example_output) ] return packed.outputs else: packed = cls( cls.tensor_to_layout(example_output), kernel, tensor_args, non_tensor_args, unflatten_args, ) packed.cpp_kernel_name = cpp_kernel_name packed.python_kernel_name = python_kernel_name packed.outputs = [packed] return packed class _WaitKernel(_CollectiveKernel): def get_volatile_reads(self): inp = self.inputs[0] if isinstance(inp, _CollectiveKernel): # Out-of-place single-output return [inp.inputs[0]] elif isinstance(inp, MultiOutput): # This can be two things: # 1. Out-of-place multi-output coll # 2. In-place coll with inputs coming from another MultiOutput coll = inp.inputs[0] # Case 1 if isinstance(coll, _CollectiveKernel): _, idx = inp.indices[0] return [coll.inputs[idx]] # Case 2 return [] else: # In-place requires no additional deps handling for volatile # reads since the inputs are mutated. return [] @classmethod def create_wait(cls, kernel, inp: TensorBox) -> None: with V.graph.fake_mode: ( example_output, tensor_args, non_tensor_args, unflatten_args, ) = cls.process_kernel(kernel, inp) packed = cls( NoneLayout(inp.get_device()), kernel, tensor_args, non_tensor_args, unflatten_args, ) if isinstance(, BaseView): inp = MutationOutput(inp.layout, inp, packed) def get_read_writes(self): read_writes = super().get_read_writes() # See [Out-of-Place Collective Safety]. volatile_reads = self.get_volatile_reads() for vr in volatile_reads: read_writes.reads.add(dependencies.StarDep(vr.get_name())) return read_writes # NB: recursive structure here reflects val_to_arg_str, avoid # calling free_unbacked_symbols on "exotic" types that don't get pexpr # treatment def maybe_free_unbacked_symbols(s): if isinstance(s, (SymTypes, sympy.Expr)): # This branch should be impossible in return position return free_unbacked_symbols(s) elif isinstance(s, (tuple, list)): r = set() for t in s: r |= maybe_free_unbacked_symbols(t) return r elif isinstance(s, torch.Tensor): # This branch is impossible in constant-args position return free_unbacked_symbols(s) else: return set() class AllToAllSingle(OutOfPlaceCollectiveKernel): def __init__( self, layout, inputs, outputs, constant_args, output_split_sizes, input_split_sizes, ): super().__init__(layout, inputs, outputs, constant_args) self.output_split_sizes = output_split_sizes self.input_split_sizes = input_split_sizes def get_unbacked_symbol_uses(self) -> Set[sympy.Symbol]: r = set() if self.output_split_sizes is not None: r |= free_unbacked_symbols(self.output_split_sizes) if self.input_split_sizes is not None: r |= free_unbacked_symbols(self.input_split_sizes) return r @classmethod def create( cls, x: "TensorBox", output_split_sizes: Optional[List[Expr]], input_split_sizes: Optional[List[Expr]], tag: str, ranks: List[int], group_size: int, ): inputs = [cls.realize_input(x)] def compute_size(new_size): if output_split_sizes is not None: new_size[0] = sum(output_split_sizes) outputs = cls.create_output_buffers(inputs, compute_size) layout = MultiOutputLayout(inputs[0].get_device()) packed = AllToAllSingle( layout=layout, inputs=inputs, outputs=outputs, constant_args=[tag, ranks, group_size], output_split_sizes=output_split_sizes, input_split_sizes=input_split_sizes, ) return cls.create_output_nodes(packed, outputs)[0] def codegen_collective(self, wrapper, output_name, input_names): tag, ranks, group_size = self.constant_args # TODO: might be necessary to do some pretty printing on # split sizes wrapper.writeline( f"{output_name}_work = dist.all_to_all_single(" f"{output_name}[0], {output_name}_inputs[0], " f"output_split_sizes={self.output_split_sizes}, " f"input_split_sizes={self.input_split_sizes}, " f"group={output_name}_pg, async_op=True)" )