import sympy from sympy import S from sympy.core.logic import fuzzy_and, fuzzy_not, fuzzy_or __all__ = [ "FloorDiv", "ModularIndexing", "CleanDiv", "CeilDiv", "Pow", "TrueDiv", "LShift", "RShift", "IsNonOverlappingAndDenseIndicator", "Round", "RoundDecimal", ] def fuzzy_eq(x, y): if None in (x, y): return None return x == y class FloorDiv(sympy.Function): """ We maintain this so that: 1. We can use divisibility guards to simplify FloorDiv(a, b) to a / b. 2. Printing out the expression is nicer (compared to say, representing a//b as (a - a % b) / b) """ nargs = (2,) precedence = 50 # precedence of mul # noqa: F811 # Default return type for SymPy assumptions. # https://docs.sympy.org/latest/guides/assumptions.html#implementing-assumptions-handlers is_real = True @property def base(self): return self.args[0] @property def divisor(self): return self.args[1] def _sympystr(self, printer): base = printer.parenthesize(self.base, self.precedence) divisor = printer.parenthesize(self.divisor, self.precedence) return f"({base}//{divisor})" # SymPy assumptions based on argument types. def _eval_is_real(self): return fuzzy_or([self.base.is_real, self.divisor.is_real]) def _eval_is_integer(self): return fuzzy_and([self.base.is_integer, self.divisor.is_integer]) # Automatic evaluation. # https://docs.sympy.org/latest/guides/custom-functions.html#best-practices-for-eval @classmethod def eval(cls, base, divisor): def check_supported_type(x): if (x.is_integer is False and x.is_real is False and x.is_complex) or x.is_Boolean: raise TypeError( f"unsupported operand type(s) for //: " f"'{type(base).__name__}' and '{type(divisor).__name__}'" f", expected integer or real") check_supported_type(base) check_supported_type(divisor) # We don't provide the same error message as in Python because SymPy # makes it difficult to check the types. if divisor.is_zero: raise ZeroDivisionError("division by zero") if base.is_zero: return sympy.S.Zero if base.is_integer and divisor == 1: return base if base.is_real and divisor == 1: return sympy.floor(base) if base.is_integer and divisor == -1: return sympy.Mul(base, -1) if isinstance(base, sympy.Integer) and isinstance(divisor, sympy.Integer): return base // divisor if isinstance(base, (sympy.Integer, sympy.Float)) and isinstance(divisor, (sympy.Integer, sympy.Float)): return sympy.floor(base / divisor) if isinstance(base, FloorDiv): return FloorDiv(base.args[0], base.args[1] * divisor) if isinstance(divisor, sympy.Rational) and divisor.p == 1: return sympy.floor(base * divisor.q) if isinstance(base, sympy.Add): for a in base.args: gcd = sympy.gcd(a, divisor) if gcd == divisor: return FloorDiv(base - a, divisor) + a / gcd try: gcd = sympy.gcd(base, divisor) if gcd != 1: return FloorDiv( sympy.simplify(base / gcd), sympy.simplify(divisor / gcd) ) except sympy.PolynomialError: pass # https://github.com/pytorch/pytorch/issues/108276 class ModularIndexing(sympy.Function): """ ModularIndexing(a, b, c) => (a // b) % c where % is the C modulus """ nargs = (3,) is_integer = True @classmethod def eval(cls, base, divisor, modulus): if base == 0 or modulus == 1: return sympy.Integer(0) if ( isinstance(base, sympy.Integer) and isinstance(divisor, sympy.Integer) and isinstance(modulus, sympy.Integer) ): return (base // divisor) % modulus try: if divisor != 1: gcd = sympy.gcd(base, divisor) if gcd != 1: return ModularIndexing( sympy.simplify(base / gcd), sympy.simplify(divisor / gcd), modulus ) except sympy.PolynomialError: pass # https://github.com/pytorch/pytorch/issues/108276 if isinstance(base, sympy.Add): new_terms = [] all_positive = True for term in base.args: if sympy.gcd(term, modulus * divisor) != modulus * divisor: if (isinstance(term, sympy.Integer) and term < 0) or ( isinstance(term, sympy.Mul) and isinstance(term.args[0], sympy.Integer) and term.args[0] < 0 ): # workaround for https://github.com/openai/triton/issues/619, # if there are negative terms, // produces wrong result # TODO if https://github.com/openai/triton/issues/619 is fixed # this optimization would become valid all_positive = False break else: new_terms.append(term) if len(new_terms) != len(base.args) and all_positive: return ModularIndexing(sum(new_terms), divisor, modulus) if isinstance(base, FloorDiv): return ModularIndexing(base.args[0], base.args[1] * divisor, modulus) def _eval_is_nonnegative(self): p, q = self.args[:2] return fuzzy_eq(p.is_nonnegative, q.is_nonnegative) # type: ignore[attr-defined] def _eval_is_positive(self): p, q = self.args[:2] return fuzzy_eq(p.is_positive, q.is_positive) # type: ignore[attr-defined] class Where(sympy.Function): """ Good ol' ternary operator """ nargs = (3,) @classmethod def eval(cls, c, p, q): if c == sympy.true: return p elif c == sympy.false: return q class Mod(sympy.Function): """ We maintain this so that we avoid SymPy correctness issues, such as: https://github.com/sympy/sympy/issues/25146 """ nargs = (2,) @classmethod def eval(cls, p, q): # This was adapted from: sympy/core/mod.py if q.is_zero: raise ZeroDivisionError("Modulo by zero") # If either of them is NaN or infinite. if p is S.NaN or q is S.NaN or p.is_finite is False or q.is_finite is False: return S.NaN # Three cases: # 1. p == 0 # 2. p is either q or -q # 3. p is integer and q == 1 if p is S.Zero or p in (q, -q) or (p.is_integer and q == 1): return S.Zero # Evaluate if they are both literals. if q.is_Number and p.is_Number: return p % q # If q == 2, it's a matter of whether p is odd or even. if q.is_Number and q == 2: if p.is_even: return S.Zero if p.is_odd: return S.One # If p is a multiple of q. r = p / q if r.is_integer: return S.Zero # If p < q and its ratio is positive, then: # - floor(p / q) = 0 # - p % q = p - floor(p / q) * q = p less = p < q if less.is_Boolean and bool(less) and r.is_positive: return p def _eval_is_integer(self): p, q = self.args return fuzzy_and([p.is_integer, q.is_integer, fuzzy_not(q.is_zero)]) # type: ignore[attr-defined] def _eval_is_nonnegative(self): return True if self.args[1].is_positive else None # type: ignore[attr-defined] def _eval_is_nonpositive(self): return True if self.args[1].is_negative else None # type: ignore[attr-defined] class CleanDiv(FloorDiv): """ Div where we can assume no rounding. This is to enable future optimizations. """ pass class CeilDiv(sympy.Function): """ Div used in indexing that rounds up. """ is_integer = True def __new__(cls, base, divisor): if sympy.gcd(base, divisor) == divisor: return CleanDiv(base, divisor) else: return FloorDiv(base + (divisor - 1), divisor) class LShift(sympy.Function): @classmethod def eval(cls, base, shift): if shift < 0: raise ValueError('negative shift count') return base * 2 ** shift class RShift(sympy.Function): @classmethod def eval(cls, base, shift): if shift < 0: raise ValueError('negative shift count') return base // 2 ** shift # Overloaded to be compatible with regular Python. # https://github.com/pytorch/pytorch/issues/90900 class Pow(sympy.Function): @classmethod def eval(cls, base, exp): if exp.is_zero: return sympy.Integer(1) elif base.is_zero and exp < 0: raise ZeroDivisionError(f"{base} cannot be raised to a negative power") else: return base ** exp # Overloaded to be compatible with regular Python. # https://github.com/pytorch/pytorch/issues/90900 class TrueDiv(sympy.Function): @classmethod def eval(cls, base, divisor): if divisor.is_zero: raise ZeroDivisionError("division by zero") else: return base / divisor # TODO: As an indicator, this != 0 implies == 1 (and vice versa). # Because we do not have the ability to guard on the stride permutation # at the moment, it is hard to make further inferences when this is true, # as although we know the tensor is contiguous in *some* layout, we don't # know which one (however, you could, for example, make the inference that # reshaping this to a 1D tensor can be guard-free.) class IsNonOverlappingAndDenseIndicator(sympy.Function): is_integer = True @classmethod def eval(cls, *args): assert len(args) % 2 == 0 dim = len(args) // 2 # TODO: it is possible to make progress evaluating this guard # even if not all of the inputs are known. For example, a 2D # tensor with non-0/1 sizes but strides (0, 1) is definitely # false, because we know its numel > 1 but it's broadcasted # in dim 0. if all(isinstance(a, sympy.Integer) for a in args): # sym_node imported in torch.__init__. Local import to avoid an import cycle from torch.fx.experimental.symbolic_shapes import eval_is_non_overlapping_and_dense size_args = args[0:dim] stride_args = args[dim:] return eval_is_non_overlapping_and_dense( [int(a) for a in size_args], [int(a) for a in stride_args] ) return None class Round(sympy.Function): is_integer = True @classmethod def eval(cls, number): if number.is_integer: return number elif isinstance(number, sympy.Number): return sympy.Integer(round(float(number))) def __int__(self): # This will only ever be called when computing size hints. At that point, self.args[0] should be a number and # no longer an expression. If it were, the float call would fail and the caller would handle this further. return round(float(self.args[0])) # type: ignore[arg-type] class RoundDecimal(sympy.Function): @classmethod def eval(cls, number, ndigits): if number.is_integer and ndigits >= 0: return number elif isinstance(number, sympy.Number) and isinstance(ndigits, sympy.Integer): value_type, output_type = (int, sympy.Integer) if isinstance(number, sympy.Integer) else (float, sympy.Float) return output_type(round(value_type(number), int(ndigits)))