121 lines
3.0 KiB
Python
121 lines
3.0 KiB
Python
|
from torch.fx.experimental.graph_gradual_typechecker import Refine
|
||
|
from torch.fx.tensor_type import TensorType
|
||
|
from torch.fx.experimental.unification import Var, unify # type: ignore[attr-defined]
|
||
|
|
||
|
|
||
|
def infer_symbolic_types_single_pass(traced):
|
||
|
"""
|
||
|
Calls our symbolic inferencer once.
|
||
|
"""
|
||
|
r = Refine(traced)
|
||
|
r.refine()
|
||
|
mgu = unify_eq(r.constraints)
|
||
|
substitute_all_types(traced.graph, mgu)
|
||
|
|
||
|
def infer_symbolic_types(traced):
|
||
|
"""
|
||
|
Calls our symbolic inferencer twice.
|
||
|
This is useful when one pass is not enough
|
||
|
to infer all the information such as the case
|
||
|
for braodcasting.
|
||
|
"""
|
||
|
r = Refine(traced)
|
||
|
r.refine()
|
||
|
mgu = unify_eq(r.constraints)
|
||
|
substitute_all_types(traced.graph, mgu)
|
||
|
|
||
|
r = Refine(traced)
|
||
|
r.refine()
|
||
|
mgu = unify_eq(r.constraints)
|
||
|
substitute_all_types(traced.graph, mgu)
|
||
|
|
||
|
r.symbolic_relations()
|
||
|
|
||
|
def convert_eq(list_of_eq):
|
||
|
"""
|
||
|
Convert equality constraints in the right format
|
||
|
to be used by unification library.
|
||
|
"""
|
||
|
lhs = []
|
||
|
rhs = []
|
||
|
for eq in list_of_eq:
|
||
|
lhs.append(eq.lhs)
|
||
|
rhs.append(eq.rhs)
|
||
|
return tuple(lhs), tuple(rhs)
|
||
|
|
||
|
|
||
|
def unify_eq(list_of_eq):
|
||
|
"""
|
||
|
Apply unification to a set of
|
||
|
equality constraints
|
||
|
"""
|
||
|
lhs, rhs = convert_eq(list_of_eq)
|
||
|
return unify(lhs, rhs)
|
||
|
|
||
|
|
||
|
def substitute_solution_one_type(mapping, t):
|
||
|
"""
|
||
|
Apply the most general unifier to a type
|
||
|
"""
|
||
|
if isinstance(t, Var):
|
||
|
if t in mapping.keys():
|
||
|
return mapping[t]
|
||
|
else:
|
||
|
return t
|
||
|
|
||
|
elif isinstance(t, TensorType):
|
||
|
new_type = []
|
||
|
for typ in t.__args__:
|
||
|
if typ in mapping.keys():
|
||
|
new_type.append(mapping[typ])
|
||
|
else:
|
||
|
new_type.append(typ)
|
||
|
return TensorType(tuple(new_type))
|
||
|
|
||
|
elif isinstance(t, list):
|
||
|
new_type = []
|
||
|
for typ in t:
|
||
|
new_type.append(substitute_solution_one_type(mapping, typ))
|
||
|
return new_type
|
||
|
|
||
|
elif isinstance(t, tuple):
|
||
|
new_type = []
|
||
|
for typ in t:
|
||
|
new_type.append(substitute_solution_one_type(mapping, typ))
|
||
|
return tuple(new_type)
|
||
|
|
||
|
else:
|
||
|
return t
|
||
|
|
||
|
|
||
|
def substitute_all_types(graph, mapping):
|
||
|
"""
|
||
|
Apply the most general unifier to all types in a graph
|
||
|
till reaching a fixed point. If the input and output graph
|
||
|
are the same, we converge.
|
||
|
"""
|
||
|
flag = True
|
||
|
while flag:
|
||
|
flag = False
|
||
|
for k in mapping:
|
||
|
old_mapping_val = mapping[k]
|
||
|
if mapping[k] in mapping.keys():
|
||
|
new_key = mapping[k]
|
||
|
mapping[k] = mapping[new_key]
|
||
|
if old_mapping_val != mapping[k]:
|
||
|
flag = True
|
||
|
|
||
|
for n in graph.nodes:
|
||
|
n.type = substitute_solution_one_type(mapping, n.type)
|
||
|
|
||
|
def check_for_type_equality(g1, g2):
|
||
|
"""
|
||
|
A check equality to be used in fixed points.
|
||
|
We do not use graph equality but instead type
|
||
|
equality.
|
||
|
"""
|
||
|
for n, m in zip(g1.nodes, g2.nodes):
|
||
|
if n.type != m.type:
|
||
|
return False
|
||
|
return True
|