ai-content-maker/.venv/Lib/site-packages/torch/fx/passes/shape_prop.py

196 lines
7.0 KiB
Python
Raw Normal View History

2024-05-03 04:18:51 +03:00
# mypy: ignore-errors
import torch
import torch.fx
import traceback
from torch._dispatch.python import enable_python_dispatcher
from torch.fx.node import Node, map_aggregate
from typing import Any, Tuple, NamedTuple, Optional, Dict
from torch.fx._compatibility import compatibility
from torch._guards import detect_fake_mode
__all__ = ['TensorMetadata', 'ShapeProp']
@compatibility(is_backward_compatible=True)
class TensorMetadata(NamedTuple):
# TensorMetadata is a structure containing pertinent information
# about a tensor within a PyTorch program.
# General Tensor metadata
shape : torch.Size
dtype : torch.dtype
requires_grad : bool
stride : Tuple[int, ...]
memory_format : Optional[torch.memory_format]
# Quantization metadata
is_quantized : bool
qparams: Dict[str, Any]
def _extract_tensor_metadata(result : torch.Tensor, include_contiguity=True) -> TensorMetadata:
"""
Extract a TensorMetadata NamedTuple describing `result`.
"""
shape = result.shape
dtype = result.dtype
requires_grad = result.requires_grad
stride = result.stride()
memory_format = None
if include_contiguity:
memory_formats = {
torch.contiguous_format,
torch.channels_last,
torch.channels_last_3d,
}
for query_format in memory_formats:
if result.is_contiguous(memory_format=query_format):
memory_format = query_format
break
is_quantized = result.is_quantized
qparams: Dict[str, Any] = {}
if is_quantized:
qscheme = result.qscheme()
qparams["qscheme"] = qscheme
if qscheme in {torch.per_tensor_affine, torch.per_tensor_symmetric}:
qparams["scale"] = result.q_scale() # type: ignore[assignment]
qparams["zero_point"] = result.q_zero_point() # type: ignore[assignment]
elif qscheme in {torch.per_channel_affine, torch.per_channel_affine_float_qparams, torch.per_channel_symmetric}:
# In this branch, scale and zero_point are expected to be tensors,
# we store the values as immutable_list in TensorMetadata for
# easier serialization downstream
qparams["scale"] = result.q_per_channel_scales().tolist() # type: ignore[assignment]
qparams["zero_point"] = result.q_per_channel_zero_points().tolist() # type: ignore[assignment]
qparams["axis"] = result.q_per_channel_axis() # type: ignore[assignment]
return TensorMetadata(
shape, dtype, requires_grad, stride, memory_format, is_quantized, qparams)
@compatibility(is_backward_compatible=True)
class ShapeProp(torch.fx.Interpreter):
"""
Execute an FX graph Node-by-Node and
record the shape and type of the result
into the corresponding node.
Example:
In this example, we record the shape
and data type of a module given
an example input ``torch.randn(50, D_in)``.
We print the name, shape and dtype of each node.
class TwoLayerNet(torch.nn.Module):
def __init__(self, D_in, H, D_out):
super().__init__()
self.linear1 = torch.nn.Linear(D_in, H)
self.linear2 = torch.nn.Linear(H, D_out)
def forward(self, x):
h_relu = self.linear1(x).clamp(min=0)
y_pred = self.linear2(h_relu)
return y_pred
N, D_in, H, D_out = 64, 1000, 100, 10
x = torch.randn(N, D_in)
y = torch.randn(N, D_out)
model = TwoLayerNet(D_in, H, D_out)
gm = torch.fx.symbolic_trace(model)
sample_input = torch.randn(50, D_in)
ShapeProp(gm).propagate(sample_input)
for node in gm.graph.nodes:
print(node.name, node.meta['tensor_meta'].dtype,
node.meta['tensor_meta'].shape)
The output of this code is:
x torch.float32 torch.Size([50, 1000])
linear1 torch.float32 torch.Size([50, 100])
clamp_1 torch.float32 torch.Size([50, 100])
linear2 torch.float32 torch.Size([50, 10])
output torch.float32 torch.Size([50, 10])
Args:
module (GraphModule): The module to be executed
fake_mode (FakeTensorMode): A fake mode for copying the gm
"""
def __init__(self, gm, fake_mode=None):
super().__init__(gm)
if fake_mode is None:
fake_mode = detect_fake_mode()
if fake_mode is not None:
from torch._dynamo.utils import deepcopy_to_fake_tensor
# Note:
# We need fake execution cause the inputs are fake, however, we cannot fakify the module
# - because we need to write to the tensor_meta of the real module. So we fakify to
# produce a result (L131 below), to extract tensor meta, and then keep going.
#
# If we were to fakify, we would write to the wrong node, and then downstream fusion
# would be missing the tensor_meta.
#
# See torch/_inductor/overrides.py for where this is called upstream of fusion.
self.fake_module = deepcopy_to_fake_tensor(self.module, fake_mode)
self.fake_mode = fake_mode
else:
self.fake_module = None
self.fake_mode = None
self.real_module = self.module
def run_node(self, n : Node) -> Any:
try:
if self.fake_module is not None:
# Hacky swap. Alternatively, we could do this with overriding
# call_module and get_attr.
self.module = self.fake_module
try:
if self.fake_mode is not None:
with self.fake_mode, enable_python_dispatcher():
result = super().run_node(n)
else:
result = super().run_node(n)
finally:
self.module = self.real_module
except Exception as e:
traceback.print_exc()
raise RuntimeError(
f"ShapeProp error for: node={n.format_node()} with "
f"meta={n.meta}"
) from e
found_tensor = False
def extract_tensor_meta(obj):
if isinstance(obj, torch.Tensor):
nonlocal found_tensor
found_tensor = True
return _extract_tensor_metadata(obj)
else:
return obj
meta = map_aggregate(result, extract_tensor_meta)
if found_tensor:
n.meta['tensor_meta'] = meta
n.meta['type'] = type(result)
return result
def propagate(self, *args):
"""
Run `module` via interpretation and return the result and
record the shape and type of each node.
Args:
*args (Tensor): the sample input.
Returns:
Any: The value returned from executing the Module
"""
if self.fake_mode is not None:
fake_args = [self.fake_mode.from_tensor(t) if isinstance(t, torch.Tensor) else t for t in args]
else:
fake_args = args
return super().run(*fake_args)