from torch.fx import ( GraphModule, Node, map_arg ) from torch.fx.graph import Graph from .match_utils import ( _is_match, MatchAllNode, ) from .pattern_utils import ( _sorted_patterns_dict, ) from ..backend_config import ( BackendConfig, get_native_backend_config, ) from ..backend_config.utils import ( get_fuser_method_mapping, get_fusion_pattern_to_root_node_getter, get_fusion_pattern_to_extra_inputs_getter, ) from .custom_config import FuseCustomConfig from .fuse_handler import ( _get_fusion_pattern_to_fuse_handler_cls, FuseHandler, ) from typing import Any, Callable, Dict, List, Tuple, Union import warnings from torch.ao.quantization.utils import Pattern, NodePattern __all__ = [ "fuse", # TODO: We should make this private in the future # This is currently needed for test_public_bindings for some reason "FuseHandler", ] def fuse( model: GraphModule, is_qat: bool, fuse_custom_config: Union[FuseCustomConfig, Dict[str, Any], None] = None, backend_config: Union[BackendConfig, Dict[str, Any], None] = None, ) -> GraphModule: if fuse_custom_config is None: fuse_custom_config = FuseCustomConfig() if isinstance(fuse_custom_config, Dict): warnings.warn( "Passing a fuse_custom_config_dict to fuse is deprecated and will not be supported " "in a future version. Please pass in a FuseCustomConfig instead.") fuse_custom_config = FuseCustomConfig.from_dict(fuse_custom_config) if isinstance(backend_config, Dict): warnings.warn( "Passing a backend_config_dict to prepare is deprecated and will not be supported " "in a future version. Please pass in a BackendConfig instead.") backend_config = BackendConfig.from_dict(backend_config) named_modules = dict(model.named_modules()) if backend_config is None: backend_config = get_native_backend_config() fusion_pattern_to_fuse_handler_cls = _sorted_patterns_dict(_get_fusion_pattern_to_fuse_handler_cls(backend_config)) fuser_method_mapping = get_fuser_method_mapping(backend_config) fusion_pattern_to_root_node_getter = get_fusion_pattern_to_root_node_getter(backend_config) fusion_pattern_to_extra_inputs_getter = get_fusion_pattern_to_extra_inputs_getter(backend_config) # find fusion fusion_pairs = _find_matches( model, model.graph, fusion_pattern_to_fuse_handler_cls) # TODO: change this to inplace changes to graph, since we no longer construct # new GraphModule anymore fused_graph = Graph() env: Dict[Any, Any] = {} def load_arg(a): return map_arg(a, lambda node: env[node.name]) def default_root_node_getter(node_pattern): while not isinstance(node_pattern[-1], Node): node_pattern = node_pattern[-1] return node_pattern[-1] for node in model.graph.nodes: maybe_last_node, pattern, matched_node_pattern, obj, node_to_subpattern = \ fusion_pairs.get(node.name, (None, None, None, None, None)) # get the corresponding subpattern for the current node if node_to_subpattern is not None: node_subpattern = node_to_subpattern.get(node, None) else: node_subpattern = None if maybe_last_node is node: assert obj is not None root_node_getter = fusion_pattern_to_root_node_getter.get(pattern, default_root_node_getter) root_node = root_node_getter(matched_node_pattern) # type: ignore[index] extra_inputs_getter = fusion_pattern_to_extra_inputs_getter.get(pattern, None) extra_inputs = [] if extra_inputs_getter is not None: extra_inputs = extra_inputs_getter(matched_node_pattern) # TODO: add validation that root_node is a module and has the same type # as the root_module in the configuration env[node.name] = obj.fuse( load_arg, named_modules, fused_graph, root_node, extra_inputs, matched_node_pattern, # type: ignore[arg-type] fuse_custom_config, fuser_method_mapping, is_qat) elif maybe_last_node is None or node_subpattern is MatchAllNode: env[node.name] = fused_graph.node_copy(node, load_arg) # node matched in patterns and is not root is removed here model = GraphModule(model, fused_graph) return model def _find_matches( root: GraphModule, graph: Graph, pattern_to_fuse_handler_cls: Dict[Pattern, Callable], ) -> Dict[str, Tuple[Node, Pattern, NodePattern, FuseHandler, Dict[Node, Any]]]: modules = dict(root.named_modules()) # node name -> (root_node, match_value) match_map : Dict[ str, Tuple[Node, Pattern, NodePattern, FuseHandler, Dict[Node, Any]]] = {} # a map from node to the matched subpattern node_to_subpattern: Dict[Node, Any] = {} # TODO: dedup with quantization matching function in match_utils.py def apply_match(pattern, node, match, matched_node_pattern, node_to_subpattern): if isinstance(pattern, tuple): s, *args = pattern current_node_pattern: List[Node] = [] apply_match(s, node, match, current_node_pattern, node_to_subpattern) for subpattern, arg in zip(args, node.args): apply_match(subpattern, arg, match, current_node_pattern, node_to_subpattern) matched_node_pattern.append(tuple(current_node_pattern)) else: # the first pattern matches will take precedence if node.name not in match_map: matched_node_pattern.append(node) # MatchAllNode here is actually MatchAllInputNode which should not # be added to match_map if pattern is not MatchAllNode: node_to_subpattern[node] = pattern root_node, pattern, handler = match match_map[node.name] = (root_node, pattern, matched_node_pattern, handler, node_to_subpattern) for node in reversed(graph.nodes): if node.name not in match_map: for pattern, fuse_handler_cls in pattern_to_fuse_handler_cls.items(): matched_node_pattern: List[Node] = [] if _is_match(modules, node, pattern): apply_match(pattern, node, (node, pattern, fuse_handler_cls(node)), matched_node_pattern, node_to_subpattern) break return match_map