886 lines
37 KiB
Python
886 lines
37 KiB
Python
|
import copy
|
||
|
import torch
|
||
|
import torch.nn as nn
|
||
|
from torch.ao.quantization import (
|
||
|
QConfigAny,
|
||
|
QuantType,
|
||
|
)
|
||
|
from torch.ao.quantization.backend_config import (
|
||
|
DTypeWithConstraints,
|
||
|
)
|
||
|
from torch.ao.quantization.fake_quantize import (
|
||
|
FakeQuantizeBase,
|
||
|
FixedQParamsFakeQuantize,
|
||
|
)
|
||
|
from torch.ao.quantization.observer import (
|
||
|
FixedQParamsObserver,
|
||
|
ObserverBase,
|
||
|
)
|
||
|
from torch.ao.quantization.qconfig import (
|
||
|
float16_static_qconfig,
|
||
|
float16_dynamic_qconfig,
|
||
|
qconfig_equals,
|
||
|
)
|
||
|
from torch.ao.quantization.stubs import DeQuantStub
|
||
|
from torch.ao.quantization.utils import (
|
||
|
activation_is_statically_quantized,
|
||
|
)
|
||
|
from torch.ao.quantization.observer import _is_activation_post_process
|
||
|
from torch.ao.quantization.qconfig_mapping import QConfigMapping
|
||
|
|
||
|
from torch.fx import GraphModule, map_arg
|
||
|
|
||
|
from torch.fx.graph import (
|
||
|
Graph,
|
||
|
Node,
|
||
|
)
|
||
|
from .custom_config import PrepareCustomConfig
|
||
|
# importing the lib so that the quantized_decomposed ops are registered
|
||
|
from ._decomposed import quantized_decomposed_lib # noqa: F401
|
||
|
|
||
|
from typing import Callable, Optional, List, Dict, Any, Set, Tuple, Union, Type
|
||
|
from dataclasses import dataclass
|
||
|
from collections import namedtuple
|
||
|
import operator
|
||
|
import warnings
|
||
|
|
||
|
# TODO: revisit this list. Many helper methods shouldn't be public
|
||
|
__all__ = [
|
||
|
"all_node_args_except_first",
|
||
|
"all_node_args_have_no_tensors",
|
||
|
"assert_and_get_unique_device",
|
||
|
"collect_producer_nodes",
|
||
|
"create_getattr_from_value",
|
||
|
"create_node_from_old_node_preserve_meta",
|
||
|
"EMPTY_ARG_DICT",
|
||
|
"get_custom_module_class_keys",
|
||
|
"get_linear_prepack_op_for_dtype",
|
||
|
"get_new_attr_name_with_prefix",
|
||
|
"get_non_observable_arg_indexes_and_types",
|
||
|
"get_qconv_prepack_op",
|
||
|
"get_skipped_module_name_and_classes",
|
||
|
"graph_module_from_producer_nodes",
|
||
|
"maybe_get_next_module",
|
||
|
"NodeInfo",
|
||
|
"node_arg_is_bias",
|
||
|
"node_arg_is_weight",
|
||
|
"NON_OBSERVABLE_ARG_DICT",
|
||
|
"NON_QUANTIZABLE_WEIGHT_OPS",
|
||
|
"return_arg_list",
|
||
|
"ObservedGraphModuleAttrs",
|
||
|
]
|
||
|
|
||
|
NON_QUANTIZABLE_WEIGHT_OPS = {torch.nn.functional.layer_norm, torch.nn.functional.group_norm, torch.nn.functional.instance_norm}
|
||
|
|
||
|
@dataclass
|
||
|
class ObservedGraphModuleAttrs:
|
||
|
node_name_to_qconfig: Dict[str, QConfigAny]
|
||
|
node_name_to_scope: Dict[str, Tuple[str, type]]
|
||
|
prepare_custom_config: PrepareCustomConfig
|
||
|
equalization_node_name_to_qconfig: Dict[str, Any]
|
||
|
qconfig_mapping: QConfigMapping
|
||
|
is_qat: bool
|
||
|
observed_node_names: Set[str]
|
||
|
is_observed_standalone_module: bool = False
|
||
|
standalone_module_input_quantized_idxs: Optional[List[int]] = None
|
||
|
standalone_module_output_quantized_idxs: Optional[List[int]] = None
|
||
|
|
||
|
def node_arg_is_weight(node: Node, arg: Any) -> bool:
|
||
|
"""Returns if node arg is weight"""
|
||
|
weight_index = None
|
||
|
if "target_dtype_info" in node.meta:
|
||
|
weight_index = node.meta["target_dtype_info"].get("weight_index", None)
|
||
|
if weight_index is not None and weight_index < len(node.args) and node.args[weight_index] is arg:
|
||
|
return True
|
||
|
return node.kwargs.get("weight") is arg
|
||
|
|
||
|
def node_arg_is_bias(node: Node, arg: Any) -> bool:
|
||
|
"""Returns if node arg is bias"""
|
||
|
bias_index = None
|
||
|
if "target_dtype_info" in node.meta:
|
||
|
bias_index = node.meta["target_dtype_info"].get("bias_index", None)
|
||
|
if bias_index is not None and bias_index < len(node.args) and node.args[bias_index] is arg:
|
||
|
return True
|
||
|
return node.kwargs.get("bias") is arg
|
||
|
|
||
|
def get_custom_module_class_keys(custom_module_mapping: Dict[QuantType, Dict[Type, Type]]) -> List[Any]:
|
||
|
r""" Get all the unique custom module keys in the custom config dict
|
||
|
e.g.
|
||
|
Input:
|
||
|
{
|
||
|
QuantType.STATIC: {
|
||
|
CustomModule1: ObservedCustomModule
|
||
|
},
|
||
|
QuantType.DYNAMIC: {
|
||
|
CustomModule2: DynamicObservedCustomModule
|
||
|
},
|
||
|
QuantType.WEIGHT_ONLY: {
|
||
|
CustomModule3: WeightOnlyObservedCustomModule
|
||
|
},
|
||
|
}
|
||
|
|
||
|
Output:
|
||
|
# extract the keys across all inner STATIC, DYNAMIC, and WEIGHT_ONLY dicts
|
||
|
[CustomModule1, CustomModule2, CustomModule3]
|
||
|
"""
|
||
|
# using set to dedup
|
||
|
float_custom_module_classes : Set[Any] = set()
|
||
|
for quant_mode in [QuantType.STATIC, QuantType.DYNAMIC, QuantType.WEIGHT_ONLY]:
|
||
|
quant_mode_custom_module_config = custom_module_mapping.get(quant_mode, {})
|
||
|
quant_mode_custom_module_classes = set(quant_mode_custom_module_config.keys())
|
||
|
float_custom_module_classes |= quant_mode_custom_module_classes
|
||
|
return list(float_custom_module_classes)
|
||
|
|
||
|
def get_linear_prepack_op_for_dtype(dtype):
|
||
|
if dtype == torch.float16:
|
||
|
return torch.ops.quantized.linear_prepack_fp16
|
||
|
elif dtype == torch.qint8:
|
||
|
return torch.ops.quantized.linear_prepack
|
||
|
else:
|
||
|
raise Exception("can't get linear prepack op for dtype:", dtype)
|
||
|
|
||
|
def get_qconv_prepack_op(conv_op: Callable) -> Callable:
|
||
|
prepack_ops = {
|
||
|
torch.nn.functional.conv1d: torch.ops.quantized.conv1d_prepack,
|
||
|
torch.nn.functional.conv2d: torch.ops.quantized.conv2d_prepack,
|
||
|
torch.nn.functional.conv3d: torch.ops.quantized.conv3d_prepack,
|
||
|
torch.nn.functional.conv_transpose1d: torch.ops.quantized.conv_transpose1d_prepack,
|
||
|
torch.nn.functional.conv_transpose2d: torch.ops.quantized.conv_transpose2d_prepack,
|
||
|
torch.nn.functional.conv_transpose3d: torch.ops.quantized.conv_transpose3d_prepack,
|
||
|
}
|
||
|
prepack_op = prepack_ops.get(conv_op, None)
|
||
|
assert prepack_op, f"Didn't find prepack op for {conv_op}"
|
||
|
return prepack_op
|
||
|
|
||
|
# Returns a function that can get a new attribute name for module with given
|
||
|
# prefix, for example,
|
||
|
# >> get_new_observer_name = get_new_attr_name_with_prefix('_observer')
|
||
|
# >> new_name = get_new_observer_name(module)
|
||
|
# new_name will be an unused attribute name on module, e.g. `_observer_1`
|
||
|
def get_new_attr_name_with_prefix(prefix: str) -> Callable:
|
||
|
prefix = prefix.replace(".", "_")
|
||
|
|
||
|
def get_new_attr_name(module: torch.nn.Module):
|
||
|
def get_attr_name(i: int):
|
||
|
return prefix + str(i)
|
||
|
i = 0
|
||
|
attr_name = get_attr_name(i)
|
||
|
while hasattr(module, attr_name):
|
||
|
i += 1
|
||
|
attr_name = get_attr_name(i)
|
||
|
return attr_name
|
||
|
return get_new_attr_name
|
||
|
|
||
|
def collect_producer_nodes(node: Node) -> Optional[List[Node]]:
|
||
|
r''' Starting from a target node, trace back until we hit inpu or
|
||
|
getattr node. This is used to extract the chain of operators
|
||
|
starting from getattr to the target node, for example
|
||
|
def forward(self, x):
|
||
|
observed = self.observer(self.weight)
|
||
|
return F.linear(x, observed)
|
||
|
collect_producer_nodes(observed) will either return a list of nodes that
|
||
|
produces the observed node or None if we can't extract a self contained
|
||
|
graph without free variables(inputs of the forward function).
|
||
|
'''
|
||
|
nodes = [node]
|
||
|
frontier = [node]
|
||
|
while frontier:
|
||
|
node = frontier.pop()
|
||
|
all_args = list(node.args) + list(node.kwargs.values())
|
||
|
for arg in all_args:
|
||
|
if not isinstance(arg, Node):
|
||
|
continue
|
||
|
if arg.op == 'placeholder':
|
||
|
# hit input, can't fold in this case
|
||
|
return None
|
||
|
nodes.append(arg)
|
||
|
if not (arg.op == 'call_function' and arg.target == getattr):
|
||
|
frontier.append(arg)
|
||
|
return nodes
|
||
|
|
||
|
def graph_module_from_producer_nodes(
|
||
|
root: GraphModule, producer_nodes: List[Node]) -> GraphModule:
|
||
|
r''' Construct a graph module from extracted producer nodes
|
||
|
from `collect_producer_nodes` function
|
||
|
Args:
|
||
|
root: the root module for the original graph
|
||
|
producer_nodes: a list of nodes we use to construct the graph
|
||
|
Return:
|
||
|
A graph module constructed from the producer nodes
|
||
|
'''
|
||
|
assert len(producer_nodes) > 0, 'list of producer nodes can not be empty'
|
||
|
# since we traced back from node to getattr
|
||
|
producer_nodes.reverse()
|
||
|
graph = Graph()
|
||
|
env: Dict[Any, Any] = {}
|
||
|
|
||
|
def load_arg(a):
|
||
|
return map_arg(a, lambda node: env[node])
|
||
|
for producer_node in producer_nodes:
|
||
|
env[producer_node] = graph.node_copy(producer_node, load_arg)
|
||
|
graph.output(load_arg(producer_nodes[-1]))
|
||
|
graph_module = GraphModule(root, graph)
|
||
|
return graph_module
|
||
|
|
||
|
def assert_and_get_unique_device(module: torch.nn.Module) -> Any:
|
||
|
"""
|
||
|
Returns the unique device for a module, or None if no device is found.
|
||
|
Throws an error if multiple devices are detected.
|
||
|
"""
|
||
|
devices = {p.device for p in module.parameters()} | \
|
||
|
{p.device for p in module.buffers()}
|
||
|
"""
|
||
|
As a temp workaround for AIMP HHC publish we added CPU check.remove it later. T163614564
|
||
|
"""
|
||
|
if {torch.device("cpu"), torch.device("meta")} == devices:
|
||
|
warnings.warn("Both 'meta' and 'cpu' are present in the list of devices. Module can have one device. We Select 'cpu'.")
|
||
|
devices = {torch.device("cpu")}
|
||
|
""
|
||
|
assert len(devices) <= 1, (
|
||
|
"prepare only works with cpu or single-device CUDA modules, "
|
||
|
f"but got devices {devices}"
|
||
|
)
|
||
|
device = next(iter(devices)) if len(devices) > 0 else None
|
||
|
return device
|
||
|
|
||
|
def create_getattr_from_value(module: torch.nn.Module, graph: Graph, prefix: str, value: Any) -> Node:
|
||
|
"""
|
||
|
Given a value of any type, creates a getattr node corresponding to the value and
|
||
|
registers the value as a buffer to the module.
|
||
|
"""
|
||
|
get_new_attr_name = get_new_attr_name_with_prefix(prefix)
|
||
|
attr_name = get_new_attr_name(module)
|
||
|
device = assert_and_get_unique_device(module)
|
||
|
new_value = value.clone().detach() if isinstance(value, torch.Tensor) \
|
||
|
else torch.tensor(value, device=device)
|
||
|
module.register_buffer(attr_name, new_value)
|
||
|
# Create get_attr with value
|
||
|
attr_node = graph.create_node("get_attr", attr_name)
|
||
|
return attr_node
|
||
|
|
||
|
def all_node_args_have_no_tensors(node: Node, modules: Dict[str, torch.nn.Module], cache: Dict[Node, bool]) -> bool:
|
||
|
"""
|
||
|
If we know for sure that all of this node's args have no
|
||
|
tensors (are primitives), return True. If we either
|
||
|
find a tensor or are not sure, return False. Note: this
|
||
|
function is not exact.
|
||
|
"""
|
||
|
if cache and node in cache:
|
||
|
return cache[node]
|
||
|
|
||
|
result = False # will be overwritten
|
||
|
if not isinstance(node, Node):
|
||
|
result = True
|
||
|
elif node.op == 'placeholder':
|
||
|
result = False
|
||
|
elif node.op == 'call_module':
|
||
|
assert isinstance(node.target, str)
|
||
|
if _is_activation_post_process(modules[node.target]):
|
||
|
result = all_node_args_have_no_tensors(node.args[0], modules, cache) # type: ignore[arg-type]
|
||
|
elif node.op == 'call_module':
|
||
|
result = False
|
||
|
elif node.op == 'call_function' and node.target is operator.getitem:
|
||
|
result = all_node_args_have_no_tensors(node.args[0], modules, cache) # type: ignore[arg-type]
|
||
|
elif node.op == 'get_attr':
|
||
|
result = False
|
||
|
elif node.target is getattr and node.args[1] in ['ndim', 'shape']:
|
||
|
# x1 = x0.ndim
|
||
|
result = True
|
||
|
elif node.op == 'call_method' and node.target == 'size':
|
||
|
# x1 = x0.size(0)
|
||
|
result = True
|
||
|
else:
|
||
|
found_one_tensor = False
|
||
|
for arg in node.args:
|
||
|
if isinstance(arg, list):
|
||
|
for list_el in arg:
|
||
|
if isinstance(list_el, Node):
|
||
|
this_list_el_args_have_no_tensors = \
|
||
|
all_node_args_have_no_tensors(list_el, modules, cache)
|
||
|
found_one_tensor = found_one_tensor or \
|
||
|
(not this_list_el_args_have_no_tensors)
|
||
|
# If found_one_tensor is True, there is no point in
|
||
|
# recursing further as the end result will always
|
||
|
# be True.
|
||
|
# TODO(future PR): remove this entire function and
|
||
|
# change to dtype inference without recursion.
|
||
|
if found_one_tensor:
|
||
|
result = not found_one_tensor
|
||
|
if cache:
|
||
|
cache[node] = result
|
||
|
return result
|
||
|
elif isinstance(arg, int):
|
||
|
pass
|
||
|
else:
|
||
|
if isinstance(arg, Node):
|
||
|
this_arg_args_have_no_tensors = all_node_args_have_no_tensors(arg, modules, cache)
|
||
|
found_one_tensor = found_one_tensor or \
|
||
|
(not this_arg_args_have_no_tensors)
|
||
|
# If found_one_tensor is True, there is no point in
|
||
|
# recursing further as the end result will always
|
||
|
# be True.
|
||
|
# TODO(future PR): remove this entire function and
|
||
|
# change to dtype inference without recursion.
|
||
|
if found_one_tensor:
|
||
|
result = not found_one_tensor
|
||
|
if cache:
|
||
|
cache[node] = result
|
||
|
return result
|
||
|
else:
|
||
|
found_one_tensor = True
|
||
|
result = not found_one_tensor
|
||
|
if cache:
|
||
|
cache[node] = result
|
||
|
return result
|
||
|
|
||
|
def all_node_args_except_first(node: Node) -> List[int]:
|
||
|
"""
|
||
|
Returns all node arg indices after first
|
||
|
"""
|
||
|
return list(range(1, len(node.args)))
|
||
|
|
||
|
def return_arg_list(arg_indices: List[int]) -> Callable[[Node], List[int]]:
|
||
|
"""
|
||
|
Constructs a function that takes a node as arg and returns the arg_indices
|
||
|
that are valid for node.args
|
||
|
"""
|
||
|
def arg_indices_func(node: Node) -> List[int]:
|
||
|
return [i for i in arg_indices if i < len(node.args)]
|
||
|
return arg_indices_func
|
||
|
|
||
|
NodeInfo = namedtuple("NodeInfo", "op target")
|
||
|
|
||
|
# this dict identifies which indices of a node are non tensors
|
||
|
# so that they can be propagated correctly since inserting observers
|
||
|
# for them would cause errors
|
||
|
|
||
|
NON_OBSERVABLE_ARG_DICT: Dict[NodeInfo, Dict[Union[type, torch.dtype], Callable[[Node], List[int]]]] = {
|
||
|
NodeInfo("call_method", "masked_fill") : {
|
||
|
torch.bool: return_arg_list([1]),
|
||
|
float: return_arg_list([2])
|
||
|
},
|
||
|
NodeInfo("call_method", "permute") : {
|
||
|
int: all_node_args_except_first
|
||
|
},
|
||
|
NodeInfo("call_method", "repeat") : {
|
||
|
int: all_node_args_except_first
|
||
|
},
|
||
|
NodeInfo("call_method", "reshape") : {
|
||
|
int: all_node_args_except_first
|
||
|
},
|
||
|
NodeInfo("call_method", "size") : {
|
||
|
int: return_arg_list([1])
|
||
|
},
|
||
|
NodeInfo("call_method", "transpose") : {
|
||
|
int: all_node_args_except_first
|
||
|
},
|
||
|
NodeInfo("call_method", torch.transpose) : {
|
||
|
int: all_node_args_except_first
|
||
|
},
|
||
|
NodeInfo("call_method", "unsqueeze") : {
|
||
|
int: return_arg_list([1])
|
||
|
},
|
||
|
NodeInfo("call_method", "unsqueeze_") : {
|
||
|
int: return_arg_list([1])
|
||
|
},
|
||
|
NodeInfo("call_method", torch.unsqueeze) : {
|
||
|
int: return_arg_list([1])
|
||
|
},
|
||
|
NodeInfo("call_method", "view") : {
|
||
|
int: all_node_args_except_first
|
||
|
},
|
||
|
}
|
||
|
|
||
|
EMPTY_ARG_DICT: Dict[Union[type, torch.dtype], Callable[[Node], List[int]]] = {}
|
||
|
|
||
|
def get_non_observable_arg_indexes_and_types(node: Node) -> Dict[Union[type, torch.dtype], Callable[[Node], List[int]]]:
|
||
|
"""
|
||
|
Returns a dict with of non float tensor types as keys and values which correspond to a
|
||
|
function to retrieve the list (which takes the node as an argument)
|
||
|
"""
|
||
|
info = NodeInfo(node.op, node.target)
|
||
|
|
||
|
return NON_OBSERVABLE_ARG_DICT.get(info, EMPTY_ARG_DICT)
|
||
|
|
||
|
def maybe_get_next_module(
|
||
|
node: Node,
|
||
|
modules: Dict[str, nn.Module],
|
||
|
target_module_type: Optional[Type[nn.Module]] = None,
|
||
|
target_functional_type: Any = None,
|
||
|
) -> Optional[Node]:
|
||
|
""" Gets the next module that matches what is needed in
|
||
|
is_target_module_type if it exists
|
||
|
|
||
|
Args:
|
||
|
node: The node whose users we want to look at
|
||
|
target_module_type: Module type that we want to check
|
||
|
target_functional_type: Functional type that we want to check
|
||
|
"""
|
||
|
|
||
|
for user in node.users.keys():
|
||
|
if user.op == 'call_module' and target_module_type is not None and \
|
||
|
isinstance(modules[str(user.target)], target_module_type):
|
||
|
return user
|
||
|
elif (user.op == 'call_function' and target_functional_type is not None and
|
||
|
user.target == target_functional_type):
|
||
|
return user
|
||
|
|
||
|
return None
|
||
|
|
||
|
def create_node_from_old_node_preserve_meta(
|
||
|
quantized_graph: Graph,
|
||
|
create_node_args: Tuple[Any, ...],
|
||
|
old_node: Node,
|
||
|
) -> Node:
|
||
|
"""
|
||
|
Creates `new_node` and copies the necessary metadata to it from `old_node`.
|
||
|
"""
|
||
|
new_node = quantized_graph.create_node(*create_node_args)
|
||
|
new_node.stack_trace = old_node.stack_trace
|
||
|
return new_node
|
||
|
|
||
|
def get_skipped_module_name_and_classes(
|
||
|
prepare_custom_config: PrepareCustomConfig,
|
||
|
is_standalone_module: bool) -> Tuple[List[str], List[Type[Any]]]:
|
||
|
skipped_module_names = copy.copy(prepare_custom_config.non_traceable_module_names)
|
||
|
skipped_module_classes = copy.copy(prepare_custom_config.non_traceable_module_classes)
|
||
|
if not is_standalone_module:
|
||
|
# standalone module and custom module config are applied in top level module
|
||
|
skipped_module_names += list(prepare_custom_config.standalone_module_names.keys())
|
||
|
skipped_module_classes += list(prepare_custom_config.standalone_module_classes.keys())
|
||
|
skipped_module_classes += get_custom_module_class_keys(prepare_custom_config.float_to_observed_mapping)
|
||
|
|
||
|
return skipped_module_names, skipped_module_classes
|
||
|
|
||
|
def _is_custom_module_lstm(
|
||
|
node: Node,
|
||
|
named_modules: Dict[str, torch.nn.Module],
|
||
|
qconfig: QConfigAny = None,
|
||
|
# QuantizeHandler, but we cannot include the type here due to circular imports
|
||
|
qhandler: Optional[Any] = None,
|
||
|
) -> bool:
|
||
|
"""
|
||
|
Return whether this refers to the custom module LSTM flow.
|
||
|
"""
|
||
|
mod = _get_module(node, named_modules)
|
||
|
if qconfig is not None and qhandler is not None:
|
||
|
assert isinstance(qhandler, torch.ao.quantization.fx.quantize_handler.QuantizeHandler) # type: ignore[attr-defined]
|
||
|
return isinstance(mod, torch.nn.LSTM) and \
|
||
|
activation_is_statically_quantized(qconfig) and \
|
||
|
qhandler.is_custom_module()
|
||
|
else:
|
||
|
return isinstance(mod, torch.ao.nn.quantizable.LSTM)
|
||
|
|
||
|
def _is_custom_module_mha(
|
||
|
node: Node,
|
||
|
named_modules: Dict[str, torch.nn.Module],
|
||
|
qconfig: QConfigAny = None,
|
||
|
# QuantizeHandler, but we cannot include the type here due to circular imports
|
||
|
qhandler: Optional[Any] = None,
|
||
|
) -> bool:
|
||
|
"""
|
||
|
Return whether this refers to the custom module MultiheadAttention flow.
|
||
|
"""
|
||
|
mod = _get_module(node, named_modules)
|
||
|
if qconfig is not None and qhandler is not None:
|
||
|
assert isinstance(qhandler, torch.ao.quantization.fx.quantize_handler.QuantizeHandler) # type: ignore[attr-defined]
|
||
|
return isinstance(mod, torch.nn.MultiheadAttention) and \
|
||
|
activation_is_statically_quantized(qconfig) and \
|
||
|
qhandler.is_custom_module()
|
||
|
else:
|
||
|
return isinstance(mod, torch.ao.nn.quantizable.MultiheadAttention)
|
||
|
|
||
|
def _get_module(node: Node, named_modules: Dict[str, torch.nn.Module]) -> Optional[torch.nn.Module]:
|
||
|
"""
|
||
|
If `node` refers to a call_module node, return the module, else None.
|
||
|
"""
|
||
|
if node.op == "call_module" and str(node.target) in named_modules:
|
||
|
return named_modules[str(node.target)]
|
||
|
else:
|
||
|
return None
|
||
|
|
||
|
def _insert_dequant_stub(
|
||
|
node: Node,
|
||
|
model: torch.nn.Module,
|
||
|
named_modules: Dict[str, torch.nn.Module],
|
||
|
graph: Graph,
|
||
|
) -> Node:
|
||
|
"""
|
||
|
Attach a `DeQuantStub` to the model and create a node that calls this
|
||
|
`DeQuantStub` on the output of `node`, similar to how observers are inserted.
|
||
|
"""
|
||
|
prefix = "dequant_stub_"
|
||
|
get_new_dequant_stub_name = get_new_attr_name_with_prefix(prefix)
|
||
|
dequant_stub_name = get_new_dequant_stub_name(model)
|
||
|
dequant_stub = DeQuantStub()
|
||
|
setattr(model, dequant_stub_name, dequant_stub)
|
||
|
named_modules[dequant_stub_name] = dequant_stub
|
||
|
with graph.inserting_after(node):
|
||
|
return graph.call_module(dequant_stub_name, (node,))
|
||
|
|
||
|
def _insert_dequant_stubs_for_custom_module_lstm_output(
|
||
|
node: Node,
|
||
|
model: torch.nn.Module,
|
||
|
named_modules: Dict[str, torch.nn.Module],
|
||
|
graph: Graph,
|
||
|
) -> Node:
|
||
|
"""
|
||
|
Insert DeQuantStubs after each internal output node of custom module LSTM.
|
||
|
|
||
|
Custom module LSTM outputs are nested tuples of the structure (output, (hidden0, hidden1)),
|
||
|
Since we cannot dequantize a tuple as a whole, we must first break down the tuple into its
|
||
|
components through `getitem`. This function transforms the graph as follows:
|
||
|
|
||
|
(1) Split the LSTM node into (output, (hidden0, hidden1))
|
||
|
(2) Insert a DeQuantStub after each internal node
|
||
|
(3) Recombine the DeQuantStubs into the same structure as before
|
||
|
(4) Reroute all consumers of the original LSTM node and its sub-nodes
|
||
|
(e.g. lstm[0])
|
||
|
|
||
|
Before:
|
||
|
lstm_output
|
||
|
|
|
||
|
v
|
||
|
original_user(s)
|
||
|
After:
|
||
|
lstm_output
|
||
|
/ \\
|
||
|
/ (getitem) \\
|
||
|
/ \\
|
||
|
v v
|
||
|
output hidden
|
||
|
| / \\
|
||
|
(DeQuantStub) (getitem)
|
||
|
| / \\
|
||
|
v v v
|
||
|
output_dq hidden0 hidden1
|
||
|
| | |
|
||
|
| (DeQuantStub) (DeQuantStub)
|
||
|
| | |
|
||
|
| v v
|
||
|
| hidden0_dq hidden1_dq
|
||
|
| \\ /
|
||
|
| (tuple)
|
||
|
| \\ /
|
||
|
| v v
|
||
|
| hidden_dq
|
||
|
\\ /
|
||
|
\\ (tuple) /
|
||
|
v v
|
||
|
lstm_output_dq
|
||
|
|
|
||
|
v
|
||
|
original_user(s)
|
||
|
|
||
|
For step (4), reroute all users of the original LSTM node(s) as follows:
|
||
|
lstm_output -> lstm_output_dq
|
||
|
lstm_output[0] -> output_dq
|
||
|
lstm_output[1] -> hidden_dq
|
||
|
lstm_output[1][0] -> hidden0_dq
|
||
|
lstm_output[1][1] -> hidden1_dq
|
||
|
|
||
|
Return the node `lstm_output_dq`.
|
||
|
"""
|
||
|
# (1) Split the LSTM node into (output, (hidden0, hidden1))
|
||
|
# (2) Insert a DeQuantStub after each internal node
|
||
|
with graph.inserting_after(node):
|
||
|
output = graph.call_function(operator.getitem, (node, 0))
|
||
|
output_dq = _insert_dequant_stub(output, model, named_modules, graph)
|
||
|
with graph.inserting_after(output_dq):
|
||
|
hidden = graph.call_function(operator.getitem, (node, 1))
|
||
|
with graph.inserting_after(hidden):
|
||
|
hidden0 = graph.call_function(operator.getitem, (hidden, 0))
|
||
|
hidden0_dq = _insert_dequant_stub(hidden0, model, named_modules, graph)
|
||
|
with graph.inserting_after(hidden0_dq):
|
||
|
hidden1 = graph.call_function(operator.getitem, (hidden, 1))
|
||
|
hidden1_dq = _insert_dequant_stub(hidden1, model, named_modules, graph)
|
||
|
|
||
|
# (3) Recombine the DeQuantStubs into the same structure as before
|
||
|
with graph.inserting_after(hidden1_dq):
|
||
|
hidden_dq = graph.call_function(tuple, ([hidden0_dq, hidden1_dq],))
|
||
|
with graph.inserting_after(hidden_dq):
|
||
|
lstm_output_dq = graph.call_function(tuple, ([output_dq, hidden_dq],))
|
||
|
|
||
|
# (4) Reroute all consumers of the original LSTM node and its sub-nodes
|
||
|
for user in list(node.users.keys()):
|
||
|
if user != output and user != hidden:
|
||
|
user.replace_input_with(node, lstm_output_dq)
|
||
|
# The getitem and tuple nodes we added here may interfere with reference quantized
|
||
|
# pattern matching, so we need to redirect the consumers of internal nodes to the
|
||
|
# corresponding nodes with DeQuantStubs (e.g. lstm_output_dq[0] -> output_dq) attached,
|
||
|
# in order to preserve reference patterns like "dequantize - consumer - quantize".
|
||
|
_reroute_tuple_getitem_pattern(graph)
|
||
|
return lstm_output_dq
|
||
|
|
||
|
def _maybe_get_custom_module_lstm_from_node_arg(
|
||
|
arg: Node,
|
||
|
named_modules: Dict[str, torch.nn.Module],
|
||
|
) -> Optional[Node]:
|
||
|
"""
|
||
|
Given an argument of a node, if the argument refers to the path through which the node
|
||
|
is a consumer of custom module LSTM, return the custom module LSTM node, or None otherwise.
|
||
|
|
||
|
This is used to determine whether a node is a consumer of custom module LSTM, and, if so,
|
||
|
skip inserting input observers for this node. This is because custom module LSTM produces
|
||
|
quantized outputs, so inserting an input observer for the consumer of custom module LSTM
|
||
|
would unnecessarily quantize the outputs again.
|
||
|
|
||
|
lstm -> consumer
|
||
|
|
||
|
In practice, however, custom module LSTM outputs a tuple (output, (hidden0, hidden1)) with
|
||
|
DeQuantStubs attached to each internal node (see `_insert_dequant_stubs_for_custom_module_lstm_output`).
|
||
|
This tuple can be consumed in one of four ways:
|
||
|
|
||
|
lstm -> getitem -> DeQuantStub -> consumer # consume lstm[0]
|
||
|
lstm -> getitem -> getitem -> DeQuantStub -> tuple -> consumer # consume lstm[1]
|
||
|
lstm -> getitem -> getitem -> DeQuantStub -> consumer # consume lstm[1][0] or lstm[1][1]
|
||
|
lstm -> getitem -> DeQuantStub -> tuple -> consumer # consume lstm
|
||
|
|
||
|
Thus, we must match against the above patterns instead of simply checking the parent node
|
||
|
to determine whether this node is a consumer of a custom module LSTM.
|
||
|
"""
|
||
|
def match_dq(a):
|
||
|
return isinstance(_get_module(a, named_modules), DeQuantStub)
|
||
|
|
||
|
def match_lstm(a):
|
||
|
return _is_custom_module_lstm(a, named_modules)
|
||
|
|
||
|
def match_getitem(a):
|
||
|
return a.op == "call_function" and a.target == operator.getitem
|
||
|
|
||
|
def match_tuple(a):
|
||
|
return a.op == "call_function" and a.target == tuple
|
||
|
|
||
|
def _match_pattern(match_pattern: List[Callable]) -> Optional[Node]:
|
||
|
"""
|
||
|
Traverse up the graph and match the args one by one.
|
||
|
If there is a match, return the last matched node, or None otherwise.
|
||
|
"""
|
||
|
a = arg
|
||
|
for i, match in enumerate(match_pattern):
|
||
|
if not match(a):
|
||
|
return None
|
||
|
# Match next arg, for tuple the arg is a tuple of a list, e.g. ([dq_1, other_node],)
|
||
|
if i < len(match_pattern) - 1:
|
||
|
if match == match_tuple:
|
||
|
a = a.args[0][0] # type: ignore[assignment,index]
|
||
|
else:
|
||
|
a = a.args[0] # type: ignore[assignment]
|
||
|
return a
|
||
|
|
||
|
all_match_patterns = [
|
||
|
[match_dq, match_getitem, match_lstm],
|
||
|
[match_tuple, match_dq, match_getitem, match_getitem, match_lstm],
|
||
|
[match_dq, match_getitem, match_getitem, match_lstm],
|
||
|
[match_tuple, match_dq, match_getitem, match_lstm],
|
||
|
]
|
||
|
|
||
|
for p in all_match_patterns:
|
||
|
matched_node = _match_pattern(p)
|
||
|
if matched_node is not None:
|
||
|
return matched_node
|
||
|
return None
|
||
|
|
||
|
def _reroute_tuple_getitem_pattern(graph: Graph):
|
||
|
"""
|
||
|
Search for patterns where N consecutive `tuple` call_function nodes are followed by
|
||
|
N consecutive `getitem` call_function nodes that are "reverses" of the `tuple` nodes.
|
||
|
If we find this pattern, reroute the consumers of the last `getitem` to skip these
|
||
|
N `tuple` and `getitem` nodes.
|
||
|
|
||
|
Before:
|
||
|
|
||
|
a b c
|
||
|
| \\ /
|
||
|
\\ tuple
|
||
|
\\ /
|
||
|
tuple
|
||
|
|
|
||
|
getitem(1)
|
||
|
|
|
||
|
getitem(0)
|
||
|
|
|
||
|
d
|
||
|
|
||
|
After:
|
||
|
|
||
|
b
|
||
|
|
|
||
|
d
|
||
|
"""
|
||
|
def find_patterns(
|
||
|
node: Node,
|
||
|
index_stack: List[int],
|
||
|
current_pattern: List[Node],
|
||
|
matched_patterns: List[List[Node]],
|
||
|
seen: Set[Tuple[Node, Tuple[int, ...]]]):
|
||
|
"""
|
||
|
Traverse the graph recursively to match for the N-tuple - N-getitem patterns,
|
||
|
starting at the given node.
|
||
|
|
||
|
We use a stack to keep track of the expected `getitem` indices, since these are
|
||
|
reversed from the `tuple` indices. In the above example, the stack after
|
||
|
(b -> tuple -> tuple) will be [0, 1], which will be popped by getitem(1) first
|
||
|
and then by getitem(0).
|
||
|
|
||
|
TODO: traverse upwards from the output and handle the case when tuple is not a
|
||
|
separate node, e.g. graph.call_function(operator.getitem, args=(a, (b, c)))
|
||
|
"""
|
||
|
if len(index_stack) == 0 and len(current_pattern) > 0:
|
||
|
matched_patterns.append(copy.copy(current_pattern))
|
||
|
current_pattern.clear()
|
||
|
|
||
|
# Avoid duplicating work
|
||
|
state = (node, tuple(index_stack))
|
||
|
if state in seen:
|
||
|
return
|
||
|
seen.add(state)
|
||
|
|
||
|
# Iterate through users of this node to find tuple/getitem nodes to match
|
||
|
for user in node.users:
|
||
|
if user.op == "call_function" and user.target == tuple:
|
||
|
for i, user_arg in enumerate(user.args[0]): # type: ignore[arg-type]
|
||
|
if user_arg == node:
|
||
|
index_stack.append(i)
|
||
|
current_pattern.append(user)
|
||
|
find_patterns(user, index_stack, current_pattern, matched_patterns, seen)
|
||
|
elif user.op == "call_function" and user.target == operator.getitem:
|
||
|
if len(index_stack) > 0:
|
||
|
if user.args[1] == index_stack[-1]:
|
||
|
index_stack.pop()
|
||
|
current_pattern.append(user)
|
||
|
find_patterns(user, index_stack, current_pattern, matched_patterns, seen)
|
||
|
return matched_patterns
|
||
|
|
||
|
# Collect all matched patterns
|
||
|
matched_patterns: List[List[Node]] = []
|
||
|
seen: Set[Tuple[Node, Tuple[int, ...]]] = set() # (node, index_stack)
|
||
|
for node in graph.nodes:
|
||
|
find_patterns(node, [], [], matched_patterns, seen)
|
||
|
|
||
|
# For each pattern, redirect all consumers of the last getitem node to the correct input
|
||
|
# of the first tuple node
|
||
|
for pattern in matched_patterns:
|
||
|
first_tuple = pattern[0]
|
||
|
last_getitem = pattern[-1]
|
||
|
assert first_tuple.op == "call_function" and first_tuple.target == tuple
|
||
|
assert last_getitem.op == "call_function" and last_getitem.target == operator.getitem
|
||
|
last_getitem_index = last_getitem.args[1]
|
||
|
new_input = first_tuple.args[0][last_getitem_index] # type: ignore[index]
|
||
|
for user in list(last_getitem.users.keys()):
|
||
|
user.replace_input_with(last_getitem, new_input)
|
||
|
|
||
|
def _get_observer_from_activation_post_process(
|
||
|
activation_post_process: Union[ObserverBase, FakeQuantizeBase],
|
||
|
) -> ObserverBase:
|
||
|
"""
|
||
|
If `activation_post_process` is an observer, return the observer.
|
||
|
If `activation_post_process` is a fake quantize, return the internal observer.
|
||
|
"""
|
||
|
if isinstance(activation_post_process, ObserverBase):
|
||
|
return activation_post_process
|
||
|
else:
|
||
|
assert isinstance(activation_post_process, FakeQuantizeBase)
|
||
|
return activation_post_process.activation_post_process # type: ignore[return-value]
|
||
|
|
||
|
def _qconfig_satisfies_dtype_config_constraints(
|
||
|
qconfig: QConfigAny,
|
||
|
dtype_with_constraints: DTypeWithConstraints,
|
||
|
is_activation: bool = True) -> bool:
|
||
|
"""
|
||
|
Return whether `qconfig` satisfies the following constraints from the backend,
|
||
|
specified through the activation and weight DTypeWithConstraints.
|
||
|
|
||
|
1. QConfig specified a quantization range that falls within the backend's, if any
|
||
|
2. QConfig specified a min scale value that is >= the backend's, if any
|
||
|
3. QConfig specified a FixedQParamsObserver or FixedQParamsFakeQuantize that has
|
||
|
scale and zero point that match the backend's, if any
|
||
|
|
||
|
If `is_activation` is True, we check `qconfig.activation`, else we check `qconfig.weight`.
|
||
|
If `qconfig` or `dtype_with_constraints.dtype` is None, or the dtypes do not match, return True.
|
||
|
"""
|
||
|
# TODO: log warnings only when the user enabled a debug flag
|
||
|
def _activation_post_process_satisfies_dtype_config_constraints(
|
||
|
activation_post_process: Union[ObserverBase, FakeQuantizeBase],
|
||
|
dtype_with_constraints: DTypeWithConstraints,
|
||
|
debug_string: str) -> bool:
|
||
|
observer = _get_observer_from_activation_post_process(activation_post_process)
|
||
|
app_quant_min = getattr(observer, "quant_min", None)
|
||
|
app_quant_max = getattr(observer, "quant_max", None)
|
||
|
# TODO: for now, just use the existing eps value as scale_min. In the future, we should
|
||
|
# resolve the differences between the two, either by renaming eps or some other way
|
||
|
app_scale_min = getattr(observer, "eps", None)
|
||
|
backend_quant_min = dtype_with_constraints.quant_min_lower_bound
|
||
|
backend_quant_max = dtype_with_constraints.quant_max_upper_bound
|
||
|
backend_scale_min = dtype_with_constraints.scale_min_lower_bound
|
||
|
backend_scale_exact_match = dtype_with_constraints.scale_exact_match
|
||
|
backend_zero_point_exact_match = dtype_with_constraints.zero_point_exact_match
|
||
|
# check quantization ranges
|
||
|
if backend_quant_min is not None and backend_quant_max is not None:
|
||
|
if app_quant_min is None or app_quant_max is None:
|
||
|
warnings.warn(f"QConfig {debug_string} must specify 'quant_min' and 'quant_max', ignoring {qconfig}")
|
||
|
return False
|
||
|
elif app_quant_min < backend_quant_min or app_quant_max > backend_quant_max:
|
||
|
warnings.warn(
|
||
|
f"QConfig {debug_string} quantization range must fall within the backend's:\n"
|
||
|
f"QConfig range = ({app_quant_min}, {app_quant_max}), "
|
||
|
f"BackendConfig range = ({backend_quant_min}, {backend_quant_max}), "
|
||
|
f"ignoring {qconfig}"
|
||
|
)
|
||
|
return False
|
||
|
# check scale min
|
||
|
if backend_scale_min is not None:
|
||
|
if app_scale_min is None:
|
||
|
warnings.warn(f"QConfig {debug_string} must specify 'eps', ignoring {qconfig}")
|
||
|
return False
|
||
|
if app_scale_min < backend_scale_min:
|
||
|
warnings.warn(
|
||
|
f"QConfig {debug_string} eps ({app_scale_min}) must be greater than or equal to "
|
||
|
f"the backend's min scale value ({backend_scale_min}), ignoring {qconfig}"
|
||
|
)
|
||
|
return False
|
||
|
# check fixed scale and zero point
|
||
|
if backend_scale_exact_match is not None and backend_zero_point_exact_match is not None:
|
||
|
# For tests only, accept the following qconfigs for now
|
||
|
# TODO: handle fp16 qconfigs properly
|
||
|
for accepted_qconfig in [float16_static_qconfig, float16_dynamic_qconfig]:
|
||
|
if qconfig_equals(qconfig, accepted_qconfig):
|
||
|
return True
|
||
|
suggestion_str = (
|
||
|
"Please use torch.ao.quantization.get_default_qconfig_mapping or "
|
||
|
"torch.ao.quantization.get_default_qat_qconfig_mapping. Example:\n"
|
||
|
" qconfig_mapping = get_default_qconfig_mapping(\"fbgemm\")\n"
|
||
|
" model = prepare_fx(model, qconfig_mapping, example_inputs)"
|
||
|
)
|
||
|
if not isinstance(activation_post_process, FixedQParamsObserver) and \
|
||
|
not isinstance(activation_post_process, FixedQParamsFakeQuantize):
|
||
|
warnings.warn(
|
||
|
f"QConfig must specify a FixedQParamsObserver or a FixedQParamsFakeQuantize "
|
||
|
f"for fixed qparams ops, ignoring {qconfig}.\n{suggestion_str}"
|
||
|
)
|
||
|
return False
|
||
|
if observer.scale != backend_scale_exact_match or observer.zero_point != backend_zero_point_exact_match:
|
||
|
warnings.warn(
|
||
|
f"QConfig fixed scale ({observer.scale}) and zero point ({observer.zero_point}) "
|
||
|
f"do not match the backend's ({backend_scale_exact_match} and {backend_zero_point_exact_match}), "
|
||
|
f"ignoring {qconfig}.\n{suggestion_str}"
|
||
|
)
|
||
|
return False
|
||
|
return True
|
||
|
|
||
|
if qconfig is None or dtype_with_constraints.dtype is None:
|
||
|
return True
|
||
|
|
||
|
activation_post_process_ctr = qconfig.activation if is_activation else qconfig.weight
|
||
|
debug_string = "activation" if is_activation else "weight"
|
||
|
satisfies_constraints = True
|
||
|
if activation_post_process_ctr is not None:
|
||
|
activation_post_process = activation_post_process_ctr()
|
||
|
assert _is_activation_post_process(activation_post_process)
|
||
|
# If dtypes don't match, don't check the activation_post_process and return True early
|
||
|
if activation_post_process.dtype != dtype_with_constraints.dtype:
|
||
|
return True
|
||
|
satisfies_constraints = _activation_post_process_satisfies_dtype_config_constraints(
|
||
|
activation_post_process, dtype_with_constraints, debug_string)
|
||
|
return satisfies_constraints
|