from typing import Dict, Tuple from torch.fx._compatibility import compatibility from torch.fx.graph import Graph from torch.fx.graph_module import GraphModule from torch.fx.passes.utils.matcher_utils import SubgraphMatcher from torch.nn import Module __all__ = ["HolderModule", "lift_subgraph_as_module", "compare_graphs"] @compatibility(is_backward_compatible=False) class HolderModule(Module): """ HolderModule is used to copy all the attributes from original module to submodules that uses the attributes """ def __init__(self, d): super().__init__() for k, v in d.items(): self.add_module(k, v) @compatibility(is_backward_compatible=False) def lift_subgraph_as_module( gm: GraphModule, subgraph: Graph, comp_name: str = "", class_name: str = "GraphModule", ) -> Tuple[GraphModule, Dict[str, str]]: """ Create a GraphModule for subgraph, which copies the necessary attributes from the original parent graph_module. Args: gm (GraphModule): parent graph module subgraph (Graph): a valid subgraph that contains copied nodes from the parent graph comp_name (str): name for the new component class_name (str): name for the submodule """ # Loop through all module calls (call_module) and param fetches (get_attr) # in this component, creating HolderModules as necessary to match the path. # e.g. if in the original module there's a get_attr node fetches "conv.weight". # We create a HolderModule as root -> add a HolderModule named "conv" -> # make "weight" a attribute of "conv" HolderModule and point to conv.weight in # the original module. submodule = HolderModule({}) orig_to_split_fqn_mapping: Dict[str, str] = {} for n in subgraph.nodes: if n.op not in ("call_module", "get_attr"): continue target = n.target assert isinstance(target, str) target_name_parts = target.split(".") curr = submodule orig_gm = gm for name in target_name_parts[:-1]: if not hasattr(curr, name): curr.add_module(name, HolderModule({})) curr = getattr(curr, name) orig_gm = getattr(orig_gm, name) leaf_node_name = target_name_parts[-1] leaf_node = getattr(orig_gm, leaf_node_name) orig_to_split_fqn_mapping[target] = f"{comp_name}.{target}" # Relies on custom __setattr__ magic. setattr(curr, leaf_node_name, leaf_node) return GraphModule(submodule, subgraph, class_name), orig_to_split_fqn_mapping @compatibility(is_backward_compatible=False) def compare_graphs(left: Graph, right: Graph) -> bool: """ Return True if two graphs are identical, i.e they - have the same number of outputs in the same order - have the same number of inputs in the same order - have the same set of nodes, and identical connectivity """ matcher = SubgraphMatcher(left, match_output=True, match_placeholder=True) matches = matcher.match(right) return len(matches) > 0