from enum import Enum from typing import NamedTuple, Dict, List, Set from torch.fx.node import Node, map_arg class Partition: """Partition class contains all the information about an individual partition. It also provides necessary methods for manipulation the partition. """ def __init__(self, partition_id: int) -> None: self.nodes: Set[Node] = set() self.partition_id = partition_id self.parents: Set[Partition] = set() self.children: Set[Partition] = set() self.bfs_level: int = -1 self.used_mem_bytes: int = 0 self.logical_device_ids: List[int] = [] def __str__(self): return str(self.partition_id) def recalculate_mem_size(self): self.used_mem_bytes = 0 for node in self.nodes: self.used_mem_bytes += get_extra_size_of(node, self.nodes) def add_node(self, node): input_nodes: Dict[Node, None] = {} map_arg(node.args, input_nodes.setdefault) map_arg(node.kwargs, input_nodes.setdefault) # Add current node's input nodes if they are placeholder or constants for n in input_nodes: if n.op in {"placeholder", "get_attr"}: self.nodes.add(n) self.nodes.add(node) self.recalculate_mem_size() def remove_node(self, node): # Remove a node only if the node is in the partition if node in self.nodes: self.nodes.remove(node) # Collect the node's input nodes input_nodes: Dict[Node, None] = {} map_arg(node.args, input_nodes.setdefault) map_arg(node.kwargs, input_nodes.setdefault) # Check if an input node is a placeholder or get_attr, # and this input node is not used by some other nodes in this partition, # the remove this input node for input_node in input_nodes: if all( n not in self.nodes for n in input_node.users ) and input_node.op in {"placeholder", "get_attr"}: self.nodes.remove(input_node) self.recalculate_mem_size() class Device(NamedTuple): name: str available_mem_bytes: int logical_id: int class NodeLatency(NamedTuple): # Latency due to the memory bandwidth mem_latency_sec: float # Latency due to the computation computer_latency_sec: float class PartitionLatency(NamedTuple): # Sum of all nodes' memory latency on the critical path mem_latency_sec: float # Sum of all nodes' compute latency on the critical path computer_latency_sec: float # Latency of the critical path overall_latency_sec: float class PartitionMode(Enum): size_based = 0 sparse_nn = 1 cost_aware = 2 kl_based = 3 aot_based = 4 class PartitionerConfig(NamedTuple): devices: List[Device] mode: PartitionMode = PartitionMode.size_based transfer_rate_bytes_per_sec: float = 0.0 node_to_latency_mapping: Dict[Node, NodeLatency] = {} node_to_partition_mapping: Dict[Node, int] = {} partition_to_logical_device_mapping: Dict[int, List[int]] = {} # Saturate host by replicating partitions to the remaining idle devices. saturate_host: bool = False def get_extra_size_of(node: Node, nodes: Set[Node]) -> int: """Given a node and a set of nodes, this function return the extra size that needed if this node is included in this set. """ # Find all its input nodes input_nodes: Dict[Node, None] = {} map_arg(node.args, input_nodes.setdefault) map_arg(node.kwargs, input_nodes.setdefault) # Calculate total size of related nodes total_size_of_input_nodes = 0 for n in input_nodes: # Make sure this node hasn't been in this set yet if n not in nodes: size_bytes = getattr(n, "size_bytes", None) if size_bytes: total_size_of_input_nodes += size_bytes.output_size else: raise RuntimeError("node has no size_bytes attr") # Don't forget the op node itself size_bytes = getattr(node, "size_bytes", None) if size_bytes: total_size_of_input_nodes += size_bytes.total_size else: raise RuntimeError("node has no size_bytes attr") return total_size_of_input_nodes def get_latency_of_one_partition( partition: Partition, node_to_latency_mapping: Dict[Node, NodeLatency] ) -> PartitionLatency: """Given a partition and its nodes' latency, return a PartitionLatency for this partition""" def get_top_nodes(partition: Partition) -> List[Node]: """Given a partition, return a list of nodes on the top bfs level""" top_nodes: List[Node] = [] for node in partition.nodes: # Skip placeholder and get_attr nodes if node.op in {"placeholder", "get_attr"}: continue input_nodes: Dict[Node, None] = {} map_arg(node.args, input_nodes.setdefault) map_arg(node.kwargs, input_nodes.setdefault) # If a node has no input nodes in this partition, # or its input nodes in this partition are placeholders and get_attrs # this node is on the top bfs level in this partition if not any( n in partition.nodes and n.op not in {"placeholder", "get_attr"} for n in input_nodes ): top_nodes.append(node) return top_nodes def dfs_helper(node: Node, partition_latency) -> PartitionLatency: """Given a top node of a partition, this function returns the latency of the critical path in the partition """ node_latency = node_to_latency_mapping[node] # Calculate the current overall latency of the partition overall_latency_sec = partition_latency.overall_latency_sec + max( node_latency.computer_latency_sec, node_latency.mem_latency_sec ) # Update the mem latency of this path mem_latency_sec = ( partition_latency.mem_latency_sec + node_latency.mem_latency_sec ) # Update the compute latency of this path computer_latency_sec = ( partition_latency.computer_latency_sec + node_latency.computer_latency_sec ) # Get all users of this node that are in this partition users = set(node.users).intersection(partition.nodes) if users: max_latency = PartitionLatency( mem_latency_sec=0.0, computer_latency_sec=0.0, overall_latency_sec=0.0 ) for n in users: # Get new partition latency recursively new_partition_latency = dfs_helper( n, PartitionLatency( mem_latency_sec, computer_latency_sec, overall_latency_sec ), ) if ( new_partition_latency.overall_latency_sec > max_latency.overall_latency_sec ): max_latency = new_partition_latency return max_latency # If there is no user, the node is at bottom of the partition return PartitionLatency( mem_latency_sec, computer_latency_sec, overall_latency_sec ) # Main part starts # Get all top level nodes of this partition top_nodes = get_top_nodes(partition) critical_path_latency = PartitionLatency( mem_latency_sec=0.0, computer_latency_sec=0.0, overall_latency_sec=0.0 ) # Go through all top nodes and find the largest latency (critical pass latency) for node in top_nodes: partition_latency = dfs_helper( node, PartitionLatency( mem_latency_sec=0.0, computer_latency_sec=0.0, overall_latency_sec=0.0 ), ) if ( partition_latency.overall_latency_sec > critical_path_latency.overall_latency_sec ): critical_path_latency = partition_latency return critical_path_latency def get_partition_to_latency_mapping( partitions: List[Partition], node_to_latency_mapping: Dict[Node, NodeLatency] ) -> Dict[Partition, PartitionLatency]: """Given all the partitions and node_to_latency_mapping dictionary, return a mapping dictionary of each partition to its overall latency """ partition_to_latency_mapping: Dict[Partition, PartitionLatency] = {} # Go through each partition and get its latency for partition in partitions: partition_latency = get_latency_of_one_partition( partition, node_to_latency_mapping ) partition_to_latency_mapping[partition] = partition_latency return partition_to_latency_mapping def get_comm_latency_between( parent_partition: Partition, child_partition: Partition, transfer_rate_bytes_per_sec: float, ): """Given two partitions (parent and child), calculate the communication latency between the two. """ # If two partitions are on the same device, the comm latency is 0. if ( parent_partition.logical_device_ids != [] and child_partition.logical_device_ids != [] and parent_partition.logical_device_ids == child_partition.logical_device_ids ): return 0.0 # Keep tracking the communication size between parent and child comm_size = 0 # Keep tracking all the counted node visited_nodes = set() # Go through all nodes in the child partition # If a node has input nodes from the parent partition, # the output size of those input nodes will be counted # and added to comm_size for node in child_partition.nodes: input_nodes: Dict[Node, None] = {} map_arg(node.args, input_nodes.setdefault) map_arg(node.kwargs, input_nodes.setdefault) for n in input_nodes: if n in parent_partition.nodes and n not in visited_nodes: size_bytes = getattr(n, "size_bytes", None) if size_bytes is not None: comm_size += size_bytes.output_size visited_nodes.add(n) return comm_size / transfer_rate_bytes_per_sec def get_latency_of_partitioned_graph( partitions: List[Partition], partition_to_latency_mapping: Dict[Partition, PartitionLatency], transfer_rate_bytes_per_sec: float, ): """Given all partitions in a graph, find the critical path among all partitions and return its latency as the latency of the whole graph """ def dfs_helper(partition: Partition, latency_so_far_sec: float) -> float: """This function helps to recursively get the latency of a path of partitions""" # Update latency by adding current partition's latency latency_so_far_sec += partition_to_latency_mapping[ partition ].overall_latency_sec children = partition.children if partition.children: max_latency_sec = 0.0 for child in partition.children: # Calculate latency between comm_latency_sec = get_comm_latency_between( partition, child, transfer_rate_bytes_per_sec ) new_latency_sec = dfs_helper( child, latency_so_far_sec + comm_latency_sec ) if new_latency_sec > max_latency_sec: max_latency_sec = new_latency_sec return max_latency_sec return latency_so_far_sec def get_top_partitions(partitions: List[Partition]) -> List[Partition]: """This function is to return all the partitions without parents as the starting points of all the paths """ top_partitions = [] for partition in partitions: # If a partition has no parents, then it is a top partition if len(partition.parents) == 0: top_partitions.append(partition) return top_partitions top_partitions = get_top_partitions(partitions) critical_path_latency_sec = 0.0 for partition in top_partitions: latency_sec = dfs_helper(partition, 0.0) if latency_sec > critical_path_latency_sec: critical_path_latency_sec = latency_sec return critical_path_latency_sec