# mypy: ignore-errors import functools import operator from collections import defaultdict from typing import Dict, List, Optional import torch from torch._dynamo.backends.debugging import boxed_nop from torch._inductor.cudagraph_trees import cudagraphify_impl from torch._inductor.cudagraph_utils import ( BoxedDeviceIndex, check_multiple_devices_or_any_cpu_nodes, get_mutation_stack_trace, ) from torch._inductor.utils import ( BoxedBool, count_tangents, has_incompatible_cudagraph_ops, num_fw_fixed_arguments, output_node, ) from torch.multiprocessing.reductions import StorageWeakRef from .common import aot_autograd from .registry import register_backend perf_log = torch._logging.getArtifactLogger(__name__, "perf_hints") def find_input_mutations(g): def meta_fk(meta): return meta["val"] if "val" in meta else meta["fake_result"] inputs = defaultdict(set) input_idx = 0 mutated_inputs = set() for n in g.nodes: if n.op == "placeholder": if isinstance(meta_fk(n.meta), torch.Tensor): inputs[StorageWeakRef(meta_fk(n.meta)._typed_storage())].add(input_idx) input_idx += 1 elif n.op == "call_function": if n.target is operator.getitem: continue schema = n.target._schema for i, arg in enumerate(schema.arguments): if i < len(n.args): argument = n.args[i] else: if arg.name not in n.kwargs: continue argument = n.kwargs[arg.name] mut_arg = False if arg.alias_info: if arg.alias_info.is_write: mut_arg = True if mut_arg: # TODO: not correct for args that contain tensors in a struct # like list mutated_inputs |= inputs[ StorageWeakRef(meta_fk(argument.meta)._typed_storage()) ] # TODO: error on unrecognized nodes return mutated_inputs def get_device_node_mapping(gm: torch.fx.GraphModule): device_node_mapping: Dict[torch.device, torch.fx.Node] = {} for n in gm.graph.nodes: t = n.meta.get("val", None) if isinstance(t, torch.Tensor) and t.device not in device_node_mapping: device_node_mapping[t.device] = n return device_node_mapping def check_for_mutation(aot_model: torch.fx.GraphModule, num_fixed) -> Optional[str]: mutation_indices = find_input_mutations(aot_model.graph) - set(range(num_fixed)) if not mutation_indices: return None return get_mutation_stack_trace(aot_model, mutation_indices) def check_for_skip(aot_model: torch.fx.GraphModule, num_fixed) -> Optional[str]: if mut_skip := check_for_mutation(aot_model, num_fixed): return mut_skip if skip := check_multiple_devices_or_any_cpu_nodes( get_device_node_mapping(aot_model) ): return skip if has_incompatible_cudagraph_ops(aot_model): return "skipping cudagraphs due to incompatible op" return None def get_device_index(gm) -> int: device = next(iter(get_device_node_mapping(gm))) assert device.type == "cuda" return device.index def get_stack_traces(gm) -> List[Optional[str]]: output = output_node(gm) assert len(output.args) == 1 return [ (arg.stack_trace if isinstance(arg, torch.fx.node.Node) else None) for arg in output.args[0] ] def cudagraphs(dynamo_model, dynamo_inputs): do_cudagraphs = BoxedBool(True) boxed_device_index = BoxedDeviceIndex(None) def forward_cudagraphs(aot_model, aot_inputs, is_inference=False): interp = boxed_nop(aot_model, aot_inputs) fixed = num_fw_fixed_arguments(len(dynamo_inputs), len(aot_inputs)) if skip_msg := check_for_skip(aot_model, fixed): BoxedBool.disable(do_cudagraphs) perf_log.warning("skipping cudagraphs due to %s", skip_msg) return interp boxed_device_index.set(get_device_index(aot_model)) out = cudagraphify_impl( interp, aot_inputs, range(fixed), device_index=boxed_device_index.value, is_backward=False, is_inference=False, stack_traces=get_stack_traces(aot_model), ) out._boxed_call = True return out def backward_cudagraphs(aot_model, aot_inputs): interp = boxed_nop(aot_model, aot_inputs) if not do_cudagraphs: return aot_model fixed = count_tangents(aot_model) if skip_msg := check_for_skip(aot_model, fixed): perf_log.warning("skipping cudagraphs due to %s", skip_msg) # See [Backward Generation Handling] manager = torch._inductor.cudagraph_trees.get_manager( boxed_device_index.value, create_if_none_exists=False ) assert manager is not None def fn(inputs): manager.set_to_running_backward() return aot_model(inputs) fn._boxed_call = True return fn out = cudagraphify_impl( interp, aot_inputs, range(fixed), device_index=get_device_index(aot_model), is_backward=True, is_inference=False, stack_traces=get_stack_traces(aot_model), ) out._boxed_call = True return out aot_cudagraphs = aot_autograd( fw_compiler=forward_cudagraphs, bw_compiler=backward_cudagraphs, inference_compiler=functools.partial(forward_cudagraphs, is_inference=True), keep_inference_input_mutations=torch._dynamo.config.cudagraph_backend_keep_input_mutation, ) return aot_cudagraphs(dynamo_model, dynamo_inputs) class CudagraphsBackend: compiler_name = "cudagraphs" @staticmethod def reset(): from torch._inductor.cudagraph_trees import reset_cudagraph_trees reset_cudagraph_trees() @staticmethod def __call__(model, inputs): return cudagraphs(model, inputs) # aot_cudagraphs only applies CUDA graphs to the graph. It is also helpful # for debugging and can serve as a perf baseline. register_backend(name="cudagraphs", compiler_fn=CudagraphsBackend()) def cudagraphs_inner(model, inputs, copy_outputs=True, copy_inputs=True): """This isn't registered as a backend, but is used in some benchmarks""" assert isinstance(inputs, (list, tuple)) if copy_inputs: static_inputs = [torch.zeros_like(x) for x in inputs] else: static_inputs = list(inputs) # warmup torch.cuda.synchronize() stream = torch.cuda.Stream() stream.wait_stream(torch.cuda.current_stream()) with torch.cuda.stream(stream): model(*inputs) stream.synchronize() torch.cuda.current_stream().wait_stream(stream) torch.cuda.synchronize() # record graph = torch.cuda.CUDAGraph() with torch.cuda.graph(graph, stream=stream): static_outputs = model(*static_inputs) if not isinstance(static_outputs, (list, tuple)): static_outputs = (static_outputs,) def run(*new_inputs): assert len(static_inputs) == len(new_inputs) if copy_inputs: for dst, src in zip(static_inputs, new_inputs): dst.copy_(src) graph.replay() if copy_outputs: return [x.clone() for x in static_outputs] else: return static_outputs return run