import builtins import dataclasses import inspect import math import sys import weakref from collections import defaultdict from typing import Any, Callable, Dict, List, Optional, Set, Tuple, TYPE_CHECKING, Union import torch from torch._subclasses.fake_tensor import FakeTensor from torch.utils._pytree import SUPPORTED_NODES from .exported_program import ExportedProgram if TYPE_CHECKING: from sympy import Symbol from torch._guards import Source from ..fx.experimental.symbolic_shapes import ShapeEnv, StrictMinMaxConstraint __all__ = ["Constraint", "Dim", "dims", "dynamic_dim"] class _Dim(type): """ Metaclass for :func:`Dim` types. """ @staticmethod def readable(name, min_, max_): if min_ == 2: min_ = None if max_ == sys.maxsize - 1: max_ = None if min_ is None and max_ is None: return f"Dim('{name}')" if min_ is None: return f"Dim('{name}', max={max_})" if max_ is None: return f"Dim('{name}', min={min_})" return f"Dim('{name}', min={min_}, max={max_})" def __add__(cls, other): # e.g., dim + 1 if type(other) is not int: raise NotImplementedError( f"Attempted to add {other} to {cls.__name__}, where an integer was expected. " "(Only increasing linear operations with integer coefficients are supported.)" ) return cls._derive(lambda x: x + other) def __radd__(cls, other): return cls + other def __sub__(cls, other): # e.g., dim - 1 if type(other) is not int: raise NotImplementedError( f"Attempted to subtract {other} from {cls.__name__}, where an integer was expected. " "(Only increasing linear operations with integer coefficients are supported.)" ) return cls._derive(lambda x: x - other) def __rsub__(cls, other): raise NotImplementedError( f"Attempted to negate {cls.__name__}. " "(Only increasing linear operations with integer coefficients are supported.)" ) def __mul__(cls, other): # e.g., dim * 2 if type(other) is not int or other <= 0: raise NotImplementedError( f"Attempted to multiply {other} with {cls.__name__}, where a positive integer was expected. " "(Only increasing linear operations with integer coefficients are supported.)" ) return cls._derive(lambda x: x * other) def __rmul__(cls, other): return cls * other def _derived_name(cls, fn): from sympy import sympify return str(fn(sympify(cls.__name__))) def _derive(cls, fn): return _DerivedDim(cls._derived_name(fn), (int,), {"root": cls, "fn": fn}) class _DerivedDim(_Dim): """ Metaclass for derived :func:`Dim` types. Currently we only support increasing linear expressions with integer coefficients. In other words, a derived Dim can always be written in the form Ax + B, where x is a regular Dim (i.e., non-derived Dim), A and B are integers, and A is positive. (In particular, the latter ensures that x < y => Ax + B < Ay + B.) These restrictions on the form of derived Dims makes the metatheory simpler: e.g., it simplifies computing ranges for derived Dims, solving for underlying regular Dims, deciding equalities between derived Dims, and so on. The function lambda x: Ax + B is expressed by `fn`, where x is a normal Dim, `root`. The range of a derived Dim is computed by mapping `fn` over the range of its `root`. """ @property def min(self): # assume that self.fn is an increasing function # TODO(avik): use sympy value range analysis instead? from sympy import Integer _min_symint = self.fn(Integer(self.root.min)) # type: ignore[attr-defined] assert _min_symint >= 2, ( f"Expected derived min value of {self.__name__} to be >= 2. " f"Please specify an appropriate min value for {self.root.__name__} " # type: ignore[attr-defined] f"(currently {self.root.min})." # type: ignore[attr-defined] ) return int(_min_symint) @property def max(self): # assume that self.fn is an increasing function # TODO(avik): use sympy value range analysis instead? from sympy import Integer _max_symint = self.fn(Integer(self.root.max)) # type: ignore[attr-defined] assert _max_symint <= sys.maxsize - 1, ( f"Expected derived max value of {self.__name__} to be <= {sys.maxsize - 1}. " f"Please specify an appropriate max value for {self.root.__name__} " # type: ignore[attr-defined] f"(currently {self.root.max})." # type: ignore[attr-defined] ) return int(_max_symint) def _derive(self, fn): # We support nesting, e.g., 2*dim + 1. # This is implemented by composing operations on the same root. # As a consequence, roots are always regular Dims (i.e., not derived Dims). return _DerivedDim( self._derived_name(fn), (int,), {"root": self.root, "fn": lambda x: fn(self.fn(x))}, # type: ignore[attr-defined] ) def Dim(name: str, *, min: Optional[int] = None, max: Optional[int] = None): """ :func:`Dim` constructs a type analogous to a named symbolic integer with a range. It can be used to describe multiple possible values of a dynamic tensor dimension. Note that different dynamic dimensions of the same tensor, or of different tensors, can be described by the same type. Args: name (str): Human-readable name for debugging. min (Optional[int]): Minimum possible value of given symbol (inclusive) max (Optional[int]): Maximum possible value of given symbol (inclusive) Returns: A type that can be used in dynamic shape specifications for tensors. """ _min = 2 if min is None else builtins.max(min, 2) _max = sys.maxsize - 1 if max is None else builtins.min(max, sys.maxsize - 1) assert _max > _min, f"Cannot create Dim with inconsistent min={min}, max={max}" dim = _Dim(name, (int,), {"min": _min, "max": _max}) dim.__module__ = getattr( inspect.getmodule(inspect.stack()[1][0]), "__name__", "__main__" ) return dim def dims(*names: str, min: Optional[int] = None, max: Optional[int] = None): """ Util to create multiple :func:`Dim` types. """ return tuple(Dim(name, min=min, max=max) for name in names) @dataclasses.dataclass class _ConstraintTarget: """ This represents input tensor dimensions. Don't create this class directly; instead, use :func:`dynamic_dim`. """ w_tensor: Any # weakref to torch.Tensor # TODO: We don't need t_id; we can get it off of w_tensor t_id: int dim: int class _ConstraintFactory(type): """ Metaclass that ensures a private constructor for :class:`_Constraint` """ def __call__(cls, *args, **kwargs): raise TypeError( f"{cls.__module__}.{cls.__qualname__} has no public constructor. " f"Please use torch.export.dynamic_dim() to create one" ) def _create( cls, w_tensor, t_id, dim, constraint_range, shared=None, debug_name=None ): return super().__call__( w_tensor, t_id, dim, constraint_range, shared, debug_name ) def _create_constraint( w_tensor, t_id, dim, constraint_range, shared=None, debug_name=None ): return _Constraint._create( w_tensor, t_id, dim, constraint_range, shared, debug_name ) @dataclasses.dataclass class _Constraint(_ConstraintTarget, metaclass=_ConstraintFactory): """ .. warning:: Do not construct :class:`_Constraint` directly, use :func:`dynamic_dim` instead. This represents constraints on input tensor dimensions, e.g., requiring them to be fully polymorphic or within some range. """ # NOTE(avik): In the future, this could be Union[StrictMinMaxConstraint, ] constraint_range: "StrictMinMaxConstraint" # Represent that `constraint_range` is shared with another _ConstraintTarget, which # typically arises because of a specified equality with another dynamic dimension. shared: Optional[_ConstraintTarget] = None debug_name: Optional[str] = None def _clone_with_range(self, lower=2, upper=math.inf): # Import sympy locally from torch.fx.experimental.symbolic_shapes import StrictMinMaxConstraint from torch.utils._sympy.value_ranges import ValueRanges constraint_range = StrictMinMaxConstraint( vr=self.constraint_range.vr & ValueRanges(lower=lower, upper=upper), warn_only=False, ) return _create_constraint( self.w_tensor, self.t_id, self.dim, constraint_range, self.shared, self.debug_name, ) def __ge__(self, lower): return self._clone_with_range(lower=lower) def __gt__(self, lower): return self._clone_with_range(lower=lower + 1) def __le__(self, upper): return self._clone_with_range(upper=upper) def __lt__(self, upper): return self._clone_with_range(upper=upper - 1) def __bool__(self): # NOTE(avik): We do not support compound expressions like a <= x <= b. # This is because Python implicitly desugars them into bool(a <= x) and bool(x <= b), # and moreover, enforces that any overload of __bool__ must return True or False. # FWIW, sympy also raises TypeError in this case. raise TypeError( "Cannot determine truth value of _Constraint. " "If you are trying to combine _Constraint's with logical connectives, " "you can specify them separately instead." ) @property def serializable_spec(self): # We need a serialization compatible format of the constraint so that it # can be savedin the graph module w/o breaking the module serialization. # The saved constraints will be used directly for the post-exporting pass # that converts constraints to runtime assertion. The saved constraints # will not be saved in the serialized module. # TODO: A better way is needed. Currently we use 't_id' to map the constraint, # which is not reliable return { "t_id": self.t_id, "dim": self.dim, "min": self.constraint_range.vr.lower, "max": self.constraint_range.vr.upper, } def __eq__(self, other): if not isinstance(other, _Constraint): raise TypeError( "A dynamic dim can be specified equal only to another dynamic dim. " f"Equality with {type(other)} is not supported." ) # import sympy locally from torch.fx.experimental.symbolic_shapes import StrictMinMaxConstraint constraint_range = StrictMinMaxConstraint( vr=self.constraint_range.vr & other.constraint_range.vr, warn_only=False, ) if self.debug_name is None: debug_name = other.debug_name else: assert other.debug_name is None or self.debug_name == other.debug_name debug_name = self.debug_name return _create_constraint( self.w_tensor, self.t_id, self.dim, constraint_range, shared=_ConstraintTarget(other.w_tensor, other.t_id, other.dim), debug_name=debug_name, ) @dataclasses.dataclass class _PhantomRoot: """ This represents the root of a derived Dim where the root does not directly specify the shape of any input dimension, but the derived Dim does. e.g., the input shapes 2*dim and dim + 1 are related via a "phantom" dim. The fields `name`, `constraint_range`, and `val` carried by a phantom root help create a symbol for it. Any derived dims with this phantom root are backed by expressions over this symbol. """ name: str constraint_range: "StrictMinMaxConstraint" val: int @dataclasses.dataclass class _DerivedConstraint(_ConstraintTarget): """ This represents a derived Dim, whose root is either a regular constraint target (which directly specifies the shape of some input dimension) or a phantom root (which does so indirectly). """ # NOTE: This is not currently a subclass of _Constraint because we do not support # `shared` for derived `Dim`s. Indeed, sharing is a necessary concept only for # legacy constraints based on `dynamic_dim`: equality can be expressed simply by # reusing the same (derived or normal) `Dim`. root: Union[_ConstraintTarget, _PhantomRoot] fn: Callable constraint_range: "StrictMinMaxConstraint" debug_name: Optional[str] = None @property def shared(self): # Some code paths expect a union of _Constraint and _DerivedConstraint. # Thus we expose a `shared` field that is always None. # TODO(avik): clean this up return None @property def serializable_spec(self): # same as _Constraint.serializable_spec return { "t_id": self.t_id, "dim": self.dim, "min": self.constraint_range.vr.lower, "max": self.constraint_range.vr.upper, } Constraint = Union[_Constraint, _DerivedConstraint] def dynamic_dim(t: torch.Tensor, index: int, debug_name: Optional[str] = None): """ .. warning:: (This feature is DEPRECATED. See :func:`Dim` instead.) :func:`dynamic_dim` constructs a :class:`_Constraint` object that describes the dynamism of a dimension ``index`` of tensor ``t``. :class:`_Constraint` objects should be passed to ``constraints`` argument of :func:`export`. Args: t (torch.Tensor): Example input tensor that have dynamic dimension size(s) index (int): Index of dynamic dimension Returns: A :class:`_Constraint` object that describes shape dynamism. It can be passed to :func:`export` so that :func:`export` does not assume static size of specified tensor, i.e. keeping it dynamic as a symbolic size rather than specializing according to size of example tracing input. Specifically :func:`dynamic_dim` can be used to express following types of dynamism. - Size of a dimension is dynamic and unbounded:: t0 = torch.rand(2, 3) t1 = torch.rand(3, 4) # First dimension of t0 can be dynamic size rather than always being static size 2 constraints = [dynamic_dim(t0, 0)] ep = export(fn, (t0, t1), constraints=constraints) - Size of a dimension is dynamic with a lower bound:: t0 = torch.rand(10, 3) t1 = torch.rand(3, 4) # First dimension of t0 can be dynamic size with a lower bound of 5 (inclusive) # Second dimension of t1 can be dynamic size with a lower bound of 2 (exclusive) constraints = [ dynamic_dim(t0, 0) >= 5, dynamic_dim(t1, 1) > 2, ] ep = export(fn, (t0, t1), constraints=constraints) - Size of a dimension is dynamic with an upper bound:: t0 = torch.rand(10, 3) t1 = torch.rand(3, 4) # First dimension of t0 can be dynamic size with a upper bound of 16 (inclusive) # Second dimension of t1 can be dynamic size with a upper bound of 8 (exclusive) constraints = [ dynamic_dim(t0, 0) <= 16, dynamic_dim(t1, 1) < 8, ] ep = export(fn, (t0, t1), constraints=constraints) - Size of a dimension is dynamic and it is always equal to size of another dynamic dimension:: t0 = torch.rand(10, 3) t1 = torch.rand(3, 4) # Sizes of second dimension of t0 and first dimension are always equal constraints = [ dynamic_dim(t0, 1) == dynamic_dim(t1, 0), ] ep = export(fn, (t0, t1), constraints=constraints) - Mix and match all types above as long as they do not express conflicting requirements """ from torch._dynamo.exc import UserError, UserErrorType if not isinstance(t, torch.Tensor): raise UserError( UserErrorType.DYNAMIC_DIM, f"Expected tensor as input to dynamic_dim but got {type(t)}", ) if t.dim() < 1: raise UserError( UserErrorType.DYNAMIC_DIM, "Cannot mark 0-dimension tensors to be dynamic" ) if index >= t.dim(): raise UserError( UserErrorType.DYNAMIC_DIM, f"Expected the dimension passed to dynamic_dim to be in the range [0:{t.dim()-1}]" f" but got {index}, which is out of bounds for the given tensor.", ) # Import sympy locally import sympy from torch.fx.experimental.symbolic_shapes import StrictMinMaxConstraint from torch.utils._sympy.value_ranges import ValueRanges return _create_constraint( weakref.ref(t), id(t), index, StrictMinMaxConstraint( vr=ValueRanges(lower=2, upper=sympy.oo), warn_only=False ), debug_name=debug_name, ) def _process_equalities( constraint: Constraint, get_sources: Callable[[int, int], List["Source"]], shape_env: "ShapeEnv", source_pairs: List[Tuple["Source", "Source"]], derived_equalities: List[Tuple["Source", Union["Source", "Symbol"], Callable]], phantom_symbols: Dict[str, "Symbol"], ): """ Updates `source_pairs`, `derived_equalities`, and `phantom_symbols` (which become fields of `EqualityConstraint`) based on a given input `constraint`. """ source, *other_sources = get_sources(constraint.t_id, constraint.dim) # When t.size()[dim] maps to src0, src1, ..., srcN, we add # constraints that make src0 "equal" to src1, ..., srcN. source_pairs.extend((source, other_source) for other_source in other_sources) if not isinstance(constraint, _DerivedConstraint): if constraint.shared is not None: # Moreover, when t.size()[dim] is specified equal to t'.size()[dim'] # and t'.size()[dim'] maps to src1', ..., srcN', we add # constraints that also make src0 "equal" to src1', ..., srcN'. other_sources = get_sources(constraint.shared.t_id, constraint.shared.dim) source_pairs.extend( (source, other_source) for other_source in other_sources ) else: # branch based on the root of the _DerivedConstraint if not isinstance(constraint.root, _PhantomRoot): # either root points to an input source root = get_sources(constraint.root.t_id, constraint.root.dim)[0] # type: ignore[assignment] else: # or root points to a phantom symbol if constraint.root.name in phantom_symbols: root = phantom_symbols[constraint.root.name] # type: ignore[assignment] else: # create a phantom symbol in the shape env based on the _PhantomRoot root = shape_env.create_symbol( val=constraint.root.val, source=torch._dynamo.source.ConstantSource(constraint.root.name), dynamic_dim=torch.fx.experimental.symbolic_shapes.DimDynamic.DYNAMIC, constraint_dim=constraint.root.constraint_range, ) phantom_symbols[constraint.root.name] = root # type: ignore[assignment] fn = constraint.fn # A derived equality (source, root, fn) informally corresponds to source = fn(root). # Here source describes an input and root might describe another input or a phantom symbol. derived_equalities.append((source, root, fn)) def _process_dynamic_shapes( f: Callable, args: Tuple[Any, ...], kwargs: Optional[Dict[str, Any]] = None, dynamic_shapes: Optional[Union[Dict[str, Any], Tuple[Any], List[Any]]] = None, ) -> Optional[List[Constraint]]: from collections import defaultdict from collections.abc import Mapping, Sequence from torch._dynamo.exc import UserError, UserErrorType if dynamic_shapes is None or len(dynamic_shapes) == 0: return None kwargs = kwargs if kwargs is not None else {} def tree_zip(combined_args, dynamic_shapes): if isinstance(combined_args, (tuple, list)): if not isinstance(dynamic_shapes, Sequence): raise UserError( UserErrorType.INVALID_INPUT, f"Expected dynamic_shapes of a {type(combined_args)} to be a Sequence, " f"got {dynamic_shapes} instead", ) if len(combined_args) != len(dynamic_shapes): raise UserError( UserErrorType.INVALID_INPUT, f"Expected {dynamic_shapes} to have {len(combined_args)} items", ) for i, shape in enumerate(dynamic_shapes): yield from tree_zip(combined_args[i], shape) elif isinstance(combined_args, dict): if not isinstance(dynamic_shapes, Mapping): raise UserError( UserErrorType.INVALID_INPUT, f"Expected dynamic_shapes of a {type(combined_args)} to be a Mapping, " f"got {dynamic_shapes} instead", ) if len(combined_args) != len(dynamic_shapes): raise UserError( UserErrorType.INVALID_INPUT, f"Expected {dynamic_shapes} to have {len(combined_args)} items", ) for k, shape in dynamic_shapes.items(): yield from tree_zip(combined_args[k], shape) elif type(combined_args) in SUPPORTED_NODES: if not isinstance(dynamic_shapes, Sequence): raise UserError( UserErrorType.INVALID_INPUT, f"Expected dynamic_shapes of a user-registered class (e.g., " f"{type(combined_args)}) to be a Sequence that matches the " f"flattened structure, but got {dynamic_shapes} instead", ) yield from tree_zip( SUPPORTED_NODES[type(combined_args)].flatten_fn(combined_args)[0], dynamic_shapes, ) elif isinstance(combined_args, torch.Tensor): yield (combined_args, dynamic_shapes) else: if dynamic_shapes is not None: raise UserError( UserErrorType.INVALID_INPUT, f"Expected dynamic_shapes of a {type(combined_args)} to be None, " f"got {dynamic_shapes} instead", ) # map of Dim names representing input shape dimensions to constraints on them symbols: Dict[str, List[Constraint]] = defaultdict(list) # track roots that do not directly represent input shape dimensions phantom_roots: Dict[str, _PhantomRoot] = {} derived_constraints_with_phantom_root: List[_DerivedConstraint] = [] def to_constraint(dim, tensor, i): import sympy from torch.fx.experimental.symbolic_shapes import StrictMinMaxConstraint from torch.utils._sympy.solve import try_solve from torch.utils._sympy.value_ranges import ValueRanges def root_value(): # given tensor.shape[i] is the value of dim = fn(root), # find the value of root symbol = sympy.Symbol(dim.root.__name__, integer=True) expr = dim.fn(symbol) solution = try_solve(sympy.Eq(expr, tensor.shape[i]), symbol) if solution is not None: return int(solution[1]) # type: ignore[call-overload] else: raise UserError( # noqa: TRY200 UserErrorType.CONSTRAINT_VIOLATION, f"Expected shape[{i}] = {tensor.shape[i]} of input Tensor to be " f"of the form {expr}, where {symbol} is an integer", ) if isinstance(dim, _DerivedDim): # generate a _DerivedConstraint where the root is: # - either a _ConstraintTarget (if dim.root directly describes an input shape) # - or a _PhantomRoot (otherwise) dim_root = dim.root # type: ignore[attr-defined] if dim_root.__name__ in symbols: # root represents an input shape dimension root_constraint = symbols[dim_root.__name__][0] root = _ConstraintTarget( root_constraint.w_tensor, root_constraint.t_id, root_constraint.dim, ) elif dim_root.__name__ not in phantom_roots: # create a phantom root root = _PhantomRoot( # type: ignore[assignment] name=dim_root.__name__, constraint_range=StrictMinMaxConstraint( vr=ValueRanges(lower=dim_root.min, upper=dim_root.max), warn_only=False, ), val=root_value(), ) phantom_roots[dim_root.__name__] = root # type: ignore[assignment] else: root = phantom_roots[dim_root.__name__] # type: ignore[assignment] constraint = _DerivedConstraint( weakref.ref(tensor), id(tensor), i, root, dim.fn, # type: ignore[attr-defined] StrictMinMaxConstraint( vr=ValueRanges(lower=dim.min, upper=dim.max), warn_only=False, ), debug_name=dim.__name__, ) if isinstance(root, _PhantomRoot): # NOTE(avik): since we have not processed all inputs yet, we may replace this # with a root that does represent an input shape dimension later (see below) derived_constraints_with_phantom_root.append(constraint) else: constraint = dynamic_dim(tensor, i, debug_name=dim.__name__) if dim.min != 2: constraint = constraint >= dim.min if dim.max != sys.maxsize - 1: constraint = constraint <= dim.max return constraint bounds: Dict[str, Tuple[int, int]] = {} def check_same_bounds(dim): if dim.__name__ in symbols: min_, max_ = bounds[dim.__name__] if dim.min != min_ or dim.max != max_: this_ = _Dim.readable(dim.__name__, min_, max_) that_ = _Dim.readable(dim.__name__, dim.min, dim.max) raise UserError( UserErrorType.INVALID_INPUT, f"Found different definitions {this_} and {that_} " f"for the same symbolic dimension {dim}!", ) else: bounds[dim.__name__] = (dim.min, dim.max) def update_symbols(tensor, shape): if isinstance(shape, dict): for i, dim in shape.items(): if isinstance(dim, _Dim): check_same_bounds(dim) constraint = to_constraint(dim, tensor, i) symbols[dim.__name__].append(constraint) else: if dim is not None: raise UserError( UserErrorType.INVALID_INPUT, f"Unexpected item #{i} ({dim}) in dynamic_shape {shape} of Tensor, " "try None instead", ) elif isinstance(shape, (tuple, list)): for i, dim in enumerate(shape): if isinstance(dim, _Dim): check_same_bounds(dim) constraint = to_constraint(dim, tensor, i) symbols[dim.__name__].append(constraint) else: if dim is not None: raise UserError( UserErrorType.INVALID_INPUT, f"Unexpected item #{i} ({dim}) in dynamic_shape {shape} of Tensor, " "try None instead", ) else: if shape is not None: raise UserError( UserErrorType.INVALID_INPUT, f"Unexpected dynamic_shape {shape} of Tensor, " "try None instead", ) import inspect if isinstance(f, ExportedProgram): f = f.module() signature = ( inspect.signature(f.forward) if isinstance(f, torch.nn.Module) else inspect.signature(f) ) combined_args = signature.bind(*args, **kwargs).arguments # This means user didn't specify dynamic shapes with argument names. combined_args = combined_args if isinstance(dynamic_shapes, Mapping) else list(combined_args.values()) # type: ignore[assignment] for tensor, shape in tree_zip(combined_args, dynamic_shapes): update_symbols(tensor, shape) constraints = [] for derived_constraint_with_phantom_root in derived_constraints_with_phantom_root: phantom_root_name = derived_constraint_with_phantom_root.root.name # type: ignore[union-attr] if phantom_root_name in symbols: # We found an input shape dimension corresponding to this name, so we # do not need a phantom symbol for it after all. # NOTE(avik): Overall we want to maintain the invariant that roots that # are phantom symbols are really "phantom," i.e., they cannot be represented # by any input source. This is important when we are deciding derived equalities, # since we can focus our attention exclusively on input sources: deciding # derived equalities involving phantom symbols are, in comparison, trivial. derived_constraint_with_phantom_root.root = symbols[phantom_root_name][0] for dynamic_dims in symbols.values(): if all( isinstance(dynamic_dim, _DerivedConstraint) for dynamic_dim in dynamic_dims ): constraints.extend(dynamic_dims) else: primary, *others = dynamic_dims if others: for other in others: constraints.append(primary == other) # type: ignore[arg-type] else: constraints.append(primary) return constraints # type: ignore[return-value] def _process_constraints( fake_mode, graph_module: torch.fx.GraphModule, num_lifted_params_buffers: int, example_inputs: List[torch.Tensor], ) -> Dict: """ Process the constraints stored in the graph module to return something more readable. Args: graph_module (torch.fx.GraphModule): GraphModule returned from dynamo.export, which contains the "input_shape_constraints" and "inline_constraints" metadata example_inputs: Flattened list of example inputs used to export the graph module Returns: range_constraints (Dict[sympy.Symbol, ValueRanges]): Mapping of symbols (from SymInts) appearing in the fake tensors in node.meta["val"] to their range constraints, which are a tuple containing (lower, upper) constraints. """ from torch._export.passes.add_runtime_assertions_for_constraints_pass import ( InputDim, ) # Import sympy locally from torch.fx.experimental.symbolic_shapes import SymInt from torch.utils._sympy.value_ranges import ValueRanges input_shape_constraints = graph_module.meta.get("input_shape_constraints", []) inline_constraints = graph_module.meta.get("inline_constraints", []) # Create dict mapping tensor_id to node names tensor_id_to_nodes: Dict[int, List[str]] = defaultdict(list) # Create dict mapping placeholder node names to their nodes placeholder_nodes: Dict[str, torch.fx.Node] = {} for i, node in enumerate(graph_module.graph.nodes): if node.op != "placeholder": # All placeholder nodes should be together in the beginning of the # graph break if i >= num_lifted_params_buffers: example_input = example_inputs[i - num_lifted_params_buffers] tensor_id_to_nodes[id(example_input)].append(node.name) placeholder_nodes[node.name] = node # Create dict mapping (node name, dim) a list of range (lower, upper) # constraints multi_range_constraints: Dict[InputDim, List[ValueRanges]] = defaultdict(list) for constraint in input_shape_constraints: for node in tensor_id_to_nodes[constraint["t_id"]]: node_dim = InputDim(node, constraint["dim"]) # Accumulate range constraints multi_range_constraints[node_dim].append( ValueRanges(constraint["min"], constraint["max"]) ) # Create dict mapping symbol to a singular range (lower, upper) range_constraints: Dict[Any, ValueRanges] = {} # Add inline constraints to range_constraints range_constraints = { symbol: inline_constraints[symbol] for symbol in inline_constraints } free_symbols: Set["Symbol"] = set() # Add input range constraints to range_constraints for input_dim, multi_range_constraint in multi_range_constraints.items(): # type: ignore[assignment] # Simplify the range constraints into a single range constraint # Ex. ranges [2, 10] and [3, 11] would get merged to [3, 10] min_vals = [rc.lower for rc in multi_range_constraint] max_vals = [rc.upper for rc in multi_range_constraint] min_val = max(min_vals) # type: ignore[type-var] max_val = min(max_vals) # type: ignore[type-var] assert min_val <= max_val # type: ignore[operator] # Add input node range constraints val = placeholder_nodes[input_dim.input_name].meta["val"] assert isinstance(val, FakeTensor) symint = val.shape[input_dim.dim] assert isinstance( symint, SymInt ), f"Expected SymInt but got {symint}: {type(symint)}" symbol = symint.node.expr range_constraints[symbol] = ValueRanges(min_val, max_val) free_symbols.update(symbol.free_symbols) for symbol in free_symbols: if symbol not in range_constraints: # Placeholders can have symbolic shapes that are derived expressions. # The above code will record direct range constraints for them # so that we can do runtime assertions. In addition, for serde checks # we want to record range constraints for their root symbols. range_constraints[symbol] = fake_mode.shape_env.var_to_range[symbol] return range_constraints