402 lines
15 KiB
Python
402 lines
15 KiB
Python
|
import dataclasses
|
||
|
import math
|
||
|
import operator
|
||
|
from typing import Any, Dict, Iterable, List, Optional, Tuple, Type
|
||
|
|
||
|
import torch
|
||
|
from torch._subclasses.fake_tensor import FakeTensor
|
||
|
|
||
|
from torch.export import ExportedProgram
|
||
|
from torch.utils._pytree import (
|
||
|
_register_pytree_node,
|
||
|
Context,
|
||
|
FlattenFunc,
|
||
|
FromDumpableContextFn,
|
||
|
KeyPath,
|
||
|
keystr,
|
||
|
MappingKey,
|
||
|
SequenceKey,
|
||
|
ToDumpableContextFn,
|
||
|
UnflattenFunc,
|
||
|
)
|
||
|
|
||
|
|
||
|
def _check_input_constraints_for_graph(
|
||
|
input_placeholders: List[torch.fx.Node], flat_args_with_path, range_constraints
|
||
|
):
|
||
|
def get_keystr(key_path: KeyPath) -> str:
|
||
|
"""For a given index into the flat_args, return a human readable string
|
||
|
describing how to access it, e.g. "*args["foo"][0].bar"
|
||
|
"""
|
||
|
# Prefix the keypath with "*args" or "**kwargs" to make it clearer where
|
||
|
# the arguments come from. Ultimately we ought to serialize the
|
||
|
# original arg names for the best error message here.
|
||
|
args_kwargs_key_path = key_path[0]
|
||
|
assert isinstance(args_kwargs_key_path, SequenceKey)
|
||
|
if args_kwargs_key_path.idx == 0:
|
||
|
return f"*args{keystr(key_path[1:])}"
|
||
|
else:
|
||
|
kwarg_key = key_path[1]
|
||
|
assert isinstance(kwarg_key, MappingKey)
|
||
|
name = str(kwarg_key)[1:-1] # get rid of the enclosed []
|
||
|
return f"{name}{keystr(key_path[2:])}"
|
||
|
|
||
|
import sympy
|
||
|
|
||
|
from torch._export.passes.add_runtime_assertions_for_constraints_pass import (
|
||
|
_convert_range_to_int,
|
||
|
)
|
||
|
from torch.utils._sympy.solve import try_solve
|
||
|
|
||
|
if len(flat_args_with_path) != len(input_placeholders):
|
||
|
raise RuntimeError(
|
||
|
"Unexpected number of inputs "
|
||
|
f"(expected {len(input_placeholders)}, got {len(flat_args_with_path)})"
|
||
|
)
|
||
|
# NOTE: export already guarantees that the same symbol is used in metadata
|
||
|
# for all InputDims related by equality constraints, so we can just unify
|
||
|
# symbols with given input dimension values to check equality constraints.
|
||
|
unification_map: "Dict[sympy.Symbol, Any]" = {}
|
||
|
for (key_path, arg), node in zip(flat_args_with_path, input_placeholders):
|
||
|
node_val = node.meta.get("val")
|
||
|
if isinstance(node_val, FakeTensor):
|
||
|
if not isinstance(arg, torch.Tensor):
|
||
|
raise RuntimeError(
|
||
|
f"Expected input at {get_keystr(key_path)} to be a tensor, but got {type(arg)}",
|
||
|
)
|
||
|
|
||
|
if len(node_val.shape) != len(arg.shape):
|
||
|
raise RuntimeError(
|
||
|
f"Unexpected number of dimensions in input at {get_keystr(key_path)}.shape "
|
||
|
f"(expected {node_val.shape}, got {arg.shape})"
|
||
|
)
|
||
|
|
||
|
for j, (arg_dim, node_dim) in enumerate(zip(arg.shape, node_val.shape)):
|
||
|
# TODO(avik): Assert the following property in the IR verifier:
|
||
|
# node_dim is either an int or a SymInt containing an int or a unary sympy.Expr
|
||
|
if (
|
||
|
isinstance(node_dim, torch.SymInt)
|
||
|
and len(node_dim.node.expr.free_symbols) == 1
|
||
|
):
|
||
|
symbol = next(iter(node_dim.node.expr.free_symbols))
|
||
|
if symbol in unification_map:
|
||
|
existing_dim = node_dim.node.expr.subs(unification_map)
|
||
|
if arg_dim != existing_dim:
|
||
|
raise RuntimeError(
|
||
|
f"Expected input at {get_keystr(key_path)}.shape[{j}] to be equal to "
|
||
|
f"{existing_dim}, but got {arg_dim}",
|
||
|
)
|
||
|
else:
|
||
|
if (
|
||
|
isinstance(arg_dim, torch.SymInt)
|
||
|
and not arg_dim.node.expr.is_number
|
||
|
):
|
||
|
# This can happen when, say, arg is a fake tensor.
|
||
|
# We do not run checks on symbolic shapes of fake inputs as
|
||
|
# such checks can affect the shape env.
|
||
|
pass
|
||
|
else:
|
||
|
solution = try_solve(
|
||
|
sympy.Eq(node_dim.node.expr, arg_dim), symbol
|
||
|
)
|
||
|
if solution is None:
|
||
|
raise RuntimeError( # noqa: TRY200
|
||
|
f"Expected input {node.name}.shape[{j}] = {arg_dim} to be "
|
||
|
f"of the form {node_dim.node.expr}, where {symbol} is an integer"
|
||
|
)
|
||
|
else:
|
||
|
unification_map[symbol] = int(solution[1])
|
||
|
|
||
|
if node_dim.node.expr in range_constraints:
|
||
|
min_val, max_val = _convert_range_to_int(
|
||
|
range_constraints[node_dim.node.expr]
|
||
|
)
|
||
|
# NOTE: we allow dimensions to be 0/1 at runtime
|
||
|
if min_val > 2:
|
||
|
if arg_dim < min_val:
|
||
|
raise RuntimeError(
|
||
|
f"Expected input at {get_keystr(key_path)}.shape[{j}] to be >= "
|
||
|
f"{min_val}, but got {arg_dim}",
|
||
|
)
|
||
|
if max_val < math.inf:
|
||
|
if arg_dim > max_val:
|
||
|
raise RuntimeError(
|
||
|
f"Expected input at {get_keystr(key_path)}.shape[{j}] to be <= "
|
||
|
f"{max_val}, but got {arg_dim}",
|
||
|
)
|
||
|
else:
|
||
|
if arg_dim != node_dim:
|
||
|
raise RuntimeError(
|
||
|
f"Expected input at {get_keystr(key_path)}.shape[{j}] to be equal to "
|
||
|
f"{node_dim}, but got {arg_dim}",
|
||
|
)
|
||
|
elif isinstance(node_val, (int, float, str)):
|
||
|
if type(arg) != type(node_val) or arg != node_val:
|
||
|
raise RuntimeError(
|
||
|
f"Expected input at {get_keystr(key_path)} to be equal to {node_val}, but got {arg}",
|
||
|
)
|
||
|
|
||
|
|
||
|
def register_dataclass_as_pytree_node(
|
||
|
cls: Type[Any],
|
||
|
flatten_fn: Optional[FlattenFunc] = None,
|
||
|
unflatten_fn: Optional[UnflattenFunc] = None,
|
||
|
*,
|
||
|
serialized_type_name: Optional[str] = None,
|
||
|
to_dumpable_context: Optional[ToDumpableContextFn] = None,
|
||
|
from_dumpable_context: Optional[FromDumpableContextFn] = None,
|
||
|
return_none_fields: bool = False,
|
||
|
) -> None:
|
||
|
assert dataclasses.is_dataclass(
|
||
|
cls
|
||
|
), f"Only dataclasses can be registered with this function: {cls}"
|
||
|
|
||
|
def default_flatten_fn(obj: Any) -> Tuple[List[Any], Context]:
|
||
|
flattened = []
|
||
|
flat_names = []
|
||
|
none_names = []
|
||
|
for f in dataclasses.fields(obj):
|
||
|
name, val = f.name, getattr(obj, f.name)
|
||
|
if val is not None or return_none_fields:
|
||
|
flattened.append(val)
|
||
|
flat_names.append(name)
|
||
|
else:
|
||
|
none_names.append(name)
|
||
|
return flattened, [flat_names, none_names]
|
||
|
|
||
|
def default_unflatten_fn(values: Iterable[Any], context: Context) -> Any:
|
||
|
flat_names, none_names = context
|
||
|
return cls(**dict(zip(flat_names, values)), **dict.fromkeys(none_names))
|
||
|
|
||
|
flatten_fn = flatten_fn if flatten_fn is not None else default_flatten_fn
|
||
|
unflatten_fn = unflatten_fn if unflatten_fn is not None else default_unflatten_fn
|
||
|
|
||
|
if (to_dumpable_context is None) ^ (from_dumpable_context is None):
|
||
|
raise ValueError(
|
||
|
f"Both to_dumpable_context and from_dumpable_context for {cls} must "
|
||
|
"be None or registered."
|
||
|
)
|
||
|
|
||
|
_register_pytree_node(
|
||
|
cls,
|
||
|
flatten_fn,
|
||
|
unflatten_fn,
|
||
|
serialized_type_name=serialized_type_name,
|
||
|
to_dumpable_context=to_dumpable_context,
|
||
|
from_dumpable_context=from_dumpable_context,
|
||
|
)
|
||
|
|
||
|
|
||
|
def is_param(program: ExportedProgram, node: torch.fx.Node) -> bool:
|
||
|
"""
|
||
|
Checks if the given node is a parameter within the exported program
|
||
|
"""
|
||
|
|
||
|
return node.name in program.graph_signature.inputs_to_parameters
|
||
|
|
||
|
|
||
|
def get_param(
|
||
|
program: ExportedProgram,
|
||
|
node: torch.fx.Node,
|
||
|
) -> Optional[torch.nn.Parameter]:
|
||
|
"""
|
||
|
Returns the parameter associated with the given node in the exported program.
|
||
|
Returns None if the node is not a parameter within the exported program
|
||
|
"""
|
||
|
|
||
|
if is_param(program, node):
|
||
|
parameter_name = program.graph_signature.inputs_to_parameters[node.name]
|
||
|
return program.state_dict[parameter_name]
|
||
|
|
||
|
return None
|
||
|
|
||
|
|
||
|
def is_buffer(program: ExportedProgram, node: torch.fx.Node) -> bool:
|
||
|
"""
|
||
|
Checks if the given node is a buffer within the exported program
|
||
|
"""
|
||
|
|
||
|
return node.name in program.graph_signature.inputs_to_buffers
|
||
|
|
||
|
|
||
|
def get_buffer(
|
||
|
program: ExportedProgram,
|
||
|
node: torch.fx.Node,
|
||
|
) -> Optional[torch.Tensor]:
|
||
|
"""
|
||
|
Returns the buffer associated with the given node in the exported program.
|
||
|
Returns None if the node is not a buffer within the exported program
|
||
|
"""
|
||
|
|
||
|
if is_buffer(program, node):
|
||
|
buffer_name = program.graph_signature.inputs_to_buffers[node.name]
|
||
|
if buffer_name in program.graph_signature.non_persistent_buffers:
|
||
|
return program.constants[buffer_name]
|
||
|
else:
|
||
|
return program.state_dict[buffer_name]
|
||
|
|
||
|
return None
|
||
|
|
||
|
|
||
|
def is_lifted_tensor_constant(
|
||
|
program: ExportedProgram,
|
||
|
node: torch.fx.Node,
|
||
|
) -> bool:
|
||
|
"""
|
||
|
Checks if the given node is a lifted tensor constant within the exported program
|
||
|
"""
|
||
|
|
||
|
return node.name in program.graph_signature.inputs_to_lifted_tensor_constants
|
||
|
|
||
|
|
||
|
def get_lifted_tensor_constant(
|
||
|
program: ExportedProgram,
|
||
|
node: torch.fx.Node,
|
||
|
) -> Optional[torch.Tensor]:
|
||
|
"""
|
||
|
Returns the lifted tensor constant associated with the given node in the exported program.
|
||
|
Returns None if the node is not a lifted tensor constant within the exported program
|
||
|
"""
|
||
|
|
||
|
if is_lifted_tensor_constant(program, node):
|
||
|
lifted_tensor_name = program.graph_signature.inputs_to_lifted_tensor_constants[
|
||
|
node.name
|
||
|
]
|
||
|
return program.constants[lifted_tensor_name]
|
||
|
|
||
|
return None
|
||
|
|
||
|
|
||
|
def sequential_split(gm: torch.fx.GraphModule, node_call_back) -> torch.fx.GraphModule:
|
||
|
"""
|
||
|
Splits the graph module into multiple submodules based on the node_call_back.
|
||
|
The node_call_back should return True if the node is a delimiter. Delimiter will be
|
||
|
the first node in the next submodule.
|
||
|
"""
|
||
|
from torch.fx.passes.split_module import split_module
|
||
|
|
||
|
split_map = {}
|
||
|
split_id = 0
|
||
|
for node in gm.graph.nodes:
|
||
|
if node_call_back(node):
|
||
|
split_id += 1
|
||
|
split_map[node] = split_id
|
||
|
|
||
|
new_gm = split_module(
|
||
|
gm,
|
||
|
gm,
|
||
|
lambda node: split_map[node],
|
||
|
keep_original_order=True,
|
||
|
keep_original_node_name=True,
|
||
|
)
|
||
|
# Keep the codegen from original graph module to preserve e.g. pytree info.
|
||
|
new_gm.graph._codegen = gm.graph._codegen
|
||
|
new_gm.recompile()
|
||
|
return new_gm
|
||
|
|
||
|
|
||
|
def nodes_filter(nodes: List[torch.fx.Node], node_call_back) -> List[torch.fx.Node]:
|
||
|
"""Returns the nodes that match the node_call_back as a list."""
|
||
|
return [node for node in nodes if node_call_back(node)]
|
||
|
|
||
|
|
||
|
def nodes_first(
|
||
|
nodes: List[torch.fx.Node], node_call_back=None
|
||
|
) -> Optional[torch.fx.Node]:
|
||
|
"""
|
||
|
Returns the first node that matches the node_call_back. If no node matches, returns None.
|
||
|
When node_call_back is None, returns the first node in the node list.
|
||
|
"""
|
||
|
ret = nodes_filter(nodes, node_call_back if node_call_back else lambda node: True)
|
||
|
if len(ret) > 0:
|
||
|
return ret[0]
|
||
|
return None
|
||
|
|
||
|
|
||
|
def nodes_count(nodes: List[torch.fx.Node], node_call_back) -> int:
|
||
|
"""Returns the number of nodes that match the node_call_back."""
|
||
|
return len(nodes_filter(nodes, node_call_back))
|
||
|
|
||
|
|
||
|
def nodes_map(nodes: List[torch.fx.Node], node_call_back) -> List[torch.fx.Node]:
|
||
|
"""
|
||
|
Sequentially visit the nodes list and invoke node_call_back on each element.
|
||
|
Returns the nodes list after the node_call_back is invoked on each element.
|
||
|
"""
|
||
|
for node in nodes:
|
||
|
node_call_back(node)
|
||
|
return nodes
|
||
|
|
||
|
|
||
|
def node_replace_(
|
||
|
old_node: torch.fx.Node, new_node: torch.fx.Node, delete_old: bool = False
|
||
|
) -> None:
|
||
|
"""
|
||
|
Replace all uses of old_node with new_node.
|
||
|
"""
|
||
|
old_node.replace_all_uses_with(new_node)
|
||
|
if delete_old:
|
||
|
old_node.users.clear()
|
||
|
old_node.graph.erase_node(old_node)
|
||
|
|
||
|
|
||
|
def node_inline_(call_mod_node: torch.fx.Node) -> None:
|
||
|
"""
|
||
|
Inline the submodule of the given node into the parent module.
|
||
|
Note: we only support the case where submodule takes tensors inputs.
|
||
|
"""
|
||
|
assert call_mod_node.op == "call_module"
|
||
|
gm = call_mod_node.graph.owning_module
|
||
|
|
||
|
assert isinstance(call_mod_node.target, str)
|
||
|
sub_gm = getattr(gm, call_mod_node.target)
|
||
|
|
||
|
phs = (node for node in sub_gm.graph.nodes if node.op == "placeholder")
|
||
|
body = (
|
||
|
node for node in sub_gm.graph.nodes if node.op not in ("placeholder", "output")
|
||
|
)
|
||
|
output = [node for node in sub_gm.graph.nodes if node.op == "output"]
|
||
|
|
||
|
for ph, arg in zip(phs, call_mod_node.args):
|
||
|
assert isinstance(arg, torch.fx.Node)
|
||
|
node_replace_(ph, arg, delete_old=True)
|
||
|
|
||
|
with gm.graph.inserting_before(call_mod_node):
|
||
|
for node in body:
|
||
|
new_node = gm.graph.node_copy(node)
|
||
|
node_replace_(node, new_node, delete_old=True)
|
||
|
|
||
|
if len(output) > 0:
|
||
|
assert len(output) == 1 and len(output[0].args) == 1
|
||
|
new_output = output[0].args[0]
|
||
|
|
||
|
if isinstance(new_output, torch.fx.Node):
|
||
|
node_replace_(call_mod_node, new_output, delete_old=True)
|
||
|
elif isinstance(new_output, (list, tuple)):
|
||
|
# Inline the get_item calls for the output node.
|
||
|
get_item_users = nodes_filter(
|
||
|
list(call_mod_node.users.keys()),
|
||
|
lambda node: node.op == "call_function"
|
||
|
and node.target == operator.getitem,
|
||
|
)
|
||
|
# get_item_node.args[1] is the idx referring to new_output[idx]
|
||
|
nodes_map(
|
||
|
get_item_users,
|
||
|
lambda get_item_node: node_replace_(
|
||
|
get_item_node,
|
||
|
new_output[get_item_node.args[1]],
|
||
|
delete_old=True,
|
||
|
),
|
||
|
)
|
||
|
call_mod_node.graph.erase_node(call_mod_node)
|
||
|
else:
|
||
|
raise NotImplementedError(
|
||
|
f"Unsupported output type {type(new_output)}. Expect it to be a Node or a list/tuple of Nodes."
|
||
|
)
|
||
|
else:
|
||
|
call_mod_node.graph.erase_node(call_mod_node)
|
||
|
|
||
|
gm.delete_all_unused_submodules()
|
||
|
gm.recompile()
|
||
|
return gm
|