""" Functions to support rewriting of SymPy expressions """ from sympy.core.expr import Expr from sympy.assumptions import ask from sympy.strategies.tools import subs from sympy.unify.usympy import rebuild, unify def rewriterule(source, target, variables=(), condition=None, assume=None): """ Rewrite rule. Transform expressions that match source into expressions that match target treating all ``variables`` as wilds. Examples ======== >>> from sympy.abc import w, x, y, z >>> from sympy.unify.rewrite import rewriterule >>> from sympy import default_sort_key >>> rl = rewriterule(x + y, x**y, [x, y]) >>> sorted(rl(z + 3), key=default_sort_key) [3**z, z**3] Use ``condition`` to specify additional requirements. Inputs are taken in the same order as is found in variables. >>> rl = rewriterule(x + y, x**y, [x, y], lambda x, y: x.is_integer) >>> list(rl(z + 3)) [3**z] Use ``assume`` to specify additional requirements using new assumptions. >>> from sympy.assumptions import Q >>> rl = rewriterule(x + y, x**y, [x, y], assume=Q.integer(x)) >>> list(rl(z + 3)) [3**z] Assumptions for the local context are provided at rule runtime >>> list(rl(w + z, Q.integer(z))) [z**w] """ def rewrite_rl(expr, assumptions=True): for match in unify(source, expr, {}, variables=variables): if (condition and not condition(*[match.get(var, var) for var in variables])): continue if (assume and not ask(assume.xreplace(match), assumptions)): continue expr2 = subs(match)(target) if isinstance(expr2, Expr): expr2 = rebuild(expr2) yield expr2 return rewrite_rl