4363 lines
191 KiB
Python
4363 lines
191 KiB
Python
|
# mypy: ignore-errors
|
||
|
|
||
|
"""
|
||
|
``torch.fx.experimental.symbolic_shapes`` provides interfaces for interacting with
|
||
|
our symbolic shapes reasoning system that is used heavily in torch.compile. Although
|
||
|
this is not generally considered public API, when writing framework code in PyTorch
|
||
|
as well as extensions to PyTorch (e.g., in custom operator implementations), you may
|
||
|
need to make use of these APIs to setup dynamic shapes support appropriately.
|
||
|
"""
|
||
|
|
||
|
import builtins
|
||
|
import collections
|
||
|
import functools
|
||
|
import inspect
|
||
|
import itertools
|
||
|
import logging
|
||
|
import math
|
||
|
import operator
|
||
|
import re
|
||
|
import sys
|
||
|
import threading
|
||
|
import traceback
|
||
|
from collections import defaultdict
|
||
|
from contextlib import contextmanager
|
||
|
from dataclasses import dataclass, field
|
||
|
from enum import Enum
|
||
|
from functools import lru_cache
|
||
|
from typing import (
|
||
|
Any,
|
||
|
cast,
|
||
|
Callable,
|
||
|
Dict,
|
||
|
Iterable,
|
||
|
List,
|
||
|
Optional,
|
||
|
Sequence,
|
||
|
Set,
|
||
|
Tuple,
|
||
|
Type,
|
||
|
Union,
|
||
|
TYPE_CHECKING
|
||
|
)
|
||
|
from typing_extensions import TypeAlias
|
||
|
|
||
|
import torch
|
||
|
import torch.fx
|
||
|
import torch.fx.traceback as fx_traceback
|
||
|
from torch.fx.experimental import _config as config
|
||
|
|
||
|
from torch.fx.experimental.recording import (
|
||
|
FakeTensorMeta,
|
||
|
ShapeEnvEvent,
|
||
|
record_shapeenv_event,
|
||
|
replay_shape_env_events,
|
||
|
shape_env_check_state_equal
|
||
|
)
|
||
|
from torch.fx.experimental.sym_node import SymNode, SymTypes
|
||
|
|
||
|
# NB: The sym_* functions are used via getattr() and must be imported here.
|
||
|
from torch import SymBool, SymFloat, SymInt
|
||
|
from torch._guards import ShapeGuard, Source, TracingContext
|
||
|
from torch.utils._python_dispatch import is_traceable_wrapper_subclass
|
||
|
from torch.utils._sympy.functions import FloorDiv, Mod, IsNonOverlappingAndDenseIndicator
|
||
|
from torch.utils._sympy.solve import try_solve
|
||
|
from torch.utils._sympy.value_ranges import bound_sympy, SymPyValueRangeAnalysis, ValueRanges, ValueRangeError
|
||
|
from torch.utils._sympy.singleton_int import SingletonInt
|
||
|
from torch.utils._traceback import format_frame, CapturedTraceback
|
||
|
from torch._utils_internal import signpost_event
|
||
|
from torch._subclasses.meta_utils import is_sparse_any
|
||
|
|
||
|
from torch._logging import LazyString
|
||
|
|
||
|
if TYPE_CHECKING:
|
||
|
from torch._dynamo.source import TensorPropertySource
|
||
|
|
||
|
InputList = List
|
||
|
DimList = List
|
||
|
|
||
|
log = logging.getLogger(__name__)
|
||
|
|
||
|
class GuardOnDataDependentSymNode(RuntimeError):
|
||
|
pass
|
||
|
|
||
|
import sympy
|
||
|
from sympy.printing.str import StrPrinter
|
||
|
from sympy.printing.precedence import precedence, PRECEDENCE
|
||
|
|
||
|
aten = torch._ops.ops.aten # type: ignore[has-type]
|
||
|
|
||
|
__all__ = [
|
||
|
"has_symbolic_sizes_strides", "create_contiguous", "ShapeEnv", "is_concrete_int",
|
||
|
"guard_int", "guard_float", "guard_scalar", "canonicalize_bool_expr",
|
||
|
"hint_int", "SYMPY_INTERP", "free_symbols", "is_symbol_binding_fx_node",
|
||
|
"is_concrete_bool", "is_nested_int", "SHAPEENV_EVENT_KEY", "CURRENT_NODE_KEY",
|
||
|
"has_free_symbols", "sym_eq", "SymbolicContext", "StatelessSymbolicContext",
|
||
|
"StatefulSymbolicContext", "SubclassSymbolicContext", "statically_known_true",
|
||
|
"guard_size_oblivious",
|
||
|
]
|
||
|
|
||
|
# FX node metadata keys for symbolic shape FX graph.
|
||
|
SHAPEENV_EVENT_KEY = "shapeenv_event"
|
||
|
CURRENT_NODE_KEY = "current_node"
|
||
|
|
||
|
# These are modules that contain generic code for interacting with ShapeEnv
|
||
|
# which are unlikely to identify a particular interesting guard statement
|
||
|
@lru_cache(None)
|
||
|
def uninteresting_files() -> Set[str]:
|
||
|
import torch._inductor.sizevars
|
||
|
import torch._library.abstract_impl
|
||
|
import torch._subclasses.meta_utils
|
||
|
import torch._subclasses.fake_tensor
|
||
|
mods = [
|
||
|
sys.modules[__name__],
|
||
|
torch.fx.experimental.recording,
|
||
|
torch.fx.experimental.sym_node,
|
||
|
torch.fx.interpreter,
|
||
|
torch,
|
||
|
torch._inductor.sizevars,
|
||
|
torch._library.abstract_impl,
|
||
|
torch._subclasses.meta_utils,
|
||
|
torch._subclasses.fake_tensor,
|
||
|
]
|
||
|
return {inspect.getfile(m) for m in mods}
|
||
|
|
||
|
# We don't bother with the metaclass as all of the dispatching logic happens
|
||
|
# entirely from Python
|
||
|
#
|
||
|
# Didn't bother with ancestors for now, unlikely to have multiple modes for
|
||
|
# symints right now
|
||
|
|
||
|
class ConstraintViolationError(RuntimeError):
|
||
|
pass
|
||
|
|
||
|
def has_symbolic_sizes_strides(elem) -> bool:
|
||
|
return elem._has_symbolic_sizes_strides
|
||
|
|
||
|
Int = Union[torch.SymInt, int]
|
||
|
|
||
|
def create_contiguous(shape: Sequence[Int]) -> List[Int]:
|
||
|
strides: List[Int] = [1]
|
||
|
for dim in reversed(shape[:-1]):
|
||
|
strides.append(dim * strides[-1])
|
||
|
return list(reversed(strides))
|
||
|
|
||
|
def hint_int(a: Union[torch.SymInt, int], fallback: Optional[int] = None) -> int:
|
||
|
"""
|
||
|
Retrieve the hint for an int (based on the underlying real values as observed
|
||
|
at runtime). If no hint is available (e.g., because data dependent shapes),
|
||
|
if fallback is not None, use that instead (otherwise raise an error).
|
||
|
"""
|
||
|
if isinstance(a, torch.SymInt):
|
||
|
return a.node.require_hint(fallback)
|
||
|
assert type(a) is int, a
|
||
|
return a
|
||
|
|
||
|
Scalar = Union[torch.SymInt, torch.SymFloat, torch.SymBool, int, float, bool]
|
||
|
|
||
|
def has_hint(a: Scalar) -> bool:
|
||
|
if isinstance(a, SymTypes):
|
||
|
return a.node.has_hint()
|
||
|
return True
|
||
|
|
||
|
def is_concrete_int(a: Union[int, SymInt]) -> bool:
|
||
|
r""" Utility to check if underlying object
|
||
|
in SymInt is concrete value. Also returns
|
||
|
true if integer is passed in.
|
||
|
|
||
|
Args:
|
||
|
a (SymInt or int): Object to test if it int
|
||
|
"""
|
||
|
assert isinstance(a, (SymInt, int))
|
||
|
|
||
|
if isinstance(a, int):
|
||
|
return True
|
||
|
|
||
|
if isinstance(a.node.expr, sympy.core.numbers.Integer):
|
||
|
return True
|
||
|
|
||
|
return False
|
||
|
|
||
|
# In obscure Meta only situations, sympy.logic.boolalg doesn't exist at runtime.
|
||
|
# So make sure only type checker evaluates this alias.
|
||
|
# Xref: https://www.internalfb.com/diff/D53324783
|
||
|
SympyBoolean: TypeAlias = "sympy.logic.boolalg.Boolean"
|
||
|
|
||
|
def guard_size_oblivious(expr: Union[torch.SymBool, bool]) -> bool:
|
||
|
"""
|
||
|
Perform a guard on a symbolic boolean expression in a size oblivious way.
|
||
|
This is typically used when a non-oblivious test would result in a guard
|
||
|
on a data dependent value of which we don't know the value of at compile time.
|
||
|
When a guard is tested this way, we may diverge in behavior from how regular
|
||
|
PyTorch semantics would treat it. For more information, see
|
||
|
https://github.com/pytorch/pytorch/pull/118579
|
||
|
"""
|
||
|
if isinstance(expr, torch.SymBool):
|
||
|
return expr.node.guard_size_oblivious("", 0)
|
||
|
else:
|
||
|
assert isinstance(expr, bool)
|
||
|
return expr
|
||
|
|
||
|
def canonicalize_bool_expr(expr: SympyBoolean) -> SympyBoolean:
|
||
|
r""" Canonicalize a boolean expression by transforming it into a lt / le
|
||
|
inequality and moving all the non-constant terms to the rhs.
|
||
|
We canonicalize And / Ors / Not via cnf and then canonicalize their subexpr
|
||
|
recursively
|
||
|
nb. sympy.Rel.canonical is not good enough https://github.com/sympy/sympy/issues/25924
|
||
|
|
||
|
Args:
|
||
|
expr (sympy.Expr): Expression to canonicalize
|
||
|
"""
|
||
|
# Canonicalise an inequality by transforming it into a lt / le
|
||
|
# inequality and moving all the non-constant terms to the rhs
|
||
|
# We canonicalise And / Ors / Not via cnf
|
||
|
# nb. Relational.canonical in sympy is broken
|
||
|
# https://github.com/sympy/sympy/issues/25924
|
||
|
|
||
|
if not isinstance(expr, (sympy.Rel, sympy.And, sympy.Or, sympy.Not, sympy.Eq, sympy.Ne)):
|
||
|
return expr
|
||
|
|
||
|
if isinstance(expr, (sympy.And, sympy.Or, sympy.Not)):
|
||
|
expr = sympy.logic.boolalg.to_cnf(expr)
|
||
|
return _canonicalize_bool_expr_impl(expr)
|
||
|
|
||
|
def _canonicalize_bool_expr_impl(expr: SympyBoolean) -> SympyBoolean:
|
||
|
"""
|
||
|
After canonicalization, we are guaranteed to have eliminated Ge/Gt relations
|
||
|
(rewriting them to Le/Lt, respectively).
|
||
|
"""
|
||
|
if isinstance(expr, (sympy.And, sympy.Or)):
|
||
|
return type(expr)(*map(canonicalize_bool_expr, expr.args))
|
||
|
|
||
|
opposite = {sympy.Gt: sympy.Lt, sympy.Ge: sympy.Le}
|
||
|
if isinstance(expr, tuple(opposite.keys())):
|
||
|
lhs = expr.rhs - expr.lhs
|
||
|
t = opposite[type(expr)]
|
||
|
else:
|
||
|
assert isinstance(expr, (sympy.Lt, sympy.Le, sympy.Eq, sympy.Ne))
|
||
|
lhs = expr.lhs - expr.rhs
|
||
|
t = type(expr)
|
||
|
rhs = 0
|
||
|
if isinstance(lhs, sympy.Add):
|
||
|
cts = []
|
||
|
variables = []
|
||
|
for term in lhs.args:
|
||
|
if term.is_number:
|
||
|
cts.append(term)
|
||
|
else:
|
||
|
variables.append(term)
|
||
|
lhs = sympy.Add(*variables)
|
||
|
rhs = -sympy.Add(*cts)
|
||
|
return t(lhs, rhs)
|
||
|
|
||
|
def is_concrete_bool(a: Union[bool, SymBool]) -> bool:
|
||
|
r""" Utility to check if underlying object
|
||
|
in SymBool is concrete value. Also returns
|
||
|
true if integer is passed in.
|
||
|
Args:
|
||
|
a (SymBool or bool): Object to test if it bool
|
||
|
"""
|
||
|
assert isinstance(a, (SymBool, bool))
|
||
|
|
||
|
if isinstance(a, bool):
|
||
|
return True
|
||
|
|
||
|
if isinstance(a.node.expr, (sympy.logic.boolalg.BooleanTrue, sympy.logic.boolalg.BooleanFalse)):
|
||
|
return True
|
||
|
|
||
|
return False
|
||
|
|
||
|
def is_nested_int(s):
|
||
|
return isinstance(s, torch.SymInt) and s.node.is_nested_int()
|
||
|
|
||
|
def _iterate_exprs(val: Union[SymInt, torch.Tensor]) -> Iterable[sympy.Basic]:
|
||
|
if isinstance(val, SymTypes):
|
||
|
# This allow applies to the jagged layout NestedTensor case as
|
||
|
# nested ints are not symbolic
|
||
|
if is_symbolic(val):
|
||
|
yield val.node.expr
|
||
|
elif isinstance(val, sympy.Basic):
|
||
|
yield val
|
||
|
elif isinstance(val, (int, float, bool)):
|
||
|
pass
|
||
|
elif is_sparse_any(val):
|
||
|
yield from _iterate_exprs(val.size())
|
||
|
elif isinstance(val, torch.Tensor):
|
||
|
yield from _iterate_exprs(val.size())
|
||
|
yield from _iterate_exprs(val.stride())
|
||
|
yield from _iterate_exprs(val.storage_offset())
|
||
|
elif isinstance(val, (tuple, list)):
|
||
|
for s in val:
|
||
|
yield from _iterate_exprs(s)
|
||
|
elif val is None:
|
||
|
pass
|
||
|
else:
|
||
|
raise AssertionError(f"cannot extract sympy expressions from {val} {type(val)}")
|
||
|
|
||
|
def free_symbols(val: Union[SymInt, torch.Tensor]) -> Set[sympy.Symbol]:
|
||
|
if val is None:
|
||
|
return set()
|
||
|
itr = _iterate_exprs(val)
|
||
|
# we need at least 1 to call union, so we hand code the identity
|
||
|
try:
|
||
|
first_expr = next(itr)
|
||
|
except StopIteration:
|
||
|
return set()
|
||
|
|
||
|
return first_expr.free_symbols.union(*(e.free_symbols for e in itr))
|
||
|
|
||
|
def has_free_symbols(val: Union[SymInt, torch.Tensor]) -> bool:
|
||
|
"""Faster version of bool(free_symbols(val))"""
|
||
|
return not all(e.is_number for e in _iterate_exprs(val))
|
||
|
|
||
|
# Like free_symbols, but filtered to only report unbacked symbols
|
||
|
def free_unbacked_symbols(x):
|
||
|
# NB: keep synced with is_unbacked_symint
|
||
|
return {s for s in free_symbols(x) if s.name.startswith(("u", "f"))}
|
||
|
|
||
|
# WARNING: Don't use this on Dynamo produced graphs, they don't have meta
|
||
|
# setup!
|
||
|
def is_symbol_binding_fx_node(node) -> Optional[sympy.Symbol]:
|
||
|
if (
|
||
|
node.op == "placeholder" and
|
||
|
"val" in node.meta and
|
||
|
isinstance(node.meta["val"], torch.SymInt) and
|
||
|
isinstance(node.meta["val"].node.expr, sympy.Symbol)
|
||
|
):
|
||
|
return node.meta["val"].node.expr
|
||
|
return None
|
||
|
|
||
|
def find_symbol_binding_fx_nodes(graph):
|
||
|
return {
|
||
|
node.meta["val"].node.expr: node
|
||
|
for node in graph.nodes
|
||
|
if is_symbol_binding_fx_node(node)
|
||
|
}
|
||
|
|
||
|
def definitely_true(a):
|
||
|
"""
|
||
|
Returns True only if we can tell that a is True, possibly introducing
|
||
|
a guard in the process. If a depends on some unbacked SymInt, we may
|
||
|
return False even though there may exist a possible value of the SymInt
|
||
|
that would cause the expression to return True.
|
||
|
|
||
|
When is it appropriate to use definitely_true? First, if you can use
|
||
|
a higher level combinator like parallel_or/parallel_and, prefer using
|
||
|
those instead, they are definitely safe (modulo short-circuiting).
|
||
|
Second, it can be used if the program would behave equivalently if
|
||
|
definitely_true always returned False (parallel_or/parallel_and are
|
||
|
examples of this pattern, modulo short-circuiting). Finally, it even
|
||
|
be OK if the program wouldn't behave equivalently, so long as the
|
||
|
change is semantics preserving. It can be semantics preserving if
|
||
|
the program errors in more cases than it did previously (but otherwise
|
||
|
behaves identically), or if it changes some quantity in a way that
|
||
|
doesn't matter (e.g., strides often fall in this bucket.)
|
||
|
"""
|
||
|
if isinstance(a, SymBool):
|
||
|
if a.node.has_hint():
|
||
|
return guard_bool(a)
|
||
|
else:
|
||
|
return False
|
||
|
return bool(a)
|
||
|
|
||
|
def definitely_false(a):
|
||
|
"""
|
||
|
Returns True only if we can tell that a is False, possibly introducing
|
||
|
a guard in the process. If a depends on some unbacked SymInt, we may
|
||
|
return False even though there may exist a possible value of the SymInt
|
||
|
that would cause the expression a to be False. See definitely_true
|
||
|
for more usage guidance.
|
||
|
"""
|
||
|
if isinstance(a, SymBool):
|
||
|
if a.node.has_hint():
|
||
|
return not guard_bool(a)
|
||
|
else:
|
||
|
return False
|
||
|
return not bool(a)
|
||
|
|
||
|
def statically_known_true(x: Union[bool, SymBool]) -> bool:
|
||
|
"""Returns True if x can be simplified to a constant and is true.
|
||
|
|
||
|
.. note::
|
||
|
This function doesn't introduce new guards, so the expression may end
|
||
|
up evaluating to true at runtime even if this function returns False.
|
||
|
|
||
|
Args:
|
||
|
x (bool, SymBool): The expression to try statically evaluating
|
||
|
|
||
|
"""
|
||
|
if isinstance(x, SymBool):
|
||
|
expr = x.node.expr
|
||
|
shape_env = x.node.shape_env
|
||
|
try:
|
||
|
simplified = shape_env._maybe_evaluate_static(expr)
|
||
|
if simplified is not None:
|
||
|
return bool(simplified)
|
||
|
except Exception:
|
||
|
log.debug("Could not simplify %s", expr)
|
||
|
return False
|
||
|
assert isinstance(x, bool)
|
||
|
return x
|
||
|
|
||
|
|
||
|
def parallel_or(*args):
|
||
|
"""
|
||
|
Evaluate the logical OR of several arguments, avoiding guarding on
|
||
|
unbacked SymInts if another argument is definitely True.
|
||
|
"""
|
||
|
if any(statically_known_true(a) for a in args):
|
||
|
return True
|
||
|
if any(definitely_true(a) for a in args):
|
||
|
return True
|
||
|
return any(args)
|
||
|
|
||
|
def parallel_and(*args):
|
||
|
"""
|
||
|
Evaluate the logical FALSE of several arguments, avoiding guarding on
|
||
|
unbacked SymInts if another argument is definitely False.
|
||
|
"""
|
||
|
if any(statically_known_true(torch.sym_not(a)) for a in args):
|
||
|
return False
|
||
|
if any(definitely_false(a) for a in args):
|
||
|
return False
|
||
|
return all(args)
|
||
|
|
||
|
def sym_eq(x, y):
|
||
|
"""
|
||
|
Like ==, but when run on list/tuple, it will recursively test equality
|
||
|
and use sym_and to join the results together, without guarding.
|
||
|
"""
|
||
|
if (isinstance(x, tuple) and isinstance(y, tuple)) or (isinstance(x, list) and isinstance(y, list)):
|
||
|
if len(x) != len(y):
|
||
|
return False
|
||
|
return functools.reduce(operator.and_, map(sym_eq, x, y), True)
|
||
|
elif isinstance(x, (int, torch.SymInt)) and isinstance(y, (int, torch.SymInt)):
|
||
|
return x == y
|
||
|
else:
|
||
|
raise AssertionError(f"unexpected sym_eq between {type(x)} {type(y)}")
|
||
|
|
||
|
def guard_scalar(a):
|
||
|
if isinstance(a, (SymBool, bool)):
|
||
|
return guard_bool(a)
|
||
|
elif isinstance(a, (SymInt, int)):
|
||
|
return guard_int(a)
|
||
|
elif isinstance(a, (SymFloat, float)):
|
||
|
return guard_float(a)
|
||
|
else:
|
||
|
raise AssertionError(f"unrecognized scalar {a}")
|
||
|
|
||
|
|
||
|
@record_shapeenv_event()
|
||
|
def _constrain_symbol_range(shape_env, s: sympy.Symbol, compiler_min: int, compiler_max: int):
|
||
|
upd_vr = ValueRanges(compiler_min, compiler_max)
|
||
|
old_vr = shape_env.var_to_range.get(s, ValueRanges.unknown())
|
||
|
new_vr = shape_env.var_to_range[s] = old_vr & upd_vr
|
||
|
if new_vr != old_vr:
|
||
|
log.info("_constrain_symbol_range %s [%s, %s]", s, new_vr.lower, new_vr.upper)
|
||
|
|
||
|
|
||
|
def _advise_is_size(a):
|
||
|
"""
|
||
|
Don't use this directly; use torch._check_is_size instead.
|
||
|
|
||
|
This is a softer version of _constrain_range_for_size (with min=0,
|
||
|
max=Inf). Instead of forcibly constraining a variable (and erroring if we
|
||
|
failed to constrain it), it will simply advise us that a size is
|
||
|
constrained in some way. We will always defer a runtime assert for this
|
||
|
constraint if we cannot prove it at compile-time, but we we only
|
||
|
*sometimes* learn useful extra information at compile-time with this
|
||
|
information. This is in contrast to constrain_range_for_size, where if
|
||
|
you don't call that on a fresh unbacked symint, chances are we will choke.
|
||
|
|
||
|
TODO: Make Dynamo handle this appropriately if this is seen in Dynamo-ed
|
||
|
code. Right now this is only really used in code with AOTAutograd trace
|
||
|
through, so it is not a big problem that this isn't supported, but in
|
||
|
principle all of this code should be Dynamo'able too.
|
||
|
|
||
|
TODO: I didn't support min/max because I didn't have a use case where this
|
||
|
actually helped. In principle we can support it, it just makes the
|
||
|
implementation below more complicated.
|
||
|
"""
|
||
|
|
||
|
# This must always succeed, because the sole allowed caller _check_is_size
|
||
|
# was responsible for expect_true'ing this
|
||
|
assert a >= 0
|
||
|
|
||
|
# NB: it's important not to constrain range for size for *hinted* SymInts,
|
||
|
# because it is not only unsound, it will immediately trip our asserts
|
||
|
# that hints have to be consistent with static analysis! If you somehow
|
||
|
# have an unbounded SymInt that later constrains to 1, this will be
|
||
|
# inconsistent with the range
|
||
|
if (
|
||
|
isinstance(a, SymInt)
|
||
|
and isinstance(a.node, SymNode)
|
||
|
and not a.node.has_hint()
|
||
|
and isinstance(a.node.expr, sympy.Symbol)
|
||
|
):
|
||
|
_constrain_range_for_size(a)
|
||
|
|
||
|
@record_shapeenv_event()
|
||
|
def _constrain_range_for_size(a, min: Optional[int] = None, max: Optional[int] = None):
|
||
|
"""
|
||
|
This function is NOT INTENDED to be used by itself.
|
||
|
"""
|
||
|
|
||
|
if isinstance(a, (SymFloat, SymBool)):
|
||
|
raise ValueError("Constraining SymFloat/SymBool is nyi")
|
||
|
|
||
|
assert isinstance(a, SymInt), "can only constrain range for SymInt"
|
||
|
assert isinstance(a.node.expr, sympy.Symbol), "constraining non-Symbols NYI"
|
||
|
|
||
|
if min is None:
|
||
|
min = 0
|
||
|
if max is None:
|
||
|
max = sympy.oo
|
||
|
|
||
|
if max < min:
|
||
|
raise ValueError(
|
||
|
"Maximum value to constrain_as_size can't be less than the specified min value, "
|
||
|
"received min={min} and max={max}"
|
||
|
)
|
||
|
|
||
|
_constrain_symbol_range(
|
||
|
a.node.shape_env,
|
||
|
a.node.expr,
|
||
|
compiler_min=min,
|
||
|
compiler_max=max,
|
||
|
)
|
||
|
a.node.shape_env.size_like.add(a.node.expr)
|
||
|
|
||
|
|
||
|
# inclusive both ways
|
||
|
@record_shapeenv_event()
|
||
|
def constrain_range(a, *, min: Optional[int], max: Optional[int] = None):
|
||
|
"""
|
||
|
Applies a constraint that the passed in SymInt must lie between min-max
|
||
|
inclusive-inclusive, WITHOUT introducing a guard on the SymInt (meaning
|
||
|
that it can be used on unbacked SymInts). If min/max are None, we assume
|
||
|
that the dimension is unbounded in that direction. Repeated application
|
||
|
of constrain_range intersects the ranges. This is a fairly low level API
|
||
|
that doesn't have a lot of safety guarantees (TODO: provide higher level
|
||
|
APIs).
|
||
|
|
||
|
Currently, we use this API in the following circumstance: when we allocate
|
||
|
an unbacked SymInt, denoting an integer quantity which is data dependent,
|
||
|
we ordinarily do not know anything about what values it may take. This
|
||
|
means that any sort of guard on it will immediately fail. However, in
|
||
|
many cases, we know something about the unbacked SymInt: for example, we
|
||
|
know that nonzero(x).size(0) must be >= 0. We use constrain_range to
|
||
|
narrow the possible range, declaring that negative symbols are impossible.
|
||
|
This permits to definitely answer True to queries like 'nnz >= 0', even if
|
||
|
we don't know what the actual (hinted) value of 'nnz' is. In fact, we
|
||
|
actually use constrain_range to unsoundly discharge common guards: for an
|
||
|
unbacked SymInt produced by nonzero, we will also assume that it is not
|
||
|
equal to 0/1 (even though these are perfectly possible values at runtime),
|
||
|
because we generally expect graphs that are valid for N=2 to also be valid
|
||
|
for N=1.
|
||
|
"""
|
||
|
if min is None:
|
||
|
min = -sympy.oo
|
||
|
if max is None:
|
||
|
max = sympy.oo
|
||
|
|
||
|
if max < min:
|
||
|
raise ValueError(
|
||
|
"Maximum value to constrain_as_size can't be less than the specified min value, "
|
||
|
"received min={min} and max={max}"
|
||
|
)
|
||
|
|
||
|
if isinstance(a, int):
|
||
|
if not (min <= a <= max):
|
||
|
raise ValueError(f"Invalid value {a} for range [{min}:{max}]")
|
||
|
return
|
||
|
|
||
|
if isinstance(a.node.expr, sympy.Integer):
|
||
|
if not (min <= int(a.node.expr) <= max):
|
||
|
raise ValueRangeError(f"Invalid value {int(a.node.expr)} for range [{min}:{max}]")
|
||
|
return
|
||
|
assert isinstance(a.node.expr, sympy.Symbol), "constraining non-Symbols NYI"
|
||
|
|
||
|
# TODO: Shouldn't we install a guard if the symbol is backed? Or is the
|
||
|
# semantics that this is an "unchecked" assert (but it this actually
|
||
|
# something useful? Might be better to restrict only for unbacked
|
||
|
# SymInt).
|
||
|
_constrain_symbol_range(
|
||
|
a.node.shape_env,
|
||
|
a.node.expr,
|
||
|
compiler_min=min,
|
||
|
compiler_max=max,
|
||
|
)
|
||
|
|
||
|
|
||
|
@record_shapeenv_event()
|
||
|
def constrain_unify(a, b):
|
||
|
"""
|
||
|
Given two SymInts, constrain them so that they must be equal. NB:
|
||
|
this will not work with SymInts that represent nontrivial expressions
|
||
|
(yet!)
|
||
|
"""
|
||
|
# TODO: this does not install a deferred runtime assert yet
|
||
|
|
||
|
# TODO: Maybe dedupe this with _maybe_guard_rel?
|
||
|
if not isinstance(a, SymInt):
|
||
|
if not isinstance(b, SymInt):
|
||
|
assert a == b
|
||
|
else:
|
||
|
assert isinstance(b.node.expr, sympy.Symbol), "constraining non-Symbols NYI"
|
||
|
shape_env = b.node.shape_env
|
||
|
shape_env.replacements[b.node.expr] = sympy.Integer(a)
|
||
|
else:
|
||
|
# TODO: Actually, we can support this as long as one of them is a symbol.
|
||
|
# NB: We can't actually do "unification" as our operators are not
|
||
|
# injective
|
||
|
assert isinstance(a.node.expr, sympy.Symbol), "constraining non-Symbols NYI"
|
||
|
shape_env = a.node.shape_env
|
||
|
if not isinstance(b, SymInt):
|
||
|
shape_env.replacements[a.node.expr] = sympy.Integer(b)
|
||
|
else:
|
||
|
assert a.node.shape_env is b.node.shape_env
|
||
|
assert isinstance(b.node.expr, sympy.Symbol), "constraining non-Symbols NYI"
|
||
|
new_var = shape_env._find(a.node.expr)
|
||
|
shape_env.replacements[b.node.expr] = new_var
|
||
|
|
||
|
# Assume that a boolean is true for the purposes of subsequent symbolic
|
||
|
# reasoning. This will keep track of corresponding runtime checks to verify
|
||
|
# that the result is upheld: either as a regular guard, or as a special set
|
||
|
# of asserts which are triggered when an unbacked SymInt is allocated.
|
||
|
#
|
||
|
# DO NOT use this function for these cases:
|
||
|
#
|
||
|
# - This is inappropriate for "branching" conditions (where both
|
||
|
# true and false result in valid programs). We will always assume
|
||
|
# the condition evaluates true, and so it will never be possible
|
||
|
# to trace the false condition when you use it. For true branching
|
||
|
# on unbacked SymInts, you must use torch.cond; if you incorrectly
|
||
|
# use expect_true in this case, you will make the false branch
|
||
|
# unreachable (as we will simply assume that only the true branch
|
||
|
# is ever exercised).
|
||
|
#
|
||
|
# - This is inappropriate for situations where you know some other system
|
||
|
# invariant guarantees that this property holds, since you don't
|
||
|
# really need to insert a runtime check in that case. Use something
|
||
|
# like constrain_range in that case.
|
||
|
#
|
||
|
# This API has a hitch. To avoid having to reimplement error reporting
|
||
|
# capabilities, this function CAN return False. The invariant is that
|
||
|
# the surrounding code must raise an error when this function returns
|
||
|
# False. This is quite low level, so we recommend using other functions
|
||
|
# like check() which enforce this in a more intuitive way.
|
||
|
#
|
||
|
# By the way, this name is a nod to the __builtin_expect macro,
|
||
|
# which is used similarly (but unlike __builtin_expect, you MUST fail
|
||
|
# in the unlikely branch.) (I think expect is a good name; in recent
|
||
|
# versions of C++, this is replaced with [[likely]], which is weaker
|
||
|
# and not accurate for this function!)
|
||
|
def expect_true(a, skip: int = 0):
|
||
|
if isinstance(a, SymBool):
|
||
|
# TODO: check perf implications of this
|
||
|
frame = inspect.currentframe()
|
||
|
for _ in range(skip + 1): # always run this loop at least once
|
||
|
frame = frame.f_back
|
||
|
return a.node.expect_true(frame.f_code.co_filename, frame.f_lineno)
|
||
|
assert type(a) is bool, a
|
||
|
return a
|
||
|
|
||
|
def guard_bool(a):
|
||
|
if isinstance(a, SymBool):
|
||
|
return a.node.guard_bool("", 0) # NB: uses Python backtrace
|
||
|
assert type(a) is bool, a
|
||
|
return a
|
||
|
|
||
|
def guard_int(a):
|
||
|
if isinstance(a, SymInt):
|
||
|
return a.node.guard_int("", 0) # NB: uses Python backtrace
|
||
|
assert type(a) is int, a
|
||
|
return a
|
||
|
|
||
|
def guard_float(a):
|
||
|
if isinstance(a, SymFloat):
|
||
|
return a.node.guard_float("", 0) # NB: uses Python backtrace
|
||
|
assert isinstance(a, float), a
|
||
|
return a
|
||
|
|
||
|
# Given a GraphModule, return all the FakeTensors for all the placeholders
|
||
|
def fx_placeholder_vals(gm):
|
||
|
return [n.meta['val'] for n in gm.graph.nodes if n.op == "placeholder"]
|
||
|
|
||
|
def fx_placeholder_targets(gm):
|
||
|
return [n.target for n in gm.graph.nodes if n.op == "placeholder"]
|
||
|
|
||
|
# Given a GraphModule and arguments to run it with, evaluate that the guards
|
||
|
# for its associated ShapeEnv are satisfied by the passed arguments. This
|
||
|
# WILL check for duck sizing.
|
||
|
def eval_guards(gm, *args, ignore_static=True):
|
||
|
return gm.shape_env.evaluate_guards_for_args(fx_placeholder_vals(gm), args, ignore_static=ignore_static)
|
||
|
|
||
|
def bind_symbols(gm, *args):
|
||
|
return gm.shape_env.bind_symbols(fx_placeholder_vals(gm), args)
|
||
|
|
||
|
def _assert_bound_is_rational(expr: sympy.Expr, bound: ValueRanges):
|
||
|
"""
|
||
|
We assert that the bounds are either Boolean, or not finite, or can be computed
|
||
|
in exact prevision via rational arithmetic.
|
||
|
The only exception to this is the rare case when the user calls `sqrt(s0)`
|
||
|
sqrt is turned into sympy.Pow so we just match for that (it matches more things, but still)
|
||
|
"""
|
||
|
assert bound.lower.is_rational or bound.lower.is_Boolean or not bound.lower.is_finite or expr.has(sympy.Pow), (bound, expr)
|
||
|
assert bound.upper.is_rational or bound.upper.is_Boolean or not bound.upper.is_finite or expr.has(sympy.Pow), (bound, expr)
|
||
|
|
||
|
class DimDynamic(Enum):
|
||
|
"""
|
||
|
Controls how to perform symbol allocation for a dimension. It is always
|
||
|
sound to default this to DYNAMIC, but the policies DUCK and STATIC can
|
||
|
result in better trace-time and compile-time performance, as they reduce
|
||
|
the number of allocated symbols and generally make your graph more static.
|
||
|
|
||
|
NB: If we notice you've applied a constraint to the dimension, we will
|
||
|
force it to DYNAMIC for simplicity.
|
||
|
|
||
|
DimDynamic is controlled by a variety of higher level UX features.
|
||
|
Currently:
|
||
|
|
||
|
- In eager mode, the default policy is DUCK.
|
||
|
- The default is changed to STATIC with assume_static_by_default.
|
||
|
- An individual dim is marked DYNAMIC if you mark_dynamic_dim.
|
||
|
- In export mode, the default policy is STATIC.
|
||
|
- An individual dim is marked DYNAMIC if you mention it as dynamic_dim
|
||
|
in the constraints kwarg.
|
||
|
"""
|
||
|
# Treat the dimension symbolically
|
||
|
DYNAMIC = 0
|
||
|
# Treat the dimension symbolically, but if its hint matches another
|
||
|
# dynamic dimension, unify the two symbols ("duck sizing")
|
||
|
DUCK = 1
|
||
|
# Treat the dimension statically based on its hint
|
||
|
STATIC = 2
|
||
|
|
||
|
|
||
|
# NB: These constraints affect both clients and backends: given some
|
||
|
# constraint C, the client must pass inputs that satisfy the constraint,
|
||
|
# while a backend must not introduce guards BEYOND this constraint.
|
||
|
# For clarity, we document the implications on both sides for both the client
|
||
|
# and the backend.
|
||
|
#
|
||
|
# NB: These constraints are on a *single* dimension. In principle, we could
|
||
|
# also have multi-dimension constraints, but our guess is that this is not
|
||
|
# actually useful and so we are not supporting it right now.
|
||
|
#
|
||
|
# NB: Strict constraints are typically only suitable for export, as in eager
|
||
|
# a backend like inductor may validly introduce extra, discretionary guards
|
||
|
# to improve performance of code. A StrictMinMaxConstraint would be brittle
|
||
|
# under future optimizations performed by inductor; we don't guarantee
|
||
|
# eager code with StrictMinMaxConstraint will keep working in the future!
|
||
|
|
||
|
@dataclass(frozen=True)
|
||
|
class Constraint:
|
||
|
warn_only: bool
|
||
|
|
||
|
@dataclass(frozen=True)
|
||
|
class StrictMinMaxConstraint(Constraint):
|
||
|
"""
|
||
|
For clients: the size at this dimension must be within 'vr' (which
|
||
|
specifies a lower and upper bound, inclusive-inclusive) AND it
|
||
|
must be non-negative and should not be 0 or 1 (but see NB below).
|
||
|
|
||
|
For backends: there must not be any guards on this dimension which
|
||
|
are not implied by the given lower and upper bound. Regardless of
|
||
|
the lower bound, the backend can assume the size is non-negative
|
||
|
and that it is not 0 or 1.
|
||
|
|
||
|
An unbounded StrictMinMaxConstraint can be thought of as a strict version
|
||
|
of "RelaxedUnspecConstraint".
|
||
|
|
||
|
NB: Export will often unsoundly assume that a graph works for 0/1, even
|
||
|
though at trace time we assumed size is not 0 or 1. The idea is that
|
||
|
if we produce a graph that works for a range of values, it will be OK
|
||
|
for N=0/1 too.
|
||
|
"""
|
||
|
vr: ValueRanges
|
||
|
|
||
|
def render(self, source: Source):
|
||
|
"""Format the constrain equation"""
|
||
|
# TODO: better printing for -oo and oo
|
||
|
return f"{self.vr.lower} <= {source.name()} <= {self.vr.upper}"
|
||
|
|
||
|
@dataclass(frozen=True)
|
||
|
class RelaxedUnspecConstraint(Constraint):
|
||
|
"""
|
||
|
For clients: no explicit constraint; constraint is whatever is implicitly
|
||
|
inferred by guards from tracing.
|
||
|
|
||
|
For backends: there must exist at least TWO possible values for the
|
||
|
size at this dimension which satisfy the guards for this dimension.
|
||
|
|
||
|
In other words, this constraint helps us distinguish between "we don't
|
||
|
care if this dimension specializes or not" versus "this dimension must be
|
||
|
unspecialized." However, this constraint doesn't say very much about what
|
||
|
specialization is permitted; for example, if we guard on a size being
|
||
|
even, this would still be acceptable under an unspec constraint. This
|
||
|
makes RelaxedUnspecConstraint useful for eager mode, where your backend compiler
|
||
|
may add constraints to otherwise dynamic dimensions; we can't assert that
|
||
|
there are NO guards as this is brittle because compilers should be able to
|
||
|
add extra constraints. If you want to assert that there are no guards,
|
||
|
use StrictMinMaxConstraint with an unbounded ValueRanges.
|
||
|
"""
|
||
|
def render(self, source: Source):
|
||
|
return f"RelaxedUnspecConstraint({source.name()})"
|
||
|
|
||
|
# NB: None here indicates the client constraint is whatever is implicitly
|
||
|
# inferred by guards from tracing, and that a backend can add whatever guards
|
||
|
# it wants (including fully specializing the value).
|
||
|
DimConstraint = Union[StrictMinMaxConstraint, RelaxedUnspecConstraint, None]
|
||
|
|
||
|
@dataclass(frozen=True)
|
||
|
class EqualityConstraint(Constraint):
|
||
|
"""
|
||
|
Represent and decide various kinds of equality constraints between input sources.
|
||
|
|
||
|
A "source pair" is a pair of input sources for dynamic dimensions that
|
||
|
are specified equal. We represent `source_pairs` in a union-find forest
|
||
|
so that we can efficiently check whether two such sources are transitively equal.
|
||
|
|
||
|
A "derived equality" relates an input source to an expression over a root.
|
||
|
The root can be another input source, corresponding to some dynamic dimension,
|
||
|
or a phantom symbol that does not directly represent any dynamic dimension. We
|
||
|
represent `derived_equalities` involving input sources in a transitively-closed map
|
||
|
so that we can efficiently check whether an input source is transitively equal to
|
||
|
a given expression over another input source.
|
||
|
(NOTE: In contrast, it is easy to decide whether an input source is transitively equal
|
||
|
to a given expression over a phantom symbol; such expressions are already in canonical
|
||
|
form and so the problem reduces to symbolic expression equality.)
|
||
|
"""
|
||
|
source_pairs: List[Tuple[Source, Source]]
|
||
|
derived_equalities: List[Tuple[Source, Union[Source, sympy.Symbol], Callable[[sympy.Expr], sympy.Expr]]]
|
||
|
phantom_symbols: List[sympy.Symbol]
|
||
|
|
||
|
def __post_init__(self):
|
||
|
"""Pre-processing to answer queries `is_equal` and `is_derived` below.
|
||
|
|
||
|
Example: Suppose we are given:
|
||
|
source_pairs [a = b, b = c]
|
||
|
derived_equalities [d = c + 1, e = d - 1]
|
||
|
We first construct a union find with source_pairs:
|
||
|
_parents = {a: a, b: a, c: a}
|
||
|
Then we compute canonical symbolic expressions, recursively applying derived_equalities
|
||
|
until we bottom out:
|
||
|
_defs = {d: c + 1, e: (c + 1) - 1 aka c}
|
||
|
"""
|
||
|
|
||
|
# self._parents is a map from input sources to input sources where, conceptually,
|
||
|
# these are directed edges in a union-find forest
|
||
|
_parents: Dict[Source, Source] = {}
|
||
|
object.__setattr__(self, "_parents", _parents)
|
||
|
# self._defs is a map from input sources to "canonical" symbolic expressions,
|
||
|
# i.e., unary expressions with symbols that corresponds to regular Dims (i.e.,
|
||
|
# not derived Dims)
|
||
|
_defs: Dict[Source, sympy.Expr] = {}
|
||
|
object.__setattr__(self, "_defs", _defs)
|
||
|
|
||
|
for source1, source2 in self.source_pairs:
|
||
|
# preprocess into a union-find forest
|
||
|
self._union(self._find(source1), self._find(source2))
|
||
|
for source, root, fn in self.derived_equalities:
|
||
|
# preprocess into a transitively-closed map
|
||
|
# NOTE(avik): we reuse the union-find forest for canonicalizing input sources
|
||
|
if isinstance(root, sympy.Symbol):
|
||
|
self._defs[self._find(source)] = fn(root)
|
||
|
else:
|
||
|
self._defs[self._find(source)] = fn(self._rewrite(root))
|
||
|
|
||
|
def _find(self, source):
|
||
|
# chase edges to find the root of this equivalence class
|
||
|
if source in self._parents:
|
||
|
return self._find(self._parents[source])
|
||
|
else:
|
||
|
return source
|
||
|
|
||
|
def _union(self, root1, root2):
|
||
|
# merge two equivalence classes by adding an edge from one root to the other
|
||
|
if root1 != root2:
|
||
|
self._parents[root1] = root2
|
||
|
|
||
|
def _rewrite(self, src):
|
||
|
# always represent the given source by the root of its equivalence class
|
||
|
src = self._find(src)
|
||
|
if src in self._defs:
|
||
|
# simply look up the definition if it exists
|
||
|
# NOTE(avik): This works because definitions are always transitively-closed;
|
||
|
# otherwise we would have to do recursive rewriting.
|
||
|
return self._defs[src]
|
||
|
else:
|
||
|
# otherwise, create a symbol representing the source
|
||
|
return sympy.Symbol(src.name())
|
||
|
|
||
|
def is_equal(self, source1, source2):
|
||
|
return (
|
||
|
# check whether source1 and source2 have the same root
|
||
|
self._find(source1) == self._find(source2) or
|
||
|
# check whether source1 is derived equal to source2
|
||
|
self.is_derived(source1, source2, lambda x: x)
|
||
|
)
|
||
|
|
||
|
def is_derived(self, src, symbol_src, fn):
|
||
|
# check whether both src and symbol_src have the same definition
|
||
|
return self._rewrite(src) == fn(self._rewrite(symbol_src))
|
||
|
|
||
|
|
||
|
def _assert_symbol_context(symbolic_context):
|
||
|
assert isinstance(symbolic_context, SymbolicContext), "Invalid symbolic_context object"
|
||
|
assert type(symbolic_context) is not SymbolicContext, "Illegal usage of symbolic_context ABC"
|
||
|
|
||
|
|
||
|
@dataclass(frozen=True)
|
||
|
class SymbolicContext:
|
||
|
"""
|
||
|
Data structure specifying how we should create symbols in
|
||
|
``create_symbolic_sizes_strides_storage_offset``; e.g., should
|
||
|
they be static or dynamic.
|
||
|
|
||
|
This is an abstract base class because we are probably going to add
|
||
|
another version of this that says "use exactly these SymInts, don't
|
||
|
allocate fresh symbols."
|
||
|
"""
|
||
|
pass
|
||
|
|
||
|
|
||
|
@dataclass(frozen=True)
|
||
|
class StatelessSymbolicContext(SymbolicContext):
|
||
|
"""
|
||
|
Create symbols in ``create_symbolic_sizes_strides_storage_offset`` via
|
||
|
a symbolic_context determination as given by ``DimDynamic`` and ``DimConstraint``.
|
||
|
This will cause fresh symbols to be allocated
|
||
|
"""
|
||
|
dynamic_sizes: DimList[DimDynamic]
|
||
|
constraint_sizes: DimList[DimConstraint] = None
|
||
|
# If the tensor is a view, this should be populated for the base. It contains
|
||
|
# information on how to allocate symbols when recursively fakeifying the base
|
||
|
# during view fake-ification.
|
||
|
view_base_context: Optional[SymbolicContext] = None
|
||
|
# TODO: add storage offset and stride symbolic_context
|
||
|
|
||
|
def __post_init__(self):
|
||
|
if self.constraint_sizes is None:
|
||
|
object.__setattr__(self, 'constraint_sizes', [None] * len(self.dynamic_sizes))
|
||
|
|
||
|
|
||
|
# note [Tensor Fakification and Symbol Caching]
|
||
|
#
|
||
|
# As of the time of this note, dynamo creates a fresh fake tensor mode for backends.
|
||
|
# The reason we do this is because there are certain classes of operations, namely,
|
||
|
# metadata mutations, that change tensor size, stride, etc. This means that the fake tensor
|
||
|
# state at the end of a dynamo trace is different than the fake tensor state at the beginning
|
||
|
# of a trace. Backends like aot_autograd need a fresh fake tensor to correctly track metadata mutation,
|
||
|
# view relationships, etc.
|
||
|
#
|
||
|
# As we create a new fake mode, we also lose the memoization that comes with it. Rather than
|
||
|
# transfer the memoization cache, we instead transfer the shape env. However, with this
|
||
|
# comes nuance - as dynamo is selective in how it makes symbolic shapes. Due to strategies in
|
||
|
# automatic dynamic and constraints, the policy for which dims are dynamic is nuanced and varies across
|
||
|
# recompilations.
|
||
|
#
|
||
|
# In order to preserve the symbolic decisions made during dynamo tensor fakification, we pass
|
||
|
# a StatefulSymbolicContext at creation time. This object is tracked, per tensor, on the TracingContext.
|
||
|
# The lifecycle of this object should match the lifecycle of the original dynamo tracked tensor, and it is
|
||
|
# safe to reuse this object as many times as necessary to create a fake tensor. Fake tensors
|
||
|
# created with new fake modes should produce the same exact symbols as the original, providing the same shape_env
|
||
|
# is used.
|
||
|
# TODO(voz): Shape env validation
|
||
|
@dataclass(frozen=True)
|
||
|
class StatefulSymbolicContext(StatelessSymbolicContext):
|
||
|
"""
|
||
|
Create symbols in ``create_symbolic_sizes_strides_storage_offset`` via
|
||
|
a symbolic_context determination as given by a cache of Source:Symbol. A cache hit
|
||
|
will reuse a stored symbol, and a cache miss will write to this cache.
|
||
|
|
||
|
This behaves like StatelessSymbolicContext, except the cache supersedes the
|
||
|
other values - dynamic_sizes and constraint_sizes will not be read if we cache
|
||
|
hit.
|
||
|
|
||
|
It is the cache owners responsibility to maintain the lifecycle of the cache
|
||
|
w/r/t different shape_envs, clearing, etc.
|
||
|
"""
|
||
|
tensor_source: Source = None
|
||
|
# Why is this keyd on int first?
|
||
|
# That integer is actually the id of the shape_env. This cache short-circuits symbol
|
||
|
# creation, and we must store it per shape env. Now, while tracing invariants are a single
|
||
|
# shape env per tracing context, and every new frame gets a new shape_env. So where would we have
|
||
|
# multiple shape envs? The answer lies in recording. When we are replaying, replay_shape_env_events
|
||
|
# is invoked, and creates a new shape_env. Replaying events against this new shape_env will
|
||
|
# cause it to fail with unknown symbols, as the symbols cached here will skip creation, and never
|
||
|
# get recorded in var_to_val, etc.
|
||
|
# TODO(voz): consider a weakref to the shape_env here
|
||
|
shape_env_to_source_to_symbol_cache : Dict[int, Dict["TensorPropertySource", "sympy.Expr"]] = None
|
||
|
|
||
|
def __post_init__(self):
|
||
|
# The None default is annoying, but required because of dataclass limitations
|
||
|
assert self.tensor_source is not None
|
||
|
if not self.shape_env_to_source_to_symbol_cache:
|
||
|
object.__setattr__(self, 'shape_env_to_source_to_symbol_cache', {})
|
||
|
|
||
|
|
||
|
@dataclass(frozen=True)
|
||
|
class SubclassSymbolicContext(StatefulSymbolicContext):
|
||
|
"""
|
||
|
The correct symbolic context for a given inner tensor of a traceable tensor subclass
|
||
|
may differ from that of the outer symbolic context. This structure allows for this
|
||
|
flexibility, with inner symbolic contexts mapped via attr -> symbolic context.
|
||
|
"""
|
||
|
inner_contexts: Dict[str, SymbolicContext] = None
|
||
|
|
||
|
def __post_init__(self):
|
||
|
super().__post_init__()
|
||
|
if self.inner_contexts is None:
|
||
|
self.inner_contexts = {}
|
||
|
|
||
|
|
||
|
def is_symbolic(val: Union[int, SymInt, float, SymFloat, bool, SymBool]) -> bool:
|
||
|
if isinstance(val, (int, float, bool)):
|
||
|
return False
|
||
|
return val.node.is_symbolic()
|
||
|
|
||
|
IndicatorTypes = (IsNonOverlappingAndDenseIndicator,)
|
||
|
|
||
|
@lru_cache(256)
|
||
|
def safe_expand(r):
|
||
|
if hasattr(r, 'expand'):
|
||
|
try:
|
||
|
return sympy.expand(r)
|
||
|
except RecursionError:
|
||
|
log.warning("RecursionError in sympy.expand(%s)", r)
|
||
|
return r
|
||
|
else:
|
||
|
return r
|
||
|
|
||
|
def error():
|
||
|
raise AssertionError("shouldn't be hit")
|
||
|
|
||
|
|
||
|
# TODO: Deduplicate this with torch/_prims_common/__init__.py
|
||
|
def eval_is_non_overlapping_and_dense(sizes, strides):
|
||
|
return int(guard_bool(_eval_is_non_overlapping_and_dense(sizes, strides)))
|
||
|
|
||
|
def _eval_is_non_overlapping_and_dense(sizes, strides):
|
||
|
dim = len(sizes)
|
||
|
|
||
|
# Short-circuits for tensors of rank one, which are
|
||
|
# non-overlapping and "dense" if their stride is one
|
||
|
# or it is a 0/1 element tensor
|
||
|
if dim == 1:
|
||
|
return strides[0] == 1 or sizes[0] < 2
|
||
|
|
||
|
# Checks that there exists a permutation of the strides s.t. the tensor would be contiguous
|
||
|
# Sorts (length, stride) pairs by stride
|
||
|
lengths_and_strides = sorted(
|
||
|
zip(sizes, strides), key=operator.itemgetter(1)
|
||
|
)
|
||
|
|
||
|
# Unlike the C++ code, we don't move the 0/1 size dimensions to the
|
||
|
# end. So we have to keep going for this code.
|
||
|
expected_stride = 1
|
||
|
for length, stride in lengths_and_strides:
|
||
|
|
||
|
if length == 1:
|
||
|
continue
|
||
|
|
||
|
if stride != expected_stride:
|
||
|
return False
|
||
|
|
||
|
expected_stride *= length
|
||
|
|
||
|
return True
|
||
|
|
||
|
|
||
|
def cast_symbool_to_symint_guardless(symbool: torch.SymBool) -> torch.SymInt:
|
||
|
int_sym = sympy.Piecewise((1, symbool.node.expr), (0, True))
|
||
|
return symbool.node.shape_env.create_symintnode(int_sym, hint=int(symbool.node.require_hint()))
|
||
|
|
||
|
SYMPY_INTERP = {
|
||
|
'Abs': operator.abs,
|
||
|
'Eq': operator.eq,
|
||
|
'Ne': operator.ne,
|
||
|
'Gt': operator.gt,
|
||
|
'Lt': operator.lt,
|
||
|
'Le': operator.le,
|
||
|
'Ge': operator.ge,
|
||
|
'Min': min,
|
||
|
'Max': max,
|
||
|
'Mod': operator.mod,
|
||
|
'FloorDiv': operator.floordiv,
|
||
|
'TrueDiv': operator.truediv,
|
||
|
'IsNonOverlappingAndDenseIndicator': eval_is_non_overlapping_and_dense,
|
||
|
'floor': math.floor,
|
||
|
'ceiling': math.ceil,
|
||
|
'cast_symbool_to_symint_guardless': cast_symbool_to_symint_guardless,
|
||
|
'Round': builtins.round,
|
||
|
'RoundDecimal': builtins.round,
|
||
|
}
|
||
|
|
||
|
|
||
|
def _lru_cache(fn, maxsize=None):
|
||
|
"""
|
||
|
Wrapper around lru_cache that clears when new info about shapes has been
|
||
|
updated.
|
||
|
|
||
|
Use lru_cache if the output is always the same, regardless of the
|
||
|
constraints we know now (i.e. evaluate_expr)
|
||
|
|
||
|
Use _lru_cache otherwise.
|
||
|
|
||
|
Also note that this depends on _update_version_counter being called on the
|
||
|
shape environment whenever the constraints are updated, otherwise the cache
|
||
|
will not be cleared.
|
||
|
"""
|
||
|
fn_cache = lru_cache(maxsize)(fn)
|
||
|
prior_version = 0
|
||
|
|
||
|
if config.validate_shape_env_version_key:
|
||
|
prior_key = None
|
||
|
|
||
|
@functools.wraps(fn)
|
||
|
def wrapper(self, *args, **kwargs):
|
||
|
nonlocal prior_version, prior_key
|
||
|
if prior_key is None:
|
||
|
prior_key = self._get_key()
|
||
|
|
||
|
if prior_version != self._version_counter:
|
||
|
fn_cache.cache_clear()
|
||
|
prior_version = self._version_counter
|
||
|
prior_key = self._get_key()
|
||
|
else:
|
||
|
assert prior_key == self._get_key(), \
|
||
|
"ShapeEnv cache key changed without version being updated!"
|
||
|
|
||
|
return fn_cache(self, *args, **kwargs)
|
||
|
|
||
|
else:
|
||
|
|
||
|
@functools.wraps(fn)
|
||
|
def wrapper(self, *args, **kwargs):
|
||
|
nonlocal prior_version
|
||
|
if prior_version != self._version_counter:
|
||
|
fn_cache.cache_clear()
|
||
|
prior_version = self._version_counter
|
||
|
|
||
|
return fn_cache(self, *args, **kwargs)
|
||
|
|
||
|
wrapper.cache_clear = fn_cache.cache_clear
|
||
|
wrapper.cache_info = fn_cache.cache_info # type: ignore[attr-defined]
|
||
|
return wrapper
|
||
|
|
||
|
|
||
|
# This is pretty similar to ShapeGuard but it also comes with a message,
|
||
|
# and is exclusively used for things that MUST be true (unlike guards,
|
||
|
# which can evaluate False, in which case you just choose not to use
|
||
|
# a particular specialization)
|
||
|
@dataclass(frozen=True)
|
||
|
class RuntimeAssert:
|
||
|
expr: sympy.Expr
|
||
|
msg: str = field(repr=False)
|
||
|
stack: str = field(repr=False)
|
||
|
|
||
|
|
||
|
class ShapeGuardPrinter(StrPrinter):
|
||
|
def __init__(
|
||
|
self,
|
||
|
symbol_to_source,
|
||
|
source_ref,
|
||
|
var_to_sources,
|
||
|
):
|
||
|
super().__init__()
|
||
|
self.symbol_to_source = symbol_to_source
|
||
|
self.source_ref = source_ref
|
||
|
self.var_to_sources = var_to_sources
|
||
|
|
||
|
def _print_Not(self, expr):
|
||
|
return 'not %s' % (self.parenthesize(expr.args[0], PRECEDENCE["Not"]))
|
||
|
|
||
|
def _print_And(self, expr):
|
||
|
return self.stringify(expr.args, " and ", PRECEDENCE["And"])
|
||
|
|
||
|
def _print_Or(self, expr):
|
||
|
return self.stringify(expr.args, " or ", PRECEDENCE["Or"])
|
||
|
|
||
|
def _print_Symbol(self, expr) -> str:
|
||
|
assert isinstance(expr, sympy.Symbol), str(type(expr))
|
||
|
|
||
|
def repr_symbol_to_source():
|
||
|
return repr({
|
||
|
symbol: [s.name() for s in sources]
|
||
|
for symbol, sources in self.symbol_to_source.items()
|
||
|
})
|
||
|
|
||
|
assert self.symbol_to_source.get(expr), (
|
||
|
f"{expr} (could be from {[s.name() for s in self.var_to_sources[expr]]}) "
|
||
|
f"not in {repr_symbol_to_source()}. If this assert is failing, it could be "
|
||
|
"due to the issue described in https://github.com/pytorch/pytorch/pull/90665"
|
||
|
)
|
||
|
return self.source_ref(self.symbol_to_source[expr][0])
|
||
|
|
||
|
|
||
|
class LoggingShapeGuardPrinter(ShapeGuardPrinter):
|
||
|
def __init__(self, var_to_sources):
|
||
|
super().__init__(var_to_sources, lambda n: n.name(), var_to_sources)
|
||
|
|
||
|
|
||
|
class DynamicDimConstraintPrinter(StrPrinter):
|
||
|
"""
|
||
|
Printer for dynamic dim constraints.
|
||
|
- Instead of t.size()[d] it prints dynamic_dim(t, d)
|
||
|
- Instead of Eq(_, _), Mod(_, _), etc. it prints _ == _, _ % _, etc.
|
||
|
|
||
|
We use this to suggest code for specifying dynamic dim constraints.
|
||
|
"""
|
||
|
def __init__(self, symbol_to_source, source_name_to_debug_name):
|
||
|
super().__init__()
|
||
|
self.symbol_to_source = symbol_to_source
|
||
|
self.source_name_to_debug_name = source_name_to_debug_name
|
||
|
|
||
|
def print_source(self, source) -> str:
|
||
|
if self.source_name_to_debug_name:
|
||
|
return source.name()
|
||
|
return f"dynamic_dim({source.base.name()}, {source.idx})"
|
||
|
|
||
|
def _print_Symbol(self, expr) -> str:
|
||
|
assert isinstance(expr, sympy.Symbol), str(type(expr))
|
||
|
assert self.symbol_to_source.get(expr), (
|
||
|
f"Unknown symbol {expr} created by constraints solver"
|
||
|
)
|
||
|
return self.print_source(self.symbol_to_source[expr][0])
|
||
|
|
||
|
def _print_Relational(self, expr):
|
||
|
return '{} {} {}'.format(
|
||
|
self.parenthesize(expr.lhs, precedence(expr)),
|
||
|
expr.rel_op,
|
||
|
self.parenthesize(expr.rhs, precedence(expr))
|
||
|
)
|
||
|
|
||
|
|
||
|
class DimConstraints:
|
||
|
"""
|
||
|
Custom solver for a system of constraints on symbolic dimensions.
|
||
|
Solutions are "static" values or simplified "dynamic" constraints.
|
||
|
"""
|
||
|
|
||
|
def __init__(self, symbol_to_source, var_to_val, marked_dynamic, source_name_to_debug_name):
|
||
|
# We try to solve systems of inequalities with 1 free variable.
|
||
|
self._univariate_inequalities: Dict[sympy.Symbol, Set[sympy.Expr]] = defaultdict(set)
|
||
|
# Among them, we prioritize solving for a free variable that has equalities.
|
||
|
# NOTE: _symbols_with_equalities is always a subset of _univariate_inequalities.keys()
|
||
|
# and removing a symbol from the former => removing it from the latter.
|
||
|
self._symbols_with_equalities: Set[sympy.Symbol] = set()
|
||
|
# A solution of a free variable with equalities becomes a substitution.
|
||
|
# We use these substitutions to simplify other constraints.
|
||
|
# NOTE: removing a symbol from _symbols_with_equalities => adding it to _substitutions.
|
||
|
self._substitutions: Dict[sympy.Symbol, sympy.Integer] = {}
|
||
|
|
||
|
# In general, constraints may have // and % operations.
|
||
|
# Of course, // can be expressed in terms of / and %.
|
||
|
# Our inequality solver can handle / but not %. So we need to transform them away.
|
||
|
# We do so by using the values of variables as hints to evaluate %.
|
||
|
# For soundness we record additional congruence guards and solve them separately.
|
||
|
self._var_to_val: Dict[sympy.Symbol, sympy.Integer] = var_to_val
|
||
|
self._congruences: Set[sympy.Expr] = defaultdict(set)
|
||
|
|
||
|
# We do not try to (directly) solve inequalities with > 1 free variables.
|
||
|
# NOTE: free variables in these inequalities cannot also be in _substitutions.
|
||
|
self._multivariate_inequalities: Set[sympy.Expr] = set()
|
||
|
|
||
|
# We park external equalities between free variables here.
|
||
|
self._symbolic_equivalences: List[Tuple[Source, sympy.Expr]] = []
|
||
|
|
||
|
# Solutions come in two forms:
|
||
|
# - (static) specializations
|
||
|
# - (dynamic) inequalities / congruences
|
||
|
self._static_results: Set[str] = set()
|
||
|
self._dynamic_results: Set[str] = set()
|
||
|
|
||
|
# printer for solutions
|
||
|
self._dcp = DynamicDimConstraintPrinter(symbol_to_source, source_name_to_debug_name)
|
||
|
|
||
|
# inconsistencies found on substituting with concrete values / static solutions
|
||
|
self._inconsistencies: List[str] = []
|
||
|
|
||
|
# symbols that are marked dynamic
|
||
|
self._marked_dynamic = marked_dynamic
|
||
|
|
||
|
def rewrite_with_congruences(self, s, expr):
|
||
|
"""
|
||
|
Eliminate expressions of the form b // d and b % d while adding congruences of the form b % d == k.
|
||
|
This leaves rational operators (in particular of the form b / d) that our inequality solver can handle.
|
||
|
We solve the added congruences separately (using our congruence solver, see below).
|
||
|
"""
|
||
|
def mod_handler(*args):
|
||
|
# Suppose that we have an expression of the form b % d with free variable s.
|
||
|
# Using the value of s as a "hint," we can evaluate b % d to a value k.
|
||
|
# Then we can rewrite b % d to k while adding the guard b % d == k.
|
||
|
|
||
|
# NOTE(avik): This abstraction is provably sound but, in general, incomplete. It is complete IFF
|
||
|
# the original expression always evaluates to a constant value (i.e., it does not vary with s).
|
||
|
# In other words,
|
||
|
# - solutions of s with the rewritten expression are guaranteed to also be solutions of s with
|
||
|
# the original expression;
|
||
|
# - while it may be possible to find solutions of s with the original expression that are not
|
||
|
# solutions with the rewritten expression, in that case the original expression cannot evaluate
|
||
|
# to the same value for all solutions of s.
|
||
|
#
|
||
|
# Should we be worried about this incompleteness? No, because of the following reasons:
|
||
|
# 1. It unblocks dramatic simplification that would not be otherwise possible with current tech
|
||
|
# (i.e., "don't let perfect be the enemy of the good").
|
||
|
# 2. We already have a tradition of using hints to add guards in the compiler for making progress.
|
||
|
# 3. We have not yet seen a counterexample arise in practice! In particular, any congruence guards
|
||
|
# we generate (or simplify to) seem to be of the form b % d == k where k is a constant.
|
||
|
#
|
||
|
# Here's a theoretical counterexample: 3*s % (s + 1) == s - 2, that is satisfied by all s >= 2.
|
||
|
# With any hint (say) s = k, we'd rewrite this to: 3*s % (s + 1) == k - 2. But, substituting, we
|
||
|
# would then get k - 2 == s - 2, and thus s = k as the (only, constant) solution!
|
||
|
base, divisor = args
|
||
|
base, divisor = self.rewrite_with_congruences(s, base), self.rewrite_with_congruences(s, divisor)
|
||
|
mod_reduced = base.subs(self._var_to_val) % divisor.subs(self._var_to_val)
|
||
|
congruence = (base - mod_reduced) % divisor
|
||
|
if congruence != 0:
|
||
|
self._congruences[s].add(congruence)
|
||
|
return mod_reduced
|
||
|
|
||
|
def floor_div_handler(*args):
|
||
|
# Suppose that we have an expression of the form b // d with free variable s.
|
||
|
# Using the value of s, we can evaluate b % d to a value k.
|
||
|
# Then we can rewrite b // d to (b - k) / d, while adding the guard b % d == k.
|
||
|
|
||
|
# NOTE(avik): This is exactly equivalent to rewriting b // d as (b - (b % d)) / d
|
||
|
# and eliminating b % d as above.
|
||
|
base, divisor = args
|
||
|
base, divisor = self.rewrite_with_congruences(s, base), self.rewrite_with_congruences(s, divisor)
|
||
|
mod_reduced = base.subs(self._var_to_val) % divisor.subs(self._var_to_val)
|
||
|
congruence = (base - mod_reduced) % divisor
|
||
|
if congruence != 0:
|
||
|
self._congruences[s].add(congruence)
|
||
|
return (base - mod_reduced) / divisor
|
||
|
|
||
|
if expr.has(Mod):
|
||
|
expr = expr.replace(Mod, mod_handler)
|
||
|
if expr.has(FloorDiv):
|
||
|
expr = expr.replace(FloorDiv, floor_div_handler)
|
||
|
return expr
|
||
|
|
||
|
def add(self, expr) -> bool:
|
||
|
"""Add an expression to the set of constraints.
|
||
|
|
||
|
Return whether the expression is a trivial constraint (i.e., an obvious tautology).
|
||
|
"""
|
||
|
if expr == sympy.true:
|
||
|
return True
|
||
|
orig_expr = expr
|
||
|
orig_reduced = orig_expr.subs(self._var_to_val)
|
||
|
# TODO(avik): https://github.com/pytorch/pytorch/issues/101093
|
||
|
# It is possible that `expr` will fail the consistency check because of
|
||
|
# precision errors. Specifically, on substituting its free symbols with
|
||
|
# their concrete values, we might end up comparing floats. Until we have
|
||
|
# a fix for this issue, we delay raising such failures. See solve().
|
||
|
if orig_reduced == sympy.false:
|
||
|
self._inconsistencies.append(f"{orig_expr} is inconsistent!")
|
||
|
if isinstance(expr, sympy.Ne):
|
||
|
# we're not going to do anything useful with these, so drop them
|
||
|
return False
|
||
|
free_symbols = expr.free_symbols
|
||
|
assert free_symbols, f"Did not expect constraint with no free variables: {expr}"
|
||
|
if len(free_symbols) > 1:
|
||
|
# multivariate: record and move on
|
||
|
self._multivariate_inequalities.add(expr)
|
||
|
else:
|
||
|
# univariate: can solve these immediately
|
||
|
s = next(iter(free_symbols))
|
||
|
# eliminate // and % (see documentation of `rewrite_with_congruences` above)
|
||
|
old_n_congruences = len(self._congruences[s])
|
||
|
expr = self.rewrite_with_congruences(s, expr)
|
||
|
new_n_congruences = len(self._congruences[s])
|
||
|
if expr == sympy.true:
|
||
|
return old_n_congruences == new_n_congruences
|
||
|
reduced = expr.subs(self._var_to_val)
|
||
|
if reduced == sympy.false:
|
||
|
self._inconsistencies.append(
|
||
|
f"{expr}, obtained by rewriting {orig_expr} with congruences, "
|
||
|
"is inconsistent!"
|
||
|
)
|
||
|
if isinstance(expr, sympy.Eq):
|
||
|
# special status for symbols that have equalities (see `solve` below)
|
||
|
self._symbols_with_equalities.add(s)
|
||
|
self._univariate_inequalities[s].add(expr)
|
||
|
return False
|
||
|
|
||
|
def add_equality(self, source, expr):
|
||
|
"""Add an equality constraint"""
|
||
|
if expr.is_number:
|
||
|
# specialization, right here
|
||
|
self._static_results.add(f"{source.name()} == {expr}")
|
||
|
else:
|
||
|
# these will resolve to either specializations or dynamic equality constraints
|
||
|
self._symbolic_equivalences.append((source, expr))
|
||
|
|
||
|
def _reduce_congruences(self):
|
||
|
reduced_congruences = {}
|
||
|
for s, congruences in self._congruences.items():
|
||
|
remainder_modulus_pairs = []
|
||
|
congruences_to_check = set()
|
||
|
for congruence in congruences:
|
||
|
base, divisor = congruence.args
|
||
|
# We are given a congruence of the form base % divisor == 0 with a free variable s. So:
|
||
|
# - we transform this into an equation of the form base = divisor * tmp;
|
||
|
# - we solve this equation for s to get a linear solution with free variable tmp.
|
||
|
tmp = sympy.Symbol("tmp", integer=True)
|
||
|
symbol, solution = sympy.solve_linear(base - divisor * tmp, symbols=[s])
|
||
|
# See https://docs.sympy.org/latest/modules/solvers/solvers.html#sympy.solvers.solvers.solve_linear
|
||
|
# for how to interpret the results.
|
||
|
if s == symbol:
|
||
|
# This means the solution is of the form s = modulus*tmp + remainder.
|
||
|
modulus, remainder = sympy.polys.polytools.div(solution, tmp)
|
||
|
if isinstance(modulus, sympy.Integer) and isinstance(remainder, sympy.Integer):
|
||
|
# Make sure 0 <= remainder <= modulus.
|
||
|
remainder = remainder % modulus
|
||
|
remainder_modulus_pairs.append((remainder, modulus))
|
||
|
continue
|
||
|
# This means that we did not get a unique solution to the equation.
|
||
|
# No problem, we will check it.
|
||
|
congruences_to_check.add(congruence)
|
||
|
# Finally we solve for a congruence s such that s = r_i mod m_i for each (r_i, m_i).
|
||
|
# The solution will be a congruence of the form s = r mod m.
|
||
|
# NOTE(avik): Since the given m_i may not be pairwise coprime, we can't just use CRT.
|
||
|
if remainder_modulus_pairs:
|
||
|
remainder, modulus = sympy.ntheory.modular.solve_congruence(*remainder_modulus_pairs)
|
||
|
reduced_congruences[s] = {(s - remainder) % modulus}
|
||
|
substitution = {s: modulus * sympy.Symbol("tmp", integer=True) + remainder}
|
||
|
reduced_congruences[s].update(
|
||
|
congruence for congruence in congruences_to_check
|
||
|
if not sympy.checksol(congruence, substitution)
|
||
|
)
|
||
|
else:
|
||
|
reduced_congruences[s] = congruences_to_check
|
||
|
|
||
|
return reduced_congruences
|
||
|
|
||
|
def _raise_inconsistencies(self):
|
||
|
if self._inconsistencies:
|
||
|
msg = "\n".join(self._inconsistencies)
|
||
|
self._inconsistencies.clear()
|
||
|
raise ValueError(f"The following inconsistencies were found:\n{msg}")
|
||
|
|
||
|
def _force_specialization(self, s):
|
||
|
val = self._var_to_val[s]
|
||
|
self._static_results.add(f"{self._dcp.symbol_to_source[s][0].name()} == {val}")
|
||
|
self._substitutions[s] = val
|
||
|
|
||
|
def _specialize_divisor_symbols(self):
|
||
|
for expr in self._multivariate_inequalities:
|
||
|
for atom in expr.atoms(FloorDiv, Mod):
|
||
|
_, divisor = atom.args
|
||
|
for s in divisor.free_symbols:
|
||
|
self._force_specialization(s)
|
||
|
|
||
|
multivariate_inequalities = self._multivariate_inequalities
|
||
|
self._multivariate_inequalities = set()
|
||
|
for expr in multivariate_inequalities:
|
||
|
self.add(expr.subs(self._substitutions))
|
||
|
self._raise_inconsistencies()
|
||
|
self._univariate_inequalities = {
|
||
|
s: exprs
|
||
|
for s, exprs in self._univariate_inequalities.items()
|
||
|
if s not in self._substitutions
|
||
|
}
|
||
|
self._congruences = {
|
||
|
s: congruences
|
||
|
for s, congruences in self._congruences.items()
|
||
|
if s not in self._substitutions
|
||
|
}
|
||
|
|
||
|
def solve(self, disable_congruences=True, disable_equivalences=True):
|
||
|
"""Solve the system of constraint equations to find simplified constraints
|
||
|
"""
|
||
|
self._raise_inconsistencies()
|
||
|
# as long as there are symbols with equalities, solve for them
|
||
|
# NOTE(avik): this is guaranteed to terminate (#iterations <= #symbols)
|
||
|
while self._symbols_with_equalities:
|
||
|
s = self._symbols_with_equalities.pop()
|
||
|
exprs = self._univariate_inequalities.pop(s)
|
||
|
solution = sympy.solvers.inequalities.reduce_inequalities(exprs, s)
|
||
|
if isinstance(solution, sympy.And):
|
||
|
solution = next((arg for arg in solution.args if isinstance(arg, sympy.Eq)), solution)
|
||
|
assert isinstance(solution, sympy.Eq), f"Expected an equality constraint for {s}, got {solution}"
|
||
|
symbol, val = solution.args
|
||
|
assert symbol == s, f"Expected a constraint on {s} instead of on {symbol}"
|
||
|
# because this is univariate, the solution is a specialization
|
||
|
self._static_results.add(f"{self._dcp.symbol_to_source[s][0].name()} == {val}")
|
||
|
# add this as a substitution to simplify other constraints
|
||
|
self._substitutions[s] = val
|
||
|
|
||
|
# simplify multivariate inequalities: some of them will now become univariate!
|
||
|
multivariate_inequalities = self._multivariate_inequalities
|
||
|
self._multivariate_inequalities = set()
|
||
|
for expr in multivariate_inequalities:
|
||
|
self.add(expr.subs(s, self._substitutions[s]))
|
||
|
self._raise_inconsistencies()
|
||
|
|
||
|
self._specialize_divisor_symbols()
|
||
|
|
||
|
# solve linear congruences
|
||
|
# NOTE(avik): We do not need to solve them for symbols that have already been specialized.
|
||
|
reduced_congruences = self._reduce_congruences()
|
||
|
for s, congruences in reduced_congruences.items():
|
||
|
for congruence in congruences:
|
||
|
# any congruence that cannot be checked becomes a dynamic constraint as well
|
||
|
if s not in self._substitutions or not sympy.checksol(congruence, {s: self._substitutions[s]}):
|
||
|
if self._is_supported_congruence(congruence):
|
||
|
base, divisor = congruence.args
|
||
|
tmp_name = f"_{self._dcp.source_name_to_debug_name[self._dcp.symbol_to_source[s][0].name()]}"
|
||
|
tmp = sympy.Symbol(tmp_name, integer=True)
|
||
|
from torch._dynamo.source import ConstantSource
|
||
|
self._dcp.symbol_to_source[tmp] = [ConstantSource(tmp_name)]
|
||
|
r = try_solve(sympy.Eq(base, divisor * tmp), s)
|
||
|
self._dynamic_results.add(self._dcp.doprint(sympy.Eq(s, r[1])))
|
||
|
elif disable_congruences:
|
||
|
self._force_specialization(s)
|
||
|
self._univariate_inequalities.pop(s, None)
|
||
|
|
||
|
# remaining symbols have only pure inequalities (no equalities)
|
||
|
for s, exprs in self._univariate_inequalities.items():
|
||
|
try:
|
||
|
solution = sympy.solvers.inequalities.reduce_inequalities(exprs, s)
|
||
|
# because this is univariate, the solution is a dynamic (range) constraint
|
||
|
if isinstance(solution, sympy.Or):
|
||
|
solution = next(iter(arg for arg in solution.args if arg.subs(self._var_to_val)))
|
||
|
if isinstance(solution, sympy.And):
|
||
|
for arg in solution.args:
|
||
|
self._dynamic_results.add(self._dcp.doprint(arg))
|
||
|
else:
|
||
|
self._dynamic_results.add(self._dcp.doprint(solution))
|
||
|
except (NotImplementedError, AssertionError) as e:
|
||
|
log.warning("Failed to reduce inequalities: %s", e)
|
||
|
for expr in exprs:
|
||
|
self._dynamic_results.add(self._dcp.doprint(expr))
|
||
|
|
||
|
# simplify symbolic equivalences: some of them will now become specializations!
|
||
|
symbolic_equivalences = self._symbolic_equivalences
|
||
|
self._symbolic_equivalences = []
|
||
|
for source, expr in symbolic_equivalences:
|
||
|
if disable_equivalences and not self._is_supported_equivalence(expr):
|
||
|
for s in expr.free_symbols:
|
||
|
self._force_specialization(s)
|
||
|
sexpr = self._dcp._print_Symbol(s)
|
||
|
self._dynamic_results = {r for r in self._dynamic_results if sexpr not in r}
|
||
|
self.add_equality(source, expr.subs(self._substitutions))
|
||
|
|
||
|
# remaining symbolic equivalences become dynamic equality constraints
|
||
|
for source, expr in self._symbolic_equivalences:
|
||
|
self._dynamic_results.add(f"{self._dcp.print_source(source)} == {self._dcp.doprint(expr)}")
|
||
|
|
||
|
@classmethod
|
||
|
def _is_supported_equivalence(cls, expr):
|
||
|
# Currently supported Dim ops are linear expressions with integer coefficients.
|
||
|
# So check that expr only contains +, *, ints, and a single occurrence of a symbol.
|
||
|
# (See also documentation of dynamic_shapes._DerivedDim.)
|
||
|
if isinstance(expr, (sympy.Add, sympy.Mul)):
|
||
|
lhs, rhs = expr.args
|
||
|
return (
|
||
|
(cls._is_supported_equivalence(lhs) and isinstance(rhs, sympy.Integer)) or
|
||
|
(isinstance(lhs, sympy.Integer) and cls._is_supported_equivalence(rhs))
|
||
|
)
|
||
|
return isinstance(expr, sympy.Symbol)
|
||
|
|
||
|
@classmethod
|
||
|
def _is_supported_congruence(cls, congruence):
|
||
|
base, divisor = congruence.args
|
||
|
# Congruences that can be currently expressed with supported Dim ops are
|
||
|
# of the form (x + a) % b == 0, where x is a Dim and a and b are constants.
|
||
|
# This allows us to derive x as b*y - a for some Dim y.
|
||
|
# (See also documentation of dynamic_shapes._DerivedDim.)
|
||
|
if isinstance(base, sympy.Add):
|
||
|
lhs, rhs = base.args
|
||
|
cond = (
|
||
|
(isinstance(lhs, sympy.Symbol) and isinstance(rhs, sympy.Integer)) or
|
||
|
(isinstance(lhs, sympy.Integer) and isinstance(rhs, sympy.Symbol))
|
||
|
)
|
||
|
else:
|
||
|
cond = isinstance(base, sympy.Symbol)
|
||
|
cond = cond and isinstance(divisor, sympy.Integer)
|
||
|
return cond
|
||
|
|
||
|
def forced_specializations(self):
|
||
|
"""Returns a dictionary of the names of symbols to their specialized value
|
||
|
"""
|
||
|
def debug_name(src):
|
||
|
name = src.name()
|
||
|
if self._dcp.source_name_to_debug_name:
|
||
|
return f"{self._dcp.source_name_to_debug_name[name]} = {name}"
|
||
|
else:
|
||
|
return name
|
||
|
|
||
|
return {
|
||
|
debug_name(self._dcp.symbol_to_source[s][0]): val
|
||
|
for s, val in self._substitutions.items()
|
||
|
if s in self._marked_dynamic
|
||
|
}
|
||
|
|
||
|
def remove_redundant_dynamic_results(self):
|
||
|
"""Remove constraints of the form 2 <= dynamic_dim(...) as 2 is the default
|
||
|
lower bound.
|
||
|
"""
|
||
|
candidates_for_removal = []
|
||
|
dynamic_results = set()
|
||
|
for dc in self._dynamic_results:
|
||
|
# Instead of 2 <= dynamic_dim(...) simply suggest dynamic_dim(...).
|
||
|
# There is no change in behavior since 2 is the default lower bound.
|
||
|
dc_ = re.sub(r"2 <= dynamic_dim(.+)", r"dynamic_dim\1", dc)
|
||
|
if dc != dc_:
|
||
|
candidates_for_removal.append(dc_)
|
||
|
else:
|
||
|
dynamic_results.add(dc_)
|
||
|
for dc in candidates_for_removal:
|
||
|
# remove dynamic_dim(t, 0) as a constraint when dynamic_dim(t, 0) also
|
||
|
# appears as part of another constraint
|
||
|
found = False
|
||
|
for other_dc in dynamic_results:
|
||
|
if dc in other_dc:
|
||
|
found = True
|
||
|
if not found:
|
||
|
dynamic_results.add(dc)
|
||
|
self._dynamic_results = dynamic_results
|
||
|
|
||
|
def prettify_results(
|
||
|
self,
|
||
|
original_signature: inspect.Signature,
|
||
|
constraint_violation_error=None,
|
||
|
forced_specializations=None,
|
||
|
):
|
||
|
"""Format a message for constraint violation erros"""
|
||
|
if self._dcp.source_name_to_debug_name:
|
||
|
def transform(s):
|
||
|
for k, v in self._dcp.source_name_to_debug_name.items():
|
||
|
s = s.replace(k, v)
|
||
|
return s
|
||
|
|
||
|
results = defaultdict(dict)
|
||
|
|
||
|
def flip(op):
|
||
|
if op == "<=":
|
||
|
return ">="
|
||
|
if op == ">=":
|
||
|
return "<="
|
||
|
if op == "<":
|
||
|
return ">"
|
||
|
if op == ">":
|
||
|
return "<"
|
||
|
assert op == "=="
|
||
|
return op
|
||
|
|
||
|
def relation_with_digit(expr, op, digit):
|
||
|
if op == "<=":
|
||
|
results[expr]["max"] = digit
|
||
|
elif op == "<":
|
||
|
results[expr]["max"] = digit - 1
|
||
|
elif op == ">=":
|
||
|
results[expr]["min"] = digit
|
||
|
elif op == ">":
|
||
|
results[expr]["min"] = digit + 1
|
||
|
else:
|
||
|
assert op == "=="
|
||
|
results[expr]["eq"] = digit
|
||
|
|
||
|
for s in self._static_results.union(self._dynamic_results):
|
||
|
t = transform(s)
|
||
|
if t == s:
|
||
|
continue
|
||
|
left, op, right = re.split(r"( == | <= | >= | < | > )", t)
|
||
|
op = op.strip()
|
||
|
if op == "==" and left == right:
|
||
|
continue
|
||
|
if right.isdigit():
|
||
|
relation_with_digit(left, op, int(right))
|
||
|
elif left.isdigit():
|
||
|
relation_with_digit(right, flip(op), int(left))
|
||
|
else:
|
||
|
assert op == "=="
|
||
|
results[left]["eq"] = sympy.sympify(right)
|
||
|
|
||
|
buf = ""
|
||
|
debug_names = set()
|
||
|
if forced_specializations:
|
||
|
debug_names.update(k.split(" = ")[0] for k in forced_specializations.keys())
|
||
|
buf += (
|
||
|
f"Specializations unexpectedly required ({', '.join(debug_names)})! "
|
||
|
"For more information, run with TORCH_LOGS=\"+dynamic\".\n"
|
||
|
)
|
||
|
for s, val in forced_specializations.items():
|
||
|
buf += f" - {s} must be specialized to {val} because the guards generated for it are too complex.\n"
|
||
|
|
||
|
dims = []
|
||
|
others = []
|
||
|
match = None
|
||
|
if constraint_violation_error:
|
||
|
match = re.search(r"Constraints violated \((.*)\)", constraint_violation_error.args[0])
|
||
|
if match is not None:
|
||
|
debug_names.update(match.expand(r'\1').split(', '))
|
||
|
|
||
|
for k, c in sorted(results.items()):
|
||
|
# if k not in debug_names:
|
||
|
# continue
|
||
|
if "eq" in c:
|
||
|
other = c["eq"]
|
||
|
if isinstance(other, int):
|
||
|
others.append(f"{k} = None # {other}")
|
||
|
elif self._is_supported_equivalence(other):
|
||
|
s = next(iter(other.free_symbols))
|
||
|
if s not in results:
|
||
|
modulus, remainder = sympy.polys.polytools.div(other, s)
|
||
|
c_min = c.get("min", 2)
|
||
|
min_ = math.ceil((c_min - remainder) / modulus)
|
||
|
c_max = c.get("max", sys.maxsize - 1)
|
||
|
max_ = math.floor((c_max - remainder) / modulus)
|
||
|
dims.append(f"{s} = Dim('{s}', min={min_}, max={max_}) # {c_min} <= {other} <= {c_max}")
|
||
|
others.append(f"{k} = {other}")
|
||
|
else:
|
||
|
min_ = c.get("min", None)
|
||
|
if min_ == 2:
|
||
|
min_ = None
|
||
|
max_ = c.get("max", None)
|
||
|
if min_ is not None and max_ is not None:
|
||
|
dims.append(f"{k} = Dim('{k}', min={min_}, max={max_})")
|
||
|
elif min_ is not None:
|
||
|
dims.append(f"{k} = Dim('{k}', min={min_})")
|
||
|
elif max_ is not None:
|
||
|
dims.append(f"{k} = Dim('{k}', max={max_})")
|
||
|
else:
|
||
|
dims.append(f"{k} = Dim('{k}')")
|
||
|
|
||
|
buf += "\nSuggested fixes:\n "
|
||
|
buf += "\n ".join(dims + others)
|
||
|
|
||
|
return buf
|
||
|
|
||
|
# Note: Model inputs are wrapped as LocalSource in dynamo.
|
||
|
# LocalSource.name() wraps the name with L[""]. We use regular
|
||
|
# expression to do the replacement to avoid traversing up
|
||
|
# the source hierarchy manually.
|
||
|
def extract_and_rewrite_local(dc):
|
||
|
match = re.search(r"L\['(.+?)'\]", dc)
|
||
|
if match is None:
|
||
|
return
|
||
|
arg = match.expand(r'\1')
|
||
|
dc = re.sub(r"L\['(.+?)'\]", r'\1', dc)
|
||
|
return arg, dc
|
||
|
|
||
|
def group(results, args_index):
|
||
|
groups = defaultdict(list)
|
||
|
for dc in results:
|
||
|
local = extract_and_rewrite_local(dc)
|
||
|
if local is None:
|
||
|
# This can happen, e.g., with `assume_constant_result`.
|
||
|
# In that case, we drop the constraint.
|
||
|
# TODO(avik) Maybe we should generate an assertion here?
|
||
|
continue
|
||
|
arg, dc = local
|
||
|
if arg in args_index:
|
||
|
groups[args_index[arg]].append(dc)
|
||
|
else:
|
||
|
# This can happen, e.g., with decorators that change the signature.
|
||
|
# In that case, we drop the constraint. Seems hard to do better. :/
|
||
|
# TODO(avik) Maybe warn that `arg` in not in `signature`?
|
||
|
continue
|
||
|
sorted_groups = []
|
||
|
for idx, dcs in sorted(groups.items()):
|
||
|
_, arg = idx
|
||
|
sorted_groups.append((arg, sorted(dcs)))
|
||
|
return sorted_groups
|
||
|
|
||
|
signature = original_signature.replace(return_annotation=inspect.Signature.empty)
|
||
|
args_index = {}
|
||
|
for i, arg in enumerate(signature.parameters.keys()):
|
||
|
args_index[arg] = (i, arg)
|
||
|
|
||
|
def print_results(grouped, indent, result_fn):
|
||
|
nonlocal buf
|
||
|
|
||
|
space = False
|
||
|
for arg, results in grouped:
|
||
|
if space:
|
||
|
buf += "\n"
|
||
|
else:
|
||
|
space = True
|
||
|
buf += f"\n{indent}# {arg}:"
|
||
|
for result in results:
|
||
|
buf += f"\n{indent}{result_fn(result)}"
|
||
|
|
||
|
buf = ""
|
||
|
if forced_specializations:
|
||
|
buf += (
|
||
|
"Some dynamic dimensions need to be specialized because "
|
||
|
"the constraints inferred for them are too complex to specify.\n"
|
||
|
)
|
||
|
for s, val in forced_specializations.items():
|
||
|
buf += f" - {s}, which was marked dynamic, must be specialized to {val}.\n"
|
||
|
indent = 4 * " "
|
||
|
if self._static_results:
|
||
|
grouped_static_results = group(self._static_results, args_index)
|
||
|
buf += "\nThe following dimensions have been specialized and CANNOT be dynamic."
|
||
|
buf += f"\n```\ndef specializations{str(signature)}:"
|
||
|
print_results(
|
||
|
grouped_static_results,
|
||
|
indent,
|
||
|
lambda result: f"assert {result}",
|
||
|
)
|
||
|
buf += "\n```\n"
|
||
|
if self._dynamic_results:
|
||
|
grouped_dynamic_results = group(self._dynamic_results, args_index)
|
||
|
buf += "\nThe following dimensions CAN be dynamic."
|
||
|
buf += "\nPlease use the following code to specify the constraints they must satisfy:"
|
||
|
buf += f"\n```\ndef specify_constraints{str(signature)}:"
|
||
|
buf += f"\n{indent}return ["
|
||
|
print_results(
|
||
|
grouped_dynamic_results,
|
||
|
indent * 2,
|
||
|
lambda result: f"{result},",
|
||
|
)
|
||
|
buf += f"\n{indent}]\n```\n"
|
||
|
return buf
|
||
|
|
||
|
|
||
|
TLS = threading.local()
|
||
|
|
||
|
|
||
|
class ShapeEnv:
|
||
|
# This is a wrapper over the actual __init__ function.
|
||
|
#
|
||
|
# Where to add a new constructor parameter to ShapeEnv?
|
||
|
# =====================================================
|
||
|
# This __init__ function should be used only for parameters related to event recording.
|
||
|
# These are parameters that we don't wish to pass down the road to new ShapeEnv instances
|
||
|
# created from replaying events.
|
||
|
#
|
||
|
# If you wish to add a parameter to the constructor of ShapeEnv, unrelated to event
|
||
|
# recording, do so in the _init function.
|
||
|
def __init__(
|
||
|
self, *,
|
||
|
should_record_events: Optional[bool] = None,
|
||
|
tracked_fakes: Optional[List[Any]] = None,
|
||
|
**kwargs
|
||
|
) -> None:
|
||
|
self._init(**kwargs)
|
||
|
|
||
|
# Disable event recording when replaying.
|
||
|
kwargs["should_record_events"] = False
|
||
|
|
||
|
from torch.fx.experimental.validator import translation_validation_enabled
|
||
|
self._translation_validation_enabled = translation_validation_enabled()
|
||
|
|
||
|
# If not specified, enable event recording if both:
|
||
|
# - Translation validation is on
|
||
|
# - Translation validation bisection is not disabled
|
||
|
self.should_record_events = (
|
||
|
should_record_events
|
||
|
if should_record_events is not None
|
||
|
else (
|
||
|
self._translation_validation_enabled
|
||
|
and not config.translation_validation_no_bisect
|
||
|
)
|
||
|
)
|
||
|
|
||
|
# Enable event recording check if both:
|
||
|
# - It should record events
|
||
|
# - The recording check is enabled
|
||
|
self.check_recorded_events = (
|
||
|
self.should_record_events and config.check_shape_env_recorded_events
|
||
|
)
|
||
|
|
||
|
# This will make sure we only record the top-level function call.
|
||
|
self.is_recording = not self.should_record_events
|
||
|
# Keep track of the list of tracked fakes.
|
||
|
self.tracked_fakes = tracked_fakes
|
||
|
# List of events for reconstructing ShapeEnv at arbitrary points in time.
|
||
|
self.events: List[ShapeEnvEvent] = (
|
||
|
[ShapeEnvEvent(ShapeEnv, kwargs=kwargs)] if self.should_record_events else []
|
||
|
)
|
||
|
|
||
|
# Pro-tip: if you add new field to ShapeEnv, this affects some accept
|
||
|
# tests. Accept their output with:
|
||
|
#
|
||
|
# EXPECTTEST_ACCEPT=1 python test/dynamo/test_dynamic_shapes.py -k test_shape_env_equal
|
||
|
#
|
||
|
def _init(
|
||
|
self, *,
|
||
|
allow_scalar_outputs=True,
|
||
|
allow_dynamic_output_shape_ops=True,
|
||
|
# NB: These are legacy configuration that help us make good choices
|
||
|
# when the constraint/dynamic dims are not explicitly passed to us.
|
||
|
# Ideally we will fix all call sites to be explicit and not have
|
||
|
# implicit choices, but this apparently was pretty involved.
|
||
|
assume_static_by_default=False,
|
||
|
# Note - On 0/1 specialization
|
||
|
#
|
||
|
# The following options affect decisions we make about eager
|
||
|
# specialization. Disabling them will increase trace time (as we do
|
||
|
# more symbolic reasoning) and can also harm the quality of generated
|
||
|
# code (because inductor may not be able to specialize for bounds
|
||
|
# being equal--although if we later respecialize because of a guard,
|
||
|
# your code may be just as good as it was before.)
|
||
|
#
|
||
|
# When True, eagerly specialize input sizes which have 0/1.
|
||
|
specialize_zero_one=True,
|
||
|
# When True, assume input sizes which have the same size are
|
||
|
# symbolically equal.
|
||
|
duck_shape=True,
|
||
|
# For debugging
|
||
|
co_fields=None,
|
||
|
# XXX Add any new settings that could affect FakeTensor evaluation
|
||
|
# to: torch._subclasses.fake_tensor._ShapeEnvSettings
|
||
|
):
|
||
|
# Not directly used by ShapeEnv; indirectly used by FakeTensor
|
||
|
self.allow_scalar_outputs = allow_scalar_outputs
|
||
|
self.allow_dynamic_output_shape_ops = allow_dynamic_output_shape_ops
|
||
|
self.guards: List[ShapeGuard] = []
|
||
|
# Maps symbolic ints to their original concrete values
|
||
|
# Currently populated from tensors
|
||
|
self.var_to_val: Dict[sympy.Symbol, sympy.Integer] = {}
|
||
|
# Maps symbolic ints to their min/max range. These ranges
|
||
|
# are conservative: the int MUST fall in the range, but the
|
||
|
# range may contain ints which may not actually appear in
|
||
|
# practice
|
||
|
self.var_to_range: Dict[sympy.Symbol, ValueRanges] = {}
|
||
|
self.source_name_to_debug_name: Dict[str, str] = {}
|
||
|
self.var_to_sources: Dict[sympy.Symbol, List[Source]] = {}
|
||
|
self.var_to_stack: Dict[sympy.Symbol, CapturedTraceback] = {}
|
||
|
# Maps from sympy ints to expressions representing them
|
||
|
# Populated from equality guards (i.e. a.shape[0] == b.shape[0])
|
||
|
self.replacements: Dict[sympy.Symbol, sympy.Expr] = {}
|
||
|
# Set holds a % b expressions that evaluate to 0.
|
||
|
self.divisible: Set[sympy.Expr] = set()
|
||
|
# Set that holds "size-like" symbols. When we perform
|
||
|
# "size-oblivious" tests, these can be assumed to be >= 2.
|
||
|
self.size_like: Set[sympy.Symbol] = set()
|
||
|
# Duck-shaping says that if two input tensors have the same size,
|
||
|
# they get assigned the same symbolic variable
|
||
|
self.val_to_var: Dict[int, sympy.Expr] = {}
|
||
|
if specialize_zero_one:
|
||
|
self.val_to_var = {0: sympy.Integer(0), 1: sympy.Integer(1)}
|
||
|
self.unbacked_symfloat_counter = itertools.count()
|
||
|
self.unbacked_symint_counter = itertools.count()
|
||
|
# Similar to guards, but these MUST evaluate to true and can
|
||
|
# only be evaluated at runtime midway through (i.e., they always
|
||
|
# involve unbacked symints)
|
||
|
#
|
||
|
# For efficiency reasons, we index in the following way. Suppose you have
|
||
|
# a runtime assert i0 + i1 <= s1. We pick the most recently allocated
|
||
|
# symbol in the source expression and add the assert to the list for
|
||
|
# that symbol e.g., {i1: [i0 + i1 <= s1]}.
|
||
|
#
|
||
|
# We access the runtime asserts in two situations:
|
||
|
#
|
||
|
# - When we are guarding on an expression, we will attempt to
|
||
|
# statically evaluate it, in case the unbacked SymInts can
|
||
|
# simplify away. If we have a runtime assert, we may be able
|
||
|
# to discharge the guard entirely. We only need to attempt
|
||
|
# runtime asserts that mention freevars of the expression in
|
||
|
# question.
|
||
|
#
|
||
|
# - When we are performing codegen (in Inductor for eager, or
|
||
|
# when finalizing the export FX graph), we need to know what
|
||
|
# extra runtime asserts to insert. Whenever an unbacked
|
||
|
# SymInt comes into scope, all runtime asserts involving it
|
||
|
# become eligible for insertion (so long as all of their other
|
||
|
# free unbacked symbols are also in scope). We technically
|
||
|
# can handle any choice of key by kicking inexpressible asserts
|
||
|
# to the next unbacked symbol to wait on, but if we choose the
|
||
|
# latest key, an assert will only show up at the moment when
|
||
|
# we can actually codegen it.
|
||
|
self.deferred_runtime_asserts: Dict[sympy.Symbol, List[RuntimeAssert]] = {}
|
||
|
# This exists so we can efficiently invalidate the cache (it's used as
|
||
|
# part of the cache key); otherwise we'd have to iterate through
|
||
|
# deferred_runtime_asserts to compute its length
|
||
|
self.num_deferred_runtime_asserts = 0
|
||
|
self.assume_static_by_default = assume_static_by_default
|
||
|
self.specialize_zero_one = specialize_zero_one
|
||
|
self.duck_shape = duck_shape
|
||
|
self.log = log
|
||
|
self.log.debug("create_env")
|
||
|
self.frozen = False
|
||
|
self.dim_constraints: Optional[DimConstraints] = None
|
||
|
self.counter = collections.Counter()
|
||
|
# Mapping from sympy.Symbol to the number of guards which mention this
|
||
|
# symbol
|
||
|
self.symbol_guard_counter = collections.Counter()
|
||
|
# A selection of important fields on co_field; solely used for
|
||
|
# signpost_event
|
||
|
self.co_fields = co_fields if co_fields else {}
|
||
|
|
||
|
# Version counter used to invalidate cached values
|
||
|
self._prev_cache_key = self._get_key()
|
||
|
self._version_counter = 0
|
||
|
|
||
|
# Cache for FX nodes.
|
||
|
# Maps an already built node a tuple of:
|
||
|
# 1. node's target
|
||
|
# 2. list of arguments
|
||
|
# This drastically reduces the size of the FX graph, avoiding
|
||
|
# duplicated nodes.
|
||
|
self.fx_node_cache: Dict[Tuple[Callable, Tuple[Any, ...]], torch.fx.Node] = {}
|
||
|
self.source_to_symbol: Dict[str, sympy.Symbol] = {}
|
||
|
|
||
|
from torch.fx.experimental.validator import translation_validation_enabled
|
||
|
self._translation_validation_enabled = translation_validation_enabled()
|
||
|
|
||
|
if self._translation_validation_enabled:
|
||
|
from torch.fx.experimental.validator import TranslationValidator
|
||
|
|
||
|
self.validator = TranslationValidator()
|
||
|
self.graph = torch.fx.Graph()
|
||
|
# Create an output graph and start inserting before that.
|
||
|
# This is needed when 'deepcopy'-ing this object.
|
||
|
self.graph.inserting_before(self.graph.output(None))
|
||
|
|
||
|
# Mapping of each node name to the node itself.
|
||
|
#
|
||
|
# This is useful for matching an FX node from a recorded ShapeEnv.graph
|
||
|
# to the FX node of the ShapeEnv we are running the event on.
|
||
|
#
|
||
|
# Whenever you add a node to self.graph, you must add a mapping to this
|
||
|
# variable. Otherwise, the built FX graph on the replayed ShapeEnv will
|
||
|
# not be valid.
|
||
|
self.name_to_node: Dict[str, torch.fx.Node] = {}
|
||
|
|
||
|
def check_equal(self, other: "ShapeEnv") -> None:
|
||
|
"""Compare another ShapeEnv for equivalence
|
||
|
"""
|
||
|
# ShapeEnv fields that are not relevant for the outcome of
|
||
|
# ShapeEnv.produce_guards call:
|
||
|
# - Debugging variables
|
||
|
# - Translation validation related variables
|
||
|
# - Events recording related variables
|
||
|
non_state_variable_names = (
|
||
|
"counter",
|
||
|
"log",
|
||
|
"var_to_stack",
|
||
|
"fx_node_cache",
|
||
|
"graph",
|
||
|
"validator",
|
||
|
"check_recorded_events",
|
||
|
"should_record_events",
|
||
|
"is_recording",
|
||
|
"tracked_fakes",
|
||
|
"events",
|
||
|
"source_name_to_debug_name",
|
||
|
"_prev_cache_key",
|
||
|
"_version_counter",
|
||
|
)
|
||
|
|
||
|
# Mapping of the value of each to-be-compared field into the values that
|
||
|
# should actually be compared.
|
||
|
#
|
||
|
# You should modify this if, for example, the field that holds state and
|
||
|
# debugging information. e.g. ShapeGuard holds the actual guard (sympy.Expr)
|
||
|
# and the stack when it was added to the set of guards. In order to compare
|
||
|
# it, we throw away the stack information.
|
||
|
def map_value(key: str, value: Any) -> Any:
|
||
|
if key in ("unbacked_symfloat_counter", "unbacked_symint_counter"):
|
||
|
from copy import copy
|
||
|
|
||
|
# For itertools.count(), we compare the next integer returned
|
||
|
# by the count iterators. Not that we need to copy the iterator
|
||
|
# first. Otherwise we are mutating the object.
|
||
|
return next(copy(value))
|
||
|
elif key == "guards":
|
||
|
# Transform the list of ShapeGuard into a list of expressions.
|
||
|
return [g.expr for g in value]
|
||
|
elif key == "deferred_runtime_asserts":
|
||
|
# Transform the list of RuntimeAsserts into a list of expressions.
|
||
|
return {s: [ra.expr for ra in ras] for s, ras in value.items()}
|
||
|
elif key == "name_to_node":
|
||
|
# Compare just the set of keys is the same.
|
||
|
return set(value.keys())
|
||
|
elif key == "symbol_guard_counter":
|
||
|
# Skip this for comparisons
|
||
|
return None
|
||
|
return value
|
||
|
|
||
|
shape_env_check_state_equal(self, other, non_state_variable_names, map_value)
|
||
|
|
||
|
def _snapshot_tracked_fakes(self) -> Optional[List[Any]]:
|
||
|
if self.tracked_fakes is None:
|
||
|
return None
|
||
|
|
||
|
from torch._dynamo.variables.builder import TrackedFake
|
||
|
|
||
|
def maybe_transform_fake(fake: TrackedFake):
|
||
|
inner_fake = fake.fake \
|
||
|
if isinstance(fake.fake, torch.SymInt) \
|
||
|
else FakeTensorMeta.from_fake(fake.fake)
|
||
|
# Even though TrackedFake accepts either a Union[SymInt, FakeTensor], here we give it a
|
||
|
# FakeTensorMeta for two reasons:
|
||
|
# 1. this is all the information we need when recording ShapeEnvEvents.
|
||
|
# 2. it works even if each TrackedFake changes its metadata.
|
||
|
return TrackedFake(inner_fake, fake.source, fake.symbolic_context) # type: ignore[arg-type]
|
||
|
|
||
|
return [maybe_transform_fake(fake) for fake in self.tracked_fakes]
|
||
|
|
||
|
def _last_event_index(self) -> int:
|
||
|
return len(self.events) - 1
|
||
|
|
||
|
@contextmanager
|
||
|
def _recording(self):
|
||
|
self.is_recording = True
|
||
|
try:
|
||
|
yield
|
||
|
finally:
|
||
|
self.is_recording = False
|
||
|
|
||
|
@record_shapeenv_event()
|
||
|
def freeze(self):
|
||
|
"""Freeze this ShapeEnv to stop accumulating guards
|
||
|
|
||
|
A frozen ShapeEnv will ignore any further guards generated on it and
|
||
|
only emit a warning which may lead to accuracy problems.
|
||
|
"""
|
||
|
self.frozen = True
|
||
|
|
||
|
def _create_symbol_for_source(self, source: Source) -> Optional[sympy.Symbol]:
|
||
|
if not self._translation_validation_enabled:
|
||
|
return None
|
||
|
srcname = source.name()
|
||
|
if source not in self.source_to_symbol:
|
||
|
self.source_to_symbol[srcname] = sympy.Symbol(srcname, integer=True)
|
||
|
return self.source_to_symbol[srcname]
|
||
|
|
||
|
def _add_z3var(self, symbol: sympy.Symbol, type: Type) -> None:
|
||
|
if self._translation_validation_enabled:
|
||
|
self.validator.add_var(symbol, type)
|
||
|
|
||
|
def _add_target_expr(self, expr) -> None:
|
||
|
if self._translation_validation_enabled:
|
||
|
self.validator.add_target_expr(expr)
|
||
|
|
||
|
def _add_assertion(self, expr) -> None:
|
||
|
if self._translation_validation_enabled:
|
||
|
self.validator.add_assertion(expr)
|
||
|
|
||
|
def _check_translation_validate(self) -> None:
|
||
|
if self._translation_validation_enabled:
|
||
|
self.validator.validate()
|
||
|
|
||
|
@record_shapeenv_event()
|
||
|
def _create_fx_call_function(
|
||
|
self,
|
||
|
op: Callable,
|
||
|
args: Tuple,
|
||
|
) -> Tuple[Optional[torch.fx.Node], bool]:
|
||
|
# Cache this tuple in order to avoid duplicated nodes.
|
||
|
node_key = (op, args)
|
||
|
# Flags whether the returned node was cached or not.
|
||
|
fresh = False
|
||
|
|
||
|
if self._translation_validation_enabled and node_key not in self.fx_node_cache:
|
||
|
from torch.fx.experimental.validator import z3op
|
||
|
|
||
|
# Presence of None in the arguments implies that we should ignore this operation.
|
||
|
if any(a is None for a in args):
|
||
|
# We check if we are not mixing SymNode that should not be ignored
|
||
|
# (fx_node is not None) with those that should (fx_node is None).
|
||
|
assert all(not isinstance(a, torch.fx.Node) for a in args)
|
||
|
return None, fresh
|
||
|
|
||
|
fresh = True
|
||
|
lifted_op = z3op(op, self.validator)
|
||
|
|
||
|
# If translation validation is enabled, all arguments must have its
|
||
|
# own FX node.
|
||
|
assert all(a is not None for a in args), f"missing arg in FX graph ({op.__name__}): {args}"
|
||
|
node = self.fx_node_cache[node_key] = self.graph.call_function(lifted_op, args)
|
||
|
self.name_to_node[node.name] = node
|
||
|
|
||
|
return self.fx_node_cache.get(node_key, None), fresh
|
||
|
|
||
|
def _create_fx_placeholder_and_z3var(
|
||
|
self,
|
||
|
symbol: sympy.Symbol,
|
||
|
type: Type,
|
||
|
) -> Optional[torch.fx.Node]:
|
||
|
if not self._translation_validation_enabled:
|
||
|
return None
|
||
|
|
||
|
node_key = (self.graph.placeholder, (symbol,))
|
||
|
|
||
|
# Check if we haven't added this symbol already.
|
||
|
# If so, skip the placeholder creation, as it
|
||
|
# generates invalid Python code.
|
||
|
if node_key not in self.fx_node_cache:
|
||
|
# Add a Z3 variable according to 'type'.
|
||
|
self._add_z3var(symbol, type)
|
||
|
# Create the FX placeholder out of a mangled name.
|
||
|
mangled_name = re.sub(r'[^a-zA-Z0-9]', '_', re.sub(r'[()]', '', symbol.name))
|
||
|
node = self.fx_node_cache[node_key] = self.graph.placeholder(mangled_name)
|
||
|
self.name_to_node[node.name] = node
|
||
|
# Attach the 'symbol' to the placeholder so that we can retrieve
|
||
|
# the Z3 variable later.
|
||
|
node.meta["symbol"] = symbol
|
||
|
|
||
|
return self.fx_node_cache[node_key]
|
||
|
|
||
|
def _remove_fx_node(self, node: Optional[torch.fx.Node]) -> None:
|
||
|
if self._translation_validation_enabled and node is not None:
|
||
|
self.name_to_node.pop(node.name)
|
||
|
self.graph.erase_node(node)
|
||
|
|
||
|
def _add_fx_node_metadata(self, node: torch.fx.Node) -> None:
|
||
|
from torch._dynamo.utils import get_current_node
|
||
|
|
||
|
if self.should_record_events:
|
||
|
node.meta[SHAPEENV_EVENT_KEY] = self._last_event_index()
|
||
|
node.meta[CURRENT_NODE_KEY] = get_current_node()
|
||
|
|
||
|
def _suppress_guards_tls(self):
|
||
|
return getattr(TLS, "suppress_guards", False)
|
||
|
|
||
|
@record_shapeenv_event()
|
||
|
def _suppress_guards_enter(self):
|
||
|
TLS.suppress_guards = True
|
||
|
|
||
|
@record_shapeenv_event()
|
||
|
def _suppress_guards_exit(self):
|
||
|
TLS.suppress_guards = False
|
||
|
|
||
|
@contextmanager
|
||
|
def suppress_guards(self):
|
||
|
"""Context manager to ignore all guards generated inside"""
|
||
|
self._suppress_guards_enter()
|
||
|
try:
|
||
|
yield
|
||
|
finally:
|
||
|
self._suppress_guards_exit()
|
||
|
|
||
|
def _get_key(self):
|
||
|
"""
|
||
|
Defines the current "state" of the guards we've accumulated in this ShapeEnv.
|
||
|
Determines when we need to invalidate our cache
|
||
|
"""
|
||
|
return (len(self.replacements), len(self.divisible), self.num_deferred_runtime_asserts)
|
||
|
|
||
|
def _update_version_counter(self):
|
||
|
# The shape environment is queried orders of magnitude more often than
|
||
|
# it is changed, so we summarise the cache key into a linearly
|
||
|
# increasing version counter which is cheaper to check in _lru_cache
|
||
|
|
||
|
# Only update version counter if the state actually changed
|
||
|
cur_key = self._get_key()
|
||
|
if self._prev_cache_key != cur_key:
|
||
|
self._prev_cache_key = cur_key
|
||
|
self._version_counter += 1
|
||
|
|
||
|
def _produce_dyn_sizes(self,
|
||
|
ex_size: Sequence[int],
|
||
|
source: Source,
|
||
|
symbolic_context: SymbolicContext
|
||
|
) -> List[sympy.Expr]:
|
||
|
return self._produce_dyn_sizes_from_int_tuple(tuple(ex_size), source, symbolic_context)
|
||
|
|
||
|
def _produce_dyn_sizes_from_int_tuple(self,
|
||
|
tensor_size: Tuple[int],
|
||
|
source: Source,
|
||
|
symbolic_context: SymbolicContext,
|
||
|
) -> List[sympy.Expr]:
|
||
|
assert all(not is_symbolic(val) for val in tensor_size), f"Expect size to be a plain tuple of ints but got {tensor_size}"
|
||
|
from torch._dynamo.source import TensorPropertySource, TensorProperty
|
||
|
_assert_symbol_context(symbolic_context)
|
||
|
dynamic_dims = symbolic_context.dynamic_sizes
|
||
|
constraint_dims = symbolic_context.constraint_sizes
|
||
|
size = []
|
||
|
for i, val in enumerate(tensor_size):
|
||
|
size.append(self.create_symbol(
|
||
|
val,
|
||
|
TensorPropertySource(source, TensorProperty.SIZE, i),
|
||
|
dynamic_dims[i],
|
||
|
constraint_dims[i],
|
||
|
symbolic_context=symbolic_context
|
||
|
))
|
||
|
return size
|
||
|
|
||
|
def create_symbolic_sizes_strides_storage_offset(
|
||
|
self,
|
||
|
ex: torch.Tensor,
|
||
|
source: Source,
|
||
|
*,
|
||
|
symbolic_context: Optional[SymbolicContext] = None,
|
||
|
):
|
||
|
"""
|
||
|
Returns a list of symbolic sizes and strides for the given tensor.
|
||
|
We try our best to express stride in terms of the sizes, so as to not
|
||
|
introduce new symbolic variables.
|
||
|
"""
|
||
|
|
||
|
# Dynamo may want to wrap FakeTensors with SymInt sizes up e.g. make_fx(opt_f(), tracing_mode="symbolic").
|
||
|
# We create symbols in shape_env using the backed hints behind SymInt.
|
||
|
|
||
|
# Case 1: when SymInt is backed, dynamo can proceed with FakeTensors that have concrete shape.
|
||
|
# produce_guards will trigger specializations on the outer stuff
|
||
|
|
||
|
# Case 2: when the SymInt is unbacked, we will throw an data dependent error in require_hint().
|
||
|
#
|
||
|
# It's probably good for now but it's important to note that this approach has implications for
|
||
|
# the original shape_env when checking guards in different order.
|
||
|
|
||
|
# Example:
|
||
|
# ---------
|
||
|
# Consider a function "opt_f" as shown below:
|
||
|
|
||
|
# @torch.compile()
|
||
|
# def opt_f(x: bool, y: Tensor):
|
||
|
# if x == True:
|
||
|
# return y + torch.randn([4])
|
||
|
# else:
|
||
|
# return y
|
||
|
# Depending on the sequence of calls, we might install two different sets of guards:
|
||
|
|
||
|
# 1. opt_f(False, y):
|
||
|
# - "x == False" (always works for any size y)
|
||
|
|
||
|
# 2. opt_f(True, y):
|
||
|
# - Triggers recompilation and results in guards like:
|
||
|
# - "x == True and y.size(0) == 4"
|
||
|
# - (or "y.size(0) == 4 and x == True")
|
||
|
|
||
|
# The order of checking the guards matters. In this specific example:
|
||
|
# If True branch guard check precedes False branch and for True branch, y.size(0) check precedes x == True,
|
||
|
# we may have an unnessary shape speciliazation for y.
|
||
|
def maybe_specialize_sym_int_with_hint(maybe_sym) -> int:
|
||
|
assert isinstance(maybe_sym, (int, torch.SymInt))
|
||
|
if is_symbolic(maybe_sym):
|
||
|
assert maybe_sym.node.shape_env is not self, \
|
||
|
"expect the symbol is created from an shape env other than current one."
|
||
|
return maybe_sym.node.require_hint()
|
||
|
return maybe_sym
|
||
|
|
||
|
ex_size = tuple(maybe_specialize_sym_int_with_hint(sz) for sz in ex.size())
|
||
|
ex_stride = tuple(maybe_specialize_sym_int_with_hint(sd) for sd in ex.stride())
|
||
|
ex_storage_offset = maybe_specialize_sym_int_with_hint(ex.storage_offset())
|
||
|
|
||
|
return self._create_symbolic_sizes_strides_storage_offset(
|
||
|
ex_size,
|
||
|
ex_stride,
|
||
|
ex_storage_offset,
|
||
|
[_is_dim_dynamic(ex, i) for i in range(ex.dim())],
|
||
|
source,
|
||
|
symbolic_context=symbolic_context,
|
||
|
)
|
||
|
|
||
|
@record_shapeenv_event()
|
||
|
def _create_symbolic_sizes_strides_storage_offset(
|
||
|
self,
|
||
|
ex_size: Sequence[int],
|
||
|
ex_stride: Sequence[int],
|
||
|
ex_storage_offset: int,
|
||
|
is_dim_dynamic: Sequence[bool],
|
||
|
source: Source,
|
||
|
*,
|
||
|
symbolic_context: Optional[SymbolicContext] = None,
|
||
|
):
|
||
|
dim = len(ex_size)
|
||
|
|
||
|
# Reimplement the legacy behavior
|
||
|
if symbolic_context is None:
|
||
|
constraint_dims = [None] * dim
|
||
|
dynamic_dims = []
|
||
|
for i in range(dim):
|
||
|
# NB: This is encapsulation breaking! Legacy behavior was
|
||
|
# bad.
|
||
|
if is_dim_dynamic[i]:
|
||
|
r = DimDynamic.DYNAMIC
|
||
|
elif self.assume_static_by_default:
|
||
|
r = DimDynamic.STATIC
|
||
|
else:
|
||
|
r = DimDynamic.DUCK
|
||
|
dynamic_dims.append(r)
|
||
|
dynamic_dims = [DimDynamic.DUCK] * dim
|
||
|
# symbolic_context is None - set one
|
||
|
symbolic_context = StatelessSymbolicContext(dynamic_sizes=dynamic_dims, constraint_sizes=constraint_dims)
|
||
|
# We got a StatelessSymbolicContext
|
||
|
_assert_symbol_context(symbolic_context)
|
||
|
constraint_dims = symbolic_context.constraint_sizes
|
||
|
dynamic_dims = symbolic_context.dynamic_sizes
|
||
|
|
||
|
# TODO: make this configurable from outside symbolic_context; we made a symbolic_context
|
||
|
# decision here where if all sizes are static, we are going to
|
||
|
# specialize all of the inner strides/offset too. We don't have to
|
||
|
# do this, and arguably we should ALWAYS allow for dynamic offset,
|
||
|
# this is cheap.
|
||
|
# TODO: This should be DYNAMIC, using DUCK for BC
|
||
|
dynamic_strides_offset = DimDynamic.STATIC if all(r == DimDynamic.STATIC for r in dynamic_dims) else DimDynamic.DUCK
|
||
|
|
||
|
assert len(dynamic_dims) == dim, f"{len(dynamic_dims)} != {dim}"
|
||
|
assert len(constraint_dims) == dim
|
||
|
|
||
|
from torch._dynamo.source import TensorPropertySource, TensorProperty
|
||
|
size: List[sympy.Expr] = self._produce_dyn_sizes_from_int_tuple(ex_size, source, symbolic_context)
|
||
|
stride: List[Optional[sympy.Expr]] = [None] * len(size)
|
||
|
for i, val in enumerate(ex_stride):
|
||
|
if val in (0, 1):
|
||
|
stride[i] = sympy.Integer(val)
|
||
|
while any(x is None for x in stride):
|
||
|
candidates = {
|
||
|
ex_size[i] * ex_stride[i]: size[i] * stride[i]
|
||
|
for i in range(len(size))
|
||
|
if stride[i] is not None and ex_stride[i] >= 0
|
||
|
}
|
||
|
|
||
|
# iterate over unbound strides in sorted order
|
||
|
def _nested_int_aware_sort(tup):
|
||
|
return (
|
||
|
# Order nested ints by their coefficients.
|
||
|
# 1 here to order nested ints after non-nested-ints.
|
||
|
(1, tup[0].node.nested_int_coeff(), tup[1]) if is_nested_int(tup[0])
|
||
|
else (0, *tup)
|
||
|
)
|
||
|
val_list = sorted(
|
||
|
[(ex_stride[i], i) for i in range(len(stride)) if stride[i] is None],
|
||
|
key=_nested_int_aware_sort,
|
||
|
)
|
||
|
for _, i in val_list:
|
||
|
if stride[i] is None and ex_stride[i] in candidates:
|
||
|
stride[i] = candidates[ex_stride[i]]
|
||
|
candidates[ex_size[i] * ex_stride[i]] = size[i] * stride[i]
|
||
|
|
||
|
if any(x is None for x in stride):
|
||
|
# bind the smallest unbound stride to a new variable
|
||
|
val, i = min(
|
||
|
[
|
||
|
(ex_stride[i], i)
|
||
|
for i in range(len(stride))
|
||
|
if stride[i] is None
|
||
|
], key=_nested_int_aware_sort
|
||
|
)
|
||
|
stride[i] = self.create_symbol(
|
||
|
val,
|
||
|
TensorPropertySource(source, TensorProperty.STRIDE, i),
|
||
|
dynamic_dim=dynamic_strides_offset,
|
||
|
constraint_dim=None,
|
||
|
symbolic_context=symbolic_context,
|
||
|
)
|
||
|
assert all(x is not None for x in stride)
|
||
|
|
||
|
sym_sizes = [
|
||
|
self.create_symintnode(
|
||
|
sym,
|
||
|
hint=hint,
|
||
|
source=TensorPropertySource(source, TensorProperty.SIZE, i),
|
||
|
)
|
||
|
for i, (sym, hint) in enumerate(zip(size, ex_size))
|
||
|
]
|
||
|
sym_stride = []
|
||
|
for i, stride_expr in enumerate(stride):
|
||
|
# NB: Don't duck size the stride; instead use the expression
|
||
|
# we computed
|
||
|
assert stride_expr is not None
|
||
|
sym_stride.append(self.create_symintnode(
|
||
|
stride_expr, hint=ex_stride[i], source=TensorPropertySource(source, TensorProperty.STRIDE, i)))
|
||
|
sym_storage_offset = self.create_symintnode(
|
||
|
self.create_symbol(
|
||
|
ex_storage_offset,
|
||
|
TensorPropertySource(source, TensorProperty.STORAGE_OFFSET),
|
||
|
dynamic_dim=dynamic_strides_offset,
|
||
|
constraint_dim=None,
|
||
|
symbolic_context=symbolic_context
|
||
|
),
|
||
|
hint=ex_storage_offset,
|
||
|
source=TensorPropertySource(source, TensorProperty.STORAGE_OFFSET))
|
||
|
return tuple(sym_sizes), tuple(sym_stride), sym_storage_offset
|
||
|
|
||
|
@record_shapeenv_event()
|
||
|
def create_symintnode(
|
||
|
self,
|
||
|
sym: "sympy.Expr",
|
||
|
*,
|
||
|
hint: Optional[int],
|
||
|
source: Optional[Source] = None,
|
||
|
):
|
||
|
"""Create a SymInt value from a symbolic expression
|
||
|
|
||
|
If you know what the current hint value of the SymInt to be created
|
||
|
is, pass it into hint. Otherwise, pass None and we will make our best
|
||
|
guess
|
||
|
|
||
|
"""
|
||
|
source_name = source.name() if source else None
|
||
|
|
||
|
if self._translation_validation_enabled and source is not None:
|
||
|
# Create a new symbol for this source.
|
||
|
symbol = self._create_symbol_for_source(source)
|
||
|
assert symbol is not None
|
||
|
|
||
|
# Create a new FX placeholder and Z3 variable for 'symbol'.
|
||
|
fx_node = self._create_fx_placeholder_and_z3var(symbol, int)
|
||
|
|
||
|
# Add an equality assertion for the newly created symbol and 'sym'.
|
||
|
self._add_assertion(sympy.Eq(symbol, sym))
|
||
|
else:
|
||
|
fx_node = None
|
||
|
|
||
|
if isinstance(sym, sympy.Integer):
|
||
|
if hint is not None:
|
||
|
assert int(sym) == hint
|
||
|
out = int(sym)
|
||
|
else:
|
||
|
out = SymInt(SymNode(sym, self, int, hint, fx_node=fx_node))
|
||
|
return out
|
||
|
|
||
|
@record_shapeenv_event()
|
||
|
def create_unspecified_symint_and_symbol(self, value, source, dynamic_dim):
|
||
|
"""Create a SymInt wrapping a new unspecified symbol"""
|
||
|
return self.create_symintnode(
|
||
|
self.create_unspecified_symbol(
|
||
|
value,
|
||
|
source=source,
|
||
|
dynamic_dim=dynamic_dim,
|
||
|
),
|
||
|
hint=value,
|
||
|
source=source,
|
||
|
)
|
||
|
|
||
|
def create_symboolnode(self, sym: "sympy.Expr"):
|
||
|
"""Create a SymBool object from a sympy boolean expression"""
|
||
|
# This function is only being used in serialization, so we do not track it
|
||
|
# for validation.
|
||
|
return SymBool(SymNode(sym, self, bool, None))
|
||
|
|
||
|
def _log_create_unbacked_symbol(self, prefix: str, symbol, vr: ValueRanges):
|
||
|
is_debug = config.extended_debug_create_symbol is not None and str(symbol) in config.extended_debug_create_symbol.split(',')
|
||
|
fsummary, maybe_user_loc, maybe_extra_debug = self._get_stack_summary(is_debug)
|
||
|
log.info(
|
||
|
"%s %s [%s, %s]%s (%s)%s",
|
||
|
prefix, symbol, vr.lower, vr.upper, maybe_user_loc, format_frame(fsummary), maybe_extra_debug, stack_info=is_debug
|
||
|
)
|
||
|
|
||
|
@record_shapeenv_event()
|
||
|
def create_unbacked_symfloat(self):
|
||
|
"""Create a symbolic float without a hint value
|
||
|
"""
|
||
|
symbol: sympy.Symbol = sympy.Symbol(f"f{next(self.unbacked_symfloat_counter)}")
|
||
|
self.counter["create_unbacked_symbol"] += 1
|
||
|
self.var_to_stack[symbol] = CapturedTraceback.extract(skip=1)
|
||
|
vr = self.var_to_range[symbol] = ValueRanges.unknown()
|
||
|
|
||
|
# Create a new FX placeholder and Z3 variable for 'symbol'.
|
||
|
fx_node = self._create_fx_placeholder_and_z3var(symbol, float)
|
||
|
|
||
|
self._log_create_unbacked_symbol("create_unbacked_symfloat", symbol, vr)
|
||
|
|
||
|
return SymFloat(SymNode(symbol, self, float, None, fx_node=fx_node))
|
||
|
|
||
|
@record_shapeenv_event()
|
||
|
def create_unbacked_symint(self):
|
||
|
"""Create a symbolic integer without a hint value
|
||
|
"""
|
||
|
symbol: sympy.Symbol = sympy.Symbol(f"u{next(self.unbacked_symint_counter)}", integer=True)
|
||
|
self.counter["create_unbacked_symbol"] += 1
|
||
|
self.var_to_stack[symbol] = CapturedTraceback.extract(skip=1)
|
||
|
vr = self.var_to_range[symbol] = self._default_unspecified_value_range()
|
||
|
|
||
|
# Create a new FX placeholder and Z3 variable for 'symbol'.
|
||
|
fx_node = self._create_fx_placeholder_and_z3var(symbol, int)
|
||
|
|
||
|
self._log_create_unbacked_symbol("create_unbacked_symint", symbol, vr)
|
||
|
|
||
|
return SymInt(SymNode(symbol, self, int, None, fx_node=fx_node))
|
||
|
|
||
|
def is_unbacked_symint(self, symbol: sympy.Symbol) -> bool:
|
||
|
"""Check if a sympy symbol matches the naming convention for unbacked symbols
|
||
|
"""
|
||
|
# NB: keep synced with free_unbacked_symbols
|
||
|
return str(symbol).startswith("u")
|
||
|
|
||
|
@record_shapeenv_event()
|
||
|
def create_unbacked_symbool(self):
|
||
|
"""Create a symbolic boolean without a hint value
|
||
|
"""
|
||
|
symbol: sympy.Symbol = sympy.Symbol(f"u{next(self.unbacked_symint_counter)}", integer=True)
|
||
|
self.counter["create_unbacked_symbol"] += 1
|
||
|
self.var_to_stack[symbol] = CapturedTraceback.extract(skip=1)
|
||
|
vr = self.var_to_range[symbol] = ValueRanges(0, 1)
|
||
|
|
||
|
# Create a new FX placeholder and Z3 variable for 'symbol'.
|
||
|
fx_node = self._create_fx_placeholder_and_z3var(symbol, bool)
|
||
|
|
||
|
self._log_create_unbacked_symbol("create_unbacked_symbool", symbol, vr)
|
||
|
|
||
|
return SymBool(SymNode(sympy.Eq(symbol, 1), self, bool, None, fx_node=fx_node))
|
||
|
|
||
|
@record_shapeenv_event()
|
||
|
def create_unspecified_symbol(
|
||
|
self,
|
||
|
val: Union[int, SymInt],
|
||
|
source: Source,
|
||
|
dynamic_dim: DimDynamic = DimDynamic.DUCK,
|
||
|
constraint_dim: DimConstraint = None, # NB: includes None
|
||
|
) -> "sympy.Expr":
|
||
|
"""Create a symbol with an unspecified value
|
||
|
|
||
|
Compared to standard symbols we do not assume the value is positive,
|
||
|
nor do we specialze on zero or one values.
|
||
|
"""
|
||
|
# 'positive' is None for unspecified symbols, since we can't
|
||
|
# assume that it will be neither positive nor negative.
|
||
|
|
||
|
# We don't want to specialize zero one val for unspecified symbol
|
||
|
# so that we can always get a new symbol despite val.
|
||
|
return self.create_symbol(
|
||
|
val,
|
||
|
source,
|
||
|
dynamic_dim,
|
||
|
constraint_dim,
|
||
|
positive=None,
|
||
|
do_not_specialize_zero_one=True,
|
||
|
symbolic_context=None)
|
||
|
|
||
|
@record_shapeenv_event()
|
||
|
def create_symbol(
|
||
|
self,
|
||
|
val: int,
|
||
|
source: Source,
|
||
|
dynamic_dim: DimDynamic = DimDynamic.DUCK,
|
||
|
constraint_dim: DimConstraint = None, # NB: includes None
|
||
|
positive: Optional[bool] = True,
|
||
|
do_not_specialize_zero_one: bool = False,
|
||
|
symbolic_context=None,
|
||
|
) -> "sympy.Expr":
|
||
|
"""Create a new symbol which is tracked by this ShapeEnv
|
||
|
"""
|
||
|
# see note [Tensor Fakification and Symbol Caching]
|
||
|
source_name = source.name()
|
||
|
if (isinstance(symbolic_context, StatefulSymbolicContext)
|
||
|
and id(self) not in symbolic_context.shape_env_to_source_to_symbol_cache):
|
||
|
symbolic_context.shape_env_to_source_to_symbol_cache[id(self)] = {}
|
||
|
|
||
|
if (isinstance(symbolic_context, StatefulSymbolicContext)
|
||
|
and source_name
|
||
|
and (source_name in symbolic_context.shape_env_to_source_to_symbol_cache[id(self)])):
|
||
|
return symbolic_context.shape_env_to_source_to_symbol_cache[id(self)][source_name]
|
||
|
|
||
|
if do_not_specialize_zero_one:
|
||
|
specialize_zero_one = False
|
||
|
else:
|
||
|
specialize_zero_one = self.specialize_zero_one
|
||
|
|
||
|
assert isinstance(source, Source), f"{type(source)} {source}"
|
||
|
assert not (positive and val < 0), f"positive set for negative value: {val}"
|
||
|
# It's always sound to allocate a symbol as DYNAMIC. If the user
|
||
|
# constrained the symbol, force the symbolic_context to DYNAMIC, because our
|
||
|
# constraint code will do weird stuff if, e.g., it's duck shaped
|
||
|
if constraint_dim is not None:
|
||
|
dynamic_dim = DimDynamic.DYNAMIC
|
||
|
|
||
|
if dynamic_dim is DimDynamic.STATIC:
|
||
|
out = sympy.Integer(val)
|
||
|
if isinstance(symbolic_context, StatefulSymbolicContext) and source_name:
|
||
|
symbolic_context.shape_env_to_source_to_symbol_cache[id(self)][source_name] = out
|
||
|
return out
|
||
|
|
||
|
elif dynamic_dim is DimDynamic.DUCK:
|
||
|
# duck_shape can be used to globally turn off duck shaping, even
|
||
|
# if it was requested
|
||
|
duck = self.duck_shape
|
||
|
elif dynamic_dim is DimDynamic.DYNAMIC:
|
||
|
duck = False
|
||
|
else:
|
||
|
raise AssertionError(f"unhandled dynamic_dim {dynamic_dim}")
|
||
|
|
||
|
if val in (0, 1) and specialize_zero_one:
|
||
|
r = self.val_to_var[val]
|
||
|
elif not duck or val not in self.val_to_var:
|
||
|
# If we're not duck shaping, we always create a new symbol
|
||
|
# Even if we're duck shaping, if we haven't seen this particular
|
||
|
# value before, we also create a new symbol
|
||
|
sympy_expr = sympy.Symbol(f"s{len(self.var_to_val)}", positive=positive, integer=True)
|
||
|
# We always associate vars to vals
|
||
|
if isinstance(val, int):
|
||
|
self.var_to_val[sympy_expr] = sympy.Integer(val)
|
||
|
else:
|
||
|
# Only used for jagged layout nested tensors
|
||
|
self.var_to_val[sympy_expr] = SingletonInt(val.node.nested_int(), coeff=val.node.nested_int_coeff())
|
||
|
|
||
|
# Do the appending later, because we always want to populate this
|
||
|
self.var_to_sources[sympy_expr] = []
|
||
|
# Create a Z3 variable for the new symbol.
|
||
|
self._add_z3var(sympy_expr, int)
|
||
|
|
||
|
if duck:
|
||
|
# Make sure to reuse this symbol for subsequent duck shaping
|
||
|
self.val_to_var[val] = sympy_expr
|
||
|
|
||
|
if isinstance(val, int):
|
||
|
if positive:
|
||
|
# Add assertions for the newly created symbols
|
||
|
self._add_assertion(sympy_expr > 1)
|
||
|
|
||
|
# Apply default range, which assumes not zero-one
|
||
|
self.var_to_range[sympy_expr] = self._default_value_range()
|
||
|
else:
|
||
|
self.var_to_range[sympy_expr] = self._default_unspecified_value_range()
|
||
|
|
||
|
# Small performance optimization: if we have a min-max constraint,
|
||
|
# we can proactively narrow to that range
|
||
|
if isinstance(constraint_dim, StrictMinMaxConstraint):
|
||
|
assert not duck
|
||
|
self.var_to_range[sympy_expr] &= constraint_dim.vr
|
||
|
|
||
|
vr = self.var_to_range[sympy_expr]
|
||
|
|
||
|
if val not in vr:
|
||
|
raise ConstraintViolationError(f"{val} not in range [{vr.lower}, {vr.upper}]")
|
||
|
|
||
|
range_str = f"[{vr.lower}, {vr.upper}]"
|
||
|
else:
|
||
|
# Skip var_range logic for SingletonInt
|
||
|
# Only used for jagged layout nested tensors
|
||
|
range_str = ""
|
||
|
|
||
|
r = sympy_expr
|
||
|
|
||
|
is_debug = (
|
||
|
config.extended_debug_create_symbol is not None and
|
||
|
str(sympy_expr) in config.extended_debug_create_symbol.split(',')
|
||
|
)
|
||
|
fsummary, maybe_user_loc, maybe_extra_debug = self._get_stack_summary(is_debug)
|
||
|
self.log.info(
|
||
|
"create_symbol %s = %s for %s %s%s (%s)%s",
|
||
|
sympy_expr, val, source.name(), range_str,
|
||
|
maybe_user_loc, format_frame(fsummary), maybe_extra_debug, stack_info=is_debug
|
||
|
)
|
||
|
|
||
|
self.counter["create_symbol"] += 1
|
||
|
else:
|
||
|
# This implements duck-shaping: input sizes that match are assigned
|
||
|
# the same symint
|
||
|
r = self.val_to_var[val]
|
||
|
self.log.debug("create_symbol %s duck sized %s", r, source.name())
|
||
|
|
||
|
if isinstance(r, sympy.Symbol):
|
||
|
r_sources = self.var_to_sources[r]
|
||
|
r_sources.append(source)
|
||
|
if not source.is_ephemeral() and r_sources[0].is_ephemeral():
|
||
|
# prefer non-ephemeral source first since it may be guarded on later
|
||
|
r_sources[0], r_sources[-1] = r_sources[-1], r_sources[0]
|
||
|
|
||
|
# This ensures we get zeros in symbol_guard_counts, which makes
|
||
|
# some queries simpler (since we will accumulate mass on 0 this
|
||
|
# way)
|
||
|
self.symbol_guard_counter[r] = 0
|
||
|
|
||
|
if isinstance(symbolic_context, StatefulSymbolicContext) and source_name:
|
||
|
symbolic_context.shape_env_to_source_to_symbol_cache[id(self)][source_name] = r
|
||
|
return r
|
||
|
|
||
|
def _debug_name(self, source):
|
||
|
src_name = source.name()
|
||
|
return self.source_name_to_debug_name.get(src_name, src_name)
|
||
|
|
||
|
def _render_range_for_constraint_violation(self, source, c):
|
||
|
if isinstance(c, StrictMinMaxConstraint):
|
||
|
lower, upper = c.vr.lower, c.vr.upper
|
||
|
default = self._default_value_range()
|
||
|
if lower <= default.lower:
|
||
|
lower = None
|
||
|
if upper >= default.upper:
|
||
|
upper = None
|
||
|
c_render = f"{self._debug_name(source)} = {source.name()} in the specified range"
|
||
|
if lower is not None and upper is not None:
|
||
|
c_render += f" {lower} <= {self._debug_name(source)} <= {upper}"
|
||
|
elif lower is None and upper is not None:
|
||
|
c_render += f" {self._debug_name(source)} <= {upper}"
|
||
|
elif lower is not None and upper is None:
|
||
|
c_render += f" {lower} <= {self._debug_name(source)}"
|
||
|
return c_render
|
||
|
return c.render(source)
|
||
|
|
||
|
def produce_guards(
|
||
|
self,
|
||
|
placeholders,
|
||
|
sources,
|
||
|
source_ref=lambda n: n.name(),
|
||
|
*,
|
||
|
input_contexts: Optional[DimList[SymbolicContext]] = None,
|
||
|
# Encodes user-specified input shape equations of the form s = s' and s = fn(s').
|
||
|
# (See docs on EqualityConstraint for details of the encoding.)
|
||
|
equalities_inputs: Optional[EqualityConstraint] = None,
|
||
|
_simplified=False,
|
||
|
# Indicates if we should produce guards for known static values.
|
||
|
ignore_static=True,
|
||
|
) -> List[str]:
|
||
|
"""
|
||
|
Generates a list of guards strings which, when evaluated in a context that
|
||
|
defines tensors for all the sources, returns True or False depending
|
||
|
on if the guards in the list evaluated to True or not. Primarily used by Dynamo,
|
||
|
but this is also helpful for manual testing of guards (see
|
||
|
evaluate_guards_for_args)
|
||
|
|
||
|
For convenience in testing, a source is allowed to be a str,
|
||
|
in which case we will assume it is a LocalSource
|
||
|
|
||
|
simplified lets you omit duck sizing, equality and 0/1 guards.
|
||
|
This is useful for testing when you don't care about the boilerplate
|
||
|
guards, and it may be helpful for user output too (be careful though;
|
||
|
some equality guards are nontrivial! It would be nice to get simplified
|
||
|
output to print them too). It's private because it's not
|
||
|
intended for normal use
|
||
|
"""
|
||
|
self.log.info("produce_guards")
|
||
|
|
||
|
# Check if we get to the same ShapeEnv state by replaying the recorded events.
|
||
|
# This will create a new ShapeEnv instance, and call all recorded function
|
||
|
# calls on this new instance. Finally, it will check whether this new instance
|
||
|
# has equal state.
|
||
|
#
|
||
|
# It's important that we do it in the begining of this function, since it modifies
|
||
|
# self.dim_constraints through its execution. Changes that happen in this method
|
||
|
# aren't interesting, since this is the function call we wish to reproduce at the
|
||
|
# end. If we wish to simply reproduce ShapeEnv instances even after this call,
|
||
|
# this method should also be recorded.
|
||
|
if self.check_recorded_events:
|
||
|
shape_env = replay_shape_env_events(self.events)
|
||
|
self.check_equal(shape_env)
|
||
|
|
||
|
assert len(placeholders) == len(sources), f"len({placeholders}) != len({sources})"
|
||
|
Tensorlike = (torch.Tensor, FakeTensorMeta)
|
||
|
|
||
|
def _create_no_constraints_context(t):
|
||
|
return StatelessSymbolicContext(
|
||
|
# Ignored; only the constraints part is relevant below.
|
||
|
dynamic_sizes=[DimDynamic.DYNAMIC] * t.dim(),
|
||
|
constraint_sizes=[None] * t.dim()
|
||
|
)
|
||
|
|
||
|
# Expand optional inputs, or verify invariants are upheld
|
||
|
if input_contexts is None:
|
||
|
input_contexts = [
|
||
|
_create_no_constraints_context(t) if isinstance(t, Tensorlike)
|
||
|
else None for t in placeholders
|
||
|
]
|
||
|
else:
|
||
|
assert len(input_contexts) == len(placeholders)
|
||
|
for i, (t, context) in enumerate(zip(placeholders, input_contexts)):
|
||
|
if isinstance(t, Tensorlike):
|
||
|
if context is None:
|
||
|
input_contexts[i] = _create_no_constraints_context(t)
|
||
|
else:
|
||
|
assert isinstance(t, (SymInt, int))
|
||
|
assert not isinstance(context, list)
|
||
|
|
||
|
# It took a lot of sweat to figure out the algorithm here. Let's
|
||
|
# explain how it works.
|
||
|
#
|
||
|
# The ShapeEnv lifecycle looks something like this:
|
||
|
#
|
||
|
# - For each input, you either generate a fresh Sympy symbol (s0) to
|
||
|
# represent its value (a binding site), or you reuse some
|
||
|
# preexisting symbol or expression, skipping the symbol allocation
|
||
|
# (e.g., duck sizing to a preexisting symbol, or expressing a
|
||
|
# stride as a multiplication of a separate stride and size.)
|
||
|
# Naively, you might expect to bind a fresh Sympy symbol for
|
||
|
# every input, but this is fairly wasteful as most of these
|
||
|
# symbols immediately simplify away, and if you don't eagerly
|
||
|
# specialize, e.g., 0/1 symbols, you end up with very complicated
|
||
|
# expressions that are not optimizable in practice.
|
||
|
#
|
||
|
# - You perform some compute on these symbols, occasionally
|
||
|
# introducing guards on boolean expressions on these symbols.
|
||
|
# In particular, whenever we guard on equality (_maybe_guard_rel),
|
||
|
# we can simplify shapes; e.g., when s0 == s1 * 2, we can now
|
||
|
# replace all occurrences of s0 with s1 * 2. Sometimes, a
|
||
|
# boolean expression evaluation doesn't introduce a guard, as
|
||
|
# the guard is already entailed by the simplifications we have
|
||
|
# applied.
|
||
|
#
|
||
|
# - In the end, you have a bunch of replacements (saying how to
|
||
|
# simplify shapes) and a bunch of guards (all the equality guards
|
||
|
# are trivial, because they're covered by the replacements).
|
||
|
#
|
||
|
# From the ShapeEnv, we must generate a Python expression that, when
|
||
|
# evaluated on a set of inputs, tells us whether or not these boolean
|
||
|
# expressions would have evaluated in the same way. However,
|
||
|
# we cannot easily compute this, as we elide recording boolean
|
||
|
# expressions when we think they are vacuously true. Thus, we seek
|
||
|
# an approximation: we must generate an expression, if true, would have
|
||
|
# produced an "equivalent" ShapeEnv, which would answer guard
|
||
|
# expressions in the same way.
|
||
|
#
|
||
|
# Our notion of equivalence is a bit subtle. For example, consider
|
||
|
# the ShapeEnv created from an input of size (5, 4) versus (4, 4)
|
||
|
# (no other guards.) Duck sizing would generate (s0, s1) in the first
|
||
|
# case but (s0, s0) in the second. We do NOT assume that size
|
||
|
# variables are disjoint; so in fact a graph that assumes the input
|
||
|
# could be (s0, s1) subsumes (s0, s0) (setting s0 == s1), but not
|
||
|
# vice versa. However, consider an analogous case (1,) versus (2,).
|
||
|
# Duck sizing generates (1,) and (s0,); the (s0,) graph does NOT
|
||
|
# subsume the (1,) graph because we assume that any size variables
|
||
|
# is NOT 0/1 (and make simplifications according to this; e.g., if
|
||
|
# we queried s0 == 0, we would immediately return False without
|
||
|
# returning a guard.)
|
||
|
#
|
||
|
# So, it is perhaps easier to flip things on their head: the guard
|
||
|
# expressions we generate here say what simplifications are valid,
|
||
|
# and what are not. Below, we explain each of the guard expressions
|
||
|
# we generate
|
||
|
|
||
|
# TODO: Make this more efficient by binding all the size/stride/offsets
|
||
|
# to locals before performing tests on them.
|
||
|
|
||
|
from torch._dynamo.source import TensorPropertySource, TensorProperty, NegateSource
|
||
|
|
||
|
# Actual codegen must be delayed as we don't necessarily know what
|
||
|
# the symbol mapping is
|
||
|
input_guards = []
|
||
|
|
||
|
symbol_to_source = collections.defaultdict(list)
|
||
|
symbol_to_constraints = collections.defaultdict(set)
|
||
|
constraint_violations : List[Tuple[bool, Callable[[], str]]] = []
|
||
|
|
||
|
def record_constraint_violation(warn_only, debug_name, msg, hint=None):
|
||
|
constraint_violations.append(
|
||
|
(warn_only, debug_name, lambda: f"{msg}{hint()}" if hint else msg)
|
||
|
)
|
||
|
|
||
|
def is_dim(src):
|
||
|
return isinstance(src, TensorPropertySource) and src.prop is TensorProperty.SIZE
|
||
|
|
||
|
if equalities_inputs:
|
||
|
source_index = {}
|
||
|
for i, src in enumerate(sources):
|
||
|
source_index[src.name()] = i
|
||
|
|
||
|
def get_expression(tensor_dim_src):
|
||
|
fake = placeholders[source_index[tensor_dim_src.base.name()]]
|
||
|
symint = fake.shape[tensor_dim_src.idx]
|
||
|
if isinstance(symint, torch.SymInt):
|
||
|
return symint.node.expr
|
||
|
else:
|
||
|
assert type(symint) is int, f"Expected int, got {type(symint)}"
|
||
|
return symint
|
||
|
|
||
|
for src1, src2 in equalities_inputs.source_pairs:
|
||
|
expr1, expr2 = get_expression(src1), get_expression(src2)
|
||
|
# Check whether given input shape values satisfy a specified equation s = s'.
|
||
|
# - Raise when the equation was violated by the given input shape values.
|
||
|
# - Otherwise issue a guard to constrain them.
|
||
|
concrete_val = self.evaluate_expr(sympy.Eq(expr1, expr2))
|
||
|
if not concrete_val:
|
||
|
raise ConstraintViolationError(
|
||
|
f"{src1.name()} = {expr1.subs(self.var_to_val)}"
|
||
|
" is not equal to "
|
||
|
f"{src2.name()} = {expr2.subs(self.var_to_val)}"
|
||
|
)
|
||
|
|
||
|
for src, root, fn in equalities_inputs.derived_equalities:
|
||
|
expr1 = get_expression(src)
|
||
|
# recall that root is either a phantom symbol or an input source
|
||
|
expr2, debug_name = (
|
||
|
(root, self.var_to_sources[root][0].name()) if isinstance(root, sympy.Symbol)
|
||
|
else (get_expression(root), self._debug_name(root))
|
||
|
)
|
||
|
expr2_ = fn(expr2)
|
||
|
# Check whether given input shape values satisfy a specified equation s = fn(s').
|
||
|
# - Raise when the equation was violated by the given input shape values.
|
||
|
# - Otherwise issue a guard to constrain them.
|
||
|
concrete_val = self.evaluate_expr(sympy.Eq(expr1, expr2_))
|
||
|
if not concrete_val:
|
||
|
raise ConstraintViolationError(
|
||
|
f"Expected input {src.name()} to be equal to "
|
||
|
f"{fn(sympy.Symbol(debug_name))}, "
|
||
|
f"where {debug_name} = {expr2.subs(self.var_to_val)}, "
|
||
|
f"but got {expr1.subs(self.var_to_val)}"
|
||
|
)
|
||
|
|
||
|
for phantom_symbol in equalities_inputs.phantom_symbols:
|
||
|
# we created additional phantom symbols that are not input shape dimensions
|
||
|
symbol_to_source[phantom_symbol].extend(self.var_to_sources[phantom_symbol])
|
||
|
|
||
|
# How do we know what the value of s0 is? Fresh variables can only be
|
||
|
# bound by inputs, so there MUST be some other input which binds the
|
||
|
# variable. If there is no such input, this is an error in our
|
||
|
# system. We record where all symbols come from, to help you diagnose
|
||
|
# why those symbols didn't occur.
|
||
|
#
|
||
|
# In fact, generally speaking it is only possible for the "outermost"
|
||
|
# user of a ShapeEnv to evaluate the guards, because some inputs may
|
||
|
# not be available to inner levels. For example, Dynamo can guard on
|
||
|
# tensors that never actually become graph arguments (they are
|
||
|
# pruned). In this case, only Dynamo knows about these arguments.
|
||
|
def track_symint(source, val, constraint=None):
|
||
|
log.debug("track_symint %s %s %s", LazyString(source.name), val, constraint)
|
||
|
assert not isinstance(val, SymInt) or is_symbolic(val)
|
||
|
|
||
|
if isinstance(val, SymInt) and val.node.maybe_as_int() is not None:
|
||
|
val = val.node.maybe_as_int()
|
||
|
|
||
|
if isinstance(val, SymInt):
|
||
|
s = val.node.expr
|
||
|
if isinstance(s, sympy.Symbol):
|
||
|
symbol_to_source[s].append(source)
|
||
|
if constraint is not None:
|
||
|
symbol_to_constraints[s].add(constraint)
|
||
|
elif isinstance(-s, sympy.Symbol):
|
||
|
symbol_to_source[-s].append(NegateSource(source))
|
||
|
else:
|
||
|
constraint_violated = False
|
||
|
if isinstance(constraint, StrictMinMaxConstraint):
|
||
|
# try inferring the ranges of the expr s
|
||
|
sym_vrs = {x: self.var_to_range.get(x, None) for x in s.free_symbols}
|
||
|
if all(vr is not None for vr in sym_vrs.values()):
|
||
|
expr_vr = bound_sympy(s, sym_vrs)
|
||
|
if expr_vr != constraint.vr:
|
||
|
# the expr and constrain ranges don't match
|
||
|
constraint_violated = True
|
||
|
else:
|
||
|
# some of the free symbols in s don't have ranges
|
||
|
constraint_violated = True
|
||
|
elif isinstance(constraint, RelaxedUnspecConstraint):
|
||
|
if s.is_number:
|
||
|
i = int(s)
|
||
|
# Don't complain about 0/1 specialization, we
|
||
|
# expect to have to compile in this case anyway
|
||
|
if i not in (0, 1):
|
||
|
constraint_violated = True
|
||
|
if constraint_violated:
|
||
|
def hint(s):
|
||
|
sexpr = ShapeGuardPrinter(symbol_to_source, source_ref, self.var_to_sources).doprint(s)
|
||
|
return f"{sexpr}."
|
||
|
|
||
|
var_with_range = self._render_range_for_constraint_violation(source, constraint)
|
||
|
msg = (
|
||
|
f"Not all values of {var_with_range} are valid because "
|
||
|
f"{self._debug_name(source)} was inferred to be equal to "
|
||
|
)
|
||
|
record_constraint_violation(
|
||
|
constraint.warn_only,
|
||
|
self._debug_name(source),
|
||
|
msg,
|
||
|
hint=functools.partial(hint, s),
|
||
|
)
|
||
|
|
||
|
input_guards.append((source, s))
|
||
|
else:
|
||
|
s = sympy.Integer(val)
|
||
|
input_guards.append((source, s))
|
||
|
constraint_violated = False
|
||
|
if isinstance(constraint, StrictMinMaxConstraint):
|
||
|
constraint_violated = True
|
||
|
elif isinstance(constraint, RelaxedUnspecConstraint):
|
||
|
# Don't complain about 0/1 specialization, we
|
||
|
# expect to have to compile in this case anyway
|
||
|
if val not in (0, 1):
|
||
|
constraint_violated = True
|
||
|
if constraint_violated:
|
||
|
var_with_range = self._render_range_for_constraint_violation(source, constraint)
|
||
|
msg = (
|
||
|
f"Not all values of {var_with_range} are valid because "
|
||
|
f"{self._debug_name(source)} was inferred to be a constant ({val})."
|
||
|
)
|
||
|
record_constraint_violation(constraint.warn_only, self._debug_name(source), msg)
|
||
|
|
||
|
for t, source, context in zip(placeholders, sources, input_contexts):
|
||
|
if isinstance(source, str):
|
||
|
from torch._dynamo.source import LocalSource
|
||
|
source = LocalSource(source)
|
||
|
assert isinstance(source, Source)
|
||
|
if t is None:
|
||
|
continue
|
||
|
if isinstance(t, (SymInt, int)):
|
||
|
track_symint(source, t)
|
||
|
continue
|
||
|
assert isinstance(t, Tensorlike)
|
||
|
if is_traceable_wrapper_subclass(t):
|
||
|
from torch._dynamo.source import AttrSource
|
||
|
|
||
|
assert isinstance(context, SubclassSymbolicContext)
|
||
|
|
||
|
# For subclasses, we need to track symints on BOTH the outer
|
||
|
# and inner tensors.
|
||
|
sources_tensors_constraints = [
|
||
|
(source, t, context.constraint_sizes)
|
||
|
]
|
||
|
attrs, _ = t.__tensor_flatten__()
|
||
|
for attr in attrs:
|
||
|
inner_t = getattr(t, attr)
|
||
|
inner_context = context.inner_contexts[attr]
|
||
|
sources_tensors_constraints.append((
|
||
|
AttrSource(source, attr),
|
||
|
inner_t,
|
||
|
inner_context.constraint_sizes
|
||
|
))
|
||
|
else:
|
||
|
sources_tensors_constraints = [(source, t, context.constraint_sizes)]
|
||
|
|
||
|
for src, curr_t, constraint in sources_tensors_constraints:
|
||
|
if is_sparse_any(curr_t):
|
||
|
for i, ss in enumerate(curr_t.size()):
|
||
|
property_source = TensorPropertySource(src, TensorProperty.SIZE, i)
|
||
|
track_symint(property_source, ss, constraint[i])
|
||
|
else:
|
||
|
for i, ss in enumerate(curr_t.size()):
|
||
|
property_source = TensorPropertySource(src, TensorProperty.SIZE, i)
|
||
|
track_symint(property_source, ss, constraint[i])
|
||
|
for i, ss in enumerate(curr_t.stride()):
|
||
|
track_symint(TensorPropertySource(src, TensorProperty.STRIDE, i), ss)
|
||
|
track_symint(TensorPropertySource(src, TensorProperty.STORAGE_OFFSET), curr_t.storage_offset())
|
||
|
|
||
|
# 1. Every input must equal the final simplified symbolic expression
|
||
|
# stored on the placeholder. Given a placeholder (s0*2, s1),
|
||
|
# if we have an input (2, 3), we must show s0*2 == 2 and s1 == 3.
|
||
|
# This does a lot of work: it covers duck sizing and equality guards.
|
||
|
exprs = []
|
||
|
self.dim_constraints = DimConstraints(
|
||
|
symbol_to_source,
|
||
|
self.var_to_val,
|
||
|
set(symbol_to_constraints.keys()),
|
||
|
self.source_name_to_debug_name,
|
||
|
)
|
||
|
|
||
|
if not _simplified:
|
||
|
for source, expr in input_guards:
|
||
|
if self._translation_validation_enabled:
|
||
|
# Ignore sources that were not turned into SymInts.
|
||
|
srcname = source.name()
|
||
|
if srcname in self.source_to_symbol:
|
||
|
self._add_target_expr(sympy.Eq(self.source_to_symbol[srcname], expr))
|
||
|
|
||
|
# Small optimization
|
||
|
if (
|
||
|
isinstance(expr, sympy.Symbol) and
|
||
|
symbol_to_source.get(expr) and
|
||
|
source == symbol_to_source[expr][0]
|
||
|
):
|
||
|
continue
|
||
|
|
||
|
# This logic excludes static values found on tensors from guarding, because
|
||
|
# dynamo's check_tensor_fn does that (see guards.cpp).
|
||
|
# However, for non tensor sources, we still need to guard here.
|
||
|
if ignore_static and isinstance(source, TensorPropertySource):
|
||
|
if expr.is_number:
|
||
|
self.log.debug("Skipping guard %s", f"{source_ref(source)} == {expr}")
|
||
|
continue
|
||
|
|
||
|
if is_dim(source):
|
||
|
self.dim_constraints.add_equality(source, expr)
|
||
|
|
||
|
sexpr = ShapeGuardPrinter(symbol_to_source, source_ref, self.var_to_sources).doprint(expr)
|
||
|
exprs.append(f"{source_ref(source)} == {sexpr}")
|
||
|
if (
|
||
|
isinstance(source, TensorPropertySource)
|
||
|
and source.prop is TensorProperty.SIZE
|
||
|
and equalities_inputs
|
||
|
and len(expr.free_symbols) == 1
|
||
|
):
|
||
|
symbol = next(iter(expr.free_symbols))
|
||
|
if (
|
||
|
isinstance(expr, sympy.Symbol) and
|
||
|
expr in symbol_to_constraints and
|
||
|
not equalities_inputs.is_equal(source, symbol_to_source[expr][0])
|
||
|
):
|
||
|
msg = (
|
||
|
f"The values of {self._debug_name(source)} = {source.name()} and "
|
||
|
f"{self._debug_name(symbol_to_source[expr][0])} = {symbol_to_source[expr][0].name()} "
|
||
|
"must always be equal."
|
||
|
)
|
||
|
record_constraint_violation(equalities_inputs.warn_only, self._debug_name(source), msg)
|
||
|
|
||
|
if (
|
||
|
not isinstance(expr, sympy.Symbol) and
|
||
|
symbol in symbol_to_constraints and
|
||
|
not equalities_inputs.is_derived(source, symbol_to_source[symbol][0], lambda x: expr.subs(symbol, x))
|
||
|
):
|
||
|
src = symbol_to_source[symbol][0]
|
||
|
msg = (
|
||
|
f"The values of {self._debug_name(source)} = {source.name()} must always be related to "
|
||
|
f"the values of {self._debug_name(src)} = {src.name()} by "
|
||
|
f"{self._debug_name(source)} = {expr.subs(symbol, sympy.sympify(self._debug_name(src)))}."
|
||
|
)
|
||
|
record_constraint_violation(equalities_inputs.warn_only, self._debug_name(source), msg)
|
||
|
|
||
|
# NB: Not necessary to report constraint violations here:
|
||
|
# constraints are guaranteed to be on symbols (we've already
|
||
|
# caught constants and non-atomic expressions), so we only
|
||
|
# have relational constraints, but we don't support those
|
||
|
# at the moment
|
||
|
|
||
|
# 2. Every guard must evaluate to True (but remember many guards
|
||
|
# like s0 == s1*2 because trivial due to simplification)
|
||
|
issued = set()
|
||
|
|
||
|
def issue_guard(guard: ShapeGuard) -> None:
|
||
|
expr = self.simplify(guard.expr)
|
||
|
|
||
|
# Avoid re-issueing the same guard.
|
||
|
if expr in issued:
|
||
|
return
|
||
|
|
||
|
issued.add(expr)
|
||
|
|
||
|
try:
|
||
|
is_trivial = False
|
||
|
if any(is_dim(source) for s in expr.free_symbols for source in symbol_to_source[s]):
|
||
|
is_trivial = self.dim_constraints.add(expr)
|
||
|
guard_expr = ShapeGuardPrinter(symbol_to_source, source_ref, self.var_to_sources).doprint(expr)
|
||
|
exprs.append(guard_expr)
|
||
|
self._add_target_expr(expr)
|
||
|
# A non-relational constraint on a single sizevar can violate
|
||
|
# a constraint
|
||
|
if not is_trivial and len(expr.free_symbols) == 1:
|
||
|
symbol = next(iter(expr.free_symbols))
|
||
|
source = symbol_to_source[symbol][0]
|
||
|
constraints = symbol_to_constraints[symbol]
|
||
|
for c in constraints:
|
||
|
if isinstance(c, StrictMinMaxConstraint):
|
||
|
var_with_range = self._render_range_for_constraint_violation(source, c)
|
||
|
msg = (
|
||
|
f"Not all values of {var_with_range} "
|
||
|
f"satisfy the generated guard {guard_expr}."
|
||
|
)
|
||
|
record_constraint_violation(c.warn_only, self._debug_name(source), msg)
|
||
|
elif isinstance(c, RelaxedUnspecConstraint):
|
||
|
# This is fine, we allow guards here as long as it
|
||
|
# didn't constrain it to one value (we don't
|
||
|
# actually know this; this depends on our
|
||
|
# ValueRanges reasoning capability)
|
||
|
pass
|
||
|
else:
|
||
|
raise AssertionError(f"unrecognized constraint {c}")
|
||
|
except Exception:
|
||
|
self.log.warning("Failing guard allocated at: \n%s", ''.join(guard.stack.format()))
|
||
|
raise
|
||
|
|
||
|
# First, issue all the non-trivial guards.
|
||
|
for guard in self.guards:
|
||
|
if self._maybe_evaluate_static(guard.expr) is not None:
|
||
|
continue
|
||
|
issue_guard(guard)
|
||
|
|
||
|
# 3. Every symbol must be within its value range (this handles 0/1
|
||
|
# specialization too).
|
||
|
for symbol, sources in symbol_to_source.items():
|
||
|
r = self.var_to_range.get(symbol)
|
||
|
if r is None:
|
||
|
if symbol not in self.var_to_range:
|
||
|
continue
|
||
|
r = self.var_to_range[symbol]
|
||
|
|
||
|
assert sources
|
||
|
assert symbol.is_integer
|
||
|
bounds = []
|
||
|
if r.lower != -sympy.oo:
|
||
|
if any(is_dim(source) for source in sources):
|
||
|
self.dim_constraints.add(sympy.Ge(symbol, r.lower))
|
||
|
# Only print lower bound in simplified mode if it is not the
|
||
|
# default
|
||
|
if not _simplified or r.lower != self._default_value_range().lower:
|
||
|
bounds.append(str(r.lower))
|
||
|
bounds.append(source_ref(sources[0]))
|
||
|
# NB: This looks like an off-by-one error but it's not: the
|
||
|
# upper bound may be sys.maxsize - 1 because we intentionally
|
||
|
# exclude sys.maxsize from our bounds to deal with direct
|
||
|
# == INT_MAX guards, but it's still dumb to actually test it.
|
||
|
# Note that you can be off by a pretty large constant and it
|
||
|
# won't matter because sizes in practice will be no where near
|
||
|
# the 64-bit limit.
|
||
|
if r.upper != sympy.oo and r.upper < sys.maxsize - 1:
|
||
|
if any(is_dim(source) for source in sources):
|
||
|
self.dim_constraints.add(sympy.Le(symbol, r.upper))
|
||
|
# nontrivial upper bound is always interesting
|
||
|
bounds.append(str(r.upper))
|
||
|
if len(bounds) > 1:
|
||
|
exprs.append(" <= ".join(bounds))
|
||
|
|
||
|
# Check constraints
|
||
|
constraints = symbol_to_constraints[symbol]
|
||
|
for c in constraints:
|
||
|
if isinstance(c, StrictMinMaxConstraint):
|
||
|
# NB: By default, we have a restrictive range
|
||
|
# 2 <= s0 <= sys.maxsize - 1. But export users generally
|
||
|
# expect to be able to specify nice ranges like [0, oo]
|
||
|
if not (c.vr & self._default_value_range()).issubset(r):
|
||
|
source = sources[0]
|
||
|
|
||
|
expr = sympy.And(sympy.Le(r.lower, symbol), sympy.Le(symbol, r.upper))
|
||
|
guard_expr = ShapeGuardPrinter(symbol_to_source, source_ref, self.var_to_sources).doprint(expr)
|
||
|
var_with_range = self._render_range_for_constraint_violation(source, c)
|
||
|
msg = (
|
||
|
f"Not all values of {var_with_range} satisfy the generated guard {guard_expr}"
|
||
|
)
|
||
|
record_constraint_violation(
|
||
|
c.warn_only,
|
||
|
self._debug_name(source),
|
||
|
msg,
|
||
|
)
|
||
|
|
||
|
if constraint_violations:
|
||
|
warn_msgs = []
|
||
|
error_msgs = []
|
||
|
debug_names = set()
|
||
|
for warn_only, debug_name, msg in constraint_violations:
|
||
|
if warn_only:
|
||
|
msg = f" {len(warn_msgs) + 1}. {msg()}"
|
||
|
warn_msgs.append(msg)
|
||
|
else:
|
||
|
msg = f" - {msg()}"
|
||
|
error_msgs.append(msg)
|
||
|
debug_names.add(debug_name)
|
||
|
if len(error_msgs) > 0:
|
||
|
debug_names = ', '.join(debug_names)
|
||
|
err = '\n'.join(error_msgs)
|
||
|
raise ConstraintViolationError(
|
||
|
f"Constraints violated ({debug_names})! "
|
||
|
"For more information, run with TORCH_LOGS=\"+dynamic\".\n"
|
||
|
f"{err}"
|
||
|
)
|
||
|
elif len(warn_msgs) > 0:
|
||
|
log.debug("%s Warning only constraints violated", len(warn_msgs))
|
||
|
|
||
|
signpost_event(
|
||
|
"dynamic",
|
||
|
"produce_guards",
|
||
|
{
|
||
|
**self.co_fields,
|
||
|
**self.counter,
|
||
|
"num_guards": len(exprs),
|
||
|
"free_symbols": sum(1 for v in symbol_to_source.values() if v),
|
||
|
# The keys are meaningless from an aggregate perspective, so
|
||
|
# don't include them. Biggest first.
|
||
|
"symbol_guard_counts": sorted(self.symbol_guard_counter.values(), reverse=True),
|
||
|
},
|
||
|
)
|
||
|
|
||
|
if self._translation_validation_enabled:
|
||
|
from torch.fx.experimental.validator import PopulateValidator
|
||
|
|
||
|
# Add all deferred runtime assertions; these are not technically
|
||
|
# handled by produce_guards but we need to put them in the target
|
||
|
# set
|
||
|
for ras in self.deferred_runtime_asserts.values():
|
||
|
for ra in ras:
|
||
|
self._add_target_expr(ra.expr)
|
||
|
|
||
|
# Add value range bound guards for all symbols with no trivial bounds.
|
||
|
# Reason: '_maybe_evaluate_static' may eliminate guards based on the
|
||
|
# refined value ranges.
|
||
|
for sym, vr in self.var_to_range.items():
|
||
|
if vr.lower != -sympy.oo:
|
||
|
self._add_target_expr(sympy.Le(vr.lower, sym))
|
||
|
if vr.upper != sympy.oo:
|
||
|
self._add_target_expr(sympy.Le(sym, vr.upper))
|
||
|
|
||
|
# Before validating, populate the input of the validator with the
|
||
|
# built FX graph.
|
||
|
with fx_traceback.preserve_node_meta():
|
||
|
PopulateValidator(self.graph, self.validator).run()
|
||
|
|
||
|
self._check_translation_validate()
|
||
|
return exprs
|
||
|
|
||
|
def produce_guards_expression(self, placeholders, ignore_static=True):
|
||
|
"""
|
||
|
Expected to be used with evaluate_guards_expression(). Produces the guards
|
||
|
for the given placeholders and returns a string expression to be evaluated
|
||
|
by evaluate_guards_expression given concrete values for the placeholders.
|
||
|
"""
|
||
|
from torch._dynamo.source import LocalSource
|
||
|
arg_names = [f"t{i}" for i in range(len(placeholders))]
|
||
|
guards = self.produce_guards(placeholders, [LocalSource(a) for a in arg_names], ignore_static=ignore_static)
|
||
|
if guards:
|
||
|
return " and ".join(guards)
|
||
|
return None
|
||
|
|
||
|
def evaluate_guards_expression(self, code, args):
|
||
|
"""
|
||
|
Expected to be used with produce_guards_expression(). Evaluates an expression
|
||
|
generated by produce_guards_expression for the given concrete args.
|
||
|
"""
|
||
|
arg_names = [f"t{i}" for i in range(len(args))]
|
||
|
return eval(code, SYMPY_INTERP, {"L": dict(zip(arg_names, args))})
|
||
|
|
||
|
def evaluate_guards_for_args(self, placeholders, args, *, ignore_static=True):
|
||
|
"""Generate guards for a graph's placeholder values and evaluate the guards with args
|
||
|
"""
|
||
|
code = self.produce_guards_expression(placeholders, ignore_static=ignore_static)
|
||
|
if code:
|
||
|
return self.evaluate_guards_expression(code, args)
|
||
|
return True
|
||
|
|
||
|
def bind_symbols(self, placeholders, args):
|
||
|
"""
|
||
|
Given a paired list of placeholders (fake tensors with
|
||
|
symbolic sizes) and concrete arguments (regular tensors
|
||
|
with real sizes), returns a dictionary mapping each
|
||
|
symbol to its real value. So for example, if you
|
||
|
have a placeholder with size (s0, s1), binding
|
||
|
(2, 4) to it will give you {s0: 2, s1: 4}. This is
|
||
|
not guaranteed to bind ALL symbols in the ShapeEnv;
|
||
|
we can't bind a symbol if it doesn't occur in any placeholder,
|
||
|
and symbols that already have replacements won't get bindings.
|
||
|
|
||
|
This is a little duplicative with evaluate_guards but
|
||
|
it's different enough that it seemed cleanest to make
|
||
|
another copy. This assumes the guards are already checked,
|
||
|
though if it's cheap we'll check for shenanigans
|
||
|
"""
|
||
|
bindings: Dict[sympy.Symbol, int] = {}
|
||
|
|
||
|
def bind_symint(arg, val):
|
||
|
if isinstance(val, SymInt):
|
||
|
s = val.node.expr
|
||
|
|
||
|
if isinstance(s, sympy.Symbol):
|
||
|
if s in bindings:
|
||
|
assert bindings[s] == arg, f"{bindings[s]} != {arg}"
|
||
|
else:
|
||
|
bindings[s] = arg
|
||
|
elif isinstance(-s, sympy.Symbol):
|
||
|
if -s in bindings:
|
||
|
assert bindings[-s] == -arg, f"{bindings[-s]} != {-arg}"
|
||
|
else:
|
||
|
bindings[-s] = -arg
|
||
|
|
||
|
for t, arg in zip(placeholders, args):
|
||
|
if t is None:
|
||
|
continue
|
||
|
if isinstance(t, SymInt):
|
||
|
bind_symint(arg, t)
|
||
|
continue
|
||
|
assert isinstance(t, torch.Tensor)
|
||
|
for i, s in enumerate(t.size()):
|
||
|
bind_symint(arg.size(i), s)
|
||
|
for i, s in enumerate(t.stride()):
|
||
|
bind_symint(arg.stride(i), s)
|
||
|
bind_symint(arg.storage_offset(), t.storage_offset())
|
||
|
|
||
|
return bindings
|
||
|
|
||
|
def get_nontrivial_guards(self):
|
||
|
"""Returns a list of guard expressions that aren't statically known (i.e. not trivial)"""
|
||
|
return [self.simplify(guard.expr) for guard in self.guards if self._maybe_evaluate_static(guard.expr) is None]
|
||
|
|
||
|
def format_guards(self, verbose=False):
|
||
|
"""Format this shape env's guard expressions with optional traceback info if verbose"""
|
||
|
def format_tb(tb):
|
||
|
if not verbose:
|
||
|
return ""
|
||
|
return f"\n Guarded at:\n{''.join(' ' + l for l in tb.format())}"
|
||
|
|
||
|
return '\n'.join(f" - {guard.expr}{format_tb(guard.stack)}" for guard in self.guards)
|
||
|
|
||
|
def bound_sympy(self, expr: sympy.Expr, size_oblivious: bool = False) -> ValueRanges:
|
||
|
"""Given a sympy expression, computes a ValueRanges bound for what values it can be"""
|
||
|
var_to_range = {x: self.var_to_range.get(x, None) for x in expr.free_symbols}
|
||
|
if size_oblivious:
|
||
|
# Clamp values of size-like variables
|
||
|
for x in self.size_like & var_to_range.keys():
|
||
|
if var_to_range[x] is not None:
|
||
|
var_to_range[x] &= ValueRanges(2, sympy.oo)
|
||
|
return bound_sympy(expr, var_to_range)
|
||
|
|
||
|
@_lru_cache
|
||
|
def _maybe_evaluate_static(
|
||
|
self, expr: "sympy.Expr", *, unbacked_only: bool = False, compute_hint: bool = False,
|
||
|
expect_rational=True, size_oblivious: bool = False
|
||
|
) -> "Optional[sympy.Expr]":
|
||
|
"""
|
||
|
Tries to evaluate expr without introducing guards
|
||
|
|
||
|
If unbacked_only == True, then we only do substitutions on
|
||
|
unbacked SymInts (leaving regular hinted integers alone). This could
|
||
|
result in an expression that still contains backed SymInts, which you
|
||
|
could then potentially guard on.
|
||
|
|
||
|
Use compute_hint == True if you are trying to compute a non-binding
|
||
|
hint for the particular hint values of backed SymInts, e.g., if
|
||
|
s0 happens to be 3 this run, compute_hint will subsitute s0 with 3.
|
||
|
"""
|
||
|
expr = self.simplify(expr)
|
||
|
|
||
|
if compute_hint:
|
||
|
expr = expr.xreplace(self.var_to_val)
|
||
|
|
||
|
expr = canonicalize_bool_expr(expr)
|
||
|
|
||
|
symbols = list(expr.free_symbols)
|
||
|
|
||
|
# Apply known runtime asserts
|
||
|
for s in symbols:
|
||
|
# Unbacked symints only
|
||
|
if s in self.var_to_val:
|
||
|
continue
|
||
|
|
||
|
subst = {}
|
||
|
|
||
|
def add_expr(expr):
|
||
|
# Expr and negation
|
||
|
subst[canonicalize_bool_expr(expr)] = sympy.true
|
||
|
subst[canonicalize_bool_expr(sympy.Not(expr))] = sympy.false
|
||
|
if isinstance(expr, sympy.Rel):
|
||
|
# multiplying by -1 changes the direction of the inequality
|
||
|
dual = type(expr)(-expr.rhs, -expr.lhs)
|
||
|
subst[canonicalize_bool_expr(dual)] = sympy.true
|
||
|
subst[canonicalize_bool_expr(sympy.Not(dual))] = sympy.false
|
||
|
|
||
|
for e in itertools.chain(self.guards, self.deferred_runtime_asserts.get(s, ())):
|
||
|
e = e.expr
|
||
|
if compute_hint:
|
||
|
e = canonicalize_bool_expr(e.xreplace(self.var_to_val))
|
||
|
add_expr(e)
|
||
|
# Other relational expressions this expression implies
|
||
|
if isinstance(e, sympy.Eq):
|
||
|
add_expr(sympy.Le(e.lhs, e.rhs))
|
||
|
add_expr(sympy.Ge(e.lhs, e.rhs))
|
||
|
elif isinstance(e, sympy.Lt):
|
||
|
add_expr(sympy.Le(e.lhs, e.rhs))
|
||
|
add_expr(sympy.Ne(e.lhs, e.rhs))
|
||
|
|
||
|
# NB: this helps us deal with And/Or connectives
|
||
|
expr = expr.subs(subst)
|
||
|
|
||
|
# Simplify making use of value range lower bound
|
||
|
new_shape_env = {}
|
||
|
new_range_env = {}
|
||
|
for idx, k in enumerate(symbols):
|
||
|
if isinstance(self.var_to_val.get(k, None), SingletonInt):
|
||
|
# Skip var_to_range logic for SingletonInt which is only used
|
||
|
# for jagged layout NestedTensors today
|
||
|
continue
|
||
|
vr = self.var_to_range[k]
|
||
|
if size_oblivious and k in self.size_like:
|
||
|
lower = max(2, vr.lower)
|
||
|
else:
|
||
|
lower = vr.lower
|
||
|
# Don't do anything if we don't have a nontrivial lower bound
|
||
|
# Also don't do anything if we asked only to simplify unbacked
|
||
|
# SymInt
|
||
|
if (
|
||
|
lower < (-sys.maxsize - 1) // 2 or
|
||
|
(unbacked_only and k in self.var_to_val)
|
||
|
):
|
||
|
new_range_env[k] = vr
|
||
|
continue
|
||
|
# Positive means >= 1
|
||
|
# Positive - 1 means >= 0
|
||
|
# Positive + lower - 1 means >= lower
|
||
|
# The new symbol 's' is "too low", so when we substitute it in
|
||
|
# we have to increase it by offset (and conversely, the new
|
||
|
# variables have to have their value range bounds adjusted as
|
||
|
# well)
|
||
|
s = sympy.Symbol(f"shape_{idx}", positive=True, integer=True)
|
||
|
offset = lower - 1
|
||
|
new_shape_env[k] = s + offset
|
||
|
new_range_env[s] = SymPyValueRangeAnalysis.add(vr, -offset)
|
||
|
|
||
|
def replace(expr, repl):
|
||
|
return expr.xreplace(repl)
|
||
|
|
||
|
try:
|
||
|
new_expr = replace(expr, new_shape_env)
|
||
|
except RecursionError:
|
||
|
log.warning("RecursionError in sympy.xreplace(%s, %s)", expr, new_shape_env)
|
||
|
self.counter["sympy_recursion_error"] += 1
|
||
|
return None
|
||
|
|
||
|
floor_div_replace = {}
|
||
|
for atom in new_expr.atoms(FloorDiv):
|
||
|
floor_div_replace[atom] = sympy.floor(atom.args[0] / atom.args[1])
|
||
|
new_expr = safe_expand(new_expr.xreplace(floor_div_replace))
|
||
|
# TODO: when unbacked_only, can sometimes early return even when there
|
||
|
# are still free symbols
|
||
|
if new_expr.is_number:
|
||
|
return new_expr
|
||
|
|
||
|
# Check if the range can solve it statically
|
||
|
out = bound_sympy(new_expr, new_range_env)
|
||
|
if expect_rational:
|
||
|
_assert_bound_is_rational(new_expr, out)
|
||
|
if out.is_singleton():
|
||
|
return out.lower
|
||
|
|
||
|
return new_expr if unbacked_only else None
|
||
|
|
||
|
@_lru_cache
|
||
|
def replace(self, expr: "sympy.Expr") -> "sympy.Expr":
|
||
|
"""Apply symbol replacements to any symbols in the given expression
|
||
|
"""
|
||
|
replacements = {s: self._find(cast(sympy.Symbol, s)) for s in expr.free_symbols}
|
||
|
return safe_expand(expr.xreplace(replacements))
|
||
|
|
||
|
@_lru_cache
|
||
|
def _update_divisible(self):
|
||
|
new_divisible = set()
|
||
|
for k in self.divisible:
|
||
|
res = self.replace(k)
|
||
|
if not res.is_number:
|
||
|
new_divisible.add(k)
|
||
|
|
||
|
self.divisible = new_divisible
|
||
|
self._update_version_counter()
|
||
|
|
||
|
@_lru_cache
|
||
|
def simplify(self, expr: "sympy.Expr") -> "sympy.Expr":
|
||
|
"""Use known constraints and replacements to simplify the given expr
|
||
|
"""
|
||
|
expr = self.replace(expr)
|
||
|
# TODO it would seem that this pass is not necessary given the
|
||
|
# below replacement of // with /, but for nested FloorDivs
|
||
|
# the non-recursive replacement doesn't work, and
|
||
|
# recursive makes it hard to look up divisibility,
|
||
|
# because existing divisibility info has FloorDiv in it, not /
|
||
|
# for now just do a separate pass to catch common nested case
|
||
|
if expr.has(FloorDiv):
|
||
|
self._update_divisible()
|
||
|
div_replacements = {}
|
||
|
for atom in expr.atoms(FloorDiv):
|
||
|
base, divisor = atom.args
|
||
|
if isinstance(divisor, FloorDiv):
|
||
|
base1, divisor1 = divisor.args
|
||
|
if self.replace(Mod(base, divisor)) in self.divisible and \
|
||
|
base == base1 and self.replace(Mod(base1, divisor1)) in self.divisible:
|
||
|
div_replacements[atom] = divisor1
|
||
|
expr = expr.xreplace(div_replacements)
|
||
|
expr = safe_expand(expr)
|
||
|
if expr.has(FloorDiv):
|
||
|
div_replacements = {}
|
||
|
pows = expr.atoms(sympy.Pow)
|
||
|
rationals = expr.atoms(sympy.Rational).difference(expr.atoms(sympy.Integer))
|
||
|
for fd in expr.atoms(FloorDiv):
|
||
|
base, divisor = fd.args
|
||
|
if self.replace(Mod(base, divisor)) in self.divisible:
|
||
|
div_replacements[fd] = base / divisor
|
||
|
new_expr = expr.xreplace(div_replacements)
|
||
|
new_expr = safe_expand(new_expr)
|
||
|
new_pows = new_expr.atoms(sympy.Pow)
|
||
|
new_rationals = new_expr.atoms(sympy.Rational).difference(new_expr.atoms(sympy.Integer))
|
||
|
# divisions simplified away
|
||
|
if new_pows.issubset(pows) and new_rationals.issubset(rationals):
|
||
|
expr = new_expr
|
||
|
return expr
|
||
|
|
||
|
@lru_cache(256)
|
||
|
def size_hint(self, expr: "sympy.Expr", *, allow_none=False):
|
||
|
"""
|
||
|
Gets a size hint for a given expression from the underlying shapes we had.
|
||
|
Does not introduce a guard, so only use this when you can guarantee that
|
||
|
your code is still valid for arbitrary shapes (such as optimization decisions)
|
||
|
"""
|
||
|
result_expr = safe_expand(expr).xreplace(self.var_to_val)
|
||
|
if not result_expr.is_number:
|
||
|
|
||
|
from torch.utils._sympy.singleton_int import SingletonInt
|
||
|
|
||
|
if isinstance(result_expr, SingletonInt):
|
||
|
return None
|
||
|
r = self._maybe_evaluate_static(result_expr, compute_hint=True)
|
||
|
if r is not None:
|
||
|
return r
|
||
|
if allow_none:
|
||
|
return None
|
||
|
raise self._make_data_dependent_error(result_expr, expr)
|
||
|
return result_expr
|
||
|
|
||
|
# NB: keep in sync with size_hint
|
||
|
@lru_cache(256)
|
||
|
def has_hint(self, expr: "sympy.Expr"):
|
||
|
result_expr = safe_expand(expr).xreplace(self.var_to_val)
|
||
|
return result_expr.is_number or self._maybe_evaluate_static(result_expr) is not None
|
||
|
|
||
|
def _make_data_dependent_error(self, expr, unhinted_expr, *, size_oblivious_result: Optional[bool] = None):
|
||
|
# TODO: in a Dynamo context, having user code, and having the
|
||
|
# name of the local, will be much better
|
||
|
size_like_symbols = []
|
||
|
for s in expr.free_symbols:
|
||
|
stacktrace = ''.join(self.var_to_stack[s].format())
|
||
|
self.log.debug("Data dependent variable '%s' allocated at:\n%s", s, stacktrace)
|
||
|
if s in self.size_like:
|
||
|
size_like_symbols.append(s)
|
||
|
size_oblivious_result_msg = ""
|
||
|
if size_oblivious_result is not None:
|
||
|
size_oblivious_result_msg = (
|
||
|
f"ATTENTION: guard_size_oblivious would fix the error, evaluating expression to {size_oblivious_result}.\n"
|
||
|
"Maybe you need to add guard_size_oblivious to framework code, see doc below for more guidance.\n\n"
|
||
|
)
|
||
|
fsummary, maybe_user_loc, maybe_extra_debug = self._get_stack_summary(True)
|
||
|
return GuardOnDataDependentSymNode(
|
||
|
f"Could not guard on data-dependent expression {expr} (unhinted: {unhinted_expr}). "
|
||
|
f"(Size-like symbols: {', '.join(map(str, size_like_symbols)) or 'none'})\n\n"
|
||
|
f"{size_oblivious_result_msg}"
|
||
|
"Potential framework code culprit (scroll up for full backtrace):\n"
|
||
|
f"{''.join(traceback.StackSummary.from_list([fsummary]).format())}\n"
|
||
|
"For more information, run with TORCH_LOGS=\"dynamic\"\n"
|
||
|
"For extended logs when we create symbols, also add "
|
||
|
f"TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL=\"{','.join(map(str, expr.free_symbols))}\"\n"
|
||
|
"If you suspect the guard was triggered from C++, add TORCHDYNAMO_EXTENDED_DEBUG_CPP=1\n"
|
||
|
"For more debugging help, see "
|
||
|
"https://docs.google.com/document/d/1HSuTTVvYH1pTew89Rtpeu84Ht3nQEFTYhAX3Ypa_xJs/edit?usp=sharing\n" +
|
||
|
maybe_extra_debug
|
||
|
# TODO: Help text about how to use our runtime tests to fix this
|
||
|
# problem
|
||
|
)
|
||
|
|
||
|
def _set_replacement(self, a: "sympy.Symbol", tgt: "sympy.Expr", msg: str) -> None:
|
||
|
"""
|
||
|
Adds or updates a replacement for a symbol.
|
||
|
Use this instead of `self.replacements[a] = tgt`.
|
||
|
"""
|
||
|
|
||
|
# Precondition: a == tgt
|
||
|
assert isinstance(a, sympy.Symbol)
|
||
|
|
||
|
# Handles nested tensor symbolic variables which don't have
|
||
|
# var_to_range bounds
|
||
|
tgt_bound = None
|
||
|
if a in self.var_to_range:
|
||
|
src_bound = self.var_to_range[a]
|
||
|
|
||
|
# If you have x in [2, maxint], then 2*x in [4, 2*maxint].
|
||
|
# But we don't really care that the max bound says we can
|
||
|
# go beyond the maximum integer size, because we aren't
|
||
|
# using bigints anyway. Arguably, ValueRanges should know
|
||
|
# to do this truncation automaticaly (to avoid doing
|
||
|
# bigint compute in range analysis), but right now it doesn't
|
||
|
# so we need to get rid of some unnecessary precision.
|
||
|
int_range = ValueRanges(-sys.maxsize - 1, sys.maxsize - 1)
|
||
|
|
||
|
def issubset(x, y):
|
||
|
return (x & int_range).issubset(y & int_range)
|
||
|
|
||
|
# First, refine the value range of a based on the computed value range
|
||
|
# of tgt. This is always OK to do, even if we decide not to do the
|
||
|
# substitution in the end. This might be a no-op, if a already has
|
||
|
# a tighter bound
|
||
|
tgt_bound = self.bound_sympy(tgt)
|
||
|
self.var_to_range[a] = src_bound & tgt_bound
|
||
|
|
||
|
# Next, check if we can update the range of free symbols in tgt
|
||
|
# based on the range in a. But only do it if:
|
||
|
# - the source bound non-trivially improves over what we get out of
|
||
|
# the existing bounds.
|
||
|
# - the replacement is univariate and we can invert the tgt expression
|
||
|
if not issubset(tgt_bound, src_bound) and len(tgt.free_symbols) == 1:
|
||
|
b = next(iter(tgt.free_symbols))
|
||
|
# Try to invert the equality
|
||
|
r = try_solve(sympy.Eq(a, tgt), b, floordiv_inequality=False)
|
||
|
if r is not None:
|
||
|
b_bound = self.bound_sympy(r[1])
|
||
|
self.var_to_range[b] = b_bound & self.var_to_range[b]
|
||
|
tgt_bound = self.bound_sympy(tgt)
|
||
|
assert issubset(tgt_bound, src_bound)
|
||
|
|
||
|
# TODO: Should we propagate size-like-ness?
|
||
|
#
|
||
|
# Pros: if u0 is size-like, intuitively u0 == u1 should cause u1
|
||
|
# to become size-like.
|
||
|
#
|
||
|
# Cons: if u0 is size-like, what about u0 - 1 == u1? You CAN'T
|
||
|
# propagate in this case, because what if u0 == 0, then u1 is negative
|
||
|
# and clearly isn't a size. So, at minimum, any f(x) whose value
|
||
|
# range isn't [0, inf] given x in [0, inf] cannot propagate
|
||
|
# size-like-ness. But there are many situations where you could
|
||
|
# imagine u1 is going to be size-like and actually you just didn't
|
||
|
# have a refined enough value range on u0. Since even innocuous
|
||
|
# looking arithmetic operations can destroy size-like-ness, it's
|
||
|
# best to not propagate it at all and force the user to annotate it
|
||
|
# as necessary.
|
||
|
#
|
||
|
# Compromise: we preserve size-like-ness only for exact equality
|
||
|
# and nothing else.
|
||
|
if a in self.size_like and isinstance(tgt, sympy.Symbol):
|
||
|
self.size_like.add(tgt)
|
||
|
elif isinstance(tgt, sympy.Symbol) and tgt in self.size_like:
|
||
|
self.size_like.add(a)
|
||
|
|
||
|
# Now, decide if we will do the substitution.
|
||
|
#
|
||
|
# - If the source has a non-trivial range, only substitute if
|
||
|
# we preserve this range. Note that we may have propagated
|
||
|
# the src_range to free variables in tgt when tgt is univariate
|
||
|
# and we could find an inverse, which helps us achieve this.
|
||
|
# This ensures we never "forget" about user defined ranges,
|
||
|
# even if they end up being defined on composite formulas
|
||
|
# like s0 + s1.
|
||
|
#
|
||
|
# - If the variable is unbacked, only substitute if the substitution
|
||
|
# would preserve the bounds also under size-like-ness conditions.
|
||
|
|
||
|
if not issubset(tgt_bound, src_bound):
|
||
|
self.log.debug("skipped set_replacement %s = %s (%s) [%s not subset of %s]", a, tgt, msg, tgt_bound, src_bound)
|
||
|
return
|
||
|
elif a in self.size_like:
|
||
|
tgt_bound_so = self.bound_sympy(tgt, size_oblivious=True)
|
||
|
# This is morally equivalent to self.bound_sympy(a, size_oblivious=True)
|
||
|
# but handles substitutions like u0 == 0
|
||
|
src_bound_so = self.var_to_range[a]
|
||
|
if src_bound_so.upper >= 2:
|
||
|
src_bound_so &= ValueRanges(2, sympy.oo)
|
||
|
if not issubset(tgt_bound_so, src_bound_so):
|
||
|
self.log.debug("skipped set_replacement %s = %s (%s) "
|
||
|
"[%s not subset of %s (size-oblivious conditions)]", a, tgt, msg, tgt_bound_so, src_bound_so)
|
||
|
return
|
||
|
|
||
|
if config.print_specializations and isinstance(tgt, (sympy.Integer, sympy.Float)):
|
||
|
# specializing to a constant, which is likely unexpected
|
||
|
|
||
|
# NOTE(avik): It is possible that we try logging the same specialization multiple times, e.g.,
|
||
|
# when adding a to self.replacements, and again when simplifying an expression containing a.
|
||
|
# Thus to avoid duplication, checking whether a is in self.replacements isn't enough; if it is,
|
||
|
# it must not already map to `tgt`. Fortunately this check is cheap because `tgt` is a constant.
|
||
|
if a not in self.replacements or tgt != self.replacements[a]:
|
||
|
self.log.warning("Specializing %s to %s", self.var_to_sources[a][0].name(), tgt)
|
||
|
self.log.debug("SPECIALIZATION", stack_info=True)
|
||
|
log.info("set_replacement %s = %s (%s) %s", a, tgt, msg, tgt_bound)
|
||
|
self.replacements[a] = tgt
|
||
|
self._update_version_counter()
|
||
|
|
||
|
# When specializing 'a == tgt', the equality should be also conveyed to
|
||
|
# Z3, in case an expression uses 'a'.
|
||
|
self._add_target_expr(sympy.Eq(a, tgt))
|
||
|
|
||
|
def _add_divisible(self, expr: "sympy.Expr"):
|
||
|
self.divisible.add(expr)
|
||
|
self._update_version_counter()
|
||
|
|
||
|
@_lru_cache
|
||
|
@record_shapeenv_event()
|
||
|
def _find(self, a: "sympy.Symbol") -> "sympy.Expr":
|
||
|
"""
|
||
|
Implements a DSU-like algorithm to find the variable that represents a
|
||
|
Also handles transitive non-identity replacements.
|
||
|
|
||
|
a: b + c
|
||
|
c: d
|
||
|
"""
|
||
|
if a not in self.replacements:
|
||
|
return a
|
||
|
res = self.replacements[a]
|
||
|
cur_replace = {s: self._find(s) for s in res.free_symbols}
|
||
|
self._set_replacement(a, self.replacements[a].xreplace(cur_replace), "find")
|
||
|
return self.replacements[a]
|
||
|
|
||
|
@lru_cache(256)
|
||
|
def _maybe_guard_rel(self, expr: "sympy.Rel") -> None:
|
||
|
"""
|
||
|
The relational guard is guarded to be true. Use this information to
|
||
|
simplify shapes (i.e. a == b or a % 5 == 0)
|
||
|
"""
|
||
|
assert isinstance(expr, sympy.Rel)
|
||
|
|
||
|
# A good example of what goes wrong if you don't do this is
|
||
|
# python test/functorch/test_aotdispatch.py -k
|
||
|
# test_aot_autograd_symbolic_module_exhaustive_nn_LazyConv3d_cpu_float32
|
||
|
if isinstance(expr, sympy.Ne):
|
||
|
return
|
||
|
|
||
|
free = list(expr.free_symbols)
|
||
|
|
||
|
assert len(free) > 0, f"The expression should not be static by this point: {expr}"
|
||
|
# In case of really gnarly expression, we don't blow up
|
||
|
if len(free) > 5:
|
||
|
return
|
||
|
|
||
|
# Prioritize unbacked symints for solving by ordering them last.
|
||
|
# Prefer to simplify out lexicographically higher symbols (i.e. simplify out s4 over s3).
|
||
|
# (NB: this unfortunately isn't strictly equivalent to simplifying out newer symbols)
|
||
|
# Prefer to simplify out symbols with ephemeral sources.
|
||
|
def _smart_symbol_sort(x):
|
||
|
has_only_ephemeral_sources = (
|
||
|
x in self.var_to_sources and all(s.is_ephemeral() for s in self.var_to_sources[x])
|
||
|
)
|
||
|
size = self.size_hint(x, allow_none=True) or sys.maxsize
|
||
|
name = x.name
|
||
|
# 1 puts ephemeral sourced symbols first when sorting in reverse
|
||
|
return (1 if has_only_ephemeral_sources else 0, size, name)
|
||
|
|
||
|
free = sorted(free, key=_smart_symbol_sort, reverse=True) # type: ignore[attr-defined]
|
||
|
lhs = expr.lhs
|
||
|
rhs = expr.rhs
|
||
|
|
||
|
self._refine_ranges(expr)
|
||
|
|
||
|
# The rest of this stuff is for equality only
|
||
|
if not isinstance(expr, sympy.Eq):
|
||
|
return
|
||
|
|
||
|
if not expr.has(Mod):
|
||
|
try:
|
||
|
floor_div_atoms = lhs.atoms(FloorDiv).union(rhs.atoms(FloorDiv))
|
||
|
if len(floor_div_atoms) > 0 and any(a.divisor != 1 for a in floor_div_atoms):
|
||
|
raise NotImplementedError
|
||
|
# short-circuit when no solving is needed
|
||
|
|
||
|
if isinstance(lhs, sympy.Symbol) and free_unbacked_symbols(lhs):
|
||
|
self._set_replacement(lhs, self._find(rhs), "trivial_lhs")
|
||
|
elif isinstance(rhs, sympy.Symbol) and free_unbacked_symbols(rhs):
|
||
|
self._set_replacement(rhs, self._find(lhs), "trivial_rhs")
|
||
|
else:
|
||
|
r = try_solve(expr, free[0], floordiv_inequality=False)
|
||
|
if r is not None and all(t.is_integer for t in sympy.preorder_traversal(r[1])):
|
||
|
new_var = self._find(r[1])
|
||
|
ok = False
|
||
|
if self.is_unbacked_symint(free[0]):
|
||
|
# If you have i0 + i1 + i2 = s0, don't substitute i2 =
|
||
|
# s0 - i0 - i1. Arguably this should be OK but the
|
||
|
# runtime assert machinery is very delicate right now
|
||
|
# so this causes things to fail e.g.,
|
||
|
# test_split_unbacked_sizes
|
||
|
ok = len(free_unbacked_symbols(new_var)) <= 1
|
||
|
msg = "solve_unbacked"
|
||
|
else:
|
||
|
# Never substitute backed with unbacked
|
||
|
ok = len(free_unbacked_symbols(new_var)) == 0
|
||
|
msg = "solve_backed"
|
||
|
if ok:
|
||
|
self._set_replacement(cast(sympy.Symbol, free[0]), new_var, msg)
|
||
|
except NotImplementedError:
|
||
|
pass
|
||
|
if expr.has(Mod):
|
||
|
mod_expr = next(iter(expr.atoms(Mod)))
|
||
|
try:
|
||
|
r = try_solve(expr, mod_expr, floordiv_inequality=False)
|
||
|
if r is not None and r[1] == 0:
|
||
|
self._add_divisible(mod_expr)
|
||
|
# This is a little bit of extra logic to make things like
|
||
|
# torch.empty(i0, q).view(c, -1, q) work out
|
||
|
p, q = mod_expr.args
|
||
|
if isinstance(q, sympy.Number) and isinstance(p, sympy.Mul) and len(p.args) == 2:
|
||
|
c, i0 = p.args
|
||
|
# Given Mod(c * i0, q) == 0
|
||
|
if (
|
||
|
isinstance(c, sympy.Number) and
|
||
|
isinstance(i0, sympy.Symbol) and
|
||
|
self.is_unbacked_symint(i0)
|
||
|
):
|
||
|
# We have Mod(i0, q / c) == 0, which means we can
|
||
|
# rewrite i0 as (q / gcd(q, c)) * i1
|
||
|
d = q / sympy.gcd(q, c)
|
||
|
i1 = self.create_unbacked_symint().node.expr
|
||
|
# Propagate the value ranges. It doesn't really
|
||
|
# matter if we use truediv or floordiv, because we
|
||
|
# have established divisibility.
|
||
|
self.var_to_range[i1] = SymPyValueRangeAnalysis.truediv(
|
||
|
self.var_to_range[i0], ValueRanges.wrap(d)
|
||
|
)
|
||
|
# Propagate size-like-ness
|
||
|
if i0 in self.size_like:
|
||
|
self.size_like.add(i1)
|
||
|
self._set_replacement(i0, d * i1, "divisibility")
|
||
|
|
||
|
except NotImplementedError:
|
||
|
pass
|
||
|
return
|
||
|
|
||
|
# See: Note - On 0/1 specialization
|
||
|
# NB: sys.maxsize is NOT allowed for sizes, because we use MAX_INT
|
||
|
# as a sentinel sometimes. Your sizevar isn't going to be
|
||
|
# anywhere near the max 64-bit integer anyway.
|
||
|
def _default_value_range(self) -> ValueRanges:
|
||
|
lower = 2 if self.specialize_zero_one else 0
|
||
|
return ValueRanges(lower, sys.maxsize - 1)
|
||
|
|
||
|
def _default_unspecified_value_range(self) -> ValueRanges:
|
||
|
return ValueRanges(-sys.maxsize - 1, sys.maxsize)
|
||
|
|
||
|
@_lru_cache
|
||
|
def _simplify_floor_div(self, expr):
|
||
|
floor_divs = tuple(expr.atoms(FloorDiv))
|
||
|
# we expect floor_divs to be exact,
|
||
|
# and thus add the guards for the exact floordivs,
|
||
|
# even if tracing doesn't require them otherwise
|
||
|
for fd in reversed(floor_divs):
|
||
|
base, divisor = fd.args
|
||
|
mod_expr = Mod(base, divisor)
|
||
|
eq_expr = sympy.Eq(mod_expr, 0)
|
||
|
# add necessary mod guards
|
||
|
self.evaluate_expr(eq_expr)
|
||
|
return self.simplify(expr)
|
||
|
|
||
|
# We're about to add a guard/runtime assert, check if the ShapeEnv is frozen
|
||
|
# and if so issue a warning
|
||
|
def _check_frozen(self, expr, concrete_val):
|
||
|
if self.frozen:
|
||
|
self.counter["ignored_backward_guard"] += 1
|
||
|
signpost_event(
|
||
|
"dynamic",
|
||
|
"evaluate_expr_frozen",
|
||
|
{
|
||
|
**self.co_fields,
|
||
|
"ignored_guard": f"{expr} == {concrete_val}",
|
||
|
# no version = original state (this signpost is expected)
|
||
|
# version 2 = dynamic backwards is eagerly compiled
|
||
|
"version": 2,
|
||
|
},
|
||
|
)
|
||
|
log.warning("Ignored guard %s == %s, this could result in accuracy problems", expr, concrete_val)
|
||
|
|
||
|
|
||
|
def _get_stack_summary(self, is_debug: bool = False):
|
||
|
fsummary = None
|
||
|
frame = inspect.currentframe()
|
||
|
try:
|
||
|
while frame is not None:
|
||
|
if frame.f_code.co_filename not in uninteresting_files():
|
||
|
fsummary = traceback.FrameSummary(
|
||
|
frame.f_code.co_filename,
|
||
|
frame.f_lineno,
|
||
|
frame.f_code.co_name,
|
||
|
)
|
||
|
break
|
||
|
frame = frame.f_back
|
||
|
finally:
|
||
|
del frame
|
||
|
|
||
|
# NB: this stack is truncated, but it's fine because the main
|
||
|
# stack_info will give you the rest of the info you need
|
||
|
maybe_user_loc = ""
|
||
|
user_tb = TracingContext.extract_stack()
|
||
|
if user_tb:
|
||
|
maybe_user_loc = " at " + format_frame(user_tb[-1])
|
||
|
|
||
|
maybe_extra_debug = ""
|
||
|
if is_debug and user_tb:
|
||
|
maybe_extra_debug = (
|
||
|
'\nUser Stack (most recent call last):\n' +
|
||
|
' (snipped, see stack below for prefix)\n' +
|
||
|
''.join(traceback.format_list(user_tb))
|
||
|
)
|
||
|
if is_debug and config.extended_debug_cpp:
|
||
|
cpp_stack = CapturedTraceback.extract(cpp=True)
|
||
|
maybe_extra_debug += "\nC++ stack trace:\n" + ''.join(cpp_stack.format())
|
||
|
|
||
|
return fsummary, maybe_user_loc, maybe_extra_debug
|
||
|
|
||
|
def _log_guard(self, prefix: str, g, forcing_spec: bool):
|
||
|
if self.log.isEnabledFor(logging.INFO):
|
||
|
str_g = str(g)
|
||
|
is_debug = config.extended_debug_guard_added is not None and str_g == config.extended_debug_guard_added
|
||
|
fsummary, maybe_user_loc, maybe_extra_debug = self._get_stack_summary(is_debug)
|
||
|
self.log.info(
|
||
|
"%s %s [guard added]%s (%s)%s",
|
||
|
prefix if not forcing_spec else f"{prefix} (forcing_spec)",
|
||
|
str_g,
|
||
|
maybe_user_loc,
|
||
|
format_frame(fsummary),
|
||
|
maybe_extra_debug,
|
||
|
stack_info=is_debug,
|
||
|
)
|
||
|
|
||
|
@lru_cache(256)
|
||
|
@record_shapeenv_event(save_tracked_fakes=True)
|
||
|
def evaluate_expr(self, orig_expr: "sympy.Expr", hint=None, fx_node=None,
|
||
|
expect_rational=True, size_oblivious: bool = False, *, forcing_spec: bool = False):
|
||
|
"""
|
||
|
Given an expression, evaluates it, adding guards if necessary
|
||
|
"""
|
||
|
|
||
|
# TODO: split conjunctions and evaluate them separately
|
||
|
|
||
|
@lru_cache(None)
|
||
|
def compute_concrete_val():
|
||
|
if hint is None:
|
||
|
return self.size_hint(orig_expr)
|
||
|
else:
|
||
|
return sympy.sympify(hint)
|
||
|
|
||
|
# Check if:
|
||
|
# 1. 'translation_validation' is set
|
||
|
# 2. the corresponding 'fx_node' is not 'None'
|
||
|
# 3. the guard should not be suppressed
|
||
|
#
|
||
|
# If all of the above check, we create an FX node representing the
|
||
|
# actual expression to be guarded.
|
||
|
node = None
|
||
|
fresh = False
|
||
|
if (
|
||
|
self._translation_validation_enabled
|
||
|
and fx_node is not None
|
||
|
and not self._suppress_guards_tls()
|
||
|
and not size_oblivious
|
||
|
):
|
||
|
concrete_val = compute_concrete_val()
|
||
|
if concrete_val is sympy.true:
|
||
|
node, fresh = self._create_fx_call_function(torch._assert, (fx_node,))
|
||
|
elif concrete_val is sympy.false:
|
||
|
neg, _ = self._create_fx_call_function(operator.not_, (fx_node,))
|
||
|
node, fresh = self._create_fx_call_function(torch._assert, (neg,))
|
||
|
else:
|
||
|
eql, _ = self._create_fx_call_function(operator.eq, (fx_node, concrete_val))
|
||
|
node, fresh = self._create_fx_call_function(torch._assert, (eql,))
|
||
|
|
||
|
assert node is not None
|
||
|
# If this is a fresh node, we have to remember the event index that
|
||
|
# corresponds to this assertion node.
|
||
|
# Reason: so that, given an assertion node, we can replay the ShapeEnv
|
||
|
# events until the point where this assertion node was freshly created.
|
||
|
if fresh:
|
||
|
self._add_fx_node_metadata(node)
|
||
|
|
||
|
# After creating the FX node corresponding to orig_expr, we must make sure that
|
||
|
# no error will be raised until the end of this function.
|
||
|
#
|
||
|
# Reason: the translation validation may become invalid otherwise.
|
||
|
#
|
||
|
# If an error is raised before the end of this function, we remove the FX node
|
||
|
# inserted, and re-raise the error.
|
||
|
guard = None
|
||
|
tb = None
|
||
|
|
||
|
try:
|
||
|
if orig_expr.is_number:
|
||
|
self.log.debug("eval %s [trivial]", orig_expr)
|
||
|
# NB: don't test float as there may be precision issues
|
||
|
if isinstance(hint, (int, bool)):
|
||
|
assert orig_expr == hint, f"{orig_expr} != {hint}"
|
||
|
return orig_expr
|
||
|
|
||
|
expr = orig_expr
|
||
|
|
||
|
static_expr = self._maybe_evaluate_static(expr,
|
||
|
expect_rational=expect_rational,
|
||
|
size_oblivious=size_oblivious)
|
||
|
if static_expr is not None:
|
||
|
self.log.debug("eval %s == %s [statically known]", orig_expr, static_expr)
|
||
|
# NB: don't test float as there may be precision issues
|
||
|
if isinstance(hint, (int, bool)):
|
||
|
assert static_expr == hint, f"{static_expr} != {hint}"
|
||
|
return static_expr
|
||
|
|
||
|
if not (expr.free_symbols <= self.var_to_val.keys()):
|
||
|
# TODO: dedupe this with _maybe_evaluate_static
|
||
|
# Attempt to eliminate the unbacked SymInt
|
||
|
new_expr = self._maybe_evaluate_static(expr, unbacked_only=True)
|
||
|
if not (new_expr.free_symbols <= self.var_to_val.keys()):
|
||
|
size_oblivious_result = None
|
||
|
if not size_oblivious:
|
||
|
size_oblivious_result = self._maybe_evaluate_static(
|
||
|
expr,
|
||
|
expect_rational=expect_rational,
|
||
|
size_oblivious=True
|
||
|
)
|
||
|
|
||
|
raise self._make_data_dependent_error(
|
||
|
expr.xreplace(self.var_to_val),
|
||
|
expr,
|
||
|
size_oblivious_result=size_oblivious_result
|
||
|
)
|
||
|
expr = new_expr
|
||
|
|
||
|
concrete_val = compute_concrete_val()
|
||
|
self._check_frozen(expr, concrete_val)
|
||
|
|
||
|
if (
|
||
|
config.inject_EVALUATE_EXPR_flip_equality_TESTING_ONLY
|
||
|
and isinstance(hint, bool)
|
||
|
and isinstance(expr, (sympy.Eq, sympy.Ne))
|
||
|
):
|
||
|
expr = sympy.Not(expr)
|
||
|
|
||
|
# Turn this into a boolean expression, no longer need to consult
|
||
|
# concrete_val
|
||
|
suppress_maybe_guard_rel = False
|
||
|
if concrete_val is sympy.true:
|
||
|
g = expr
|
||
|
elif concrete_val is sympy.false:
|
||
|
g = sympy.Not(expr)
|
||
|
else:
|
||
|
# WARNING: we cannot actually do simplifications on guards
|
||
|
# on floating point values, because Sympy generally does not
|
||
|
# think expressions on integers can ever be equal to floating
|
||
|
# point (e.g., sympy.Eq(s0/6, 0.5) evaluates to False). Without
|
||
|
# very clear algebraic laws that hold for floating point, such
|
||
|
# simplifications are error prone anyway, so be sure not to
|
||
|
# maybe_guard_rel in those cases.
|
||
|
if not isinstance(concrete_val, sympy.Integer):
|
||
|
suppress_maybe_guard_rel = True
|
||
|
g = sympy.Eq(expr, concrete_val) # type: ignore[arg-type]
|
||
|
|
||
|
if isinstance(g, sympy.Rel):
|
||
|
# TODO: If we successfully eliminate a symbol via equality, it
|
||
|
# is not actually necessary to save a guard for the equality,
|
||
|
# as we will implicitly generate a guard when we match that
|
||
|
# input against the symbol. Probably the easiest way to
|
||
|
# implement this is to have maybe_guard_rel return a bool
|
||
|
# saying if it "subsumed" the guard (and therefore the guard
|
||
|
# is no longer necessary)
|
||
|
self._maybe_guard_rel(g)
|
||
|
|
||
|
if not self._suppress_guards_tls():
|
||
|
stack = CapturedTraceback.extract(skip=1)
|
||
|
guard = ShapeGuard(g, stack)
|
||
|
# TODO: deal with duplicate guards somehow
|
||
|
self.guards.append(guard)
|
||
|
except Exception:
|
||
|
if fresh:
|
||
|
self._remove_fx_node(node)
|
||
|
raise
|
||
|
else:
|
||
|
if not self._suppress_guards_tls():
|
||
|
assert guard is not None
|
||
|
|
||
|
self._log_guard("eval", g, forcing_spec=forcing_spec)
|
||
|
|
||
|
for s in g.free_symbols:
|
||
|
self.symbol_guard_counter[s] += 1
|
||
|
# Forcing_spec to avoid infinite recursion
|
||
|
if (
|
||
|
not forcing_spec and
|
||
|
config.symbol_guard_limit_before_specialize is not None and
|
||
|
self.symbol_guard_counter[s] > config.symbol_guard_limit_before_specialize
|
||
|
):
|
||
|
# Force specialization
|
||
|
self.log.info(
|
||
|
"symbol_guard_limit_before_specialize=%s exceeded on %s",
|
||
|
config.symbol_guard_limit_before_specialize,
|
||
|
s
|
||
|
)
|
||
|
self.evaluate_expr(s, forcing_spec=True)
|
||
|
else:
|
||
|
self.log.debug("eval %s [guard suppressed]", g)
|
||
|
|
||
|
return concrete_val
|
||
|
|
||
|
def cleanup(self):
|
||
|
"""
|
||
|
Break reference cycles.
|
||
|
|
||
|
This destroys the stacks. If you really want to keep them, we
|
||
|
just need some way to break references on code objects.
|
||
|
"""
|
||
|
for g in self.guards:
|
||
|
g.stack.cleanup()
|
||
|
for s in self.var_to_stack.values():
|
||
|
s.cleanup()
|
||
|
for ras in self.deferred_runtime_asserts.values():
|
||
|
for ra in ras:
|
||
|
ra.stack.cleanup()
|
||
|
|
||
|
@record_shapeenv_event(save_tracked_fakes=True)
|
||
|
def defer_runtime_assert(self, orig_expr: "sympy.Expr", msg, fx_node=None):
|
||
|
"""Create an assert that is checked at runtime
|
||
|
|
||
|
Args:
|
||
|
orig_expr (sympy.Expr): Boolean expression to assert is true
|
||
|
msg (str): Message to display on assertion failure
|
||
|
fx_node (Optional, torch.fx.Node): node in ``self.graph`` corresponding
|
||
|
to the expression, if applicable
|
||
|
|
||
|
"""
|
||
|
expr = orig_expr
|
||
|
|
||
|
# TODO: split conjunctions and evaluate them separately
|
||
|
|
||
|
static_expr = self._maybe_evaluate_static(expr)
|
||
|
if static_expr is not None:
|
||
|
self.log.debug("runtime_assert %s == %s [statically known]", orig_expr, static_expr)
|
||
|
return static_expr
|
||
|
|
||
|
# Attempt to eliminate the unbacked SymInt
|
||
|
new_expr = self._maybe_evaluate_static(expr, unbacked_only=True)
|
||
|
if new_expr.free_symbols <= self.var_to_val.keys():
|
||
|
# Do a normal guard
|
||
|
return self.evaluate_expr(new_expr, fx_node=fx_node)
|
||
|
# NB: Don't use new_expr as expr; it could contain gunk like shape0
|
||
|
# which we don't want to guard on
|
||
|
|
||
|
# OK, we're definitely doing a runtime assert now
|
||
|
if (
|
||
|
self._translation_validation_enabled
|
||
|
and fx_node is not None
|
||
|
and not self._suppress_guards_tls()
|
||
|
):
|
||
|
node, fresh = self._create_fx_call_function(torch._assert, (fx_node,))
|
||
|
assert node is not None
|
||
|
if fresh:
|
||
|
self._add_fx_node_metadata(node)
|
||
|
|
||
|
self._check_frozen(expr, sympy.true)
|
||
|
|
||
|
# eliminate symbols on equality tests / refine ranges
|
||
|
if isinstance(expr, sympy.Rel):
|
||
|
self._maybe_guard_rel(expr)
|
||
|
|
||
|
if not self._suppress_guards_tls():
|
||
|
# canonicalise to remove equations that are trivially equal
|
||
|
orig_expr = expr
|
||
|
expr = canonicalize_bool_expr(expr)
|
||
|
stack = CapturedTraceback.extract(skip=1)
|
||
|
ra = RuntimeAssert(expr, msg, stack)
|
||
|
# TODO: Do this in a way that is less janky than int(s.name[1:])
|
||
|
cands = sorted([s for s in expr.free_symbols if s.name.startswith("u")], key=lambda s: int(s.name[1:]))
|
||
|
self.deferred_runtime_asserts.setdefault(cands[-1], []).append(ra)
|
||
|
self.num_deferred_runtime_asserts += 1
|
||
|
self._update_version_counter()
|
||
|
self._log_guard("runtime_assert", orig_expr, forcing_spec=False)
|
||
|
else:
|
||
|
self.log.debug("runtime_assert %s [guard suppressed]", expr)
|
||
|
|
||
|
return True
|
||
|
|
||
|
# Refines the ranges of the variables present in 'guard'.
|
||
|
#
|
||
|
# This function tries to refine the range of the variables inside
|
||
|
# 'guard' by reasoning about it. Specifically, when 'guard' is a
|
||
|
# 'sympy.Relational' operation.
|
||
|
#
|
||
|
# It does mainly 3 things:
|
||
|
# 1. Tries to isolate a variable in the left-hand side
|
||
|
# 2. Compute the value range of the right-hand side
|
||
|
# 3. Update the value range of the variable, if better
|
||
|
def _refine_ranges(self, expr: sympy.Expr) -> None:
|
||
|
expr = self.simplify(expr)
|
||
|
|
||
|
for symbol in expr.free_symbols:
|
||
|
assert isinstance(symbol, sympy.Symbol)
|
||
|
|
||
|
if isinstance(self.var_to_val.get(symbol, None), SingletonInt):
|
||
|
# Skip var_to_range logic for SingletonInt which is only used
|
||
|
# for jagged layout NestedTensors today
|
||
|
continue
|
||
|
|
||
|
r = try_solve(expr, symbol)
|
||
|
|
||
|
if r is None or not (symbol.is_integer and r[1].is_integer):
|
||
|
# Range refinement only supports integer symbols for now.
|
||
|
# There are lots of SymPy bugs when it comes to comparing
|
||
|
# reals and integers, so we skip that for now.
|
||
|
continue
|
||
|
|
||
|
r_expr, rhs = r
|
||
|
vr = self.var_to_range[symbol]
|
||
|
lower, upper = vr.lower, vr.upper
|
||
|
|
||
|
rhs_vr = bound_sympy(rhs, self.var_to_range)
|
||
|
_assert_bound_is_rational(rhs, rhs_vr)
|
||
|
|
||
|
# Let's suppose that we have a preexisting range for x [0, 100].
|
||
|
# Now, we issue a guard x > y, where the range for y is [50, 150].
|
||
|
# Then, lower = 0, rhs_vr.lower = 50 and therefore refinement can happen,
|
||
|
# refining x to [51, 100], since x must be greater than y, but the lowest
|
||
|
# y could be is 50.
|
||
|
#
|
||
|
# sympy.Eq may update both lower and upper bounds.
|
||
|
# sympy.G{t,e} may update the lower bound, only.
|
||
|
# sympy.L{t,e} may update the upper bound, only.
|
||
|
if lower < rhs_vr.lower and isinstance(r_expr, (sympy.Eq, sympy.Ge, sympy.Gt)):
|
||
|
# Strictly greater relations allow us to refine a bit more, since
|
||
|
# x < y implies that the lower bound for x is: y + 1.
|
||
|
lower = rhs_vr.lower + int(isinstance(r_expr, sympy.Gt))
|
||
|
if upper > rhs_vr.upper and isinstance(r_expr, (sympy.Eq, sympy.Le, sympy.Lt)):
|
||
|
upper = rhs_vr.upper - int(isinstance(r_expr, sympy.Lt))
|
||
|
|
||
|
# Do nothing if the new value range is no better than what we already have.
|
||
|
if vr == ValueRanges(lower, upper):
|
||
|
continue
|
||
|
|
||
|
# Updates the range and the guards corresponding to each bound of the symbol.
|
||
|
self.var_to_range[symbol] = ValueRanges(lower, upper)
|
||
|
# Clears the cache, since this update can change the result.
|
||
|
self._maybe_evaluate_static.cache_clear()
|
||
|
|
||
|
def _is_int(expr):
|
||
|
return isinstance(expr, SymInt) and expr.node.expr.is_number
|
||
|
|
||
|
# WARNING: This is legacy, DO NOT USE
|
||
|
def _is_dim_dynamic(t, d):
|
||
|
return hasattr(t, "_dynamo_dynamic_indices") and d in t._dynamo_dynamic_indices
|