ai-content-maker/.venv/Lib/site-packages/torch/fx/passes/utils/common.py

96 lines
3.0 KiB
Python
Raw Permalink Normal View History

2024-05-03 04:18:51 +03:00
from typing import Dict, Tuple
from torch.fx._compatibility import compatibility
from torch.fx.graph import Graph
from torch.fx.graph_module import GraphModule
from torch.fx.passes.utils.matcher_utils import SubgraphMatcher
from torch.nn import Module
__all__ = ["HolderModule", "lift_subgraph_as_module", "compare_graphs"]
@compatibility(is_backward_compatible=False)
class HolderModule(Module):
"""
HolderModule is used to copy all the attributes from original module to submodules
that uses the attributes
"""
def __init__(self, d):
super().__init__()
for k, v in d.items():
self.add_module(k, v)
@compatibility(is_backward_compatible=False)
def lift_subgraph_as_module(
gm: GraphModule,
subgraph: Graph,
comp_name: str = "",
class_name: str = "GraphModule",
) -> Tuple[GraphModule, Dict[str, str]]:
"""
Create a GraphModule for subgraph, which copies the necessary attributes from the original parent graph_module.
Args:
gm (GraphModule): parent graph module
subgraph (Graph): a valid subgraph that contains copied nodes from the parent graph
comp_name (str): name for the new component
class_name (str): name for the submodule
"""
# Loop through all module calls (call_module) and param fetches (get_attr)
# in this component, creating HolderModules as necessary to match the path.
# e.g. if in the original module there's a get_attr node fetches "conv.weight".
# We create a HolderModule as root -> add a HolderModule named "conv" ->
# make "weight" a attribute of "conv" HolderModule and point to conv.weight in
# the original module.
submodule = HolderModule({})
orig_to_split_fqn_mapping: Dict[str, str] = {}
for n in subgraph.nodes:
if n.op not in ("call_module", "get_attr"):
continue
target = n.target
assert isinstance(target, str)
target_name_parts = target.split(".")
curr = submodule
orig_gm = gm
for name in target_name_parts[:-1]:
if not hasattr(curr, name):
curr.add_module(name, HolderModule({}))
curr = getattr(curr, name)
orig_gm = getattr(orig_gm, name)
leaf_node_name = target_name_parts[-1]
leaf_node = getattr(orig_gm, leaf_node_name)
orig_to_split_fqn_mapping[target] = f"{comp_name}.{target}"
# Relies on custom __setattr__ magic.
setattr(curr, leaf_node_name, leaf_node)
return GraphModule(submodule, subgraph, class_name), orig_to_split_fqn_mapping
@compatibility(is_backward_compatible=False)
def compare_graphs(left: Graph, right: Graph) -> bool:
"""
Return True if two graphs are identical, i.e they
- have the same number of outputs in the same order
- have the same number of inputs in the same order
- have the same set of nodes, and identical connectivity
"""
matcher = SubgraphMatcher(left, match_output=True, match_placeholder=True)
matches = matcher.match(right)
return len(matches) > 0