354 lines
12 KiB
Python
354 lines
12 KiB
Python
|
import sympy
|
||
|
from sympy import S
|
||
|
from sympy.core.logic import fuzzy_and, fuzzy_not, fuzzy_or
|
||
|
|
||
|
__all__ = [
|
||
|
"FloorDiv", "ModularIndexing", "CleanDiv", "CeilDiv", "Pow", "TrueDiv",
|
||
|
"LShift", "RShift", "IsNonOverlappingAndDenseIndicator", "Round", "RoundDecimal",
|
||
|
]
|
||
|
|
||
|
|
||
|
def fuzzy_eq(x, y):
|
||
|
if None in (x, y):
|
||
|
return None
|
||
|
return x == y
|
||
|
|
||
|
|
||
|
class FloorDiv(sympy.Function):
|
||
|
"""
|
||
|
We maintain this so that:
|
||
|
1. We can use divisibility guards to simplify FloorDiv(a, b) to a / b.
|
||
|
2. Printing out the expression is nicer (compared to say, representing a//b as (a - a % b) / b)
|
||
|
"""
|
||
|
nargs = (2,)
|
||
|
precedence = 50 # precedence of mul # noqa: F811
|
||
|
|
||
|
# Default return type for SymPy assumptions.
|
||
|
# https://docs.sympy.org/latest/guides/assumptions.html#implementing-assumptions-handlers
|
||
|
is_real = True
|
||
|
|
||
|
@property
|
||
|
def base(self):
|
||
|
return self.args[0]
|
||
|
|
||
|
@property
|
||
|
def divisor(self):
|
||
|
return self.args[1]
|
||
|
|
||
|
def _sympystr(self, printer):
|
||
|
base = printer.parenthesize(self.base, self.precedence)
|
||
|
divisor = printer.parenthesize(self.divisor, self.precedence)
|
||
|
return f"({base}//{divisor})"
|
||
|
|
||
|
# SymPy assumptions based on argument types.
|
||
|
def _eval_is_real(self):
|
||
|
return fuzzy_or([self.base.is_real, self.divisor.is_real])
|
||
|
|
||
|
def _eval_is_integer(self):
|
||
|
return fuzzy_and([self.base.is_integer, self.divisor.is_integer])
|
||
|
|
||
|
# Automatic evaluation.
|
||
|
# https://docs.sympy.org/latest/guides/custom-functions.html#best-practices-for-eval
|
||
|
@classmethod
|
||
|
def eval(cls, base, divisor):
|
||
|
def check_supported_type(x):
|
||
|
if (x.is_integer is False and x.is_real is False and x.is_complex) or x.is_Boolean:
|
||
|
raise TypeError(
|
||
|
f"unsupported operand type(s) for //: "
|
||
|
f"'{type(base).__name__}' and '{type(divisor).__name__}'"
|
||
|
f", expected integer or real")
|
||
|
|
||
|
check_supported_type(base)
|
||
|
check_supported_type(divisor)
|
||
|
|
||
|
# We don't provide the same error message as in Python because SymPy
|
||
|
# makes it difficult to check the types.
|
||
|
if divisor.is_zero:
|
||
|
raise ZeroDivisionError("division by zero")
|
||
|
|
||
|
if base.is_zero:
|
||
|
return sympy.S.Zero
|
||
|
if base.is_integer and divisor == 1:
|
||
|
return base
|
||
|
if base.is_real and divisor == 1:
|
||
|
return sympy.floor(base)
|
||
|
if base.is_integer and divisor == -1:
|
||
|
return sympy.Mul(base, -1)
|
||
|
if isinstance(base, sympy.Integer) and isinstance(divisor, sympy.Integer):
|
||
|
return base // divisor
|
||
|
if isinstance(base, (sympy.Integer, sympy.Float)) and isinstance(divisor, (sympy.Integer, sympy.Float)):
|
||
|
return sympy.floor(base / divisor)
|
||
|
if isinstance(base, FloorDiv):
|
||
|
return FloorDiv(base.args[0], base.args[1] * divisor)
|
||
|
if isinstance(divisor, sympy.Rational) and divisor.p == 1:
|
||
|
return sympy.floor(base * divisor.q)
|
||
|
|
||
|
if isinstance(base, sympy.Add):
|
||
|
for a in base.args:
|
||
|
gcd = sympy.gcd(a, divisor)
|
||
|
if gcd == divisor:
|
||
|
return FloorDiv(base - a, divisor) + a / gcd
|
||
|
|
||
|
try:
|
||
|
gcd = sympy.gcd(base, divisor)
|
||
|
if gcd != 1:
|
||
|
return FloorDiv(
|
||
|
sympy.simplify(base / gcd), sympy.simplify(divisor / gcd)
|
||
|
)
|
||
|
except sympy.PolynomialError:
|
||
|
pass # https://github.com/pytorch/pytorch/issues/108276
|
||
|
|
||
|
|
||
|
class ModularIndexing(sympy.Function):
|
||
|
"""
|
||
|
ModularIndexing(a, b, c) => (a // b) % c where % is the C modulus
|
||
|
"""
|
||
|
|
||
|
nargs = (3,)
|
||
|
is_integer = True
|
||
|
|
||
|
@classmethod
|
||
|
def eval(cls, base, divisor, modulus):
|
||
|
if base == 0 or modulus == 1:
|
||
|
return sympy.Integer(0)
|
||
|
|
||
|
if (
|
||
|
isinstance(base, sympy.Integer)
|
||
|
and isinstance(divisor, sympy.Integer)
|
||
|
and isinstance(modulus, sympy.Integer)
|
||
|
):
|
||
|
return (base // divisor) % modulus
|
||
|
|
||
|
try:
|
||
|
if divisor != 1:
|
||
|
gcd = sympy.gcd(base, divisor)
|
||
|
if gcd != 1:
|
||
|
return ModularIndexing(
|
||
|
sympy.simplify(base / gcd), sympy.simplify(divisor / gcd), modulus
|
||
|
)
|
||
|
except sympy.PolynomialError:
|
||
|
pass # https://github.com/pytorch/pytorch/issues/108276
|
||
|
|
||
|
if isinstance(base, sympy.Add):
|
||
|
new_terms = []
|
||
|
all_positive = True
|
||
|
for term in base.args:
|
||
|
if sympy.gcd(term, modulus * divisor) != modulus * divisor:
|
||
|
if (isinstance(term, sympy.Integer) and term < 0) or (
|
||
|
isinstance(term, sympy.Mul)
|
||
|
and isinstance(term.args[0], sympy.Integer)
|
||
|
and term.args[0] < 0
|
||
|
):
|
||
|
# workaround for https://github.com/openai/triton/issues/619,
|
||
|
# if there are negative terms, // produces wrong result
|
||
|
# TODO if https://github.com/openai/triton/issues/619 is fixed
|
||
|
# this optimization would become valid
|
||
|
all_positive = False
|
||
|
break
|
||
|
else:
|
||
|
new_terms.append(term)
|
||
|
|
||
|
if len(new_terms) != len(base.args) and all_positive:
|
||
|
return ModularIndexing(sum(new_terms), divisor, modulus)
|
||
|
|
||
|
if isinstance(base, FloorDiv):
|
||
|
return ModularIndexing(base.args[0], base.args[1] * divisor, modulus)
|
||
|
|
||
|
def _eval_is_nonnegative(self):
|
||
|
p, q = self.args[:2]
|
||
|
return fuzzy_eq(p.is_nonnegative, q.is_nonnegative) # type: ignore[attr-defined]
|
||
|
|
||
|
def _eval_is_positive(self):
|
||
|
p, q = self.args[:2]
|
||
|
return fuzzy_eq(p.is_positive, q.is_positive) # type: ignore[attr-defined]
|
||
|
|
||
|
|
||
|
class Where(sympy.Function):
|
||
|
"""
|
||
|
Good ol' ternary operator
|
||
|
"""
|
||
|
|
||
|
nargs = (3,)
|
||
|
|
||
|
@classmethod
|
||
|
def eval(cls, c, p, q):
|
||
|
if c == sympy.true:
|
||
|
return p
|
||
|
elif c == sympy.false:
|
||
|
return q
|
||
|
|
||
|
class Mod(sympy.Function):
|
||
|
"""
|
||
|
We maintain this so that we avoid SymPy correctness issues, such as:
|
||
|
https://github.com/sympy/sympy/issues/25146
|
||
|
"""
|
||
|
|
||
|
nargs = (2,)
|
||
|
|
||
|
@classmethod
|
||
|
def eval(cls, p, q):
|
||
|
# This was adapted from: sympy/core/mod.py
|
||
|
|
||
|
if q.is_zero:
|
||
|
raise ZeroDivisionError("Modulo by zero")
|
||
|
# If either of them is NaN or infinite.
|
||
|
if p is S.NaN or q is S.NaN or p.is_finite is False or q.is_finite is False:
|
||
|
return S.NaN
|
||
|
# Three cases:
|
||
|
# 1. p == 0
|
||
|
# 2. p is either q or -q
|
||
|
# 3. p is integer and q == 1
|
||
|
if p is S.Zero or p in (q, -q) or (p.is_integer and q == 1):
|
||
|
return S.Zero
|
||
|
|
||
|
# Evaluate if they are both literals.
|
||
|
if q.is_Number and p.is_Number:
|
||
|
return p % q
|
||
|
|
||
|
# If q == 2, it's a matter of whether p is odd or even.
|
||
|
if q.is_Number and q == 2:
|
||
|
if p.is_even:
|
||
|
return S.Zero
|
||
|
if p.is_odd:
|
||
|
return S.One
|
||
|
|
||
|
# If p is a multiple of q.
|
||
|
r = p / q
|
||
|
if r.is_integer:
|
||
|
return S.Zero
|
||
|
|
||
|
# If p < q and its ratio is positive, then:
|
||
|
# - floor(p / q) = 0
|
||
|
# - p % q = p - floor(p / q) * q = p
|
||
|
less = p < q
|
||
|
if less.is_Boolean and bool(less) and r.is_positive:
|
||
|
return p
|
||
|
|
||
|
def _eval_is_integer(self):
|
||
|
p, q = self.args
|
||
|
return fuzzy_and([p.is_integer, q.is_integer, fuzzy_not(q.is_zero)]) # type: ignore[attr-defined]
|
||
|
|
||
|
def _eval_is_nonnegative(self):
|
||
|
return True if self.args[1].is_positive else None # type: ignore[attr-defined]
|
||
|
|
||
|
def _eval_is_nonpositive(self):
|
||
|
return True if self.args[1].is_negative else None # type: ignore[attr-defined]
|
||
|
|
||
|
|
||
|
class CleanDiv(FloorDiv):
|
||
|
"""
|
||
|
Div where we can assume no rounding.
|
||
|
This is to enable future optimizations.
|
||
|
"""
|
||
|
|
||
|
pass
|
||
|
|
||
|
|
||
|
class CeilDiv(sympy.Function):
|
||
|
"""
|
||
|
Div used in indexing that rounds up.
|
||
|
"""
|
||
|
|
||
|
is_integer = True
|
||
|
|
||
|
def __new__(cls, base, divisor):
|
||
|
if sympy.gcd(base, divisor) == divisor:
|
||
|
return CleanDiv(base, divisor)
|
||
|
else:
|
||
|
return FloorDiv(base + (divisor - 1), divisor)
|
||
|
|
||
|
|
||
|
class LShift(sympy.Function):
|
||
|
@classmethod
|
||
|
def eval(cls, base, shift):
|
||
|
if shift < 0:
|
||
|
raise ValueError('negative shift count')
|
||
|
return base * 2 ** shift
|
||
|
|
||
|
|
||
|
class RShift(sympy.Function):
|
||
|
@classmethod
|
||
|
def eval(cls, base, shift):
|
||
|
if shift < 0:
|
||
|
raise ValueError('negative shift count')
|
||
|
return base // 2 ** shift
|
||
|
|
||
|
# Overloaded to be compatible with regular Python.
|
||
|
# https://github.com/pytorch/pytorch/issues/90900
|
||
|
class Pow(sympy.Function):
|
||
|
@classmethod
|
||
|
def eval(cls, base, exp):
|
||
|
if exp.is_zero:
|
||
|
return sympy.Integer(1)
|
||
|
elif base.is_zero and exp < 0:
|
||
|
raise ZeroDivisionError(f"{base} cannot be raised to a negative power")
|
||
|
else:
|
||
|
return base ** exp
|
||
|
|
||
|
# Overloaded to be compatible with regular Python.
|
||
|
# https://github.com/pytorch/pytorch/issues/90900
|
||
|
class TrueDiv(sympy.Function):
|
||
|
@classmethod
|
||
|
def eval(cls, base, divisor):
|
||
|
if divisor.is_zero:
|
||
|
raise ZeroDivisionError("division by zero")
|
||
|
else:
|
||
|
return base / divisor
|
||
|
|
||
|
|
||
|
# TODO: As an indicator, this != 0 implies == 1 (and vice versa).
|
||
|
# Because we do not have the ability to guard on the stride permutation
|
||
|
# at the moment, it is hard to make further inferences when this is true,
|
||
|
# as although we know the tensor is contiguous in *some* layout, we don't
|
||
|
# know which one (however, you could, for example, make the inference that
|
||
|
# reshaping this to a 1D tensor can be guard-free.)
|
||
|
class IsNonOverlappingAndDenseIndicator(sympy.Function):
|
||
|
is_integer = True
|
||
|
|
||
|
@classmethod
|
||
|
def eval(cls, *args):
|
||
|
assert len(args) % 2 == 0
|
||
|
dim = len(args) // 2
|
||
|
# TODO: it is possible to make progress evaluating this guard
|
||
|
# even if not all of the inputs are known. For example, a 2D
|
||
|
# tensor with non-0/1 sizes but strides (0, 1) is definitely
|
||
|
# false, because we know its numel > 1 but it's broadcasted
|
||
|
# in dim 0.
|
||
|
if all(isinstance(a, sympy.Integer) for a in args):
|
||
|
# sym_node imported in torch.__init__. Local import to avoid an import cycle
|
||
|
from torch.fx.experimental.symbolic_shapes import eval_is_non_overlapping_and_dense
|
||
|
|
||
|
size_args = args[0:dim]
|
||
|
stride_args = args[dim:]
|
||
|
return eval_is_non_overlapping_and_dense(
|
||
|
[int(a) for a in size_args],
|
||
|
[int(a) for a in stride_args]
|
||
|
)
|
||
|
return None
|
||
|
|
||
|
|
||
|
class Round(sympy.Function):
|
||
|
is_integer = True
|
||
|
|
||
|
@classmethod
|
||
|
def eval(cls, number):
|
||
|
if number.is_integer:
|
||
|
return number
|
||
|
elif isinstance(number, sympy.Number):
|
||
|
return sympy.Integer(round(float(number)))
|
||
|
|
||
|
def __int__(self):
|
||
|
# This will only ever be called when computing size hints. At that point, self.args[0] should be a number and
|
||
|
# no longer an expression. If it were, the float call would fail and the caller would handle this further.
|
||
|
return round(float(self.args[0])) # type: ignore[arg-type]
|
||
|
|
||
|
|
||
|
class RoundDecimal(sympy.Function):
|
||
|
@classmethod
|
||
|
def eval(cls, number, ndigits):
|
||
|
if number.is_integer and ndigits >= 0:
|
||
|
return number
|
||
|
elif isinstance(number, sympy.Number) and isinstance(ndigits, sympy.Integer):
|
||
|
value_type, output_type = (int, sympy.Integer) if isinstance(number, sympy.Integer) else (float, sympy.Float)
|
||
|
return output_type(round(value_type(number), int(ndigits)))
|