163 lines
5.3 KiB
Python
163 lines
5.3 KiB
Python
import operator
|
|
from typing import Any, Callable, Dict, Tuple, Optional
|
|
|
|
import torch
|
|
import torch.fx
|
|
import torch.fx as fx
|
|
from torch.fx import Transformer, Proxy
|
|
from torch.fx.node import Argument, Target, Node, map_aggregate
|
|
from torch.fx.operator_schemas import (
|
|
normalize_module,
|
|
normalize_function,
|
|
create_type_hint,
|
|
)
|
|
|
|
from .schema_type_annotation import AnnotateTypesWithSchema
|
|
|
|
|
|
class NormalizeArgs(Transformer):
|
|
"""
|
|
Normalize arguments to Python targets. This means that
|
|
`args/kwargs` will be matched up to the module/functional's
|
|
signature and rewritten to exclusively kwargs in positional order
|
|
if `normalize_to_only_use_kwargs` is true. Also populates default
|
|
values. Does not support positional-only parameters or varargs
|
|
parameters (*args, **kwargs).
|
|
|
|
If the nodes have 'type' metadata, it will use it to disambiguate
|
|
overloads. Otherwise, it will throw an error.
|
|
|
|
Example usage:
|
|
m = torchvision.models.resnet18()
|
|
traced = torch.fx.symbolic_trace(m)
|
|
traced = NormalizeArgs(traced).transform()
|
|
"""
|
|
|
|
def __init__(
|
|
self, module: torch.fx.GraphModule, normalize_to_only_use_kwargs: bool = True
|
|
):
|
|
super().__init__(module)
|
|
self.node_map: Dict[Proxy, Node] = {}
|
|
self.normalize_to_only_use_kwargs = normalize_to_only_use_kwargs
|
|
|
|
def run_node(self, n: Node) -> Any:
|
|
args, kwargs = self.fetch_args_kwargs_from_env(n)
|
|
|
|
def get_type(arg):
|
|
if isinstance(arg, fx.Node):
|
|
return n.meta["type"] if "type" in n.meta else None
|
|
return type(arg)
|
|
|
|
arg_types = map_aggregate(n.args, get_type)
|
|
assert isinstance(arg_types, tuple)
|
|
arg_types = tuple([create_type_hint(i) for i in arg_types])
|
|
kwarg_types = {k: get_type(v) for k, v in kwargs.items()}
|
|
if n.op == "call_function":
|
|
out = self.call_function(n.target, args, kwargs, arg_types, kwarg_types)
|
|
else:
|
|
out = super().run_node(n)
|
|
if n.op != "output":
|
|
self.node_map[out] = n
|
|
out.node.meta = n.meta
|
|
out.node.type = n.type
|
|
return out
|
|
|
|
def call_function(
|
|
self,
|
|
target: Target,
|
|
args: Tuple[Argument, ...],
|
|
kwargs: Dict[str, Any],
|
|
arg_types: Optional[Tuple[Any, ...]] = None,
|
|
kwarg_types: Optional[Dict[str, Any]] = None,
|
|
):
|
|
assert callable(target)
|
|
new_args_and_kwargs = normalize_function(
|
|
target,
|
|
args, # type: ignore[arg-type]
|
|
kwargs,
|
|
arg_types, # type: ignore[arg-type]
|
|
kwarg_types,
|
|
self.normalize_to_only_use_kwargs,
|
|
)
|
|
if new_args_and_kwargs:
|
|
new_args, new_kwargs = new_args_and_kwargs
|
|
return self.tracer.create_proxy(
|
|
"call_function", target, new_args, new_kwargs
|
|
)
|
|
else:
|
|
return super().call_function(target, args, kwargs)
|
|
|
|
def call_module(
|
|
self, target: Target, args: Tuple[Argument, ...], kwargs: Dict[str, Any]
|
|
):
|
|
assert isinstance(target, str)
|
|
new_args_and_kwargs = normalize_module(
|
|
self.module,
|
|
target,
|
|
args, # type: ignore[arg-type]
|
|
kwargs,
|
|
self.normalize_to_only_use_kwargs,
|
|
)
|
|
if new_args_and_kwargs:
|
|
new_args, new_kwargs = new_args_and_kwargs
|
|
return super().call_module(target, new_args, new_kwargs)
|
|
else:
|
|
return super().call_module(target, args, kwargs)
|
|
|
|
|
|
class NormalizeOperators(AnnotateTypesWithSchema):
|
|
"""
|
|
Normalize callsites that are different ways of "spelling" the same
|
|
invocation into a single, canonical call. Currently supports:
|
|
|
|
1. Normalize operators (e.g. operator.add) to the `torch` ops they
|
|
ultimately invoke (e.g. torch.add) when it is possible to statically
|
|
reason that
|
|
|
|
Example usage:
|
|
|
|
m = torchvision.models.resnet18()
|
|
|
|
traced = torch.fx.symbolic_trace(m)
|
|
|
|
traced = NormalizeOperators(traced).transform()
|
|
"""
|
|
|
|
binary_magic_method_remap: Dict[
|
|
Callable[[Any, Any], Any], Callable[[Any, Any], Any]
|
|
] = {
|
|
torch.add: operator.add,
|
|
torch.mul: operator.mul,
|
|
torch.sub: operator.sub,
|
|
torch.div: operator.truediv,
|
|
torch.floor_divide: operator.floordiv,
|
|
torch.remainder: operator.mod,
|
|
torch.eq: operator.eq,
|
|
torch.ne: operator.ne,
|
|
torch.lt: operator.lt,
|
|
torch.le: operator.le,
|
|
torch.gt: operator.gt,
|
|
torch.ge: operator.ge,
|
|
}
|
|
|
|
def call_function(
|
|
self, target: Target, args: Tuple[Argument, ...], kwargs: Dict[str, Any]
|
|
):
|
|
# Normalize operators according to the magic methods implemented on tensors here:
|
|
# https://github.com/pytorch/pytorch/blob/28c5d90b679c6b38bf4183ec99f16d933c2f1bcd/tools/autograd/templates/python_variable_methods.cpp#L1137 # noqa: B950
|
|
|
|
assert callable(target)
|
|
|
|
if target in self.binary_magic_method_remap:
|
|
if len(args) != 2:
|
|
return super().call_function(target, args, kwargs)
|
|
lhs, rhs = args
|
|
|
|
return super().call_function(
|
|
target=self.binary_magic_method_remap[target],
|
|
args=(lhs, rhs),
|
|
kwargs={},
|
|
)
|
|
|
|
return super().call_function(target, args, kwargs)
|