235 lines
6.9 KiB
Python
235 lines
6.9 KiB
Python
|
""" Generic Unification algorithm for expression trees with lists of children
|
||
|
|
||
|
This implementation is a direct translation of
|
||
|
|
||
|
Artificial Intelligence: A Modern Approach by Stuart Russel and Peter Norvig
|
||
|
Second edition, section 9.2, page 276
|
||
|
|
||
|
It is modified in the following ways:
|
||
|
|
||
|
1. We allow associative and commutative Compound expressions. This results in
|
||
|
combinatorial blowup.
|
||
|
2. We explore the tree lazily.
|
||
|
3. We provide generic interfaces to symbolic algebra libraries in Python.
|
||
|
|
||
|
A more traditional version can be found here
|
||
|
http://aima.cs.berkeley.edu/python/logic.html
|
||
|
"""
|
||
|
|
||
|
from sympy.utilities.iterables import kbins
|
||
|
|
||
|
class Compound:
|
||
|
""" A little class to represent an interior node in the tree
|
||
|
|
||
|
This is analogous to SymPy.Basic for non-Atoms
|
||
|
"""
|
||
|
def __init__(self, op, args):
|
||
|
self.op = op
|
||
|
self.args = args
|
||
|
|
||
|
def __eq__(self, other):
|
||
|
return (type(self) is type(other) and self.op == other.op and
|
||
|
self.args == other.args)
|
||
|
|
||
|
def __hash__(self):
|
||
|
return hash((type(self), self.op, self.args))
|
||
|
|
||
|
def __str__(self):
|
||
|
return "%s[%s]" % (str(self.op), ', '.join(map(str, self.args)))
|
||
|
|
||
|
class Variable:
|
||
|
""" A Wild token """
|
||
|
def __init__(self, arg):
|
||
|
self.arg = arg
|
||
|
|
||
|
def __eq__(self, other):
|
||
|
return type(self) is type(other) and self.arg == other.arg
|
||
|
|
||
|
def __hash__(self):
|
||
|
return hash((type(self), self.arg))
|
||
|
|
||
|
def __str__(self):
|
||
|
return "Variable(%s)" % str(self.arg)
|
||
|
|
||
|
class CondVariable:
|
||
|
""" A wild token that matches conditionally.
|
||
|
|
||
|
arg - a wild token.
|
||
|
valid - an additional constraining function on a match.
|
||
|
"""
|
||
|
def __init__(self, arg, valid):
|
||
|
self.arg = arg
|
||
|
self.valid = valid
|
||
|
|
||
|
def __eq__(self, other):
|
||
|
return (type(self) is type(other) and
|
||
|
self.arg == other.arg and
|
||
|
self.valid == other.valid)
|
||
|
|
||
|
def __hash__(self):
|
||
|
return hash((type(self), self.arg, self.valid))
|
||
|
|
||
|
def __str__(self):
|
||
|
return "CondVariable(%s)" % str(self.arg)
|
||
|
|
||
|
def unify(x, y, s=None, **fns):
|
||
|
""" Unify two expressions.
|
||
|
|
||
|
Parameters
|
||
|
==========
|
||
|
|
||
|
x, y - expression trees containing leaves, Compounds and Variables.
|
||
|
s - a mapping of variables to subtrees.
|
||
|
|
||
|
Returns
|
||
|
=======
|
||
|
|
||
|
lazy sequence of mappings {Variable: subtree}
|
||
|
|
||
|
Examples
|
||
|
========
|
||
|
|
||
|
>>> from sympy.unify.core import unify, Compound, Variable
|
||
|
>>> expr = Compound("Add", ("x", "y"))
|
||
|
>>> pattern = Compound("Add", ("x", Variable("a")))
|
||
|
>>> next(unify(expr, pattern, {}))
|
||
|
{Variable(a): 'y'}
|
||
|
"""
|
||
|
s = s or {}
|
||
|
|
||
|
if x == y:
|
||
|
yield s
|
||
|
elif isinstance(x, (Variable, CondVariable)):
|
||
|
yield from unify_var(x, y, s, **fns)
|
||
|
elif isinstance(y, (Variable, CondVariable)):
|
||
|
yield from unify_var(y, x, s, **fns)
|
||
|
elif isinstance(x, Compound) and isinstance(y, Compound):
|
||
|
is_commutative = fns.get('is_commutative', lambda x: False)
|
||
|
is_associative = fns.get('is_associative', lambda x: False)
|
||
|
for sop in unify(x.op, y.op, s, **fns):
|
||
|
if is_associative(x) and is_associative(y):
|
||
|
a, b = (x, y) if len(x.args) < len(y.args) else (y, x)
|
||
|
if is_commutative(x) and is_commutative(y):
|
||
|
combs = allcombinations(a.args, b.args, 'commutative')
|
||
|
else:
|
||
|
combs = allcombinations(a.args, b.args, 'associative')
|
||
|
for aaargs, bbargs in combs:
|
||
|
aa = [unpack(Compound(a.op, arg)) for arg in aaargs]
|
||
|
bb = [unpack(Compound(b.op, arg)) for arg in bbargs]
|
||
|
yield from unify(aa, bb, sop, **fns)
|
||
|
elif len(x.args) == len(y.args):
|
||
|
yield from unify(x.args, y.args, sop, **fns)
|
||
|
|
||
|
elif is_args(x) and is_args(y) and len(x) == len(y):
|
||
|
if len(x) == 0:
|
||
|
yield s
|
||
|
else:
|
||
|
for shead in unify(x[0], y[0], s, **fns):
|
||
|
yield from unify(x[1:], y[1:], shead, **fns)
|
||
|
|
||
|
def unify_var(var, x, s, **fns):
|
||
|
if var in s:
|
||
|
yield from unify(s[var], x, s, **fns)
|
||
|
elif occur_check(var, x):
|
||
|
pass
|
||
|
elif isinstance(var, CondVariable) and var.valid(x):
|
||
|
yield assoc(s, var, x)
|
||
|
elif isinstance(var, Variable):
|
||
|
yield assoc(s, var, x)
|
||
|
|
||
|
def occur_check(var, x):
|
||
|
""" var occurs in subtree owned by x? """
|
||
|
if var == x:
|
||
|
return True
|
||
|
elif isinstance(x, Compound):
|
||
|
return occur_check(var, x.args)
|
||
|
elif is_args(x):
|
||
|
if any(occur_check(var, xi) for xi in x): return True
|
||
|
return False
|
||
|
|
||
|
def assoc(d, key, val):
|
||
|
""" Return copy of d with key associated to val """
|
||
|
d = d.copy()
|
||
|
d[key] = val
|
||
|
return d
|
||
|
|
||
|
def is_args(x):
|
||
|
""" Is x a traditional iterable? """
|
||
|
return type(x) in (tuple, list, set)
|
||
|
|
||
|
def unpack(x):
|
||
|
if isinstance(x, Compound) and len(x.args) == 1:
|
||
|
return x.args[0]
|
||
|
else:
|
||
|
return x
|
||
|
|
||
|
def allcombinations(A, B, ordered):
|
||
|
"""
|
||
|
Restructure A and B to have the same number of elements.
|
||
|
|
||
|
Parameters
|
||
|
==========
|
||
|
|
||
|
ordered must be either 'commutative' or 'associative'.
|
||
|
|
||
|
A and B can be rearranged so that the larger of the two lists is
|
||
|
reorganized into smaller sublists.
|
||
|
|
||
|
Examples
|
||
|
========
|
||
|
|
||
|
>>> from sympy.unify.core import allcombinations
|
||
|
>>> for x in allcombinations((1, 2, 3), (5, 6), 'associative'): print(x)
|
||
|
(((1,), (2, 3)), ((5,), (6,)))
|
||
|
(((1, 2), (3,)), ((5,), (6,)))
|
||
|
|
||
|
>>> for x in allcombinations((1, 2, 3), (5, 6), 'commutative'): print(x)
|
||
|
(((1,), (2, 3)), ((5,), (6,)))
|
||
|
(((1, 2), (3,)), ((5,), (6,)))
|
||
|
(((1,), (3, 2)), ((5,), (6,)))
|
||
|
(((1, 3), (2,)), ((5,), (6,)))
|
||
|
(((2,), (1, 3)), ((5,), (6,)))
|
||
|
(((2, 1), (3,)), ((5,), (6,)))
|
||
|
(((2,), (3, 1)), ((5,), (6,)))
|
||
|
(((2, 3), (1,)), ((5,), (6,)))
|
||
|
(((3,), (1, 2)), ((5,), (6,)))
|
||
|
(((3, 1), (2,)), ((5,), (6,)))
|
||
|
(((3,), (2, 1)), ((5,), (6,)))
|
||
|
(((3, 2), (1,)), ((5,), (6,)))
|
||
|
"""
|
||
|
|
||
|
if ordered == "commutative":
|
||
|
ordered = 11
|
||
|
if ordered == "associative":
|
||
|
ordered = None
|
||
|
sm, bg = (A, B) if len(A) < len(B) else (B, A)
|
||
|
for part in kbins(list(range(len(bg))), len(sm), ordered=ordered):
|
||
|
if bg == B:
|
||
|
yield tuple((a,) for a in A), partition(B, part)
|
||
|
else:
|
||
|
yield partition(A, part), tuple((b,) for b in B)
|
||
|
|
||
|
def partition(it, part):
|
||
|
""" Partition a tuple/list into pieces defined by indices.
|
||
|
|
||
|
Examples
|
||
|
========
|
||
|
|
||
|
>>> from sympy.unify.core import partition
|
||
|
>>> partition((10, 20, 30, 40), [[0, 1, 2], [3]])
|
||
|
((10, 20, 30), (40,))
|
||
|
"""
|
||
|
return type(it)([index(it, ind) for ind in part])
|
||
|
|
||
|
def index(it, ind):
|
||
|
""" Fancy indexing into an indexable iterable (tuple, list).
|
||
|
|
||
|
Examples
|
||
|
========
|
||
|
|
||
|
>>> from sympy.unify.core import index
|
||
|
>>> index([10, 20, 30], (1, 2, 0))
|
||
|
[20, 30, 10]
|
||
|
"""
|
||
|
return type(it)([it[i] for i in ind])
|