303 lines
11 KiB
Python
303 lines
11 KiB
Python
import copy
|
|
from dataclasses import dataclass, field
|
|
from typing import Dict, List, Optional, Tuple, Type, Union
|
|
|
|
import torch.fx
|
|
from torch.fx._compatibility import compatibility
|
|
from torch.fx.graph import map_arg
|
|
from torch.fx.passes.utils import HolderModule, lift_subgraph_as_module
|
|
|
|
from .tools_common import NodeList
|
|
|
|
__all__ = ["getattr_recursive", "setattr_recursive", "Component", "split_by_tags"]
|
|
|
|
|
|
@compatibility(is_backward_compatible=False)
|
|
def getattr_recursive(obj, name):
|
|
for layer in name.split("."):
|
|
if hasattr(obj, layer):
|
|
obj = getattr(obj, layer)
|
|
else:
|
|
return None
|
|
return obj
|
|
|
|
|
|
@compatibility(is_backward_compatible=False)
|
|
def setattr_recursive(obj, attr, value):
|
|
if "." not in attr:
|
|
setattr(obj, attr, value)
|
|
else:
|
|
layer = attr.split(".")
|
|
setattr_recursive(getattr(obj, layer[0]), ".".join(layer[1:]), value)
|
|
|
|
|
|
@compatibility(is_backward_compatible=False)
|
|
@dataclass
|
|
class Component:
|
|
"""
|
|
A component serves as a container for a subgraph we want to create afterwards.
|
|
"""
|
|
|
|
graph: torch.fx.Graph
|
|
order: int
|
|
name: str
|
|
|
|
# Stores the placeholder nodes in `graph`.
|
|
input_placeholders: List = field(default_factory=list)
|
|
|
|
# Store the nodes in original graph that are placeholder in `graph`.
|
|
orig_inputs: List = field(default_factory=list)
|
|
|
|
# Store the nodes in original graph that are outputs in `graph`.
|
|
orig_outputs: List = field(default_factory=list)
|
|
|
|
# Mapping from get_attr node in original graph to get_attr node in `graph`.
|
|
getattr_maps: Dict[torch.fx.Node, torch.fx.Node] = field(default_factory=dict)
|
|
constructor_args: List[str] = field(default_factory=list)
|
|
gm: Optional[torch.fx.GraphModule] = None
|
|
|
|
|
|
@compatibility(is_backward_compatible=False)
|
|
def split_by_tags(
|
|
gm: torch.fx.GraphModule,
|
|
tags: List[str],
|
|
return_fqn_mapping: bool = False,
|
|
return_tuple: bool = False,
|
|
GraphModuleCls: Type[torch.fx.GraphModule] = torch.fx.GraphModule,
|
|
) -> Union[torch.fx.GraphModule, Tuple[torch.fx.GraphModule, Dict[str, str]]]:
|
|
"""
|
|
Splits a GraphModule using tags on its graph nodes. We honor the order of
|
|
tags. For example, we have tags = ["a", "b", "c"], the function will create
|
|
the initial submodules in the order of "a", "b", "c".
|
|
|
|
To set a tag:
|
|
gm.graph.nodes[idx].tag = "mytag"
|
|
|
|
This will result in all nodes with the same tag being extracted and placed in their
|
|
own submodule. For placeholder, output and get_attr node, the tag is ignored. placeholder
|
|
and output nodes are created when needed while get_attr nodes get copied to submodules
|
|
where they are used.
|
|
|
|
Given the following module def:
|
|
|
|
class SimpleModule(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.linear1 = torch.nn.Linear(...)
|
|
self.linear2 = torch.nn.Linear(...)
|
|
self.linear3 = torch.nn.Linear(...)
|
|
|
|
def forward(self, in1, in2):
|
|
r1 = self.linear1(in1)
|
|
r2 = self.linear2(in2)
|
|
r3 = torch.cat([r1, r2])
|
|
return self.linear3(r3)
|
|
|
|
Marking the node corresponding to in1 with the tag sc.REQUEST_ONLY.lower() results in the following split:
|
|
|
|
ro:
|
|
def forward(self, in1):
|
|
self = self.root
|
|
linear1 = self.linear1(in1)
|
|
return linear1
|
|
|
|
main:
|
|
def forward(self, in2, linear1):
|
|
self = self.root
|
|
linear2 = self.linear2(in2)
|
|
cat_1 = torch.cat([linear1, linear2])
|
|
linear3 = self.linear3(cat_1)
|
|
return linear3
|
|
|
|
main:
|
|
def forward(self, in1, in2):
|
|
self = self.root
|
|
ro_0 = self.ro_0(in1)
|
|
main_1 = self.main_1(in2, ro_0)
|
|
return main_1
|
|
|
|
Returns:
|
|
split_gm: torch fx graph after split
|
|
orig_to_split_fqn_mapping: a map between the original fqn and the fqn
|
|
after split for call_module and get_attr.
|
|
"""
|
|
|
|
def flatten(x: torch.fx.node.Argument) -> NodeList:
|
|
"""
|
|
Stores nodes in x to a list and returns the list.
|
|
"""
|
|
r: NodeList = []
|
|
map_arg(x, r.append)
|
|
return r
|
|
|
|
# Mapping from node in original module to node in created submodule.
|
|
node_remapping: Dict[torch.fx.Node, torch.fx.Node] = {}
|
|
|
|
# Mapping from node in original module or created submodules to
|
|
# corresponding component.
|
|
node_to_component: Dict[torch.fx.Node, Component] = {}
|
|
|
|
# Mapping from tag to the corresponding component.
|
|
tag_to_component: Dict[str, Component] = {}
|
|
|
|
# Stores all components.
|
|
all_components: List[Component] = []
|
|
|
|
# Stores nodes that will be used in main graph.
|
|
used_in_main: Dict[torch.fx.Node, None] = {}
|
|
|
|
# Main graph after split.
|
|
main_g = torch.fx.Graph()
|
|
|
|
# Mapping from node in original module to node in main graph after split.
|
|
main_remapping: Dict[torch.fx.Node, torch.fx.Node] = {}
|
|
|
|
# Output node of original module.
|
|
output_node: Optional[torch.fx.Node] = None
|
|
|
|
# Create a component for each tag, we don't expect to create other components afterwards.
|
|
for tag in tags:
|
|
comp = Component(torch.fx.Graph(), len(all_components), f"{tag}")
|
|
all_components.append(comp)
|
|
tag_to_component[tag] = comp
|
|
|
|
# Traverse the nodes in original graph and take care of them.
|
|
for node in gm.graph.nodes:
|
|
if node.op == "output":
|
|
if output_node is not None:
|
|
raise RuntimeError("Multiple output nodes in graph!")
|
|
output_node = node
|
|
continue
|
|
|
|
# Placeholders in the original graph get copied to main graph.
|
|
if node.op == "placeholder":
|
|
main_remapping[node] = main_g.placeholder(node.name, type_expr=node.type)
|
|
main_remapping[node].meta = copy.copy(node.meta)
|
|
continue
|
|
|
|
# Get_attr nodes are ignored because we are not tagging them.
|
|
# Instead, we copy them directly to the submodules use them afterwards.
|
|
if node.op == "get_attr":
|
|
continue
|
|
|
|
# Now we process callable nodes which are nodes with op of call_module,
|
|
# call_function or call_method. Every callable nodes should be tagged.
|
|
assert hasattr(node, "tag")
|
|
|
|
upstream_components = [
|
|
node_to_component[x]
|
|
for x in flatten(node.args) + flatten(node.kwargs)
|
|
if x.op not in {"placeholder", "get_attr"}
|
|
]
|
|
|
|
comp = tag_to_component[node.tag]
|
|
node_to_component[node] = comp
|
|
|
|
# Max order of upperstream components.
|
|
mx = max((c.order for c in upstream_components), default=0)
|
|
|
|
# Expect the component for `node` has higher order then its upstream components.
|
|
assert comp.order >= mx
|
|
|
|
# Map a input of `node` to nodes in the component's graph.
|
|
def remap_func(x):
|
|
# If input is a get_attr node, copy it to current component's graph.
|
|
# Returns the get_attr node in current component's graph.
|
|
if x.op == "get_attr":
|
|
if x not in comp.getattr_maps:
|
|
comp.getattr_maps[x] = comp.graph.get_attr(
|
|
x.target, type_expr=x.type
|
|
)
|
|
return comp.getattr_maps[x]
|
|
|
|
# If input is not a placeholder, it should have been put into a component
|
|
# already. If it's the current component then we return the corresponding
|
|
# node in the component.
|
|
if x.op != "placeholder" and node_to_component[x] == comp:
|
|
return node_remapping[x]
|
|
|
|
# If input is a placeholder or it's in other components, we want to make it
|
|
# as a placeholder in current component's graph.
|
|
if x not in comp.orig_inputs:
|
|
comp.orig_inputs.append(x)
|
|
placeholder = comp.graph.placeholder(x.name, type_expr=x.type)
|
|
placeholder.meta = copy.copy(x.meta)
|
|
comp.input_placeholders.append(placeholder)
|
|
used_in_main[x] = None
|
|
|
|
return comp.input_placeholders[comp.orig_inputs.index(x)]
|
|
|
|
n = comp.graph.node_copy(node, remap_func)
|
|
n.tag = node.tag # type: ignore[attr-defined]
|
|
node_remapping[node] = n
|
|
node_to_component[n] = comp
|
|
|
|
if output_node is None:
|
|
raise RuntimeError("Graph had no output node!")
|
|
|
|
for x in flatten(output_node.args[0]):
|
|
if x.op == "get_attr":
|
|
# We don't need components mapping for nodes of type "get_attr"
|
|
# that are consumed by the output. Only need to make sure we create
|
|
# corresponding counterparts in the resulting graph.
|
|
main_remapping[x] = main_g.get_attr(x.name, type_expr=x.type)
|
|
else:
|
|
# All component results consumed by the output node should be
|
|
# marked as "used in main".
|
|
used_in_main[x] = None
|
|
|
|
# If a node is used in main graph then we mark it as an output in the component
|
|
# it belongs to.
|
|
for n in used_in_main:
|
|
if n.op != "placeholder":
|
|
node_to_component[n].orig_outputs.append(n)
|
|
|
|
# Now we create a graphmodule for each component.
|
|
orig_to_split_fqn_mapping: Dict[str, str] = {}
|
|
for comp in all_components:
|
|
outs = tuple(map(node_remapping.__getitem__, comp.orig_outputs))
|
|
|
|
if return_tuple:
|
|
comp.graph.output(outs)
|
|
else:
|
|
# Take care of the args of FX output node. If there's a single
|
|
# output then the output node args is like (output_single), else
|
|
# if there're multiple outputs then the output node args is like
|
|
# ((output_0, output_1, ...)).
|
|
comp.graph.output(outs[0] if len(outs) == 1 else outs)
|
|
|
|
comp.gm, comp_orig_to_split_fqn_mapping = lift_subgraph_as_module(
|
|
gm, subgraph=comp.graph, comp_name=comp.name
|
|
)
|
|
orig_to_split_fqn_mapping.update(comp_orig_to_split_fqn_mapping)
|
|
|
|
# Create a call_module node in main graph.
|
|
main_node = main_g.call_module(
|
|
comp.name,
|
|
args=tuple(map(main_remapping.__getitem__, comp.orig_inputs)),
|
|
kwargs=None,
|
|
)
|
|
|
|
if len(outs) == 1 and not return_tuple:
|
|
main_remapping[comp.orig_outputs[0]] = main_node
|
|
else:
|
|
for i, o in enumerate(comp.orig_outputs):
|
|
# Use Proxy to record getitem access.
|
|
main_remapping[o] = torch.fx.Proxy(main_node)[i].node # type: ignore[index]
|
|
|
|
main_g.output(map_arg(output_node.args[0], main_remapping.__getitem__))
|
|
main_root = HolderModule({comp.name: comp.gm for comp in all_components})
|
|
main_g._codegen = gm.graph._codegen
|
|
|
|
# If the output nodes consumes get_attr directly in the original graph,
|
|
# then we need to make sure get_attr is copied to the new graph.
|
|
for x in flatten(output_node.args[0]):
|
|
if x.op == "get_attr":
|
|
setattr(main_root, x.name, getattr_recursive(gm, x.target)) # type: ignore[arg-type]
|
|
|
|
result_gm = GraphModuleCls(main_root, main_g)
|
|
if return_fqn_mapping:
|
|
return result_gm, orig_to_split_fqn_mapping
|
|
|
|
return result_gm
|