ai-content-maker/.venv/Lib/site-packages/torch/fx/experimental/sym_node.py

1331 lines
41 KiB
Python

"""
This file does three things:
- Contains the definition of SymNode
- Installs all the magic methods into SymBool, SymFloat, SymFloat at import time
- Does not depend on sympy at import time
As this file is imported from within torch/__init__.py we do not want it to depend on SymPy
to avoid having to load SymPy at import time, as doing so is *very* slow.
"""
import builtins
import itertools
import logging
import math
import operator
import sys
from functools import lru_cache, update_wrapper
from typing import Optional, Type, TYPE_CHECKING, Union
import torch
# NB: The sym_* functions are used via getattr() and must be imported here.
from torch import ( # noqa: F401
sym_float,
sym_ite,
sym_max,
sym_min,
sym_not,
SymBool,
SymFloat,
SymInt,
)
from torch.fx.experimental._sym_dispatch_mode import (
handle_sym_dispatch,
sym_function_mode,
)
if TYPE_CHECKING:
from torch.fx.experimental.symbolic_shapes import ShapeEnv
log = logging.getLogger(__name__)
sym_node_log = torch._logging.getArtifactLogger(__name__, "sym_node")
__all__ = ["SymNode", "method_to_operator", "magic_methods"]
SymTypes = (SymInt, SymFloat, SymBool)
def _to_symtype(t):
if t is bool:
return SymBool
if t is int:
return SymInt
if t is float:
return SymFloat
return t
# TODO: An incomplete list
# 1. Set variables to be equal when we do equality
# 2. Specialize on 0/1 when we do subtraction
class SymNode:
"""
This is a type erased SymInt/SymFloat which we use to do actual operations.
End users don't touch this. Magic methods are NOT defined on this object.
"""
def __init__(
self,
expr,
shape_env,
pytype,
hint: Optional[Union[int, float, bool]],
constant=None,
fx_node=None,
):
self._expr = expr
self.shape_env = shape_env
self.pytype = pytype
# What's the difference between hint and constant?
#
# - A constant is known to be invariant across invocations of the model;
# it will always be this value. We only really know this when we
# encounter an honest-to-goodness literal (when wrapping it into
# a SymNode, we set constant.) Most of the time, constant is None
#
# - A hint is a *particular* value from the particular run we are
# tracing, but it may vary the next time around. It's useful to
# keep this around, as if we need a concrete value from a SymNode,
# we will return the hint and guard on the expression that produced
# it giving the same hint next time around. The hint is not
# guaranteed to be set either: if you have an unbacked SymNode,
# there won't be any hint; it was the result of some tensor-dependent
# computation, but we don't know what it actually is because we
# haven't actually run the tensor computation.
#
# If _hint is None, we will query maybe_evaluate_static(compute_hint=True)
# in hopes that we've learned enough about the unbacked symints to
# discharge the hint; otherwise, you're likely to just error out.
#
# (A previous version of this system had some optimizations to only
# recompute when it was possible we had learned enough about the
# unbacked symint that a hint was now possible, but as we added more
# potential refinements to unbacked symints this got harder to keep
# in sync, so we've deleted it for now.)
if hint is not None:
assert type(hint) is pytype or type(hint) is _to_symtype(pytype), (
"Cannot create SymNode of type "
f"{pytype} with incompatible hint of type {type(hint)}"
)
self._hint = hint
self.constant: Optional[Union[int, float, bool]] = constant
# Record the FX node of the current node if we are doing translation
# validation. They will be used for building the input assertions for
# the translation validation problem.
self.fx_node = (
fx_node if self.shape_env._translation_validation_enabled else None
)
def with_shape_env(self, shape_env: "ShapeEnv") -> "SymNode":
return SymNode(
self._expr, shape_env, self.pytype, self._hint, self.constant, self.fx_node
)
@property
def expr(self):
return self.shape_env.replace(self._expr)
# Recompute the hint and see if we've got it now
# Precondition: self._hint is None
def _update_hint(self):
r = self.shape_env._maybe_evaluate_static(self.expr, compute_hint=True)
if r is not None:
self._hint = self.pytype(r) if not isinstance(r, SymTypes) else r
@property
def hint(self):
if self._hint is None:
self._update_hint()
return self._hint
def has_hint(self):
if self._hint is None:
self._update_hint()
return self._hint is not None
def require_hint(self, fallback=None):
if self._hint is None:
self._update_hint()
if self._hint is None:
if fallback is not None:
return fallback
# NB: we expect this to raise
return self.shape_env.size_hint(self.expr)
return self._hint
def maybe_as_int(self):
if self.expr.is_number:
return int(self.expr)
else:
return None
def is_int(self):
return self.pytype is int
def is_float(self):
return self.pytype is float
def is_bool(self):
return self.pytype is bool
def is_nested_int(self):
# Unbacked SymInts cannot be nested int today
return (
self._hint is not None
and isinstance(self._hint, SymInt)
and self._hint.node.is_nested_int()
)
def wrap_int(self, num):
assert type(num) is int
import sympy
return SymNode(
sympy.Integer(num), self.shape_env, int, num, constant=num, fx_node=num
)
def wrap_float(self, num):
assert type(num) is float
import sympy
return SymNode(
sympy.Float(num), self.shape_env, float, num, constant=num, fx_node=num
)
def wrap_bool(self, num):
assert type(num) is bool
import sympy
return SymNode(
sympy.true if num else sympy.false,
self.shape_env,
bool,
num,
constant=num,
fx_node=num,
)
def clone(self):
return self
def str(self):
return f"{self.expr}"
def __str__(self):
return self.str()
def __repr__(self):
return self.str()
# These methods call the metaprogrammed methods, they're hand written
# here so we get good stack traces
def abs(self) -> "SymNode":
return self._abs() # type: ignore[attr-defined]
def pos(self) -> "SymNode":
return self._pos() # type: ignore[attr-defined]
def round(self, ndigits=None) -> "SymNode":
return self._round(ndigits) # type: ignore[attr-defined]
def add(self, other) -> "SymNode":
return self._add(other) # type: ignore[attr-defined]
def sub(self, other) -> "SymNode":
return self._sub(other) # type: ignore[attr-defined]
def mul(self, other) -> "SymNode":
return self._mul(other) # type: ignore[attr-defined]
def mod(self, other) -> "SymNode":
return self._mod(other) # type: ignore[attr-defined]
def pow(self, other) -> "SymNode":
return self._pow(other) # type: ignore[attr-defined]
def and_(self, other) -> "SymNode":
return self._and_(other) # type: ignore[attr-defined]
def or_(self, other) -> "SymNode":
return self._or_(other) # type: ignore[attr-defined]
def truediv(self, other) -> "SymNode":
return self._truediv(other) # type: ignore[attr-defined]
def floordiv(self, other) -> "SymNode":
return self._floordiv(other) # type: ignore[attr-defined]
def lshift(self, other) -> "SymNode":
return self._lshift(other) # type: ignore[attr-defined]
def rshift(self, other) -> "SymNode":
return self._rshift(other) # type: ignore[attr-defined]
def sym_not(self) -> "SymNode": # noqa: F811
return self._sym_not() # type: ignore[attr-defined]
def eq(self, other) -> "SymNode":
return self._eq(other) # type: ignore[attr-defined]
def ne(self, other) -> "SymNode":
return self._ne(other) # type: ignore[attr-defined]
def gt(self, other) -> "SymNode":
return self._gt(other) # type: ignore[attr-defined]
def lt(self, other) -> "SymNode":
return self._lt(other) # type: ignore[attr-defined]
def le(self, other) -> "SymNode":
return self._le(other) # type: ignore[attr-defined]
def ge(self, other) -> "SymNode":
return self._ge(other) # type: ignore[attr-defined]
def floor(self) -> "SymNode":
return self._floor() # type: ignore[attr-defined]
def is_integer(self) -> "SymNode":
return self._is_integer() # type: ignore[attr-defined]
def sym_float(self) -> "SymNode": # noqa: F811
return self._sym_float() # type: ignore[attr-defined]
def sym_int(self) -> "SymNode":
return self._sym_int() # type: ignore[attr-defined]
def ceil(self) -> "SymNode":
return self._ceil() # type: ignore[attr-defined]
def neg(self) -> "SymNode":
return self._neg() # type: ignore[attr-defined]
def sym_min(self, other) -> "SymNode": # noqa: F811
return self._sym_min(other) # type: ignore[attr-defined]
def sym_max(self, other) -> "SymNode": # noqa: F811
return self._sym_max(other) # type: ignore[attr-defined]
def sym_ite(self, then_val, else_val) -> "SymNode":
return self._sym_ite(then_val, else_val) # type: ignore[attr-defined]
def is_contiguous(self, sizes, strides) -> "SymNode":
return self._is_contiguous(sizes, strides) # type: ignore[attr-defined]
def is_channels_last_contiguous_2d(self, sizes, strides) -> "SymNode":
return self._is_channels_last_contiguous_2d(sizes, strides) # type: ignore[attr-defined]
def is_channels_last_contiguous_3d(self, sizes, strides) -> "SymNode":
return self._is_channels_last_contiguous_3d(sizes, strides) # type: ignore[attr-defined]
def is_channels_last_strides_2d(self, sizes, strides) -> "SymNode":
return self._is_channels_last_strides_2d(sizes, strides) # type: ignore[attr-defined]
def is_channels_last_strides_3d(self, sizes, strides) -> "SymNode":
return self._is_channels_last_strides_3d(sizes, strides) # type: ignore[attr-defined]
def is_non_overlapping_and_dense_indicator(self, sizes, strides) -> "SymNode":
return self._is_non_overlapping_and_dense_indicator(sizes, strides) # type: ignore[attr-defined]
# Make C++ happy
def sym_or(self, other):
return self.or_(other)
def sym_and(self, other):
return self.and_(other)
def is_non_overlapping_and_dense(self, sizes, strides):
return self.is_non_overlapping_and_dense_indicator(sizes, strides).eq(to_node(self, 1)) # type: ignore[attr-defined]
def int_(self):
return self.guard_int("", 0) # NB: uses Python backtrace
# You can manually trigger a guard with this function
def guard_int(self, file, line):
# TODO: use the file/line for some useful diagnostic on why a
# guard occurred
r = self.shape_env.evaluate_expr(self.expr, self.hint, fx_node=self.fx_node)
try:
return int(r)
except Exception:
log.warning("Failed to convert to int: %s", r)
raise
def guard_float(self, file, line):
# TODO: use the file/line for some useful diagnostic on why a
# guard occurred
r = self.shape_env.evaluate_expr(
self.expr, self.hint, fx_node=self.fx_node, expect_rational=False
)
try:
return float(r)
except Exception:
log.warning("Failed to convert to float: %s", r)
raise
def guard_bool(self, file, line):
# TODO: use the file/line for some useful diagnostic on why a
# guard occurred
r = self.shape_env.evaluate_expr(self.expr, self.hint, fx_node=self.fx_node)
try:
return bool(r)
except Exception:
log.warning("Failed to convert to bool: %s", r)
raise
def expect_true(self, file, line):
from torch.fx.experimental.symbolic_shapes import free_unbacked_symbols
if self.has_hint() and not free_unbacked_symbols(self.expr):
# OK to generate guards
return self.guard_bool(file, line)
# Generate a deferred runtime assert (this might actually end up doing
# a regular guard if we can!)
# TODO: file/line here is very important, because the assert has been
# deferred so you can't backtrace easily
return self.shape_env.defer_runtime_assert(
self.expr, f"{file}:{line}", fx_node=self.fx_node
)
def expect_size(self, file, line):
from torch.fx.experimental.symbolic_shapes import _advise_is_size
b = self.ge(self.wrap_int(0))
# Generate a deferred runtime assert
r = b.expect_true(file, line)
# Refine compile time range, but only if it's unbacked.
# If you refine range for hinted variables, you can end up making
# improper deductions since compile time reasoning may be
# incompatible with runtime reasoning.
if r and not self.has_hint():
_advise_is_size(SymInt(self))
return r
def guard_size_oblivious(self, file, line):
"""
Like guard_bool, but if we encounter unbacked symbols, if those symbols
are size-like, we will treat them as >= 2 for the purposes of the analysis.
This CHANGES the runtime semantics, but all size-oblivious sites have been
audited to ensure that the runtime semantics don't change in a material way.
Acceptable runtime semantic changes are, e.g., squeeze() no longer dropping
an unbacked one size, or a tensor reporting as non-contiguous even if it's
contiguous if it would have been reported contiguous due to being empty.
"""
# TODO: use the file/line for some useful diagnostic on why a
# guard occurred
r = self.shape_env.evaluate_expr(
self.expr, self.hint, fx_node=self.fx_node, size_oblivious=True
)
try:
return bool(r)
except Exception:
log.warning("Failed to convert to bool: %s", r)
raise
def bool_(self):
return self.guard_bool("", 0)
def is_symbolic(self):
return True
def nested_int(self):
return None
def is_constant(self):
return False
# TODO: this probably needs the sizes-strides eval functions
METHOD_TO_OPERATOR = {
"pos": operator.pos,
"abs": operator.abs,
"add": operator.add,
"and": operator.and_,
"ceil": math.ceil,
"eq": operator.eq,
"floor": math.floor,
"floordiv": operator.floordiv,
"ge": operator.ge,
"gt": operator.gt,
"is_integer": lambda x: x.is_integer(),
"le": operator.le,
"lshift": operator.lshift,
"lt": operator.lt,
"mod": operator.mod,
"mul": operator.mul,
"ne": operator.ne,
"neg": operator.neg,
"or": operator.or_,
"pow": operator.pow,
"round": builtins.round,
"rshift": operator.rshift,
"sub": operator.sub,
"sym_float": sym_float,
"sym_ite": sym_ite,
"sym_max": sym_max,
"sym_min": sym_min,
"sym_not": sym_not,
"truediv": operator.truediv,
}
unary_magic_methods = {
"abs",
"sym_float",
"ceil",
"floor",
"neg",
"sym_not",
"pos",
}
# Adding math ops: sqrt, cos, sin, ...
def _get_sym_node_fn(name):
def fn(self):
return getattr(self, f"_sym_{name}")()
return fn
math_op_names = (
"sqrt",
"cos",
"cosh",
"sin",
"sinh",
"tan",
"tanh",
"asin",
"acos",
"atan",
)
for name in math_op_names:
sym_name = f"sym_{name}"
priv_sym_name = f"_{sym_name}"
setattr(SymNode, sym_name, _get_sym_node_fn(name))
METHOD_TO_OPERATOR[sym_name] = getattr(torch, priv_sym_name)
unary_magic_methods.add(sym_name)
__all__.append(sym_name)
# Unary methods that are not magic methods
unary_nonmagic_methods = {
"is_integer",
}
unary_methods = unary_magic_methods | unary_nonmagic_methods
# Most methods are only registered on SymInt and SymFloat
# Some methods are only be registered on SymBool
only_bool_magic_methods = {"and", "or", "sym_not", "sym_ite"}
# Methods that implicitly convert SymBool into SymInt
bool_becomes_int_magic_methods = {"add", "sub", "mul"}
# Methods that are also on SymBool, in addition to on SymInt and SymFloat
also_bool_magic_methods = {"eq"}
bool_magic_methods = only_bool_magic_methods | also_bool_magic_methods
# Methods that are only for float
only_float_magic_methods = {"is_integer"}
magic_methods_on_operator_with_trailing_underscore = {"and", "or"}
always_float_magic_methods = {"truediv", "sym_float", "pow"}
for name in math_op_names:
sym_name = f"sym_{name}"
always_float_magic_methods.add(sym_name)
always_int_magic_methods = {"ceil", "floor"}
always_bool_magic_methods = {
"eq",
"ne",
"gt",
"lt",
"le",
"ge",
"and",
"or",
"sym_not",
"is_non_overlapping_and_dense",
"is_integer",
}
# Methods that have a `__foo__` as well as `__rfoo__`
def _sympy_truediv(a, b):
from torch.utils._sympy.functions import TrueDiv
return TrueDiv(a, b)
def _sympy_floordiv(a, b):
from torch.utils._sympy.functions import FloorDiv
return FloorDiv(a, b)
def _sympy_mod(a, b):
from torch.utils._sympy.functions import Mod
return Mod(a, b)
def _sympy_pow(a, b):
from torch.utils._sympy.functions import Pow
return Pow(a, b)
def _sympy_and(a, b):
import sympy
return sympy.And(a, b)
def _sympy_or(a, b):
import sympy
return sympy.Or(a, b)
def _sympy_lshift(a, b):
from torch.utils._sympy.functions import LShift
return LShift(a, b)
def _sympy_rshift(a, b):
from torch.utils._sympy.functions import RShift
return RShift(a, b)
reflectable_magic_methods = {
"add": operator.add,
"sub": operator.sub,
"mul": operator.mul,
"mod": _sympy_mod,
"pow": _sympy_pow,
"and": _sympy_and,
"or": _sympy_or,
"truediv": _sympy_truediv,
"floordiv": _sympy_floordiv,
"lshift": _sympy_lshift,
"rshift": _sympy_rshift,
}
def _floor_ceil_helper(a, fn):
import sympy
if isinstance(a, sympy.Mul):
aa = a.args
if len(aa) == 2 and isinstance(aa[0], sympy.Float) and aa[1].is_integer:
coef = sympy.Integer(aa[0])
if aa[0] == coef: # structural equality test
return coef * aa[1]
if (
isinstance(a, sympy.Float)
and a == sympy.Integer(a)
or isinstance(a, sympy.Integer)
):
return sympy.Integer(a)
return fn(a)
def _sympy_floor(a):
import sympy
return _floor_ceil_helper(a, sympy.floor)
def _sympy_ceil(a):
import sympy
return _floor_ceil_helper(a, sympy.ceiling)
def _sympy_eq(a, b):
import sympy
return sympy.Eq(a, b)
def _sympy_ne(a, b):
import sympy
return sympy.Ne(a, b)
def _sympy_gt(a, b):
import sympy
return sympy.Gt(a, b)
def _sympy_lt(a, b):
import sympy
return sympy.Lt(a, b)
def _sympy_le(a, b):
import sympy
return sympy.Le(a, b)
def _sympy_ge(a, b):
import sympy
return sympy.Ge(a, b)
def _sympy_min(a, b):
import sympy
return sympy.Min(a, b)
def _sympy_max(a, b):
import sympy
return sympy.Max(a, b)
def _sympy_ite(a, t, f):
import sympy
return sympy.Piecewise((t, a), (f, True))
current_module = sys.modules[__name__]
def _get_sym_math_fn(name):
def fn(a):
import sympy
return getattr(sympy, name)(a)
return fn
for name in math_op_names:
priv_sympy_name = f"_sympy_{name}"
fn = _get_sym_math_fn(name)
fn.__qualname__ = fn.__name__ = priv_sympy_name
setattr(current_module, priv_sympy_name, fn)
del fn, name, priv_sympy_name # type: ignore[possibly-undefined]
def _sympy_abs(a):
import sympy
return sympy.Abs(a)
def _sympy_round(number, ndigits=None):
from torch.utils._sympy.functions import Round, RoundDecimal
if ndigits is None:
return Round(number)
else:
return RoundDecimal(number, ndigits)
def _sympy_sym_float(a):
# Cannot use sympy.Float(a) here, coz it expects python literals
# Multiply by 1.0 to cast to float. This is needed when the input
# is a SymInt which has the assumption that it is integer and
# SymPy will otherwise assume that return value cannot be a float.
return a * 1.0
def _sympy_is_integer(a):
import sympy
return sympy.Eq(sympy.floor(a), a)
magic_methods = {
**reflectable_magic_methods,
"sym_not": operator.invert,
"pos": operator.pos,
"eq": _sympy_eq,
"ne": _sympy_ne,
"gt": _sympy_gt,
"lt": _sympy_lt,
"le": _sympy_le,
"ge": _sympy_ge,
"floor": _sympy_floor,
"sym_float": _sympy_sym_float,
"ceil": _sympy_ceil,
"neg": operator.neg,
"sym_min": _sympy_min,
"sym_max": _sympy_max,
"sym_ite": _sympy_ite,
"abs": _sympy_abs,
"round": _sympy_round,
"is_integer": _sympy_is_integer,
}
for name in math_op_names:
sym_name = f"sym_{name}"
magic_methods[sym_name] = getattr(current_module, f"_sympy_{name}")
del name, sym_name, math_op_names, current_module # type: ignore[possibly-undefined]
def sympy_is_contiguous(sizes, strides):
dim = len(sizes)
return sympy_is_contiguous_generic(sizes, strides, list(range(dim - 1, -1, -1)))
def sympy_is_contiguous_generic(sizes, strides, dim_order):
import sympy
dim = len(sizes)
if len(dim_order) != dim:
return sympy.false
is_contiguous = sympy.true
z = sympy.Integer(1)
# Contiguous if the strides make sense (or the dim is size 1)
for d in dim_order:
is_contiguous &= sympy.Eq(sizes[d], sympy.Integer(1)) | sympy.Eq(strides[d], z)
z *= sizes[d]
# OR if any size is zero
for d in range(dim):
is_contiguous |= sympy.Eq(sizes[d], sympy.Integer(0))
return is_contiguous
# NB: There is a TODO in C++ to allow omitting the batch dim. If that
# happens you will need to refactor this
def sympy_is_channels_last_contiguous_2d(sizes, strides):
return sympy_is_contiguous_generic(sizes, strides, [1, 3, 2, 0])
def sympy_is_channels_last_contiguous_3d(sizes, strides):
return sympy_is_contiguous_generic(sizes, strides, [1, 4, 3, 2, 0])
def sympy_is_channels_last_strides_generic(sizes, strides, dim_order):
import sympy
dim = len(sizes)
if dim != len(dim_order):
return sympy.false
m = sympy.Integer(0)
r = sympy.true
# special case for trivial C dimension. default to NCHW
r &= sympy.Ne(strides[1], 0)
for d in dim_order:
r &= sympy.Ne(sizes[d], 0) & (strides[d] >= m)
# Fallback to NCHW as default layout for ambiguous cases
# This is the flaw of implicit memory_format from strides.
# N111 tensor with identical strides for size 1 dimension;
# Two cases could lead us here:
# a. N111 contiguous Tensor ([N,1,1,1]@[1,1,1,1])
# b. N11W contiguous Tensor sliced on the W-dimension.
# ([N,1,1,1]@[W,W,W,W])
if d == 0:
r &= sympy.Ne(m, strides[1])
# This is necessary to:
# 1. distinguish the memory_format of N1H1;
# [H, 1, 1, 1] channels_last stride
# [H, H, 1, 1] contiguous stride
# 2. permutation of 1C1W:
# [1, C, 1, H]@[HC, H, H, 1] transpose(1, 3)
# [1, H, 1, C]@[HC, 1, H, H] shouldn't be identified as
# channels_last
m = strides[d] * sympy.Max(sizes[d], 1)
return r
def sympy_is_channels_last_strides_2d(sizes, strides):
return sympy_is_channels_last_strides_generic(sizes, strides, [1, 3, 2, 0])
def sympy_is_channels_last_strides_3d(sizes, strides):
return sympy_is_channels_last_strides_generic(sizes, strides, [1, 4, 3, 2, 0])
def _sympy_is_non_overlapping_and_dense_indicator(sizes, strides):
from torch.utils._sympy.functions import IsNonOverlappingAndDenseIndicator
return IsNonOverlappingAndDenseIndicator(*sizes, *strides)
sizes_strides_methods = {
# TODO: These could also be done with indicators, maybe it is better
# for reasoning to do it that way
"is_contiguous": sympy_is_contiguous,
"is_channels_last_contiguous_2d": sympy_is_channels_last_contiguous_2d,
"is_channels_last_contiguous_3d": sympy_is_channels_last_contiguous_3d,
"is_channels_last_strides_2d": sympy_is_channels_last_strides_2d,
"is_channels_last_strides_3d": sympy_is_channels_last_strides_3d,
"is_non_overlapping_and_dense_indicator": _sympy_is_non_overlapping_and_dense_indicator,
}
alternate_impl_if_hinted_methods = {
"sym_min": builtins.min,
"sym_max": builtins.max,
}
def to_node(self, num):
if isinstance(num, SymTypes):
return num.node
elif type(num) is bool:
return self.wrap_bool(num)
elif type(num) is int:
return self.wrap_int(num)
elif type(num) is float:
return self.wrap_float(num)
else:
# NotImplemented is important so that Python tries the
# other magic method
return NotImplemented
def wrap_node(x):
# TODO: let C++ also take advantage of this
if isinstance(x, SymNode) and x.constant is not None:
return x.constant
if x.is_int():
return SymInt(x)
elif x.is_float():
return SymFloat(x)
elif x.is_bool():
return SymBool(x)
else:
raise AssertionError(f"unrecognized return type {x}")
def method_to_operator(method):
return METHOD_TO_OPERATOR[method]
def _make_node_magic(method, func):
func = lru_cache(256)(func)
if method in magic_methods_on_operator_with_trailing_underscore:
method_attr = f"{method}_"
else:
method_attr = method
def binary_magic_impl(self, other):
from torch.fx.experimental.symbolic_shapes import safe_expand
op = method_to_operator(method)
out_hint = None
if self.hint is not None and other.hint is not None:
out_hint = op(self.hint, other.hint)
alternate_impl = alternate_impl_if_hinted_methods.get(method)
if alternate_impl and out_hint is not None:
return to_node(self, alternate_impl(wrap_node(self), wrap_node(other)))
if sym_function_mode():
return to_node(
self, handle_sym_dispatch(op, (wrap_node(self), wrap_node(other)), {})
)
assert isinstance(other, SymNode)
# TODO: consider constant prop here
try:
out = func(self.expr, other.expr)
except Exception:
log.warning("failed to eval %s(%s, %s)", method, self.expr, other.expr)
raise
out = safe_expand(out)
sym_node_log.debug("%s %s %s -> %s", func, self.expr, other.expr, out)
pytype: Type
# This is not strictly correct. In Python, a**b may return complex when
# a < 0 and b is a float: (-1)**2.1. Same for sympy.sqrt(-3.14). This
# returns a float while both arguments are ints: 2**(-1). Also, max and
# min do not type promote. To avoid having data-dependent control flow
# here, we just set the type to float if one of the args is a float. In
# case of a type mismatch, we assume that it will be detected during
# evaluation.
if method in always_float_magic_methods:
pytype = float
elif method in always_bool_magic_methods:
pytype = bool
elif self.pytype is float or other.pytype is float:
pytype = float
else:
pytype = self.pytype
if (
pytype is not None
and out_hint is not None
and not isinstance(out_hint, SymTypes)
):
out_hint = pytype(out_hint)
# Create a FX node that corresponds to the operation being applied to
# this node.
fx_node, _ = self.shape_env._create_fx_call_function(
op, (self.fx_node, other.fx_node)
)
return SymNode(out, self.shape_env, pytype, out_hint, fx_node=fx_node)
def unary_magic_impl(self):
from torch.fx.experimental.symbolic_shapes import safe_expand
op = method_to_operator(method)
if sym_function_mode():
return to_node(self, handle_sym_dispatch(op, (wrap_node(self),), {}))
# TODO: consider constant prop here
expr = self.expr
if method == "floor" or method == "ceiling":
expr = self.shape_env._simplify_floor_div(expr)
try:
out = func(expr)
except Exception:
log.warning("failed to eval %s(%s)", method, expr)
raise
sym_node_log.debug("%s %s -> %s", func, expr, out)
out_hint = None
if self.hint is not None:
out_hint = op(self.hint)
out = safe_expand(out)
pytype: Type
if method in always_int_magic_methods:
pytype = int
elif method in always_bool_magic_methods:
pytype = bool
elif method in always_float_magic_methods:
pytype = float
else:
pytype = self.pytype
fx_node, _ = self.shape_env._create_fx_call_function(op, (self.fx_node,))
return SymNode(out, self.shape_env, pytype, out_hint, fx_node=fx_node)
if method in unary_methods:
setattr(SymNode, f"_{method_attr}", unary_magic_impl)
elif method == "sym_ite":
def sym_ite_impl(pred_node, then_node, else_node):
from torch.fx.experimental.symbolic_shapes import safe_expand
out_hint = then_node.hint if pred_node.hint else else_node.hint
if sym_function_mode():
return to_node(
pred_node,
handle_sym_dispatch(
sym_ite,
(
wrap_node(pred_node),
wrap_node(then_node),
wrap_node(else_node),
),
{},
),
)
try:
out = func(pred_node.expr, then_node.expr, else_node.expr)
except Exception:
log.warning(
"failed to eval %s(%s, %s, %s)",
method,
pred_node.expr,
then_node.expr,
else_node.expr,
)
raise
out = safe_expand(out)
fx_node, _ = pred_node.shape_env._create_fx_call_function(
sym_ite, (pred_node.fx_node, then_node.fx_node, else_node.fx_node)
)
return SymNode(
out, pred_node.shape_env, then_node.pytype, out_hint, fx_node=fx_node
)
setattr(SymNode, f"_{method_attr}", sym_ite_impl)
elif method == "round":
def round_impl(self, ndigits=None):
from torch.fx.experimental.symbolic_shapes import safe_expand
op = builtins.round
if sym_function_mode():
return to_node(
self, handle_sym_dispatch(op, (wrap_node(self), ndigits), {})
)
expr = self.expr
try:
out = func(expr, ndigits)
except Exception:
log.warning("failed to eval %s(%s, ndigits=%s)", method, expr, ndigits)
raise
out = safe_expand(out)
pytype = int if ndigits is None else self.pytype
out_hint = None
if self.hint is not None:
out_hint = op(self.hint, ndigits)
# Internally, None is used as sentinel to indicate that a something is not a node on an FX graph. At the
# same time, there is no way to wrap a plain None into an FX node. Thus, there is no way to pass None here
# without triggering some asserts that check whether we are mixing FX nodes with untracked arguments. The
# hack down below works, because all round function down the line all take ndigits=None as default in their
# signature.
# TODO: Remove the args construction below if a different sentinel is used by FX.
args = [self.fx_node]
if ndigits is not None:
args.append(ndigits)
fx_node, _ = self.shape_env._create_fx_call_function(op, tuple(args))
return SymNode(out, self.shape_env, pytype, out_hint, fx_node=fx_node)
setattr(SymNode, f"_{method_attr}", round_impl)
else:
setattr(SymNode, f"_{method_attr}", binary_magic_impl)
def _make_node_sizes_strides(method, func):
# NB: don't LRU cache, lots of arguments
def sizes_strides_impl(self, sizes, strides):
op = getattr(sys.modules[__name__], method)
if sym_function_mode():
return to_node(
self,
handle_sym_dispatch(
op,
([wrap_node(s) for s in sizes], [wrap_node(s) for s in strides]),
{},
),
)
size_exprs = [s.expr for s in sizes]
stride_exprs = [s.expr for s in strides]
try:
out = func(size_exprs, stride_exprs)
except Exception:
log.warning("failed to eval %s(%s, %s)", method, size_exprs, stride_exprs)
raise
# bool is never expandable
size_hints = []
out_hint = None
for s in sizes:
if s.hint is None:
break
size_hints.append(s.hint)
else:
stride_hints = []
for s in strides:
if s.hint is None:
break
stride_hints.append(s.hint)
else:
out_hint = op(size_hints, stride_hints)
# NB: This is the indicator function, not the actual bool!
pytype: Type
if method.endswith("_indicator"):
pytype = int
else:
pytype = bool
return SymNode(out, self.shape_env, pytype, out_hint)
setattr(SymNode, f"_{method}", sizes_strides_impl)
# TODO: This is technically hotpath, but in the ideal end state
# guards on this will resolve at a higher level so you never
# spend time in this code
def sizes_strides_user(sizes, strides):
import sympy
from torch.fx.experimental.symbolic_shapes import (
eval_is_non_overlapping_and_dense,
)
for a in itertools.chain(sizes, strides):
if isinstance(a, SymInt):
return wrap_node(
getattr(a.node, method)(
[to_node(a.node, b) for b in sizes],
[to_node(a.node, b) for b in strides],
)
)
if method == "is_non_overlapping_and_dense_indicator":
return eval_is_non_overlapping_and_dense(sizes, strides)
else:
# TODO: this is an awful implementation
return bool(
func(
[sympy.sympify(a) for a in sizes],
[sympy.sympify(a) for a in strides],
)
)
# Skip for is_non_overlapping_and_dense_indicator
if not hasattr(sys.modules[__name__], method):
setattr(sys.modules[__name__], method, sizes_strides_user)
for method, func in magic_methods.items():
_make_node_magic(method, func)
for method, func in sizes_strides_methods.items():
_make_node_sizes_strides(method, func)
def _make_user_magic(method, user_type):
# User magic takes care of wrapping the other operand into a node,
# so that our internal logic can assume everything is nodes
if method in magic_methods_on_operator_with_trailing_underscore:
method_attr = f"sym_{method}"
else:
method_attr = method
def get_constant(x: Union[SymInt, int, SymFloat, float, SymBool, bool]):
if isinstance(x, (int, float, bool)):
return x
if isinstance(x, SymBool):
return x.node.guard_bool("", 0)
raise AssertionError("expect to be called with constant SymBools")
def is_constant(x):
if isinstance(x, (int, float, bool)):
return True
if isinstance(x, (SymInt, SymFloat, SymBool)):
return x.node.is_constant()
return False
if method in bool_becomes_int_magic_methods:
def promote(x):
"""Implements True+True=2, which works in python but not sympy"""
if isinstance(x, SymBool):
return SymInt(x.node.wrap_int(int(x)))
return x
else:
def promote(x):
return x
# Before and after performing the operation, check if any operands are constant.
# If so, extract out the constant values first. If `self` itself is a
# constant, then "redispatch" by calling back into the operator. Sometimes
# this means that operations involving SymBool return plain bools.
# Alternatively, we could also rewrap into constant Symbool (i.e. by
# implementing wrap_bool in ConstantSymNodeImpl), but we're not doing that
# today for no particular reason.
def unary_magic_impl(self):
self = promote(self)
if is_constant(self):
return (method_to_operator(method))(get_constant(self))
return wrap_node(getattr(self.node, method_attr)())
def binary_magic_impl(self, other):
sym_node_log.debug("MAGIC %s %s %s", method, self, other)
self = promote(self)
other = promote(other)
if is_constant(self):
return (method_to_operator(method))(get_constant(self), other)
if is_constant(other):
other = get_constant(other)
other_node = to_node(self.node, other)
if other_node is NotImplemented:
return NotImplemented
ret = wrap_node(getattr(self.node, method_attr)(other_node))
return get_constant(ret) if is_constant(ret) else ret
def rbinary_magic_impl(self, other):
self = promote(self)
other = promote(other)
if is_constant(self):
return (method_to_operator(method))(get_constant(self), other)
if is_constant(other):
other = get_constant(other)
other_node = to_node(self.node, other)
if other_node is NotImplemented:
return NotImplemented
ret = wrap_node(getattr(other_node, method_attr)(self.node))
return get_constant(ret) if is_constant(ret) else ret
if method in unary_magic_methods:
setattr(user_type, f"__{method}__", unary_magic_impl)
elif method in unary_nonmagic_methods:
orig = getattr(user_type, method)
setattr(user_type, method, update_wrapper(unary_magic_impl, orig))
elif method == "sym_ite":
def sym_ite_magic_impl(pred, then_val, else_val):
pred_node = pred.node
then_node = to_node(pred_node, then_val)
else_node = to_node(pred_node, else_val)
if then_node is NotImplemented or else_node is NotImplemented:
return NotImplemented
assert (
isinstance(then_node, SymNode)
and isinstance(else_node, SymNode)
and then_node.pytype == else_node.pytype
)
ret = wrap_node(getattr(pred.node, method_attr)(then_node, else_node))
return get_constant(ret) if ret.node.is_constant() else ret
setattr(user_type, f"__{method}__", sym_ite_magic_impl)
elif method == "round":
def round_magic_impl(self, ndigits=None):
if is_constant(self):
return builtins.round(get_constant(self), ndigits)
return wrap_node(getattr(self.node, method)(ndigits))
setattr(user_type, f"__{method}__", round_magic_impl)
else:
setattr(user_type, f"__{method}__", binary_magic_impl)
if method in reflectable_magic_methods:
setattr(user_type, f"__r{method}__", rbinary_magic_impl)
for method, func in magic_methods.items(): # type: ignore[assignment]
if method in only_bool_magic_methods:
_make_user_magic(method, SymBool)
continue
if method in only_float_magic_methods:
_make_user_magic(method, SymFloat)
continue
if method in also_bool_magic_methods or method in bool_becomes_int_magic_methods:
_make_user_magic(method, SymBool)
_make_user_magic(method, SymInt)
_make_user_magic(method, SymFloat)
del method
del func