307 lines
8.8 KiB
Python
307 lines
8.8 KiB
Python
from .basic import Basic
|
|
from .sorting import ordered
|
|
from .sympify import sympify
|
|
from sympy.utilities.iterables import iterable
|
|
|
|
|
|
|
|
def iterargs(expr):
|
|
"""Yield the args of a Basic object in a breadth-first traversal.
|
|
Depth-traversal stops if `arg.args` is either empty or is not
|
|
an iterable.
|
|
|
|
Examples
|
|
========
|
|
|
|
>>> from sympy import Integral, Function
|
|
>>> from sympy.abc import x
|
|
>>> f = Function('f')
|
|
>>> from sympy.core.traversal import iterargs
|
|
>>> list(iterargs(Integral(f(x), (f(x), 1))))
|
|
[Integral(f(x), (f(x), 1)), f(x), (f(x), 1), x, f(x), 1, x]
|
|
|
|
See Also
|
|
========
|
|
iterfreeargs, preorder_traversal
|
|
"""
|
|
args = [expr]
|
|
for i in args:
|
|
yield i
|
|
try:
|
|
args.extend(i.args)
|
|
except TypeError:
|
|
pass # for cases like f being an arg
|
|
|
|
|
|
def iterfreeargs(expr, _first=True):
|
|
"""Yield the args of a Basic object in a breadth-first traversal.
|
|
Depth-traversal stops if `arg.args` is either empty or is not
|
|
an iterable. The bound objects of an expression will be returned
|
|
as canonical variables.
|
|
|
|
Examples
|
|
========
|
|
|
|
>>> from sympy import Integral, Function
|
|
>>> from sympy.abc import x
|
|
>>> f = Function('f')
|
|
>>> from sympy.core.traversal import iterfreeargs
|
|
>>> list(iterfreeargs(Integral(f(x), (f(x), 1))))
|
|
[Integral(f(x), (f(x), 1)), 1]
|
|
|
|
See Also
|
|
========
|
|
iterargs, preorder_traversal
|
|
"""
|
|
args = [expr]
|
|
for i in args:
|
|
yield i
|
|
if _first and hasattr(i, 'bound_symbols'):
|
|
void = i.canonical_variables.values()
|
|
for i in iterfreeargs(i.as_dummy(), _first=False):
|
|
if not i.has(*void):
|
|
yield i
|
|
try:
|
|
args.extend(i.args)
|
|
except TypeError:
|
|
pass # for cases like f being an arg
|
|
|
|
|
|
class preorder_traversal:
|
|
"""
|
|
Do a pre-order traversal of a tree.
|
|
|
|
This iterator recursively yields nodes that it has visited in a pre-order
|
|
fashion. That is, it yields the current node then descends through the
|
|
tree breadth-first to yield all of a node's children's pre-order
|
|
traversal.
|
|
|
|
|
|
For an expression, the order of the traversal depends on the order of
|
|
.args, which in many cases can be arbitrary.
|
|
|
|
Parameters
|
|
==========
|
|
node : SymPy expression
|
|
The expression to traverse.
|
|
keys : (default None) sort key(s)
|
|
The key(s) used to sort args of Basic objects. When None, args of Basic
|
|
objects are processed in arbitrary order. If key is defined, it will
|
|
be passed along to ordered() as the only key(s) to use to sort the
|
|
arguments; if ``key`` is simply True then the default keys of ordered
|
|
will be used.
|
|
|
|
Yields
|
|
======
|
|
subtree : SymPy expression
|
|
All of the subtrees in the tree.
|
|
|
|
Examples
|
|
========
|
|
|
|
>>> from sympy import preorder_traversal, symbols
|
|
>>> x, y, z = symbols('x y z')
|
|
|
|
The nodes are returned in the order that they are encountered unless key
|
|
is given; simply passing key=True will guarantee that the traversal is
|
|
unique.
|
|
|
|
>>> list(preorder_traversal((x + y)*z, keys=None)) # doctest: +SKIP
|
|
[z*(x + y), z, x + y, y, x]
|
|
>>> list(preorder_traversal((x + y)*z, keys=True))
|
|
[z*(x + y), z, x + y, x, y]
|
|
|
|
"""
|
|
def __init__(self, node, keys=None):
|
|
self._skip_flag = False
|
|
self._pt = self._preorder_traversal(node, keys)
|
|
|
|
def _preorder_traversal(self, node, keys):
|
|
yield node
|
|
if self._skip_flag:
|
|
self._skip_flag = False
|
|
return
|
|
if isinstance(node, Basic):
|
|
if not keys and hasattr(node, '_argset'):
|
|
# LatticeOp keeps args as a set. We should use this if we
|
|
# don't care about the order, to prevent unnecessary sorting.
|
|
args = node._argset
|
|
else:
|
|
args = node.args
|
|
if keys:
|
|
if keys != True:
|
|
args = ordered(args, keys, default=False)
|
|
else:
|
|
args = ordered(args)
|
|
for arg in args:
|
|
yield from self._preorder_traversal(arg, keys)
|
|
elif iterable(node):
|
|
for item in node:
|
|
yield from self._preorder_traversal(item, keys)
|
|
|
|
def skip(self):
|
|
"""
|
|
Skip yielding current node's (last yielded node's) subtrees.
|
|
|
|
Examples
|
|
========
|
|
|
|
>>> from sympy import preorder_traversal, symbols
|
|
>>> x, y, z = symbols('x y z')
|
|
>>> pt = preorder_traversal((x + y*z)*z)
|
|
>>> for i in pt:
|
|
... print(i)
|
|
... if i == x + y*z:
|
|
... pt.skip()
|
|
z*(x + y*z)
|
|
z
|
|
x + y*z
|
|
"""
|
|
self._skip_flag = True
|
|
|
|
def __next__(self):
|
|
return next(self._pt)
|
|
|
|
def __iter__(self):
|
|
return self
|
|
|
|
|
|
def use(expr, func, level=0, args=(), kwargs={}):
|
|
"""
|
|
Use ``func`` to transform ``expr`` at the given level.
|
|
|
|
Examples
|
|
========
|
|
|
|
>>> from sympy import use, expand
|
|
>>> from sympy.abc import x, y
|
|
|
|
>>> f = (x + y)**2*x + 1
|
|
|
|
>>> use(f, expand, level=2)
|
|
x*(x**2 + 2*x*y + y**2) + 1
|
|
>>> expand(f)
|
|
x**3 + 2*x**2*y + x*y**2 + 1
|
|
|
|
"""
|
|
def _use(expr, level):
|
|
if not level:
|
|
return func(expr, *args, **kwargs)
|
|
else:
|
|
if expr.is_Atom:
|
|
return expr
|
|
else:
|
|
level -= 1
|
|
_args = [_use(arg, level) for arg in expr.args]
|
|
return expr.__class__(*_args)
|
|
|
|
return _use(sympify(expr), level)
|
|
|
|
|
|
def walk(e, *target):
|
|
"""Iterate through the args that are the given types (target) and
|
|
return a list of the args that were traversed; arguments
|
|
that are not of the specified types are not traversed.
|
|
|
|
Examples
|
|
========
|
|
|
|
>>> from sympy.core.traversal import walk
|
|
>>> from sympy import Min, Max
|
|
>>> from sympy.abc import x, y, z
|
|
>>> list(walk(Min(x, Max(y, Min(1, z))), Min))
|
|
[Min(x, Max(y, Min(1, z)))]
|
|
>>> list(walk(Min(x, Max(y, Min(1, z))), Min, Max))
|
|
[Min(x, Max(y, Min(1, z))), Max(y, Min(1, z)), Min(1, z)]
|
|
|
|
See Also
|
|
========
|
|
|
|
bottom_up
|
|
"""
|
|
if isinstance(e, target):
|
|
yield e
|
|
for i in e.args:
|
|
yield from walk(i, *target)
|
|
|
|
|
|
def bottom_up(rv, F, atoms=False, nonbasic=False):
|
|
"""Apply ``F`` to all expressions in an expression tree from the
|
|
bottom up. If ``atoms`` is True, apply ``F`` even if there are no args;
|
|
if ``nonbasic`` is True, try to apply ``F`` to non-Basic objects.
|
|
"""
|
|
args = getattr(rv, 'args', None)
|
|
if args is not None:
|
|
if args:
|
|
args = tuple([bottom_up(a, F, atoms, nonbasic) for a in args])
|
|
if args != rv.args:
|
|
rv = rv.func(*args)
|
|
rv = F(rv)
|
|
elif atoms:
|
|
rv = F(rv)
|
|
else:
|
|
if nonbasic:
|
|
try:
|
|
rv = F(rv)
|
|
except TypeError:
|
|
pass
|
|
|
|
return rv
|
|
|
|
|
|
def postorder_traversal(node, keys=None):
|
|
"""
|
|
Do a postorder traversal of a tree.
|
|
|
|
This generator recursively yields nodes that it has visited in a postorder
|
|
fashion. That is, it descends through the tree depth-first to yield all of
|
|
a node's children's postorder traversal before yielding the node itself.
|
|
|
|
Parameters
|
|
==========
|
|
|
|
node : SymPy expression
|
|
The expression to traverse.
|
|
keys : (default None) sort key(s)
|
|
The key(s) used to sort args of Basic objects. When None, args of Basic
|
|
objects are processed in arbitrary order. If key is defined, it will
|
|
be passed along to ordered() as the only key(s) to use to sort the
|
|
arguments; if ``key`` is simply True then the default keys of
|
|
``ordered`` will be used (node count and default_sort_key).
|
|
|
|
Yields
|
|
======
|
|
subtree : SymPy expression
|
|
All of the subtrees in the tree.
|
|
|
|
Examples
|
|
========
|
|
|
|
>>> from sympy import postorder_traversal
|
|
>>> from sympy.abc import w, x, y, z
|
|
|
|
The nodes are returned in the order that they are encountered unless key
|
|
is given; simply passing key=True will guarantee that the traversal is
|
|
unique.
|
|
|
|
>>> list(postorder_traversal(w + (x + y)*z)) # doctest: +SKIP
|
|
[z, y, x, x + y, z*(x + y), w, w + z*(x + y)]
|
|
>>> list(postorder_traversal(w + (x + y)*z, keys=True))
|
|
[w, z, x, y, x + y, z*(x + y), w + z*(x + y)]
|
|
|
|
|
|
"""
|
|
if isinstance(node, Basic):
|
|
args = node.args
|
|
if keys:
|
|
if keys != True:
|
|
args = ordered(args, keys, default=False)
|
|
else:
|
|
args = ordered(args)
|
|
for arg in args:
|
|
yield from postorder_traversal(arg, keys)
|
|
elif iterable(node):
|
|
for item in node:
|
|
yield from postorder_traversal(item, keys)
|
|
yield node
|