ai-content-maker/.venv/Lib/site-packages/torch/fx/experimental/unification/match.py

122 lines
3.3 KiB
Python

from .core import unify, reify # type: ignore[attr-defined]
from .variable import isvar
from .utils import _toposort, freeze
from .unification_tools import groupby, first # type: ignore[import]
class Dispatcher:
def __init__(self, name):
self.name = name
self.funcs = {}
self.ordering = []
def add(self, signature, func):
self.funcs[freeze(signature)] = func
self.ordering = ordering(self.funcs)
def __call__(self, *args, **kwargs):
func, s = self.resolve(args)
return func(*args, **kwargs)
def resolve(self, args):
n = len(args)
for signature in self.ordering:
if len(signature) != n:
continue
s = unify(freeze(args), signature)
if s is not False:
result = self.funcs[signature]
return result, s
raise NotImplementedError("No match found. \nKnown matches: "
+ str(self.ordering) + "\nInput: " + str(args))
def register(self, *signature):
def _(func):
self.add(signature, func)
return self
return _
class VarDispatcher(Dispatcher):
""" A dispatcher that calls functions with variable names
>>> # xdoctest: +SKIP
>>> d = VarDispatcher('d')
>>> x = var('x')
>>> @d.register('inc', x)
... def f(x):
... return x + 1
>>> @d.register('double', x)
... def f(x):
... return x * 2
>>> d('inc', 10)
11
>>> d('double', 10)
20
"""
def __call__(self, *args, **kwargs):
func, s = self.resolve(args)
d = {k.token: v for k, v in s.items()}
return func(**d)
global_namespace = {} # type: ignore[var-annotated]
def match(*signature, **kwargs):
namespace = kwargs.get('namespace', global_namespace)
dispatcher = kwargs.get('Dispatcher', Dispatcher)
def _(func):
name = func.__name__
if name not in namespace:
namespace[name] = dispatcher(name)
d = namespace[name]
d.add(signature, func)
return d
return _
def supercedes(a, b):
""" ``a`` is a more specific match than ``b`` """
if isvar(b) and not isvar(a):
return True
s = unify(a, b)
if s is False:
return False
s = {k: v for k, v in s.items() if not isvar(k) or not isvar(v)}
if reify(a, s) == a:
return True
if reify(b, s) == b:
return False
# Taken from multipledispatch
def edge(a, b, tie_breaker=hash):
""" A should be checked before B
Tie broken by tie_breaker, defaults to ``hash``
"""
if supercedes(a, b):
if supercedes(b, a):
return tie_breaker(a) > tie_breaker(b)
else:
return True
return False
# Taken from multipledispatch
def ordering(signatures):
""" A sane ordering of signatures to check, first to last
Topological sort of edges as given by ``edge`` and ``supercedes``
"""
signatures = list(map(tuple, signatures))
edges = [(a, b) for a in signatures for b in signatures if edge(a, b)]
edges = groupby(first, edges)
for s in signatures:
if s not in edges:
edges[s] = []
edges = {k: [b for a, b in v] for k, v in edges.items()} # type: ignore[attr-defined, assignment]
return _toposort(edges)