""" This file does three things: - Contains the definition of SymNode - Installs all the magic methods into SymBool, SymFloat, SymFloat at import time - Does not depend on sympy at import time As this file is imported from within torch/__init__.py we do not want it to depend on SymPy to avoid having to load SymPy at import time, as doing so is *very* slow. """ import builtins import itertools import logging import math import operator import sys from functools import lru_cache, update_wrapper from typing import Optional, Type, TYPE_CHECKING, Union import torch # NB: The sym_* functions are used via getattr() and must be imported here. from torch import ( # noqa: F401 sym_float, sym_ite, sym_max, sym_min, sym_not, SymBool, SymFloat, SymInt, ) from torch.fx.experimental._sym_dispatch_mode import ( handle_sym_dispatch, sym_function_mode, ) if TYPE_CHECKING: from torch.fx.experimental.symbolic_shapes import ShapeEnv log = logging.getLogger(__name__) sym_node_log = torch._logging.getArtifactLogger(__name__, "sym_node") __all__ = ["SymNode", "method_to_operator", "magic_methods"] SymTypes = (SymInt, SymFloat, SymBool) def _to_symtype(t): if t is bool: return SymBool if t is int: return SymInt if t is float: return SymFloat return t # TODO: An incomplete list # 1. Set variables to be equal when we do equality # 2. Specialize on 0/1 when we do subtraction class SymNode: """ This is a type erased SymInt/SymFloat which we use to do actual operations. End users don't touch this. Magic methods are NOT defined on this object. """ def __init__( self, expr, shape_env, pytype, hint: Optional[Union[int, float, bool]], constant=None, fx_node=None, ): self._expr = expr self.shape_env = shape_env self.pytype = pytype # What's the difference between hint and constant? # # - A constant is known to be invariant across invocations of the model; # it will always be this value. We only really know this when we # encounter an honest-to-goodness literal (when wrapping it into # a SymNode, we set constant.) Most of the time, constant is None # # - A hint is a *particular* value from the particular run we are # tracing, but it may vary the next time around. It's useful to # keep this around, as if we need a concrete value from a SymNode, # we will return the hint and guard on the expression that produced # it giving the same hint next time around. The hint is not # guaranteed to be set either: if you have an unbacked SymNode, # there won't be any hint; it was the result of some tensor-dependent # computation, but we don't know what it actually is because we # haven't actually run the tensor computation. # # If _hint is None, we will query maybe_evaluate_static(compute_hint=True) # in hopes that we've learned enough about the unbacked symints to # discharge the hint; otherwise, you're likely to just error out. # # (A previous version of this system had some optimizations to only # recompute when it was possible we had learned enough about the # unbacked symint that a hint was now possible, but as we added more # potential refinements to unbacked symints this got harder to keep # in sync, so we've deleted it for now.) if hint is not None: assert type(hint) is pytype or type(hint) is _to_symtype(pytype), ( "Cannot create SymNode of type " f"{pytype} with incompatible hint of type {type(hint)}" ) self._hint = hint self.constant: Optional[Union[int, float, bool]] = constant # Record the FX node of the current node if we are doing translation # validation. They will be used for building the input assertions for # the translation validation problem. self.fx_node = ( fx_node if self.shape_env._translation_validation_enabled else None ) def with_shape_env(self, shape_env: "ShapeEnv") -> "SymNode": return SymNode( self._expr, shape_env, self.pytype, self._hint, self.constant, self.fx_node ) @property def expr(self): return self.shape_env.replace(self._expr) # Recompute the hint and see if we've got it now # Precondition: self._hint is None def _update_hint(self): r = self.shape_env._maybe_evaluate_static(self.expr, compute_hint=True) if r is not None: self._hint = self.pytype(r) if not isinstance(r, SymTypes) else r @property def hint(self): if self._hint is None: self._update_hint() return self._hint def has_hint(self): if self._hint is None: self._update_hint() return self._hint is not None def require_hint(self, fallback=None): if self._hint is None: self._update_hint() if self._hint is None: if fallback is not None: return fallback # NB: we expect this to raise return self.shape_env.size_hint(self.expr) return self._hint def maybe_as_int(self): if self.expr.is_number: return int(self.expr) else: return None def is_int(self): return self.pytype is int def is_float(self): return self.pytype is float def is_bool(self): return self.pytype is bool def is_nested_int(self): # Unbacked SymInts cannot be nested int today return ( self._hint is not None and isinstance(self._hint, SymInt) and self._hint.node.is_nested_int() ) def wrap_int(self, num): assert type(num) is int import sympy return SymNode( sympy.Integer(num), self.shape_env, int, num, constant=num, fx_node=num ) def wrap_float(self, num): assert type(num) is float import sympy return SymNode( sympy.Float(num), self.shape_env, float, num, constant=num, fx_node=num ) def wrap_bool(self, num): assert type(num) is bool import sympy return SymNode( sympy.true if num else sympy.false, self.shape_env, bool, num, constant=num, fx_node=num, ) def clone(self): return self def str(self): return f"{self.expr}" def __str__(self): return self.str() def __repr__(self): return self.str() # These methods call the metaprogrammed methods, they're hand written # here so we get good stack traces def abs(self) -> "SymNode": return self._abs() # type: ignore[attr-defined] def pos(self) -> "SymNode": return self._pos() # type: ignore[attr-defined] def round(self, ndigits=None) -> "SymNode": return self._round(ndigits) # type: ignore[attr-defined] def add(self, other) -> "SymNode": return self._add(other) # type: ignore[attr-defined] def sub(self, other) -> "SymNode": return self._sub(other) # type: ignore[attr-defined] def mul(self, other) -> "SymNode": return self._mul(other) # type: ignore[attr-defined] def mod(self, other) -> "SymNode": return self._mod(other) # type: ignore[attr-defined] def pow(self, other) -> "SymNode": return self._pow(other) # type: ignore[attr-defined] def and_(self, other) -> "SymNode": return self._and_(other) # type: ignore[attr-defined] def or_(self, other) -> "SymNode": return self._or_(other) # type: ignore[attr-defined] def truediv(self, other) -> "SymNode": return self._truediv(other) # type: ignore[attr-defined] def floordiv(self, other) -> "SymNode": return self._floordiv(other) # type: ignore[attr-defined] def lshift(self, other) -> "SymNode": return self._lshift(other) # type: ignore[attr-defined] def rshift(self, other) -> "SymNode": return self._rshift(other) # type: ignore[attr-defined] def sym_not(self) -> "SymNode": # noqa: F811 return self._sym_not() # type: ignore[attr-defined] def eq(self, other) -> "SymNode": return self._eq(other) # type: ignore[attr-defined] def ne(self, other) -> "SymNode": return self._ne(other) # type: ignore[attr-defined] def gt(self, other) -> "SymNode": return self._gt(other) # type: ignore[attr-defined] def lt(self, other) -> "SymNode": return self._lt(other) # type: ignore[attr-defined] def le(self, other) -> "SymNode": return self._le(other) # type: ignore[attr-defined] def ge(self, other) -> "SymNode": return self._ge(other) # type: ignore[attr-defined] def floor(self) -> "SymNode": return self._floor() # type: ignore[attr-defined] def is_integer(self) -> "SymNode": return self._is_integer() # type: ignore[attr-defined] def sym_float(self) -> "SymNode": # noqa: F811 return self._sym_float() # type: ignore[attr-defined] def sym_int(self) -> "SymNode": return self._sym_int() # type: ignore[attr-defined] def ceil(self) -> "SymNode": return self._ceil() # type: ignore[attr-defined] def neg(self) -> "SymNode": return self._neg() # type: ignore[attr-defined] def sym_min(self, other) -> "SymNode": # noqa: F811 return self._sym_min(other) # type: ignore[attr-defined] def sym_max(self, other) -> "SymNode": # noqa: F811 return self._sym_max(other) # type: ignore[attr-defined] def sym_ite(self, then_val, else_val) -> "SymNode": return self._sym_ite(then_val, else_val) # type: ignore[attr-defined] def is_contiguous(self, sizes, strides) -> "SymNode": return self._is_contiguous(sizes, strides) # type: ignore[attr-defined] def is_channels_last_contiguous_2d(self, sizes, strides) -> "SymNode": return self._is_channels_last_contiguous_2d(sizes, strides) # type: ignore[attr-defined] def is_channels_last_contiguous_3d(self, sizes, strides) -> "SymNode": return self._is_channels_last_contiguous_3d(sizes, strides) # type: ignore[attr-defined] def is_channels_last_strides_2d(self, sizes, strides) -> "SymNode": return self._is_channels_last_strides_2d(sizes, strides) # type: ignore[attr-defined] def is_channels_last_strides_3d(self, sizes, strides) -> "SymNode": return self._is_channels_last_strides_3d(sizes, strides) # type: ignore[attr-defined] def is_non_overlapping_and_dense_indicator(self, sizes, strides) -> "SymNode": return self._is_non_overlapping_and_dense_indicator(sizes, strides) # type: ignore[attr-defined] # Make C++ happy def sym_or(self, other): return self.or_(other) def sym_and(self, other): return self.and_(other) def is_non_overlapping_and_dense(self, sizes, strides): return self.is_non_overlapping_and_dense_indicator(sizes, strides).eq(to_node(self, 1)) # type: ignore[attr-defined] def int_(self): return self.guard_int("", 0) # NB: uses Python backtrace # You can manually trigger a guard with this function def guard_int(self, file, line): # TODO: use the file/line for some useful diagnostic on why a # guard occurred r = self.shape_env.evaluate_expr(self.expr, self.hint, fx_node=self.fx_node) try: return int(r) except Exception: log.warning("Failed to convert to int: %s", r) raise def guard_float(self, file, line): # TODO: use the file/line for some useful diagnostic on why a # guard occurred r = self.shape_env.evaluate_expr( self.expr, self.hint, fx_node=self.fx_node, expect_rational=False ) try: return float(r) except Exception: log.warning("Failed to convert to float: %s", r) raise def guard_bool(self, file, line): # TODO: use the file/line for some useful diagnostic on why a # guard occurred r = self.shape_env.evaluate_expr(self.expr, self.hint, fx_node=self.fx_node) try: return bool(r) except Exception: log.warning("Failed to convert to bool: %s", r) raise def expect_true(self, file, line): from torch.fx.experimental.symbolic_shapes import free_unbacked_symbols if self.has_hint() and not free_unbacked_symbols(self.expr): # OK to generate guards return self.guard_bool(file, line) # Generate a deferred runtime assert (this might actually end up doing # a regular guard if we can!) # TODO: file/line here is very important, because the assert has been # deferred so you can't backtrace easily return self.shape_env.defer_runtime_assert( self.expr, f"{file}:{line}", fx_node=self.fx_node ) def expect_size(self, file, line): from torch.fx.experimental.symbolic_shapes import _advise_is_size b = self.ge(self.wrap_int(0)) # Generate a deferred runtime assert r = b.expect_true(file, line) # Refine compile time range, but only if it's unbacked. # If you refine range for hinted variables, you can end up making # improper deductions since compile time reasoning may be # incompatible with runtime reasoning. if r and not self.has_hint(): _advise_is_size(SymInt(self)) return r def guard_size_oblivious(self, file, line): """ Like guard_bool, but if we encounter unbacked symbols, if those symbols are size-like, we will treat them as >= 2 for the purposes of the analysis. This CHANGES the runtime semantics, but all size-oblivious sites have been audited to ensure that the runtime semantics don't change in a material way. Acceptable runtime semantic changes are, e.g., squeeze() no longer dropping an unbacked one size, or a tensor reporting as non-contiguous even if it's contiguous if it would have been reported contiguous due to being empty. """ # TODO: use the file/line for some useful diagnostic on why a # guard occurred r = self.shape_env.evaluate_expr( self.expr, self.hint, fx_node=self.fx_node, size_oblivious=True ) try: return bool(r) except Exception: log.warning("Failed to convert to bool: %s", r) raise def bool_(self): return self.guard_bool("", 0) def is_symbolic(self): return True def nested_int(self): return None def is_constant(self): return False # TODO: this probably needs the sizes-strides eval functions METHOD_TO_OPERATOR = { "pos": operator.pos, "abs": operator.abs, "add": operator.add, "and": operator.and_, "ceil": math.ceil, "eq": operator.eq, "floor": math.floor, "floordiv": operator.floordiv, "ge": operator.ge, "gt": operator.gt, "is_integer": lambda x: x.is_integer(), "le": operator.le, "lshift": operator.lshift, "lt": operator.lt, "mod": operator.mod, "mul": operator.mul, "ne": operator.ne, "neg": operator.neg, "or": operator.or_, "pow": operator.pow, "round": builtins.round, "rshift": operator.rshift, "sub": operator.sub, "sym_float": sym_float, "sym_ite": sym_ite, "sym_max": sym_max, "sym_min": sym_min, "sym_not": sym_not, "truediv": operator.truediv, } unary_magic_methods = { "abs", "sym_float", "ceil", "floor", "neg", "sym_not", "pos", } # Adding math ops: sqrt, cos, sin, ... def _get_sym_node_fn(name): def fn(self): return getattr(self, f"_sym_{name}")() return fn math_op_names = ( "sqrt", "cos", "cosh", "sin", "sinh", "tan", "tanh", "asin", "acos", "atan", ) for name in math_op_names: sym_name = f"sym_{name}" priv_sym_name = f"_{sym_name}" setattr(SymNode, sym_name, _get_sym_node_fn(name)) METHOD_TO_OPERATOR[sym_name] = getattr(torch, priv_sym_name) unary_magic_methods.add(sym_name) __all__.append(sym_name) # Unary methods that are not magic methods unary_nonmagic_methods = { "is_integer", } unary_methods = unary_magic_methods | unary_nonmagic_methods # Most methods are only registered on SymInt and SymFloat # Some methods are only be registered on SymBool only_bool_magic_methods = {"and", "or", "sym_not", "sym_ite"} # Methods that implicitly convert SymBool into SymInt bool_becomes_int_magic_methods = {"add", "sub", "mul"} # Methods that are also on SymBool, in addition to on SymInt and SymFloat also_bool_magic_methods = {"eq"} bool_magic_methods = only_bool_magic_methods | also_bool_magic_methods # Methods that are only for float only_float_magic_methods = {"is_integer"} magic_methods_on_operator_with_trailing_underscore = {"and", "or"} always_float_magic_methods = {"truediv", "sym_float", "pow"} for name in math_op_names: sym_name = f"sym_{name}" always_float_magic_methods.add(sym_name) always_int_magic_methods = {"ceil", "floor"} always_bool_magic_methods = { "eq", "ne", "gt", "lt", "le", "ge", "and", "or", "sym_not", "is_non_overlapping_and_dense", "is_integer", } # Methods that have a `__foo__` as well as `__rfoo__` def _sympy_truediv(a, b): from torch.utils._sympy.functions import TrueDiv return TrueDiv(a, b) def _sympy_floordiv(a, b): from torch.utils._sympy.functions import FloorDiv return FloorDiv(a, b) def _sympy_mod(a, b): from torch.utils._sympy.functions import Mod return Mod(a, b) def _sympy_pow(a, b): from torch.utils._sympy.functions import Pow return Pow(a, b) def _sympy_and(a, b): import sympy return sympy.And(a, b) def _sympy_or(a, b): import sympy return sympy.Or(a, b) def _sympy_lshift(a, b): from torch.utils._sympy.functions import LShift return LShift(a, b) def _sympy_rshift(a, b): from torch.utils._sympy.functions import RShift return RShift(a, b) reflectable_magic_methods = { "add": operator.add, "sub": operator.sub, "mul": operator.mul, "mod": _sympy_mod, "pow": _sympy_pow, "and": _sympy_and, "or": _sympy_or, "truediv": _sympy_truediv, "floordiv": _sympy_floordiv, "lshift": _sympy_lshift, "rshift": _sympy_rshift, } def _floor_ceil_helper(a, fn): import sympy if isinstance(a, sympy.Mul): aa = a.args if len(aa) == 2 and isinstance(aa[0], sympy.Float) and aa[1].is_integer: coef = sympy.Integer(aa[0]) if aa[0] == coef: # structural equality test return coef * aa[1] if ( isinstance(a, sympy.Float) and a == sympy.Integer(a) or isinstance(a, sympy.Integer) ): return sympy.Integer(a) return fn(a) def _sympy_floor(a): import sympy return _floor_ceil_helper(a, sympy.floor) def _sympy_ceil(a): import sympy return _floor_ceil_helper(a, sympy.ceiling) def _sympy_eq(a, b): import sympy return sympy.Eq(a, b) def _sympy_ne(a, b): import sympy return sympy.Ne(a, b) def _sympy_gt(a, b): import sympy return sympy.Gt(a, b) def _sympy_lt(a, b): import sympy return sympy.Lt(a, b) def _sympy_le(a, b): import sympy return sympy.Le(a, b) def _sympy_ge(a, b): import sympy return sympy.Ge(a, b) def _sympy_min(a, b): import sympy return sympy.Min(a, b) def _sympy_max(a, b): import sympy return sympy.Max(a, b) def _sympy_ite(a, t, f): import sympy return sympy.Piecewise((t, a), (f, True)) current_module = sys.modules[__name__] def _get_sym_math_fn(name): def fn(a): import sympy return getattr(sympy, name)(a) return fn for name in math_op_names: priv_sympy_name = f"_sympy_{name}" fn = _get_sym_math_fn(name) fn.__qualname__ = fn.__name__ = priv_sympy_name setattr(current_module, priv_sympy_name, fn) del fn, name, priv_sympy_name # type: ignore[possibly-undefined] def _sympy_abs(a): import sympy return sympy.Abs(a) def _sympy_round(number, ndigits=None): from torch.utils._sympy.functions import Round, RoundDecimal if ndigits is None: return Round(number) else: return RoundDecimal(number, ndigits) def _sympy_sym_float(a): # Cannot use sympy.Float(a) here, coz it expects python literals # Multiply by 1.0 to cast to float. This is needed when the input # is a SymInt which has the assumption that it is integer and # SymPy will otherwise assume that return value cannot be a float. return a * 1.0 def _sympy_is_integer(a): import sympy return sympy.Eq(sympy.floor(a), a) magic_methods = { **reflectable_magic_methods, "sym_not": operator.invert, "pos": operator.pos, "eq": _sympy_eq, "ne": _sympy_ne, "gt": _sympy_gt, "lt": _sympy_lt, "le": _sympy_le, "ge": _sympy_ge, "floor": _sympy_floor, "sym_float": _sympy_sym_float, "ceil": _sympy_ceil, "neg": operator.neg, "sym_min": _sympy_min, "sym_max": _sympy_max, "sym_ite": _sympy_ite, "abs": _sympy_abs, "round": _sympy_round, "is_integer": _sympy_is_integer, } for name in math_op_names: sym_name = f"sym_{name}" magic_methods[sym_name] = getattr(current_module, f"_sympy_{name}") del name, sym_name, math_op_names, current_module # type: ignore[possibly-undefined] def sympy_is_contiguous(sizes, strides): dim = len(sizes) return sympy_is_contiguous_generic(sizes, strides, list(range(dim - 1, -1, -1))) def sympy_is_contiguous_generic(sizes, strides, dim_order): import sympy dim = len(sizes) if len(dim_order) != dim: return sympy.false is_contiguous = sympy.true z = sympy.Integer(1) # Contiguous if the strides make sense (or the dim is size 1) for d in dim_order: is_contiguous &= sympy.Eq(sizes[d], sympy.Integer(1)) | sympy.Eq(strides[d], z) z *= sizes[d] # OR if any size is zero for d in range(dim): is_contiguous |= sympy.Eq(sizes[d], sympy.Integer(0)) return is_contiguous # NB: There is a TODO in C++ to allow omitting the batch dim. If that # happens you will need to refactor this def sympy_is_channels_last_contiguous_2d(sizes, strides): return sympy_is_contiguous_generic(sizes, strides, [1, 3, 2, 0]) def sympy_is_channels_last_contiguous_3d(sizes, strides): return sympy_is_contiguous_generic(sizes, strides, [1, 4, 3, 2, 0]) def sympy_is_channels_last_strides_generic(sizes, strides, dim_order): import sympy dim = len(sizes) if dim != len(dim_order): return sympy.false m = sympy.Integer(0) r = sympy.true # special case for trivial C dimension. default to NCHW r &= sympy.Ne(strides[1], 0) for d in dim_order: r &= sympy.Ne(sizes[d], 0) & (strides[d] >= m) # Fallback to NCHW as default layout for ambiguous cases # This is the flaw of implicit memory_format from strides. # N111 tensor with identical strides for size 1 dimension; # Two cases could lead us here: # a. N111 contiguous Tensor ([N,1,1,1]@[1,1,1,1]) # b. N11W contiguous Tensor sliced on the W-dimension. # ([N,1,1,1]@[W,W,W,W]) if d == 0: r &= sympy.Ne(m, strides[1]) # This is necessary to: # 1. distinguish the memory_format of N1H1; # [H, 1, 1, 1] channels_last stride # [H, H, 1, 1] contiguous stride # 2. permutation of 1C1W: # [1, C, 1, H]@[HC, H, H, 1] transpose(1, 3) # [1, H, 1, C]@[HC, 1, H, H] shouldn't be identified as # channels_last m = strides[d] * sympy.Max(sizes[d], 1) return r def sympy_is_channels_last_strides_2d(sizes, strides): return sympy_is_channels_last_strides_generic(sizes, strides, [1, 3, 2, 0]) def sympy_is_channels_last_strides_3d(sizes, strides): return sympy_is_channels_last_strides_generic(sizes, strides, [1, 4, 3, 2, 0]) def _sympy_is_non_overlapping_and_dense_indicator(sizes, strides): from torch.utils._sympy.functions import IsNonOverlappingAndDenseIndicator return IsNonOverlappingAndDenseIndicator(*sizes, *strides) sizes_strides_methods = { # TODO: These could also be done with indicators, maybe it is better # for reasoning to do it that way "is_contiguous": sympy_is_contiguous, "is_channels_last_contiguous_2d": sympy_is_channels_last_contiguous_2d, "is_channels_last_contiguous_3d": sympy_is_channels_last_contiguous_3d, "is_channels_last_strides_2d": sympy_is_channels_last_strides_2d, "is_channels_last_strides_3d": sympy_is_channels_last_strides_3d, "is_non_overlapping_and_dense_indicator": _sympy_is_non_overlapping_and_dense_indicator, } alternate_impl_if_hinted_methods = { "sym_min": builtins.min, "sym_max": builtins.max, } def to_node(self, num): if isinstance(num, SymTypes): return num.node elif type(num) is bool: return self.wrap_bool(num) elif type(num) is int: return self.wrap_int(num) elif type(num) is float: return self.wrap_float(num) else: # NotImplemented is important so that Python tries the # other magic method return NotImplemented def wrap_node(x): # TODO: let C++ also take advantage of this if isinstance(x, SymNode) and x.constant is not None: return x.constant if x.is_int(): return SymInt(x) elif x.is_float(): return SymFloat(x) elif x.is_bool(): return SymBool(x) else: raise AssertionError(f"unrecognized return type {x}") def method_to_operator(method): return METHOD_TO_OPERATOR[method] def _make_node_magic(method, func): func = lru_cache(256)(func) if method in magic_methods_on_operator_with_trailing_underscore: method_attr = f"{method}_" else: method_attr = method def binary_magic_impl(self, other): from torch.fx.experimental.symbolic_shapes import safe_expand op = method_to_operator(method) out_hint = None if self.hint is not None and other.hint is not None: out_hint = op(self.hint, other.hint) alternate_impl = alternate_impl_if_hinted_methods.get(method) if alternate_impl and out_hint is not None: return to_node(self, alternate_impl(wrap_node(self), wrap_node(other))) if sym_function_mode(): return to_node( self, handle_sym_dispatch(op, (wrap_node(self), wrap_node(other)), {}) ) assert isinstance(other, SymNode) # TODO: consider constant prop here try: out = func(self.expr, other.expr) except Exception: log.warning("failed to eval %s(%s, %s)", method, self.expr, other.expr) raise out = safe_expand(out) sym_node_log.debug("%s %s %s -> %s", func, self.expr, other.expr, out) pytype: Type # This is not strictly correct. In Python, a**b may return complex when # a < 0 and b is a float: (-1)**2.1. Same for sympy.sqrt(-3.14). This # returns a float while both arguments are ints: 2**(-1). Also, max and # min do not type promote. To avoid having data-dependent control flow # here, we just set the type to float if one of the args is a float. In # case of a type mismatch, we assume that it will be detected during # evaluation. if method in always_float_magic_methods: pytype = float elif method in always_bool_magic_methods: pytype = bool elif self.pytype is float or other.pytype is float: pytype = float else: pytype = self.pytype if ( pytype is not None and out_hint is not None and not isinstance(out_hint, SymTypes) ): out_hint = pytype(out_hint) # Create a FX node that corresponds to the operation being applied to # this node. fx_node, _ = self.shape_env._create_fx_call_function( op, (self.fx_node, other.fx_node) ) return SymNode(out, self.shape_env, pytype, out_hint, fx_node=fx_node) def unary_magic_impl(self): from torch.fx.experimental.symbolic_shapes import safe_expand op = method_to_operator(method) if sym_function_mode(): return to_node(self, handle_sym_dispatch(op, (wrap_node(self),), {})) # TODO: consider constant prop here expr = self.expr if method == "floor" or method == "ceiling": expr = self.shape_env._simplify_floor_div(expr) try: out = func(expr) except Exception: log.warning("failed to eval %s(%s)", method, expr) raise sym_node_log.debug("%s %s -> %s", func, expr, out) out_hint = None if self.hint is not None: out_hint = op(self.hint) out = safe_expand(out) pytype: Type if method in always_int_magic_methods: pytype = int elif method in always_bool_magic_methods: pytype = bool elif method in always_float_magic_methods: pytype = float else: pytype = self.pytype fx_node, _ = self.shape_env._create_fx_call_function(op, (self.fx_node,)) return SymNode(out, self.shape_env, pytype, out_hint, fx_node=fx_node) if method in unary_methods: setattr(SymNode, f"_{method_attr}", unary_magic_impl) elif method == "sym_ite": def sym_ite_impl(pred_node, then_node, else_node): from torch.fx.experimental.symbolic_shapes import safe_expand out_hint = then_node.hint if pred_node.hint else else_node.hint if sym_function_mode(): return to_node( pred_node, handle_sym_dispatch( sym_ite, ( wrap_node(pred_node), wrap_node(then_node), wrap_node(else_node), ), {}, ), ) try: out = func(pred_node.expr, then_node.expr, else_node.expr) except Exception: log.warning( "failed to eval %s(%s, %s, %s)", method, pred_node.expr, then_node.expr, else_node.expr, ) raise out = safe_expand(out) fx_node, _ = pred_node.shape_env._create_fx_call_function( sym_ite, (pred_node.fx_node, then_node.fx_node, else_node.fx_node) ) return SymNode( out, pred_node.shape_env, then_node.pytype, out_hint, fx_node=fx_node ) setattr(SymNode, f"_{method_attr}", sym_ite_impl) elif method == "round": def round_impl(self, ndigits=None): from torch.fx.experimental.symbolic_shapes import safe_expand op = builtins.round if sym_function_mode(): return to_node( self, handle_sym_dispatch(op, (wrap_node(self), ndigits), {}) ) expr = self.expr try: out = func(expr, ndigits) except Exception: log.warning("failed to eval %s(%s, ndigits=%s)", method, expr, ndigits) raise out = safe_expand(out) pytype = int if ndigits is None else self.pytype out_hint = None if self.hint is not None: out_hint = op(self.hint, ndigits) # Internally, None is used as sentinel to indicate that a something is not a node on an FX graph. At the # same time, there is no way to wrap a plain None into an FX node. Thus, there is no way to pass None here # without triggering some asserts that check whether we are mixing FX nodes with untracked arguments. The # hack down below works, because all round function down the line all take ndigits=None as default in their # signature. # TODO: Remove the args construction below if a different sentinel is used by FX. args = [self.fx_node] if ndigits is not None: args.append(ndigits) fx_node, _ = self.shape_env._create_fx_call_function(op, tuple(args)) return SymNode(out, self.shape_env, pytype, out_hint, fx_node=fx_node) setattr(SymNode, f"_{method_attr}", round_impl) else: setattr(SymNode, f"_{method_attr}", binary_magic_impl) def _make_node_sizes_strides(method, func): # NB: don't LRU cache, lots of arguments def sizes_strides_impl(self, sizes, strides): op = getattr(sys.modules[__name__], method) if sym_function_mode(): return to_node( self, handle_sym_dispatch( op, ([wrap_node(s) for s in sizes], [wrap_node(s) for s in strides]), {}, ), ) size_exprs = [s.expr for s in sizes] stride_exprs = [s.expr for s in strides] try: out = func(size_exprs, stride_exprs) except Exception: log.warning("failed to eval %s(%s, %s)", method, size_exprs, stride_exprs) raise # bool is never expandable size_hints = [] out_hint = None for s in sizes: if s.hint is None: break size_hints.append(s.hint) else: stride_hints = [] for s in strides: if s.hint is None: break stride_hints.append(s.hint) else: out_hint = op(size_hints, stride_hints) # NB: This is the indicator function, not the actual bool! pytype: Type if method.endswith("_indicator"): pytype = int else: pytype = bool return SymNode(out, self.shape_env, pytype, out_hint) setattr(SymNode, f"_{method}", sizes_strides_impl) # TODO: This is technically hotpath, but in the ideal end state # guards on this will resolve at a higher level so you never # spend time in this code def sizes_strides_user(sizes, strides): import sympy from torch.fx.experimental.symbolic_shapes import ( eval_is_non_overlapping_and_dense, ) for a in itertools.chain(sizes, strides): if isinstance(a, SymInt): return wrap_node( getattr(a.node, method)( [to_node(a.node, b) for b in sizes], [to_node(a.node, b) for b in strides], ) ) if method == "is_non_overlapping_and_dense_indicator": return eval_is_non_overlapping_and_dense(sizes, strides) else: # TODO: this is an awful implementation return bool( func( [sympy.sympify(a) for a in sizes], [sympy.sympify(a) for a in strides], ) ) # Skip for is_non_overlapping_and_dense_indicator if not hasattr(sys.modules[__name__], method): setattr(sys.modules[__name__], method, sizes_strides_user) for method, func in magic_methods.items(): _make_node_magic(method, func) for method, func in sizes_strides_methods.items(): _make_node_sizes_strides(method, func) def _make_user_magic(method, user_type): # User magic takes care of wrapping the other operand into a node, # so that our internal logic can assume everything is nodes if method in magic_methods_on_operator_with_trailing_underscore: method_attr = f"sym_{method}" else: method_attr = method def get_constant(x: Union[SymInt, int, SymFloat, float, SymBool, bool]): if isinstance(x, (int, float, bool)): return x if isinstance(x, SymBool): return x.node.guard_bool("", 0) raise AssertionError("expect to be called with constant SymBools") def is_constant(x): if isinstance(x, (int, float, bool)): return True if isinstance(x, (SymInt, SymFloat, SymBool)): return x.node.is_constant() return False if method in bool_becomes_int_magic_methods: def promote(x): """Implements True+True=2, which works in python but not sympy""" if isinstance(x, SymBool): return SymInt(x.node.wrap_int(int(x))) return x else: def promote(x): return x # Before and after performing the operation, check if any operands are constant. # If so, extract out the constant values first. If `self` itself is a # constant, then "redispatch" by calling back into the operator. Sometimes # this means that operations involving SymBool return plain bools. # Alternatively, we could also rewrap into constant Symbool (i.e. by # implementing wrap_bool in ConstantSymNodeImpl), but we're not doing that # today for no particular reason. def unary_magic_impl(self): self = promote(self) if is_constant(self): return (method_to_operator(method))(get_constant(self)) return wrap_node(getattr(self.node, method_attr)()) def binary_magic_impl(self, other): sym_node_log.debug("MAGIC %s %s %s", method, self, other) self = promote(self) other = promote(other) if is_constant(self): return (method_to_operator(method))(get_constant(self), other) if is_constant(other): other = get_constant(other) other_node = to_node(self.node, other) if other_node is NotImplemented: return NotImplemented ret = wrap_node(getattr(self.node, method_attr)(other_node)) return get_constant(ret) if is_constant(ret) else ret def rbinary_magic_impl(self, other): self = promote(self) other = promote(other) if is_constant(self): return (method_to_operator(method))(get_constant(self), other) if is_constant(other): other = get_constant(other) other_node = to_node(self.node, other) if other_node is NotImplemented: return NotImplemented ret = wrap_node(getattr(other_node, method_attr)(self.node)) return get_constant(ret) if is_constant(ret) else ret if method in unary_magic_methods: setattr(user_type, f"__{method}__", unary_magic_impl) elif method in unary_nonmagic_methods: orig = getattr(user_type, method) setattr(user_type, method, update_wrapper(unary_magic_impl, orig)) elif method == "sym_ite": def sym_ite_magic_impl(pred, then_val, else_val): pred_node = pred.node then_node = to_node(pred_node, then_val) else_node = to_node(pred_node, else_val) if then_node is NotImplemented or else_node is NotImplemented: return NotImplemented assert ( isinstance(then_node, SymNode) and isinstance(else_node, SymNode) and then_node.pytype == else_node.pytype ) ret = wrap_node(getattr(pred.node, method_attr)(then_node, else_node)) return get_constant(ret) if ret.node.is_constant() else ret setattr(user_type, f"__{method}__", sym_ite_magic_impl) elif method == "round": def round_magic_impl(self, ndigits=None): if is_constant(self): return builtins.round(get_constant(self), ndigits) return wrap_node(getattr(self.node, method)(ndigits)) setattr(user_type, f"__{method}__", round_magic_impl) else: setattr(user_type, f"__{method}__", binary_magic_impl) if method in reflectable_magic_methods: setattr(user_type, f"__r{method}__", rbinary_magic_impl) for method, func in magic_methods.items(): # type: ignore[assignment] if method in only_bool_magic_methods: _make_user_magic(method, SymBool) continue if method in only_float_magic_methods: _make_user_magic(method, SymFloat) continue if method in also_bool_magic_methods or method in bool_becomes_int_magic_methods: _make_user_magic(method, SymBool) _make_user_magic(method, SymInt) _make_user_magic(method, SymFloat) del method del func