490 lines
20 KiB
Python
490 lines
20 KiB
Python
import torch
|
|
from torch._subclasses import FakeTensor
|
|
from torch.ao.quantization.fx.prepare import (
|
|
_insert_obs_or_fq,
|
|
_save_state,
|
|
_is_activation_post_process_node,
|
|
_create_obs_or_fq_from_qspec,
|
|
)
|
|
from torch.fx import (
|
|
GraphModule,
|
|
Graph,
|
|
Node,
|
|
)
|
|
from torch.fx.node import Argument
|
|
|
|
from torch.ao.quantization import QConfigMapping
|
|
from torch.ao.quantization.qconfig import QConfigAny
|
|
from torch.ao.quantization.fx.custom_config import PrepareCustomConfig
|
|
from typing import Dict, Tuple, Union, Any, Optional
|
|
from torch.ao.quantization.quantizer import (
|
|
EdgeOrNode,
|
|
SharedQuantizationSpec,
|
|
QuantizationSpecBase,
|
|
)
|
|
from torch.ao.quantization import ObserverOrFakeQuantize
|
|
|
|
# TODO: make pt2e folder private?
|
|
__all__ = [
|
|
"prepare",
|
|
]
|
|
|
|
|
|
def _find_root_edge_or_node(edge_or_node: EdgeOrNode, shared_with_map: Dict[EdgeOrNode, EdgeOrNode]) -> EdgeOrNode:
|
|
"""Find the root node for the sharing tree
|
|
Args:
|
|
edge_or_node: edge/node that we want to find the root
|
|
shared_with_map: each edge/node points to the parent, the root node will points to itself
|
|
|
|
Returns:
|
|
root edge/node
|
|
"""
|
|
parent = shared_with_map[edge_or_node]
|
|
if parent == edge_or_node:
|
|
return edge_or_node
|
|
root = _find_root_edge_or_node(parent, shared_with_map)
|
|
# path compression
|
|
shared_with_map[edge_or_node] = root
|
|
return root
|
|
|
|
def _union(parent: EdgeOrNode, child: EdgeOrNode, shared_with_map: Dict[EdgeOrNode, EdgeOrNode]) -> None:
|
|
"""Merge the subtree for `child` with `parent`, the order is important here
|
|
"""
|
|
root_parent = _find_root_edge_or_node(parent, shared_with_map)
|
|
root_child = _find_root_edge_or_node(child, shared_with_map)
|
|
# union the two trees by pointing the root of child to root of parent
|
|
shared_with_map[root_child] = root_parent
|
|
|
|
def _update_shared_with(child: EdgeOrNode, qspec: QuantizationSpecBase, shared_with_map: Dict[EdgeOrNode, EdgeOrNode]):
|
|
"""Update the `shared_with_map` based on the qspec, this applies the `SharedQuantizationSpec`
|
|
configuration and established the relationship between `edge_or_node` with the edge/node that it
|
|
is pointing to, we'll use this information in the end to get the group id
|
|
"""
|
|
if isinstance(qspec, SharedQuantizationSpec):
|
|
parent = qspec.edge_or_node
|
|
# we point from edge_or_node to the node that it is sharing_with, e.g.
|
|
# qspec for a = SharedQuantizationSpec(b) means `a` points to `b`
|
|
_union(parent, child, shared_with_map)
|
|
|
|
def _unwrap_shared_qspec(
|
|
qspec: QuantizationSpecBase,
|
|
edge_or_node_to_qspec: Dict[EdgeOrNode, QuantizationSpecBase],
|
|
shared_with_map: Dict[EdgeOrNode, EdgeOrNode]
|
|
) -> QuantizationSpecBase:
|
|
"""Unwraps qspec to get the final root qspec (non SharedQuantizationSpec)
|
|
if qspec is SharedQuantizationSpec
|
|
(1). tries to find the root edge or node for the node that the qspec points to
|
|
(2). recursively find the root qspec based on the qspec for the root node
|
|
"""
|
|
if isinstance(qspec, SharedQuantizationSpec):
|
|
sharing_with = qspec.edge_or_node
|
|
root = _find_root_edge_or_node(sharing_with, shared_with_map)
|
|
qspec = edge_or_node_to_qspec[root]
|
|
return _unwrap_shared_qspec(qspec, edge_or_node_to_qspec, shared_with_map)
|
|
return qspec
|
|
|
|
def _has_same_dtype(qspec_a: QuantizationSpecBase, qspec_b: QuantizationSpecBase):
|
|
return (
|
|
hasattr(qspec_a, "dtype") and
|
|
hasattr(qspec_b, "dtype") and
|
|
qspec_a.dtype == qspec_b.dtype
|
|
)
|
|
|
|
def _has_same_is_dynamic(qspec_a: QuantizationSpecBase, qspec_b: QuantizationSpecBase):
|
|
return (
|
|
hasattr(qspec_a, "is_dynamic") and
|
|
hasattr(qspec_b, "is_dynamic") and
|
|
qspec_a.is_dynamic == qspec_b.is_dynamic
|
|
)
|
|
|
|
def _get_edge_or_node_to_qspec(model: torch.fx.GraphModule) -> Dict[EdgeOrNode, QuantizationSpecBase]:
|
|
"""Get a map from EdgeOrNode to quantization spec based on annotations on the nodes
|
|
"""
|
|
edge_or_node_to_qspec: Dict[EdgeOrNode, QuantizationSpecBase] = {}
|
|
for n in model.graph.nodes:
|
|
if hasattr(n, "meta") and "quantization_annotation" in n.meta:
|
|
qa = n.meta["quantization_annotation"]
|
|
for input_to_n, qspec in qa.input_qspec_map.items():
|
|
input_edge = (input_to_n, n)
|
|
edge_or_node_to_qspec[input_edge] = qspec
|
|
if qa.output_qspec is not None:
|
|
output_node = n
|
|
qspec = qa.output_qspec
|
|
edge_or_node_to_qspec[output_node] = qspec
|
|
return edge_or_node_to_qspec
|
|
|
|
def _union_input_edge_with(input_edge, input_edge_root_qspec, edge_or_node, edge_or_node_to_qspec, shared_with_map):
|
|
"""Union input edge with another edge or node, used in implicit sharing to point the current input
|
|
edge to other user edges of the producer node, or the output of producer node since these are
|
|
referring to the same Tensor
|
|
"""
|
|
root_qspec = None
|
|
if edge_or_node in edge_or_node_to_qspec:
|
|
qspec = edge_or_node_to_qspec[edge_or_node]
|
|
root_qspec = _unwrap_shared_qspec(qspec, edge_or_node_to_qspec, shared_with_map)
|
|
# TODO: add assertions for types of root qspecs
|
|
if (
|
|
root_qspec is not None and
|
|
_has_same_dtype(root_qspec, input_edge_root_qspec) and
|
|
_has_same_is_dynamic(root_qspec, input_edge_root_qspec)
|
|
):
|
|
# the input arg to the node should reuse the existing output observer for arg
|
|
# since dtype is the same (we may want to extend this to be a more strict check
|
|
# in the future)
|
|
# so we point from `input_edge` to `arg` (output of the argument)
|
|
_union(edge_or_node, input_edge, shared_with_map)
|
|
|
|
|
|
def _get_edge_or_node_to_group_id(edge_or_node_to_qspec: Dict[EdgeOrNode, QuantizationSpecBase]) -> Dict[EdgeOrNode, int]:
|
|
"""Map from edge/node to the group ID, generated from quantization annotations,
|
|
edge/node with the same group ID should use the same observer/fake_quant instance
|
|
|
|
This is applying SharedQuantizationSpec configuration and map each edge/node to a group
|
|
There is another implicit sharing that's built in the quantization, when we have the following:
|
|
* op1 -> op2
|
|
* output of op1: int8_qspec
|
|
* (op1 -> op2) input edge: int8_qspec
|
|
we'll assume sharing between the output of op1 and input of (op1 -> op2) since these are the same Tensor.
|
|
|
|
Figuring out the correct group ID for all edge/node is a standard union find problem:
|
|
https://www.geeksforgeeks.org/introduction-to-disjoint-set-data-structure-or-union-find-algorithm/
|
|
|
|
Args:
|
|
edge_or_node_to_qspec: Dictionary from edge_or_node to the qspec, derived from annotations
|
|
Returns:
|
|
edge_or_node_to_group_id: Dictionary from edge_or_node to group_id (int), all edge or node that
|
|
belongs to the same group should have the same id
|
|
|
|
Example:
|
|
op2 -> cat1 -> cat2
|
|
op1 / /
|
|
op3
|
|
edge_or_node_to_qspec: {
|
|
op1: int8_qspec,
|
|
op2: int8_qspec,
|
|
(op1, cat1): int8_qspc,
|
|
(op2, cat1): SharedQuantizationSpec((op1, cat1)),
|
|
cat1: SharedQuantizationSpec((op1, cat1)),
|
|
(op3, cat2): int8_qspec,
|
|
(cat1, cat2): SharedQuantizationSpec((op3, cat2)),
|
|
cat2: SharedQuantizationSpec((op3, cat2)),
|
|
}
|
|
|
|
edge_or_node_to_group_id = _get_edge_or_node_to_group_id(edge_or_node_to_qspec)
|
|
edge_or_node_to_group_id: {
|
|
op1: 1,
|
|
op2: 1,
|
|
(op1, cat1): 1,
|
|
(op2, cat1): 1,
|
|
cat1: 1,
|
|
(op3, cat2): 1,
|
|
(cat1, cat2): 1,
|
|
cat2: 1,
|
|
}
|
|
# everything are in the same group because (cat1) and (cat1, cat2) are implicitly shared, which
|
|
# connects the two sharing group around cat1 and cat2 op due to transitive sharing
|
|
"""
|
|
# means the observer of key should be shared with observer with value, by default it will
|
|
# be shared with itself
|
|
shared_with_map: Dict[EdgeOrNode, EdgeOrNode] = {k: k for k in edge_or_node_to_qspec.keys()}
|
|
for edge_or_node, qspec in edge_or_node_to_qspec.items():
|
|
if isinstance(edge_or_node, torch.fx.Node):
|
|
output_node = edge_or_node
|
|
_update_shared_with(output_node, qspec, shared_with_map)
|
|
else:
|
|
input_edge = edge_or_node
|
|
input_edge_root_qspec = _unwrap_shared_qspec(qspec, edge_or_node_to_qspec, shared_with_map)
|
|
|
|
assert isinstance(input_edge, tuple)
|
|
arg, n = input_edge
|
|
if n.meta["quantization_annotation"].allow_implicit_sharing:
|
|
# NOTE: the order is important here, we first share with other users and then share with previous
|
|
# output because the reverse order could cause circular dependency
|
|
# e.g node1 -> node2
|
|
# \ -> node3
|
|
# when processing (node1, node2), if we first point (node1, node2) to node1
|
|
# Step 1. shared_map = {(node1, node2): node1}
|
|
# Step 2. after that, we point the (node1, node2) to its other user (node1, node3) ,
|
|
# which means shared_map = {(node1, node2): node1, node1: (node1, node3)}
|
|
# because we will point the root of (node1, node2) (in this case node1) to the root of (node1, node3)
|
|
# Step 3. and when we process (node1, node3), it can try to point to node1 as well, then we'll
|
|
# have a circular dependency
|
|
# the following order works around this issue, but this does not allow arbitrary configuration
|
|
# of sharing so it might break in a different case in the future, when it breaks
|
|
# quantizer writer can check the notes here to debug the issue
|
|
|
|
# sharing with other users of the producer node
|
|
# (arg, user)
|
|
if not isinstance(arg, Node) or not isinstance(n, Node):
|
|
raise Exception(f"Expected input_edge to have type Tuple[Node, Node], but got: {arg, n}")
|
|
for user in arg.users:
|
|
if user is n:
|
|
continue
|
|
arg_to_user_edge = (arg, user)
|
|
_union_input_edge_with(
|
|
input_edge,
|
|
input_edge_root_qspec,
|
|
arg_to_user_edge,
|
|
edge_or_node_to_qspec,
|
|
shared_with_map
|
|
)
|
|
|
|
# sharing with output of producer node
|
|
_union_input_edge_with(input_edge, input_edge_root_qspec, arg, edge_or_node_to_qspec, shared_with_map)
|
|
|
|
_update_shared_with(input_edge, qspec, shared_with_map)
|
|
|
|
# now that we get the sharing relations between all edges and nodes, we can assingn group ids
|
|
cur_group_id = 0
|
|
edge_or_node_to_group_id: Dict[EdgeOrNode, int] = {}
|
|
for edge_or_node in shared_with_map.keys():
|
|
root = _find_root_edge_or_node(edge_or_node, shared_with_map)
|
|
if root not in edge_or_node_to_group_id:
|
|
edge_or_node_to_group_id[root] = cur_group_id
|
|
cur_group_id += 1
|
|
edge_or_node_to_group_id[edge_or_node] = edge_or_node_to_group_id[root]
|
|
|
|
return edge_or_node_to_group_id
|
|
|
|
def _get_obs_or_fq_map(
|
|
edge_or_node_to_group_id: Dict[EdgeOrNode, int],
|
|
edge_or_node_to_qspec: Dict[EdgeOrNode, QuantizationSpecBase],
|
|
is_qat: bool
|
|
) -> Dict[EdgeOrNode, ObserverOrFakeQuantize]:
|
|
"""Generates the EdgeOrNode to observer/fake_quant instances
|
|
Makes sure that for EdgeOrNode that has the same group_id should have the same observer or fake quant
|
|
instances
|
|
"""
|
|
obs_or_fq_map: Dict[EdgeOrNode, ObserverOrFakeQuantize] = {}
|
|
group_id_to_obs_or_fq: Dict[int, ObserverOrFakeQuantize] = {}
|
|
for edge_or_node, qspec in edge_or_node_to_qspec.items():
|
|
group_id = edge_or_node_to_group_id[edge_or_node]
|
|
if group_id not in group_id_to_obs_or_fq:
|
|
# TODO: maybe edge_or_node_to_qspec should be edge_or_node_to_root_qspec, this will simplify
|
|
# the implementation for _create_obs_or_fq_from_qspec
|
|
group_id_to_obs_or_fq[group_id] = _create_obs_or_fq_from_qspec(qspec, obs_or_fq_map, is_qat)
|
|
obs_or_fq_map[edge_or_node] = group_id_to_obs_or_fq[group_id]
|
|
return obs_or_fq_map
|
|
|
|
def _maybe_insert_input_observer_for_arg_or_kwarg(
|
|
node: Union[Node, Any],
|
|
arg: Argument,
|
|
qconfig: QConfigAny,
|
|
model: torch.nn.Module,
|
|
named_modules: Dict[str, torch.nn.Module],
|
|
obs_or_fq_map: Dict[EdgeOrNode, ObserverOrFakeQuantize],
|
|
is_qat: bool,
|
|
) -> Argument:
|
|
"""
|
|
Given a `node` and an `arg`, inserts an input observer between
|
|
`node` and `arg` if necessary.
|
|
"""
|
|
# for ops such as torch.cat([x0, x1]),
|
|
# traverse through the list
|
|
if isinstance(arg, (list, tuple)):
|
|
new_arg_to_return = []
|
|
for inner_arg in arg:
|
|
new_inner_arg = _maybe_insert_input_observer_for_arg_or_kwarg(
|
|
node, inner_arg, qconfig, model, named_modules, obs_or_fq_map, is_qat,
|
|
)
|
|
new_arg_to_return.append(new_inner_arg)
|
|
return type(arg)(new_arg_to_return)
|
|
|
|
if not isinstance(arg, Node):
|
|
return arg
|
|
assert isinstance(arg, Node)
|
|
# default (no observer)
|
|
new_arg = arg
|
|
|
|
# find the original `arg` node to the current node, skipping inserted observer/fake_quant nodes
|
|
original_arg = arg
|
|
while _is_activation_post_process_node(original_arg, named_modules):
|
|
original_arg = original_arg.args[0] # type: ignore[assignment]
|
|
assert isinstance(original_arg, Node), f"expect original argument to be a Node, but got: {type(original_arg)}"
|
|
|
|
input_edge = (original_arg, node)
|
|
if input_edge not in obs_or_fq_map:
|
|
return new_arg
|
|
# input_edge needs to be observed
|
|
input_edge_obs_or_fq = obs_or_fq_map[input_edge]
|
|
if input_edge_obs_or_fq is None:
|
|
return new_arg
|
|
|
|
arg_as_output_obs_or_fq = obs_or_fq_map.get(original_arg, None)
|
|
# the arg is observed as the output and is using the same instance as the input_edge
|
|
# we'll reuse the inserted observer/fake_quant
|
|
if arg_as_output_obs_or_fq is not None and id(arg_as_output_obs_or_fq) == id(input_edge_obs_or_fq):
|
|
return new_arg
|
|
|
|
# otherwise, we'll insert a new observer/fake_quant node
|
|
|
|
existing_obs_node = None
|
|
# skip inserting new observers if the same observer instance is inserted before for another user
|
|
# Example:
|
|
# conv1 -> obs1 -> existing_obs -> conv2
|
|
# \ -> conv3
|
|
#
|
|
# instead of inserting new observers we will have:
|
|
# conv1 -> obs1 -> existing_obs -> conv2
|
|
# \ -> conv3
|
|
for maybe_obs_node in arg.users.keys():
|
|
if not _is_activation_post_process_node(maybe_obs_node, named_modules):
|
|
continue
|
|
maybe_obs_mod = named_modules[maybe_obs_node.target] # type: ignore[index]
|
|
if id(maybe_obs_mod) == id(input_edge_obs_or_fq):
|
|
return maybe_obs_node
|
|
|
|
new_arg = _insert_obs_or_fq(arg, input_edge_obs_or_fq, model, named_modules, model.graph)
|
|
return new_arg
|
|
|
|
def _maybe_insert_input_observers_for_node(
|
|
node: Node,
|
|
qconfig: QConfigAny,
|
|
model: torch.nn.Module,
|
|
named_modules: Dict[str, torch.nn.Module],
|
|
obs_or_fq_map: Dict[EdgeOrNode, ObserverOrFakeQuantize],
|
|
is_qat: bool,
|
|
) -> None:
|
|
"""
|
|
If needed, inserts observers to the input args and kwargs of `node`.
|
|
Note: modifies `node` inplace.
|
|
|
|
For example, if cur_node needs an observer after prev_node, we change from
|
|
|
|
prev_node -> cur_node
|
|
|
|
To
|
|
|
|
prev_node -> obs -> cur_node
|
|
|
|
"""
|
|
# Look through every input arg. If that arg's target dtype does not
|
|
# match the current node's target dtype, insert an observer.
|
|
new_args = []
|
|
# map from old arg to new arg, used for updating the numeric debug handle map
|
|
remap = {}
|
|
for arg in node.args:
|
|
new_arg = _maybe_insert_input_observer_for_arg_or_kwarg(
|
|
node, arg, qconfig, model, named_modules, obs_or_fq_map, is_qat,
|
|
)
|
|
new_args.append(new_arg)
|
|
remap[arg] = new_arg
|
|
|
|
if "numeric_debug_handle" in node.meta:
|
|
|
|
def remap_fn(x):
|
|
return remap.get(x, x)
|
|
|
|
numeric_debug_handle = node.meta["numeric_debug_handle"]
|
|
node.meta["numeric_debug_handle"] = {remap_fn(k): v for k, v in numeric_debug_handle.items()}
|
|
|
|
# Clone has a memory_format kwarg and zeros_like has a pin_memory kwarg
|
|
# that persist in exported graph. This is just a work around for these.
|
|
assert (
|
|
node.target == torch.ops.aten.clone.default or
|
|
node.target == torch.ops.aten.zeros_like.default or
|
|
len(node.kwargs) == 0
|
|
), " expecting kwargs for aten op IR to be empty"
|
|
|
|
# assign the new args to the node, inplace
|
|
node.args = tuple(new_args)
|
|
|
|
def _maybe_insert_output_observer_for_node(
|
|
node: Node,
|
|
model: torch.nn.Module,
|
|
named_modules: Dict[str, torch.nn.Module],
|
|
graph: Graph,
|
|
obs_or_fq_map: Dict[EdgeOrNode, ObserverOrFakeQuantize],
|
|
is_qat: bool,
|
|
) -> Optional[Node]:
|
|
if node in obs_or_fq_map:
|
|
output_act_obs_or_fq = obs_or_fq_map[node]
|
|
return _insert_obs_or_fq(node, output_act_obs_or_fq, model, named_modules, graph)
|
|
return None
|
|
|
|
def _maybe_insert_input_and_output_observers_for_node(
|
|
node: Node,
|
|
model: torch.fx.GraphModule,
|
|
obs_or_fq_map: Dict[EdgeOrNode, ObserverOrFakeQuantize],
|
|
is_qat: bool,
|
|
):
|
|
this_node_quantization_annotation = node.meta["quantization_annotation"] if "quantization_annotation" in node.meta else None
|
|
if this_node_quantization_annotation is None:
|
|
return
|
|
|
|
named_modules = dict(model.named_modules(remove_duplicate=False))
|
|
_maybe_insert_input_observers_for_node(
|
|
node,
|
|
None, # qconfig
|
|
model,
|
|
named_modules,
|
|
obs_or_fq_map,
|
|
is_qat,
|
|
)
|
|
|
|
output_is_a_tensor = "val" in node.meta and isinstance(node.meta["val"], FakeTensor)
|
|
if not output_is_a_tensor:
|
|
return
|
|
|
|
# this returns the new observer node if it was needed
|
|
maybe_output_obs_node = _maybe_insert_output_observer_for_node(
|
|
node, model, named_modules, model.graph, obs_or_fq_map, is_qat)
|
|
|
|
if maybe_output_obs_node is None:
|
|
return
|
|
# Update users of original node to use the output observer
|
|
# instead. For example, change
|
|
#
|
|
# next_node
|
|
# /
|
|
# cur_node -> obs
|
|
#
|
|
# to
|
|
#
|
|
# next_node
|
|
# /
|
|
# cur_node -> obs
|
|
#
|
|
# We need to save orig users before updating uses because
|
|
# the list of users will change as we update uses
|
|
orig_users = list(node.users.keys())
|
|
for user_node in orig_users:
|
|
if user_node is maybe_output_obs_node:
|
|
continue
|
|
user_node.replace_input_with(node, maybe_output_obs_node)
|
|
|
|
def prepare(
|
|
model: GraphModule,
|
|
node_name_to_scope: Dict[str, Tuple[str, type]],
|
|
is_qat: bool,
|
|
) -> GraphModule:
|
|
# Since we are mutating the graph as we go, we iterate over the original
|
|
# nodes before observer insertion, instead of model.graph.nodes.
|
|
nodes_before_observation = list(model.graph.nodes)
|
|
|
|
# At the high level we construct a map from EdgeOrNode to a observer_or_fake_quant instance
|
|
# all edge/nodes that belongs to the same group will use the same instance
|
|
# and when we insert observers we'll just query this map to get the correct observer_or_fake_quant
|
|
# instance
|
|
edge_or_node_to_qspec = _get_edge_or_node_to_qspec(model)
|
|
edge_or_node_to_group_id = _get_edge_or_node_to_group_id(edge_or_node_to_qspec)
|
|
obs_or_fq_map = _get_obs_or_fq_map(edge_or_node_to_group_id, edge_or_node_to_qspec, is_qat)
|
|
|
|
for node in nodes_before_observation:
|
|
# TODO: simplify logic for inserting observers
|
|
_maybe_insert_input_and_output_observers_for_node(node, model, obs_or_fq_map, is_qat)
|
|
|
|
model = GraphModule(model, model.graph)
|
|
|
|
_save_state(
|
|
model,
|
|
{}, # node_name_to_qconfig
|
|
node_name_to_scope,
|
|
PrepareCustomConfig(),
|
|
{}, # equalization_node_name_to_qconfig
|
|
QConfigMapping(),
|
|
is_qat,
|
|
set() # observed_node_names
|
|
)
|
|
return model
|