315 lines
10 KiB
Python
315 lines
10 KiB
Python
|
import copy
|
||
|
from itertools import chain
|
||
|
from typing import Any, Dict, List, Optional, Tuple
|
||
|
|
||
|
import torch
|
||
|
import torch.utils._pytree as pytree
|
||
|
from torch._export.utils import _check_input_constraints_for_graph
|
||
|
from torch.export.unflatten import _assign_attr, _AttrKind
|
||
|
from torch.fx.graph import _PyTreeCodeGen, _PyTreeInfo
|
||
|
from ._remove_effect_tokens_pass import _remove_effect_tokens
|
||
|
|
||
|
from .exported_program import (
|
||
|
ExportedProgram,
|
||
|
ExportGraphSignature,
|
||
|
InputKind,
|
||
|
OutputKind,
|
||
|
)
|
||
|
|
||
|
|
||
|
@torch._dynamo.disable
|
||
|
def _check_input_constraints_pre_hook(self, *args, **kwargs):
|
||
|
flat_args_with_path, received_spec = pytree.tree_flatten_with_path(args)
|
||
|
|
||
|
if received_spec != self._in_spec:
|
||
|
raise ValueError( # noqa: TRY200
|
||
|
"Trying to flatten user inputs with exported input tree spec: \n"
|
||
|
f"{self._in_spec}\n"
|
||
|
"but actually got inputs with tree spec of: \n"
|
||
|
f"{received_spec}"
|
||
|
)
|
||
|
|
||
|
return _check_input_constraints_for_graph(
|
||
|
[node for node in self.graph.nodes if node.op == "placeholder"],
|
||
|
flat_args_with_path,
|
||
|
self.range_constraints,
|
||
|
)
|
||
|
|
||
|
|
||
|
def _unlift_inputs_as_getattr(
|
||
|
gm: torch.fx.GraphModule,
|
||
|
lifted_inputs: List[Optional[str]],
|
||
|
) -> Tuple[Dict[str, torch.fx.Node], Dict[str, torch.fx.Node]]:
|
||
|
"""
|
||
|
Unlift inputs referring to params/buffers/constants as getattr nodes in the
|
||
|
graph
|
||
|
"""
|
||
|
unlifted_name_to_node = {}
|
||
|
input_name_to_node = {}
|
||
|
|
||
|
placeholder_nodes = [node for node in gm.graph.nodes if node.op == "placeholder"]
|
||
|
assert len(lifted_inputs) == len(placeholder_nodes)
|
||
|
for input_node, lifted_node in zip(placeholder_nodes, lifted_inputs):
|
||
|
if lifted_node is None:
|
||
|
input_name_to_node[input_node.name] = input_node
|
||
|
|
||
|
else:
|
||
|
with gm.graph.inserting_after(input_node):
|
||
|
getattr_node = gm.graph.get_attr(lifted_node)
|
||
|
input_node.replace_all_uses_with(getattr_node)
|
||
|
metadata = input_node.meta
|
||
|
gm.graph.erase_node(input_node)
|
||
|
getattr_node.meta = metadata
|
||
|
unlifted_name_to_node[lifted_node] = getattr_node
|
||
|
|
||
|
return unlifted_name_to_node, input_name_to_node
|
||
|
|
||
|
|
||
|
def _insert_copy_for_mutations(
|
||
|
gm: torch.fx.GraphModule,
|
||
|
mutated_outputs: List[Optional[str]],
|
||
|
unlifted_name_to_node: Dict[str, torch.fx.Node],
|
||
|
input_name_to_node: Dict[str, torch.fx.Node],
|
||
|
) -> None:
|
||
|
"""
|
||
|
Find the all the buffers and inputs that were mutated and insert copy_
|
||
|
operators to reflect mutations.
|
||
|
"""
|
||
|
output_node = None
|
||
|
for node in gm.graph.nodes:
|
||
|
if node.op == "output":
|
||
|
output_node = node
|
||
|
break
|
||
|
assert output_node is not None
|
||
|
outputs = pytree.tree_flatten(output_node.args)[0]
|
||
|
assert len(outputs) == len(mutated_outputs)
|
||
|
|
||
|
user_output_nodes = []
|
||
|
for return_node, mutated_node_name in zip(outputs, mutated_outputs):
|
||
|
if mutated_node_name is None:
|
||
|
user_output_nodes.append(return_node)
|
||
|
continue
|
||
|
|
||
|
if mutated_node_name in unlifted_name_to_node:
|
||
|
mutated_node = unlifted_name_to_node[mutated_node_name]
|
||
|
elif mutated_node_name in input_name_to_node:
|
||
|
mutated_node = input_name_to_node[mutated_node_name]
|
||
|
else:
|
||
|
raise RuntimeError(
|
||
|
f"Could not find {mutated_node_name} in either buffer or input nodes"
|
||
|
)
|
||
|
|
||
|
with gm.graph.inserting_before(output_node):
|
||
|
_ = gm.graph.call_function(
|
||
|
torch.ops.aten.copy_.default, (mutated_node, return_node)
|
||
|
)
|
||
|
|
||
|
with gm.graph.inserting_before(output_node):
|
||
|
# Only return user outputs
|
||
|
new_output = gm.graph.output(tuple(user_output_nodes))
|
||
|
output_node.replace_all_uses_with(new_output)
|
||
|
gm.graph.erase_node(output_node)
|
||
|
|
||
|
|
||
|
def _get_codegen(
|
||
|
in_spec: pytree.TreeSpec,
|
||
|
out_spec: Optional[pytree.TreeSpec],
|
||
|
) -> _PyTreeCodeGen:
|
||
|
"""
|
||
|
Create the codegen for the graph module based on the in/out specs
|
||
|
"""
|
||
|
if (
|
||
|
in_spec.type == tuple
|
||
|
and in_spec.num_children == 2
|
||
|
and in_spec.children_specs[0].type == tuple
|
||
|
and in_spec.children_specs[1].type == dict
|
||
|
):
|
||
|
# if in_spec contains the args (tuple) and kwargs (dict)
|
||
|
names = [f"arg_{i}" for i in range(in_spec.children_specs[0].num_children)]
|
||
|
# add kwarg names
|
||
|
names.extend(in_spec.children_specs[1].context)
|
||
|
else:
|
||
|
names = [f"arg_{i}" for i in range(in_spec.num_children)]
|
||
|
|
||
|
return _PyTreeCodeGen(
|
||
|
_PyTreeInfo(
|
||
|
names,
|
||
|
in_spec,
|
||
|
out_spec,
|
||
|
)
|
||
|
)
|
||
|
|
||
|
|
||
|
def _unlift(
|
||
|
gm: torch.fx.GraphModule,
|
||
|
lifted_inputs: List[Optional[str]],
|
||
|
mutated_outputs: List[Optional[str]],
|
||
|
in_spec: pytree.TreeSpec,
|
||
|
out_spec: Optional[pytree.TreeSpec],
|
||
|
state_dict: Dict[str, Any],
|
||
|
constants: Dict[str, Any],
|
||
|
):
|
||
|
"""
|
||
|
Args:
|
||
|
lifted_inputs: A list matching the graph module's input nodes. For
|
||
|
an input node that is referring to a lifted parameter/buffer, this
|
||
|
list will contain the fqn the corresponding attribute. Otherwise, this
|
||
|
list will contain None. This is used to unlift the lifted parameters as
|
||
|
get_attr nodes.
|
||
|
|
||
|
mutated_outputs: A list matching the graph module's output nodes. For
|
||
|
an output node that is referring to a mutated buffer or user input, this
|
||
|
list will contain the name of the corresponding buffer or user input
|
||
|
that needs to be mutated. Otherwise, this list will contain None. This
|
||
|
is used to re-insert an inplace copy_ operator to copy the mutated
|
||
|
values back to the original node.
|
||
|
"""
|
||
|
unlifted_name_to_node, input_name_to_node = _unlift_inputs_as_getattr(
|
||
|
gm, lifted_inputs
|
||
|
)
|
||
|
_insert_copy_for_mutations(
|
||
|
gm, mutated_outputs, unlifted_name_to_node, input_name_to_node
|
||
|
)
|
||
|
gm.graph._codegen = _get_codegen(in_spec, out_spec)
|
||
|
gm.graph.lint()
|
||
|
gm.graph.eliminate_dead_code()
|
||
|
gm.recompile()
|
||
|
return gm
|
||
|
|
||
|
|
||
|
def _register_attrs_to_new_gm(
|
||
|
new_gm: torch.fx.GraphModule,
|
||
|
graph_signature: ExportGraphSignature,
|
||
|
state_dict: Dict[str, Any],
|
||
|
constants: Dict[str, Any],
|
||
|
) -> None:
|
||
|
non_persistent_buffers = set(graph_signature.non_persistent_buffers)
|
||
|
for name in graph_signature.buffers:
|
||
|
if name in non_persistent_buffers:
|
||
|
persistent = False
|
||
|
value = constants[name]
|
||
|
else:
|
||
|
persistent = True
|
||
|
value = state_dict[name]
|
||
|
_assign_attr(
|
||
|
value, new_gm, name, attr_kind=_AttrKind.BUFFER, persistent=persistent
|
||
|
)
|
||
|
for name in graph_signature.parameters:
|
||
|
value = state_dict[name]
|
||
|
_assign_attr(
|
||
|
value,
|
||
|
new_gm,
|
||
|
name,
|
||
|
attr_kind=_AttrKind.PARAMETER,
|
||
|
)
|
||
|
|
||
|
for name in chain(
|
||
|
graph_signature.lifted_custom_objs, graph_signature.lifted_tensor_constants
|
||
|
):
|
||
|
value = constants[name]
|
||
|
_assign_attr(
|
||
|
value,
|
||
|
new_gm,
|
||
|
name,
|
||
|
attr_kind=_AttrKind.CONSTANT,
|
||
|
)
|
||
|
|
||
|
|
||
|
class _StatefulGraphModuleFactory(type):
|
||
|
"""
|
||
|
Metaclass that ensures a private constructor for _StatefulGraphModule
|
||
|
"""
|
||
|
|
||
|
def __call__(cls, *args, **kwargs):
|
||
|
raise TypeError(
|
||
|
f"{cls.__module__}.{cls.__qualname__} has no public constructor. "
|
||
|
)
|
||
|
|
||
|
def _create(cls, root, graph, range_constraints=None):
|
||
|
return super().__call__(
|
||
|
root,
|
||
|
graph,
|
||
|
range_constraints=range_constraints,
|
||
|
)
|
||
|
|
||
|
|
||
|
class _StatefulGraphModule(torch.fx.GraphModule, metaclass=_StatefulGraphModuleFactory):
|
||
|
def __init__(self, root, graph, range_constraints=None):
|
||
|
super().__init__(root, graph)
|
||
|
# Need to fix up non-persistent buffers.
|
||
|
self.range_constraints = range_constraints or []
|
||
|
|
||
|
|
||
|
def _create_stateful_graph_module(
|
||
|
plain_graph_module: torch.fx.GraphModule,
|
||
|
range_constraints,
|
||
|
# TODO(suo) this should not be optional, but is since we still ahve
|
||
|
# capture_pre_autograd_graph grr
|
||
|
graph_signature: Optional[ExportGraphSignature] = None,
|
||
|
):
|
||
|
stateful_gm = _StatefulGraphModule._create(
|
||
|
plain_graph_module,
|
||
|
plain_graph_module.graph,
|
||
|
range_constraints=range_constraints,
|
||
|
)
|
||
|
stateful_gm.register_forward_pre_hook(
|
||
|
_check_input_constraints_pre_hook, with_kwargs=True
|
||
|
)
|
||
|
|
||
|
if graph_signature is None:
|
||
|
return stateful_gm
|
||
|
# Fix up non-persistent buffers. torch.fx does not distinguish between
|
||
|
# persistent and non-persistent buffers, so we must restore that distinction
|
||
|
# here.
|
||
|
for buffer in graph_signature.non_persistent_buffers:
|
||
|
_assign_attr(
|
||
|
plain_graph_module.get_buffer(buffer),
|
||
|
stateful_gm,
|
||
|
buffer,
|
||
|
attr_kind=_AttrKind.BUFFER,
|
||
|
persistent=False,
|
||
|
)
|
||
|
|
||
|
return stateful_gm
|
||
|
|
||
|
|
||
|
def _unlift_exported_program_lifted_states(ep: ExportedProgram) -> torch.nn.Module:
|
||
|
ep = _remove_effect_tokens(ep)
|
||
|
new_gm = torch.fx.GraphModule(ep.graph_module, copy.deepcopy(ep.graph))
|
||
|
_register_attrs_to_new_gm(new_gm, ep.graph_signature, ep.state_dict, ep.constants)
|
||
|
|
||
|
lifted_inputs: List[Optional[str]] = [
|
||
|
in_spec.target
|
||
|
if in_spec.kind
|
||
|
in (
|
||
|
InputKind.BUFFER,
|
||
|
InputKind.CONSTANT_TENSOR,
|
||
|
InputKind.PARAMETER,
|
||
|
InputKind.CUSTOM_OBJ,
|
||
|
)
|
||
|
else None
|
||
|
for in_spec in ep.graph_signature.input_specs
|
||
|
]
|
||
|
|
||
|
mutated_outputs: List[Optional[str]] = [
|
||
|
out_spec.target
|
||
|
if out_spec.kind in (OutputKind.BUFFER_MUTATION, OutputKind.USER_INPUT_MUTATION)
|
||
|
else None
|
||
|
for out_spec in ep.graph_signature.output_specs
|
||
|
]
|
||
|
|
||
|
new_gm = _unlift(
|
||
|
new_gm,
|
||
|
lifted_inputs,
|
||
|
mutated_outputs,
|
||
|
ep.call_spec.in_spec,
|
||
|
ep.call_spec.out_spec,
|
||
|
ep.state_dict,
|
||
|
ep.constants,
|
||
|
)
|
||
|
unlift_gm = _create_stateful_graph_module(
|
||
|
new_gm, ep.range_constraints, ep.graph_signature
|
||
|
)
|
||
|
unlift_gm.meta.update(ep.graph_module.meta)
|
||
|
return unlift_gm
|