318 lines
12 KiB
Python
318 lines
12 KiB
Python
|
from enum import Enum
|
||
|
from typing import NamedTuple, Dict, List, Set
|
||
|
|
||
|
from torch.fx.node import Node, map_arg
|
||
|
|
||
|
|
||
|
class Partition:
|
||
|
"""Partition class contains all the information about an individual partition.
|
||
|
It also provides necessary methods for manipulation the partition.
|
||
|
"""
|
||
|
|
||
|
def __init__(self, partition_id: int) -> None:
|
||
|
self.nodes: Set[Node] = set()
|
||
|
self.partition_id = partition_id
|
||
|
self.parents: Set[Partition] = set()
|
||
|
self.children: Set[Partition] = set()
|
||
|
self.bfs_level: int = -1
|
||
|
self.used_mem_bytes: int = 0
|
||
|
self.logical_device_ids: List[int] = []
|
||
|
|
||
|
def __str__(self):
|
||
|
return str(self.partition_id)
|
||
|
|
||
|
def recalculate_mem_size(self):
|
||
|
self.used_mem_bytes = 0
|
||
|
for node in self.nodes:
|
||
|
self.used_mem_bytes += get_extra_size_of(node, self.nodes)
|
||
|
|
||
|
def add_node(self, node):
|
||
|
input_nodes: Dict[Node, None] = {}
|
||
|
map_arg(node.args, input_nodes.setdefault)
|
||
|
map_arg(node.kwargs, input_nodes.setdefault)
|
||
|
# Add current node's input nodes if they are placeholder or constants
|
||
|
for n in input_nodes:
|
||
|
if n.op in {"placeholder", "get_attr"}:
|
||
|
self.nodes.add(n)
|
||
|
self.nodes.add(node)
|
||
|
self.recalculate_mem_size()
|
||
|
|
||
|
def remove_node(self, node):
|
||
|
# Remove a node only if the node is in the partition
|
||
|
if node in self.nodes:
|
||
|
self.nodes.remove(node)
|
||
|
# Collect the node's input nodes
|
||
|
input_nodes: Dict[Node, None] = {}
|
||
|
map_arg(node.args, input_nodes.setdefault)
|
||
|
map_arg(node.kwargs, input_nodes.setdefault)
|
||
|
# Check if an input node is a placeholder or get_attr,
|
||
|
# and this input node is not used by some other nodes in this partition,
|
||
|
# the remove this input node
|
||
|
for input_node in input_nodes:
|
||
|
if all(
|
||
|
n not in self.nodes for n in input_node.users
|
||
|
) and input_node.op in {"placeholder", "get_attr"}:
|
||
|
self.nodes.remove(input_node)
|
||
|
self.recalculate_mem_size()
|
||
|
|
||
|
|
||
|
class Device(NamedTuple):
|
||
|
name: str
|
||
|
available_mem_bytes: int
|
||
|
logical_id: int
|
||
|
|
||
|
|
||
|
class NodeLatency(NamedTuple):
|
||
|
# Latency due to the memory bandwidth
|
||
|
mem_latency_sec: float
|
||
|
# Latency due to the computation
|
||
|
computer_latency_sec: float
|
||
|
|
||
|
|
||
|
class PartitionLatency(NamedTuple):
|
||
|
# Sum of all nodes' memory latency on the critical path
|
||
|
mem_latency_sec: float
|
||
|
# Sum of all nodes' compute latency on the critical path
|
||
|
computer_latency_sec: float
|
||
|
# Latency of the critical path
|
||
|
overall_latency_sec: float
|
||
|
|
||
|
|
||
|
class PartitionMode(Enum):
|
||
|
size_based = 0
|
||
|
sparse_nn = 1
|
||
|
cost_aware = 2
|
||
|
kl_based = 3
|
||
|
aot_based = 4
|
||
|
|
||
|
|
||
|
class PartitionerConfig(NamedTuple):
|
||
|
devices: List[Device]
|
||
|
mode: PartitionMode = PartitionMode.size_based
|
||
|
transfer_rate_bytes_per_sec: float = 0.0
|
||
|
node_to_latency_mapping: Dict[Node, NodeLatency] = {}
|
||
|
node_to_partition_mapping: Dict[Node, int] = {}
|
||
|
partition_to_logical_device_mapping: Dict[int, List[int]] = {}
|
||
|
# Saturate host by replicating partitions to the remaining idle devices.
|
||
|
saturate_host: bool = False
|
||
|
|
||
|
|
||
|
def get_extra_size_of(node: Node, nodes: Set[Node]) -> int:
|
||
|
"""Given a node and a set of nodes,
|
||
|
this function return the extra size that needed
|
||
|
if this node is included in this set.
|
||
|
"""
|
||
|
# Find all its input nodes
|
||
|
input_nodes: Dict[Node, None] = {}
|
||
|
map_arg(node.args, input_nodes.setdefault)
|
||
|
map_arg(node.kwargs, input_nodes.setdefault)
|
||
|
# Calculate total size of related nodes
|
||
|
total_size_of_input_nodes = 0
|
||
|
for n in input_nodes:
|
||
|
# Make sure this node hasn't been in this set yet
|
||
|
if n not in nodes:
|
||
|
size_bytes = getattr(n, "size_bytes", None)
|
||
|
if size_bytes:
|
||
|
total_size_of_input_nodes += size_bytes.output_size
|
||
|
else:
|
||
|
raise RuntimeError("node has no size_bytes attr")
|
||
|
# Don't forget the op node itself
|
||
|
size_bytes = getattr(node, "size_bytes", None)
|
||
|
if size_bytes:
|
||
|
total_size_of_input_nodes += size_bytes.total_size
|
||
|
else:
|
||
|
raise RuntimeError("node has no size_bytes attr")
|
||
|
return total_size_of_input_nodes
|
||
|
|
||
|
|
||
|
def get_latency_of_one_partition(
|
||
|
partition: Partition, node_to_latency_mapping: Dict[Node, NodeLatency]
|
||
|
) -> PartitionLatency:
|
||
|
"""Given a partition and its nodes' latency, return a PartitionLatency for this partition"""
|
||
|
|
||
|
def get_top_nodes(partition: Partition) -> List[Node]:
|
||
|
"""Given a partition, return a list of nodes on the top bfs level"""
|
||
|
top_nodes: List[Node] = []
|
||
|
for node in partition.nodes:
|
||
|
# Skip placeholder and get_attr nodes
|
||
|
if node.op in {"placeholder", "get_attr"}:
|
||
|
continue
|
||
|
input_nodes: Dict[Node, None] = {}
|
||
|
map_arg(node.args, input_nodes.setdefault)
|
||
|
map_arg(node.kwargs, input_nodes.setdefault)
|
||
|
# If a node has no input nodes in this partition,
|
||
|
# or its input nodes in this partition are placeholders and get_attrs
|
||
|
# this node is on the top bfs level in this partition
|
||
|
if not any(
|
||
|
n in partition.nodes and n.op not in {"placeholder", "get_attr"}
|
||
|
for n in input_nodes
|
||
|
):
|
||
|
top_nodes.append(node)
|
||
|
return top_nodes
|
||
|
|
||
|
def dfs_helper(node: Node, partition_latency) -> PartitionLatency:
|
||
|
"""Given a top node of a partition, this function returns
|
||
|
the latency of the critical path in the partition
|
||
|
"""
|
||
|
node_latency = node_to_latency_mapping[node]
|
||
|
# Calculate the current overall latency of the partition
|
||
|
overall_latency_sec = partition_latency.overall_latency_sec + max(
|
||
|
node_latency.computer_latency_sec, node_latency.mem_latency_sec
|
||
|
)
|
||
|
# Update the mem latency of this path
|
||
|
mem_latency_sec = (
|
||
|
partition_latency.mem_latency_sec + node_latency.mem_latency_sec
|
||
|
)
|
||
|
# Update the compute latency of this path
|
||
|
computer_latency_sec = (
|
||
|
partition_latency.computer_latency_sec + node_latency.computer_latency_sec
|
||
|
)
|
||
|
# Get all users of this node that are in this partition
|
||
|
users = set(node.users).intersection(partition.nodes)
|
||
|
if users:
|
||
|
max_latency = PartitionLatency(
|
||
|
mem_latency_sec=0.0, computer_latency_sec=0.0, overall_latency_sec=0.0
|
||
|
)
|
||
|
for n in users:
|
||
|
# Get new partition latency recursively
|
||
|
new_partition_latency = dfs_helper(
|
||
|
n,
|
||
|
PartitionLatency(
|
||
|
mem_latency_sec, computer_latency_sec, overall_latency_sec
|
||
|
),
|
||
|
)
|
||
|
if (
|
||
|
new_partition_latency.overall_latency_sec
|
||
|
> max_latency.overall_latency_sec
|
||
|
):
|
||
|
max_latency = new_partition_latency
|
||
|
return max_latency
|
||
|
# If there is no user, the node is at bottom of the partition
|
||
|
return PartitionLatency(
|
||
|
mem_latency_sec, computer_latency_sec, overall_latency_sec
|
||
|
)
|
||
|
|
||
|
# Main part starts
|
||
|
# Get all top level nodes of this partition
|
||
|
top_nodes = get_top_nodes(partition)
|
||
|
critical_path_latency = PartitionLatency(
|
||
|
mem_latency_sec=0.0, computer_latency_sec=0.0, overall_latency_sec=0.0
|
||
|
)
|
||
|
# Go through all top nodes and find the largest latency (critical pass latency)
|
||
|
for node in top_nodes:
|
||
|
partition_latency = dfs_helper(
|
||
|
node,
|
||
|
PartitionLatency(
|
||
|
mem_latency_sec=0.0, computer_latency_sec=0.0, overall_latency_sec=0.0
|
||
|
),
|
||
|
)
|
||
|
if (
|
||
|
partition_latency.overall_latency_sec
|
||
|
> critical_path_latency.overall_latency_sec
|
||
|
):
|
||
|
critical_path_latency = partition_latency
|
||
|
return critical_path_latency
|
||
|
|
||
|
|
||
|
def get_partition_to_latency_mapping(
|
||
|
partitions: List[Partition], node_to_latency_mapping: Dict[Node, NodeLatency]
|
||
|
) -> Dict[Partition, PartitionLatency]:
|
||
|
"""Given all the partitions and node_to_latency_mapping dictionary,
|
||
|
return a mapping dictionary of each partition to its overall latency
|
||
|
"""
|
||
|
partition_to_latency_mapping: Dict[Partition, PartitionLatency] = {}
|
||
|
# Go through each partition and get its latency
|
||
|
for partition in partitions:
|
||
|
partition_latency = get_latency_of_one_partition(
|
||
|
partition, node_to_latency_mapping
|
||
|
)
|
||
|
partition_to_latency_mapping[partition] = partition_latency
|
||
|
return partition_to_latency_mapping
|
||
|
|
||
|
|
||
|
def get_comm_latency_between(
|
||
|
parent_partition: Partition,
|
||
|
child_partition: Partition,
|
||
|
transfer_rate_bytes_per_sec: float,
|
||
|
):
|
||
|
"""Given two partitions (parent and child),
|
||
|
calculate the communication latency between the two.
|
||
|
"""
|
||
|
# If two partitions are on the same device, the comm latency is 0.
|
||
|
if (
|
||
|
parent_partition.logical_device_ids != []
|
||
|
and child_partition.logical_device_ids != []
|
||
|
and parent_partition.logical_device_ids == child_partition.logical_device_ids
|
||
|
):
|
||
|
return 0.0
|
||
|
# Keep tracking the communication size between parent and child
|
||
|
comm_size = 0
|
||
|
# Keep tracking all the counted node
|
||
|
visited_nodes = set()
|
||
|
# Go through all nodes in the child partition
|
||
|
# If a node has input nodes from the parent partition,
|
||
|
# the output size of those input nodes will be counted
|
||
|
# and added to comm_size
|
||
|
for node in child_partition.nodes:
|
||
|
input_nodes: Dict[Node, None] = {}
|
||
|
map_arg(node.args, input_nodes.setdefault)
|
||
|
map_arg(node.kwargs, input_nodes.setdefault)
|
||
|
for n in input_nodes:
|
||
|
if n in parent_partition.nodes and n not in visited_nodes:
|
||
|
size_bytes = getattr(n, "size_bytes", None)
|
||
|
if size_bytes is not None:
|
||
|
comm_size += size_bytes.output_size
|
||
|
visited_nodes.add(n)
|
||
|
return comm_size / transfer_rate_bytes_per_sec
|
||
|
|
||
|
|
||
|
def get_latency_of_partitioned_graph(
|
||
|
partitions: List[Partition],
|
||
|
partition_to_latency_mapping: Dict[Partition, PartitionLatency],
|
||
|
transfer_rate_bytes_per_sec: float,
|
||
|
):
|
||
|
"""Given all partitions in a graph, find the critical path among all partitions
|
||
|
and return its latency as the latency of the whole graph
|
||
|
"""
|
||
|
|
||
|
def dfs_helper(partition: Partition, latency_so_far_sec: float) -> float:
|
||
|
"""This function helps to recursively get the latency of a path of partitions"""
|
||
|
# Update latency by adding current partition's latency
|
||
|
latency_so_far_sec += partition_to_latency_mapping[
|
||
|
partition
|
||
|
].overall_latency_sec
|
||
|
children = partition.children
|
||
|
if partition.children:
|
||
|
max_latency_sec = 0.0
|
||
|
for child in partition.children:
|
||
|
# Calculate latency between
|
||
|
comm_latency_sec = get_comm_latency_between(
|
||
|
partition, child, transfer_rate_bytes_per_sec
|
||
|
)
|
||
|
new_latency_sec = dfs_helper(
|
||
|
child, latency_so_far_sec + comm_latency_sec
|
||
|
)
|
||
|
if new_latency_sec > max_latency_sec:
|
||
|
max_latency_sec = new_latency_sec
|
||
|
return max_latency_sec
|
||
|
return latency_so_far_sec
|
||
|
|
||
|
def get_top_partitions(partitions: List[Partition]) -> List[Partition]:
|
||
|
"""This function is to return all the partitions without parents
|
||
|
as the starting points of all the paths
|
||
|
"""
|
||
|
top_partitions = []
|
||
|
for partition in partitions:
|
||
|
# If a partition has no parents, then it is a top partition
|
||
|
if len(partition.parents) == 0:
|
||
|
top_partitions.append(partition)
|
||
|
return top_partitions
|
||
|
|
||
|
top_partitions = get_top_partitions(partitions)
|
||
|
critical_path_latency_sec = 0.0
|
||
|
for partition in top_partitions:
|
||
|
latency_sec = dfs_helper(partition, 0.0)
|
||
|
if latency_sec > critical_path_latency_sec:
|
||
|
critical_path_latency_sec = latency_sec
|
||
|
return critical_path_latency_sec
|