1132 lines
54 KiB
Python
1132 lines
54 KiB
Python
# mypy: ignore-errors
|
|
|
|
from typing import Any, Dict, List, Optional, Set, Tuple, Union, Type, Callable
|
|
from torch.ao.quantization.quant_type import QuantType
|
|
import torch
|
|
import copy
|
|
import warnings
|
|
from torch.fx import (
|
|
GraphModule,
|
|
)
|
|
from torch.fx.graph import (
|
|
Graph,
|
|
Node,
|
|
Argument,
|
|
)
|
|
from ..utils import (
|
|
activation_is_statically_quantized,
|
|
weight_is_quantized,
|
|
get_qparam_dict,
|
|
_parent_name,
|
|
get_swapped_custom_module_class,
|
|
)
|
|
from ..qconfig import (
|
|
QConfigAny,
|
|
qconfig_equals
|
|
)
|
|
from ..qconfig_mapping import QConfigMapping
|
|
from .qconfig_mapping_utils import (
|
|
_generate_node_name_to_qconfig,
|
|
_compare_prepare_convert_qconfig_mappings,
|
|
_update_qconfig_for_fusion,
|
|
_is_qconfig_supported_by_dtype_configs,
|
|
_update_qconfig_for_qat,
|
|
)
|
|
from torch.ao.quantization.backend_config.utils import (
|
|
get_root_module_to_quantized_reference_module,
|
|
get_pattern_to_dtype_configs,
|
|
get_fused_module_classes,
|
|
get_qat_module_classes,
|
|
)
|
|
from torch.ao.quantization.backend_config import (
|
|
BackendConfig,
|
|
get_native_backend_config,
|
|
)
|
|
from torch.ao.quantization.observer import _is_activation_post_process
|
|
from .graph_module import (
|
|
_is_observed_module,
|
|
_is_observed_standalone_module,
|
|
)
|
|
from ._equalize import update_obs_for_equalization, convert_eq_obs
|
|
from torch.nn.utils.parametrize import type_before_parametrizations
|
|
from .utils import (
|
|
_get_module,
|
|
_is_custom_module_lstm,
|
|
_is_custom_module_mha,
|
|
assert_and_get_unique_device,
|
|
get_custom_module_class_keys,
|
|
create_getattr_from_value,
|
|
collect_producer_nodes,
|
|
graph_module_from_producer_nodes,
|
|
node_arg_is_weight,
|
|
)
|
|
from torch.ao.quantization.utils import (
|
|
is_per_channel,
|
|
to_underlying_dtype,
|
|
)
|
|
from torch.ao.quantization.quantize import (
|
|
_remove_qconfig,
|
|
)
|
|
from torch.ao.quantization.stubs import DeQuantStub
|
|
from .custom_config import (
|
|
ConvertCustomConfig,
|
|
PrepareCustomConfig,
|
|
)
|
|
from .lower_to_fbgemm import lower_to_fbgemm
|
|
# importing the lib so that the quantized_decomposed ops are registered
|
|
from ._decomposed import quantized_decomposed_lib # noqa: F401
|
|
import operator
|
|
|
|
__all__ = [
|
|
"convert",
|
|
"convert_custom_module",
|
|
"convert_standalone_module",
|
|
"convert_weighted_module",
|
|
]
|
|
|
|
_QSCHEME_TO_CHOOSE_QPARAMS_OP = {
|
|
torch.per_tensor_affine: torch.ops.quantized_decomposed.choose_qparams.tensor,
|
|
torch.per_tensor_symmetric: torch.ops.quantized_decomposed.choose_qparams_symmetric.tensor,
|
|
}
|
|
|
|
def _replace_observer_with_quantize_dequantize_node_decomposed(
|
|
model: torch.fx.GraphModule,
|
|
node: Node,
|
|
modules: Dict[str, torch.nn.Module],
|
|
node_name_to_scope: Dict[str, Tuple[str, type]],
|
|
node_name_to_qconfig: Dict[str, QConfigAny]) -> None:
|
|
""" Replace activation_post_process module call node with quantize and
|
|
dequantize node working with decomposed Tensor
|
|
|
|
Before:
|
|
... -> observer_0(x) -> ...
|
|
After:
|
|
... -> torch.ops.quantized_decomposed.quantize_per_tensor(x, ...) ->
|
|
torch.ops.quantized_decomposed.dequantize_per_tensor() -> ...
|
|
|
|
or quantize_per_channel and dequantize_per_channel
|
|
"""
|
|
graph = model.graph
|
|
assert modules is not None
|
|
assert isinstance(node.target, str)
|
|
module_path, prefix = _get_module_path_and_prefix(node, node_name_to_scope, node_name_to_qconfig)
|
|
activation_post_process = modules[node.target]
|
|
if hasattr(activation_post_process, "convert"):
|
|
activation_post_process.convert(model, node)
|
|
return
|
|
# skip replacing observers to quant/dequant nodes if the qconfigs of all
|
|
# consumers and producers of this observer are None
|
|
skip_replacement = all(_has_none_qconfig(n, node_name_to_qconfig) for n in
|
|
list(node.args) + list(node.users.keys()))
|
|
if skip_replacement or not _is_conversion_supported(activation_post_process):
|
|
# didn't find corresponding quantize op and info for the activation_post_process
|
|
# so we just remove the observer
|
|
with graph.inserting_before(node):
|
|
node.replace_all_uses_with(node.args[0])
|
|
graph.erase_node(node)
|
|
return
|
|
|
|
# otherwise, we can convert the activation_post_process module call to quantize/dequantize node
|
|
|
|
# 1. extract the information from activation_post_process module for generating
|
|
# the quantize and dequantize operator
|
|
dtype = activation_post_process.dtype # type: ignore[attr-defined]
|
|
|
|
is_dynamic = False
|
|
if hasattr(activation_post_process, "is_dynamic"):
|
|
is_dynamic = activation_post_process.is_dynamic # type: ignore[assignment]
|
|
|
|
if dtype in [torch.quint8, torch.qint8, torch.qint32, torch.uint8, torch.int8, torch.int16, torch.int32] and \
|
|
(not is_dynamic):
|
|
# TODO: probably should cleanup this condition check, it's hard
|
|
# to reason about this if and the following elif
|
|
|
|
# uint8/int8/int32 static quantization branch
|
|
|
|
# 1. extract information for inserting q/dq node from activation_post_process
|
|
node_type = "call_function"
|
|
quantize_op : Optional[Callable] = None
|
|
scale, zero_point = activation_post_process.calculate_qparams() # type: ignore[attr-defined, operator]
|
|
if is_per_channel(activation_post_process.qscheme): # type: ignore[attr-defined]
|
|
ch_axis = int(activation_post_process.ch_axis) # type: ignore[attr-defined, arg-type]
|
|
quantize_op = torch.ops.quantized_decomposed.quantize_per_channel.default
|
|
dequantize_op = torch.ops.quantized_decomposed.dequantize_per_channel.default
|
|
quant_min = activation_post_process.quant_min
|
|
quant_max = activation_post_process.quant_max
|
|
dtype_ = to_underlying_dtype(dtype)
|
|
qparams = {
|
|
"_scale_": scale,
|
|
"_zero_point_": zero_point,
|
|
"_axis_": ch_axis,
|
|
"_quant_min_": quant_min,
|
|
"_quant_max_": quant_max,
|
|
"_dtype_": dtype_
|
|
}
|
|
else:
|
|
quantize_op = torch.ops.quantized_decomposed.quantize_per_tensor.default
|
|
dequantize_op = torch.ops.quantized_decomposed.dequantize_per_tensor.default
|
|
scale = float(scale)
|
|
zero_point = int(zero_point)
|
|
quant_min = activation_post_process.quant_min # type: ignore[attr-defined]
|
|
quant_max = activation_post_process.quant_max # type: ignore[attr-defined]
|
|
dtype_ = to_underlying_dtype(dtype)
|
|
qparams = {
|
|
"_scale_": scale,
|
|
"_zero_point_": zero_point,
|
|
"_quant_min_": quant_min,
|
|
"_quant_max_": quant_max,
|
|
"_dtype_": dtype_
|
|
}
|
|
|
|
# 2. replace activation_post_process node with quantize and dequantize
|
|
with graph.inserting_before(node):
|
|
input_node = node.args[0]
|
|
quantize_op_inputs = [input_node]
|
|
for key, value_or_node in qparams.items():
|
|
# TODO: we can add the information of whether a value needs to
|
|
# be registered as an attribute in qparams dict itself
|
|
if key in ['_scale_', '_zero_point_'] and (not isinstance(value_or_node, (float, int))):
|
|
# For scale and zero_point values we register them as buffers in the root module.
|
|
# However, note that when the values are not tensors, as in the case of
|
|
# per_tensor quantization, they will be treated as literals.
|
|
# However, registering them as a node seems to cause issue with dynamo
|
|
# tracing where it may consider tensor overload as opposed to default.
|
|
# With extra check of scale and zero_point being scalar, it makes
|
|
# sure that the default overload can be used.
|
|
# TODO: maybe need more complex attr name here
|
|
qparam_node = create_getattr_from_value(
|
|
model, graph, module_path + prefix + key, value_or_node)
|
|
quantize_op_inputs.append(qparam_node)
|
|
else:
|
|
# for qparams that are not scale/zero_point (like axis, dtype) we store them as literals in the graph.
|
|
quantize_op_inputs.append(value_or_node)
|
|
|
|
quantized_node = graph.create_node(node_type, quantize_op, tuple(quantize_op_inputs), {})
|
|
# use the same qparams from quantize op
|
|
dq_inputs = [quantized_node] + quantize_op_inputs[1:]
|
|
dequantized_node = graph.call_function(
|
|
dequantize_op,
|
|
tuple(dq_inputs),
|
|
{}
|
|
)
|
|
|
|
def remap_fn(x):
|
|
return dequantized_node if x is node else x
|
|
|
|
# remap numeric_debug_handle
|
|
for user_node in node.users:
|
|
if "numeric_debug_handle" in user_node.meta:
|
|
numeric_debug_handle = user_node.meta["numeric_debug_handle"]
|
|
user_node.meta["numeric_debug_handle"] = {remap_fn(k): v for k, v in numeric_debug_handle.items()}
|
|
node.replace_all_uses_with(dequantized_node)
|
|
graph.erase_node(node)
|
|
elif is_dynamic:
|
|
|
|
# uint8/int8/fp16 dynamic quantization
|
|
|
|
# 1. extract information for inserting q/dq node from activation_post_process
|
|
node_type = "call_function"
|
|
quantize_op = torch.ops.quantized_decomposed.quantize_per_tensor.tensor
|
|
# we only use choose_qparams for is_decomposed now,
|
|
# but we should probably align the non-decomposed path with this as well,
|
|
# and that can be done after we remove reduce_range flag
|
|
# 1. extract qparams from activation_post_process module
|
|
dtype_ = to_underlying_dtype(dtype)
|
|
assert dtype_ in [torch.uint8, torch.int8], \
|
|
"only uint8 and int8 are supported in reference flow for " \
|
|
"dynamic quantization right now"
|
|
quant_min = activation_post_process.quant_min # type: ignore[attr-defined]
|
|
quant_max = activation_post_process.quant_max # type: ignore[attr-defined]
|
|
qscheme = getattr(activation_post_process, "qscheme", torch.per_tensor_affine) # type: ignore[attr-defined]
|
|
eps = getattr(activation_post_process, "eps", torch.finfo(torch.float32).eps) # type: ignore[attr-defined]
|
|
# note: scale and zero_point are missing for quantize_per_tensor op
|
|
# we'll need to get this from choose_qparams op, which we'll add after
|
|
# this step
|
|
qparams = {
|
|
"_quant_min_": quant_min,
|
|
"_quant_max_": quant_max,
|
|
"_eps_": eps,
|
|
"_dtype_": dtype_
|
|
}
|
|
|
|
choose_qparams_op = _QSCHEME_TO_CHOOSE_QPARAMS_OP[qscheme]
|
|
# 2. insert choose_qparams op and update the qparams list
|
|
with graph.inserting_before(node):
|
|
input_node = node.args[0]
|
|
choose_qparams_op_inputs = [node.args[0]]
|
|
for key, value in qparams.items():
|
|
# we have quant_min, quant_max and dtype, all should be stored
|
|
# as literals
|
|
choose_qparams_op_inputs.append(value)
|
|
choose_qparams_node = graph.create_node(
|
|
"call_function",
|
|
choose_qparams_op,
|
|
tuple(choose_qparams_op_inputs),
|
|
{}
|
|
)
|
|
# choose_qparms returns (scale, zero_point)
|
|
scale_node = graph.create_node(
|
|
"call_function",
|
|
operator.getitem,
|
|
(choose_qparams_node, 0),
|
|
{}
|
|
)
|
|
zero_point_node = graph.create_node(
|
|
"call_function",
|
|
operator.getitem,
|
|
(choose_qparams_node, 1),
|
|
{}
|
|
)
|
|
quant_min = qparams["_quant_min_"]
|
|
quant_max = qparams["_quant_max_"]
|
|
dtype = qparams["_dtype_"]
|
|
qparams = {
|
|
"_scale_": scale_node,
|
|
"_zero_point_": zero_point_node,
|
|
"_quant_min_": quant_min,
|
|
"_quant_max_": quant_max,
|
|
"_dtype_": dtype
|
|
}
|
|
|
|
# 3. replace activation_post_process node to quantize and dequantize node
|
|
with graph.inserting_before(node):
|
|
input_node = node.args[0]
|
|
quantize_op_inputs = [input_node]
|
|
for key, value_or_node in qparams.items():
|
|
# TODO: we can add the information of whether a value needs to
|
|
# be registered as an attribute in qparams dict itself
|
|
if key in ['_scale_', '_zero_point_']:
|
|
# in this case we have a node in the graph since it's dynamically
|
|
# computed from the input, with choose_qparams op
|
|
qparam_node = value_or_node
|
|
quantize_op_inputs.append(qparam_node)
|
|
else:
|
|
# for qparams that are not scale/zero_point (like axis, dtype) we
|
|
# store them as literals in the graph.
|
|
quantize_op_inputs.append(value_or_node)
|
|
|
|
quantized_node = graph.create_node(node_type, quantize_op, tuple(quantize_op_inputs), {})
|
|
# use the same qparams from quantize op
|
|
dq_inputs = [quantized_node] + quantize_op_inputs[1:]
|
|
# need to use the tensor variant of this op, since scale and zero_point
|
|
# from choose_qparam are Tensors, instead of float/int, this is to
|
|
# prevent these nodes being traced away by downstream systems
|
|
dequantize_op = torch.ops.quantized_decomposed.dequantize_per_tensor.tensor
|
|
dequantized_node = graph.call_function(
|
|
dequantize_op,
|
|
tuple(dq_inputs),
|
|
{}
|
|
)
|
|
|
|
def remap_fn(x):
|
|
return dequantized_node if x is node else x
|
|
|
|
# remap numeric_debug_handle
|
|
for user_node in node.users:
|
|
if "numeric_debug_handle" in user_node.meta:
|
|
numeric_debug_handle = user_node.meta["numeric_debug_handle"]
|
|
user_node.meta["numeric_debug_handle"] = {remap_fn(k): v for k, v in numeric_debug_handle.items()}
|
|
node.replace_all_uses_with(dequantized_node)
|
|
graph.erase_node(node)
|
|
elif dtype == torch.float16:
|
|
raise NotImplementedError("decomposed to float16 op not implemented yet")
|
|
|
|
# should not reach since we have checks in the beginning to make sure the
|
|
# activation_post_process is supported
|
|
|
|
def _replace_observer_with_quantize_dequantize_node(
|
|
model: torch.fx.GraphModule,
|
|
node: Node,
|
|
modules: Dict[str, torch.nn.Module],
|
|
node_name_to_scope: Dict[str, Tuple[str, type]],
|
|
node_name_to_qconfig: Dict[str, QConfigAny]) -> None:
|
|
""" Replace activation_post_process module call node with quantize and
|
|
dequantize node
|
|
|
|
Before:
|
|
... -> observer_0(x) -> ...
|
|
After:
|
|
... -> torch.quantize_per_tensor(x, ...) -> x.dequantize() -> ...
|
|
"""
|
|
assert modules is not None
|
|
assert isinstance(node.target, str)
|
|
graph = model.graph
|
|
module_path, prefix = _get_module_path_and_prefix(node, node_name_to_scope, node_name_to_qconfig)
|
|
activation_post_process = modules[node.target]
|
|
# skip replacing observers to quant/dequant nodes if the qconfigs of all
|
|
# consumers and producers of this observer are None
|
|
skip_replacement = all(_has_none_qconfig(n, node_name_to_qconfig) for n in
|
|
list(node.args) + list(node.users.keys()))
|
|
if skip_replacement or not _is_conversion_supported(activation_post_process):
|
|
# didn't find corresponding quantize op and info for the activation_post_process
|
|
# so we just remove the observer
|
|
with graph.inserting_before(node):
|
|
node.replace_all_uses_with(node.args[0])
|
|
graph.erase_node(node)
|
|
return
|
|
|
|
# otherwise, we can convert the activation_post_process module call to quantize/dequantize node
|
|
dtype = activation_post_process.dtype # type: ignore[attr-defined]
|
|
|
|
is_dynamic = False
|
|
if hasattr(activation_post_process, "is_dynamic"):
|
|
is_dynamic = activation_post_process.is_dynamic # type: ignore[attr-defined, assignment]
|
|
|
|
if dtype in [torch.quint8, torch.qint8, torch.qint32] and \
|
|
(not is_dynamic):
|
|
# TODO: probably should cleanup this condition check, it's hard
|
|
# to reason about this if and the following elif
|
|
|
|
# uint8/int8/int32 static quantization branch
|
|
|
|
# 1. extract the information from activation_post_process module for generating
|
|
# the quantize and dequantize operator
|
|
node_type = "call_function"
|
|
quantize_op : Optional[Callable] = None
|
|
scale, zero_point = activation_post_process.calculate_qparams() # type: ignore[attr-defined, operator]
|
|
if is_per_channel(activation_post_process.qscheme): # type: ignore[attr-defined]
|
|
ch_axis = int(activation_post_process.ch_axis) # type: ignore[attr-defined, arg-type]
|
|
qparams = {"_scale_": scale, "_zero_point_": zero_point, "_axis_": ch_axis, "_dtype_": dtype}
|
|
quantize_op = torch.quantize_per_channel
|
|
else:
|
|
scale = float(scale)
|
|
zero_point = int(zero_point)
|
|
qparams = {"_scale_": scale, "_zero_point_": zero_point, "_dtype_": dtype}
|
|
quantize_op = torch.quantize_per_tensor
|
|
|
|
# 2. replace activation_post_process node with quantize and dequantize
|
|
with graph.inserting_before(node):
|
|
input_node = node.args[0]
|
|
quantize_op_inputs = [input_node]
|
|
for key, value_or_node in qparams.items():
|
|
# TODO: we can add the information of whether a value needs to
|
|
# be registered as an attribute in qparams dict itself
|
|
if key in ['_scale_', '_zero_point_']:
|
|
# For scale and zero_point values we register them as buffers in the root module.
|
|
# TODO: maybe need more complex attr name here
|
|
qparam_node = create_getattr_from_value(
|
|
model, graph, module_path + prefix + key, value_or_node)
|
|
quantize_op_inputs.append(qparam_node)
|
|
else:
|
|
# for qparams that are not scale/zero_point (like axis, dtype) we store them as literals in the graph.
|
|
quantize_op_inputs.append(value_or_node)
|
|
|
|
quantized_node = graph.create_node(node_type, quantize_op, tuple(quantize_op_inputs), {})
|
|
dequantized_node = graph.call_method("dequantize", args=(quantized_node,))
|
|
node.replace_all_uses_with(dequantized_node)
|
|
graph.erase_node(node)
|
|
elif is_dynamic:
|
|
|
|
# uint8/int8/fp16 dynamic quantization branch
|
|
|
|
node_type = "call_function"
|
|
quantize_op = torch.quantize_per_tensor_dynamic
|
|
# TODO: get reduce range from observer
|
|
# reduce_range = activation_post_process.reduce_range
|
|
reduce_range = torch.backends.quantized.engine in ("fbgemm", "x86")
|
|
qparams = {"_dtype_": dtype, "_reduce_range_": reduce_range}
|
|
|
|
with graph.inserting_before(node):
|
|
input_node = node.args[0]
|
|
quantize_op_inputs = [input_node]
|
|
for key, value in qparams.items():
|
|
quantize_op_inputs.append(value)
|
|
|
|
quantized_node = graph.create_node(node_type, quantize_op, tuple(quantize_op_inputs), {})
|
|
dequantized_node = graph.call_method("dequantize", args=(quantized_node,))
|
|
node.replace_all_uses_with(dequantized_node)
|
|
graph.erase_node(node)
|
|
elif dtype == torch.float16:
|
|
node_type = "call_method"
|
|
quantize_op = "to" # type: ignore[assignment]
|
|
qparams = {"_dtype_": dtype}
|
|
with graph.inserting_before(node):
|
|
input_node = node.args[0]
|
|
quantize_op_inputs = [input_node]
|
|
for key, value in qparams.items():
|
|
# TODO: we can add the information of whether a value needs to
|
|
# be registered as an attribute in qparams dict itself
|
|
quantize_op_inputs.append(value)
|
|
|
|
quantized_node = graph.create_node(node_type, quantize_op, tuple(quantize_op_inputs), {})
|
|
dequantized_node = graph.call_method("dequantize", args=(quantized_node,))
|
|
node.replace_all_uses_with(dequantized_node)
|
|
graph.erase_node(node)
|
|
|
|
# should not reach since we have checks in the beginning to make sure the
|
|
# activation_post_process is supported
|
|
|
|
# this is a temporary hack for custom module, we may want to implement
|
|
# this properly after the custom module class design is finalized
|
|
# TODO: DeQuantStubs are currently inserted only after custom module LSTM, while observers are inserted
|
|
# after all other custom modules. In the future, we should simply insert QuantStubs before and DeQuantStubs
|
|
# after custom modules in general, and replace these with "quantize" and "dequantize" nodes respectively.
|
|
def _replace_observer_or_dequant_stub_with_dequantize_node(node: Node, graph: Graph) -> None:
|
|
call_custom_module_node = node.args[0]
|
|
assert isinstance(call_custom_module_node, Node), \
|
|
f"Expecting the for call custom module node to be a Node, but got {call_custom_module_node}"
|
|
node.replace_all_uses_with(call_custom_module_node)
|
|
graph.erase_node(node)
|
|
_insert_dequantize_node(call_custom_module_node, graph)
|
|
|
|
def _is_conversion_supported(activation_post_process: torch.nn.Module) -> bool:
|
|
dtype = activation_post_process.dtype # type: ignore[attr-defined]
|
|
|
|
is_dynamic = False
|
|
if hasattr(activation_post_process, "is_dynamic"):
|
|
is_dynamic = activation_post_process.is_dynamic # type: ignore[attr-defined, assignment]
|
|
|
|
return (
|
|
(dtype in [
|
|
torch.quint8,
|
|
torch.qint8,
|
|
torch.qint32,
|
|
torch.uint8,
|
|
torch.int8,
|
|
torch.int16,
|
|
torch.int32
|
|
] and (not is_dynamic)) or # type: ignore[return-value]
|
|
is_dynamic or
|
|
dtype == torch.float16
|
|
)
|
|
|
|
def _has_none_qconfig(node: Argument, node_name_to_qconfig: Dict[str, QConfigAny]) -> bool:
|
|
""" Check if a node has a qconfig of None, i.e. user requested to not quantize
|
|
the node
|
|
"""
|
|
return isinstance(node, Node) and node.name in node_name_to_qconfig and node_name_to_qconfig[node.name] is None
|
|
|
|
def _run_weight_observers(observed: GraphModule, backend_config: BackendConfig) -> None:
|
|
""" Extract the subgraph that produces the weight for dynamic quant
|
|
or weight only quant node and run the subgraph to observe the weight.
|
|
Note that the observers of dynamic quant or weight only quant ops are
|
|
run during the convert step.
|
|
"""
|
|
for node in observed.graph.nodes:
|
|
if node.op != "call_function":
|
|
continue
|
|
for node_arg in node.args:
|
|
# node_arg is weight
|
|
if node_arg and node_arg_is_weight(node, node_arg):
|
|
weight_observer_nodes = collect_producer_nodes(node_arg)
|
|
if weight_observer_nodes is None:
|
|
continue
|
|
weight_observer_module = \
|
|
graph_module_from_producer_nodes(
|
|
observed, weight_observer_nodes)
|
|
# run the weight observer
|
|
weight_observer_module()
|
|
|
|
def _maybe_recursive_remove_dequantize(arg: Any, node: Node, graph: Graph) -> None:
|
|
""" If the arg is a dequantize Node, or a list/tuple/dict of dequantize Node,
|
|
we'll recursively remove the dequantize Node
|
|
"""
|
|
if isinstance(arg, Node) and \
|
|
arg.op == "call_method" and \
|
|
arg.target == "dequantize":
|
|
quantize_node = arg.args[0]
|
|
# we only replace the specific use since dequantize could be used by other nodes
|
|
# as well
|
|
node.replace_input_with(arg, quantize_node)
|
|
elif isinstance(arg, (list, tuple)):
|
|
for arg_element in arg:
|
|
_maybe_recursive_remove_dequantize(arg_element, node, graph)
|
|
elif isinstance(arg, dict):
|
|
for arg_element in arg.values():
|
|
_maybe_recursive_remove_dequantize(arg_element, node, graph)
|
|
else:
|
|
warnings.warn(f"Unsupported node type in recursive remove dequantize: {type(arg)}")
|
|
|
|
def _get_module_path_and_prefix(
|
|
obs_node: Node,
|
|
node_name_to_scope: Dict[str, Tuple[str, type]],
|
|
node_name_to_qconfig: Dict[str, QConfigAny]) -> Tuple[str, str]:
|
|
""" Given and observer node, get the `Scope` or the fully qualified name for
|
|
the submodule containing the observed node, also return a prefix of "_input"
|
|
when the observed node is an input of a F.linear op, and not the output of another
|
|
quantized op.
|
|
TODO: this logic is hacky, we should think about how to remove it or make it more
|
|
general
|
|
"""
|
|
observed_node = obs_node.args[0]
|
|
# an observer can be inserted for both input of the next operator or output of the previous
|
|
# operator (they can be the same)
|
|
# this flag identifies if the observer is inserted only because the observed node is
|
|
# the input of the next operator
|
|
assert isinstance(observed_node, Node), \
|
|
f"Expecting observed node to be a Node, but got {observed_node}"
|
|
is_input_observer_only = node_name_to_qconfig[observed_node.name] is None \
|
|
if observed_node.name in node_name_to_qconfig else None
|
|
if is_input_observer_only:
|
|
# if the quantize function is at the input of op, then we find the first user of the observer_node
|
|
# to get the path. If a linear call_function is in the user list, we return the first instance
|
|
# of linear node to get the FQN.
|
|
users = list(obs_node.users)
|
|
first_linear_use_or_first_use = users[0] if users else None
|
|
linear_node = None
|
|
for n in users:
|
|
if n.op == "call_function" and n.target == torch.nn.functional.linear:
|
|
linear_node = n
|
|
break
|
|
if linear_node:
|
|
first_linear_use_or_first_use = linear_node
|
|
prefix = "_input"
|
|
else:
|
|
# if the quantize function is at the output of the op, we use the observer input node to get the path
|
|
first_linear_use_or_first_use = observed_node
|
|
prefix = ""
|
|
|
|
if first_linear_use_or_first_use and first_linear_use_or_first_use.name in node_name_to_scope:
|
|
module_path, _ = node_name_to_scope[first_linear_use_or_first_use.name]
|
|
else:
|
|
# TODO: it's not used, so actually we can skip quantization
|
|
# but this requires changing return type of quantize_node
|
|
# we can fix it later if needed
|
|
module_path = ""
|
|
return module_path, prefix
|
|
|
|
def _insert_dequantize_node(
|
|
node: Node,
|
|
graph: Graph) -> None:
|
|
""" Inserts dequantize node for `node` in `graph`
|
|
"""
|
|
with graph.inserting_after(node):
|
|
dequantize_node = graph.call_method("dequantize", (node,))
|
|
for user_node in dict(node.users):
|
|
if user_node is not dequantize_node:
|
|
user_node.replace_input_with(node, dequantize_node)
|
|
|
|
def _maybe_get_observer_for_node(
|
|
node: Node,
|
|
modules: Dict[str, torch.nn.Module]
|
|
) -> Optional[torch.nn.Module]:
|
|
"""
|
|
If the node is observed, return the observer
|
|
instance. Otherwise, return None.
|
|
"""
|
|
for maybe_obs_node in node.users.keys():
|
|
if maybe_obs_node.op == 'call_module':
|
|
maybe_obs = modules[str(maybe_obs_node.target)]
|
|
if _is_activation_post_process(maybe_obs):
|
|
return maybe_obs
|
|
return None
|
|
|
|
def convert_standalone_module(
|
|
node: Node,
|
|
modules: Dict[str, torch.nn.Module],
|
|
model: torch.fx.GraphModule,
|
|
is_reference: bool,
|
|
backend_config: Optional[BackendConfig]) -> None:
|
|
""" Converts a observed standalone module to a quantized standalone module by calling
|
|
the fx convert api, currently using the same `is_reference` flag as parent, but we may
|
|
changing this behavior in the future (e.g. separating quantization and lowering for
|
|
standalone module as well)
|
|
|
|
Args:
|
|
- node: The call_module node of the observed standalone module
|
|
- modules: named_module of original model
|
|
- model: original model
|
|
- is_reference: a flag from parent provided by user to decide if we want to
|
|
produce a reference model or a fbgemm/qnnpack model
|
|
- backend_config: backend configuration of the target backend of quantization
|
|
"""
|
|
# TODO: remove is_reference flag
|
|
if is_reference:
|
|
convert_fn = torch.ao.quantization.quantize_fx.convert_to_reference_fx
|
|
else:
|
|
convert_fn = torch.ao.quantization.quantize_fx.convert_fx # type: ignore[attr-defined]
|
|
# We know that observed standalone module is a GraphModule since
|
|
# it's produced by us
|
|
observed_standalone_module : GraphModule = modules[str(node.target)] # type: ignore[assignment]
|
|
sm_input_quantized_idxs = \
|
|
observed_standalone_module \
|
|
.meta["_observed_graph_module_attrs"].standalone_module_input_quantized_idxs
|
|
# remove the dequantize nodes for inputs
|
|
args = list(node.args)
|
|
for idx in range(len(args)):
|
|
if idx in sm_input_quantized_idxs:
|
|
arg = args[idx]
|
|
if arg.op == "call_method" and arg.target == "dequantize": # type: ignore[union-attr]
|
|
quantize_node = arg.args[0] # type: ignore[union-attr]
|
|
node.replace_input_with(arg, quantize_node)
|
|
if len(arg.users) == 0: # type: ignore[union-attr]
|
|
model.graph.erase_node(arg)
|
|
# add dequantize node for output
|
|
sm_output_quantized_idxs = \
|
|
observed_standalone_module \
|
|
.meta["_observed_graph_module_attrs"].standalone_module_output_quantized_idxs
|
|
if len(sm_output_quantized_idxs) > 0:
|
|
assert sm_output_quantized_idxs[0] == 0, "Currently only quantized"
|
|
"output idxs = [0] is supported"
|
|
|
|
# if it's non-empty, then it means the output is kept in quantized form
|
|
# we'll just add a dequantize node after this node
|
|
_insert_dequantize_node(node, model.graph)
|
|
|
|
# TODO: allow convert_custom_config to override backend_config
|
|
# for standalone module
|
|
quantized_standalone_module = convert_fn(
|
|
observed_standalone_module,
|
|
backend_config=backend_config)
|
|
parent_name, name = _parent_name(node.target)
|
|
# update the modules dict
|
|
setattr(modules[parent_name], name, quantized_standalone_module)
|
|
modules[str(node.target)] = quantized_standalone_module
|
|
|
|
def convert_weighted_module(
|
|
node: Node,
|
|
modules: Dict[str, torch.nn.Module],
|
|
observed_node_names: Set[str],
|
|
node_name_to_qconfig: Dict[str, QConfigAny],
|
|
backend_config: BackendConfig,
|
|
is_decomposed: bool = False,
|
|
is_reference: bool = False,
|
|
) -> None:
|
|
""" Convert a weighted module to reference quantized module in the model
|
|
If the QConfig of a QAT module is not set, the module will still be converted to
|
|
a float module.
|
|
|
|
Args:
|
|
- node: The call_module node of the observed standalone module
|
|
- modules: named_module of original model
|
|
- observed_node_names: names for the set of observed fx node, we can skip
|
|
this conversion if the node is not observed
|
|
"""
|
|
original_module = modules[str(node.target)]
|
|
qconfig: QConfigAny = original_module.qconfig # type: ignore[assignment]
|
|
weight_post_process = None
|
|
qat_module_classes = get_qat_module_classes(backend_config)
|
|
|
|
if isinstance(
|
|
original_module,
|
|
qat_module_classes):
|
|
# Converting qat module to a float module, we need to attach
|
|
# weight fake_quant to the module, weight fake_quant is assumed to be run during
|
|
# QAT so we don't need to run it again here
|
|
weight_post_process = original_module.weight_fake_quant
|
|
original_module = original_module.to_float() # type: ignore[operator]
|
|
# change qat module to float module
|
|
parent_name, name = _parent_name(node.target)
|
|
setattr(modules[parent_name], name, original_module)
|
|
|
|
is_observed = node.name in observed_node_names
|
|
# If a qconfig is not defined for this node, then skip converting to a reference module
|
|
if qconfig is None or _has_none_qconfig(node, node_name_to_qconfig) or not is_observed:
|
|
return
|
|
|
|
# skip converting to reference quantized module if the qconfig is not supported
|
|
pattern_to_dtype_configs = get_pattern_to_dtype_configs(backend_config)
|
|
dtype_configs = pattern_to_dtype_configs.get(type(original_module), [])
|
|
if not _is_qconfig_supported_by_dtype_configs(qconfig, dtype_configs):
|
|
return
|
|
|
|
# TODO: rename weight_is_statically_quantized to weight_is_int8_quantized
|
|
is_weight_quantized = weight_is_quantized(qconfig)
|
|
|
|
# the condition for swapping the module to reference quantized module is:
|
|
# weights need to be quantized
|
|
if not is_weight_quantized:
|
|
return
|
|
|
|
fused_module = None
|
|
float_module = original_module
|
|
# extract the individual float_module and fused module
|
|
if isinstance(original_module, torch.ao.nn.intrinsic._FusedModule):
|
|
fused_module = float_module
|
|
float_module = fused_module[0] # type: ignore[index]
|
|
|
|
# TODO: move this to the reference quantized module
|
|
# weight_qparams or weight_qparams dict
|
|
wq_or_wq_dict = {"is_decomposed": is_decomposed}
|
|
if isinstance(float_module, torch.nn.RNNCellBase):
|
|
weight_post_process_ih = qconfig.weight() # type: ignore[union-attr, operator]
|
|
weight_post_process_hh = qconfig.weight() # type: ignore[union-attr, operator]
|
|
weight_post_process_ih(float_module.weight_ih)
|
|
weight_post_process_hh(float_module.weight_hh)
|
|
weight_qparams_ih = get_qparam_dict(weight_post_process_ih)
|
|
weight_qparams_hh = get_qparam_dict(weight_post_process_hh)
|
|
wq_or_wq_dict.update({
|
|
"weight_ih": weight_qparams_ih,
|
|
"weight_hh": weight_qparams_hh,
|
|
})
|
|
elif isinstance(float_module, (torch.nn.LSTM, torch.nn.GRU)):
|
|
# format for wq_or_wq_dict (flattened attributes):
|
|
# {"weight_ih_l0_scale": ..., "weight_ih_l0_qscheme": ..., ...}
|
|
for wn in float_module._flat_weights_names:
|
|
if hasattr(float_module, wn) and wn.startswith("weight"):
|
|
weight = getattr(float_module, wn)
|
|
weight_post_process = qconfig.weight() # type: ignore[union-attr, operator]
|
|
if weight_post_process.dtype == torch.qint8: # type: ignore[union-attr]
|
|
weight_post_process(weight) # type: ignore[operator, misc]
|
|
wq_or_wq_dict[wn] = get_qparam_dict(weight_post_process)
|
|
else:
|
|
# weight_post_process is None means the original module is not a QAT module
|
|
# we need to get weight_post_process from qconfig in this case
|
|
is_ptq = weight_post_process is None
|
|
if is_ptq:
|
|
weight_post_process = qconfig.weight() # type: ignore[union-attr, operator]
|
|
device = assert_and_get_unique_device(float_module)
|
|
if device:
|
|
weight_post_process.to(device)
|
|
|
|
# Call weight observer/fake_quant at least once to ensure the scales and zero points
|
|
# have the right shapes. Note: there are two cases where we don't have to do this:
|
|
#
|
|
# (1) QAT: The model's forward method already calls the weight observer/fake_quant,
|
|
# and this typically happens during training, so we don't need to do it here.
|
|
#
|
|
# (2) Non-reference (lowered) case: The quantized module's from_float method already
|
|
# calls the weight observer/fake_quant, so we don't have to do it here.
|
|
#
|
|
# Currently we ignore both cases and call the weight observer/fake_quant here
|
|
# regardless, which is technically incorrect. For (1), this is mainly to preserve BC
|
|
# in test code, which may not always train before convert. In the future, we should
|
|
# break BC for these two cases. See https://github.com/pytorch/pytorch/issues/73941.
|
|
#
|
|
# For PT2, however, we don't need to preserve BC here, so we can skip this hack
|
|
# for QAT. We identify this case as (is_decomposed + is_reference + is_qat).
|
|
# Note that we still need it for PTQ in the PT2 flow since the model's forward
|
|
# method doesn't call the weight observer.
|
|
is_qat = not is_ptq
|
|
if not (is_decomposed and is_reference and is_qat):
|
|
weight_post_process(float_module.weight) # type: ignore[operator]
|
|
|
|
wq_or_wq_dict.update(get_qparam_dict(weight_post_process))
|
|
|
|
# We use the same reference module for all modes of quantization: static, dynamic, weight_only
|
|
# root_module_to_quantized_reference_module: module mapping from root (floating point) module class
|
|
# to quantized reference module class, e.g. nn.Conv2d to nn.quantized._reference.Conv2d
|
|
root_module_to_quantized_reference_module = get_root_module_to_quantized_reference_module(backend_config)
|
|
ref_qmodule_cls = root_module_to_quantized_reference_module.get(type_before_parametrizations(float_module), None)
|
|
assert (
|
|
ref_qmodule_cls is not None
|
|
), f"No reference quantized module class configured for {type_before_parametrizations(float_module)}"
|
|
ref_qmodule = ref_qmodule_cls.from_float(float_module, wq_or_wq_dict) # type: ignore[attr-defined]
|
|
if fused_module is not None:
|
|
fused_module[0] = ref_qmodule # type: ignore[operator]
|
|
else:
|
|
parent_name, name = _parent_name(node.target)
|
|
setattr(modules[parent_name], name, ref_qmodule)
|
|
|
|
def _remove_previous_dequantize_in_custom_module(node: Node, prev_node: Node, graph: Graph) -> None:
|
|
"""
|
|
Given a custom module `node`, if the previous node is a dequantize, reroute the custom as follows:
|
|
|
|
Before: quantize - dequantize - custom_module
|
|
After: quantize - custom_module
|
|
\\ - dequantize
|
|
"""
|
|
# expecting the input node for a custom module node to be a Node
|
|
assert isinstance(prev_node, Node), \
|
|
f"Expecting the argument for custom module node to be a Node, but got {prev_node}"
|
|
if prev_node.op == "call_method" and prev_node.target == "dequantize":
|
|
node.replace_input_with(prev_node, prev_node.args[0])
|
|
# Remove the dequantize node if it doesn't have other users
|
|
if len(prev_node.users) == 0:
|
|
graph.erase_node(prev_node)
|
|
|
|
def convert_custom_module(
|
|
node: Node,
|
|
graph: Graph,
|
|
modules: Dict[str, torch.nn.Module],
|
|
custom_module_class_mapping: Dict[QuantType, Dict[Type, Type]],
|
|
statically_quantized_custom_module_nodes: Set[Node]) -> None:
|
|
""" Converts an observed custom module to a quantized custom module based on
|
|
`custom_module_class_mapping`
|
|
For static quantization, we'll also remove the previous `dequantize` node and
|
|
attach the observer node for output to the module, the observer for the node
|
|
will be converted to a dequantize node instead of quantize-dequantize pairs
|
|
later in the graph. In the end we would have a quantized custom module that
|
|
has the same interface as a default quantized module in nn.quantized namespace,
|
|
i.e. quantized input and quantized output.
|
|
|
|
Args:
|
|
- node: The call_module node of the observed standalone module
|
|
- graph: The graph containing the node
|
|
- modules: named_module of original model
|
|
- custom_module_class_mapping: mapping from observed custom module class to
|
|
quantized custom module class, used to swap custom modules
|
|
- statically_quantized_custom_module_nodes: we'll add the custom module node
|
|
if we find it is statically quantized, this will be used later when converting
|
|
observers to quant/dequant node pairs, if the observed node is a statically
|
|
quantized custom module nodes, we'll convert the observer to a dequantize node,
|
|
this is to keep the interface the same as the default quantized module.
|
|
TODO: maybe we want to redesign this part to align with reference model design
|
|
as well, but there has been some discussions around the interface, so we can do
|
|
it later.
|
|
"""
|
|
observed_custom_module = modules[str(node.target)]
|
|
maybe_obs = _maybe_get_observer_for_node(node, modules)
|
|
qconfig = observed_custom_module.qconfig
|
|
if activation_is_statically_quantized(qconfig):
|
|
statically_quantized_custom_module_nodes.add(node)
|
|
if _is_custom_module_lstm(node, modules):
|
|
# The inputs are tuples in the form (input, (hidden0, hidden1))
|
|
# Ensure all three input nodes are quantized
|
|
assert (
|
|
len(node.args) == 2 and
|
|
isinstance(node.args[1], tuple) and
|
|
len(node.args[1]) == 2
|
|
)
|
|
(inputs, (hidden0, hidden1)) = node.args # type: ignore[misc]
|
|
assert isinstance(inputs, Node)
|
|
assert isinstance(hidden0, Node)
|
|
assert isinstance(hidden1, Node)
|
|
_remove_previous_dequantize_in_custom_module(node, inputs, graph)
|
|
_remove_previous_dequantize_in_custom_module(node, hidden0, graph)
|
|
_remove_previous_dequantize_in_custom_module(node, hidden1, graph)
|
|
elif _is_custom_module_mha(node, modules):
|
|
# Inputs are in the form (query, key, value)
|
|
# TODO: This is the first step in enabling the full fx custom module
|
|
# quantization path for MultiheadAttention, and only covers the inputs
|
|
# to the module.
|
|
# Additional handling is yet to be implemented for the outputs, similar
|
|
# to LSTM custom module
|
|
assert len(node.args) == 3
|
|
query, key, value = node.args
|
|
assert isinstance(query, Node)
|
|
assert isinstance(key, Node)
|
|
assert isinstance(value, Node)
|
|
_remove_previous_dequantize_in_custom_module(node, query, graph)
|
|
_remove_previous_dequantize_in_custom_module(node, key, graph)
|
|
_remove_previous_dequantize_in_custom_module(node, value, graph)
|
|
else:
|
|
# remove the previous dequant node to ensure the inputs are quantized
|
|
arg = node.args[0]
|
|
assert isinstance(arg, Node)
|
|
_remove_previous_dequantize_in_custom_module(node, arg, graph)
|
|
# absorb the following observer into the module conversion
|
|
activation_post_process = _maybe_get_observer_for_node(node, modules)
|
|
assert activation_post_process is not None
|
|
observed_custom_module.activation_post_process = activation_post_process
|
|
|
|
# swap the observed custom module to quantized custom module
|
|
quantized_custom_module_class = get_swapped_custom_module_class(
|
|
observed_custom_module, custom_module_class_mapping, qconfig)
|
|
quantized_custom_module = \
|
|
quantized_custom_module_class.from_observed(observed_custom_module)
|
|
parent_name, name = _parent_name(node.target)
|
|
setattr(modules[parent_name], name, quantized_custom_module)
|
|
|
|
def convert(
|
|
model: GraphModule, is_reference: bool = False,
|
|
convert_custom_config: Union[ConvertCustomConfig, Dict[str, Any], None] = None,
|
|
is_standalone_module: bool = False,
|
|
_remove_qconfig_flag: bool = True,
|
|
qconfig_mapping: Union[QConfigMapping, Dict[str, Any], None] = None,
|
|
backend_config: Union[BackendConfig, Dict[str, Any], None] = None,
|
|
is_decomposed: bool = False) -> GraphModule:
|
|
"""
|
|
We will convert an observed model (a module with observer calls) to a reference
|
|
quantized model, the rule is simple:
|
|
1. for each observer module call in the graph, we'll convert it to calls to
|
|
quantize and dequantize functions based on the observer instance
|
|
2. for weighted operations like linear/conv, we need to convert them to reference
|
|
quantized module, this requires us to know whether the dtype configured for the
|
|
weight is supported in the backend, this is done in prepare step and the result
|
|
is stored in observed_node_names, we can decide whether we need to swap the
|
|
module based on this set
|
|
|
|
Args:
|
|
* `is_standalone_module`: when this flag is True, it means we are quantizing
|
|
a submodule that is not inlined in parent module, and will be quantized
|
|
separately as one unit.
|
|
|
|
* `is_decomposed`: a boolean flag to indicate whether we want to use the
|
|
quantize operator for decomposed quantized tensor
|
|
(torch.ops.quantized_decomposed.quantize_per_tensor) or default/standalone
|
|
quantized tensor (torch.quantize_per_tensor)
|
|
|
|
Returns:
|
|
a quantized standalone module, whether input/output is quantized is
|
|
specified by prepare_custom_config, with
|
|
input_quantized_idxs, output_quantized_idxs, please
|
|
see docs for :func:`~torch.ao.quantization.prepare_fx` for details
|
|
"""
|
|
if convert_custom_config is None:
|
|
convert_custom_config = ConvertCustomConfig()
|
|
|
|
if isinstance(convert_custom_config, Dict):
|
|
warnings.warn(
|
|
"Passing a convert_custom_config_dict to convert is deprecated and will not be supported "
|
|
"in a future version. Please pass in a ConvertCustomConfig instead.")
|
|
convert_custom_config = ConvertCustomConfig.from_dict(convert_custom_config)
|
|
|
|
if isinstance(qconfig_mapping, Dict):
|
|
warnings.warn(
|
|
"Passing a QConfig dictionary to convert is deprecated and will not be supported "
|
|
"in a future version. Please pass in a QConfigMapping instead.")
|
|
qconfig_mapping = QConfigMapping.from_dict(qconfig_mapping) if qconfig_mapping else None
|
|
qconfig_mapping = copy.deepcopy(qconfig_mapping)
|
|
assert qconfig_mapping is None or isinstance(qconfig_mapping, QConfigMapping)
|
|
|
|
if isinstance(backend_config, Dict):
|
|
warnings.warn(
|
|
"Passing a backend_config_dict to prepare is deprecated and will not be supported "
|
|
"in a future version. Please pass in a BackendConfig instead.")
|
|
backend_config = BackendConfig.from_dict(backend_config)
|
|
|
|
if backend_config is None:
|
|
backend_config = get_native_backend_config()
|
|
|
|
assert _is_observed_module(model), \
|
|
'incoming model must be produced by prepare_fx'
|
|
observed_graph_module_attrs = model.meta["_observed_graph_module_attrs"]
|
|
node_name_to_scope: Dict[str, Tuple[str, type]] = observed_graph_module_attrs.node_name_to_scope
|
|
prepare_custom_config: PrepareCustomConfig = observed_graph_module_attrs.prepare_custom_config
|
|
observed_node_names: Set[str] = observed_graph_module_attrs.observed_node_names
|
|
node_name_to_qconfig: Dict[str, QConfigAny] = observed_graph_module_attrs.node_name_to_qconfig # type: ignore[assignment]
|
|
|
|
# mapping from fully qualified module name to module instance
|
|
# for example,
|
|
# {
|
|
# '': Model(...),
|
|
# 'linear': Linear(...),
|
|
# 'linear.weight_fake_quant': PerChannelMinMaxObserver(...),
|
|
# }
|
|
# We use remove_duplicate=False here because torch.cat uses
|
|
# the same activation_post_process module instance but different names
|
|
modules = dict(model.named_modules(remove_duplicate=False))
|
|
|
|
# TODO refactor this code once we update the prepare logic to have additional information on
|
|
# which graph nodes have been observed and share that with convert to decide which observers to ignore.
|
|
if qconfig_mapping:
|
|
prepare_qconfig_mapping: QConfigMapping = observed_graph_module_attrs.qconfig_mapping # type: ignore[assignment]
|
|
modules_copy = copy.deepcopy(modules)
|
|
|
|
if observed_graph_module_attrs.is_qat:
|
|
_update_qconfig_for_qat(qconfig_mapping, backend_config)
|
|
_update_qconfig_for_fusion(model, qconfig_mapping)
|
|
|
|
_compare_prepare_convert_qconfig_mappings(prepare_qconfig_mapping, qconfig_mapping) # type: ignore[arg-type]
|
|
convert_node_name_to_qconfig = _generate_node_name_to_qconfig(
|
|
model, modules_copy, model.graph, qconfig_mapping, node_name_to_scope)
|
|
# check the convert_node_name_to_qconfig generated and ensure that
|
|
# all the values either match what was set in prepare node_name_to_qconfig
|
|
# or are set to None in the convert_node_name_to_qconfig.
|
|
for k, v in node_name_to_qconfig.items():
|
|
assert k in convert_node_name_to_qconfig, f'Expected key {k} in convert node_name_to_qconfig'
|
|
if convert_node_name_to_qconfig[k] is not None:
|
|
assert qconfig_equals(v, convert_node_name_to_qconfig[k]), \
|
|
f"Expected k {k} to have the same value in prepare and convert QConfigMappings, " \
|
|
f"but {v} was updated to {convert_node_name_to_qconfig[k]}"
|
|
node_name_to_qconfig = convert_node_name_to_qconfig
|
|
|
|
custom_module_classes = get_custom_module_class_keys(convert_custom_config.observed_to_quantized_mapping)
|
|
custom_module_class_mapping = convert_custom_config.observed_to_quantized_mapping
|
|
|
|
if observed_graph_module_attrs.equalization_node_name_to_qconfig is not None:
|
|
# If we want to do equalization then do the following:
|
|
# Calculate the equalization scale, update the observers with the scaled
|
|
# inputs, and scale the weight
|
|
weight_eq_obs_dict = update_obs_for_equalization(model, modules)
|
|
convert_eq_obs(model, modules, weight_eq_obs_dict)
|
|
|
|
# always run weight observers in the top level forward method
|
|
# for dynamic quant ops or weight only quant ops
|
|
_run_weight_observers(model, backend_config)
|
|
|
|
graph_inputs: List[str] = []
|
|
for node in model.graph.nodes:
|
|
if node.op == 'placeholder':
|
|
graph_inputs.append(node.name)
|
|
|
|
# additional state to override inputs to be quantized, if specified
|
|
# by the user
|
|
placeholder_node_seen_cnt = 0
|
|
input_quantized_idxs: List[int] = prepare_custom_config.input_quantized_indexes
|
|
output_quantized_idxs: List[int] = prepare_custom_config.output_quantized_indexes
|
|
|
|
root_module_to_quantized_reference_module = get_root_module_to_quantized_reference_module(backend_config)
|
|
# convert tuples so that it can work with isinstance(module, tuple_of_classes)
|
|
root_module_classes = tuple(root_module_to_quantized_reference_module.keys())
|
|
qat_module_classes = get_qat_module_classes(backend_config)
|
|
fused_module_classes = get_fused_module_classes(backend_config)
|
|
statically_quantized_custom_module_nodes: Set[Node] = set()
|
|
|
|
for node in list(model.graph.nodes):
|
|
if node.op == 'placeholder':
|
|
cur_placeholder_node_idx = placeholder_node_seen_cnt
|
|
placeholder_node_seen_cnt += 1
|
|
if cur_placeholder_node_idx in input_quantized_idxs:
|
|
# Inputs are assumed to be quantized if the user specified the
|
|
# input_quantized_idxs override.
|
|
# we need to dequantize the inputs since all operators took
|
|
# floating point inputs in reference quantized models
|
|
_insert_dequantize_node(node, model.graph)
|
|
elif node.op == "output":
|
|
# If the argument is empty we don't need to do anything
|
|
if len(output_quantized_idxs) == 0:
|
|
continue
|
|
# Result are kept quantized if the user specified the
|
|
# output_quantized_idxs override.
|
|
# Remove the dequantize operator for the node in the end if any
|
|
return_node = node
|
|
output = node.args[0]
|
|
# outputs can be Node, list, tuple, dict, other cases are not supported yet
|
|
if isinstance(output, (list, tuple)):
|
|
for idx in output_quantized_idxs:
|
|
_maybe_recursive_remove_dequantize(output[idx], return_node, model.graph)
|
|
elif isinstance(output, (Node, dict)):
|
|
# we treat dict as a single argument currently, but it can be extended
|
|
# to support {"key": dtype} after we change output_quantized_idxs to
|
|
# dict
|
|
if 0 in output_quantized_idxs:
|
|
_maybe_recursive_remove_dequantize(output, return_node, model.graph)
|
|
else:
|
|
warnings.warn(f"Unsupported node type for output_quantized_idxs: {type(output)}")
|
|
elif node.op == "call_module":
|
|
mod = _get_module(node, modules)
|
|
assert mod is not None
|
|
if _is_activation_post_process(mod):
|
|
observed_node = node.args[0]
|
|
if observed_node in statically_quantized_custom_module_nodes:
|
|
_replace_observer_or_dequant_stub_with_dequantize_node(node, model.graph)
|
|
else:
|
|
if is_decomposed:
|
|
_replace_observer_with_quantize_dequantize_node_decomposed(
|
|
model, node, modules, node_name_to_scope,
|
|
node_name_to_qconfig)
|
|
else:
|
|
_replace_observer_with_quantize_dequantize_node(
|
|
model, node, modules, node_name_to_scope,
|
|
node_name_to_qconfig)
|
|
elif isinstance(mod, DeQuantStub):
|
|
_replace_observer_or_dequant_stub_with_dequantize_node(node, model.graph)
|
|
elif _is_observed_standalone_module(mod):
|
|
convert_standalone_module(
|
|
node, modules, model, is_reference, backend_config)
|
|
# below this point `type_before_parametrizations` is used
|
|
# instead of `type` to handle situations with fx quant + sparsity
|
|
elif type_before_parametrizations(mod) in set(
|
|
root_module_classes).union(qat_module_classes).union(fused_module_classes):
|
|
# extra check for fused module classes to make sure they are fused module classes
|
|
# of target modules
|
|
if type_before_parametrizations(mod) in fused_module_classes and \
|
|
type_before_parametrizations(mod[0]) not in root_module_classes: # type: ignore[index]
|
|
continue
|
|
convert_weighted_module(
|
|
node, modules, observed_node_names, node_name_to_qconfig, backend_config,
|
|
is_decomposed, is_reference)
|
|
elif type_before_parametrizations(mod) in custom_module_classes:
|
|
convert_custom_module(
|
|
node, model.graph, modules, custom_module_class_mapping,
|
|
statically_quantized_custom_module_nodes)
|
|
|
|
# remove deadcode after converting observers to quant/dequant ops
|
|
model.graph.eliminate_dead_code()
|
|
model = GraphModule(model, model.graph)
|
|
|
|
# TODO: maybe move this to quantize_fx.py
|
|
if not is_reference:
|
|
model = lower_to_fbgemm(model, node_name_to_qconfig, node_name_to_scope)
|
|
|
|
# TODO: this looks hacky, we want to check why we need this and see if we can
|
|
# remove this
|
|
# removes qconfig and activation_post_process modules
|
|
if _remove_qconfig_flag:
|
|
_remove_qconfig(model)
|
|
model.delete_all_unused_submodules()
|
|
model.meta.pop("_observed_graph_module_attrs", None)
|
|
return model
|