ai-content-maker/.venv/Lib/site-packages/torch/fx/experimental/unify_refinements.py

121 lines
3.0 KiB
Python
Raw Normal View History

2024-05-03 04:18:51 +03:00
from torch.fx.experimental.graph_gradual_typechecker import Refine
from torch.fx.tensor_type import TensorType
from torch.fx.experimental.unification import Var, unify # type: ignore[attr-defined]
def infer_symbolic_types_single_pass(traced):
"""
Calls our symbolic inferencer once.
"""
r = Refine(traced)
r.refine()
mgu = unify_eq(r.constraints)
substitute_all_types(traced.graph, mgu)
def infer_symbolic_types(traced):
"""
Calls our symbolic inferencer twice.
This is useful when one pass is not enough
to infer all the information such as the case
for braodcasting.
"""
r = Refine(traced)
r.refine()
mgu = unify_eq(r.constraints)
substitute_all_types(traced.graph, mgu)
r = Refine(traced)
r.refine()
mgu = unify_eq(r.constraints)
substitute_all_types(traced.graph, mgu)
r.symbolic_relations()
def convert_eq(list_of_eq):
"""
Convert equality constraints in the right format
to be used by unification library.
"""
lhs = []
rhs = []
for eq in list_of_eq:
lhs.append(eq.lhs)
rhs.append(eq.rhs)
return tuple(lhs), tuple(rhs)
def unify_eq(list_of_eq):
"""
Apply unification to a set of
equality constraints
"""
lhs, rhs = convert_eq(list_of_eq)
return unify(lhs, rhs)
def substitute_solution_one_type(mapping, t):
"""
Apply the most general unifier to a type
"""
if isinstance(t, Var):
if t in mapping.keys():
return mapping[t]
else:
return t
elif isinstance(t, TensorType):
new_type = []
for typ in t.__args__:
if typ in mapping.keys():
new_type.append(mapping[typ])
else:
new_type.append(typ)
return TensorType(tuple(new_type))
elif isinstance(t, list):
new_type = []
for typ in t:
new_type.append(substitute_solution_one_type(mapping, typ))
return new_type
elif isinstance(t, tuple):
new_type = []
for typ in t:
new_type.append(substitute_solution_one_type(mapping, typ))
return tuple(new_type)
else:
return t
def substitute_all_types(graph, mapping):
"""
Apply the most general unifier to all types in a graph
till reaching a fixed point. If the input and output graph
are the same, we converge.
"""
flag = True
while flag:
flag = False
for k in mapping:
old_mapping_val = mapping[k]
if mapping[k] in mapping.keys():
new_key = mapping[k]
mapping[k] = mapping[new_key]
if old_mapping_val != mapping[k]:
flag = True
for n in graph.nodes:
n.type = substitute_solution_one_type(mapping, n.type)
def check_for_type_equality(g1, g2):
"""
A check equality to be used in fixed points.
We do not use graph equality but instead type
equality.
"""
for n, m in zip(g1.nodes, g2.nodes):
if n.type != m.type:
return False
return True