import re from typing import Callable, Dict, Optional, Set, Union import torch.fx from torch.fx.node import map_arg from torch.fx.passes.split_module import split_module __all__ = ['FoldedGraphModule', 'get_unique_attr_name_in_module', 'split_const_subgraphs'] class FoldedGraphModule(torch.fx.GraphModule): """ FoldedGraphModule is a GraphModule which also contains another `const_subgraph_module` representing a subgraph which has all const attr inputs and which can be run once before running the main standard `graph`. The `const_output_names` are the ordered list names of attrs which represent what each respective output from the const_subgraph should be set on which attrs. """ def __init__( self, root: torch.nn.Module, graph: torch.fx.Graph, const_subgraph: Optional[torch.fx.Graph] = None, fx_const_folded_attrs_name: Optional[str] = None, device_for_folded_attrs: str = "cuda", ): super().__init__(root, graph) self.const_subgraph_module = ( None if const_subgraph is None else torch.fx.GraphModule(root, const_subgraph) ) self.has_folding_been_run = False self.fx_const_folded_attrs_name = fx_const_folded_attrs_name self.device_for_folded_attrs = device_for_folded_attrs def __call__(self, *args, **kwargs): if not self.has_folding_been_run: self.run_folding() return super().__call__(*args) def run_folding(self): # If there's no const subgraph module or attr output names to use, return # early as there is no const folding to perform. if ( self.const_subgraph_module is None or self.fx_const_folded_attrs_name is None ): return assert not self.has_folding_been_run self.has_folding_been_run = True # Actually run const folding subgraph. Note that single attr const fold # subgraphs output a single Tensor while multiple outputs are returned as # Tuple[Tensor,]. folded_attrs = self.const_subgraph_module() def _create_param(i): return torch.nn.Parameter( i if not isinstance(i, int) else torch.Tensor([i]).to(device=self.device_for_folded_attrs), requires_grad=i.requires_grad if isinstance(i, torch.Tensor) else False, ) params = ( torch.nn.ParameterList([_create_param(i) for i in folded_attrs]) if isinstance(folded_attrs, tuple) else _create_param(folded_attrs) ) setattr(self, self.fx_const_folded_attrs_name, params) def _inline_module(gm: torch.fx.GraphModule, inline_mod_name: str): """ Given `gm` and some graph module which is called with target name `inline_mod_name`, this helper will inline all of the nodes from that called graph module into `gm`. """ # Fetch the inner graph module that we want to inline inside `gm`. inline_mod = dict(gm.named_modules())[inline_mod_name] assert isinstance(inline_mod, torch.fx.GraphModule) call_mod_node_to_replace = None for node in gm.graph.nodes: if node.op == "call_module" and node.target == inline_mod_name: call_mod_node_to_replace = node break assert call_mod_node_to_replace is not None # Now actually do the swap. Note that we have to keep track of new nodes that are # copied into `gm` -- we do this via replacement_mapping. call_mod_args = call_mod_node_to_replace.args replacement_mapping: Dict[torch.fx.Node, torch.fx.Node] = {} ph_count = 0 def replacement_fn(node): new_node = replacement_mapping[node] new_node.meta = node.meta.copy() return new_node for inline_node in inline_mod.graph.nodes: if inline_node.op == "placeholder": replacement_mapping[inline_node] = call_mod_args[ph_count] ph_count += 1 continue if inline_node.op == "output": outputs = inline_node.args[0] output_replacements = map_arg(outputs, replacement_fn) call_mod_node_to_replace.replace_all_uses_with(output_replacements) continue with gm.graph.inserting_before(call_mod_node_to_replace): new_node = gm.graph.node_copy(inline_node, replacement_fn) replacement_mapping[inline_node] = new_node gm.graph.eliminate_dead_code() def get_unique_attr_name_in_module(mod_traced: torch.fx.GraphModule, name: str) -> str: """ Make sure the name is unique (in a module) and can represents an attr. """ # Delete all characters that are illegal in a Python identifier. name = re.sub("[^0-9a-zA-Z_]+", "_", name) if name[0].isdigit(): name = f"_{name}" # Now make sure it is in fact unique to the module by incrementing suffix value. while hasattr(mod_traced, name): match = re.match(r"(.*)_(\d+)$", name) if match is None: name = name + "_1" else: base, num = match.group(1, 2) name = f"{base}_{int(num) + 1}" return name def split_const_subgraphs( module: Union[torch.nn.Module, torch.fx.GraphModule], skip_folding_node_fn: Optional[Callable[[torch.fx.Node], bool]] = None, device_for_folded_attrs: str = "cpu", ) -> FoldedGraphModule: """ Looks through `module` for any nodes that have all constant attribute inputs and separates them out into their own constant subgraph, and returns a FoldedGraphModule which runs that constant subgraph on the first run to set attributes on the module prior to running the non-constant portion of the graph. """ if not isinstance(module, torch.fx.GraphModule): mod_traced = torch.fx.symbolic_trace(module) else: mod_traced = module # Build up a list of const_nodes, defined as nodes that are themselves # get_attrs, or have all get_attr or other constant node inputs. const_nodes: Set[torch.fx.Node] = set() found_const_folding = False for node in mod_traced.graph.nodes: # Skip over placeholders/outputs because they can't be const folded and # we don't want to add tags to them. if node.op in {"placeholder", "output"}: continue # If the node itself is constant, or all of its inputs are constant, # then tag it as constant. if node.op != "get_attr" and not set(node.all_input_nodes).issubset( const_nodes ): continue # If provided skip folding function says to skip, then skip. if skip_folding_node_fn and skip_folding_node_fn(node): continue # Skip folding side-effectful functions if node.is_impure(): continue # Must be a constant foldable node at this point. const_nodes.add(node) if node.op != "get_attr": found_const_folding = True # If we did not find any const folding then return early without a const fold subgraph. if not found_const_folding: return FoldedGraphModule(mod_traced, mod_traced.graph) # Partition the module into two: submod_0 for constant folding subgraph, and # submod_1 for the rest. def mod_partition(node: torch.fx.Node): return 0 if node in const_nodes else 1 split = split_module(mod_traced, module, mod_partition) const_gm, non_const_gm = split.submod_0, split.submod_1 const_mod_name, non_const_mod_name = "submod_0", "submod_1" # The module that a call_module node refers to gets copied to submodules during split. # The path to the module also gets inlined, i.e. mod.a.b -> mod_a_b. Here we need to # attach inlined modules to `split` as it's the owning module now. for node in non_const_gm.graph.nodes: if node.op == "call_module": setattr(split, node.target, getattr(non_const_gm, node.target)) for node in const_gm.graph.nodes: if node.op == "call_module": setattr(split, node.target, getattr(const_gm, node.target)) # split_module currently does not use get_attrs for attrs. Instead it passes # them in as args from the parent module, which used get_attrs. Here we set # them as get_attrs inside const_gm, allowing for running folding without # somehow a priori knowing the attrs that should be passed as args. We can # unconditionally do this for all placeholders because we know all # placeholders to const_gm must be constants accessible via get_attr. call_const_gm_args = None for node in split.graph.nodes: if node.op == "call_module": if node.target == const_mod_name: call_const_gm_args = node.args break assert call_const_gm_args is not None # Here we do the actual replacement of placeholders to get_attrs. Note that here we # set the const_gm.graph into a new root_const_gm with split as the root module, # because we are fetching attributes directly from the root module, instead of # fetching them from const_gm. Example: The const_gm must have some format like: # graph(): # %inp : [num_users=1] = placeholder[target=const_inp] # %add : [num_users=1] = call_function[target=operator.add](args = (%inp, %inp), kwargs = {}) # return add # We replace that with the following, which does not have any placeholders: # graph(): # %inp_1 : [num_users=1] = get_attr[target=const_inp] # %add : [num_users=1] = call_function[target=operator.add](args = (%inp_1, %inp_1), kwargs = {}) # return add root_const_gm = torch.fx.GraphModule(split, const_gm.graph) for node in root_const_gm.graph.nodes: if node.op == "output": multiple_outputs = isinstance(node.args[0], tuple) continue if node.op != "placeholder": continue in_node = next(n for n in call_const_gm_args if n.name == node.target) assert in_node.op == "get_attr" with root_const_gm.graph.inserting_before(node): new_node = root_const_gm.graph.get_attr(in_node.target) new_node.meta = node.meta.copy() node.replace_all_uses_with(new_node) root_const_gm.graph.erase_node(node) assert "multiple_outputs" in locals() # Now find the call to const_gm inside split, and replace it with a getattr to the # folded tensor(s) that result from constant folding. Note that we don't need to # worry about whether this is one or more tensors because the original graph # correctly uses getitem to extract individual tensors if there are multiple folded. fx_const_folded_attrs_name = get_unique_attr_name_in_module( split, "_FX_CONST_FOLDED_ATTRS" ) setattr( split, fx_const_folded_attrs_name, torch.nn.ParameterList() if multiple_outputs else torch.nn.Parameter(), # type: ignore[possibly-undefined] ) for node in split.graph.nodes: if node.op == "call_module" and node.target == const_mod_name: with node.graph.inserting_before(node): folded_attrs = node.graph.get_attr(fx_const_folded_attrs_name) folded_attrs.meta = node.meta.copy() node.replace_all_uses_with(folded_attrs) break split.graph.eliminate_dead_code() # Finally, inline the non-constant submod into the split submod. This is so that the # original caller who may have passed in a graph module will get back out a graph # module whose graph is traced to the same granularity. _inline_module(split, non_const_mod_name) return FoldedGraphModule( split, split.graph, root_const_gm.graph, fx_const_folded_attrs_name, device_for_folded_attrs, )