267 lines
9.6 KiB
Python
267 lines
9.6 KiB
Python
|
from __future__ import annotations
|
||
|
|
||
|
import itertools
|
||
|
import logging
|
||
|
|
||
|
import weakref
|
||
|
from typing import Any, List, Optional, Tuple
|
||
|
|
||
|
import torch
|
||
|
import torch.utils._pytree as pytree
|
||
|
from torch._dynamo.utils import dynamo_timed, lazy_format_graph_code
|
||
|
from torch._functorch.aot_autograd import MutationType
|
||
|
from torch._functorch.compile_utils import fx_graph_cse
|
||
|
from torch._inductor.constant_folding import constant_fold, replace_node_with_constant
|
||
|
|
||
|
from torch._inductor.fx_passes.freezing_patterns import freezing_passes
|
||
|
from torch._inductor.fx_passes.post_grad import view_to_reshape
|
||
|
|
||
|
from . import config
|
||
|
|
||
|
aten = torch.ops.aten
|
||
|
prims = torch.ops.prims
|
||
|
|
||
|
log = logging.getLogger(__name__)
|
||
|
|
||
|
|
||
|
def replace_params_with_constants(
|
||
|
gm: torch.fx.GraphModule,
|
||
|
flat_params: list[Any],
|
||
|
fw_metadata: torch._functorch.aot_autograd.ViewAndMutationMeta,
|
||
|
) -> List[int]:
|
||
|
"""
|
||
|
Replaces the parameters of a PyTorch GraphModule with constants wherever possible.
|
||
|
Returns a list of indices representing the input parameters that were not converted to constants.
|
||
|
"""
|
||
|
params = [node for node in gm.graph.nodes if node.op == "placeholder"]
|
||
|
fake_inp_nodes = params[: len(params)]
|
||
|
preserved_arg_indices = []
|
||
|
aliased_input_args = [
|
||
|
out_info.base_idx
|
||
|
for out_info in fw_metadata.output_info
|
||
|
if out_info.base_idx is not None
|
||
|
]
|
||
|
|
||
|
# TODO (tmanlaibaatar) figure out why this is different
|
||
|
# from mutated_inp_runtime_indices
|
||
|
mutated_inps = [
|
||
|
i
|
||
|
for i, m in enumerate(fw_metadata.input_info)
|
||
|
if m.mutation_type
|
||
|
in (MutationType.MUTATED_IN_GRAPH, MutationType.MUTATED_OUT_GRAPH)
|
||
|
]
|
||
|
|
||
|
for i, (real_input, node) in enumerate(zip(flat_params, fake_inp_nodes)):
|
||
|
if i in mutated_inps or i in aliased_input_args:
|
||
|
preserved_arg_indices.append(i)
|
||
|
continue
|
||
|
replace_node_with_constant(gm, node, real_input)
|
||
|
# add on non param inputs
|
||
|
preserved_arg_indices.extend(range(len(flat_params), len(params)))
|
||
|
# is this necessary ?
|
||
|
gm.recompile()
|
||
|
return preserved_arg_indices
|
||
|
|
||
|
|
||
|
def freeze(
|
||
|
dynamo_gm: torch.fx.GraphModule,
|
||
|
aot_autograd_gm: torch.fx.GraphModule,
|
||
|
example_inputs: List[torch._subclasses.FakeTensor],
|
||
|
) -> Tuple[torch.fx.GraphModule, List[int]]:
|
||
|
"""
|
||
|
Inlines parameters that are not mutated into constants and optimizes the graph through constant propagation
|
||
|
and other techniques. If enabled, the function also discards the original parameters of the module for memory efficiency.
|
||
|
|
||
|
Assumes that this function is run in dynamo tracing post aot_autograd.
|
||
|
|
||
|
Args:
|
||
|
dynamo_gm (torch.fx.GraphModule): The Dynamo constructed GraphModule.
|
||
|
aot_autograd_gm (torch.fx.GraphModule): The aot_autograd constructed GraphModule to be frozen.
|
||
|
example_inputs (List[torch.Tensor]): A list of example input tensors to be used in the freezing process.
|
||
|
|
||
|
Returns:
|
||
|
Tuple[torch.fx.GraphModule, List[int]]: A tuple containing the frozen GraphModule and a list of indices
|
||
|
of the inputs that were preserved (not turned into constants).
|
||
|
"""
|
||
|
# We have convert conv's weight to channels last which may meet error for .view
|
||
|
# when doing fake_tensor_prop. So we need to convert view to reshape first.
|
||
|
# See the details in fx_codegen_and_compile of compile_fx.py.
|
||
|
view_to_reshape(aot_autograd_gm)
|
||
|
|
||
|
if tracing_context := torch._guards.TracingContext.try_get():
|
||
|
fw_metadata = tracing_context.fw_metadata
|
||
|
params_flat = tracing_context.params_flat
|
||
|
assert fw_metadata is not None and params_flat is not None
|
||
|
|
||
|
preserved_arg_indices = replace_params_with_constants(
|
||
|
aot_autograd_gm, params_flat, fw_metadata
|
||
|
)
|
||
|
else:
|
||
|
inputs = [
|
||
|
node for node in aot_autograd_gm.graph.nodes if node.op == "placeholder"
|
||
|
]
|
||
|
preserved_arg_indices = list(range(len(inputs)))
|
||
|
|
||
|
# TODO - further restrict cse ? right now needed to dedup aliasing ops
|
||
|
cse_graph = fx_graph_cse(aot_autograd_gm.graph)
|
||
|
aot_autograd_gm.graph = cse_graph
|
||
|
aot_autograd_gm.recompile()
|
||
|
|
||
|
aot_example_inputs = [example_inputs[ind] for ind in preserved_arg_indices]
|
||
|
freezing_passes(aot_autograd_gm, aot_example_inputs)
|
||
|
|
||
|
constant_fold(aot_autograd_gm)
|
||
|
# invalidate nn Modules
|
||
|
if config.freezing_discard_parameters:
|
||
|
invalidate_eager_modules()
|
||
|
discard_traced_gm_params(dynamo_gm)
|
||
|
|
||
|
log.debug("%s", lazy_format_graph_code("FROZEN GRAPH", aot_autograd_gm))
|
||
|
|
||
|
return aot_autograd_gm, preserved_arg_indices
|
||
|
|
||
|
|
||
|
class ErasedTensor(torch.Tensor):
|
||
|
@staticmethod
|
||
|
def __new__(cls, elem, name, owning_mod):
|
||
|
return super().__new__(cls, elem.to(device="meta"))
|
||
|
|
||
|
def __init__(self, elem, name: Optional[str], mod):
|
||
|
self.erased_name = name
|
||
|
self.owning_mod_ref = weakref.ref(mod)
|
||
|
|
||
|
@classmethod
|
||
|
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
|
||
|
erased_tensors = [
|
||
|
e
|
||
|
for e in pytree.arg_tree_leaves(*args, **kwargs)
|
||
|
if isinstance(e, ErasedTensor)
|
||
|
]
|
||
|
assert len(erased_tensors) > 0
|
||
|
e = erased_tensors[0]
|
||
|
|
||
|
raise RuntimeError(
|
||
|
f"Trying to run Pytorch Eager Module after Dynamo Freezing. "
|
||
|
"The original parameters have been discarded for memory efficiency. "
|
||
|
f"Found in op {func} for erased parameter {e.erased_name} of {e.owning_mod_ref()}"
|
||
|
)
|
||
|
|
||
|
|
||
|
@torch.utils._python_dispatch._disable_current_modes()
|
||
|
def invalidate_eager_modules():
|
||
|
for mod in torch._guards.TracingContext.get().module_context.nn_modules.values():
|
||
|
if not isinstance(mod, torch.nn.Module):
|
||
|
continue
|
||
|
|
||
|
for attr_name, tensor in list(
|
||
|
itertools.chain(
|
||
|
mod.named_parameters(recurse=False), mod.named_buffers(recurse=False)
|
||
|
)
|
||
|
):
|
||
|
with torch._dispatch.python.no_python_dispatcher():
|
||
|
e_t = ErasedTensor(tensor, attr_name, mod)
|
||
|
if isinstance(tensor, torch.nn.Parameter):
|
||
|
e_t.requires_grad_(True)
|
||
|
e_t._is_param = True # type: ignore[attr-defined]
|
||
|
setattr(mod, attr_name, e_t)
|
||
|
|
||
|
|
||
|
@torch.utils._python_dispatch._disable_current_modes()
|
||
|
def discard_traced_gm_params(mod: torch.fx.GraphModule):
|
||
|
for attr_name, tensor in list(
|
||
|
itertools.chain(
|
||
|
mod.named_parameters(recurse=False), mod.named_buffers(recurse=False)
|
||
|
)
|
||
|
):
|
||
|
with torch._dispatch.python.no_python_dispatcher():
|
||
|
e_t = ErasedTensor(tensor, attr_name, mod)
|
||
|
if isinstance(tensor, torch.nn.Parameter):
|
||
|
e_t.requires_grad_(True)
|
||
|
e_t._is_param = True # type: ignore[attr-defined]
|
||
|
setattr(mod, attr_name, e_t)
|
||
|
|
||
|
|
||
|
def enforce_output_layout(gm: torch.fx.GraphModule):
|
||
|
"""
|
||
|
Make sure the output node's layout does not change due to compiler optimizations
|
||
|
by adding aten.as_strided nodes with the expected strides.
|
||
|
|
||
|
Only used for inference so we can assume all graph outputs are model outputs.
|
||
|
"""
|
||
|
*_, output_node = gm.graph.nodes
|
||
|
out_list = output_node.args[0]
|
||
|
with gm.graph.inserting_before(output_node):
|
||
|
for n in out_list:
|
||
|
if not isinstance(
|
||
|
n.meta["val"], torch.Tensor
|
||
|
) or not torch._prims_common.is_non_overlapping_and_dense(n.meta["val"]):
|
||
|
continue
|
||
|
|
||
|
# add a node to enforce eager layout
|
||
|
ft = n.meta["val"]
|
||
|
new_node = gm.graph.call_function(
|
||
|
prims.inductor_force_stride_order.default, (n, ft.stride())
|
||
|
)
|
||
|
|
||
|
# can not call
|
||
|
# n.replace_all_uses_with(new_node)
|
||
|
# since it will replace the usage of n in new_node itself.
|
||
|
output_node.replace_input_with(n, new_node)
|
||
|
|
||
|
gm.graph.lint()
|
||
|
gm.recompile()
|
||
|
|
||
|
|
||
|
def enforce_as_strided_input_layout(gm: torch.fx.GraphModule):
|
||
|
"""
|
||
|
Make sure the as_strided node's input's layout does not change due to compiler
|
||
|
optimizations, because the as_strided strides info depends on input tensor stride info.
|
||
|
"""
|
||
|
|
||
|
as_strided_ops = [
|
||
|
torch.ops.aten.as_strided.default,
|
||
|
torch.ops.aten.as_strided_.default,
|
||
|
torch.ops.aten.as_strided_scatter.default,
|
||
|
]
|
||
|
strided_nodes = [n for n in gm.graph.nodes if n.target in as_strided_ops]
|
||
|
for n in strided_nodes:
|
||
|
with gm.graph.inserting_before(n):
|
||
|
# add a node to enforce eager layout
|
||
|
ft = n.args[0].meta["val"]
|
||
|
new_node = gm.graph.call_function(
|
||
|
prims.inductor_force_stride_order.default, (n.args[0], ft.stride())
|
||
|
)
|
||
|
n.replace_input_with(n.args[0], new_node)
|
||
|
|
||
|
gm.graph.lint()
|
||
|
gm.recompile()
|
||
|
|
||
|
|
||
|
@dynamo_timed
|
||
|
def convert_conv_weights_to_channels_last(gm: torch.fx.GraphModule):
|
||
|
"""
|
||
|
Convert 4d convolution weight tensor to channels last format.
|
||
|
|
||
|
This pass is performed before freezing so the added nodes can be constant
|
||
|
folded by freezing.
|
||
|
"""
|
||
|
convs = [n for n in gm.graph.nodes if n.target == aten.convolution.default]
|
||
|
for conv in convs:
|
||
|
weight_node = conv.args[1]
|
||
|
if len(weight_node.meta["val"].size()) != 4 or weight_node.meta[
|
||
|
"val"
|
||
|
].is_contiguous(memory_format=torch.channels_last):
|
||
|
# not a 4d tensor or already channels last, skip
|
||
|
continue
|
||
|
|
||
|
with gm.graph.inserting_before(conv):
|
||
|
new_node = gm.graph.call_function(
|
||
|
aten.clone.default,
|
||
|
(weight_node,),
|
||
|
{"memory_format": torch.channels_last},
|
||
|
)
|
||
|
conv.replace_input_with(weight_node, new_node)
|
||
|
|
||
|
enforce_as_strided_input_layout(gm)
|
||
|
enforce_output_layout(gm)
|