import torch from torch.fx.node import Node from torch.fx._symbolic_trace import symbolic_trace from torch.fx.passes.tools_common import legalize_graph import itertools import operator from typing import Dict, List, Tuple def split_result_tensors( result: torch.Tensor, inputs: List[torch.Tensor] ) -> Tuple[torch.Tensor, ...]: """ A free function for use in the merge_matmul graph transformation below that splits the output from a merged matmul into the individual results for each input tensor. Arguments: result: The merged matmul result tensor. inputs: The list of inputs that were merged into one for the matmul. Returns: List of matmul results for each input tensor. """ # When fx tracer is running, x.shape[0] will be torch.fx.Attribute but we # need an int even when tracing if isinstance(result, torch.fx.Proxy): splits = [0] * len(inputs) else: splits = [x.shape[0] for x in inputs] return torch.split(result, splits) def may_depend_on(a: Node, b: Node, search_depth: int = 6): """ Determine if one node depends on another in a torch.fx.Graph. Arguments: a: The node that may have a dependency on b. b: The node that a may have a dependency on. search_depth: In the case of an indirect dependency, this function searches upto this many nodes away in search of a data dependency. If none is found, the function makes the conservative assumption that there is a dependency. Returns: True if a may depend on b, False if it definitely does not. """ # Equivalence is defined as dependence. if a == b: return True # If a has no inputs, it cannot depend on b. if len(a.all_input_nodes) == 0: return False # If the search depth has been exhausted and no conclusion has been # reached, assume that there is a data dependency. if search_depth == 0: return True # Recursively check all inputs of a. for inp in a.all_input_nodes: if may_depend_on(inp, b, search_depth - 1): return True return False def are_nodes_independent(nodes: List[Node]): """ Check if all of the given nodes are pairwise-data independent. Arguments: nodes: The nodes to check for data dependencies. Returns: True if any pair in nodes has a data dependency. """ # For each pair in nodes: for i, j in itertools.combinations(nodes, 2): if may_depend_on(i, j) or may_depend_on(j, i): return False return True def merge_matmul(in_mod: torch.nn.Module): """ A graph transformation that merges matrix multiplication operations that share the same right-hand side operand into one large matrix multiplication. ____ _________ _________ ---- | | | | M| A * C | M| A | T| B | * K| C | = |---------| ---- , | | | | T| B * C | K ---- --------- --------- K R R """ gm = symbolic_trace(in_mod) rhs_users: Dict[Node, List[Node]] = {} lhs_users: Dict[Node, List[Node]] = {} # Populate rhs_users and lhs_users - maps from LHS/RHS matrix multiply operands to # the matmul of which they are the LHS/RHS. for node in gm.graph.nodes: if node.op != "call_function" or node.target is not torch.matmul: continue lhs, rhs = node.args # TODO: Properly handle aliasing caused by get_attr. For now, # use the attribute name as the operand if the node is a # get_attr. lhs = lhs.target if lhs.op == "get_attr" else lhs rhs = rhs.target if rhs.op == "get_attr" else rhs lhs_users.setdefault(lhs, []).append(node) rhs_users.setdefault(rhs, []).append(node) for rhs, mms in rhs_users.items(): # There must be at least matmuls for a merge to make sense. if len(mms) < 2: continue # All matmuls must not depend on each other directly or indirectly # in order for the merge to be possible. if not are_nodes_independent(mms): continue lhs_vals = [mm.args[0] for mm in mms] # Merge the matmul. # Collect a list of LHS operands and the single RHS operand. lhs = [gm.graph.get_attr(l) if isinstance(l, str) else l for l in lhs_vals] rhs = gm.graph.get_attr(rhs) if isinstance(rhs, str) else rhs # Concatenate all the LHS operands. merge_mm_cat = gm.graph.call_function(torch.cat, (lhs,), {}) # Multiply the concatenated LHS operands with the one RHS. This will produce # the same results as all the individual matmuls involving rhs in the original graph, # but they will all be concatenated together. merge_mm = gm.graph.call_function(torch.matmul, (merge_mm_cat, rhs,), {}) # Split the result of the merged matmul using the shapes of the LHS operands # to ascertain how large each chunk should be. merge_mm_split = gm.graph.call_function( split_result_tensors, (merge_mm, lhs), {} ) merge_mm_res = [ gm.graph.call_function(operator.getitem, (merge_mm_split, out), {}) for out in range(len(lhs)) ] # Replace all uses of the original, unmerged matmuls with the equivalent split chunk from the merged matmul. for old, new in zip(mms, merge_mm_res): old.replace_all_uses_with(new) gm.graph.erase_node(old) # All of the new nodes created above were inserted at the end, so we need to sort # the nodes topologically to make sure all definitions precede uses. legalize_graph(gm) gm.recompile() gm.graph.lint() return gm