162 lines
6.4 KiB
Python
162 lines
6.4 KiB
Python
|
from torch.fx import (
|
||
|
GraphModule,
|
||
|
Node,
|
||
|
map_arg
|
||
|
)
|
||
|
from torch.fx.graph import Graph
|
||
|
from .match_utils import (
|
||
|
_is_match,
|
||
|
MatchAllNode,
|
||
|
)
|
||
|
from .pattern_utils import (
|
||
|
_sorted_patterns_dict,
|
||
|
)
|
||
|
|
||
|
from ..backend_config import (
|
||
|
BackendConfig,
|
||
|
get_native_backend_config,
|
||
|
)
|
||
|
from ..backend_config.utils import (
|
||
|
get_fuser_method_mapping,
|
||
|
get_fusion_pattern_to_root_node_getter,
|
||
|
get_fusion_pattern_to_extra_inputs_getter,
|
||
|
)
|
||
|
|
||
|
from .custom_config import FuseCustomConfig
|
||
|
|
||
|
from .fuse_handler import (
|
||
|
_get_fusion_pattern_to_fuse_handler_cls,
|
||
|
FuseHandler,
|
||
|
)
|
||
|
|
||
|
from typing import Any, Callable, Dict, List, Tuple, Union
|
||
|
import warnings
|
||
|
|
||
|
from torch.ao.quantization.utils import Pattern, NodePattern
|
||
|
|
||
|
|
||
|
__all__ = [
|
||
|
"fuse",
|
||
|
# TODO: We should make this private in the future
|
||
|
# This is currently needed for test_public_bindings for some reason
|
||
|
"FuseHandler",
|
||
|
]
|
||
|
|
||
|
|
||
|
def fuse(
|
||
|
model: GraphModule,
|
||
|
is_qat: bool,
|
||
|
fuse_custom_config: Union[FuseCustomConfig, Dict[str, Any], None] = None,
|
||
|
backend_config: Union[BackendConfig, Dict[str, Any], None] = None,
|
||
|
) -> GraphModule:
|
||
|
if fuse_custom_config is None:
|
||
|
fuse_custom_config = FuseCustomConfig()
|
||
|
|
||
|
if isinstance(fuse_custom_config, Dict):
|
||
|
warnings.warn(
|
||
|
"Passing a fuse_custom_config_dict to fuse is deprecated and will not be supported "
|
||
|
"in a future version. Please pass in a FuseCustomConfig instead.")
|
||
|
fuse_custom_config = FuseCustomConfig.from_dict(fuse_custom_config)
|
||
|
|
||
|
if isinstance(backend_config, Dict):
|
||
|
warnings.warn(
|
||
|
"Passing a backend_config_dict to prepare is deprecated and will not be supported "
|
||
|
"in a future version. Please pass in a BackendConfig instead.")
|
||
|
backend_config = BackendConfig.from_dict(backend_config)
|
||
|
|
||
|
named_modules = dict(model.named_modules())
|
||
|
|
||
|
if backend_config is None:
|
||
|
backend_config = get_native_backend_config()
|
||
|
|
||
|
fusion_pattern_to_fuse_handler_cls = _sorted_patterns_dict(_get_fusion_pattern_to_fuse_handler_cls(backend_config))
|
||
|
fuser_method_mapping = get_fuser_method_mapping(backend_config)
|
||
|
fusion_pattern_to_root_node_getter = get_fusion_pattern_to_root_node_getter(backend_config)
|
||
|
fusion_pattern_to_extra_inputs_getter = get_fusion_pattern_to_extra_inputs_getter(backend_config)
|
||
|
|
||
|
# find fusion
|
||
|
fusion_pairs = _find_matches(
|
||
|
model, model.graph, fusion_pattern_to_fuse_handler_cls)
|
||
|
# TODO: change this to inplace changes to graph, since we no longer construct
|
||
|
# new GraphModule anymore
|
||
|
fused_graph = Graph()
|
||
|
env: Dict[Any, Any] = {}
|
||
|
|
||
|
def load_arg(a):
|
||
|
return map_arg(a, lambda node: env[node.name])
|
||
|
|
||
|
def default_root_node_getter(node_pattern):
|
||
|
while not isinstance(node_pattern[-1], Node):
|
||
|
node_pattern = node_pattern[-1]
|
||
|
return node_pattern[-1]
|
||
|
|
||
|
for node in model.graph.nodes:
|
||
|
maybe_last_node, pattern, matched_node_pattern, obj, node_to_subpattern = \
|
||
|
fusion_pairs.get(node.name, (None, None, None, None, None))
|
||
|
# get the corresponding subpattern for the current node
|
||
|
if node_to_subpattern is not None:
|
||
|
node_subpattern = node_to_subpattern.get(node, None)
|
||
|
else:
|
||
|
node_subpattern = None
|
||
|
if maybe_last_node is node:
|
||
|
assert obj is not None
|
||
|
root_node_getter = fusion_pattern_to_root_node_getter.get(pattern, default_root_node_getter)
|
||
|
root_node = root_node_getter(matched_node_pattern) # type: ignore[index]
|
||
|
extra_inputs_getter = fusion_pattern_to_extra_inputs_getter.get(pattern, None)
|
||
|
extra_inputs = []
|
||
|
if extra_inputs_getter is not None:
|
||
|
extra_inputs = extra_inputs_getter(matched_node_pattern)
|
||
|
# TODO: add validation that root_node is a module and has the same type
|
||
|
# as the root_module in the configuration
|
||
|
env[node.name] = obj.fuse(
|
||
|
load_arg, named_modules, fused_graph, root_node, extra_inputs, matched_node_pattern, # type: ignore[arg-type]
|
||
|
fuse_custom_config, fuser_method_mapping, is_qat)
|
||
|
elif maybe_last_node is None or node_subpattern is MatchAllNode:
|
||
|
env[node.name] = fused_graph.node_copy(node, load_arg)
|
||
|
# node matched in patterns and is not root is removed here
|
||
|
|
||
|
model = GraphModule(model, fused_graph)
|
||
|
return model
|
||
|
|
||
|
def _find_matches(
|
||
|
root: GraphModule,
|
||
|
graph: Graph,
|
||
|
pattern_to_fuse_handler_cls: Dict[Pattern, Callable],
|
||
|
) -> Dict[str, Tuple[Node, Pattern, NodePattern, FuseHandler, Dict[Node, Any]]]:
|
||
|
modules = dict(root.named_modules())
|
||
|
# node name -> (root_node, match_value)
|
||
|
match_map : Dict[
|
||
|
str, Tuple[Node, Pattern, NodePattern, FuseHandler, Dict[Node, Any]]] = {}
|
||
|
# a map from node to the matched subpattern
|
||
|
node_to_subpattern: Dict[Node, Any] = {}
|
||
|
|
||
|
# TODO: dedup with quantization matching function in match_utils.py
|
||
|
def apply_match(pattern, node, match, matched_node_pattern, node_to_subpattern):
|
||
|
if isinstance(pattern, tuple):
|
||
|
s, *args = pattern
|
||
|
current_node_pattern: List[Node] = []
|
||
|
apply_match(s, node, match, current_node_pattern, node_to_subpattern)
|
||
|
for subpattern, arg in zip(args, node.args):
|
||
|
apply_match(subpattern, arg, match, current_node_pattern, node_to_subpattern)
|
||
|
matched_node_pattern.append(tuple(current_node_pattern))
|
||
|
else:
|
||
|
# the first pattern matches will take precedence
|
||
|
if node.name not in match_map:
|
||
|
matched_node_pattern.append(node)
|
||
|
# MatchAllNode here is actually MatchAllInputNode which should not
|
||
|
# be added to match_map
|
||
|
if pattern is not MatchAllNode:
|
||
|
node_to_subpattern[node] = pattern
|
||
|
root_node, pattern, handler = match
|
||
|
match_map[node.name] = (root_node, pattern, matched_node_pattern, handler, node_to_subpattern)
|
||
|
|
||
|
for node in reversed(graph.nodes):
|
||
|
if node.name not in match_map:
|
||
|
for pattern, fuse_handler_cls in pattern_to_fuse_handler_cls.items():
|
||
|
matched_node_pattern: List[Node] = []
|
||
|
if _is_match(modules, node, pattern):
|
||
|
apply_match(pattern, node, (node, pattern, fuse_handler_cls(node)), matched_node_pattern, node_to_subpattern)
|
||
|
break
|
||
|
|
||
|
return match_map
|