ai-content-maker/.venv/Lib/site-packages/torch/_inductor/comms.py

364 lines
14 KiB
Python
Raw Normal View History

2024-05-03 04:18:51 +03:00
# pyre-strict
from typing import List
import torch
from . import config, ir, scheduler
from .dependencies import WeakDep
from .utils import tuple_sorted
overlap_log = torch._logging.getArtifactLogger(__name__, "overlap")
def sink_waits(
snodes: List["scheduler.BaseSchedulerNode"],
) -> List["scheduler.BaseSchedulerNode"]:
"""
Greedily moves waits as late as possible (i.e. until we reach a use). Optimal in terms of
communication overlap.
"""
new_order = []
cur_waits = set()
for snode in snodes:
if isinstance(snode.node, ir.Wait):
cur_waits.add(snode)
else:
for wait in tuple_sorted(cur_waits):
if snode in wait.node_users:
new_order.append(wait)
cur_waits.remove(wait)
new_order.append(snode)
new_order.extend(tuple_sorted(cur_waits))
return new_order
def raise_comms(
snodes: List["scheduler.BaseSchedulerNode"],
) -> List["scheduler.BaseSchedulerNode"]:
"""
Greedily moves comms as early as possible (i.e. until we reach an input).
Optimal in terms of communication overlap.
TODO: We might want to adjust this in the future to account for memory limitations.
e.g. when we are compiling FSDP, this heuristics will cause the all-gathers to be prefetched as soon as possible,
which is the beginning of the forwards pass. We'll have to either do a special pass for FSDP,
or we'll want to redo this pass with memory considerations so we handle the FSDP case in a general way.
"""
new_order_reversed: List["scheduler.BaseSchedulerNode"] = []
cur_comms: List["scheduler.BaseSchedulerNode"] = []
for snode in reversed(snodes):
if isinstance(snode.node, ir.CollectiveKernel):
cur_comms.append(snode)
else:
for comm in cur_comms:
assert len(comm.inverse_users) > 0
while len(cur_comms) > 0 and any(
snode in comm.inverse_users for comm in cur_comms
):
comm = cur_comms.pop(0)
new_order_reversed.append(comm)
new_order_reversed.append(snode)
assert len(cur_comms) <= 1
new_order_reversed.extend(tuple_sorted(cur_comms))
return new_order_reversed[::-1]
def get_ancestors(node):
ancestors = set()
cur_nodes = [node]
while len(cur_nodes) > 0:
new_nodes = []
for node in cur_nodes:
for inp in node.inverse_users:
if inp not in ancestors:
ancestors.add(inp)
new_nodes.append(inp)
cur_nodes = new_nodes
return ancestors
def get_descendants(node):
descendants = set()
cur_nodes = [node]
while len(cur_nodes) > 0:
new_nodes = []
for node in cur_nodes:
for inp in node.node_users:
if inp not in descendants:
descendants.add(inp)
new_nodes.append(inp)
cur_nodes = new_nodes
return descendants
def decide_global_ordering_of_comms(nodes: List["scheduler.BaseSchedulerNode"]):
"""
Decide global ordering of comms, by just enforcing the ordering that's in the input graph
(might not be the same ordering as the eager mode program).
TODO: Come up with a better approach
"""
comm_nodes = [n for n in nodes if isinstance(n.node, ir.CollectiveKernel)]
for i in range(1, len(comm_nodes)):
# Enforce ordering by making previous comm a `WeakDep` dependency of the next comm
comm_nodes[i].add_fake_dep(WeakDep(comm_nodes[i - 1].get_name()))
def assert_no_comm_nodes(snodes: List["scheduler.BaseSchedulerNode"]) -> None:
assert not any(isinstance(snode.node, ir.CollectiveKernel) for snode in snodes)
def estimate_op_runtime(snode: "scheduler.BaseSchedulerNode") -> float:
"""
Returns estimated op runtime in nanoseconds (ns)
"""
if config.estimate_op_runtime == "default":
runtime = snode.get_estimated_runtime()
else:
assert callable(config.estimate_op_runtime)
runtime = config.estimate_op_runtime(snode)
return runtime
def reorder_compute_for_overlap(
snodes: List["scheduler.BaseSchedulerNode"],
) -> List["scheduler.BaseSchedulerNode"]:
"""
Decides a global ordering of all compute and communication nodes,
assuming that we already have a global ordering of communication nodes.
Overall scheduling procedure is:
Step 1: Given that we've currently scheduled comm N, we now schedule all compute nodes
that are required for comm N + 1 but do not depend on comm N, to run at the same time with comm N.
Step 2: If all those compute nodes are sufficient to overlap comm N, we're done.
Otherwise, we now need to look elsewhere to find compute that overlaps with comm N.
We prioritize compute nodes that are needed sooner.
Step 3: We schedule the compute nodes dependent on comm N and required for comm N + 1.
Step 4: We schedule comm N + 1.
Repeat this for subsequent comm nodes.
"""
final_order = []
comm_nodes = []
for snode in snodes:
if isinstance(snode.node, ir.CollectiveKernel):
comm_nodes.append(snode)
if len(comm_nodes) == 0:
# if there is no comm nodes, return the current order
return snodes
comm_ancestors = {node: get_ancestors(node) for node in comm_nodes}
comm_descendants = {node: get_descendants(node) for node in comm_nodes}
indeg = dict.fromkeys(snodes, 0)
for snode in snodes:
for user in snode.node_users:
if user in indeg:
indeg[user] += 1
ready_to_schedule_nodes = {node for node in snodes if indeg[node] == 0}
unscheduled_nodes = set()
unscheduled_nodes = set(snodes)
def schedule_node(snode):
"""
Schedule a single node.
"""
assert snode in unscheduled_nodes
assert snode in ready_to_schedule_nodes
ready_to_schedule_nodes.remove(snode)
unscheduled_nodes.remove(snode)
final_order.append(snode)
for user in tuple_sorted(snode.node_users):
if user in indeg:
indeg[user] -= 1
if indeg[user] == 0:
ready_to_schedule_nodes.add(user)
def schedule_nodes(snodes):
"""
Schedules all nodes in `snodes` in an arbitrary topologically valid order.
"""
all_nodes = set(snodes)
assert all(node in unscheduled_nodes for node in all_nodes)
while len(all_nodes) > 0:
# NOTE: since model graph is always a DAG and does not have circular dependency inside,
# there should be at least one node that is a "free node" (i.e. indeg == 0),
# hence infinite loop is not possible. But we check here just to be safe.
progress = False
for node in tuple_sorted(all_nodes):
if node in ready_to_schedule_nodes:
schedule_node(node)
all_nodes.remove(node)
progress = True
if not progress:
raise Exception(
"Unable to find a free node (indeg == 0). This is an impossible state to reach. "
"Please report a bug to PyTorch."
)
# First, schedule all compute nodes that are required by first comm node,
# as well as the first comm node itself.
assert len(comm_nodes) > 0
schedule_nodes(
list(comm_ancestors[comm_nodes[0]]) + [comm_nodes[0]],
)
rolled_over_compute_cost = 0
for idx in range(1, len(comm_ancestors)):
# Step 1: Given that we've currently scheduled comm `idx-1`, we now schedule
# all compute nodes that are required for comm `idx` but do not depend on comm `idx-1`,
# to run at the same time with comm `idx-1`.
needed_by_next_comm_and_ready_compute_nodes = unscheduled_nodes & (
comm_ancestors[comm_nodes[idx]] - comm_descendants[comm_nodes[idx - 1]]
)
assert_no_comm_nodes(needed_by_next_comm_and_ready_compute_nodes)
total_compute_runtime_cost = rolled_over_compute_cost + sum(
[
estimate_op_runtime(node)
for node in needed_by_next_comm_and_ready_compute_nodes
]
)
prev_comm_runtime_cost = estimate_op_runtime(comm_nodes[idx - 1])
schedule_nodes(tuple_sorted(needed_by_next_comm_and_ready_compute_nodes))
# Step 2: If all those compute nodes are sufficient to overlap comm `idx-1`, we're done.
# Otherwise, we now need to look elsewhere to find compute that overlaps with comm `idx`.
# We prioritize compute nodes that are needed sooner.
step1_runtime_cost = total_compute_runtime_cost
if step1_runtime_cost >= prev_comm_runtime_cost:
pass
else:
# Find all ready to schedule compute nodes that do not depend on comm `idx-1`.
ready_to_schedule_compute_nodes = tuple_sorted(
ready_to_schedule_nodes - comm_descendants[comm_nodes[idx - 1]]
)
assert_no_comm_nodes(ready_to_schedule_compute_nodes)
def earliest_comm_descendant(node):
for idx in range(len(comm_nodes)):
if node in comm_ancestors[comm_nodes[idx]]:
return idx
return len(comm_nodes)
# Prioritize compute nodes that are needed sooner.
ready_to_schedule_compute_nodes = sorted(
ready_to_schedule_compute_nodes, key=earliest_comm_descendant
)
for snode in ready_to_schedule_compute_nodes:
if total_compute_runtime_cost >= prev_comm_runtime_cost:
# If accumulated compute runtime cost is greater than comm `idx-1` runtime cost,
# it means we have maximized overlap for comm `idx-1`, and hence we stop looking
# for more compute to schedule.
break
compute_runtime_cost = estimate_op_runtime(snode)
# If we're not able to leverage more than half of this
# node's compute to overlap, we skip it.
# TODO: Smarter heuristics here
if (
prev_comm_runtime_cost - total_compute_runtime_cost
) <= compute_runtime_cost / 2:
continue
schedule_node(snode)
total_compute_runtime_cost += compute_runtime_cost
rollable_compute_cost = total_compute_runtime_cost - step1_runtime_cost
# Step 3: We schedule the compute nodes dependent on comm `idx-1` and required for comm `idx`.
needed_by_next_comm_nodes = unscheduled_nodes & comm_ancestors[comm_nodes[idx]]
schedule_nodes(list(needed_by_next_comm_nodes))
# Step 4: We schedule comm `idx`.
schedule_nodes([comm_nodes[idx]])
is_prev_comm_blocking_next_comm = len(needed_by_next_comm_nodes) > 0
# The idea here is that if there are no compute nodes from Step 3
# (i.e. if prev comm is not blocking next comm), we can roll over the compute nodes
# in Step 2 to overlap with the next comm, since they're not required to finish
# before the next comm starts.
if is_prev_comm_blocking_next_comm:
rolled_over_compute_cost = 0
else:
rolled_over_compute_cost = rollable_compute_cost # type: ignore[assignment]
schedule_nodes(unscheduled_nodes)
return final_order
def node_summary(snode):
detail = ""
if isinstance(snode.node, ir.ExternKernelOut):
detail = f" ({snode.node.python_kernel_name})"
out_tensor_info = ""
if (
hasattr(snode.node, "layout")
and hasattr(snode.node.layout, "size")
and hasattr(snode.node.layout, "stride")
):
out_tensor_info = (
f" (size={snode.node.layout.size}, stride={snode.node.layout.stride})"
)
node_name = ""
if hasattr(snode.node, "name"):
node_name = snode.node.name
return f"{snode.node.__class__.__name__}{detail}{out_tensor_info} ({node_name})"
def visualize_overlap(order):
total_est_runtime: float = 0.0
cur_comm_node = None
for snode in order:
if cur_comm_node is None:
if isinstance(snode.node, ir.CollectiveKernel):
total_est_runtime += estimate_op_runtime(snode)
cur_comm_node = snode.node
elif isinstance(snode.node, ir.Wait):
raise Exception(
"Wait is not expected when there is no collective running"
)
else: # exposed compute op
total_est_runtime += estimate_op_runtime(snode)
overlap_log.debug(f"{node_summary(snode)}") # noqa: G004
else: # cur_comm_node is not None
if isinstance(snode.node, ir.CollectiveKernel):
raise Exception(
"Found two collectives running at the same time. "
"`visualize_overlap` needs to be updated to handle this case"
)
elif isinstance(snode.node, ir.Wait): # end of this comm op
overlap_log.debug(f"{node_summary(snode)}") # noqa: G004
cur_comm_node = None
else: # overlapped compute op
overlap_log.debug(f"| {node_summary(snode)}") # noqa: G004
overlap_log.debug(
f"Est. runtime (ms): {total_est_runtime / 1000 / 1000}" # noqa: G004
)
def reorder_compute_and_comm_for_overlap(
snodes: List["scheduler.BaseSchedulerNode"],
) -> List["scheduler.BaseSchedulerNode"]:
order = snodes
for p in config.reorder_for_compute_comm_overlap_passes:
if isinstance(p, str) and p in globals():
p = globals()[p] # it is a builtin pass
if torch.distributed.get_rank() == 0:
overlap_log.debug(
f"==== Visualize overlap before reordering pass {p} ====" # noqa: G004
)
try:
visualize_overlap(order)
except Exception as e:
overlap_log.debug(str(e))
order = p(order) # type: ignore[operator]
if torch.distributed.get_rank() == 0:
overlap_log.debug(
f"==== Visualize overlap after reordering pass {p} ====" # noqa: G004
)
try:
visualize_overlap(order)
except Exception as e:
overlap_log.debug(str(e))
return order