196 lines
7.0 KiB
Python
196 lines
7.0 KiB
Python
# mypy: ignore-errors
|
|
|
|
import torch
|
|
import torch.fx
|
|
import traceback
|
|
|
|
from torch._dispatch.python import enable_python_dispatcher
|
|
from torch.fx.node import Node, map_aggregate
|
|
from typing import Any, Tuple, NamedTuple, Optional, Dict
|
|
from torch.fx._compatibility import compatibility
|
|
from torch._guards import detect_fake_mode
|
|
|
|
__all__ = ['TensorMetadata', 'ShapeProp']
|
|
|
|
@compatibility(is_backward_compatible=True)
|
|
class TensorMetadata(NamedTuple):
|
|
# TensorMetadata is a structure containing pertinent information
|
|
# about a tensor within a PyTorch program.
|
|
|
|
# General Tensor metadata
|
|
shape : torch.Size
|
|
dtype : torch.dtype
|
|
requires_grad : bool
|
|
stride : Tuple[int, ...]
|
|
memory_format : Optional[torch.memory_format]
|
|
|
|
# Quantization metadata
|
|
is_quantized : bool
|
|
qparams: Dict[str, Any]
|
|
|
|
def _extract_tensor_metadata(result : torch.Tensor, include_contiguity=True) -> TensorMetadata:
|
|
"""
|
|
Extract a TensorMetadata NamedTuple describing `result`.
|
|
"""
|
|
shape = result.shape
|
|
dtype = result.dtype
|
|
requires_grad = result.requires_grad
|
|
stride = result.stride()
|
|
|
|
memory_format = None
|
|
|
|
if include_contiguity:
|
|
memory_formats = {
|
|
torch.contiguous_format,
|
|
torch.channels_last,
|
|
torch.channels_last_3d,
|
|
}
|
|
for query_format in memory_formats:
|
|
if result.is_contiguous(memory_format=query_format):
|
|
memory_format = query_format
|
|
break
|
|
|
|
is_quantized = result.is_quantized
|
|
qparams: Dict[str, Any] = {}
|
|
if is_quantized:
|
|
qscheme = result.qscheme()
|
|
qparams["qscheme"] = qscheme
|
|
if qscheme in {torch.per_tensor_affine, torch.per_tensor_symmetric}:
|
|
qparams["scale"] = result.q_scale() # type: ignore[assignment]
|
|
qparams["zero_point"] = result.q_zero_point() # type: ignore[assignment]
|
|
elif qscheme in {torch.per_channel_affine, torch.per_channel_affine_float_qparams, torch.per_channel_symmetric}:
|
|
# In this branch, scale and zero_point are expected to be tensors,
|
|
# we store the values as immutable_list in TensorMetadata for
|
|
# easier serialization downstream
|
|
qparams["scale"] = result.q_per_channel_scales().tolist() # type: ignore[assignment]
|
|
qparams["zero_point"] = result.q_per_channel_zero_points().tolist() # type: ignore[assignment]
|
|
qparams["axis"] = result.q_per_channel_axis() # type: ignore[assignment]
|
|
|
|
return TensorMetadata(
|
|
shape, dtype, requires_grad, stride, memory_format, is_quantized, qparams)
|
|
|
|
@compatibility(is_backward_compatible=True)
|
|
class ShapeProp(torch.fx.Interpreter):
|
|
"""
|
|
Execute an FX graph Node-by-Node and
|
|
record the shape and type of the result
|
|
into the corresponding node.
|
|
|
|
Example:
|
|
In this example, we record the shape
|
|
and data type of a module given
|
|
an example input ``torch.randn(50, D_in)``.
|
|
We print the name, shape and dtype of each node.
|
|
|
|
class TwoLayerNet(torch.nn.Module):
|
|
def __init__(self, D_in, H, D_out):
|
|
super().__init__()
|
|
self.linear1 = torch.nn.Linear(D_in, H)
|
|
self.linear2 = torch.nn.Linear(H, D_out)
|
|
def forward(self, x):
|
|
h_relu = self.linear1(x).clamp(min=0)
|
|
y_pred = self.linear2(h_relu)
|
|
return y_pred
|
|
N, D_in, H, D_out = 64, 1000, 100, 10
|
|
x = torch.randn(N, D_in)
|
|
y = torch.randn(N, D_out)
|
|
model = TwoLayerNet(D_in, H, D_out)
|
|
gm = torch.fx.symbolic_trace(model)
|
|
sample_input = torch.randn(50, D_in)
|
|
ShapeProp(gm).propagate(sample_input)
|
|
|
|
for node in gm.graph.nodes:
|
|
print(node.name, node.meta['tensor_meta'].dtype,
|
|
node.meta['tensor_meta'].shape)
|
|
|
|
The output of this code is:
|
|
|
|
x torch.float32 torch.Size([50, 1000])
|
|
linear1 torch.float32 torch.Size([50, 100])
|
|
clamp_1 torch.float32 torch.Size([50, 100])
|
|
linear2 torch.float32 torch.Size([50, 10])
|
|
output torch.float32 torch.Size([50, 10])
|
|
|
|
Args:
|
|
module (GraphModule): The module to be executed
|
|
fake_mode (FakeTensorMode): A fake mode for copying the gm
|
|
|
|
"""
|
|
def __init__(self, gm, fake_mode=None):
|
|
super().__init__(gm)
|
|
if fake_mode is None:
|
|
fake_mode = detect_fake_mode()
|
|
if fake_mode is not None:
|
|
from torch._dynamo.utils import deepcopy_to_fake_tensor
|
|
# Note:
|
|
# We need fake execution cause the inputs are fake, however, we cannot fakify the module
|
|
# - because we need to write to the tensor_meta of the real module. So we fakify to
|
|
# produce a result (L131 below), to extract tensor meta, and then keep going.
|
|
#
|
|
# If we were to fakify, we would write to the wrong node, and then downstream fusion
|
|
# would be missing the tensor_meta.
|
|
#
|
|
# See torch/_inductor/overrides.py for where this is called upstream of fusion.
|
|
self.fake_module = deepcopy_to_fake_tensor(self.module, fake_mode)
|
|
self.fake_mode = fake_mode
|
|
else:
|
|
self.fake_module = None
|
|
self.fake_mode = None
|
|
|
|
self.real_module = self.module
|
|
|
|
def run_node(self, n : Node) -> Any:
|
|
try:
|
|
if self.fake_module is not None:
|
|
# Hacky swap. Alternatively, we could do this with overriding
|
|
# call_module and get_attr.
|
|
self.module = self.fake_module
|
|
try:
|
|
if self.fake_mode is not None:
|
|
with self.fake_mode, enable_python_dispatcher():
|
|
result = super().run_node(n)
|
|
else:
|
|
result = super().run_node(n)
|
|
finally:
|
|
self.module = self.real_module
|
|
except Exception as e:
|
|
traceback.print_exc()
|
|
raise RuntimeError(
|
|
f"ShapeProp error for: node={n.format_node()} with "
|
|
f"meta={n.meta}"
|
|
) from e
|
|
|
|
found_tensor = False
|
|
|
|
def extract_tensor_meta(obj):
|
|
if isinstance(obj, torch.Tensor):
|
|
nonlocal found_tensor
|
|
found_tensor = True
|
|
return _extract_tensor_metadata(obj)
|
|
else:
|
|
return obj
|
|
|
|
meta = map_aggregate(result, extract_tensor_meta)
|
|
if found_tensor:
|
|
n.meta['tensor_meta'] = meta
|
|
|
|
n.meta['type'] = type(result)
|
|
return result
|
|
|
|
def propagate(self, *args):
|
|
"""
|
|
Run `module` via interpretation and return the result and
|
|
record the shape and type of each node.
|
|
|
|
Args:
|
|
*args (Tensor): the sample input.
|
|
|
|
Returns:
|
|
Any: The value returned from executing the Module
|
|
"""
|
|
if self.fake_mode is not None:
|
|
fake_args = [self.fake_mode.from_tensor(t) if isinstance(t, torch.Tensor) else t for t in args]
|
|
else:
|
|
fake_args = args
|
|
return super().run(*fake_args)
|