2234 lines
71 KiB
Python
2234 lines
71 KiB
Python
|
"""Base class for all the objects in SymPy"""
|
||
|
from __future__ import annotations
|
||
|
|
||
|
from collections import defaultdict
|
||
|
from collections.abc import Mapping
|
||
|
from itertools import chain, zip_longest
|
||
|
|
||
|
from .assumptions import _prepare_class_assumptions
|
||
|
from .cache import cacheit
|
||
|
from .core import ordering_of_classes
|
||
|
from .sympify import _sympify, sympify, SympifyError, _external_converter
|
||
|
from .sorting import ordered
|
||
|
from .kind import Kind, UndefinedKind
|
||
|
from ._print_helpers import Printable
|
||
|
|
||
|
from sympy.utilities.decorator import deprecated
|
||
|
from sympy.utilities.exceptions import sympy_deprecation_warning
|
||
|
from sympy.utilities.iterables import iterable, numbered_symbols
|
||
|
from sympy.utilities.misc import filldedent, func_name
|
||
|
|
||
|
from inspect import getmro
|
||
|
|
||
|
|
||
|
def as_Basic(expr):
|
||
|
"""Return expr as a Basic instance using strict sympify
|
||
|
or raise a TypeError; this is just a wrapper to _sympify,
|
||
|
raising a TypeError instead of a SympifyError."""
|
||
|
try:
|
||
|
return _sympify(expr)
|
||
|
except SympifyError:
|
||
|
raise TypeError(
|
||
|
'Argument must be a Basic object, not `%s`' % func_name(
|
||
|
expr))
|
||
|
|
||
|
|
||
|
def _old_compare(x: type, y: type) -> int:
|
||
|
# If the other object is not a Basic subclass, then we are not equal to it.
|
||
|
if not issubclass(y, Basic):
|
||
|
return -1
|
||
|
|
||
|
n1 = x.__name__
|
||
|
n2 = y.__name__
|
||
|
if n1 == n2:
|
||
|
return 0
|
||
|
|
||
|
UNKNOWN = len(ordering_of_classes) + 1
|
||
|
try:
|
||
|
i1 = ordering_of_classes.index(n1)
|
||
|
except ValueError:
|
||
|
i1 = UNKNOWN
|
||
|
try:
|
||
|
i2 = ordering_of_classes.index(n2)
|
||
|
except ValueError:
|
||
|
i2 = UNKNOWN
|
||
|
if i1 == UNKNOWN and i2 == UNKNOWN:
|
||
|
return (n1 > n2) - (n1 < n2)
|
||
|
return (i1 > i2) - (i1 < i2)
|
||
|
|
||
|
|
||
|
class Basic(Printable):
|
||
|
"""
|
||
|
Base class for all SymPy objects.
|
||
|
|
||
|
Notes and conventions
|
||
|
=====================
|
||
|
|
||
|
1) Always use ``.args``, when accessing parameters of some instance:
|
||
|
|
||
|
>>> from sympy import cot
|
||
|
>>> from sympy.abc import x, y
|
||
|
|
||
|
>>> cot(x).args
|
||
|
(x,)
|
||
|
|
||
|
>>> cot(x).args[0]
|
||
|
x
|
||
|
|
||
|
>>> (x*y).args
|
||
|
(x, y)
|
||
|
|
||
|
>>> (x*y).args[1]
|
||
|
y
|
||
|
|
||
|
|
||
|
2) Never use internal methods or variables (the ones prefixed with ``_``):
|
||
|
|
||
|
>>> cot(x)._args # do not use this, use cot(x).args instead
|
||
|
(x,)
|
||
|
|
||
|
|
||
|
3) By "SymPy object" we mean something that can be returned by
|
||
|
``sympify``. But not all objects one encounters using SymPy are
|
||
|
subclasses of Basic. For example, mutable objects are not:
|
||
|
|
||
|
>>> from sympy import Basic, Matrix, sympify
|
||
|
>>> A = Matrix([[1, 2], [3, 4]]).as_mutable()
|
||
|
>>> isinstance(A, Basic)
|
||
|
False
|
||
|
|
||
|
>>> B = sympify(A)
|
||
|
>>> isinstance(B, Basic)
|
||
|
True
|
||
|
"""
|
||
|
__slots__ = ('_mhash', # hash value
|
||
|
'_args', # arguments
|
||
|
'_assumptions'
|
||
|
)
|
||
|
|
||
|
_args: tuple[Basic, ...]
|
||
|
_mhash: int | None
|
||
|
|
||
|
@property
|
||
|
def __sympy__(self):
|
||
|
return True
|
||
|
|
||
|
def __init_subclass__(cls):
|
||
|
# Initialize the default_assumptions FactKB and also any assumptions
|
||
|
# property methods. This method will only be called for subclasses of
|
||
|
# Basic but not for Basic itself so we call
|
||
|
# _prepare_class_assumptions(Basic) below the class definition.
|
||
|
_prepare_class_assumptions(cls)
|
||
|
|
||
|
# To be overridden with True in the appropriate subclasses
|
||
|
is_number = False
|
||
|
is_Atom = False
|
||
|
is_Symbol = False
|
||
|
is_symbol = False
|
||
|
is_Indexed = False
|
||
|
is_Dummy = False
|
||
|
is_Wild = False
|
||
|
is_Function = False
|
||
|
is_Add = False
|
||
|
is_Mul = False
|
||
|
is_Pow = False
|
||
|
is_Number = False
|
||
|
is_Float = False
|
||
|
is_Rational = False
|
||
|
is_Integer = False
|
||
|
is_NumberSymbol = False
|
||
|
is_Order = False
|
||
|
is_Derivative = False
|
||
|
is_Piecewise = False
|
||
|
is_Poly = False
|
||
|
is_AlgebraicNumber = False
|
||
|
is_Relational = False
|
||
|
is_Equality = False
|
||
|
is_Boolean = False
|
||
|
is_Not = False
|
||
|
is_Matrix = False
|
||
|
is_Vector = False
|
||
|
is_Point = False
|
||
|
is_MatAdd = False
|
||
|
is_MatMul = False
|
||
|
is_real: bool | None
|
||
|
is_extended_real: bool | None
|
||
|
is_zero: bool | None
|
||
|
is_negative: bool | None
|
||
|
is_commutative: bool | None
|
||
|
|
||
|
kind: Kind = UndefinedKind
|
||
|
|
||
|
def __new__(cls, *args):
|
||
|
obj = object.__new__(cls)
|
||
|
obj._assumptions = cls.default_assumptions
|
||
|
obj._mhash = None # will be set by __hash__ method.
|
||
|
|
||
|
obj._args = args # all items in args must be Basic objects
|
||
|
return obj
|
||
|
|
||
|
def copy(self):
|
||
|
return self.func(*self.args)
|
||
|
|
||
|
def __getnewargs__(self):
|
||
|
return self.args
|
||
|
|
||
|
def __getstate__(self):
|
||
|
return None
|
||
|
|
||
|
def __setstate__(self, state):
|
||
|
for name, value in state.items():
|
||
|
setattr(self, name, value)
|
||
|
|
||
|
def __reduce_ex__(self, protocol):
|
||
|
if protocol < 2:
|
||
|
msg = "Only pickle protocol 2 or higher is supported by SymPy"
|
||
|
raise NotImplementedError(msg)
|
||
|
return super().__reduce_ex__(protocol)
|
||
|
|
||
|
def __hash__(self) -> int:
|
||
|
# hash cannot be cached using cache_it because infinite recurrence
|
||
|
# occurs as hash is needed for setting cache dictionary keys
|
||
|
h = self._mhash
|
||
|
if h is None:
|
||
|
h = hash((type(self).__name__,) + self._hashable_content())
|
||
|
self._mhash = h
|
||
|
return h
|
||
|
|
||
|
def _hashable_content(self):
|
||
|
"""Return a tuple of information about self that can be used to
|
||
|
compute the hash. If a class defines additional attributes,
|
||
|
like ``name`` in Symbol, then this method should be updated
|
||
|
accordingly to return such relevant attributes.
|
||
|
|
||
|
Defining more than _hashable_content is necessary if __eq__ has
|
||
|
been defined by a class. See note about this in Basic.__eq__."""
|
||
|
return self._args
|
||
|
|
||
|
@property
|
||
|
def assumptions0(self):
|
||
|
"""
|
||
|
Return object `type` assumptions.
|
||
|
|
||
|
For example:
|
||
|
|
||
|
Symbol('x', real=True)
|
||
|
Symbol('x', integer=True)
|
||
|
|
||
|
are different objects. In other words, besides Python type (Symbol in
|
||
|
this case), the initial assumptions are also forming their typeinfo.
|
||
|
|
||
|
Examples
|
||
|
========
|
||
|
|
||
|
>>> from sympy import Symbol
|
||
|
>>> from sympy.abc import x
|
||
|
>>> x.assumptions0
|
||
|
{'commutative': True}
|
||
|
>>> x = Symbol("x", positive=True)
|
||
|
>>> x.assumptions0
|
||
|
{'commutative': True, 'complex': True, 'extended_negative': False,
|
||
|
'extended_nonnegative': True, 'extended_nonpositive': False,
|
||
|
'extended_nonzero': True, 'extended_positive': True, 'extended_real':
|
||
|
True, 'finite': True, 'hermitian': True, 'imaginary': False,
|
||
|
'infinite': False, 'negative': False, 'nonnegative': True,
|
||
|
'nonpositive': False, 'nonzero': True, 'positive': True, 'real':
|
||
|
True, 'zero': False}
|
||
|
"""
|
||
|
return {}
|
||
|
|
||
|
def compare(self, other):
|
||
|
"""
|
||
|
Return -1, 0, 1 if the object is smaller, equal, or greater than other.
|
||
|
|
||
|
Not in the mathematical sense. If the object is of a different type
|
||
|
from the "other" then their classes are ordered according to
|
||
|
the sorted_classes list.
|
||
|
|
||
|
Examples
|
||
|
========
|
||
|
|
||
|
>>> from sympy.abc import x, y
|
||
|
>>> x.compare(y)
|
||
|
-1
|
||
|
>>> x.compare(x)
|
||
|
0
|
||
|
>>> y.compare(x)
|
||
|
1
|
||
|
|
||
|
"""
|
||
|
# all redefinitions of __cmp__ method should start with the
|
||
|
# following lines:
|
||
|
if self is other:
|
||
|
return 0
|
||
|
n1 = self.__class__
|
||
|
n2 = other.__class__
|
||
|
c = _old_compare(n1, n2)
|
||
|
if c:
|
||
|
return c
|
||
|
#
|
||
|
st = self._hashable_content()
|
||
|
ot = other._hashable_content()
|
||
|
c = (len(st) > len(ot)) - (len(st) < len(ot))
|
||
|
if c:
|
||
|
return c
|
||
|
for l, r in zip(st, ot):
|
||
|
l = Basic(*l) if isinstance(l, frozenset) else l
|
||
|
r = Basic(*r) if isinstance(r, frozenset) else r
|
||
|
if isinstance(l, Basic):
|
||
|
c = l.compare(r)
|
||
|
else:
|
||
|
c = (l > r) - (l < r)
|
||
|
if c:
|
||
|
return c
|
||
|
return 0
|
||
|
|
||
|
@staticmethod
|
||
|
def _compare_pretty(a, b):
|
||
|
from sympy.series.order import Order
|
||
|
if isinstance(a, Order) and not isinstance(b, Order):
|
||
|
return 1
|
||
|
if not isinstance(a, Order) and isinstance(b, Order):
|
||
|
return -1
|
||
|
|
||
|
if a.is_Rational and b.is_Rational:
|
||
|
l = a.p * b.q
|
||
|
r = b.p * a.q
|
||
|
return (l > r) - (l < r)
|
||
|
else:
|
||
|
from .symbol import Wild
|
||
|
p1, p2, p3 = Wild("p1"), Wild("p2"), Wild("p3")
|
||
|
r_a = a.match(p1 * p2**p3)
|
||
|
if r_a and p3 in r_a:
|
||
|
a3 = r_a[p3]
|
||
|
r_b = b.match(p1 * p2**p3)
|
||
|
if r_b and p3 in r_b:
|
||
|
b3 = r_b[p3]
|
||
|
c = Basic.compare(a3, b3)
|
||
|
if c != 0:
|
||
|
return c
|
||
|
|
||
|
return Basic.compare(a, b)
|
||
|
|
||
|
@classmethod
|
||
|
def fromiter(cls, args, **assumptions):
|
||
|
"""
|
||
|
Create a new object from an iterable.
|
||
|
|
||
|
This is a convenience function that allows one to create objects from
|
||
|
any iterable, without having to convert to a list or tuple first.
|
||
|
|
||
|
Examples
|
||
|
========
|
||
|
|
||
|
>>> from sympy import Tuple
|
||
|
>>> Tuple.fromiter(i for i in range(5))
|
||
|
(0, 1, 2, 3, 4)
|
||
|
|
||
|
"""
|
||
|
return cls(*tuple(args), **assumptions)
|
||
|
|
||
|
@classmethod
|
||
|
def class_key(cls):
|
||
|
"""Nice order of classes."""
|
||
|
return 5, 0, cls.__name__
|
||
|
|
||
|
@cacheit
|
||
|
def sort_key(self, order=None):
|
||
|
"""
|
||
|
Return a sort key.
|
||
|
|
||
|
Examples
|
||
|
========
|
||
|
|
||
|
>>> from sympy import S, I
|
||
|
|
||
|
>>> sorted([S(1)/2, I, -I], key=lambda x: x.sort_key())
|
||
|
[1/2, -I, I]
|
||
|
|
||
|
>>> S("[x, 1/x, 1/x**2, x**2, x**(1/2), x**(1/4), x**(3/2)]")
|
||
|
[x, 1/x, x**(-2), x**2, sqrt(x), x**(1/4), x**(3/2)]
|
||
|
>>> sorted(_, key=lambda x: x.sort_key())
|
||
|
[x**(-2), 1/x, x**(1/4), sqrt(x), x, x**(3/2), x**2]
|
||
|
|
||
|
"""
|
||
|
|
||
|
# XXX: remove this when issue 5169 is fixed
|
||
|
def inner_key(arg):
|
||
|
if isinstance(arg, Basic):
|
||
|
return arg.sort_key(order)
|
||
|
else:
|
||
|
return arg
|
||
|
|
||
|
args = self._sorted_args
|
||
|
args = len(args), tuple([inner_key(arg) for arg in args])
|
||
|
return self.class_key(), args, S.One.sort_key(), S.One
|
||
|
|
||
|
def _do_eq_sympify(self, other):
|
||
|
"""Returns a boolean indicating whether a == b when either a
|
||
|
or b is not a Basic. This is only done for types that were either
|
||
|
added to `converter` by a 3rd party or when the object has `_sympy_`
|
||
|
defined. This essentially reuses the code in `_sympify` that is
|
||
|
specific for this use case. Non-user defined types that are meant
|
||
|
to work with SymPy should be handled directly in the __eq__ methods
|
||
|
of the `Basic` classes it could equate to and not be converted. Note
|
||
|
that after conversion, `==` is used again since it is not
|
||
|
necessarily clear whether `self` or `other`'s __eq__ method needs
|
||
|
to be used."""
|
||
|
for superclass in type(other).__mro__:
|
||
|
conv = _external_converter.get(superclass)
|
||
|
if conv is not None:
|
||
|
return self == conv(other)
|
||
|
if hasattr(other, '_sympy_'):
|
||
|
return self == other._sympy_()
|
||
|
return NotImplemented
|
||
|
|
||
|
def __eq__(self, other):
|
||
|
"""Return a boolean indicating whether a == b on the basis of
|
||
|
their symbolic trees.
|
||
|
|
||
|
This is the same as a.compare(b) == 0 but faster.
|
||
|
|
||
|
Notes
|
||
|
=====
|
||
|
|
||
|
If a class that overrides __eq__() needs to retain the
|
||
|
implementation of __hash__() from a parent class, the
|
||
|
interpreter must be told this explicitly by setting
|
||
|
__hash__ : Callable[[object], int] = <ParentClass>.__hash__.
|
||
|
Otherwise the inheritance of __hash__() will be blocked,
|
||
|
just as if __hash__ had been explicitly set to None.
|
||
|
|
||
|
References
|
||
|
==========
|
||
|
|
||
|
from https://docs.python.org/dev/reference/datamodel.html#object.__hash__
|
||
|
"""
|
||
|
if self is other:
|
||
|
return True
|
||
|
|
||
|
if not isinstance(other, Basic):
|
||
|
return self._do_eq_sympify(other)
|
||
|
|
||
|
# check for pure number expr
|
||
|
if not (self.is_Number and other.is_Number) and (
|
||
|
type(self) != type(other)):
|
||
|
return False
|
||
|
a, b = self._hashable_content(), other._hashable_content()
|
||
|
if a != b:
|
||
|
return False
|
||
|
# check number *in* an expression
|
||
|
for a, b in zip(a, b):
|
||
|
if not isinstance(a, Basic):
|
||
|
continue
|
||
|
if a.is_Number and type(a) != type(b):
|
||
|
return False
|
||
|
return True
|
||
|
|
||
|
def __ne__(self, other):
|
||
|
"""``a != b`` -> Compare two symbolic trees and see whether they are different
|
||
|
|
||
|
this is the same as:
|
||
|
|
||
|
``a.compare(b) != 0``
|
||
|
|
||
|
but faster
|
||
|
"""
|
||
|
return not self == other
|
||
|
|
||
|
def dummy_eq(self, other, symbol=None):
|
||
|
"""
|
||
|
Compare two expressions and handle dummy symbols.
|
||
|
|
||
|
Examples
|
||
|
========
|
||
|
|
||
|
>>> from sympy import Dummy
|
||
|
>>> from sympy.abc import x, y
|
||
|
|
||
|
>>> u = Dummy('u')
|
||
|
|
||
|
>>> (u**2 + 1).dummy_eq(x**2 + 1)
|
||
|
True
|
||
|
>>> (u**2 + 1) == (x**2 + 1)
|
||
|
False
|
||
|
|
||
|
>>> (u**2 + y).dummy_eq(x**2 + y, x)
|
||
|
True
|
||
|
>>> (u**2 + y).dummy_eq(x**2 + y, y)
|
||
|
False
|
||
|
|
||
|
"""
|
||
|
s = self.as_dummy()
|
||
|
o = _sympify(other)
|
||
|
o = o.as_dummy()
|
||
|
|
||
|
dummy_symbols = [i for i in s.free_symbols if i.is_Dummy]
|
||
|
|
||
|
if len(dummy_symbols) == 1:
|
||
|
dummy = dummy_symbols.pop()
|
||
|
else:
|
||
|
return s == o
|
||
|
|
||
|
if symbol is None:
|
||
|
symbols = o.free_symbols
|
||
|
|
||
|
if len(symbols) == 1:
|
||
|
symbol = symbols.pop()
|
||
|
else:
|
||
|
return s == o
|
||
|
|
||
|
tmp = dummy.__class__()
|
||
|
|
||
|
return s.xreplace({dummy: tmp}) == o.xreplace({symbol: tmp})
|
||
|
|
||
|
def atoms(self, *types):
|
||
|
"""Returns the atoms that form the current object.
|
||
|
|
||
|
By default, only objects that are truly atomic and cannot
|
||
|
be divided into smaller pieces are returned: symbols, numbers,
|
||
|
and number symbols like I and pi. It is possible to request
|
||
|
atoms of any type, however, as demonstrated below.
|
||
|
|
||
|
Examples
|
||
|
========
|
||
|
|
||
|
>>> from sympy import I, pi, sin
|
||
|
>>> from sympy.abc import x, y
|
||
|
>>> (1 + x + 2*sin(y + I*pi)).atoms()
|
||
|
{1, 2, I, pi, x, y}
|
||
|
|
||
|
If one or more types are given, the results will contain only
|
||
|
those types of atoms.
|
||
|
|
||
|
>>> from sympy import Number, NumberSymbol, Symbol
|
||
|
>>> (1 + x + 2*sin(y + I*pi)).atoms(Symbol)
|
||
|
{x, y}
|
||
|
|
||
|
>>> (1 + x + 2*sin(y + I*pi)).atoms(Number)
|
||
|
{1, 2}
|
||
|
|
||
|
>>> (1 + x + 2*sin(y + I*pi)).atoms(Number, NumberSymbol)
|
||
|
{1, 2, pi}
|
||
|
|
||
|
>>> (1 + x + 2*sin(y + I*pi)).atoms(Number, NumberSymbol, I)
|
||
|
{1, 2, I, pi}
|
||
|
|
||
|
Note that I (imaginary unit) and zoo (complex infinity) are special
|
||
|
types of number symbols and are not part of the NumberSymbol class.
|
||
|
|
||
|
The type can be given implicitly, too:
|
||
|
|
||
|
>>> (1 + x + 2*sin(y + I*pi)).atoms(x) # x is a Symbol
|
||
|
{x, y}
|
||
|
|
||
|
Be careful to check your assumptions when using the implicit option
|
||
|
since ``S(1).is_Integer = True`` but ``type(S(1))`` is ``One``, a special type
|
||
|
of SymPy atom, while ``type(S(2))`` is type ``Integer`` and will find all
|
||
|
integers in an expression:
|
||
|
|
||
|
>>> from sympy import S
|
||
|
>>> (1 + x + 2*sin(y + I*pi)).atoms(S(1))
|
||
|
{1}
|
||
|
|
||
|
>>> (1 + x + 2*sin(y + I*pi)).atoms(S(2))
|
||
|
{1, 2}
|
||
|
|
||
|
Finally, arguments to atoms() can select more than atomic atoms: any
|
||
|
SymPy type (loaded in core/__init__.py) can be listed as an argument
|
||
|
and those types of "atoms" as found in scanning the arguments of the
|
||
|
expression recursively:
|
||
|
|
||
|
>>> from sympy import Function, Mul
|
||
|
>>> from sympy.core.function import AppliedUndef
|
||
|
>>> f = Function('f')
|
||
|
>>> (1 + f(x) + 2*sin(y + I*pi)).atoms(Function)
|
||
|
{f(x), sin(y + I*pi)}
|
||
|
>>> (1 + f(x) + 2*sin(y + I*pi)).atoms(AppliedUndef)
|
||
|
{f(x)}
|
||
|
|
||
|
>>> (1 + x + 2*sin(y + I*pi)).atoms(Mul)
|
||
|
{I*pi, 2*sin(y + I*pi)}
|
||
|
|
||
|
"""
|
||
|
if types:
|
||
|
types = tuple(
|
||
|
[t if isinstance(t, type) else type(t) for t in types])
|
||
|
nodes = _preorder_traversal(self)
|
||
|
if types:
|
||
|
result = {node for node in nodes if isinstance(node, types)}
|
||
|
else:
|
||
|
result = {node for node in nodes if not node.args}
|
||
|
return result
|
||
|
|
||
|
@property
|
||
|
def free_symbols(self) -> set[Basic]:
|
||
|
"""Return from the atoms of self those which are free symbols.
|
||
|
|
||
|
Not all free symbols are ``Symbol``. Eg: IndexedBase('I')[0].free_symbols
|
||
|
|
||
|
For most expressions, all symbols are free symbols. For some classes
|
||
|
this is not true. e.g. Integrals use Symbols for the dummy variables
|
||
|
which are bound variables, so Integral has a method to return all
|
||
|
symbols except those. Derivative keeps track of symbols with respect
|
||
|
to which it will perform a derivative; those are
|
||
|
bound variables, too, so it has its own free_symbols method.
|
||
|
|
||
|
Any other method that uses bound variables should implement a
|
||
|
free_symbols method."""
|
||
|
empty: set[Basic] = set()
|
||
|
return empty.union(*(a.free_symbols for a in self.args))
|
||
|
|
||
|
@property
|
||
|
def expr_free_symbols(self):
|
||
|
sympy_deprecation_warning("""
|
||
|
The expr_free_symbols property is deprecated. Use free_symbols to get
|
||
|
the free symbols of an expression.
|
||
|
""",
|
||
|
deprecated_since_version="1.9",
|
||
|
active_deprecations_target="deprecated-expr-free-symbols")
|
||
|
return set()
|
||
|
|
||
|
def as_dummy(self):
|
||
|
"""Return the expression with any objects having structurally
|
||
|
bound symbols replaced with unique, canonical symbols within
|
||
|
the object in which they appear and having only the default
|
||
|
assumption for commutativity being True. When applied to a
|
||
|
symbol a new symbol having only the same commutativity will be
|
||
|
returned.
|
||
|
|
||
|
Examples
|
||
|
========
|
||
|
|
||
|
>>> from sympy import Integral, Symbol
|
||
|
>>> from sympy.abc import x
|
||
|
>>> r = Symbol('r', real=True)
|
||
|
>>> Integral(r, (r, x)).as_dummy()
|
||
|
Integral(_0, (_0, x))
|
||
|
>>> _.variables[0].is_real is None
|
||
|
True
|
||
|
>>> r.as_dummy()
|
||
|
_r
|
||
|
|
||
|
Notes
|
||
|
=====
|
||
|
|
||
|
Any object that has structurally bound variables should have
|
||
|
a property, `bound_symbols` that returns those symbols
|
||
|
appearing in the object.
|
||
|
"""
|
||
|
from .symbol import Dummy, Symbol
|
||
|
def can(x):
|
||
|
# mask free that shadow bound
|
||
|
free = x.free_symbols
|
||
|
bound = set(x.bound_symbols)
|
||
|
d = {i: Dummy() for i in bound & free}
|
||
|
x = x.subs(d)
|
||
|
# replace bound with canonical names
|
||
|
x = x.xreplace(x.canonical_variables)
|
||
|
# return after undoing masking
|
||
|
return x.xreplace({v: k for k, v in d.items()})
|
||
|
if not self.has(Symbol):
|
||
|
return self
|
||
|
return self.replace(
|
||
|
lambda x: hasattr(x, 'bound_symbols'),
|
||
|
can,
|
||
|
simultaneous=False)
|
||
|
|
||
|
@property
|
||
|
def canonical_variables(self):
|
||
|
"""Return a dictionary mapping any variable defined in
|
||
|
``self.bound_symbols`` to Symbols that do not clash
|
||
|
with any free symbols in the expression.
|
||
|
|
||
|
Examples
|
||
|
========
|
||
|
|
||
|
>>> from sympy import Lambda
|
||
|
>>> from sympy.abc import x
|
||
|
>>> Lambda(x, 2*x).canonical_variables
|
||
|
{x: _0}
|
||
|
"""
|
||
|
if not hasattr(self, 'bound_symbols'):
|
||
|
return {}
|
||
|
dums = numbered_symbols('_')
|
||
|
reps = {}
|
||
|
# watch out for free symbol that are not in bound symbols;
|
||
|
# those that are in bound symbols are about to get changed
|
||
|
bound = self.bound_symbols
|
||
|
names = {i.name for i in self.free_symbols - set(bound)}
|
||
|
for b in bound:
|
||
|
d = next(dums)
|
||
|
if b.is_Symbol:
|
||
|
while d.name in names:
|
||
|
d = next(dums)
|
||
|
reps[b] = d
|
||
|
return reps
|
||
|
|
||
|
def rcall(self, *args):
|
||
|
"""Apply on the argument recursively through the expression tree.
|
||
|
|
||
|
This method is used to simulate a common abuse of notation for
|
||
|
operators. For instance, in SymPy the following will not work:
|
||
|
|
||
|
``(x+Lambda(y, 2*y))(z) == x+2*z``,
|
||
|
|
||
|
however, you can use:
|
||
|
|
||
|
>>> from sympy import Lambda
|
||
|
>>> from sympy.abc import x, y, z
|
||
|
>>> (x + Lambda(y, 2*y)).rcall(z)
|
||
|
x + 2*z
|
||
|
"""
|
||
|
return Basic._recursive_call(self, args)
|
||
|
|
||
|
@staticmethod
|
||
|
def _recursive_call(expr_to_call, on_args):
|
||
|
"""Helper for rcall method."""
|
||
|
from .symbol import Symbol
|
||
|
def the_call_method_is_overridden(expr):
|
||
|
for cls in getmro(type(expr)):
|
||
|
if '__call__' in cls.__dict__:
|
||
|
return cls != Basic
|
||
|
|
||
|
if callable(expr_to_call) and the_call_method_is_overridden(expr_to_call):
|
||
|
if isinstance(expr_to_call, Symbol): # XXX When you call a Symbol it is
|
||
|
return expr_to_call # transformed into an UndefFunction
|
||
|
else:
|
||
|
return expr_to_call(*on_args)
|
||
|
elif expr_to_call.args:
|
||
|
args = [Basic._recursive_call(
|
||
|
sub, on_args) for sub in expr_to_call.args]
|
||
|
return type(expr_to_call)(*args)
|
||
|
else:
|
||
|
return expr_to_call
|
||
|
|
||
|
def is_hypergeometric(self, k):
|
||
|
from sympy.simplify.simplify import hypersimp
|
||
|
from sympy.functions.elementary.piecewise import Piecewise
|
||
|
if self.has(Piecewise):
|
||
|
return None
|
||
|
return hypersimp(self, k) is not None
|
||
|
|
||
|
@property
|
||
|
def is_comparable(self):
|
||
|
"""Return True if self can be computed to a real number
|
||
|
(or already is a real number) with precision, else False.
|
||
|
|
||
|
Examples
|
||
|
========
|
||
|
|
||
|
>>> from sympy import exp_polar, pi, I
|
||
|
>>> (I*exp_polar(I*pi/2)).is_comparable
|
||
|
True
|
||
|
>>> (I*exp_polar(I*pi*2)).is_comparable
|
||
|
False
|
||
|
|
||
|
A False result does not mean that `self` cannot be rewritten
|
||
|
into a form that would be comparable. For example, the
|
||
|
difference computed below is zero but without simplification
|
||
|
it does not evaluate to a zero with precision:
|
||
|
|
||
|
>>> e = 2**pi*(1 + 2**pi)
|
||
|
>>> dif = e - e.expand()
|
||
|
>>> dif.is_comparable
|
||
|
False
|
||
|
>>> dif.n(2)._prec
|
||
|
1
|
||
|
|
||
|
"""
|
||
|
is_extended_real = self.is_extended_real
|
||
|
if is_extended_real is False:
|
||
|
return False
|
||
|
if not self.is_number:
|
||
|
return False
|
||
|
# don't re-eval numbers that are already evaluated since
|
||
|
# this will create spurious precision
|
||
|
n, i = [p.evalf(2) if not p.is_Number else p
|
||
|
for p in self.as_real_imag()]
|
||
|
if not (i.is_Number and n.is_Number):
|
||
|
return False
|
||
|
if i:
|
||
|
# if _prec = 1 we can't decide and if not,
|
||
|
# the answer is False because numbers with
|
||
|
# imaginary parts can't be compared
|
||
|
# so return False
|
||
|
return False
|
||
|
else:
|
||
|
return n._prec != 1
|
||
|
|
||
|
@property
|
||
|
def func(self):
|
||
|
"""
|
||
|
The top-level function in an expression.
|
||
|
|
||
|
The following should hold for all objects::
|
||
|
|
||
|
>> x == x.func(*x.args)
|
||
|
|
||
|
Examples
|
||
|
========
|
||
|
|
||
|
>>> from sympy.abc import x
|
||
|
>>> a = 2*x
|
||
|
>>> a.func
|
||
|
<class 'sympy.core.mul.Mul'>
|
||
|
>>> a.args
|
||
|
(2, x)
|
||
|
>>> a.func(*a.args)
|
||
|
2*x
|
||
|
>>> a == a.func(*a.args)
|
||
|
True
|
||
|
|
||
|
"""
|
||
|
return self.__class__
|
||
|
|
||
|
@property
|
||
|
def args(self) -> tuple[Basic, ...]:
|
||
|
"""Returns a tuple of arguments of 'self'.
|
||
|
|
||
|
Examples
|
||
|
========
|
||
|
|
||
|
>>> from sympy import cot
|
||
|
>>> from sympy.abc import x, y
|
||
|
|
||
|
>>> cot(x).args
|
||
|
(x,)
|
||
|
|
||
|
>>> cot(x).args[0]
|
||
|
x
|
||
|
|
||
|
>>> (x*y).args
|
||
|
(x, y)
|
||
|
|
||
|
>>> (x*y).args[1]
|
||
|
y
|
||
|
|
||
|
Notes
|
||
|
=====
|
||
|
|
||
|
Never use self._args, always use self.args.
|
||
|
Only use _args in __new__ when creating a new function.
|
||
|
Do not override .args() from Basic (so that it is easy to
|
||
|
change the interface in the future if needed).
|
||
|
"""
|
||
|
return self._args
|
||
|
|
||
|
@property
|
||
|
def _sorted_args(self):
|
||
|
"""
|
||
|
The same as ``args``. Derived classes which do not fix an
|
||
|
order on their arguments should override this method to
|
||
|
produce the sorted representation.
|
||
|
"""
|
||
|
return self.args
|
||
|
|
||
|
def as_content_primitive(self, radical=False, clear=True):
|
||
|
"""A stub to allow Basic args (like Tuple) to be skipped when computing
|
||
|
the content and primitive components of an expression.
|
||
|
|
||
|
See Also
|
||
|
========
|
||
|
|
||
|
sympy.core.expr.Expr.as_content_primitive
|
||
|
"""
|
||
|
return S.One, self
|
||
|
|
||
|
def subs(self, *args, **kwargs):
|
||
|
"""
|
||
|
Substitutes old for new in an expression after sympifying args.
|
||
|
|
||
|
`args` is either:
|
||
|
- two arguments, e.g. foo.subs(old, new)
|
||
|
- one iterable argument, e.g. foo.subs(iterable). The iterable may be
|
||
|
o an iterable container with (old, new) pairs. In this case the
|
||
|
replacements are processed in the order given with successive
|
||
|
patterns possibly affecting replacements already made.
|
||
|
o a dict or set whose key/value items correspond to old/new pairs.
|
||
|
In this case the old/new pairs will be sorted by op count and in
|
||
|
case of a tie, by number of args and the default_sort_key. The
|
||
|
resulting sorted list is then processed as an iterable container
|
||
|
(see previous).
|
||
|
|
||
|
If the keyword ``simultaneous`` is True, the subexpressions will not be
|
||
|
evaluated until all the substitutions have been made.
|
||
|
|
||
|
Examples
|
||
|
========
|
||
|
|
||
|
>>> from sympy import pi, exp, limit, oo
|
||
|
>>> from sympy.abc import x, y
|
||
|
>>> (1 + x*y).subs(x, pi)
|
||
|
pi*y + 1
|
||
|
>>> (1 + x*y).subs({x:pi, y:2})
|
||
|
1 + 2*pi
|
||
|
>>> (1 + x*y).subs([(x, pi), (y, 2)])
|
||
|
1 + 2*pi
|
||
|
>>> reps = [(y, x**2), (x, 2)]
|
||
|
>>> (x + y).subs(reps)
|
||
|
6
|
||
|
>>> (x + y).subs(reversed(reps))
|
||
|
x**2 + 2
|
||
|
|
||
|
>>> (x**2 + x**4).subs(x**2, y)
|
||
|
y**2 + y
|
||
|
|
||
|
To replace only the x**2 but not the x**4, use xreplace:
|
||
|
|
||
|
>>> (x**2 + x**4).xreplace({x**2: y})
|
||
|
x**4 + y
|
||
|
|
||
|
To delay evaluation until all substitutions have been made,
|
||
|
set the keyword ``simultaneous`` to True:
|
||
|
|
||
|
>>> (x/y).subs([(x, 0), (y, 0)])
|
||
|
0
|
||
|
>>> (x/y).subs([(x, 0), (y, 0)], simultaneous=True)
|
||
|
nan
|
||
|
|
||
|
This has the added feature of not allowing subsequent substitutions
|
||
|
to affect those already made:
|
||
|
|
||
|
>>> ((x + y)/y).subs({x + y: y, y: x + y})
|
||
|
1
|
||
|
>>> ((x + y)/y).subs({x + y: y, y: x + y}, simultaneous=True)
|
||
|
y/(x + y)
|
||
|
|
||
|
In order to obtain a canonical result, unordered iterables are
|
||
|
sorted by count_op length, number of arguments and by the
|
||
|
default_sort_key to break any ties. All other iterables are left
|
||
|
unsorted.
|
||
|
|
||
|
>>> from sympy import sqrt, sin, cos
|
||
|
>>> from sympy.abc import a, b, c, d, e
|
||
|
|
||
|
>>> A = (sqrt(sin(2*x)), a)
|
||
|
>>> B = (sin(2*x), b)
|
||
|
>>> C = (cos(2*x), c)
|
||
|
>>> D = (x, d)
|
||
|
>>> E = (exp(x), e)
|
||
|
|
||
|
>>> expr = sqrt(sin(2*x))*sin(exp(x)*x)*cos(2*x) + sin(2*x)
|
||
|
|
||
|
>>> expr.subs(dict([A, B, C, D, E]))
|
||
|
a*c*sin(d*e) + b
|
||
|
|
||
|
The resulting expression represents a literal replacement of the
|
||
|
old arguments with the new arguments. This may not reflect the
|
||
|
limiting behavior of the expression:
|
||
|
|
||
|
>>> (x**3 - 3*x).subs({x: oo})
|
||
|
nan
|
||
|
|
||
|
>>> limit(x**3 - 3*x, x, oo)
|
||
|
oo
|
||
|
|
||
|
If the substitution will be followed by numerical
|
||
|
evaluation, it is better to pass the substitution to
|
||
|
evalf as
|
||
|
|
||
|
>>> (1/x).evalf(subs={x: 3.0}, n=21)
|
||
|
0.333333333333333333333
|
||
|
|
||
|
rather than
|
||
|
|
||
|
>>> (1/x).subs({x: 3.0}).evalf(21)
|
||
|
0.333333333333333314830
|
||
|
|
||
|
as the former will ensure that the desired level of precision is
|
||
|
obtained.
|
||
|
|
||
|
See Also
|
||
|
========
|
||
|
replace: replacement capable of doing wildcard-like matching,
|
||
|
parsing of match, and conditional replacements
|
||
|
xreplace: exact node replacement in expr tree; also capable of
|
||
|
using matching rules
|
||
|
sympy.core.evalf.EvalfMixin.evalf: calculates the given formula to a desired level of precision
|
||
|
|
||
|
"""
|
||
|
from .containers import Dict
|
||
|
from .symbol import Dummy, Symbol
|
||
|
from .numbers import _illegal
|
||
|
|
||
|
unordered = False
|
||
|
if len(args) == 1:
|
||
|
|
||
|
sequence = args[0]
|
||
|
if isinstance(sequence, set):
|
||
|
unordered = True
|
||
|
elif isinstance(sequence, (Dict, Mapping)):
|
||
|
unordered = True
|
||
|
sequence = sequence.items()
|
||
|
elif not iterable(sequence):
|
||
|
raise ValueError(filldedent("""
|
||
|
When a single argument is passed to subs
|
||
|
it should be a dictionary of old: new pairs or an iterable
|
||
|
of (old, new) tuples."""))
|
||
|
elif len(args) == 2:
|
||
|
sequence = [args]
|
||
|
else:
|
||
|
raise ValueError("subs accepts either 1 or 2 arguments")
|
||
|
|
||
|
def sympify_old(old):
|
||
|
if isinstance(old, str):
|
||
|
# Use Symbol rather than parse_expr for old
|
||
|
return Symbol(old)
|
||
|
elif isinstance(old, type):
|
||
|
# Allow a type e.g. Function('f') or sin
|
||
|
return sympify(old, strict=False)
|
||
|
else:
|
||
|
return sympify(old, strict=True)
|
||
|
|
||
|
def sympify_new(new):
|
||
|
if isinstance(new, (str, type)):
|
||
|
# Allow a type or parse a string input
|
||
|
return sympify(new, strict=False)
|
||
|
else:
|
||
|
return sympify(new, strict=True)
|
||
|
|
||
|
sequence = [(sympify_old(s1), sympify_new(s2)) for s1, s2 in sequence]
|
||
|
|
||
|
# skip if there is no change
|
||
|
sequence = [(s1, s2) for s1, s2 in sequence if not _aresame(s1, s2)]
|
||
|
|
||
|
simultaneous = kwargs.pop('simultaneous', False)
|
||
|
|
||
|
if unordered:
|
||
|
from .sorting import _nodes, default_sort_key
|
||
|
sequence = dict(sequence)
|
||
|
# order so more complex items are first and items
|
||
|
# of identical complexity are ordered so
|
||
|
# f(x) < f(y) < x < y
|
||
|
# \___ 2 __/ \_1_/ <- number of nodes
|
||
|
#
|
||
|
# For more complex ordering use an unordered sequence.
|
||
|
k = list(ordered(sequence, default=False, keys=(
|
||
|
lambda x: -_nodes(x),
|
||
|
default_sort_key,
|
||
|
)))
|
||
|
sequence = [(k, sequence[k]) for k in k]
|
||
|
# do infinities first
|
||
|
if not simultaneous:
|
||
|
redo = [i for i, seq in enumerate(sequence) if seq[1] in _illegal]
|
||
|
for i in reversed(redo):
|
||
|
sequence.insert(0, sequence.pop(i))
|
||
|
|
||
|
if simultaneous: # XXX should this be the default for dict subs?
|
||
|
reps = {}
|
||
|
rv = self
|
||
|
kwargs['hack2'] = True
|
||
|
m = Dummy('subs_m')
|
||
|
for old, new in sequence:
|
||
|
com = new.is_commutative
|
||
|
if com is None:
|
||
|
com = True
|
||
|
d = Dummy('subs_d', commutative=com)
|
||
|
# using d*m so Subs will be used on dummy variables
|
||
|
# in things like Derivative(f(x, y), x) in which x
|
||
|
# is both free and bound
|
||
|
rv = rv._subs(old, d*m, **kwargs)
|
||
|
if not isinstance(rv, Basic):
|
||
|
break
|
||
|
reps[d] = new
|
||
|
reps[m] = S.One # get rid of m
|
||
|
return rv.xreplace(reps)
|
||
|
else:
|
||
|
rv = self
|
||
|
for old, new in sequence:
|
||
|
rv = rv._subs(old, new, **kwargs)
|
||
|
if not isinstance(rv, Basic):
|
||
|
break
|
||
|
return rv
|
||
|
|
||
|
@cacheit
|
||
|
def _subs(self, old, new, **hints):
|
||
|
"""Substitutes an expression old -> new.
|
||
|
|
||
|
If self is not equal to old then _eval_subs is called.
|
||
|
If _eval_subs does not want to make any special replacement
|
||
|
then a None is received which indicates that the fallback
|
||
|
should be applied wherein a search for replacements is made
|
||
|
amongst the arguments of self.
|
||
|
|
||
|
>>> from sympy import Add
|
||
|
>>> from sympy.abc import x, y, z
|
||
|
|
||
|
Examples
|
||
|
========
|
||
|
|
||
|
Add's _eval_subs knows how to target x + y in the following
|
||
|
so it makes the change:
|
||
|
|
||
|
>>> (x + y + z).subs(x + y, 1)
|
||
|
z + 1
|
||
|
|
||
|
Add's _eval_subs does not need to know how to find x + y in
|
||
|
the following:
|
||
|
|
||
|
>>> Add._eval_subs(z*(x + y) + 3, x + y, 1) is None
|
||
|
True
|
||
|
|
||
|
The returned None will cause the fallback routine to traverse the args and
|
||
|
pass the z*(x + y) arg to Mul where the change will take place and the
|
||
|
substitution will succeed:
|
||
|
|
||
|
>>> (z*(x + y) + 3).subs(x + y, 1)
|
||
|
z + 3
|
||
|
|
||
|
** Developers Notes **
|
||
|
|
||
|
An _eval_subs routine for a class should be written if:
|
||
|
|
||
|
1) any arguments are not instances of Basic (e.g. bool, tuple);
|
||
|
|
||
|
2) some arguments should not be targeted (as in integration
|
||
|
variables);
|
||
|
|
||
|
3) if there is something other than a literal replacement
|
||
|
that should be attempted (as in Piecewise where the condition
|
||
|
may be updated without doing a replacement).
|
||
|
|
||
|
If it is overridden, here are some special cases that might arise:
|
||
|
|
||
|
1) If it turns out that no special change was made and all
|
||
|
the original sub-arguments should be checked for
|
||
|
replacements then None should be returned.
|
||
|
|
||
|
2) If it is necessary to do substitutions on a portion of
|
||
|
the expression then _subs should be called. _subs will
|
||
|
handle the case of any sub-expression being equal to old
|
||
|
(which usually would not be the case) while its fallback
|
||
|
will handle the recursion into the sub-arguments. For
|
||
|
example, after Add's _eval_subs removes some matching terms
|
||
|
it must process the remaining terms so it calls _subs
|
||
|
on each of the un-matched terms and then adds them
|
||
|
onto the terms previously obtained.
|
||
|
|
||
|
3) If the initial expression should remain unchanged then
|
||
|
the original expression should be returned. (Whenever an
|
||
|
expression is returned, modified or not, no further
|
||
|
substitution of old -> new is attempted.) Sum's _eval_subs
|
||
|
routine uses this strategy when a substitution is attempted
|
||
|
on any of its summation variables.
|
||
|
"""
|
||
|
|
||
|
def fallback(self, old, new):
|
||
|
"""
|
||
|
Try to replace old with new in any of self's arguments.
|
||
|
"""
|
||
|
hit = False
|
||
|
args = list(self.args)
|
||
|
for i, arg in enumerate(args):
|
||
|
if not hasattr(arg, '_eval_subs'):
|
||
|
continue
|
||
|
arg = arg._subs(old, new, **hints)
|
||
|
if not _aresame(arg, args[i]):
|
||
|
hit = True
|
||
|
args[i] = arg
|
||
|
if hit:
|
||
|
rv = self.func(*args)
|
||
|
hack2 = hints.get('hack2', False)
|
||
|
if hack2 and self.is_Mul and not rv.is_Mul: # 2-arg hack
|
||
|
coeff = S.One
|
||
|
nonnumber = []
|
||
|
for i in args:
|
||
|
if i.is_Number:
|
||
|
coeff *= i
|
||
|
else:
|
||
|
nonnumber.append(i)
|
||
|
nonnumber = self.func(*nonnumber)
|
||
|
if coeff is S.One:
|
||
|
return nonnumber
|
||
|
else:
|
||
|
return self.func(coeff, nonnumber, evaluate=False)
|
||
|
return rv
|
||
|
return self
|
||
|
|
||
|
if _aresame(self, old):
|
||
|
return new
|
||
|
|
||
|
rv = self._eval_subs(old, new)
|
||
|
if rv is None:
|
||
|
rv = fallback(self, old, new)
|
||
|
return rv
|
||
|
|
||
|
def _eval_subs(self, old, new):
|
||
|
"""Override this stub if you want to do anything more than
|
||
|
attempt a replacement of old with new in the arguments of self.
|
||
|
|
||
|
See also
|
||
|
========
|
||
|
|
||
|
_subs
|
||
|
"""
|
||
|
return None
|
||
|
|
||
|
def xreplace(self, rule):
|
||
|
"""
|
||
|
Replace occurrences of objects within the expression.
|
||
|
|
||
|
Parameters
|
||
|
==========
|
||
|
|
||
|
rule : dict-like
|
||
|
Expresses a replacement rule
|
||
|
|
||
|
Returns
|
||
|
=======
|
||
|
|
||
|
xreplace : the result of the replacement
|
||
|
|
||
|
Examples
|
||
|
========
|
||
|
|
||
|
>>> from sympy import symbols, pi, exp
|
||
|
>>> x, y, z = symbols('x y z')
|
||
|
>>> (1 + x*y).xreplace({x: pi})
|
||
|
pi*y + 1
|
||
|
>>> (1 + x*y).xreplace({x: pi, y: 2})
|
||
|
1 + 2*pi
|
||
|
|
||
|
Replacements occur only if an entire node in the expression tree is
|
||
|
matched:
|
||
|
|
||
|
>>> (x*y + z).xreplace({x*y: pi})
|
||
|
z + pi
|
||
|
>>> (x*y*z).xreplace({x*y: pi})
|
||
|
x*y*z
|
||
|
>>> (2*x).xreplace({2*x: y, x: z})
|
||
|
y
|
||
|
>>> (2*2*x).xreplace({2*x: y, x: z})
|
||
|
4*z
|
||
|
>>> (x + y + 2).xreplace({x + y: 2})
|
||
|
x + y + 2
|
||
|
>>> (x + 2 + exp(x + 2)).xreplace({x + 2: y})
|
||
|
x + exp(y) + 2
|
||
|
|
||
|
xreplace does not differentiate between free and bound symbols. In the
|
||
|
following, subs(x, y) would not change x since it is a bound symbol,
|
||
|
but xreplace does:
|
||
|
|
||
|
>>> from sympy import Integral
|
||
|
>>> Integral(x, (x, 1, 2*x)).xreplace({x: y})
|
||
|
Integral(y, (y, 1, 2*y))
|
||
|
|
||
|
Trying to replace x with an expression raises an error:
|
||
|
|
||
|
>>> Integral(x, (x, 1, 2*x)).xreplace({x: 2*y}) # doctest: +SKIP
|
||
|
ValueError: Invalid limits given: ((2*y, 1, 4*y),)
|
||
|
|
||
|
See Also
|
||
|
========
|
||
|
replace: replacement capable of doing wildcard-like matching,
|
||
|
parsing of match, and conditional replacements
|
||
|
subs: substitution of subexpressions as defined by the objects
|
||
|
themselves.
|
||
|
|
||
|
"""
|
||
|
value, _ = self._xreplace(rule)
|
||
|
return value
|
||
|
|
||
|
def _xreplace(self, rule):
|
||
|
"""
|
||
|
Helper for xreplace. Tracks whether a replacement actually occurred.
|
||
|
"""
|
||
|
if self in rule:
|
||
|
return rule[self], True
|
||
|
elif rule:
|
||
|
args = []
|
||
|
changed = False
|
||
|
for a in self.args:
|
||
|
_xreplace = getattr(a, '_xreplace', None)
|
||
|
if _xreplace is not None:
|
||
|
a_xr = _xreplace(rule)
|
||
|
args.append(a_xr[0])
|
||
|
changed |= a_xr[1]
|
||
|
else:
|
||
|
args.append(a)
|
||
|
args = tuple(args)
|
||
|
if changed:
|
||
|
return self.func(*args), True
|
||
|
return self, False
|
||
|
|
||
|
@cacheit
|
||
|
def has(self, *patterns):
|
||
|
"""
|
||
|
Test whether any subexpression matches any of the patterns.
|
||
|
|
||
|
Examples
|
||
|
========
|
||
|
|
||
|
>>> from sympy import sin
|
||
|
>>> from sympy.abc import x, y, z
|
||
|
>>> (x**2 + sin(x*y)).has(z)
|
||
|
False
|
||
|
>>> (x**2 + sin(x*y)).has(x, y, z)
|
||
|
True
|
||
|
>>> x.has(x)
|
||
|
True
|
||
|
|
||
|
Note ``has`` is a structural algorithm with no knowledge of
|
||
|
mathematics. Consider the following half-open interval:
|
||
|
|
||
|
>>> from sympy import Interval
|
||
|
>>> i = Interval.Lopen(0, 5); i
|
||
|
Interval.Lopen(0, 5)
|
||
|
>>> i.args
|
||
|
(0, 5, True, False)
|
||
|
>>> i.has(4) # there is no "4" in the arguments
|
||
|
False
|
||
|
>>> i.has(0) # there *is* a "0" in the arguments
|
||
|
True
|
||
|
|
||
|
Instead, use ``contains`` to determine whether a number is in the
|
||
|
interval or not:
|
||
|
|
||
|
>>> i.contains(4)
|
||
|
True
|
||
|
>>> i.contains(0)
|
||
|
False
|
||
|
|
||
|
|
||
|
Note that ``expr.has(*patterns)`` is exactly equivalent to
|
||
|
``any(expr.has(p) for p in patterns)``. In particular, ``False`` is
|
||
|
returned when the list of patterns is empty.
|
||
|
|
||
|
>>> x.has()
|
||
|
False
|
||
|
|
||
|
"""
|
||
|
return self._has(iterargs, *patterns)
|
||
|
|
||
|
def has_xfree(self, s: set[Basic]):
|
||
|
"""Return True if self has any of the patterns in s as a
|
||
|
free argument, else False. This is like `Basic.has_free`
|
||
|
but this will only report exact argument matches.
|
||
|
|
||
|
Examples
|
||
|
========
|
||
|
|
||
|
>>> from sympy import Function
|
||
|
>>> from sympy.abc import x, y
|
||
|
>>> f = Function('f')
|
||
|
>>> f(x).has_xfree({f})
|
||
|
False
|
||
|
>>> f(x).has_xfree({f(x)})
|
||
|
True
|
||
|
>>> f(x + 1).has_xfree({x})
|
||
|
True
|
||
|
>>> f(x + 1).has_xfree({x + 1})
|
||
|
True
|
||
|
>>> f(x + y + 1).has_xfree({x + 1})
|
||
|
False
|
||
|
"""
|
||
|
# protect O(1) containment check by requiring:
|
||
|
if type(s) is not set:
|
||
|
raise TypeError('expecting set argument')
|
||
|
return any(a in s for a in iterfreeargs(self))
|
||
|
|
||
|
@cacheit
|
||
|
def has_free(self, *patterns):
|
||
|
"""Return True if self has object(s) ``x`` as a free expression
|
||
|
else False.
|
||
|
|
||
|
Examples
|
||
|
========
|
||
|
|
||
|
>>> from sympy import Integral, Function
|
||
|
>>> from sympy.abc import x, y
|
||
|
>>> f = Function('f')
|
||
|
>>> g = Function('g')
|
||
|
>>> expr = Integral(f(x), (f(x), 1, g(y)))
|
||
|
>>> expr.free_symbols
|
||
|
{y}
|
||
|
>>> expr.has_free(g(y))
|
||
|
True
|
||
|
>>> expr.has_free(*(x, f(x)))
|
||
|
False
|
||
|
|
||
|
This works for subexpressions and types, too:
|
||
|
|
||
|
>>> expr.has_free(g)
|
||
|
True
|
||
|
>>> (x + y + 1).has_free(y + 1)
|
||
|
True
|
||
|
"""
|
||
|
if not patterns:
|
||
|
return False
|
||
|
p0 = patterns[0]
|
||
|
if len(patterns) == 1 and iterable(p0) and not isinstance(p0, Basic):
|
||
|
# Basic can contain iterables (though not non-Basic, ideally)
|
||
|
# but don't encourage mixed passing patterns
|
||
|
raise TypeError(filldedent('''
|
||
|
Expecting 1 or more Basic args, not a single
|
||
|
non-Basic iterable. Don't forget to unpack
|
||
|
iterables: `eq.has_free(*patterns)`'''))
|
||
|
# try quick test first
|
||
|
s = set(patterns)
|
||
|
rv = self.has_xfree(s)
|
||
|
if rv:
|
||
|
return rv
|
||
|
# now try matching through slower _has
|
||
|
return self._has(iterfreeargs, *patterns)
|
||
|
|
||
|
def _has(self, iterargs, *patterns):
|
||
|
# separate out types and unhashable objects
|
||
|
type_set = set() # only types
|
||
|
p_set = set() # hashable non-types
|
||
|
for p in patterns:
|
||
|
if isinstance(p, type) and issubclass(p, Basic):
|
||
|
type_set.add(p)
|
||
|
continue
|
||
|
if not isinstance(p, Basic):
|
||
|
try:
|
||
|
p = _sympify(p)
|
||
|
except SympifyError:
|
||
|
continue # Basic won't have this in it
|
||
|
p_set.add(p) # fails if object defines __eq__ but
|
||
|
# doesn't define __hash__
|
||
|
types = tuple(type_set) #
|
||
|
for i in iterargs(self): #
|
||
|
if i in p_set: # <--- here, too
|
||
|
return True
|
||
|
if isinstance(i, types):
|
||
|
return True
|
||
|
|
||
|
# use matcher if defined, e.g. operations defines
|
||
|
# matcher that checks for exact subset containment,
|
||
|
# (x + y + 1).has(x + 1) -> True
|
||
|
for i in p_set - type_set: # types don't have matchers
|
||
|
if not hasattr(i, '_has_matcher'):
|
||
|
continue
|
||
|
match = i._has_matcher()
|
||
|
if any(match(arg) for arg in iterargs(self)):
|
||
|
return True
|
||
|
|
||
|
# no success
|
||
|
return False
|
||
|
|
||
|
def replace(self, query, value, map=False, simultaneous=True, exact=None):
|
||
|
"""
|
||
|
Replace matching subexpressions of ``self`` with ``value``.
|
||
|
|
||
|
If ``map = True`` then also return the mapping {old: new} where ``old``
|
||
|
was a sub-expression found with query and ``new`` is the replacement
|
||
|
value for it. If the expression itself does not match the query, then
|
||
|
the returned value will be ``self.xreplace(map)`` otherwise it should
|
||
|
be ``self.subs(ordered(map.items()))``.
|
||
|
|
||
|
Traverses an expression tree and performs replacement of matching
|
||
|
subexpressions from the bottom to the top of the tree. The default
|
||
|
approach is to do the replacement in a simultaneous fashion so
|
||
|
changes made are targeted only once. If this is not desired or causes
|
||
|
problems, ``simultaneous`` can be set to False.
|
||
|
|
||
|
In addition, if an expression containing more than one Wild symbol
|
||
|
is being used to match subexpressions and the ``exact`` flag is None
|
||
|
it will be set to True so the match will only succeed if all non-zero
|
||
|
values are received for each Wild that appears in the match pattern.
|
||
|
Setting this to False accepts a match of 0; while setting it True
|
||
|
accepts all matches that have a 0 in them. See example below for
|
||
|
cautions.
|
||
|
|
||
|
The list of possible combinations of queries and replacement values
|
||
|
is listed below:
|
||
|
|
||
|
Examples
|
||
|
========
|
||
|
|
||
|
Initial setup
|
||
|
|
||
|
>>> from sympy import log, sin, cos, tan, Wild, Mul, Add
|
||
|
>>> from sympy.abc import x, y
|
||
|
>>> f = log(sin(x)) + tan(sin(x**2))
|
||
|
|
||
|
1.1. type -> type
|
||
|
obj.replace(type, newtype)
|
||
|
|
||
|
When object of type ``type`` is found, replace it with the
|
||
|
result of passing its argument(s) to ``newtype``.
|
||
|
|
||
|
>>> f.replace(sin, cos)
|
||
|
log(cos(x)) + tan(cos(x**2))
|
||
|
>>> sin(x).replace(sin, cos, map=True)
|
||
|
(cos(x), {sin(x): cos(x)})
|
||
|
>>> (x*y).replace(Mul, Add)
|
||
|
x + y
|
||
|
|
||
|
1.2. type -> func
|
||
|
obj.replace(type, func)
|
||
|
|
||
|
When object of type ``type`` is found, apply ``func`` to its
|
||
|
argument(s). ``func`` must be written to handle the number
|
||
|
of arguments of ``type``.
|
||
|
|
||
|
>>> f.replace(sin, lambda arg: sin(2*arg))
|
||
|
log(sin(2*x)) + tan(sin(2*x**2))
|
||
|
>>> (x*y).replace(Mul, lambda *args: sin(2*Mul(*args)))
|
||
|
sin(2*x*y)
|
||
|
|
||
|
2.1. pattern -> expr
|
||
|
obj.replace(pattern(wild), expr(wild))
|
||
|
|
||
|
Replace subexpressions matching ``pattern`` with the expression
|
||
|
written in terms of the Wild symbols in ``pattern``.
|
||
|
|
||
|
>>> a, b = map(Wild, 'ab')
|
||
|
>>> f.replace(sin(a), tan(a))
|
||
|
log(tan(x)) + tan(tan(x**2))
|
||
|
>>> f.replace(sin(a), tan(a/2))
|
||
|
log(tan(x/2)) + tan(tan(x**2/2))
|
||
|
>>> f.replace(sin(a), a)
|
||
|
log(x) + tan(x**2)
|
||
|
>>> (x*y).replace(a*x, a)
|
||
|
y
|
||
|
|
||
|
Matching is exact by default when more than one Wild symbol
|
||
|
is used: matching fails unless the match gives non-zero
|
||
|
values for all Wild symbols:
|
||
|
|
||
|
>>> (2*x + y).replace(a*x + b, b - a)
|
||
|
y - 2
|
||
|
>>> (2*x).replace(a*x + b, b - a)
|
||
|
2*x
|
||
|
|
||
|
When set to False, the results may be non-intuitive:
|
||
|
|
||
|
>>> (2*x).replace(a*x + b, b - a, exact=False)
|
||
|
2/x
|
||
|
|
||
|
2.2. pattern -> func
|
||
|
obj.replace(pattern(wild), lambda wild: expr(wild))
|
||
|
|
||
|
All behavior is the same as in 2.1 but now a function in terms of
|
||
|
pattern variables is used rather than an expression:
|
||
|
|
||
|
>>> f.replace(sin(a), lambda a: sin(2*a))
|
||
|
log(sin(2*x)) + tan(sin(2*x**2))
|
||
|
|
||
|
3.1. func -> func
|
||
|
obj.replace(filter, func)
|
||
|
|
||
|
Replace subexpression ``e`` with ``func(e)`` if ``filter(e)``
|
||
|
is True.
|
||
|
|
||
|
>>> g = 2*sin(x**3)
|
||
|
>>> g.replace(lambda expr: expr.is_Number, lambda expr: expr**2)
|
||
|
4*sin(x**9)
|
||
|
|
||
|
The expression itself is also targeted by the query but is done in
|
||
|
such a fashion that changes are not made twice.
|
||
|
|
||
|
>>> e = x*(x*y + 1)
|
||
|
>>> e.replace(lambda x: x.is_Mul, lambda x: 2*x)
|
||
|
2*x*(2*x*y + 1)
|
||
|
|
||
|
When matching a single symbol, `exact` will default to True, but
|
||
|
this may or may not be the behavior that is desired:
|
||
|
|
||
|
Here, we want `exact=False`:
|
||
|
|
||
|
>>> from sympy import Function
|
||
|
>>> f = Function('f')
|
||
|
>>> e = f(1) + f(0)
|
||
|
>>> q = f(a), lambda a: f(a + 1)
|
||
|
>>> e.replace(*q, exact=False)
|
||
|
f(1) + f(2)
|
||
|
>>> e.replace(*q, exact=True)
|
||
|
f(0) + f(2)
|
||
|
|
||
|
But here, the nature of matching makes selecting
|
||
|
the right setting tricky:
|
||
|
|
||
|
>>> e = x**(1 + y)
|
||
|
>>> (x**(1 + y)).replace(x**(1 + a), lambda a: x**-a, exact=False)
|
||
|
x
|
||
|
>>> (x**(1 + y)).replace(x**(1 + a), lambda a: x**-a, exact=True)
|
||
|
x**(-x - y + 1)
|
||
|
>>> (x**y).replace(x**(1 + a), lambda a: x**-a, exact=False)
|
||
|
x
|
||
|
>>> (x**y).replace(x**(1 + a), lambda a: x**-a, exact=True)
|
||
|
x**(1 - y)
|
||
|
|
||
|
It is probably better to use a different form of the query
|
||
|
that describes the target expression more precisely:
|
||
|
|
||
|
>>> (1 + x**(1 + y)).replace(
|
||
|
... lambda x: x.is_Pow and x.exp.is_Add and x.exp.args[0] == 1,
|
||
|
... lambda x: x.base**(1 - (x.exp - 1)))
|
||
|
...
|
||
|
x**(1 - y) + 1
|
||
|
|
||
|
See Also
|
||
|
========
|
||
|
|
||
|
subs: substitution of subexpressions as defined by the objects
|
||
|
themselves.
|
||
|
xreplace: exact node replacement in expr tree; also capable of
|
||
|
using matching rules
|
||
|
|
||
|
"""
|
||
|
|
||
|
try:
|
||
|
query = _sympify(query)
|
||
|
except SympifyError:
|
||
|
pass
|
||
|
try:
|
||
|
value = _sympify(value)
|
||
|
except SympifyError:
|
||
|
pass
|
||
|
if isinstance(query, type):
|
||
|
_query = lambda expr: isinstance(expr, query)
|
||
|
|
||
|
if isinstance(value, type):
|
||
|
_value = lambda expr, result: value(*expr.args)
|
||
|
elif callable(value):
|
||
|
_value = lambda expr, result: value(*expr.args)
|
||
|
else:
|
||
|
raise TypeError(
|
||
|
"given a type, replace() expects another "
|
||
|
"type or a callable")
|
||
|
elif isinstance(query, Basic):
|
||
|
_query = lambda expr: expr.match(query)
|
||
|
if exact is None:
|
||
|
from .symbol import Wild
|
||
|
exact = (len(query.atoms(Wild)) > 1)
|
||
|
|
||
|
if isinstance(value, Basic):
|
||
|
if exact:
|
||
|
_value = lambda expr, result: (value.subs(result)
|
||
|
if all(result.values()) else expr)
|
||
|
else:
|
||
|
_value = lambda expr, result: value.subs(result)
|
||
|
elif callable(value):
|
||
|
# match dictionary keys get the trailing underscore stripped
|
||
|
# from them and are then passed as keywords to the callable;
|
||
|
# if ``exact`` is True, only accept match if there are no null
|
||
|
# values amongst those matched.
|
||
|
if exact:
|
||
|
_value = lambda expr, result: (value(**
|
||
|
{str(k)[:-1]: v for k, v in result.items()})
|
||
|
if all(val for val in result.values()) else expr)
|
||
|
else:
|
||
|
_value = lambda expr, result: value(**
|
||
|
{str(k)[:-1]: v for k, v in result.items()})
|
||
|
else:
|
||
|
raise TypeError(
|
||
|
"given an expression, replace() expects "
|
||
|
"another expression or a callable")
|
||
|
elif callable(query):
|
||
|
_query = query
|
||
|
|
||
|
if callable(value):
|
||
|
_value = lambda expr, result: value(expr)
|
||
|
else:
|
||
|
raise TypeError(
|
||
|
"given a callable, replace() expects "
|
||
|
"another callable")
|
||
|
else:
|
||
|
raise TypeError(
|
||
|
"first argument to replace() must be a "
|
||
|
"type, an expression or a callable")
|
||
|
|
||
|
def walk(rv, F):
|
||
|
"""Apply ``F`` to args and then to result.
|
||
|
"""
|
||
|
args = getattr(rv, 'args', None)
|
||
|
if args is not None:
|
||
|
if args:
|
||
|
newargs = tuple([walk(a, F) for a in args])
|
||
|
if args != newargs:
|
||
|
rv = rv.func(*newargs)
|
||
|
if simultaneous:
|
||
|
# if rv is something that was already
|
||
|
# matched (that was changed) then skip
|
||
|
# applying F again
|
||
|
for i, e in enumerate(args):
|
||
|
if rv == e and e != newargs[i]:
|
||
|
return rv
|
||
|
rv = F(rv)
|
||
|
return rv
|
||
|
|
||
|
mapping = {} # changes that took place
|
||
|
|
||
|
def rec_replace(expr):
|
||
|
result = _query(expr)
|
||
|
if result or result == {}:
|
||
|
v = _value(expr, result)
|
||
|
if v is not None and v != expr:
|
||
|
if map:
|
||
|
mapping[expr] = v
|
||
|
expr = v
|
||
|
return expr
|
||
|
|
||
|
rv = walk(self, rec_replace)
|
||
|
return (rv, mapping) if map else rv
|
||
|
|
||
|
def find(self, query, group=False):
|
||
|
"""Find all subexpressions matching a query."""
|
||
|
query = _make_find_query(query)
|
||
|
results = list(filter(query, _preorder_traversal(self)))
|
||
|
|
||
|
if not group:
|
||
|
return set(results)
|
||
|
else:
|
||
|
groups = {}
|
||
|
|
||
|
for result in results:
|
||
|
if result in groups:
|
||
|
groups[result] += 1
|
||
|
else:
|
||
|
groups[result] = 1
|
||
|
|
||
|
return groups
|
||
|
|
||
|
def count(self, query):
|
||
|
"""Count the number of matching subexpressions."""
|
||
|
query = _make_find_query(query)
|
||
|
return sum(bool(query(sub)) for sub in _preorder_traversal(self))
|
||
|
|
||
|
def matches(self, expr, repl_dict=None, old=False):
|
||
|
"""
|
||
|
Helper method for match() that looks for a match between Wild symbols
|
||
|
in self and expressions in expr.
|
||
|
|
||
|
Examples
|
||
|
========
|
||
|
|
||
|
>>> from sympy import symbols, Wild, Basic
|
||
|
>>> a, b, c = symbols('a b c')
|
||
|
>>> x = Wild('x')
|
||
|
>>> Basic(a + x, x).matches(Basic(a + b, c)) is None
|
||
|
True
|
||
|
>>> Basic(a + x, x).matches(Basic(a + b + c, b + c))
|
||
|
{x_: b + c}
|
||
|
"""
|
||
|
expr = sympify(expr)
|
||
|
if not isinstance(expr, self.__class__):
|
||
|
return None
|
||
|
|
||
|
if repl_dict is None:
|
||
|
repl_dict = {}
|
||
|
else:
|
||
|
repl_dict = repl_dict.copy()
|
||
|
|
||
|
if self == expr:
|
||
|
return repl_dict
|
||
|
|
||
|
if len(self.args) != len(expr.args):
|
||
|
return None
|
||
|
|
||
|
d = repl_dict # already a copy
|
||
|
for arg, other_arg in zip(self.args, expr.args):
|
||
|
if arg == other_arg:
|
||
|
continue
|
||
|
if arg.is_Relational:
|
||
|
try:
|
||
|
d = arg.xreplace(d).matches(other_arg, d, old=old)
|
||
|
except TypeError: # Should be InvalidComparisonError when introduced
|
||
|
d = None
|
||
|
else:
|
||
|
d = arg.xreplace(d).matches(other_arg, d, old=old)
|
||
|
if d is None:
|
||
|
return None
|
||
|
return d
|
||
|
|
||
|
def match(self, pattern, old=False):
|
||
|
"""
|
||
|
Pattern matching.
|
||
|
|
||
|
Wild symbols match all.
|
||
|
|
||
|
Return ``None`` when expression (self) does not match
|
||
|
with pattern. Otherwise return a dictionary such that::
|
||
|
|
||
|
pattern.xreplace(self.match(pattern)) == self
|
||
|
|
||
|
Examples
|
||
|
========
|
||
|
|
||
|
>>> from sympy import Wild, Sum
|
||
|
>>> from sympy.abc import x, y
|
||
|
>>> p = Wild("p")
|
||
|
>>> q = Wild("q")
|
||
|
>>> r = Wild("r")
|
||
|
>>> e = (x+y)**(x+y)
|
||
|
>>> e.match(p**p)
|
||
|
{p_: x + y}
|
||
|
>>> e.match(p**q)
|
||
|
{p_: x + y, q_: x + y}
|
||
|
>>> e = (2*x)**2
|
||
|
>>> e.match(p*q**r)
|
||
|
{p_: 4, q_: x, r_: 2}
|
||
|
>>> (p*q**r).xreplace(e.match(p*q**r))
|
||
|
4*x**2
|
||
|
|
||
|
Structurally bound symbols are ignored during matching:
|
||
|
|
||
|
>>> Sum(x, (x, 1, 2)).match(Sum(y, (y, 1, p)))
|
||
|
{p_: 2}
|
||
|
|
||
|
But they can be identified if desired:
|
||
|
|
||
|
>>> Sum(x, (x, 1, 2)).match(Sum(q, (q, 1, p)))
|
||
|
{p_: 2, q_: x}
|
||
|
|
||
|
The ``old`` flag will give the old-style pattern matching where
|
||
|
expressions and patterns are essentially solved to give the
|
||
|
match. Both of the following give None unless ``old=True``:
|
||
|
|
||
|
>>> (x - 2).match(p - x, old=True)
|
||
|
{p_: 2*x - 2}
|
||
|
>>> (2/x).match(p*x, old=True)
|
||
|
{p_: 2/x**2}
|
||
|
|
||
|
"""
|
||
|
pattern = sympify(pattern)
|
||
|
# match non-bound symbols
|
||
|
canonical = lambda x: x if x.is_Symbol else x.as_dummy()
|
||
|
m = canonical(pattern).matches(canonical(self), old=old)
|
||
|
if m is None:
|
||
|
return m
|
||
|
from .symbol import Wild
|
||
|
from .function import WildFunction
|
||
|
from ..tensor.tensor import WildTensor, WildTensorIndex, WildTensorHead
|
||
|
wild = pattern.atoms(Wild, WildFunction, WildTensor, WildTensorIndex, WildTensorHead)
|
||
|
# sanity check
|
||
|
if set(m) - wild:
|
||
|
raise ValueError(filldedent('''
|
||
|
Some `matches` routine did not use a copy of repl_dict
|
||
|
and injected unexpected symbols. Report this as an
|
||
|
error at https://github.com/sympy/sympy/issues'''))
|
||
|
# now see if bound symbols were requested
|
||
|
bwild = wild - set(m)
|
||
|
if not bwild:
|
||
|
return m
|
||
|
# replace free-Wild symbols in pattern with match result
|
||
|
# so they will match but not be in the next match
|
||
|
wpat = pattern.xreplace(m)
|
||
|
# identify remaining bound wild
|
||
|
w = wpat.matches(self, old=old)
|
||
|
# add them to m
|
||
|
if w:
|
||
|
m.update(w)
|
||
|
# done
|
||
|
return m
|
||
|
|
||
|
def count_ops(self, visual=None):
|
||
|
"""Wrapper for count_ops that returns the operation count."""
|
||
|
from .function import count_ops
|
||
|
return count_ops(self, visual)
|
||
|
|
||
|
def doit(self, **hints):
|
||
|
"""Evaluate objects that are not evaluated by default like limits,
|
||
|
integrals, sums and products. All objects of this kind will be
|
||
|
evaluated recursively, unless some species were excluded via 'hints'
|
||
|
or unless the 'deep' hint was set to 'False'.
|
||
|
|
||
|
>>> from sympy import Integral
|
||
|
>>> from sympy.abc import x
|
||
|
|
||
|
>>> 2*Integral(x, x)
|
||
|
2*Integral(x, x)
|
||
|
|
||
|
>>> (2*Integral(x, x)).doit()
|
||
|
x**2
|
||
|
|
||
|
>>> (2*Integral(x, x)).doit(deep=False)
|
||
|
2*Integral(x, x)
|
||
|
|
||
|
"""
|
||
|
if hints.get('deep', True):
|
||
|
terms = [term.doit(**hints) if isinstance(term, Basic) else term
|
||
|
for term in self.args]
|
||
|
return self.func(*terms)
|
||
|
else:
|
||
|
return self
|
||
|
|
||
|
def simplify(self, **kwargs):
|
||
|
"""See the simplify function in sympy.simplify"""
|
||
|
from sympy.simplify.simplify import simplify
|
||
|
return simplify(self, **kwargs)
|
||
|
|
||
|
def refine(self, assumption=True):
|
||
|
"""See the refine function in sympy.assumptions"""
|
||
|
from sympy.assumptions.refine import refine
|
||
|
return refine(self, assumption)
|
||
|
|
||
|
def _eval_derivative_n_times(self, s, n):
|
||
|
# This is the default evaluator for derivatives (as called by `diff`
|
||
|
# and `Derivative`), it will attempt a loop to derive the expression
|
||
|
# `n` times by calling the corresponding `_eval_derivative` method,
|
||
|
# while leaving the derivative unevaluated if `n` is symbolic. This
|
||
|
# method should be overridden if the object has a closed form for its
|
||
|
# symbolic n-th derivative.
|
||
|
from .numbers import Integer
|
||
|
if isinstance(n, (int, Integer)):
|
||
|
obj = self
|
||
|
for i in range(n):
|
||
|
obj2 = obj._eval_derivative(s)
|
||
|
if obj == obj2 or obj2 is None:
|
||
|
break
|
||
|
obj = obj2
|
||
|
return obj2
|
||
|
else:
|
||
|
return None
|
||
|
|
||
|
def rewrite(self, *args, deep=True, **hints):
|
||
|
"""
|
||
|
Rewrite *self* using a defined rule.
|
||
|
|
||
|
Rewriting transforms an expression to another, which is mathematically
|
||
|
equivalent but structurally different. For example you can rewrite
|
||
|
trigonometric functions as complex exponentials or combinatorial
|
||
|
functions as gamma function.
|
||
|
|
||
|
This method takes a *pattern* and a *rule* as positional arguments.
|
||
|
*pattern* is optional parameter which defines the types of expressions
|
||
|
that will be transformed. If it is not passed, all possible expressions
|
||
|
will be rewritten. *rule* defines how the expression will be rewritten.
|
||
|
|
||
|
Parameters
|
||
|
==========
|
||
|
|
||
|
args : Expr
|
||
|
A *rule*, or *pattern* and *rule*.
|
||
|
- *pattern* is a type or an iterable of types.
|
||
|
- *rule* can be any object.
|
||
|
|
||
|
deep : bool, optional
|
||
|
If ``True``, subexpressions are recursively transformed. Default is
|
||
|
``True``.
|
||
|
|
||
|
Examples
|
||
|
========
|
||
|
|
||
|
If *pattern* is unspecified, all possible expressions are transformed.
|
||
|
|
||
|
>>> from sympy import cos, sin, exp, I
|
||
|
>>> from sympy.abc import x
|
||
|
>>> expr = cos(x) + I*sin(x)
|
||
|
>>> expr.rewrite(exp)
|
||
|
exp(I*x)
|
||
|
|
||
|
Pattern can be a type or an iterable of types.
|
||
|
|
||
|
>>> expr.rewrite(sin, exp)
|
||
|
exp(I*x)/2 + cos(x) - exp(-I*x)/2
|
||
|
>>> expr.rewrite([cos,], exp)
|
||
|
exp(I*x)/2 + I*sin(x) + exp(-I*x)/2
|
||
|
>>> expr.rewrite([cos, sin], exp)
|
||
|
exp(I*x)
|
||
|
|
||
|
Rewriting behavior can be implemented by defining ``_eval_rewrite()``
|
||
|
method.
|
||
|
|
||
|
>>> from sympy import Expr, sqrt, pi
|
||
|
>>> class MySin(Expr):
|
||
|
... def _eval_rewrite(self, rule, args, **hints):
|
||
|
... x, = args
|
||
|
... if rule == cos:
|
||
|
... return cos(pi/2 - x, evaluate=False)
|
||
|
... if rule == sqrt:
|
||
|
... return sqrt(1 - cos(x)**2)
|
||
|
>>> MySin(MySin(x)).rewrite(cos)
|
||
|
cos(-cos(-x + pi/2) + pi/2)
|
||
|
>>> MySin(x).rewrite(sqrt)
|
||
|
sqrt(1 - cos(x)**2)
|
||
|
|
||
|
Defining ``_eval_rewrite_as_[...]()`` method is supported for backwards
|
||
|
compatibility reason. This may be removed in the future and using it is
|
||
|
discouraged.
|
||
|
|
||
|
>>> class MySin(Expr):
|
||
|
... def _eval_rewrite_as_cos(self, *args, **hints):
|
||
|
... x, = args
|
||
|
... return cos(pi/2 - x, evaluate=False)
|
||
|
>>> MySin(x).rewrite(cos)
|
||
|
cos(-x + pi/2)
|
||
|
|
||
|
"""
|
||
|
if not args:
|
||
|
return self
|
||
|
|
||
|
hints.update(deep=deep)
|
||
|
|
||
|
pattern = args[:-1]
|
||
|
rule = args[-1]
|
||
|
|
||
|
# support old design by _eval_rewrite_as_[...] method
|
||
|
if isinstance(rule, str):
|
||
|
method = "_eval_rewrite_as_%s" % rule
|
||
|
elif hasattr(rule, "__name__"):
|
||
|
# rule is class or function
|
||
|
clsname = rule.__name__
|
||
|
method = "_eval_rewrite_as_%s" % clsname
|
||
|
else:
|
||
|
# rule is instance
|
||
|
clsname = rule.__class__.__name__
|
||
|
method = "_eval_rewrite_as_%s" % clsname
|
||
|
|
||
|
if pattern:
|
||
|
if iterable(pattern[0]):
|
||
|
pattern = pattern[0]
|
||
|
pattern = tuple(p for p in pattern if self.has(p))
|
||
|
if not pattern:
|
||
|
return self
|
||
|
# hereafter, empty pattern is interpreted as all pattern.
|
||
|
|
||
|
return self._rewrite(pattern, rule, method, **hints)
|
||
|
|
||
|
def _rewrite(self, pattern, rule, method, **hints):
|
||
|
deep = hints.pop('deep', True)
|
||
|
if deep:
|
||
|
args = [a._rewrite(pattern, rule, method, **hints)
|
||
|
for a in self.args]
|
||
|
else:
|
||
|
args = self.args
|
||
|
if not pattern or any(isinstance(self, p) for p in pattern):
|
||
|
meth = getattr(self, method, None)
|
||
|
if meth is not None:
|
||
|
rewritten = meth(*args, **hints)
|
||
|
else:
|
||
|
rewritten = self._eval_rewrite(rule, args, **hints)
|
||
|
if rewritten is not None:
|
||
|
return rewritten
|
||
|
if not args:
|
||
|
return self
|
||
|
return self.func(*args)
|
||
|
|
||
|
def _eval_rewrite(self, rule, args, **hints):
|
||
|
return None
|
||
|
|
||
|
_constructor_postprocessor_mapping = {} # type: ignore
|
||
|
|
||
|
@classmethod
|
||
|
def _exec_constructor_postprocessors(cls, obj):
|
||
|
# WARNING: This API is experimental.
|
||
|
|
||
|
# This is an experimental API that introduces constructor
|
||
|
# postprosessors for SymPy Core elements. If an argument of a SymPy
|
||
|
# expression has a `_constructor_postprocessor_mapping` attribute, it will
|
||
|
# be interpreted as a dictionary containing lists of postprocessing
|
||
|
# functions for matching expression node names.
|
||
|
|
||
|
clsname = obj.__class__.__name__
|
||
|
postprocessors = defaultdict(list)
|
||
|
for i in obj.args:
|
||
|
try:
|
||
|
postprocessor_mappings = (
|
||
|
Basic._constructor_postprocessor_mapping[cls].items()
|
||
|
for cls in type(i).mro()
|
||
|
if cls in Basic._constructor_postprocessor_mapping
|
||
|
)
|
||
|
for k, v in chain.from_iterable(postprocessor_mappings):
|
||
|
postprocessors[k].extend([j for j in v if j not in postprocessors[k]])
|
||
|
except TypeError:
|
||
|
pass
|
||
|
|
||
|
for f in postprocessors.get(clsname, []):
|
||
|
obj = f(obj)
|
||
|
|
||
|
return obj
|
||
|
|
||
|
def _sage_(self):
|
||
|
"""
|
||
|
Convert *self* to a symbolic expression of SageMath.
|
||
|
|
||
|
This version of the method is merely a placeholder.
|
||
|
"""
|
||
|
old_method = self._sage_
|
||
|
from sage.interfaces.sympy import sympy_init
|
||
|
sympy_init() # may monkey-patch _sage_ method into self's class or superclasses
|
||
|
if old_method == self._sage_:
|
||
|
raise NotImplementedError('conversion to SageMath is not implemented')
|
||
|
else:
|
||
|
# call the freshly monkey-patched method
|
||
|
return self._sage_()
|
||
|
|
||
|
def could_extract_minus_sign(self):
|
||
|
return False # see Expr.could_extract_minus_sign
|
||
|
|
||
|
|
||
|
# For all Basic subclasses _prepare_class_assumptions is called by
|
||
|
# Basic.__init_subclass__ but that method is not called for Basic itself so we
|
||
|
# call the function here instead.
|
||
|
_prepare_class_assumptions(Basic)
|
||
|
|
||
|
|
||
|
class Atom(Basic):
|
||
|
"""
|
||
|
A parent class for atomic things. An atom is an expression with no subexpressions.
|
||
|
|
||
|
Examples
|
||
|
========
|
||
|
|
||
|
Symbol, Number, Rational, Integer, ...
|
||
|
But not: Add, Mul, Pow, ...
|
||
|
"""
|
||
|
|
||
|
is_Atom = True
|
||
|
|
||
|
__slots__ = ()
|
||
|
|
||
|
def matches(self, expr, repl_dict=None, old=False):
|
||
|
if self == expr:
|
||
|
if repl_dict is None:
|
||
|
return {}
|
||
|
return repl_dict.copy()
|
||
|
|
||
|
def xreplace(self, rule, hack2=False):
|
||
|
return rule.get(self, self)
|
||
|
|
||
|
def doit(self, **hints):
|
||
|
return self
|
||
|
|
||
|
@classmethod
|
||
|
def class_key(cls):
|
||
|
return 2, 0, cls.__name__
|
||
|
|
||
|
@cacheit
|
||
|
def sort_key(self, order=None):
|
||
|
return self.class_key(), (1, (str(self),)), S.One.sort_key(), S.One
|
||
|
|
||
|
def _eval_simplify(self, **kwargs):
|
||
|
return self
|
||
|
|
||
|
@property
|
||
|
def _sorted_args(self):
|
||
|
# this is here as a safeguard against accidentally using _sorted_args
|
||
|
# on Atoms -- they cannot be rebuilt as atom.func(*atom._sorted_args)
|
||
|
# since there are no args. So the calling routine should be checking
|
||
|
# to see that this property is not called for Atoms.
|
||
|
raise AttributeError('Atoms have no args. It might be necessary'
|
||
|
' to make a check for Atoms in the calling code.')
|
||
|
|
||
|
|
||
|
def _aresame(a, b):
|
||
|
"""Return True if a and b are structurally the same, else False.
|
||
|
|
||
|
Examples
|
||
|
========
|
||
|
|
||
|
In SymPy (as in Python) two numbers compare the same if they
|
||
|
have the same underlying base-2 representation even though
|
||
|
they may not be the same type:
|
||
|
|
||
|
>>> from sympy import S
|
||
|
>>> 2.0 == S(2)
|
||
|
True
|
||
|
>>> 0.5 == S.Half
|
||
|
True
|
||
|
|
||
|
This routine was written to provide a query for such cases that
|
||
|
would give false when the types do not match:
|
||
|
|
||
|
>>> from sympy.core.basic import _aresame
|
||
|
>>> _aresame(S(2.0), S(2))
|
||
|
False
|
||
|
|
||
|
"""
|
||
|
from .numbers import Number
|
||
|
from .function import AppliedUndef, UndefinedFunction as UndefFunc
|
||
|
if isinstance(a, Number) and isinstance(b, Number):
|
||
|
return a == b and a.__class__ == b.__class__
|
||
|
for i, j in zip_longest(_preorder_traversal(a), _preorder_traversal(b)):
|
||
|
if i != j or type(i) != type(j):
|
||
|
if ((isinstance(i, UndefFunc) and isinstance(j, UndefFunc)) or
|
||
|
(isinstance(i, AppliedUndef) and isinstance(j, AppliedUndef))):
|
||
|
if i.class_key() != j.class_key():
|
||
|
return False
|
||
|
else:
|
||
|
return False
|
||
|
return True
|
||
|
|
||
|
|
||
|
def _ne(a, b):
|
||
|
# use this as a second test after `a != b` if you want to make
|
||
|
# sure that things are truly equal, e.g.
|
||
|
# a, b = 0.5, S.Half
|
||
|
# a !=b or _ne(a, b) -> True
|
||
|
from .numbers import Number
|
||
|
# 0.5 == S.Half
|
||
|
if isinstance(a, Number) and isinstance(b, Number):
|
||
|
return a.__class__ != b.__class__
|
||
|
|
||
|
|
||
|
def _atomic(e, recursive=False):
|
||
|
"""Return atom-like quantities as far as substitution is
|
||
|
concerned: Derivatives, Functions and Symbols. Do not
|
||
|
return any 'atoms' that are inside such quantities unless
|
||
|
they also appear outside, too, unless `recursive` is True.
|
||
|
|
||
|
Examples
|
||
|
========
|
||
|
|
||
|
>>> from sympy import Derivative, Function, cos
|
||
|
>>> from sympy.abc import x, y
|
||
|
>>> from sympy.core.basic import _atomic
|
||
|
>>> f = Function('f')
|
||
|
>>> _atomic(x + y)
|
||
|
{x, y}
|
||
|
>>> _atomic(x + f(y))
|
||
|
{x, f(y)}
|
||
|
>>> _atomic(Derivative(f(x), x) + cos(x) + y)
|
||
|
{y, cos(x), Derivative(f(x), x)}
|
||
|
|
||
|
"""
|
||
|
pot = _preorder_traversal(e)
|
||
|
seen = set()
|
||
|
if isinstance(e, Basic):
|
||
|
free = getattr(e, "free_symbols", None)
|
||
|
if free is None:
|
||
|
return {e}
|
||
|
else:
|
||
|
return set()
|
||
|
from .symbol import Symbol
|
||
|
from .function import Derivative, Function
|
||
|
atoms = set()
|
||
|
for p in pot:
|
||
|
if p in seen:
|
||
|
pot.skip()
|
||
|
continue
|
||
|
seen.add(p)
|
||
|
if isinstance(p, Symbol) and p in free:
|
||
|
atoms.add(p)
|
||
|
elif isinstance(p, (Derivative, Function)):
|
||
|
if not recursive:
|
||
|
pot.skip()
|
||
|
atoms.add(p)
|
||
|
return atoms
|
||
|
|
||
|
|
||
|
def _make_find_query(query):
|
||
|
"""Convert the argument of Basic.find() into a callable"""
|
||
|
try:
|
||
|
query = _sympify(query)
|
||
|
except SympifyError:
|
||
|
pass
|
||
|
if isinstance(query, type):
|
||
|
return lambda expr: isinstance(expr, query)
|
||
|
elif isinstance(query, Basic):
|
||
|
return lambda expr: expr.match(query) is not None
|
||
|
return query
|
||
|
|
||
|
# Delayed to avoid cyclic import
|
||
|
from .singleton import S
|
||
|
from .traversal import (preorder_traversal as _preorder_traversal,
|
||
|
iterargs, iterfreeargs)
|
||
|
|
||
|
preorder_traversal = deprecated(
|
||
|
"""
|
||
|
Using preorder_traversal from the sympy.core.basic submodule is
|
||
|
deprecated.
|
||
|
|
||
|
Instead, use preorder_traversal from the top-level sympy namespace, like
|
||
|
|
||
|
sympy.preorder_traversal
|
||
|
""",
|
||
|
deprecated_since_version="1.10",
|
||
|
active_deprecations_target="deprecated-traversal-functions-moved",
|
||
|
)(_preorder_traversal)
|