50 lines
1.8 KiB
Python
50 lines
1.8 KiB
Python
|
from typing import List
|
||
|
|
||
|
from torch.ao.quantization.pt2e.utils import _is_sym_size_node
|
||
|
|
||
|
from torch.ao.quantization.quantizer.quantizer import QuantizationAnnotation
|
||
|
from torch.fx import Node
|
||
|
|
||
|
|
||
|
def _annotate_input_qspec_map(node: Node, input_node: Node, qspec):
|
||
|
quantization_annotation = node.meta.get(
|
||
|
"quantization_annotation", QuantizationAnnotation()
|
||
|
)
|
||
|
if quantization_annotation.input_qspec_map is None:
|
||
|
quantization_annotation.input_qspec_map = {}
|
||
|
quantization_annotation.input_qspec_map[input_node] = qspec
|
||
|
node.meta["quantization_annotation"] = quantization_annotation
|
||
|
|
||
|
|
||
|
def _annotate_output_qspec(node: Node, qspec):
|
||
|
quantization_annotation = node.meta.get(
|
||
|
"quantization_annotation", QuantizationAnnotation()
|
||
|
)
|
||
|
quantization_annotation.output_qspec = qspec
|
||
|
node.meta["quantization_annotation"] = quantization_annotation
|
||
|
|
||
|
|
||
|
def _node_only_used_for_sym_size(node: Node, partition_nodes: List[Node]):
|
||
|
"""
|
||
|
This utility is used to handle cases when dynami_shape=True tracing leads
|
||
|
to symint nodes in the pattern of linear module. In those cases, we need to
|
||
|
distinguish between the nodes that are in input for just extracting value of
|
||
|
some dimentions (and symint nodes) vs. the one that is activation.
|
||
|
For example:
|
||
|
graph(x, y, weight):
|
||
|
size_0 = torch.ops.aten.sym_size([x], [0])
|
||
|
size_1 = torch.ops.aten.sym_size([y], [1])
|
||
|
view_size = size_0 * size_1
|
||
|
size_3 = torch.ops.aten.sym_size([x], [2])
|
||
|
vie_out = torch.ops.aten.view(x, [view_size, size_3])
|
||
|
return mm(view_out, weight)
|
||
|
In the example above y node is not actual input. It exist only to extract size_1
|
||
|
"""
|
||
|
if _is_sym_size_node(node):
|
||
|
return True
|
||
|
|
||
|
return all(
|
||
|
((user not in partition_nodes) or _is_sym_size_node(user))
|
||
|
for user in node.users
|
||
|
)
|