1881 lines
81 KiB
Python
1881 lines
81 KiB
Python
|
import copy
|
||
|
import torch
|
||
|
import warnings
|
||
|
from torch.fx import (
|
||
|
GraphModule,
|
||
|
)
|
||
|
from torch.fx.graph import (
|
||
|
Graph,
|
||
|
Node,
|
||
|
)
|
||
|
from torch.fx.node import Argument
|
||
|
|
||
|
from ..quantize import (
|
||
|
propagate_qconfig_,
|
||
|
)
|
||
|
from ..observer import (
|
||
|
_is_activation_post_process,
|
||
|
_PartialWrapper,
|
||
|
)
|
||
|
from ..qconfig import (
|
||
|
_is_reuse_input_qconfig,
|
||
|
QConfigAny,
|
||
|
)
|
||
|
from ..qconfig_mapping import (
|
||
|
QConfigMapping,
|
||
|
)
|
||
|
from .qconfig_mapping_utils import (
|
||
|
_generate_node_name_to_qconfig,
|
||
|
_update_qconfig_for_fusion,
|
||
|
_get_flattened_qconfig_dict,
|
||
|
_update_qconfig_for_qat,
|
||
|
)
|
||
|
|
||
|
from .quantize_handler import (
|
||
|
_default_root_node_getter,
|
||
|
_get_pattern_to_quantize_handlers,
|
||
|
QuantizeHandler,
|
||
|
)
|
||
|
|
||
|
from torch.ao.quantization import (
|
||
|
ObserverBase,
|
||
|
FixedQParamsObserver,
|
||
|
FixedQParamsFakeQuantize,
|
||
|
_DerivedObserverOrFakeQuantize,
|
||
|
)
|
||
|
|
||
|
from torch.ao.quantization.utils import (
|
||
|
Pattern,
|
||
|
NodePattern,
|
||
|
)
|
||
|
|
||
|
from ._equalize import (
|
||
|
is_equalization_observer,
|
||
|
node_supports_equalization,
|
||
|
)
|
||
|
|
||
|
from .pattern_utils import (
|
||
|
_sorted_patterns_dict,
|
||
|
)
|
||
|
|
||
|
from .match_utils import (
|
||
|
_MatchResultWithQConfig,
|
||
|
_find_matches,
|
||
|
)
|
||
|
|
||
|
from .utils import (
|
||
|
_insert_dequant_stubs_for_custom_module_lstm_output,
|
||
|
_is_custom_module_lstm,
|
||
|
_maybe_get_custom_module_lstm_from_node_arg,
|
||
|
_qconfig_satisfies_dtype_config_constraints,
|
||
|
get_custom_module_class_keys,
|
||
|
all_node_args_have_no_tensors,
|
||
|
assert_and_get_unique_device,
|
||
|
get_non_observable_arg_indexes_and_types,
|
||
|
get_new_attr_name_with_prefix,
|
||
|
node_arg_is_weight,
|
||
|
node_arg_is_bias,
|
||
|
NON_QUANTIZABLE_WEIGHT_OPS,
|
||
|
ObservedGraphModuleAttrs,
|
||
|
)
|
||
|
|
||
|
from torch.ao.quantization import (
|
||
|
PlaceholderObserver
|
||
|
)
|
||
|
from torch.ao.quantization.quantize import (
|
||
|
convert
|
||
|
)
|
||
|
|
||
|
from ..utils import (
|
||
|
_parent_name,
|
||
|
get_qconfig_dtypes,
|
||
|
get_swapped_custom_module_class,
|
||
|
)
|
||
|
|
||
|
from ..backend_config.utils import (
|
||
|
get_pattern_to_dtype_configs,
|
||
|
get_module_to_qat_module,
|
||
|
get_fusion_pattern_to_root_node_getter,
|
||
|
)
|
||
|
from ..backend_config import (
|
||
|
BackendConfig,
|
||
|
DTypeConfig,
|
||
|
get_native_backend_config,
|
||
|
)
|
||
|
from .custom_config import (
|
||
|
PrepareCustomConfig,
|
||
|
StandaloneModuleConfigEntry,
|
||
|
)
|
||
|
from torch.ao.quantization.quantizer import (
|
||
|
EdgeOrNode,
|
||
|
QuantizationSpec,
|
||
|
QuantizationSpecBase,
|
||
|
FixedQParamsQuantizationSpec,
|
||
|
SharedQuantizationSpec,
|
||
|
DerivedQuantizationSpec,
|
||
|
)
|
||
|
from torch.ao.quantization import ObserverOrFakeQuantize
|
||
|
|
||
|
from torch._subclasses import FakeTensor
|
||
|
|
||
|
from typing import Any, Dict, List, Optional, Set, Tuple, Type, Union
|
||
|
from dataclasses import asdict
|
||
|
|
||
|
__all__ = [
|
||
|
"insert_observers_for_model",
|
||
|
"prepare",
|
||
|
"propagate_dtypes_for_known_nodes",
|
||
|
]
|
||
|
|
||
|
|
||
|
# list of dtypes to not add observers to
|
||
|
_DO_NOT_OBS_DTYPE_LIST = [int, float, torch.bool, None]
|
||
|
_OBS_DTYPE_LIST = [
|
||
|
torch.quint8,
|
||
|
torch.qint8,
|
||
|
torch.qint32,
|
||
|
torch.float16,
|
||
|
torch.uint8,
|
||
|
torch.int8,
|
||
|
torch.int16,
|
||
|
torch.int32
|
||
|
]
|
||
|
|
||
|
_DEFAULT_FP32_OBS_OR_FQ_CTR = PlaceholderObserver.with_args(dtype=torch.float)
|
||
|
|
||
|
# note: the following default target dtype info dicts are temporary,
|
||
|
# should be moved to the new programmable API class soon
|
||
|
_DEFAULT_FP32_QCONFIG_FOR_TARGET_DTYPE_INFO = {
|
||
|
"input_act_obs_or_fq_ctr": torch.ao.quantization.qconfig._default_fp32_placeholder_qconfig.activation,
|
||
|
"output_act_obs_or_fq_ctr": torch.ao.quantization.qconfig._default_fp32_placeholder_qconfig.activation
|
||
|
}
|
||
|
|
||
|
_DEFAULT_QUINT8_QCONFIG_FOR_TARGET_DTYPE_INFO = {
|
||
|
"input_act_obs_or_fq_ctr": torch.ao.quantization.qconfig._default_quint8_placeholder_qconfig.activation,
|
||
|
"output_act_obs_or_fq_ctr": torch.ao.quantization.qconfig._default_quint8_placeholder_qconfig.activation
|
||
|
}
|
||
|
|
||
|
|
||
|
def _get_observer_kwargs(quant_spec: Union[QuantizationSpec, FixedQParamsQuantizationSpec]):
|
||
|
kwargs_dict = asdict(quant_spec)
|
||
|
return copy.deepcopy(kwargs_dict)
|
||
|
|
||
|
def _get_qspec_for_arg(
|
||
|
arg: Node,
|
||
|
input_qspec_map: Dict[Node, QuantizationSpecBase],
|
||
|
named_modules: Dict[str, torch.nn.Module]
|
||
|
) -> Optional[QuantizationSpecBase]:
|
||
|
while _is_activation_post_process_node(arg, named_modules):
|
||
|
arg = arg.args[0] # type: ignore[assignment]
|
||
|
return input_qspec_map.get(arg, None)
|
||
|
|
||
|
def _create_obs_or_fq_from_qspec(
|
||
|
quantization_spec: Optional[QuantizationSpecBase],
|
||
|
obs_or_fq_map: Dict[EdgeOrNode, ObserverOrFakeQuantize],
|
||
|
is_qat: bool,
|
||
|
):
|
||
|
""" Create observer or fake quantize objects based on quantization spec
|
||
|
|
||
|
Args:
|
||
|
quantization_spec: used to store parameters to create the observer or fake quantizer
|
||
|
obs_or_fq_map: this is a map from edge/output to the corresponding observer/fake_quant
|
||
|
instance, it may be reused for different edge/output depending on configuration
|
||
|
"""
|
||
|
if quantization_spec is None:
|
||
|
return None
|
||
|
if isinstance(quantization_spec, SharedQuantizationSpec):
|
||
|
edge_or_node = quantization_spec.edge_or_node
|
||
|
assert edge_or_node in obs_or_fq_map, \
|
||
|
"please make sure only refer to edge or node that has " \
|
||
|
f"observer/fake_quant inserted: '{edge_or_node}' not in\n{obs_or_fq_map.keys()}"
|
||
|
return obs_or_fq_map[edge_or_node]
|
||
|
elif isinstance(quantization_spec, DerivedQuantizationSpec):
|
||
|
# can't use asdict, so not calling get_observer_kwargs here
|
||
|
kwargs = {
|
||
|
"dtype": quantization_spec.dtype,
|
||
|
"derive_qparams_fn": quantization_spec.derive_qparams_fn,
|
||
|
"quant_min": quantization_spec.quant_min,
|
||
|
"quant_max": quantization_spec.quant_max,
|
||
|
"qscheme": quantization_spec.qscheme,
|
||
|
"ch_axis": quantization_spec.ch_axis,
|
||
|
}
|
||
|
edge_or_nodes = quantization_spec.derived_from
|
||
|
obs_or_fqs = [obs_or_fq_map[k] for k in edge_or_nodes]
|
||
|
kwargs["obs_or_fqs"] = obs_or_fqs
|
||
|
return _DerivedObserverOrFakeQuantize.with_args(**kwargs)()
|
||
|
elif isinstance(quantization_spec, FixedQParamsQuantizationSpec):
|
||
|
kwargs = _get_observer_kwargs(quantization_spec)
|
||
|
observer_ctr = FixedQParamsObserver.with_args(**kwargs)
|
||
|
if is_qat:
|
||
|
return FixedQParamsFakeQuantize.with_args(observer=observer_ctr)
|
||
|
else:
|
||
|
return observer_ctr()
|
||
|
|
||
|
assert isinstance(quantization_spec, QuantizationSpec)
|
||
|
observer_or_fake_quant_ctr = quantization_spec.observer_or_fake_quant_ctr
|
||
|
kwargs = _get_observer_kwargs(quantization_spec)
|
||
|
kwargs.pop("observer_or_fake_quant_ctr")
|
||
|
# we will remove is_dynamic from QuantizationSpec because
|
||
|
# it seems that dynamic range quantization
|
||
|
obs_or_fq_class = observer_or_fake_quant_ctr
|
||
|
if isinstance(observer_or_fake_quant_ctr, _PartialWrapper):
|
||
|
obs_or_fq_class = observer_or_fake_quant_ctr.p.func # type: ignore[union-attr, assignment]
|
||
|
if "PerChannel" not in obs_or_fq_class.__name__: # type: ignore[operator, union-attr]
|
||
|
kwargs.pop("ch_axis")
|
||
|
return observer_or_fake_quant_ctr.with_args(**kwargs)()
|
||
|
|
||
|
def _needs_obs_or_fq(
|
||
|
prev_output_dtype: Any,
|
||
|
prev_output_is_dynamic: bool,
|
||
|
cur_target_dtype: Any,
|
||
|
cur_target_is_dynamic: bool,
|
||
|
reuse_input_obs_or_fq: bool,
|
||
|
is_zeroth_arg: bool = False) -> bool:
|
||
|
"""
|
||
|
note: we will treat "not specified" as torch.float for now
|
||
|
utility function that checks if we should insert an observer or fake quant node
|
||
|
base on the requested dtype for the nodes from user
|
||
|
|
||
|
is_zeroth_arg: we only dynamically quantize the first arg of the node right now
|
||
|
this should be removed when we enable configuring dynamic quantization
|
||
|
for a specific argument, this can be removed if we deprecate fx graph mode
|
||
|
quantization
|
||
|
|
||
|
"""
|
||
|
|
||
|
# need to insert placeholder observer for dynamic quantization so that it can
|
||
|
# be converted to choose_qparams -> q -> dq in convert step
|
||
|
if cur_target_is_dynamic:
|
||
|
assert cur_target_dtype in _OBS_DTYPE_LIST, \
|
||
|
f"Expected cur_target_dtype to be torch.float, but got: {cur_target_dtype}"
|
||
|
assert prev_output_dtype not in _DO_NOT_OBS_DTYPE_LIST
|
||
|
return is_zeroth_arg
|
||
|
if reuse_input_obs_or_fq:
|
||
|
return False
|
||
|
# non dynamic quantization
|
||
|
if cur_target_dtype in _OBS_DTYPE_LIST:
|
||
|
return prev_output_dtype in _OBS_DTYPE_LIST + [torch.float] and cur_target_dtype != prev_output_dtype
|
||
|
|
||
|
# lots of error checking are skipped here for now
|
||
|
return False
|
||
|
|
||
|
def _is_activation_post_process_node(node: Node, named_modules: Dict[str, torch.nn.Module]) -> bool:
|
||
|
return isinstance(node, torch.fx.Node) and node.op == "call_module" and \
|
||
|
_is_activation_post_process(named_modules[str(node.target)])
|
||
|
|
||
|
def _get_dtype_and_is_dynamic(obs_or_fq: Optional[ObserverOrFakeQuantize]) -> Tuple[Optional[torch.dtype], bool]:
|
||
|
""" Given a constructor for observer or fake quant module, returns
|
||
|
a Tuple of dtype and is_dynamic
|
||
|
"""
|
||
|
# TODO: instead of instantiating the instance, we can use inspect to get the default args
|
||
|
if obs_or_fq is None:
|
||
|
return None, False
|
||
|
else:
|
||
|
return obs_or_fq.dtype, getattr(obs_or_fq, "is_dynamic", False) # type: ignore[return-value]
|
||
|
|
||
|
def _is_input_arg_dtype_supported_by_backend(
|
||
|
arg: Argument,
|
||
|
node: Node,
|
||
|
qconfig: QConfigAny,
|
||
|
dtype_config: DTypeConfig,
|
||
|
backend_config: BackendConfig,
|
||
|
) -> bool:
|
||
|
""" Check if the configured qconfig for the argument
|
||
|
is supported by the backend or not
|
||
|
"""
|
||
|
if isinstance(arg, (list, tuple)):
|
||
|
return all(_is_input_arg_dtype_supported_by_backend(
|
||
|
a, node, qconfig,
|
||
|
dtype_config, backend_config) for a in arg)
|
||
|
if not isinstance(arg, Node):
|
||
|
return True
|
||
|
# TODO: support check for standalone module
|
||
|
is_weight = node_arg_is_weight(node, arg)
|
||
|
is_bias = node_arg_is_bias(node, arg)
|
||
|
is_activation = not is_weight and not is_bias
|
||
|
if is_activation:
|
||
|
input_act_obs_or_fq_ctr = node.meta["target_dtype_info"].get("input_act_obs_or_fq_ctr")
|
||
|
input_act_obs_or_fq = input_act_obs_or_fq_ctr() if input_act_obs_or_fq_ctr else None
|
||
|
qconfig_dtype, qconfig_is_dynamic = _get_dtype_and_is_dynamic(input_act_obs_or_fq)
|
||
|
# TODO(future PR): remove the cast to bool below after figuring
|
||
|
# out why backend_config has is_dynamic set to None in some cases.
|
||
|
return (dtype_config.input_dtype is None) or (
|
||
|
dtype_config.input_dtype == qconfig_dtype and
|
||
|
bool(dtype_config.is_dynamic) == bool(qconfig_is_dynamic) and
|
||
|
_qconfig_satisfies_dtype_config_constraints(qconfig, dtype_config.input_dtype_with_constraints)
|
||
|
)
|
||
|
elif is_weight:
|
||
|
# TODO: move dtype check into `_qconfig_satisfies_dtype_config_constraints` as well
|
||
|
weight_obs_or_fq_ctr = node.meta["target_dtype_info"].get("weight_obs_or_fq_ctr", None)
|
||
|
weight_obs_or_fq = weight_obs_or_fq_ctr() if weight_obs_or_fq_ctr else None
|
||
|
qconfig_weight_dtype, _ = _get_dtype_and_is_dynamic(weight_obs_or_fq)
|
||
|
backend_config_weight_dtype = dtype_config.weight_dtype
|
||
|
dtype_matches = qconfig_weight_dtype == backend_config_weight_dtype
|
||
|
qconfig_satisfies_constraints = _qconfig_satisfies_dtype_config_constraints(
|
||
|
qconfig, dtype_config.weight_dtype_with_constraints, is_activation=False)
|
||
|
return backend_config_weight_dtype is None or (dtype_matches and qconfig_satisfies_constraints)
|
||
|
else: # bias
|
||
|
# TODO: move dtype check into `_qconfig_satisfies_dtype_config_constraints` as well
|
||
|
bias_obs_or_fq_ctr = node.meta["target_dtype_info"].get("bias_obs_or_fq_ctr", None)
|
||
|
bias_obs_or_fq = bias_obs_or_fq_ctr() if bias_obs_or_fq_ctr else None
|
||
|
qconfig_bias_dtype, _ = _get_dtype_and_is_dynamic(bias_obs_or_fq)
|
||
|
backend_config_bias_dtype = dtype_config.bias_dtype
|
||
|
return backend_config_bias_dtype is None or qconfig_bias_dtype == backend_config_bias_dtype
|
||
|
|
||
|
def _is_output_dtype_supported_by_backend(
|
||
|
node: Node,
|
||
|
qconfig: QConfigAny,
|
||
|
dtype_config: DTypeConfig,
|
||
|
) -> bool:
|
||
|
""" Check if the configured qconfig for the output
|
||
|
is supported by the backend or not
|
||
|
"""
|
||
|
# TODO: move dtype check into `_qconfig_satisfies_dtype_config_constraints` as well
|
||
|
backend_config_output_dtype = dtype_config.output_dtype
|
||
|
# TODO: we should check is_dynamic here as well, the code from _is_input_arg_dtype_supported_by_backend
|
||
|
# from input activation check can be reused here
|
||
|
qconfig_output_dtype = None
|
||
|
output_act_obs_or_fq_ctr = node.meta["target_dtype_info"].get("output_act_obs_or_fq_ctr", _DEFAULT_FP32_OBS_OR_FQ_CTR)
|
||
|
output_act_obs_or_fq = output_act_obs_or_fq_ctr() if output_act_obs_or_fq_ctr else None
|
||
|
qconfig_output_dtype, qconfig_output_is_dynamic = _get_dtype_and_is_dynamic(output_act_obs_or_fq)
|
||
|
# TODO: this is a hack because we can only specify one activation_obs_or_fq for
|
||
|
# qconfig (qconfig.activation), and we are only supporting dynamically quantized
|
||
|
# linear op which has fp32 output dtype, this should be removed if we generalize
|
||
|
# the structure of qconfig in the future
|
||
|
if qconfig_output_is_dynamic:
|
||
|
qconfig_output_dtype = torch.float32
|
||
|
dtype_matches = qconfig_output_dtype == backend_config_output_dtype
|
||
|
qconfig_satisfies_constraints = _qconfig_satisfies_dtype_config_constraints(
|
||
|
qconfig, dtype_config.output_dtype_with_constraints)
|
||
|
return backend_config_output_dtype is None or (dtype_matches and qconfig_satisfies_constraints)
|
||
|
|
||
|
def _is_observer_in_same_graph(
|
||
|
node: Node,
|
||
|
named_modules: Dict[str, torch.nn.Module],
|
||
|
obs_or_fq_map: Dict[EdgeOrNode, ObserverOrFakeQuantize],
|
||
|
is_qat,
|
||
|
):
|
||
|
""" Check if observer in same graph
|
||
|
when the node output is not fp32 and input is 'placeholder'
|
||
|
the input is assumed to be quantized, so it is observed
|
||
|
in a different place rather than not observed.
|
||
|
"""
|
||
|
node_output_dtype = _get_arg_target_dtype_as_output(node, named_modules, obs_or_fq_map, is_qat)
|
||
|
if len(node.args) > 0 and isinstance(node.args[0], Node):
|
||
|
if node_output_dtype in [torch.quint8, torch.uint8] and node.args[0].op == 'placeholder':
|
||
|
return False
|
||
|
return True
|
||
|
|
||
|
def _is_pattern_dtype_config_and_qconfig_supported_by_backend(
|
||
|
pattern: Optional[Pattern],
|
||
|
matched_node_pattern: Optional[List[Node]],
|
||
|
qconfig: QConfigAny,
|
||
|
backend_config: BackendConfig,
|
||
|
) -> bool:
|
||
|
""" Check if the dtype configuration of a pattern is supported by
|
||
|
the backend or not, and whether the qconfig satisfies constraints
|
||
|
specified in the corresponding dtype config.
|
||
|
"""
|
||
|
if backend_config is None or pattern is None:
|
||
|
return True
|
||
|
assert matched_node_pattern is not None and len(matched_node_pattern) >= 1
|
||
|
pattern_to_dtype_configs = get_pattern_to_dtype_configs(backend_config)
|
||
|
dtype_configs: List[DTypeConfig] = pattern_to_dtype_configs.get(pattern, [])
|
||
|
pattern_to_root_node_getter = get_fusion_pattern_to_root_node_getter(backend_config)
|
||
|
|
||
|
root_node_getter = pattern_to_root_node_getter.get(pattern, _default_root_node_getter)
|
||
|
root_node = root_node_getter(matched_node_pattern)
|
||
|
input_node = root_node
|
||
|
output_node = matched_node_pattern[0]
|
||
|
for dtype_config in dtype_configs:
|
||
|
# check if arg dtype are supported
|
||
|
supported = True
|
||
|
for arg in list(input_node.args) + list(input_node.kwargs.values()):
|
||
|
supported = supported and _is_input_arg_dtype_supported_by_backend(
|
||
|
arg, input_node, qconfig, dtype_config, backend_config)
|
||
|
# check if output dtype is supported
|
||
|
supported = supported and _is_output_dtype_supported_by_backend(
|
||
|
output_node, qconfig, dtype_config)
|
||
|
if supported:
|
||
|
return True
|
||
|
return False
|
||
|
|
||
|
def _get_standalone_module_configs(
|
||
|
node: Node,
|
||
|
named_modules: Dict[str, torch.nn.Module],
|
||
|
prepare_custom_config: PrepareCustomConfig,
|
||
|
parent_qconfig: QConfigAny,
|
||
|
parent_backend_config: Optional[BackendConfig],
|
||
|
) -> Tuple[QConfigMapping, Tuple[Any, ...], PrepareCustomConfig, Optional[BackendConfig]]:
|
||
|
"""
|
||
|
Returns the standalone module QConfigMapping and PrepareCustomConfig
|
||
|
for `node`, assuming that the module pointed to by `node` is
|
||
|
a standalone modules.
|
||
|
"""
|
||
|
module_name = str(node.target)
|
||
|
module_type = type(named_modules[module_name]) # type: ignore[index]
|
||
|
# name config has precedence over type config
|
||
|
config_entry = StandaloneModuleConfigEntry(None, (), None, None)
|
||
|
config_entry = prepare_custom_config.standalone_module_classes.get(module_type, config_entry)
|
||
|
config_entry = prepare_custom_config.standalone_module_names.get(module_name, config_entry)
|
||
|
# fallback to use parent module's qconfig if user didn't specify qconfig dict
|
||
|
qconfig_mapping = config_entry.qconfig_mapping or QConfigMapping().set_global(parent_qconfig)
|
||
|
example_inputs = config_entry.example_inputs
|
||
|
prepare_custom_config = config_entry.prepare_custom_config or PrepareCustomConfig()
|
||
|
backend_config = config_entry.backend_config or parent_backend_config
|
||
|
return (qconfig_mapping, example_inputs, prepare_custom_config, backend_config)
|
||
|
|
||
|
def _qat_swap_modules(
|
||
|
root: torch.nn.Module,
|
||
|
module_to_qat_module: Dict[Pattern, Type[torch.nn.Module]]) -> None:
|
||
|
convert(root, mapping=module_to_qat_module, inplace=True, remove_qconfig=False)
|
||
|
|
||
|
def _add_matched_node_name_to_set(matched_node_pattern: NodePattern, s: Set[str]):
|
||
|
if isinstance(matched_node_pattern, Node):
|
||
|
s.add(matched_node_pattern.name)
|
||
|
elif isinstance(matched_node_pattern, (list, tuple)):
|
||
|
for maybe_node in matched_node_pattern:
|
||
|
_add_matched_node_name_to_set(maybe_node, s)
|
||
|
|
||
|
def _insert_obs_or_fq(
|
||
|
node: Node,
|
||
|
obs_or_fq: ObserverOrFakeQuantize,
|
||
|
model: torch.nn.Module,
|
||
|
named_modules: Dict[str, torch.nn.Module],
|
||
|
graph: Graph,
|
||
|
) -> Node:
|
||
|
"""
|
||
|
Attaches `obs_or_fq` to `model`, and creates a node which calls
|
||
|
`obs_or_fq` on the output of `node`.
|
||
|
|
||
|
obs_or_fq: an instance of Observer or FakeQuantize module
|
||
|
"""
|
||
|
model_device = assert_and_get_unique_device(model)
|
||
|
if model_device:
|
||
|
obs_or_fq.to(model_device)
|
||
|
# add obs_or_fq module as attribute
|
||
|
if is_equalization_observer(obs_or_fq):
|
||
|
prefix = node.name + '_equalization_process_'
|
||
|
else:
|
||
|
prefix = 'activation_post_process_'
|
||
|
get_new_obs_or_fq_name = get_new_attr_name_with_prefix(prefix)
|
||
|
obs_or_fq_name = get_new_obs_or_fq_name(model)
|
||
|
setattr(model, obs_or_fq_name, obs_or_fq)
|
||
|
named_modules[obs_or_fq_name] = obs_or_fq
|
||
|
with graph.inserting_after(node):
|
||
|
new_obs = graph.create_node(
|
||
|
'call_module', obs_or_fq_name, (node,), {})
|
||
|
return new_obs
|
||
|
|
||
|
def _set_target_dtype_info_for_matched_node_pattern(
|
||
|
matched_node_pattern: NodePattern,
|
||
|
last_node: Node,
|
||
|
qconfig: QConfigAny,
|
||
|
qhandler: Optional[QuantizeHandler],
|
||
|
backend_config: BackendConfig,
|
||
|
named_modules: Dict[str, torch.nn.Module],
|
||
|
cache_for_no_tensor_check: Dict[Node, bool],
|
||
|
processed_nodes: Set[Node],
|
||
|
) -> None:
|
||
|
""" Sets the target_dtype_info for each node in matched_node_pattern
|
||
|
Note: processed_nodes is used to ensure we only process each node once
|
||
|
"""
|
||
|
if isinstance(matched_node_pattern, (list, tuple)):
|
||
|
for node_pattern in matched_node_pattern:
|
||
|
_set_target_dtype_info_for_matched_node_pattern(
|
||
|
node_pattern,
|
||
|
last_node,
|
||
|
qconfig,
|
||
|
qhandler,
|
||
|
backend_config,
|
||
|
named_modules,
|
||
|
cache_for_no_tensor_check,
|
||
|
processed_nodes
|
||
|
)
|
||
|
|
||
|
# set target_dtype_info if matched_node_pattern is a Node
|
||
|
# other types of matched object, e.g. int, float literals, are ignored
|
||
|
elif isinstance(matched_node_pattern, Node):
|
||
|
# for pyre
|
||
|
assert isinstance(matched_node_pattern, Node)
|
||
|
node = matched_node_pattern
|
||
|
if node in processed_nodes:
|
||
|
return
|
||
|
processed_nodes.add(node)
|
||
|
|
||
|
if qconfig is None:
|
||
|
return
|
||
|
# TODO: refactor the following code in terms of apply a qconfig to a pattern
|
||
|
# e.g. for a pattern with op1 -> op2 -> op3, and qconfig = QConfig(input_act=obs0, output_act=obs1)
|
||
|
# we set the input_obs_or_fq_ctr for the arguments of op1 to based on qconfig.input_act,
|
||
|
# and set output_obs_or_fq_ctr based on qconfig.output_act
|
||
|
# this also requires we extend the structure of QConfig to support more fine
|
||
|
# grained configurations
|
||
|
target_dtype_info: Dict[str, Any] = (
|
||
|
_get_target_activation_dtype_for_node(
|
||
|
node,
|
||
|
qconfig,
|
||
|
qhandler,
|
||
|
named_modules,
|
||
|
backend_config,
|
||
|
cache_for_no_tensor_check,
|
||
|
)
|
||
|
)
|
||
|
node.meta["target_dtype_info"] = target_dtype_info
|
||
|
|
||
|
def _get_target_activation_dtype_for_node(
|
||
|
node: Node,
|
||
|
qconfig: QConfigAny,
|
||
|
qhandler: Optional[QuantizeHandler],
|
||
|
named_modules: Dict[str, torch.nn.Module],
|
||
|
backend_config: BackendConfig,
|
||
|
cache_for_no_tensor_check: Dict[Node, bool],
|
||
|
) -> Dict[str, Any]:
|
||
|
"""
|
||
|
For each op attribute in the op's input activation, output activation,
|
||
|
weight, bias - returns the settings of dtype and is_dynamic we expect
|
||
|
for the `quantize` call in the reference model representation, or None
|
||
|
if there is no `quantize` call needed.
|
||
|
|
||
|
For example, if we have a node corresponding to `op0` in
|
||
|
|
||
|
x0 -> op0 -> x1
|
||
|
|
||
|
And we want a reference quantized representation to be
|
||
|
|
||
|
x0 -> quant_static -> dequant -> op0 -> quant_dynamic -> dequant -> x1
|
||
|
|
||
|
Then this function will return
|
||
|
|
||
|
{
|
||
|
"input_act_obs_or_fq_ctr": MinMaxObserver.with_args(dtype=torch.quint8, is_dynamic=False),
|
||
|
"output_act_obs_or_fq_ctr": MinMaxObserver.with_args(dtype=torch.quint8, is_dynamic=False),
|
||
|
}
|
||
|
|
||
|
TODO(future PR, if needed): explicitly spell out the non-Tensor
|
||
|
dtypes.
|
||
|
"""
|
||
|
args_have_no_tensors = \
|
||
|
all_node_args_have_no_tensors(
|
||
|
node, named_modules, cache_for_no_tensor_check)
|
||
|
if args_have_no_tensors:
|
||
|
return {
|
||
|
"input_act_obs_or_fq_ctr": None,
|
||
|
"output_act_obs_or_fq_ctr": None,
|
||
|
}
|
||
|
# get qconfig to determine the eventual dtype of this node
|
||
|
if qconfig is not None:
|
||
|
act_dtype, weight_dtype, input_act_is_dynamic = \
|
||
|
get_qconfig_dtypes(qconfig)
|
||
|
|
||
|
# Currently `QConfig` only has one `activation` field.
|
||
|
# For static quantization, it is reused for both input
|
||
|
# and output activation. For dynamic quantization, this
|
||
|
# field is currently only used for the input activation,
|
||
|
# with the output activation being in fp32.
|
||
|
# In the future this may change as we add more fields
|
||
|
# to the `QConfig` object.
|
||
|
output_act_dtype = act_dtype \
|
||
|
if (not input_act_is_dynamic) else torch.float
|
||
|
|
||
|
bias_dtype = torch.float16 \
|
||
|
if (
|
||
|
act_dtype == torch.float16
|
||
|
and weight_dtype == torch.float16
|
||
|
and (not input_act_is_dynamic)
|
||
|
) else torch.float
|
||
|
|
||
|
is_general_tensor_value_op = \
|
||
|
(qhandler is not None and qhandler.is_general_tensor_value_op())
|
||
|
|
||
|
_is_standalone_module = (
|
||
|
qhandler is not None and qhandler.is_standalone_module()
|
||
|
)
|
||
|
|
||
|
weight_index = None
|
||
|
if isinstance(node, Node) and node.op == "call_function" and \
|
||
|
node.target in backend_config._pattern_complex_format_to_config:
|
||
|
weight_index = backend_config._pattern_complex_format_to_config[node.target]._input_type_to_index.get("weight")
|
||
|
|
||
|
bias_index = None
|
||
|
if isinstance(node, Node) and node.op == "call_function" and \
|
||
|
node.target in backend_config._pattern_complex_format_to_config:
|
||
|
bias_index = backend_config._pattern_complex_format_to_config[node.target]._input_type_to_index.get("bias")
|
||
|
|
||
|
return {
|
||
|
"input_act_obs_or_fq_ctr": qconfig.activation,
|
||
|
"weight_obs_or_fq_ctr": qconfig.weight,
|
||
|
"bias_obs_or_fq_ctr": PlaceholderObserver.with_args(dtype=bias_dtype),
|
||
|
"weight_index": weight_index,
|
||
|
"bias_index": bias_index,
|
||
|
"output_act_obs_or_fq_ctr": qconfig.activation,
|
||
|
"reuse_input_obs_or_fq": _is_reuse_input_qconfig(qconfig),
|
||
|
"input_output_share_observers": is_general_tensor_value_op,
|
||
|
"_is_standalone_module": _is_standalone_module,
|
||
|
}
|
||
|
return copy.copy(_DEFAULT_FP32_QCONFIG_FOR_TARGET_DTYPE_INFO)
|
||
|
|
||
|
def _get_output_act_obs_or_fq(
|
||
|
arg: Node,
|
||
|
named_modules: Dict[str, torch.nn.Module],
|
||
|
obs_or_fq_map: Dict[EdgeOrNode, ObserverOrFakeQuantize],
|
||
|
is_qat: bool,
|
||
|
) -> ObserverOrFakeQuantize:
|
||
|
""" Get the constructor for observer or fake quant object for
|
||
|
the argument in the original graph as the output of previous node,
|
||
|
skipping inserted observers
|
||
|
|
||
|
We are assuming that the observers are inserted correctly, and the dtype for
|
||
|
argument in quantized graph will match what is specified by the qconfig
|
||
|
"""
|
||
|
assert isinstance(arg, Node)
|
||
|
if "quantization_annotation" in arg.meta:
|
||
|
return _create_obs_or_fq_from_qspec(arg.meta["quantization_annotation"].output_qspec, obs_or_fq_map, is_qat)
|
||
|
|
||
|
# Custom module LSTM output is a tuple that we broke down into the internal nodes in order
|
||
|
# to insert DeQuantStubs (see `_insert_dequant_stubs_for_custom_module_lstm_output`).
|
||
|
# Since we modified the graph in this case, we must trace back from the args through
|
||
|
# the specific nodes we added in order to reach the original LSTM node. Otherwise, we would
|
||
|
# not be able to accurately detect whether this node is a consumer of custom module LSTM.
|
||
|
custom_module_lstm_node = _maybe_get_custom_module_lstm_from_node_arg(arg, named_modules)
|
||
|
output_act_obs_or_fq_ctr = None
|
||
|
if custom_module_lstm_node is not None:
|
||
|
output_act_obs_or_fq_ctr = custom_module_lstm_node.meta["target_dtype_info"]["output_act_obs_or_fq_ctr"]
|
||
|
output_act_obs_or_fq = output_act_obs_or_fq_ctr() if output_act_obs_or_fq_ctr else None
|
||
|
elif _is_activation_post_process_node(arg, named_modules):
|
||
|
observed_arg = arg.args[0]
|
||
|
assert isinstance(observed_arg, Node), "Currently we only support observing Node"
|
||
|
if "quantization_annotation" in observed_arg.meta:
|
||
|
output_act_obs_or_fq = \
|
||
|
_create_obs_or_fq_from_qspec(
|
||
|
observed_arg.meta["quantization_annotation"].output_qspec, obs_or_fq_map, is_qat)
|
||
|
else:
|
||
|
assert "target_dtype_info" in observed_arg.meta
|
||
|
output_act_obs_or_fq_ctr = observed_arg.meta["target_dtype_info"]["output_act_obs_or_fq_ctr"]
|
||
|
output_act_obs_or_fq = output_act_obs_or_fq_ctr() if output_act_obs_or_fq_ctr else None
|
||
|
else:
|
||
|
if "target_dtype_info" in arg.meta:
|
||
|
output_act_obs_or_fq_ctr = \
|
||
|
arg.meta["target_dtype_info"].get("output_act_obs_or_fq_ctr", _DEFAULT_FP32_OBS_OR_FQ_CTR)
|
||
|
else:
|
||
|
output_act_obs_or_fq_ctr = _DEFAULT_FP32_OBS_OR_FQ_CTR
|
||
|
output_act_obs_or_fq = output_act_obs_or_fq_ctr() if output_act_obs_or_fq_ctr else None
|
||
|
|
||
|
return output_act_obs_or_fq
|
||
|
|
||
|
def _get_arg_target_dtype_as_output(
|
||
|
arg: Node,
|
||
|
named_modules: Dict[str, torch.nn.Module],
|
||
|
obs_or_fq_map: Dict[EdgeOrNode, ObserverOrFakeQuantize],
|
||
|
is_qat: bool,
|
||
|
) -> Optional[torch.dtype]:
|
||
|
arg_as_output_act_obs_or_fq = _get_output_act_obs_or_fq(arg, named_modules, obs_or_fq_map, is_qat)
|
||
|
arg_as_output_target_dtype, _ = _get_dtype_and_is_dynamic(arg_as_output_act_obs_or_fq)
|
||
|
return arg_as_output_target_dtype
|
||
|
|
||
|
def _get_arg_as_input_act_obs_or_fq(
|
||
|
arg: Node,
|
||
|
node: Node,
|
||
|
named_modules: Dict[str, torch.nn.Module],
|
||
|
obs_or_fq_map: Dict[EdgeOrNode, ObserverOrFakeQuantize],
|
||
|
is_qat: bool,
|
||
|
) -> Optional[ObserverOrFakeQuantize]:
|
||
|
""" Get the observer or fake quant constructor for the Argument `arg`, as input
|
||
|
to Node `node`
|
||
|
"""
|
||
|
assert isinstance(arg, Node)
|
||
|
# "input_qspec_map" is the more general design we'll use for pt2e path
|
||
|
# it is a map from input argument node to observer or fake quant constructor, for example
|
||
|
# for the following graph:
|
||
|
# x -> conv -> output
|
||
|
#
|
||
|
# we may annotate conv node like the following:
|
||
|
# conv.meta[...] = QuantizationAnnotation("input_qspec_map": {x: MinMaxObserver.with_args(dtype=torch.qint8)}, ...)
|
||
|
#
|
||
|
if "quantization_annotation" in node.meta:
|
||
|
input_qspec_map = node.meta["quantization_annotation"].input_qspec_map
|
||
|
input_arg_qspec = _get_qspec_for_arg(arg, input_qspec_map, named_modules)
|
||
|
if input_arg_qspec is None:
|
||
|
input_arg_obs_or_fq = _DEFAULT_FP32_OBS_OR_FQ_CTR()
|
||
|
else:
|
||
|
input_arg_obs_or_fq = _create_obs_or_fq_from_qspec(input_arg_qspec, obs_or_fq_map, is_qat)
|
||
|
return input_arg_obs_or_fq
|
||
|
|
||
|
# we can remove the following path in the future if fx graph mode quantization is
|
||
|
# no longer used
|
||
|
is_weight = node_arg_is_weight(node, arg)
|
||
|
is_bias = node_arg_is_bias(node, arg)
|
||
|
is_activation = not is_weight and not is_bias
|
||
|
obs_or_fq_ctr = None
|
||
|
if is_activation:
|
||
|
obs_or_fq_ctr = node.meta["target_dtype_info"].get("input_act_obs_or_fq_ctr", _DEFAULT_FP32_OBS_OR_FQ_CTR)
|
||
|
elif is_weight:
|
||
|
if node.target not in NON_QUANTIZABLE_WEIGHT_OPS:
|
||
|
obs_or_fq_ctr = node.meta["target_dtype_info"].get("weight_obs_or_fq_ctr", _DEFAULT_FP32_OBS_OR_FQ_CTR)
|
||
|
else:
|
||
|
obs_or_fq_ctr = node.meta["target_dtype_info"].get("bias_obs_or_fq_ctr", _DEFAULT_FP32_OBS_OR_FQ_CTR)
|
||
|
return obs_or_fq_ctr() if obs_or_fq_ctr else None
|
||
|
|
||
|
def _maybe_insert_input_observer_for_arg_or_kwarg(
|
||
|
node: Union[Node, Any],
|
||
|
arg: Argument,
|
||
|
qconfig: QConfigAny,
|
||
|
model: torch.nn.Module,
|
||
|
named_modules: Dict[str, torch.nn.Module],
|
||
|
graph: Graph,
|
||
|
qhandler: Optional[QuantizeHandler],
|
||
|
prepare_custom_config: PrepareCustomConfig,
|
||
|
obs_or_fq_map: Dict[EdgeOrNode, ObserverOrFakeQuantize],
|
||
|
is_qat: bool,
|
||
|
backend_config: Optional[BackendConfig] = None,
|
||
|
) -> Argument:
|
||
|
"""
|
||
|
Given a `node` and an `arg`, inserts an input observer between
|
||
|
`node` and `arg` if necessary.
|
||
|
"""
|
||
|
# for ops such as torch.cat([x0, x1]),
|
||
|
# traverse through the list
|
||
|
if isinstance(arg, (list, tuple)):
|
||
|
new_arg_to_return = []
|
||
|
for inner_arg in arg:
|
||
|
new_inner_arg = _maybe_insert_input_observer_for_arg_or_kwarg(
|
||
|
node, inner_arg, qconfig, model, named_modules,
|
||
|
graph,
|
||
|
qhandler,
|
||
|
prepare_custom_config,
|
||
|
obs_or_fq_map,
|
||
|
is_qat,
|
||
|
backend_config)
|
||
|
new_arg_to_return.append(new_inner_arg)
|
||
|
return type(arg)(new_arg_to_return)
|
||
|
|
||
|
if not isinstance(arg, Node):
|
||
|
return arg
|
||
|
assert isinstance(arg, Node)
|
||
|
# default (no observer)
|
||
|
new_arg = arg
|
||
|
|
||
|
is_standalone_module = qhandler is not None and qhandler.is_standalone_module()
|
||
|
# TODO: move this to a separate function
|
||
|
if not is_standalone_module:
|
||
|
# Note: qconfig can be None in this branch this we are getting act/fq from
|
||
|
# node.meta now
|
||
|
# regular flow for most nodes, except standalone modules
|
||
|
|
||
|
if "quantization_annotation" in node.meta:
|
||
|
reuse_input_obs_or_fq = node.meta["quantization_annotation"]._reuse_input_obs_or_fq
|
||
|
else:
|
||
|
assert "target_dtype_info" in node.meta
|
||
|
# TODO: we are assuming "target_dtype_info" exists here, maybe
|
||
|
# a default value also need to be provided here
|
||
|
target_dtype_info = node.meta["target_dtype_info"]
|
||
|
# for nodes that doesn't have `reuse_input_obs_or_fq` configured,
|
||
|
# we'll default to False, this makes configuring this field optional for users
|
||
|
reuse_input_obs_or_fq = target_dtype_info.get("reuse_input_obs_or_fq", False)
|
||
|
arg_as_input_act_obs_or_fq = _get_arg_as_input_act_obs_or_fq(arg, node, named_modules, obs_or_fq_map, is_qat)
|
||
|
arg_as_input_target_dtype, arg_as_input_target_is_dynamic = _get_dtype_and_is_dynamic(arg_as_input_act_obs_or_fq)
|
||
|
|
||
|
arg_as_output_act_obs_or_fq = _get_output_act_obs_or_fq(arg, named_modules, obs_or_fq_map, is_qat)
|
||
|
arg_as_output_target_dtype, arg_as_output_target_is_dynamic = _get_dtype_and_is_dynamic(arg_as_output_act_obs_or_fq)
|
||
|
|
||
|
|
||
|
needs_obs_or_fq = _needs_obs_or_fq(
|
||
|
arg_as_output_target_dtype,
|
||
|
arg_as_output_target_is_dynamic,
|
||
|
arg_as_input_target_dtype,
|
||
|
arg_as_input_target_is_dynamic,
|
||
|
reuse_input_obs_or_fq,
|
||
|
is_zeroth_arg=len(node.args) > 0 and arg is node.args[0],
|
||
|
)
|
||
|
|
||
|
else:
|
||
|
assert qconfig is not None
|
||
|
# custom flow for standalone modules
|
||
|
_, _, sm_prepare_custom_config, _ = \
|
||
|
_get_standalone_module_configs(
|
||
|
node, named_modules, prepare_custom_config, qconfig, backend_config)
|
||
|
sm_input_quantized_idxs = sm_prepare_custom_config.input_quantized_indexes
|
||
|
|
||
|
# for args, this is set to the index of the current arg
|
||
|
# for kwargs, this is left at None
|
||
|
cur_input_idx = None
|
||
|
for arg_idx, arg_to_check in enumerate(node.args):
|
||
|
if arg_to_check is arg:
|
||
|
cur_input_idx = arg_idx
|
||
|
break
|
||
|
|
||
|
if cur_input_idx is None:
|
||
|
needs_obs_or_fq = False
|
||
|
else:
|
||
|
arg_as_output_target_dtype = _get_arg_target_dtype_as_output(arg, named_modules, obs_or_fq_map, is_qat)
|
||
|
arg_as_input_target_dtype = torch.quint8 if cur_input_idx in sm_input_quantized_idxs \
|
||
|
else torch.float
|
||
|
needs_obs_or_fq = (
|
||
|
(arg_as_output_target_dtype != arg_as_input_target_dtype) and
|
||
|
(arg_as_input_target_dtype != torch.float)
|
||
|
)
|
||
|
|
||
|
act_post_process_ctr = qconfig.activation
|
||
|
arg_as_input_act_obs_or_fq = act_post_process_ctr() if act_post_process_ctr else None
|
||
|
|
||
|
if needs_obs_or_fq:
|
||
|
|
||
|
existing_obs_node = None
|
||
|
|
||
|
# Before using the new observer, check if an observer
|
||
|
# of the correct type already exists. If it does, use it.
|
||
|
# This prevents duplicate observer insertions if a node is
|
||
|
# used by multiple nodes.
|
||
|
# TODO: this is looking into how the value is used in the future
|
||
|
# we should remove this
|
||
|
# removing this means we insert one observer for each use, even if they
|
||
|
# have the same dtype, we can have an extra pass that removes the extra observers
|
||
|
for maybe_obs_node in arg.users.keys():
|
||
|
if maybe_obs_node.op == 'call_module':
|
||
|
maybe_obs_mod = named_modules[maybe_obs_node.target] # type: ignore[index]
|
||
|
if (
|
||
|
type(maybe_obs_mod) == type(arg_as_input_act_obs_or_fq) and
|
||
|
maybe_obs_mod.dtype == arg_as_input_target_dtype # type: ignore[possibly-undefined]
|
||
|
):
|
||
|
arg_as_input_act_obs_or_fq = maybe_obs_mod # type: ignore[assignment]
|
||
|
existing_obs_node = maybe_obs_node
|
||
|
break
|
||
|
|
||
|
assert arg_as_input_act_obs_or_fq is not None
|
||
|
obs_or_fq_map[(arg, node)] = arg_as_input_act_obs_or_fq
|
||
|
if existing_obs_node is None:
|
||
|
new_obs_node = _insert_obs_or_fq(
|
||
|
arg, arg_as_input_act_obs_or_fq, model, named_modules, graph)
|
||
|
# override this arg to be the observed arg
|
||
|
new_arg = new_obs_node
|
||
|
else:
|
||
|
new_arg = existing_obs_node
|
||
|
|
||
|
return new_arg
|
||
|
|
||
|
|
||
|
def _maybe_insert_input_observers_for_node(
|
||
|
node: Node,
|
||
|
qconfig: QConfigAny,
|
||
|
model: torch.nn.Module,
|
||
|
named_modules: Dict[str, torch.nn.Module],
|
||
|
graph: Graph,
|
||
|
qhandler: Optional[QuantizeHandler],
|
||
|
prepare_custom_config: PrepareCustomConfig,
|
||
|
obs_or_fq_map: Dict[EdgeOrNode, ObserverOrFakeQuantize],
|
||
|
is_qat: bool,
|
||
|
backend_config: Optional[BackendConfig] = None
|
||
|
) -> None:
|
||
|
"""
|
||
|
If needed, inserts observers to the input args and kwargs of `node`.
|
||
|
Note: modifies `node` inplace.
|
||
|
|
||
|
For example, if cur_node needs an observer after prev_node, we change from
|
||
|
|
||
|
prev_node -> cur_node
|
||
|
|
||
|
To
|
||
|
|
||
|
prev_node -> obs -> cur_node
|
||
|
|
||
|
Note: backend_config only needed for standalone_module node
|
||
|
"""
|
||
|
# Look through every input arg. If that arg's target dtype does not
|
||
|
# match the current node's target dtype, insert an observer.
|
||
|
new_args = []
|
||
|
for arg in node.args:
|
||
|
new_arg = _maybe_insert_input_observer_for_arg_or_kwarg(
|
||
|
node, arg, qconfig, model, named_modules, graph,
|
||
|
qhandler,
|
||
|
prepare_custom_config,
|
||
|
obs_or_fq_map,
|
||
|
is_qat,
|
||
|
backend_config)
|
||
|
new_args.append(new_arg)
|
||
|
|
||
|
new_kwargs = {}
|
||
|
for k, kwarg in node.kwargs.items():
|
||
|
new_kwarg = _maybe_insert_input_observer_for_arg_or_kwarg(
|
||
|
node, kwarg, qconfig, model, named_modules, graph,
|
||
|
qhandler,
|
||
|
prepare_custom_config,
|
||
|
obs_or_fq_map,
|
||
|
is_qat,
|
||
|
backend_config)
|
||
|
new_kwargs[k] = new_kwarg
|
||
|
|
||
|
# assign the new args and kwargs to the node, inplace
|
||
|
node.args = tuple(new_args)
|
||
|
node.kwargs = new_kwargs
|
||
|
|
||
|
def _maybe_insert_input_equalization_observers_for_node(
|
||
|
node: Node,
|
||
|
equalization_qconfig: Any,
|
||
|
model: torch.nn.Module,
|
||
|
named_modules: Dict[str, torch.nn.Module],
|
||
|
graph: Graph,
|
||
|
is_branch: bool,
|
||
|
) -> None:
|
||
|
"""
|
||
|
If `node` needs to be equalized, find the input/weight observers it needs in
|
||
|
`equalization_qconfig`, creates them, and inserts it into `graph`.
|
||
|
|
||
|
If `node` does not need an equalization observer, returns None.
|
||
|
"""
|
||
|
if equalization_qconfig is None or not node_supports_equalization(node, named_modules):
|
||
|
return
|
||
|
|
||
|
if is_branch:
|
||
|
warnings.warn(
|
||
|
f"Cannot equalize {node} because it is part of a branch."
|
||
|
)
|
||
|
return
|
||
|
|
||
|
new_args = []
|
||
|
for arg in node.args:
|
||
|
if not isinstance(arg, Node) or node_arg_is_bias(node, arg):
|
||
|
new_args.append(arg)
|
||
|
continue
|
||
|
|
||
|
is_weight = node_arg_is_weight(node, arg)
|
||
|
|
||
|
act_eq_process_ctr = equalization_qconfig.weight if is_weight else \
|
||
|
equalization_qconfig.input_activation
|
||
|
|
||
|
new_eq_obs_mod = act_eq_process_ctr()
|
||
|
new_eq_obs_node = _insert_obs_or_fq(
|
||
|
arg, new_eq_obs_mod, model, named_modules, graph)
|
||
|
|
||
|
new_args.append(new_eq_obs_node)
|
||
|
|
||
|
# assign the new args and kwargs to the node, inplace
|
||
|
node.args = tuple(new_args)
|
||
|
|
||
|
def _maybe_insert_output_observer_for_node(
|
||
|
node: Node,
|
||
|
model: torch.nn.Module,
|
||
|
named_modules: Dict[str, torch.nn.Module],
|
||
|
graph: Graph,
|
||
|
obs_or_fq_map: Dict[EdgeOrNode, ObserverOrFakeQuantize],
|
||
|
is_qat: bool,
|
||
|
) -> Optional[Node]:
|
||
|
"""
|
||
|
If `node` needs an output observer, creates it, inserts it into `graph`
|
||
|
and returns it.
|
||
|
|
||
|
If `node` does not need an output observer, returns None.
|
||
|
|
||
|
Note: inserting dynamic quantization ops for output is not supported in fx graph mode
|
||
|
quantization code path right now
|
||
|
"""
|
||
|
assert node.op != 'output', 'observer insertion for outputs is handled elsewhere'
|
||
|
|
||
|
is_standalone_module = False
|
||
|
if "quantization_annotation" in node.meta:
|
||
|
output_act_obs_or_fq = _create_obs_or_fq_from_qspec(
|
||
|
node.meta["quantization_annotation"].output_qspec, obs_or_fq_map, is_qat
|
||
|
)
|
||
|
else:
|
||
|
assert "target_dtype_info" in node.meta
|
||
|
is_standalone_module = node.meta["target_dtype_info"].get("_is_standalone_module", False)
|
||
|
output_act_obs_or_fq_ctr = node.meta["target_dtype_info"].get("output_act_obs_or_fq_ctr")
|
||
|
output_act_obs_or_fq = output_act_obs_or_fq_ctr() if output_act_obs_or_fq_ctr else None
|
||
|
target_dtype, target_is_dynamic = _get_dtype_and_is_dynamic(output_act_obs_or_fq)
|
||
|
# uncomment after we support reuse_input_obs_or_fq properly by having separate
|
||
|
# implemntations for this key instead of reusing the input_output_share_observers
|
||
|
# code
|
||
|
# reuse_input_obs_or_fq = node.meta["target_dtype_info"].get("reuse_input_obs_or_fq", False)
|
||
|
# for now we set this to False since reuse_input_obs_or_fq for
|
||
|
# the output of a node is implementation in the same code path as observer sharing,
|
||
|
# we should refactor this part to make it clearer in the future
|
||
|
# and we would be able to read this from config directly
|
||
|
reuse_input_obs_or_fq = False
|
||
|
|
||
|
# Note: prev_output_dtype = torch.float and prev_output_is_dynamic=False
|
||
|
# because the prev_output is the output of an fp32 op, althought technically
|
||
|
# we should get the dtype of the output from node.meta["val"] in the future
|
||
|
# if we deprecate fx graph mode quantization
|
||
|
needs_obs_or_fq = _needs_obs_or_fq(torch.float, False, target_dtype, target_is_dynamic, reuse_input_obs_or_fq)
|
||
|
# currently the activation in QConfig(activation=...,) is for both input
|
||
|
# and output, and when the activation is configured to be dynamic quantization
|
||
|
# e.g. PlaceholderObserver(dtype=torch.quint8, is_dynamic=True, ...), it means
|
||
|
# the input should by dynamically quantized, but output should not be quantized
|
||
|
#
|
||
|
# there is no way we can specify different observer/fq for input and output
|
||
|
# activation through QConfig today, this limitation is lifted in the
|
||
|
# quantizer/annotation API in pytorch 2.0 export quantization code path,
|
||
|
# but since this code is reused, annotating output to be dynamically quantized
|
||
|
# would not work either for that.
|
||
|
# we can change QConfig to support input/output activation if we want
|
||
|
# to remove the following check, or if we can deprecate fx graph mode quantization
|
||
|
if target_is_dynamic:
|
||
|
needs_obs_or_fq = False
|
||
|
|
||
|
# we never insert observers to output of standalone module, we assume
|
||
|
# if needed, they are inserted inside the standalone module
|
||
|
needs_obs_or_fq = needs_obs_or_fq and \
|
||
|
(not is_standalone_module)
|
||
|
|
||
|
if needs_obs_or_fq:
|
||
|
obs_or_fq_map[node] = output_act_obs_or_fq
|
||
|
return _insert_obs_or_fq(node, output_act_obs_or_fq, model, named_modules, graph)
|
||
|
else:
|
||
|
return None
|
||
|
|
||
|
def _maybe_insert_observers_before_graph_output(
|
||
|
graph_output_node: Node,
|
||
|
model: torch.nn.Module,
|
||
|
named_modules: Dict[str, torch.nn.Module],
|
||
|
graph: Graph,
|
||
|
obs_or_fq_map: Dict[EdgeOrNode, ObserverOrFakeQuantize],
|
||
|
is_qat: bool,
|
||
|
) -> None:
|
||
|
"""
|
||
|
If the output needs to be quantized and there are any nodes
|
||
|
in the output which are not already observed, inserts observers
|
||
|
for those nodes.
|
||
|
"""
|
||
|
|
||
|
def _recursive_maybe_replace_node_with_obs(
|
||
|
maybe_node: Argument,
|
||
|
model: torch.nn.Module,
|
||
|
named_modules: Dict[str, torch.nn.Module],
|
||
|
graph: Graph,
|
||
|
) -> Argument:
|
||
|
"""
|
||
|
Navigate an arbitrary data structure of lists, tuples, dicts.
|
||
|
For each container type, recurse on all inputs. Once any Node
|
||
|
is found, insert an observer if needed and do not recurse further.
|
||
|
|
||
|
For example, given a structure of
|
||
|
|
||
|
{'foo1': [[bar1]], 'foo2': {'foo3': [[[bar3]]]}}
|
||
|
|
||
|
we recurse down to bar1 and bar3, observe them if necessary,
|
||
|
and if we inserted an observer then replace the original node
|
||
|
with its observer.
|
||
|
|
||
|
Returns the data structure with all nodes needing observation being
|
||
|
replaced by their observers.
|
||
|
"""
|
||
|
if isinstance(maybe_node, Node):
|
||
|
# check dtype of this node
|
||
|
arg_as_output_target_dtype = _get_arg_target_dtype_as_output(maybe_node, named_modules, obs_or_fq_map, is_qat)
|
||
|
observer_mod = None
|
||
|
arg_as_input_target_dtype = torch.float
|
||
|
if "target_dtype_info" in maybe_node.meta:
|
||
|
observer_cls = maybe_node.meta["target_dtype_info"].get("input_act_obs_or_fq_ctr", None)
|
||
|
if observer_cls is not None:
|
||
|
observer_mod = observer_cls()
|
||
|
arg_as_input_target_dtype = observer_mod.dtype
|
||
|
# TODO: this does not handle dynamic quantization yet
|
||
|
need_obs = (
|
||
|
arg_as_output_target_dtype != arg_as_input_target_dtype and
|
||
|
arg_as_input_target_dtype != torch.float
|
||
|
)
|
||
|
if need_obs:
|
||
|
assert observer_mod is not None
|
||
|
# insert observer
|
||
|
observer_node = _insert_obs_or_fq(
|
||
|
maybe_node, observer_mod, model, named_modules, graph)
|
||
|
return observer_node
|
||
|
else:
|
||
|
return maybe_node
|
||
|
elif isinstance(maybe_node, (list, tuple)):
|
||
|
results = []
|
||
|
for inner_node in maybe_node:
|
||
|
results.append(_recursive_maybe_replace_node_with_obs(
|
||
|
inner_node, model, named_modules, graph))
|
||
|
if isinstance(maybe_node, list):
|
||
|
return results
|
||
|
else:
|
||
|
return tuple(results)
|
||
|
elif isinstance(maybe_node, dict):
|
||
|
results_dict = {}
|
||
|
for k, inner_v in maybe_node.items():
|
||
|
results_dict[k] = _recursive_maybe_replace_node_with_obs(
|
||
|
inner_v, model, named_modules, graph)
|
||
|
return results_dict
|
||
|
elif maybe_node is None:
|
||
|
return None
|
||
|
else:
|
||
|
raise Exception("Unhandled type for returned node:", maybe_node)
|
||
|
|
||
|
new_args = []
|
||
|
for old_arg in graph_output_node.args:
|
||
|
new_args.append(
|
||
|
_recursive_maybe_replace_node_with_obs(
|
||
|
old_arg, model, named_modules, graph))
|
||
|
|
||
|
graph_output_node.args = tuple(new_args) # type: ignore[assignment]
|
||
|
|
||
|
|
||
|
def _maybe_propagate_dtype_for_node(
|
||
|
node: Node,
|
||
|
target_dtype: Union[torch.dtype, type],
|
||
|
node_name_to_match_result_with_qconfig: Dict[str, _MatchResultWithQConfig],
|
||
|
) -> None:
|
||
|
"""
|
||
|
Assigns `target_dtype` to `node`, setting `is_dynamic` to False. If `node`
|
||
|
is a general tensor shape op, also call this function recursively on
|
||
|
the first argument, to propagate the dtype to the caller.
|
||
|
"""
|
||
|
node.meta["target_dtype_info"]["input_act_obs_or_fq_ctr"] = None
|
||
|
node.meta["target_dtype_info"]["output_act_obs_or_fq_ctr"] = None
|
||
|
# if this is a copy node, propagate to first arg
|
||
|
root_node, _, pattern, qhandler, qconfig = node_name_to_match_result_with_qconfig.get(
|
||
|
node.name, (None, None, None, None, None))
|
||
|
# TODO: probably need to remove `is_general_tensor_value_op`
|
||
|
if qhandler is not None and qhandler.is_general_tensor_value_op():
|
||
|
prev_node = node.args[0]
|
||
|
if isinstance(prev_node, Node):
|
||
|
_maybe_propagate_dtype_for_node(
|
||
|
prev_node, target_dtype, node_name_to_match_result_with_qconfig)
|
||
|
|
||
|
def propagate_dtypes_for_known_nodes(
|
||
|
graph: Graph,
|
||
|
node_name_to_match_result_with_qconfig: Dict[str, _MatchResultWithQConfig],
|
||
|
) -> None:
|
||
|
"""
|
||
|
Currently we assume that inputs to the graph are either `torch.float` or
|
||
|
`torch.quint8`, which is not always correct. For ops such as
|
||
|
`x.masked_fill(mask, value)`, we know that the dtype of `mask` is a
|
||
|
`BoolTensor`. Propagate this information throughout the graph.
|
||
|
|
||
|
Note: not all dtypes in the graph will be correct after this pass, but a
|
||
|
higher percentage of them will be correct. Hopefully in the future we can
|
||
|
replace this with a better way to reason about dtypes of tensors.
|
||
|
"""
|
||
|
for node in graph.nodes:
|
||
|
non_observable_arg_dict = get_non_observable_arg_indexes_and_types(node)
|
||
|
|
||
|
for arg_type in non_observable_arg_dict:
|
||
|
non_observable_indices = non_observable_arg_dict[arg_type](node)
|
||
|
|
||
|
for index in non_observable_indices:
|
||
|
arg = node.args[index]
|
||
|
|
||
|
# when an argument is a tuple, it does not show up as another node so we need to go through
|
||
|
# all elements of the tuple manually
|
||
|
if isinstance(arg, (tuple, list)):
|
||
|
arg_list = list(arg)
|
||
|
else:
|
||
|
arg_list = [arg]
|
||
|
|
||
|
for cur_arg in arg_list:
|
||
|
# hard coded arguments show up but aren't `Node` typed and do not need dtype propagated
|
||
|
if isinstance(cur_arg, torch.fx.node.Node):
|
||
|
_maybe_propagate_dtype_for_node(
|
||
|
cur_arg, arg_type, node_name_to_match_result_with_qconfig)
|
||
|
|
||
|
def _maybe_make_input_output_share_observers(
|
||
|
node: Node,
|
||
|
model: torch.nn.Module,
|
||
|
named_modules: Dict[str, torch.nn.Module],
|
||
|
) -> bool:
|
||
|
"""
|
||
|
Ensures that we share an observer
|
||
|
for all input arguments as well as the output argument. In detail, given
|
||
|
a graph of
|
||
|
|
||
|
x0 -> obs0 -> op -> x2
|
||
|
/
|
||
|
x1 -> obs1 /
|
||
|
|
||
|
where node obs0 points to observer instance observer0,
|
||
|
obs1 points to observer1 and obs2 points to observer2, we make nodes obs1
|
||
|
and ob2 point to observer0.
|
||
|
Returns: whether the operation succeeded or not
|
||
|
"""
|
||
|
first_arg = None
|
||
|
# find the first non-Tensor arg
|
||
|
for i in range(len(node.args)):
|
||
|
if isinstance(node.args[i], (Node, list, tuple)):
|
||
|
first_arg = node.args[i]
|
||
|
break
|
||
|
|
||
|
# if there is no non-Tensor arg, return directly
|
||
|
if first_arg is None:
|
||
|
return False
|
||
|
|
||
|
if isinstance(first_arg, (list, tuple)):
|
||
|
first_arg_arg = first_arg[0]
|
||
|
elif isinstance(first_arg, Node):
|
||
|
first_arg_arg = first_arg
|
||
|
else:
|
||
|
return False
|
||
|
|
||
|
# if we have a graph such as
|
||
|
# observed_node -> non_observed_node -> cat
|
||
|
# we need to navigate up to the first observer
|
||
|
iteration_guard = 0
|
||
|
while not _is_activation_post_process_node(first_arg_arg, named_modules):
|
||
|
if not isinstance(first_arg_arg, Node):
|
||
|
return False
|
||
|
# did not find an activation_post_process for the op
|
||
|
if first_arg_arg.op == "placeholder":
|
||
|
return False
|
||
|
# trace back the args until we found the first Tensor/Node
|
||
|
trace_back_node = None
|
||
|
for i in range(len(first_arg_arg.args)):
|
||
|
trace_back_node = first_arg_arg.args[i]
|
||
|
if isinstance(trace_back_node, Node):
|
||
|
break
|
||
|
if trace_back_node is None:
|
||
|
return False
|
||
|
first_arg_arg = trace_back_node
|
||
|
|
||
|
iteration_guard += 1
|
||
|
if iteration_guard > 10000:
|
||
|
raise AssertionError('Unable to find observer of previous node')
|
||
|
|
||
|
assert isinstance(first_arg_arg, Node)
|
||
|
target_to_use = first_arg_arg.target
|
||
|
assert isinstance(target_to_use, str)
|
||
|
obs_mod_to_use = named_modules[target_to_use]
|
||
|
|
||
|
if isinstance(first_arg, (list, tuple)):
|
||
|
# set all other input observer nodes to use that module
|
||
|
for input_idx, input_arg in enumerate(first_arg):
|
||
|
if input_idx == 0:
|
||
|
continue
|
||
|
iteration_guard = 0
|
||
|
while not _is_activation_post_process_node(input_arg, named_modules):
|
||
|
# failed to trace back since no input arg for the current node
|
||
|
if len(input_arg.args) < 1:
|
||
|
return False
|
||
|
input_arg = input_arg.args[0]
|
||
|
iteration_guard += 1
|
||
|
if iteration_guard > 10000:
|
||
|
raise AssertionError('Unable to find observer of previous node')
|
||
|
|
||
|
parent_name, name = _parent_name(input_arg.target)
|
||
|
setattr(named_modules[parent_name], name, obs_mod_to_use)
|
||
|
|
||
|
# set the output observer node to use that module
|
||
|
for output_obs_node in node.users.keys():
|
||
|
assert _is_activation_post_process_node(output_obs_node, named_modules)
|
||
|
parent_name, name = _parent_name(output_obs_node.target)
|
||
|
setattr(named_modules[parent_name], name, obs_mod_to_use)
|
||
|
|
||
|
# TODO(future PR): delete the orphaned observer modules
|
||
|
return True
|
||
|
|
||
|
def _remove_output_observer(
|
||
|
node: Node,
|
||
|
model: torch.nn.Module,
|
||
|
named_modules: Dict[str, torch.nn.Module]):
|
||
|
items = list(node.users.items())
|
||
|
for output_obs_node, _ in items:
|
||
|
assert _is_activation_post_process_node(output_obs_node, named_modules)
|
||
|
output_obs_node.replace_all_uses_with(node)
|
||
|
model.graph.erase_node(output_obs_node) # type: ignore[union-attr, operator]
|
||
|
|
||
|
def _swap_custom_module_to_observed(
|
||
|
node: Node,
|
||
|
qconfig: QConfigAny,
|
||
|
named_modules: Dict[str, torch.nn.Module],
|
||
|
prepare_custom_config: PrepareCustomConfig):
|
||
|
custom_module = named_modules[node.target] # type: ignore[index]
|
||
|
custom_module_class_mapping = prepare_custom_config.float_to_observed_mapping
|
||
|
observed_custom_module_class = \
|
||
|
get_swapped_custom_module_class(
|
||
|
custom_module, custom_module_class_mapping, qconfig)
|
||
|
observed_custom_module = \
|
||
|
observed_custom_module_class.from_float(custom_module)
|
||
|
parent_name, name = _parent_name(node.target)
|
||
|
setattr(named_modules[parent_name], name, observed_custom_module)
|
||
|
|
||
|
def insert_observers_for_model(
|
||
|
model: GraphModule,
|
||
|
node_name_to_match_result_with_qconfig: Dict[str, _MatchResultWithQConfig],
|
||
|
node_name_to_qconfig: Dict[str, QConfigAny],
|
||
|
prepare_custom_config: PrepareCustomConfig,
|
||
|
equalization_config_map: Dict[str, Any],
|
||
|
backend_config: BackendConfig,
|
||
|
observed_node_names: Set[str],
|
||
|
is_qat: bool,
|
||
|
) -> Optional[Node]:
|
||
|
"""
|
||
|
Inserts observers, using the following high level algorithm:
|
||
|
|
||
|
For each node in the graph:
|
||
|
1. determine the target dtype of this node in the quantized graph, and save
|
||
|
it for future steps
|
||
|
2. determine the target dtype or all args and kwargs of this node
|
||
|
3. if any arg or kwarg's target dtype does not match the current node's
|
||
|
dtype, insert an observer
|
||
|
4. if the current node needs an output observer, insert it
|
||
|
|
||
|
For example:
|
||
|
|
||
|
- starting graph:
|
||
|
x0 -> linear -> x1
|
||
|
|
||
|
- observed graph after processing x0:
|
||
|
x0(fp32)
|
||
|
|
||
|
- observed graph after processing linear:
|
||
|
x0(fp32) -> x0_obs0(int8) -> linear(int8) -> linear_obs0(int8)
|
||
|
|
||
|
- observed graph after processing x1:
|
||
|
x0(fp32) -> x0_obs0(int8) -> linear(int8) -> linear_obs0(int8) -> x1
|
||
|
|
||
|
After a node is processed, the naive observer placement is guaranteed to be
|
||
|
complete for that node and all of its predecessors. There can be future
|
||
|
passes which optimize the graph by deduplicating observers, etc.
|
||
|
"""
|
||
|
|
||
|
# node.meta["target_dtype_info"] stores the target dtype information
|
||
|
# that's derived from qconfig for the Node, for example, if we have
|
||
|
# a conv2d node that has a qconfig
|
||
|
# qconfig = QConfig(activation=..., weight=...)
|
||
|
# # information for input and bias node omitted
|
||
|
# # for getattr node
|
||
|
# # weight = getattr(self, 'weight')
|
||
|
# weight.meta["target_dtype_info"] = {
|
||
|
# 'output_act_obs_or_fq_ctr': qconfig.weight,
|
||
|
# }
|
||
|
# # for conv2d node
|
||
|
# # conv2d = call_function[target=torch.nn.functional.conv2d](
|
||
|
# # args=(input, weight, bias))
|
||
|
# conv2d.meta["target_dtype_info"] = {
|
||
|
# 'input_act_obs_or_fq_ctr': qconfig.activation
|
||
|
# 'weight_obs_or_fq_ctr': qconfig.weight,
|
||
|
# 'bias_obs_or_fq_ctr': PlaceholderObserver.with_args(dtype=torch.float32),
|
||
|
# 'output_act_obs_or_fq_ctr': qconfig.activation,
|
||
|
# }
|
||
|
#
|
||
|
cache_for_no_tensor_check: Dict[Node, bool] = {}
|
||
|
|
||
|
# first, populate the dtype map based only on qconfig and qhandler
|
||
|
# this assumes:
|
||
|
# graph inputs are fp32 by default, and int8 where overriden
|
||
|
# other nodes output dtype is specified by the qconfig
|
||
|
named_modules = dict(model.named_modules(remove_duplicate=False))
|
||
|
|
||
|
input_quantized_idxs: List[int] = prepare_custom_config.input_quantized_indexes
|
||
|
output_quantized_idxs: List[int] = prepare_custom_config.output_quantized_indexes
|
||
|
processed_nodes: Set[Node] = set()
|
||
|
# initialize target_dtype_info
|
||
|
for node in model.graph.nodes:
|
||
|
node.meta["target_dtype_info"] = copy.copy(_DEFAULT_FP32_QCONFIG_FOR_TARGET_DTYPE_INFO)
|
||
|
|
||
|
inputs_seen_counter = 0
|
||
|
outputs_seen_counter = 0
|
||
|
placeholder_node_to_input_index: Dict[Node, int] = {}
|
||
|
# TODO: we probably don't need this counter since each graph will only have
|
||
|
# one output node?
|
||
|
output_node_to_output_index: Dict[Node, int] = {}
|
||
|
for node in model.graph.nodes:
|
||
|
if node.op == "placeholder":
|
||
|
placeholder_node_to_input_index[node] = inputs_seen_counter
|
||
|
inputs_seen_counter += 1
|
||
|
if node.op == "output":
|
||
|
output_node_to_output_index[node] = outputs_seen_counter
|
||
|
outputs_seen_counter += 1
|
||
|
|
||
|
# Step 1, set the observer or fake quantize module constructor for each node in the
|
||
|
# matched_node_pattern
|
||
|
|
||
|
for match_res_with_qconfig in node_name_to_match_result_with_qconfig.values():
|
||
|
last_node, matched_node_pattern, pattern, qhandler, qconfig = match_res_with_qconfig
|
||
|
assert qhandler is not None
|
||
|
_set_target_dtype_info_for_matched_node_pattern(
|
||
|
matched_node_pattern,
|
||
|
last_node,
|
||
|
qconfig,
|
||
|
qhandler,
|
||
|
backend_config,
|
||
|
named_modules,
|
||
|
cache_for_no_tensor_check,
|
||
|
processed_nodes
|
||
|
)
|
||
|
|
||
|
# Step 2. Special cases for some operators, we might be able to remove them
|
||
|
# in the future if we know dtype information of each node better
|
||
|
|
||
|
# Step 2.1. some settings are not based on patterns, we need to process each node
|
||
|
# instead
|
||
|
for node in model.graph.nodes:
|
||
|
if node.op == "placeholder" and placeholder_node_to_input_index[node] in input_quantized_idxs:
|
||
|
# users are not supposed to call calculate_qparams on PlaceholderObserver, and
|
||
|
# this is OK because we are using this as a way to encode the dtypes of input
|
||
|
# tensor, we won't actually insert these observers in the graph and won't
|
||
|
# actually call calculate_qparams
|
||
|
node.meta["target_dtype_info"] = copy.copy(_DEFAULT_QUINT8_QCONFIG_FOR_TARGET_DTYPE_INFO)
|
||
|
elif node.op in ("call_module", "call_method", "call_function"):
|
||
|
args_have_no_tensors = \
|
||
|
all_node_args_have_no_tensors(
|
||
|
node, named_modules, cache_for_no_tensor_check)
|
||
|
if args_have_no_tensors:
|
||
|
node.meta["target_dtype_info"] = {
|
||
|
"input_act_obs_or_fq_ctr": None,
|
||
|
"output_act_obs_or_fq_ctr": None,
|
||
|
}
|
||
|
elif node.op == "output" and output_node_to_output_index[node] in output_quantized_idxs:
|
||
|
# TODO(future PR): update the output_quantized_idxs API to match
|
||
|
# arbitrary data structures. There is always a single output, and
|
||
|
# that output can have arbitrary nesting of values. List[int] is
|
||
|
# not the right data type for this.
|
||
|
|
||
|
# TODO(future PR): support more dtypes in model outputs, if necessary
|
||
|
node.meta["target_dtype_info"] = copy.copy(_DEFAULT_QUINT8_QCONFIG_FOR_TARGET_DTYPE_INFO)
|
||
|
|
||
|
# Step 2.2, for nodes with known input dtypes, propagate them throughout the
|
||
|
# graph. For example, if there is a call such as
|
||
|
# x1 = x0.masked_fill(mask, 1)
|
||
|
# we propagate the type of mask to be torch.bool
|
||
|
propagate_dtypes_for_known_nodes(model.graph, node_name_to_match_result_with_qconfig)
|
||
|
|
||
|
# Step 3, check if the requested target_dtype_info is supported by backend or not
|
||
|
# if not, we'll reset the target_dtye_info to use the default (float Tensor)
|
||
|
|
||
|
# reset the counters and set of processed_nodes
|
||
|
processed_nodes: Set[Node] = set()
|
||
|
for match_res_with_qconfig in node_name_to_match_result_with_qconfig.values():
|
||
|
last_node, matched_node_pattern, pattern, qhandler, qconfig = match_res_with_qconfig
|
||
|
is_supported_by_backend = _is_pattern_dtype_config_and_qconfig_supported_by_backend(
|
||
|
pattern, matched_node_pattern, qconfig, backend_config)
|
||
|
assert qhandler is not None
|
||
|
|
||
|
# get output_act_dtype so that we don't also reset the special typed nodes
|
||
|
# TODO: we might want to handle these more uniformly with the default path
|
||
|
# this can be improved if we can use node.meta["val"]
|
||
|
output_act_or_fq_ctr = node.meta["target_dtype_info"]["output_act_obs_or_fq_ctr"]
|
||
|
output_act_or_fq = output_act_or_fq_ctr() if output_act_or_fq_ctr else None
|
||
|
output_act_dtype, _ = _get_dtype_and_is_dynamic(output_act_or_fq)
|
||
|
if not is_supported_by_backend and output_act_dtype not in [None, int, float, torch.bool]:
|
||
|
# restore target_dtype_info to default if it is not supported by backend
|
||
|
_set_target_dtype_info_for_matched_node_pattern(
|
||
|
matched_node_pattern,
|
||
|
last_node,
|
||
|
torch.ao.quantization.qconfig._default_fp32_placeholder_qconfig,
|
||
|
None,
|
||
|
backend_config,
|
||
|
named_modules,
|
||
|
cache_for_no_tensor_check,
|
||
|
processed_nodes
|
||
|
)
|
||
|
|
||
|
# After this point, the current node and all of its arguments
|
||
|
# have a target_dtype_info assigned. Now, we insert observers for inputs
|
||
|
# of this node (if needed for this node), and the output of this node
|
||
|
# (if needed for this node).
|
||
|
|
||
|
# Since we are mutating the graph as we go, we iterate over the original
|
||
|
# nodes before observer insertion, instead of model.graph.nodes.
|
||
|
nodes_before_observation = list(model.graph.nodes)
|
||
|
|
||
|
# Avoid duplicates custom module swaps for multiple nodes with same target.
|
||
|
custom_module_names_already_swapped: Set[str] = set()
|
||
|
|
||
|
# TODO: reuse placeholder_node_to_input_index and output_node_to_output_index
|
||
|
# reset inputs/outputs counters
|
||
|
inputs_seen_counter = 0
|
||
|
outputs_seen_counter = 0
|
||
|
results_node = None
|
||
|
obs_or_fq_map: Dict[EdgeOrNode, ObserverOrFakeQuantize] = {}
|
||
|
|
||
|
# TODO: change this to insert obs/fq by pattern instead of by node
|
||
|
for node in nodes_before_observation:
|
||
|
|
||
|
if node.op == 'placeholder':
|
||
|
# if a graph input is in fp32, it does not need observation
|
||
|
# if a graph input is in int8, we assume the observation happens
|
||
|
# outside of the graph, and no additional observation is needed
|
||
|
pass
|
||
|
|
||
|
elif node.op in ('call_module', 'call_method', 'call_function', 'output'):
|
||
|
# check for matches
|
||
|
last_node, matched_node_pattern, pattern, qhandler, qconfig = (
|
||
|
node_name_to_match_result_with_qconfig.get(node.name, (None, None, None, None, None)) # type: ignore[assignment]
|
||
|
)
|
||
|
equalization_qconfig = equalization_config_map.get(node.name, None)
|
||
|
|
||
|
this_node_dtype_info = node.meta["target_dtype_info"]
|
||
|
if "val" in node.meta:
|
||
|
output_is_a_tensor = (
|
||
|
this_node_dtype_info is not None and
|
||
|
isinstance(node.meta["val"], FakeTensor)
|
||
|
)
|
||
|
else:
|
||
|
output_is_a_tensor = this_node_dtype_info is not None
|
||
|
|
||
|
skip_inserting_observers = (
|
||
|
(qconfig is None) or
|
||
|
not output_is_a_tensor
|
||
|
) and (
|
||
|
not node.op == 'output'
|
||
|
)
|
||
|
|
||
|
# TODO: take a closer look to see if we can remove this check
|
||
|
# right now it is here because of `observed_node_names`, we are using
|
||
|
# it as an indicator for swapping the modules to reference modules in
|
||
|
# convert
|
||
|
is_supported_by_backend = _is_pattern_dtype_config_and_qconfig_supported_by_backend(
|
||
|
pattern, matched_node_pattern, qconfig, backend_config)
|
||
|
|
||
|
if not skip_inserting_observers and is_supported_by_backend:
|
||
|
named_modules = dict(model.named_modules(remove_duplicate=False))
|
||
|
if node.op != 'output':
|
||
|
assert matched_node_pattern is not None
|
||
|
# add matched nodes to the observed node name set
|
||
|
_add_matched_node_name_to_set(matched_node_pattern, observed_node_names)
|
||
|
|
||
|
# This is currently only used for equalization.
|
||
|
# Checks if the current node is in a branch in which the two
|
||
|
# first layers are both being quantized.
|
||
|
#
|
||
|
# ex. conv2
|
||
|
# /
|
||
|
# x -> conv1
|
||
|
#
|
||
|
# If this is the case, we will not apply equalization to the
|
||
|
# initial two layers.
|
||
|
is_quantized_branch = False
|
||
|
if (
|
||
|
len(node.args) > 0 and
|
||
|
isinstance(node.args[0], Node) and
|
||
|
len(node.args[0].users) > 1
|
||
|
):
|
||
|
for user in node.args[0].users:
|
||
|
# Checks if there exists another user being quantized
|
||
|
is_user_quantized = (
|
||
|
node_name_to_qconfig.get(user.name, None) is not None or
|
||
|
(user.op == 'call_module' and isinstance(named_modules[str(user.target)], ObserverBase))
|
||
|
)
|
||
|
if user != node and is_user_quantized:
|
||
|
is_quantized_branch = True
|
||
|
|
||
|
pattern_to_root_node_getter = get_fusion_pattern_to_root_node_getter(backend_config)
|
||
|
root_node_getter = pattern_to_root_node_getter.get(pattern, _default_root_node_getter)
|
||
|
root_node = root_node_getter(matched_node_pattern)
|
||
|
is_input_node_of_the_pattern = node is root_node
|
||
|
if is_input_node_of_the_pattern:
|
||
|
# this modifies node inplace
|
||
|
_maybe_insert_input_observers_for_node(
|
||
|
node, qconfig, model, named_modules, model.graph,
|
||
|
qhandler,
|
||
|
prepare_custom_config,
|
||
|
obs_or_fq_map,
|
||
|
is_qat,
|
||
|
backend_config)
|
||
|
|
||
|
# insert equalization input observers if needed
|
||
|
_maybe_insert_input_equalization_observers_for_node(
|
||
|
node, equalization_qconfig, model, named_modules, model.graph,
|
||
|
is_quantized_branch)
|
||
|
|
||
|
is_last_node_of_pattern = node is last_node
|
||
|
input_output_share_observers = node.meta["target_dtype_info"].get("input_output_share_observers", False)
|
||
|
reuse_input_obs_or_fq = node.meta["target_dtype_info"].get("reuse_input_obs_or_fq", False)
|
||
|
|
||
|
if is_last_node_of_pattern:
|
||
|
if _is_custom_module_lstm(node, named_modules, qconfig, qhandler):
|
||
|
# Currently custom module outputs are assumed to be already quantized,
|
||
|
# so we need to insert a DeQuantStub after the output. For custom module
|
||
|
# LSTM specifically, the outputs are also a nested tuple, so we must first
|
||
|
# break down the tuple to insert DeQuantStubs after the internal nodes.
|
||
|
|
||
|
# TODO: This currently diverges from how custom modules are handled today,
|
||
|
# where we insert observers after the output instead of DeQuantStubs, and
|
||
|
# replace these observers with "dequantize" nodes during convert. Conceptually,
|
||
|
# these output observers are the same as DeQuantStubs. In the future, we
|
||
|
# should resolve this inconsistency by inserting DeQuantStubs for all custom
|
||
|
# modules, not just for LSTM.
|
||
|
_insert_dequant_stubs_for_custom_module_lstm_output(node, model, named_modules, model.graph)
|
||
|
if node.target not in custom_module_names_already_swapped:
|
||
|
custom_module_names_already_swapped.add(node.target)
|
||
|
_swap_custom_module_to_observed(node, qconfig, named_modules, prepare_custom_config)
|
||
|
else:
|
||
|
# this returns the new observer node if it was needed
|
||
|
maybe_output_obs_node = _maybe_insert_output_observer_for_node(
|
||
|
node, model, named_modules, model.graph, obs_or_fq_map, is_qat)
|
||
|
|
||
|
if maybe_output_obs_node is not None:
|
||
|
# Update users of original node to use the output observer
|
||
|
# instead. For example, change
|
||
|
#
|
||
|
# next_node
|
||
|
# /
|
||
|
# cur_node -> obs
|
||
|
#
|
||
|
# to
|
||
|
#
|
||
|
# next_node
|
||
|
# /
|
||
|
# cur_node -> obs
|
||
|
#
|
||
|
# We need to save orig users before updating uses because
|
||
|
# the list of users will change as we update uses
|
||
|
orig_users = list(node.users.keys())
|
||
|
for user_node in orig_users:
|
||
|
if user_node is maybe_output_obs_node:
|
||
|
continue
|
||
|
user_node.replace_input_with(node, maybe_output_obs_node)
|
||
|
|
||
|
_is_observer_in_same_graph_ = _is_observer_in_same_graph(
|
||
|
node, named_modules, obs_or_fq_map, is_qat)
|
||
|
|
||
|
# for ops whose inputs and outputs share observer/fqs, we modify the graph
|
||
|
# to make all inputs and outputs use the first input's
|
||
|
# observer/fq
|
||
|
if (input_output_share_observers and _is_observer_in_same_graph_) or \
|
||
|
reuse_input_obs_or_fq:
|
||
|
if not _maybe_make_input_output_share_observers(node, model, named_modules):
|
||
|
_remove_output_observer(node, model, named_modules)
|
||
|
|
||
|
if qhandler is not None and qhandler.is_custom_module():
|
||
|
if node.target not in custom_module_names_already_swapped:
|
||
|
custom_module_names_already_swapped.add(node.target)
|
||
|
_swap_custom_module_to_observed(node, qconfig, named_modules, prepare_custom_config)
|
||
|
|
||
|
else: # output
|
||
|
_maybe_insert_observers_before_graph_output(node, model, named_modules, model.graph, obs_or_fq_map, is_qat)
|
||
|
|
||
|
#
|
||
|
# After this point, the current node has input and output observers
|
||
|
# that it needs for itself inserted.
|
||
|
#
|
||
|
|
||
|
# increment the counters, so future inputs and outputs are assigned
|
||
|
# correct dtypes
|
||
|
if node.op == 'placeholder':
|
||
|
inputs_seen_counter += 1
|
||
|
elif node.op == 'output':
|
||
|
outputs_seen_counter += 1
|
||
|
results_node = node
|
||
|
|
||
|
return results_node
|
||
|
|
||
|
def _run_prepare_fx_on_standalone_modules(
|
||
|
model: torch.nn.Module,
|
||
|
is_qat: bool,
|
||
|
named_modules: Dict[str, torch.nn.Module],
|
||
|
node_name_to_match_result_with_qconfig: Any,
|
||
|
prepare_custom_config: PrepareCustomConfig,
|
||
|
backend_config: BackendConfig,
|
||
|
) -> None:
|
||
|
"""
|
||
|
Runs prepare_fx on each standalone module. Note: this does
|
||
|
not modify the graph, it just replaces the unobserved modules with
|
||
|
their observed versions.
|
||
|
"""
|
||
|
for (root_node, _, pattern, qhandler, qconfig) in node_name_to_match_result_with_qconfig.values():
|
||
|
if qhandler is None:
|
||
|
continue
|
||
|
elif not qhandler.is_standalone_module():
|
||
|
continue
|
||
|
|
||
|
sm_qconfig_mapping, sm_example_inputs, sm_prepare_custom_config, \
|
||
|
sm_backend_config = _get_standalone_module_configs(
|
||
|
root_node, named_modules, prepare_custom_config, qconfig, backend_config)
|
||
|
|
||
|
standalone_module = named_modules[root_node.target]
|
||
|
prepare = \
|
||
|
torch.ao.quantization.quantize_fx._prepare_standalone_module_fx # type: ignore[attr-defined]
|
||
|
observed_standalone_module = \
|
||
|
prepare(
|
||
|
standalone_module,
|
||
|
sm_qconfig_mapping,
|
||
|
is_qat,
|
||
|
example_inputs=sm_example_inputs,
|
||
|
prepare_custom_config=sm_prepare_custom_config,
|
||
|
backend_config=sm_backend_config)
|
||
|
parent_name, name = _parent_name(root_node.target)
|
||
|
setattr(named_modules[parent_name], name, observed_standalone_module)
|
||
|
named_modules[root_node.target] = observed_standalone_module
|
||
|
|
||
|
def _save_state(
|
||
|
observed: GraphModule,
|
||
|
node_name_to_qconfig: Dict[str, QConfigAny],
|
||
|
node_name_to_scope: Dict[str, Tuple[str, type]],
|
||
|
prepare_custom_config: PrepareCustomConfig,
|
||
|
equalization_node_name_to_qconfig: Dict[str, Any],
|
||
|
qconfig_mapping: QConfigMapping,
|
||
|
is_qat: bool,
|
||
|
observed_node_names: Set[str],
|
||
|
) -> None:
|
||
|
observed.meta["_observed_graph_module_attrs"] = (
|
||
|
ObservedGraphModuleAttrs(
|
||
|
node_name_to_qconfig=node_name_to_qconfig,
|
||
|
node_name_to_scope=node_name_to_scope,
|
||
|
prepare_custom_config=prepare_custom_config,
|
||
|
equalization_node_name_to_qconfig=equalization_node_name_to_qconfig,
|
||
|
qconfig_mapping=qconfig_mapping,
|
||
|
is_qat=is_qat,
|
||
|
observed_node_names=observed_node_names,
|
||
|
)
|
||
|
)
|
||
|
|
||
|
def prepare(
|
||
|
model: GraphModule,
|
||
|
qconfig_mapping: Union[QConfigMapping, Dict[str, Any]],
|
||
|
is_qat: bool,
|
||
|
node_name_to_scope: Dict[str, Tuple[str, type]],
|
||
|
example_inputs: Tuple[Any, ...],
|
||
|
prepare_custom_config: Union[PrepareCustomConfig, Dict[str, Any], None] = None,
|
||
|
_equalization_config: Union[QConfigMapping, Dict[str, Any], None] = None,
|
||
|
backend_config: Union[BackendConfig, Dict[str, Any], None] = None,
|
||
|
is_standalone_module: bool = False) -> GraphModule:
|
||
|
""" standalone_module means it a submodule that is not inlined in
|
||
|
parent module, and will be quantized separately as one unit.
|
||
|
|
||
|
How the standalone module is observed is specified by `input_quantized_idxs` and
|
||
|
`output_quantized_idxs` in the prepare_custom_config for the standalone module
|
||
|
Args:
|
||
|
node_name_to_scope: mapping from node name to the scope of the module which contains the node.
|
||
|
The scope is a tuple of fully qualified path of the module and the type of the module
|
||
|
Returns:
|
||
|
model(GraphModule): prepared standalone module
|
||
|
attributes related to standalone module
|
||
|
in model.meta["_observed_graph_module_attrs"]:
|
||
|
is_observed_standalone_module (bool): boolean value that shows whether the
|
||
|
current model is a observed standalone module or not
|
||
|
standalone_module_input_quantized_idxs(List[Int]): a list of
|
||
|
indexes for the graph input that is expected to be quantized,
|
||
|
same as input_quantized_idxs configuration provided
|
||
|
for the standalone module
|
||
|
standalone_module_output_quantized_idxs(List[Int]): a list of
|
||
|
indexs for the graph output that is quantized
|
||
|
same as input_quantized_idxs configuration provided
|
||
|
for the standalone module
|
||
|
"""
|
||
|
if prepare_custom_config is None:
|
||
|
prepare_custom_config = PrepareCustomConfig()
|
||
|
if _equalization_config is None:
|
||
|
_equalization_config = QConfigMapping()
|
||
|
|
||
|
if isinstance(qconfig_mapping, Dict):
|
||
|
warnings.warn(
|
||
|
"Passing a QConfig dictionary to prepare is deprecated and will not be supported "
|
||
|
"in a future version. Please pass in a QConfigMapping instead.")
|
||
|
qconfig_mapping = QConfigMapping.from_dict(qconfig_mapping)
|
||
|
|
||
|
if isinstance(_equalization_config, Dict):
|
||
|
warnings.warn(
|
||
|
"Passing a QConfig dictionary to prepare for equalization is deprecated and will not "
|
||
|
"be supported in a future version. Please pass in a QConfigMapping instead.")
|
||
|
_equalization_config = QConfigMapping.from_dict(_equalization_config)
|
||
|
|
||
|
if isinstance(prepare_custom_config, Dict):
|
||
|
warnings.warn(
|
||
|
"Passing a prepare_custom_config_dict to prepare is deprecated and will not be supported "
|
||
|
"in a future version. Please pass in a PrepareCustomConfig instead.")
|
||
|
prepare_custom_config = PrepareCustomConfig.from_dict(prepare_custom_config)
|
||
|
|
||
|
if isinstance(backend_config, Dict):
|
||
|
warnings.warn(
|
||
|
"Passing a backend_config_dict to prepare is deprecated and will not be supported "
|
||
|
"in a future version. Please pass in a BackendConfig instead.")
|
||
|
backend_config = BackendConfig.from_dict(backend_config)
|
||
|
|
||
|
assert isinstance(qconfig_mapping, QConfigMapping)
|
||
|
assert isinstance(_equalization_config, QConfigMapping)
|
||
|
qconfig_mapping = copy.deepcopy(qconfig_mapping)
|
||
|
_equalization_config = copy.deepcopy(_equalization_config)
|
||
|
|
||
|
# mapping from a tuple of nodes in reverse order to uninitialized
|
||
|
# QuantizeHandler subclass. For example,
|
||
|
# {
|
||
|
# # match a single node
|
||
|
# (<class 'torch.nn.modules.conv.Conv3d'>:
|
||
|
# <class 'torch.ao.quantization.fx.quantize.ConvRelu'>),
|
||
|
# # match multiple nodes in reverse order
|
||
|
# ((<function relu at 0x7f766a7360d0>, <built-in function add>):
|
||
|
# <class 'torch.ao.quantization.fx.quantize.Add'>),
|
||
|
# }
|
||
|
|
||
|
pattern_to_quantize_handler: Dict[Pattern, QuantizeHandler] = {}
|
||
|
if backend_config is None:
|
||
|
backend_config = get_native_backend_config()
|
||
|
pattern_to_quantize_handler = _get_pattern_to_quantize_handlers(backend_config)
|
||
|
pattern_to_quantize_handler = _sorted_patterns_dict(pattern_to_quantize_handler)
|
||
|
|
||
|
root_node_getter_mapping = \
|
||
|
get_fusion_pattern_to_root_node_getter(backend_config)
|
||
|
|
||
|
_update_qconfig_for_fusion(model, qconfig_mapping)
|
||
|
_update_qconfig_for_fusion(model, _equalization_config)
|
||
|
flattened_qconfig_dict = _get_flattened_qconfig_dict(qconfig_mapping)
|
||
|
# TODO: support regex as well
|
||
|
propagate_qconfig_(model, flattened_qconfig_dict, prepare_custom_config.to_dict())
|
||
|
|
||
|
if is_qat:
|
||
|
module_to_qat_module = get_module_to_qat_module(backend_config)
|
||
|
_qat_swap_modules(model, module_to_qat_module)
|
||
|
_update_qconfig_for_qat(qconfig_mapping, backend_config)
|
||
|
|
||
|
# mapping from fully qualified module name to module instance
|
||
|
# for example,
|
||
|
# {
|
||
|
# '': Model(...),
|
||
|
# 'linear': Linear(...),
|
||
|
# 'linear.weight_fake_quant': PerChannelMinMaxObserver(...),
|
||
|
# }
|
||
|
named_modules = dict(model.named_modules(remove_duplicate=False))
|
||
|
|
||
|
# fill node_name_to_qconfig, a map from node name to qconfig, used in _find_matches
|
||
|
equalization_node_name_to_qconfig = _generate_node_name_to_qconfig(
|
||
|
model, named_modules, model.graph, _equalization_config, node_name_to_scope)
|
||
|
node_name_to_qconfig = _generate_node_name_to_qconfig(model, named_modules, model.graph, qconfig_mapping, node_name_to_scope)
|
||
|
|
||
|
# match the patterns that will get quantized
|
||
|
standalone_module_names = list(prepare_custom_config.standalone_module_names.keys())
|
||
|
standalone_module_classes = list(prepare_custom_config.standalone_module_classes.keys())
|
||
|
|
||
|
custom_module_classes = get_custom_module_class_keys(prepare_custom_config.float_to_observed_mapping)
|
||
|
matches_without_qconfig = _find_matches(
|
||
|
model.graph, named_modules, pattern_to_quantize_handler, root_node_getter_mapping,
|
||
|
standalone_module_names, standalone_module_classes, custom_module_classes)
|
||
|
|
||
|
# map qconfig instances to matches
|
||
|
node_name_to_match_result_with_qconfig = {}
|
||
|
for node_name, match_without_qconfig in matches_without_qconfig.items():
|
||
|
match_with_qconfig = (*match_without_qconfig, node_name_to_qconfig[node_name])
|
||
|
node_name_to_match_result_with_qconfig[node_name] = match_with_qconfig
|
||
|
|
||
|
_run_prepare_fx_on_standalone_modules(
|
||
|
model, is_qat, named_modules, node_name_to_match_result_with_qconfig, prepare_custom_config, backend_config)
|
||
|
|
||
|
# record names for the set of observed node, so that in convert step
|
||
|
# we know whether we need to convert a floating point module to reference
|
||
|
# quantized module or not
|
||
|
observed_node_names: Set[str] = set()
|
||
|
|
||
|
result_node = insert_observers_for_model(
|
||
|
model,
|
||
|
node_name_to_match_result_with_qconfig,
|
||
|
node_name_to_qconfig,
|
||
|
prepare_custom_config,
|
||
|
equalization_node_name_to_qconfig,
|
||
|
backend_config,
|
||
|
observed_node_names,
|
||
|
is_qat,
|
||
|
)
|
||
|
model = GraphModule(model, model.graph)
|
||
|
|
||
|
_save_state(model, node_name_to_qconfig, node_name_to_scope,
|
||
|
prepare_custom_config, equalization_node_name_to_qconfig,
|
||
|
qconfig_mapping, is_qat, observed_node_names)
|
||
|
|
||
|
if is_standalone_module:
|
||
|
assert result_node is not None
|
||
|
assert isinstance(result_node.args[0], Node), \
|
||
|
"standalone module only supports returning simple value currently"\
|
||
|
"(not tuple, dict etc.)"
|
||
|
# these inputs are observed in parent
|
||
|
# converting List[int] to Tensor since module attribute is
|
||
|
# Union[Tensor, Module]
|
||
|
input_quantized_idxs: List[int] = prepare_custom_config.input_quantized_indexes
|
||
|
output_quantized_idxs: List[int] = prepare_custom_config.output_quantized_indexes
|
||
|
observed_graph_module_attrs = model.meta["_observed_graph_module_attrs"]
|
||
|
# inplace modification
|
||
|
observed_graph_module_attrs.is_observed_standalone_module = True
|
||
|
observed_graph_module_attrs.standalone_module_input_quantized_idxs = \
|
||
|
input_quantized_idxs
|
||
|
observed_graph_module_attrs.standalone_module_output_quantized_idxs = \
|
||
|
output_quantized_idxs
|
||
|
return model
|