""" Generic Unification algorithm for expression trees with lists of children This implementation is a direct translation of Artificial Intelligence: A Modern Approach by Stuart Russel and Peter Norvig Second edition, section 9.2, page 276 It is modified in the following ways: 1. We allow associative and commutative Compound expressions. This results in combinatorial blowup. 2. We explore the tree lazily. 3. We provide generic interfaces to symbolic algebra libraries in Python. A more traditional version can be found here http://aima.cs.berkeley.edu/python/logic.html """ from sympy.utilities.iterables import kbins class Compound: """ A little class to represent an interior node in the tree This is analogous to SymPy.Basic for non-Atoms """ def __init__(self, op, args): self.op = op self.args = args def __eq__(self, other): return (type(self) is type(other) and self.op == other.op and self.args == other.args) def __hash__(self): return hash((type(self), self.op, self.args)) def __str__(self): return "%s[%s]" % (str(self.op), ', '.join(map(str, self.args))) class Variable: """ A Wild token """ def __init__(self, arg): self.arg = arg def __eq__(self, other): return type(self) is type(other) and self.arg == other.arg def __hash__(self): return hash((type(self), self.arg)) def __str__(self): return "Variable(%s)" % str(self.arg) class CondVariable: """ A wild token that matches conditionally. arg - a wild token. valid - an additional constraining function on a match. """ def __init__(self, arg, valid): self.arg = arg self.valid = valid def __eq__(self, other): return (type(self) is type(other) and self.arg == other.arg and self.valid == other.valid) def __hash__(self): return hash((type(self), self.arg, self.valid)) def __str__(self): return "CondVariable(%s)" % str(self.arg) def unify(x, y, s=None, **fns): """ Unify two expressions. Parameters ========== x, y - expression trees containing leaves, Compounds and Variables. s - a mapping of variables to subtrees. Returns ======= lazy sequence of mappings {Variable: subtree} Examples ======== >>> from sympy.unify.core import unify, Compound, Variable >>> expr = Compound("Add", ("x", "y")) >>> pattern = Compound("Add", ("x", Variable("a"))) >>> next(unify(expr, pattern, {})) {Variable(a): 'y'} """ s = s or {} if x == y: yield s elif isinstance(x, (Variable, CondVariable)): yield from unify_var(x, y, s, **fns) elif isinstance(y, (Variable, CondVariable)): yield from unify_var(y, x, s, **fns) elif isinstance(x, Compound) and isinstance(y, Compound): is_commutative = fns.get('is_commutative', lambda x: False) is_associative = fns.get('is_associative', lambda x: False) for sop in unify(x.op, y.op, s, **fns): if is_associative(x) and is_associative(y): a, b = (x, y) if len(x.args) < len(y.args) else (y, x) if is_commutative(x) and is_commutative(y): combs = allcombinations(a.args, b.args, 'commutative') else: combs = allcombinations(a.args, b.args, 'associative') for aaargs, bbargs in combs: aa = [unpack(Compound(a.op, arg)) for arg in aaargs] bb = [unpack(Compound(b.op, arg)) for arg in bbargs] yield from unify(aa, bb, sop, **fns) elif len(x.args) == len(y.args): yield from unify(x.args, y.args, sop, **fns) elif is_args(x) and is_args(y) and len(x) == len(y): if len(x) == 0: yield s else: for shead in unify(x[0], y[0], s, **fns): yield from unify(x[1:], y[1:], shead, **fns) def unify_var(var, x, s, **fns): if var in s: yield from unify(s[var], x, s, **fns) elif occur_check(var, x): pass elif isinstance(var, CondVariable) and var.valid(x): yield assoc(s, var, x) elif isinstance(var, Variable): yield assoc(s, var, x) def occur_check(var, x): """ var occurs in subtree owned by x? """ if var == x: return True elif isinstance(x, Compound): return occur_check(var, x.args) elif is_args(x): if any(occur_check(var, xi) for xi in x): return True return False def assoc(d, key, val): """ Return copy of d with key associated to val """ d = d.copy() d[key] = val return d def is_args(x): """ Is x a traditional iterable? """ return type(x) in (tuple, list, set) def unpack(x): if isinstance(x, Compound) and len(x.args) == 1: return x.args[0] else: return x def allcombinations(A, B, ordered): """ Restructure A and B to have the same number of elements. Parameters ========== ordered must be either 'commutative' or 'associative'. A and B can be rearranged so that the larger of the two lists is reorganized into smaller sublists. Examples ======== >>> from sympy.unify.core import allcombinations >>> for x in allcombinations((1, 2, 3), (5, 6), 'associative'): print(x) (((1,), (2, 3)), ((5,), (6,))) (((1, 2), (3,)), ((5,), (6,))) >>> for x in allcombinations((1, 2, 3), (5, 6), 'commutative'): print(x) (((1,), (2, 3)), ((5,), (6,))) (((1, 2), (3,)), ((5,), (6,))) (((1,), (3, 2)), ((5,), (6,))) (((1, 3), (2,)), ((5,), (6,))) (((2,), (1, 3)), ((5,), (6,))) (((2, 1), (3,)), ((5,), (6,))) (((2,), (3, 1)), ((5,), (6,))) (((2, 3), (1,)), ((5,), (6,))) (((3,), (1, 2)), ((5,), (6,))) (((3, 1), (2,)), ((5,), (6,))) (((3,), (2, 1)), ((5,), (6,))) (((3, 2), (1,)), ((5,), (6,))) """ if ordered == "commutative": ordered = 11 if ordered == "associative": ordered = None sm, bg = (A, B) if len(A) < len(B) else (B, A) for part in kbins(list(range(len(bg))), len(sm), ordered=ordered): if bg == B: yield tuple((a,) for a in A), partition(B, part) else: yield partition(A, part), tuple((b,) for b in B) def partition(it, part): """ Partition a tuple/list into pieces defined by indices. Examples ======== >>> from sympy.unify.core import partition >>> partition((10, 20, 30, 40), [[0, 1, 2], [3]]) ((10, 20, 30), (40,)) """ return type(it)([index(it, ind) for ind in part]) def index(it, ind): """ Fancy indexing into an indexable iterable (tuple, list). Examples ======== >>> from sympy.unify.core import index >>> index([10, 20, 30], (1, 2, 0)) [20, 30, 10] """ return type(it)([it[i] for i in ind])