ai-content-maker/.venv/Lib/site-packages/torch/ao/quantization/pt2e/prepare.py

490 lines
20 KiB
Python

import torch
from torch._subclasses import FakeTensor
from torch.ao.quantization.fx.prepare import (
_insert_obs_or_fq,
_save_state,
_is_activation_post_process_node,
_create_obs_or_fq_from_qspec,
)
from torch.fx import (
GraphModule,
Graph,
Node,
)
from torch.fx.node import Argument
from torch.ao.quantization import QConfigMapping
from torch.ao.quantization.qconfig import QConfigAny
from torch.ao.quantization.fx.custom_config import PrepareCustomConfig
from typing import Dict, Tuple, Union, Any, Optional
from torch.ao.quantization.quantizer import (
EdgeOrNode,
SharedQuantizationSpec,
QuantizationSpecBase,
)
from torch.ao.quantization import ObserverOrFakeQuantize
# TODO: make pt2e folder private?
__all__ = [
"prepare",
]
def _find_root_edge_or_node(edge_or_node: EdgeOrNode, shared_with_map: Dict[EdgeOrNode, EdgeOrNode]) -> EdgeOrNode:
"""Find the root node for the sharing tree
Args:
edge_or_node: edge/node that we want to find the root
shared_with_map: each edge/node points to the parent, the root node will points to itself
Returns:
root edge/node
"""
parent = shared_with_map[edge_or_node]
if parent == edge_or_node:
return edge_or_node
root = _find_root_edge_or_node(parent, shared_with_map)
# path compression
shared_with_map[edge_or_node] = root
return root
def _union(parent: EdgeOrNode, child: EdgeOrNode, shared_with_map: Dict[EdgeOrNode, EdgeOrNode]) -> None:
"""Merge the subtree for `child` with `parent`, the order is important here
"""
root_parent = _find_root_edge_or_node(parent, shared_with_map)
root_child = _find_root_edge_or_node(child, shared_with_map)
# union the two trees by pointing the root of child to root of parent
shared_with_map[root_child] = root_parent
def _update_shared_with(child: EdgeOrNode, qspec: QuantizationSpecBase, shared_with_map: Dict[EdgeOrNode, EdgeOrNode]):
"""Update the `shared_with_map` based on the qspec, this applies the `SharedQuantizationSpec`
configuration and established the relationship between `edge_or_node` with the edge/node that it
is pointing to, we'll use this information in the end to get the group id
"""
if isinstance(qspec, SharedQuantizationSpec):
parent = qspec.edge_or_node
# we point from edge_or_node to the node that it is sharing_with, e.g.
# qspec for a = SharedQuantizationSpec(b) means `a` points to `b`
_union(parent, child, shared_with_map)
def _unwrap_shared_qspec(
qspec: QuantizationSpecBase,
edge_or_node_to_qspec: Dict[EdgeOrNode, QuantizationSpecBase],
shared_with_map: Dict[EdgeOrNode, EdgeOrNode]
) -> QuantizationSpecBase:
"""Unwraps qspec to get the final root qspec (non SharedQuantizationSpec)
if qspec is SharedQuantizationSpec
(1). tries to find the root edge or node for the node that the qspec points to
(2). recursively find the root qspec based on the qspec for the root node
"""
if isinstance(qspec, SharedQuantizationSpec):
sharing_with = qspec.edge_or_node
root = _find_root_edge_or_node(sharing_with, shared_with_map)
qspec = edge_or_node_to_qspec[root]
return _unwrap_shared_qspec(qspec, edge_or_node_to_qspec, shared_with_map)
return qspec
def _has_same_dtype(qspec_a: QuantizationSpecBase, qspec_b: QuantizationSpecBase):
return (
hasattr(qspec_a, "dtype") and
hasattr(qspec_b, "dtype") and
qspec_a.dtype == qspec_b.dtype
)
def _has_same_is_dynamic(qspec_a: QuantizationSpecBase, qspec_b: QuantizationSpecBase):
return (
hasattr(qspec_a, "is_dynamic") and
hasattr(qspec_b, "is_dynamic") and
qspec_a.is_dynamic == qspec_b.is_dynamic
)
def _get_edge_or_node_to_qspec(model: torch.fx.GraphModule) -> Dict[EdgeOrNode, QuantizationSpecBase]:
"""Get a map from EdgeOrNode to quantization spec based on annotations on the nodes
"""
edge_or_node_to_qspec: Dict[EdgeOrNode, QuantizationSpecBase] = {}
for n in model.graph.nodes:
if hasattr(n, "meta") and "quantization_annotation" in n.meta:
qa = n.meta["quantization_annotation"]
for input_to_n, qspec in qa.input_qspec_map.items():
input_edge = (input_to_n, n)
edge_or_node_to_qspec[input_edge] = qspec
if qa.output_qspec is not None:
output_node = n
qspec = qa.output_qspec
edge_or_node_to_qspec[output_node] = qspec
return edge_or_node_to_qspec
def _union_input_edge_with(input_edge, input_edge_root_qspec, edge_or_node, edge_or_node_to_qspec, shared_with_map):
"""Union input edge with another edge or node, used in implicit sharing to point the current input
edge to other user edges of the producer node, or the output of producer node since these are
referring to the same Tensor
"""
root_qspec = None
if edge_or_node in edge_or_node_to_qspec:
qspec = edge_or_node_to_qspec[edge_or_node]
root_qspec = _unwrap_shared_qspec(qspec, edge_or_node_to_qspec, shared_with_map)
# TODO: add assertions for types of root qspecs
if (
root_qspec is not None and
_has_same_dtype(root_qspec, input_edge_root_qspec) and
_has_same_is_dynamic(root_qspec, input_edge_root_qspec)
):
# the input arg to the node should reuse the existing output observer for arg
# since dtype is the same (we may want to extend this to be a more strict check
# in the future)
# so we point from `input_edge` to `arg` (output of the argument)
_union(edge_or_node, input_edge, shared_with_map)
def _get_edge_or_node_to_group_id(edge_or_node_to_qspec: Dict[EdgeOrNode, QuantizationSpecBase]) -> Dict[EdgeOrNode, int]:
"""Map from edge/node to the group ID, generated from quantization annotations,
edge/node with the same group ID should use the same observer/fake_quant instance
This is applying SharedQuantizationSpec configuration and map each edge/node to a group
There is another implicit sharing that's built in the quantization, when we have the following:
* op1 -> op2
* output of op1: int8_qspec
* (op1 -> op2) input edge: int8_qspec
we'll assume sharing between the output of op1 and input of (op1 -> op2) since these are the same Tensor.
Figuring out the correct group ID for all edge/node is a standard union find problem:
https://www.geeksforgeeks.org/introduction-to-disjoint-set-data-structure-or-union-find-algorithm/
Args:
edge_or_node_to_qspec: Dictionary from edge_or_node to the qspec, derived from annotations
Returns:
edge_or_node_to_group_id: Dictionary from edge_or_node to group_id (int), all edge or node that
belongs to the same group should have the same id
Example:
op2 -> cat1 -> cat2
op1 / /
op3
edge_or_node_to_qspec: {
op1: int8_qspec,
op2: int8_qspec,
(op1, cat1): int8_qspc,
(op2, cat1): SharedQuantizationSpec((op1, cat1)),
cat1: SharedQuantizationSpec((op1, cat1)),
(op3, cat2): int8_qspec,
(cat1, cat2): SharedQuantizationSpec((op3, cat2)),
cat2: SharedQuantizationSpec((op3, cat2)),
}
edge_or_node_to_group_id = _get_edge_or_node_to_group_id(edge_or_node_to_qspec)
edge_or_node_to_group_id: {
op1: 1,
op2: 1,
(op1, cat1): 1,
(op2, cat1): 1,
cat1: 1,
(op3, cat2): 1,
(cat1, cat2): 1,
cat2: 1,
}
# everything are in the same group because (cat1) and (cat1, cat2) are implicitly shared, which
# connects the two sharing group around cat1 and cat2 op due to transitive sharing
"""
# means the observer of key should be shared with observer with value, by default it will
# be shared with itself
shared_with_map: Dict[EdgeOrNode, EdgeOrNode] = {k: k for k in edge_or_node_to_qspec.keys()}
for edge_or_node, qspec in edge_or_node_to_qspec.items():
if isinstance(edge_or_node, torch.fx.Node):
output_node = edge_or_node
_update_shared_with(output_node, qspec, shared_with_map)
else:
input_edge = edge_or_node
input_edge_root_qspec = _unwrap_shared_qspec(qspec, edge_or_node_to_qspec, shared_with_map)
assert isinstance(input_edge, tuple)
arg, n = input_edge
if n.meta["quantization_annotation"].allow_implicit_sharing:
# NOTE: the order is important here, we first share with other users and then share with previous
# output because the reverse order could cause circular dependency
# e.g node1 -> node2
# \ -> node3
# when processing (node1, node2), if we first point (node1, node2) to node1
# Step 1. shared_map = {(node1, node2): node1}
# Step 2. after that, we point the (node1, node2) to its other user (node1, node3) ,
# which means shared_map = {(node1, node2): node1, node1: (node1, node3)}
# because we will point the root of (node1, node2) (in this case node1) to the root of (node1, node3)
# Step 3. and when we process (node1, node3), it can try to point to node1 as well, then we'll
# have a circular dependency
# the following order works around this issue, but this does not allow arbitrary configuration
# of sharing so it might break in a different case in the future, when it breaks
# quantizer writer can check the notes here to debug the issue
# sharing with other users of the producer node
# (arg, user)
if not isinstance(arg, Node) or not isinstance(n, Node):
raise Exception(f"Expected input_edge to have type Tuple[Node, Node], but got: {arg, n}")
for user in arg.users:
if user is n:
continue
arg_to_user_edge = (arg, user)
_union_input_edge_with(
input_edge,
input_edge_root_qspec,
arg_to_user_edge,
edge_or_node_to_qspec,
shared_with_map
)
# sharing with output of producer node
_union_input_edge_with(input_edge, input_edge_root_qspec, arg, edge_or_node_to_qspec, shared_with_map)
_update_shared_with(input_edge, qspec, shared_with_map)
# now that we get the sharing relations between all edges and nodes, we can assingn group ids
cur_group_id = 0
edge_or_node_to_group_id: Dict[EdgeOrNode, int] = {}
for edge_or_node in shared_with_map.keys():
root = _find_root_edge_or_node(edge_or_node, shared_with_map)
if root not in edge_or_node_to_group_id:
edge_or_node_to_group_id[root] = cur_group_id
cur_group_id += 1
edge_or_node_to_group_id[edge_or_node] = edge_or_node_to_group_id[root]
return edge_or_node_to_group_id
def _get_obs_or_fq_map(
edge_or_node_to_group_id: Dict[EdgeOrNode, int],
edge_or_node_to_qspec: Dict[EdgeOrNode, QuantizationSpecBase],
is_qat: bool
) -> Dict[EdgeOrNode, ObserverOrFakeQuantize]:
"""Generates the EdgeOrNode to observer/fake_quant instances
Makes sure that for EdgeOrNode that has the same group_id should have the same observer or fake quant
instances
"""
obs_or_fq_map: Dict[EdgeOrNode, ObserverOrFakeQuantize] = {}
group_id_to_obs_or_fq: Dict[int, ObserverOrFakeQuantize] = {}
for edge_or_node, qspec in edge_or_node_to_qspec.items():
group_id = edge_or_node_to_group_id[edge_or_node]
if group_id not in group_id_to_obs_or_fq:
# TODO: maybe edge_or_node_to_qspec should be edge_or_node_to_root_qspec, this will simplify
# the implementation for _create_obs_or_fq_from_qspec
group_id_to_obs_or_fq[group_id] = _create_obs_or_fq_from_qspec(qspec, obs_or_fq_map, is_qat)
obs_or_fq_map[edge_or_node] = group_id_to_obs_or_fq[group_id]
return obs_or_fq_map
def _maybe_insert_input_observer_for_arg_or_kwarg(
node: Union[Node, Any],
arg: Argument,
qconfig: QConfigAny,
model: torch.nn.Module,
named_modules: Dict[str, torch.nn.Module],
obs_or_fq_map: Dict[EdgeOrNode, ObserverOrFakeQuantize],
is_qat: bool,
) -> Argument:
"""
Given a `node` and an `arg`, inserts an input observer between
`node` and `arg` if necessary.
"""
# for ops such as torch.cat([x0, x1]),
# traverse through the list
if isinstance(arg, (list, tuple)):
new_arg_to_return = []
for inner_arg in arg:
new_inner_arg = _maybe_insert_input_observer_for_arg_or_kwarg(
node, inner_arg, qconfig, model, named_modules, obs_or_fq_map, is_qat,
)
new_arg_to_return.append(new_inner_arg)
return type(arg)(new_arg_to_return)
if not isinstance(arg, Node):
return arg
assert isinstance(arg, Node)
# default (no observer)
new_arg = arg
# find the original `arg` node to the current node, skipping inserted observer/fake_quant nodes
original_arg = arg
while _is_activation_post_process_node(original_arg, named_modules):
original_arg = original_arg.args[0] # type: ignore[assignment]
assert isinstance(original_arg, Node), f"expect original argument to be a Node, but got: {type(original_arg)}"
input_edge = (original_arg, node)
if input_edge not in obs_or_fq_map:
return new_arg
# input_edge needs to be observed
input_edge_obs_or_fq = obs_or_fq_map[input_edge]
if input_edge_obs_or_fq is None:
return new_arg
arg_as_output_obs_or_fq = obs_or_fq_map.get(original_arg, None)
# the arg is observed as the output and is using the same instance as the input_edge
# we'll reuse the inserted observer/fake_quant
if arg_as_output_obs_or_fq is not None and id(arg_as_output_obs_or_fq) == id(input_edge_obs_or_fq):
return new_arg
# otherwise, we'll insert a new observer/fake_quant node
existing_obs_node = None
# skip inserting new observers if the same observer instance is inserted before for another user
# Example:
# conv1 -> obs1 -> existing_obs -> conv2
# \ -> conv3
#
# instead of inserting new observers we will have:
# conv1 -> obs1 -> existing_obs -> conv2
# \ -> conv3
for maybe_obs_node in arg.users.keys():
if not _is_activation_post_process_node(maybe_obs_node, named_modules):
continue
maybe_obs_mod = named_modules[maybe_obs_node.target] # type: ignore[index]
if id(maybe_obs_mod) == id(input_edge_obs_or_fq):
return maybe_obs_node
new_arg = _insert_obs_or_fq(arg, input_edge_obs_or_fq, model, named_modules, model.graph)
return new_arg
def _maybe_insert_input_observers_for_node(
node: Node,
qconfig: QConfigAny,
model: torch.nn.Module,
named_modules: Dict[str, torch.nn.Module],
obs_or_fq_map: Dict[EdgeOrNode, ObserverOrFakeQuantize],
is_qat: bool,
) -> None:
"""
If needed, inserts observers to the input args and kwargs of `node`.
Note: modifies `node` inplace.
For example, if cur_node needs an observer after prev_node, we change from
prev_node -> cur_node
To
prev_node -> obs -> cur_node
"""
# Look through every input arg. If that arg's target dtype does not
# match the current node's target dtype, insert an observer.
new_args = []
# map from old arg to new arg, used for updating the numeric debug handle map
remap = {}
for arg in node.args:
new_arg = _maybe_insert_input_observer_for_arg_or_kwarg(
node, arg, qconfig, model, named_modules, obs_or_fq_map, is_qat,
)
new_args.append(new_arg)
remap[arg] = new_arg
if "numeric_debug_handle" in node.meta:
def remap_fn(x):
return remap.get(x, x)
numeric_debug_handle = node.meta["numeric_debug_handle"]
node.meta["numeric_debug_handle"] = {remap_fn(k): v for k, v in numeric_debug_handle.items()}
# Clone has a memory_format kwarg and zeros_like has a pin_memory kwarg
# that persist in exported graph. This is just a work around for these.
assert (
node.target == torch.ops.aten.clone.default or
node.target == torch.ops.aten.zeros_like.default or
len(node.kwargs) == 0
), " expecting kwargs for aten op IR to be empty"
# assign the new args to the node, inplace
node.args = tuple(new_args)
def _maybe_insert_output_observer_for_node(
node: Node,
model: torch.nn.Module,
named_modules: Dict[str, torch.nn.Module],
graph: Graph,
obs_or_fq_map: Dict[EdgeOrNode, ObserverOrFakeQuantize],
is_qat: bool,
) -> Optional[Node]:
if node in obs_or_fq_map:
output_act_obs_or_fq = obs_or_fq_map[node]
return _insert_obs_or_fq(node, output_act_obs_or_fq, model, named_modules, graph)
return None
def _maybe_insert_input_and_output_observers_for_node(
node: Node,
model: torch.fx.GraphModule,
obs_or_fq_map: Dict[EdgeOrNode, ObserverOrFakeQuantize],
is_qat: bool,
):
this_node_quantization_annotation = node.meta["quantization_annotation"] if "quantization_annotation" in node.meta else None
if this_node_quantization_annotation is None:
return
named_modules = dict(model.named_modules(remove_duplicate=False))
_maybe_insert_input_observers_for_node(
node,
None, # qconfig
model,
named_modules,
obs_or_fq_map,
is_qat,
)
output_is_a_tensor = "val" in node.meta and isinstance(node.meta["val"], FakeTensor)
if not output_is_a_tensor:
return
# this returns the new observer node if it was needed
maybe_output_obs_node = _maybe_insert_output_observer_for_node(
node, model, named_modules, model.graph, obs_or_fq_map, is_qat)
if maybe_output_obs_node is None:
return
# Update users of original node to use the output observer
# instead. For example, change
#
# next_node
# /
# cur_node -> obs
#
# to
#
# next_node
# /
# cur_node -> obs
#
# We need to save orig users before updating uses because
# the list of users will change as we update uses
orig_users = list(node.users.keys())
for user_node in orig_users:
if user_node is maybe_output_obs_node:
continue
user_node.replace_input_with(node, maybe_output_obs_node)
def prepare(
model: GraphModule,
node_name_to_scope: Dict[str, Tuple[str, type]],
is_qat: bool,
) -> GraphModule:
# Since we are mutating the graph as we go, we iterate over the original
# nodes before observer insertion, instead of model.graph.nodes.
nodes_before_observation = list(model.graph.nodes)
# At the high level we construct a map from EdgeOrNode to a observer_or_fake_quant instance
# all edge/nodes that belongs to the same group will use the same instance
# and when we insert observers we'll just query this map to get the correct observer_or_fake_quant
# instance
edge_or_node_to_qspec = _get_edge_or_node_to_qspec(model)
edge_or_node_to_group_id = _get_edge_or_node_to_group_id(edge_or_node_to_qspec)
obs_or_fq_map = _get_obs_or_fq_map(edge_or_node_to_group_id, edge_or_node_to_qspec, is_qat)
for node in nodes_before_observation:
# TODO: simplify logic for inserting observers
_maybe_insert_input_and_output_observers_for_node(node, model, obs_or_fq_map, is_qat)
model = GraphModule(model, model.graph)
_save_state(
model,
{}, # node_name_to_qconfig
node_name_to_scope,
PrepareCustomConfig(),
{}, # equalization_node_name_to_qconfig
QConfigMapping(),
is_qat,
set() # observed_node_names
)
return model