576 lines
20 KiB
Python
576 lines
20 KiB
Python
|
from abc import ABC, abstractmethod
|
||
|
from contextlib import contextmanager, nullcontext
|
||
|
from copy import copy
|
||
|
from dataclasses import dataclass
|
||
|
from functools import partial, wraps
|
||
|
from typing import Any, Callable, cast, Dict, List, Optional, Set, Tuple, Union
|
||
|
|
||
|
from functorch import make_fx
|
||
|
|
||
|
import torch
|
||
|
import torch.distributed as dist
|
||
|
|
||
|
# We need to import _functional_collectives to trigger op registration
|
||
|
import torch.distributed._functional_collectives
|
||
|
import torch.nn as nn
|
||
|
import torch.utils._pytree as pytree
|
||
|
|
||
|
from torch import fx
|
||
|
from torch._decomp.decompositions import native_layer_norm_backward
|
||
|
|
||
|
from torch._subclasses.fake_tensor import FakeTensorMode
|
||
|
from torch.distributed._spmd.data_parallel import gradients_tagging
|
||
|
from torch.distributed._spmd.parallel_mode import (
|
||
|
DataParallel,
|
||
|
DTensorExpandMode,
|
||
|
ParallelMode,
|
||
|
)
|
||
|
from torch.distributed._tensor import Placement
|
||
|
from torch.fx.graph import _PyTreeCodeGen, _PyTreeInfo, CodeGen
|
||
|
from torch.nn.utils import stateless
|
||
|
from torch.nn.utils._named_member_accessor import NamedMemberAccessor
|
||
|
|
||
|
|
||
|
class Override(ABC):
|
||
|
r"""Override the tracing and transformation behavior of :meth:`~torch.distributed._spmd.compile`.
|
||
|
|
||
|
This is useful when any part of the model is not traceable or if you prefer
|
||
|
to not trace it due to any reason. More specifically, users can implement
|
||
|
:meth:`torch.distributed._spmd.Override.replacement` to replace an original
|
||
|
submodule with the return new submodule. The new submodule contains
|
||
|
operations that users preferred to be traced, which simply be a dummy
|
||
|
placeholder operator. After tracing, users can implement
|
||
|
:meth:`torch.distributed._spmd.Override.transform` to transform the traced
|
||
|
graph, where the dummy placeholder operator serves as an anchor to insert
|
||
|
new sub-graphs.
|
||
|
"""
|
||
|
|
||
|
@abstractmethod
|
||
|
def replacement(self, fqn: str, orig_submodule: torch.nn.Module) -> torch.nn.Module:
|
||
|
r"""Implement this method to return a new :class:`nn.Module` instance to replace the ``orig_submodule``
|
||
|
argument in the model.
|
||
|
|
||
|
This helps if ``orig_submodule`` is not traceable or should not be traced.
|
||
|
|
||
|
Args:
|
||
|
fqn (str): fully quantified name of the submodule.
|
||
|
orig_submodule (class:`nn.Module`): original submodule instance to replace.
|
||
|
|
||
|
Returns:
|
||
|
A new :class:`nn.Module` instance to replace the original one.
|
||
|
|
||
|
"""
|
||
|
pass
|
||
|
|
||
|
@abstractmethod
|
||
|
def transform(
|
||
|
self,
|
||
|
gm: fx.GraphModule,
|
||
|
flat_state: List[torch.Tensor],
|
||
|
) -> fx.GraphModule:
|
||
|
r"""
|
||
|
Given a DTensor-expanded graph and sharding schema for every node,
|
||
|
conduct additional transformation for the sub-graph from the :class:`nn.Module`
|
||
|
returned by :meth:`torch.distributed._spmd.Override.replacement` if
|
||
|
necessary.
|
||
|
|
||
|
Args:
|
||
|
gm (:class:`fx.Graph`): a DTensor-expanded graph.
|
||
|
flat_state (List[str, :class:`Tensor`]): a reference to the list of
|
||
|
flattened state. The elements in ``flat_state`` map to the first
|
||
|
``len(flat_state)`` placeholders in the graph. The transformation
|
||
|
can add state to or remove state from ``flat_state`` as long as
|
||
|
it keeps ``flat_state`` and the placeholders consistent.
|
||
|
|
||
|
Returns:
|
||
|
The :class:`fx.Graph` after transformation.
|
||
|
|
||
|
"""
|
||
|
pass
|
||
|
|
||
|
|
||
|
class _PyTreeCodeGenOutputsOnly(_PyTreeCodeGen):
|
||
|
# pyre-ignore[3]
|
||
|
def process_inputs(self, *args: Any) -> Any:
|
||
|
return args
|
||
|
|
||
|
# pyre-ignore[2, 3]
|
||
|
def gen_fn_def(self, free_vars, maybe_return_annotation):
|
||
|
return CodeGen.gen_fn_def(self, free_vars, maybe_return_annotation)
|
||
|
|
||
|
|
||
|
def _to_caller_flattened_graph_module(gm: torch.fx.GraphModule) -> torch.fx.GraphModule:
|
||
|
"""Move the responsibility of flattening the input arguments from the graph module to the caller.
|
||
|
|
||
|
Example:
|
||
|
|
||
|
output = gm(my_struct)
|
||
|
|
||
|
gm = gm(to_caller_flattened_graph_module)
|
||
|
|
||
|
output = gm(*pytree.flatten(my_struct)[0])
|
||
|
|
||
|
"""
|
||
|
# pyre-ignore[16]
|
||
|
gm._graph._codegen = _PyTreeCodeGenOutputsOnly(
|
||
|
pytree_info=_PyTreeInfo(
|
||
|
# pyre-ignore[6]
|
||
|
orig_args=None, # type: ignore[arg-type]
|
||
|
# pyre-ignore[6]
|
||
|
in_spec=None, # type: ignore[arg-type]
|
||
|
# pyre-ignore[16]
|
||
|
out_spec=gm._graph._codegen.pytree_info.out_spec,
|
||
|
)
|
||
|
)
|
||
|
gm.recompile()
|
||
|
return gm
|
||
|
|
||
|
|
||
|
# Use a dtensor expand mode for now to preserve the old behavior
|
||
|
# and avoid breaking existing code
|
||
|
dtensor_expand_mode = DTensorExpandMode()
|
||
|
|
||
|
|
||
|
def _override_placements(t: torch.Tensor, placements: List[Placement]):
|
||
|
global dtensor_expand_mode
|
||
|
dtensor_expand_mode._placements_override[id(t)] = placements
|
||
|
|
||
|
|
||
|
@contextmanager
|
||
|
def _rematerialize_optimizer(
|
||
|
opt: torch.optim.Optimizer,
|
||
|
named_states: Dict[str, Any],
|
||
|
params: Dict[str, nn.Parameter],
|
||
|
):
|
||
|
assert opt is not None
|
||
|
|
||
|
# update opt.state with proxy tensors
|
||
|
orig_states = copy(opt.state)
|
||
|
for n in named_states:
|
||
|
# opt.state's key type is string, but optimizer uses Parameter as keys
|
||
|
opt.state[params[n]] = named_states[n] # type: ignore[index]
|
||
|
|
||
|
# FIXME: support multiple parameter groups
|
||
|
param_group = opt.param_groups[0]
|
||
|
orig_params = param_group["params"]
|
||
|
param_group["params"] = params.values()
|
||
|
|
||
|
try:
|
||
|
yield
|
||
|
finally:
|
||
|
param_group["params"] = orig_params
|
||
|
opt.state = orig_states
|
||
|
|
||
|
|
||
|
aten = torch.ops.aten # pyre-ignore
|
||
|
|
||
|
|
||
|
@contextmanager
|
||
|
def _enable_compile():
|
||
|
# The return value of torch._utils.is_compiling changes optimizer behavior.
|
||
|
# We need that function to return True to include optimizer in the graph.
|
||
|
# See: https://github.com/pytorch/pytorch/blob/a524123c91ab399c9dd6882c1189596dd77e7734/torch/optim/optimizer.py#L41
|
||
|
def f_true():
|
||
|
return True
|
||
|
|
||
|
orig_is_compiling_code = torch._utils.is_compiling.__code__
|
||
|
torch._utils.is_compiling.__code__ = f_true.__code__
|
||
|
try:
|
||
|
yield
|
||
|
finally:
|
||
|
torch._utils.is_compiling.__code__ = orig_is_compiling_code
|
||
|
|
||
|
|
||
|
def _foreach_add_decomp(self, other, alpha=1):
|
||
|
self_updated = aten._foreach_add.List(self, other, alpha=alpha)
|
||
|
for s, s_u in zip(self, self_updated):
|
||
|
s.copy_(s_u)
|
||
|
|
||
|
|
||
|
def _foreach_unaop_decomp(op, self):
|
||
|
self_updated = op(self)
|
||
|
for s, s_u in zip(self, self_updated):
|
||
|
s.copy_(s_u)
|
||
|
|
||
|
|
||
|
def _foreach_binop_list_decomp(op, self, other):
|
||
|
self_updated = op(self, other)
|
||
|
for s, s_u in zip(self, self_updated):
|
||
|
s.copy_(s_u)
|
||
|
|
||
|
|
||
|
def _foreach_binop_scalar_decomp(op, self, scalar=1):
|
||
|
self_updated = op(self, scalar)
|
||
|
for s, s_u in zip(self, self_updated):
|
||
|
s.copy_(s_u)
|
||
|
|
||
|
|
||
|
def _foreach_addcop_scalar_decomp(op, self, tensor1, tensor2, scalar=1):
|
||
|
self_updated = op(self, tensor1, tensor2, scalar)
|
||
|
for s, s_u in zip(self, self_updated):
|
||
|
s.copy_(s_u)
|
||
|
|
||
|
|
||
|
def _fused_adam_decomp(
|
||
|
self,
|
||
|
grads,
|
||
|
exp_avgs,
|
||
|
exp_avg_sqs,
|
||
|
max_exp_avg_sqs,
|
||
|
state_steps,
|
||
|
*,
|
||
|
lr=1,
|
||
|
beta1=1,
|
||
|
beta2=1,
|
||
|
weight_decay=1,
|
||
|
eps=1,
|
||
|
amsgrad=True,
|
||
|
maximize=True,
|
||
|
grad_scale=None,
|
||
|
found_inf=None,
|
||
|
):
|
||
|
orig_tuple = (self, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs)
|
||
|
updated_tuple = aten._fused_adam.default(
|
||
|
self,
|
||
|
grads,
|
||
|
exp_avgs,
|
||
|
exp_avg_sqs,
|
||
|
max_exp_avg_sqs,
|
||
|
state_steps,
|
||
|
lr=lr,
|
||
|
beta1=beta1,
|
||
|
beta2=beta2,
|
||
|
weight_decay=weight_decay,
|
||
|
eps=eps,
|
||
|
amsgrad=amsgrad,
|
||
|
maximize=maximize,
|
||
|
grad_scale=grad_scale,
|
||
|
found_inf=found_inf,
|
||
|
)
|
||
|
|
||
|
for idx, (orig, updated) in enumerate(zip(orig_tuple, updated_tuple)):
|
||
|
if idx == 1:
|
||
|
# skip gradient copying as we don't need to copy gradients back
|
||
|
continue
|
||
|
for o, u in zip(orig, updated):
|
||
|
o.copy_(u)
|
||
|
|
||
|
|
||
|
SPMD_DECOMP_TABLE = {
|
||
|
aten._foreach_add_.List: _foreach_add_decomp,
|
||
|
aten._foreach_add_.Scalar: partial(
|
||
|
_foreach_binop_scalar_decomp, aten._foreach_add.Scalar
|
||
|
),
|
||
|
aten._foreach_addcdiv_.Scalar: partial(
|
||
|
_foreach_addcop_scalar_decomp, aten._foreach_addcdiv.Scalar
|
||
|
),
|
||
|
aten._foreach_addcmul_.Scalar: partial(
|
||
|
_foreach_addcop_scalar_decomp, aten._foreach_addcmul.Scalar
|
||
|
),
|
||
|
aten._foreach_div_.List: partial(
|
||
|
_foreach_binop_list_decomp, aten._foreach_div.List
|
||
|
),
|
||
|
aten._foreach_mul_.Scalar: partial(
|
||
|
_foreach_binop_scalar_decomp, aten._foreach_mul.Scalar
|
||
|
),
|
||
|
aten._foreach_div_.Scalar: partial(
|
||
|
_foreach_binop_scalar_decomp, aten._foreach_div.Scalar
|
||
|
),
|
||
|
aten._foreach_neg_.default: partial(
|
||
|
_foreach_unaop_decomp, aten._foreach_neg.default
|
||
|
),
|
||
|
aten._foreach_reciprocal_.default: partial(
|
||
|
_foreach_unaop_decomp, aten._foreach_reciprocal.default
|
||
|
),
|
||
|
aten._foreach_sqrt_.default: partial(
|
||
|
_foreach_unaop_decomp, aten._foreach_sqrt.default
|
||
|
),
|
||
|
aten._foreach_sub_.Scalar: partial(
|
||
|
_foreach_binop_scalar_decomp, aten._foreach_sub.Scalar
|
||
|
),
|
||
|
aten._fused_adam_.default: _fused_adam_decomp,
|
||
|
aten.native_layer_norm_backward.default: native_layer_norm_backward,
|
||
|
}
|
||
|
|
||
|
|
||
|
DEDUP_TARGETS: Set[torch._ops.OpOverload] = {
|
||
|
torch.ops.c10d_functional.all_reduce.default,
|
||
|
torch.ops.c10d_functional.wait_tensor.default,
|
||
|
}
|
||
|
|
||
|
|
||
|
def _dedup_collectives(gm: fx.GraphModule) -> fx.GraphModule:
|
||
|
args_to_node: Dict[Tuple[Any, ...], fx.Node] = {}
|
||
|
|
||
|
for node in gm.graph.nodes:
|
||
|
# replace all args with the results from the first unique comm op
|
||
|
args = pytree.arg_tree_leaves(*node.args)
|
||
|
|
||
|
if node.target in DEDUP_TARGETS:
|
||
|
args_key = (node.target, *args)
|
||
|
unique_node = args_to_node.get(args_key, None)
|
||
|
if unique_node is None:
|
||
|
# first time seeing this combination, remember it
|
||
|
args_to_node[args_key] = node
|
||
|
else:
|
||
|
# the current node is a duplicate, replace it
|
||
|
node.replace_all_uses_with(unique_node)
|
||
|
gm.graph.erase_node(node)
|
||
|
|
||
|
gm.recompile()
|
||
|
|
||
|
return gm
|
||
|
|
||
|
|
||
|
@dataclass
|
||
|
class _CompiledResult:
|
||
|
gm: fx.GraphModule
|
||
|
mod: nn.Module
|
||
|
opt: Optional[torch.optim.Optimizer]
|
||
|
flat_state: List[torch.Tensor]
|
||
|
|
||
|
|
||
|
def _compile(
|
||
|
func: Callable,
|
||
|
module_override: Optional[List[Override]],
|
||
|
parallel_mode: ParallelMode,
|
||
|
*args: Any,
|
||
|
**kwargs: Any,
|
||
|
) -> _CompiledResult:
|
||
|
# 1. Extract nn.Module and Optimizer from args and kwargs
|
||
|
# FIXME(@mrshenli): support multiple nn.Module instances
|
||
|
# FIXME(@mrshenli): support multiple Optiimzer instances
|
||
|
# FIXME(@mrshenli): need to broadcast model to sync parameters
|
||
|
mod, opt = None, None
|
||
|
for arg in pytree.arg_tree_leaves(*args, **kwargs):
|
||
|
if isinstance(arg, nn.Module):
|
||
|
assert mod is None, "Only support single nn.Module for now"
|
||
|
mod = arg
|
||
|
if isinstance(arg, torch.optim.Optimizer):
|
||
|
assert opt is None, "Only support single Optimizer for now"
|
||
|
opt = arg
|
||
|
|
||
|
assert mod is not None, "Couldn't find nn.Module instances from the arguments."
|
||
|
|
||
|
# 2. Override target submodules (e.g., MoE) with dummy replacements
|
||
|
if module_override:
|
||
|
accessor = NamedMemberAccessor(mod)
|
||
|
|
||
|
def swap(fqn_prefix: str, module: torch.nn.Module) -> None:
|
||
|
for override in module_override: # type: ignore[union-attr]
|
||
|
for name, child in module.named_children():
|
||
|
if len(name) == 0:
|
||
|
continue
|
||
|
fqn = fqn_prefix + "." + name if fqn_prefix != "" else name
|
||
|
new_child = override.replacement(fqn, child)
|
||
|
if id(new_child) == id(child):
|
||
|
swap(fqn, new_child)
|
||
|
else:
|
||
|
accessor.swap_submodule(fqn, new_child)
|
||
|
|
||
|
swap("", mod)
|
||
|
|
||
|
# 3. Trace statelss version of the train_step
|
||
|
params = dict(mod.named_parameters(remove_duplicate=False))
|
||
|
buffers = dict(mod.named_buffers(remove_duplicate=False))
|
||
|
|
||
|
named_states = {}
|
||
|
if opt is not None:
|
||
|
# Pass named_states instead of opt.state to stateless_func, because
|
||
|
# the later uses nn.Parameter as key. During tracing, we need to
|
||
|
# make sure optimizers can find the states using proxy tensors.
|
||
|
for n, p in params.items():
|
||
|
if p in opt.state:
|
||
|
# opt.state's key type is string, but optimizer uses
|
||
|
# Parameter as keys
|
||
|
named_states[n] = opt.state[p] # type: ignore[index]
|
||
|
|
||
|
is_data_parallel_mode = isinstance(parallel_mode, DataParallel)
|
||
|
|
||
|
# Lift states and parameters as function arguments so that make_fx
|
||
|
# can trace operations applied to them.
|
||
|
def stateless_func(func, params, buffers, named_states, args, kwargs):
|
||
|
with stateless._reparametrize_module(
|
||
|
mod, {**params, **buffers}
|
||
|
), _rematerialize_optimizer(
|
||
|
opt, named_states, params
|
||
|
) if opt else nullcontext():
|
||
|
# For DataParallel mode, install hooks first to tag the gradients
|
||
|
with gradients_tagging(params) if is_data_parallel_mode else nullcontext():
|
||
|
ret = func(*args, **kwargs)
|
||
|
|
||
|
# make sure updated parameters are returned
|
||
|
return ret, list(mod.parameters()), list(named_states.values()) # type: ignore[union-attr]
|
||
|
|
||
|
# FIXME: Using symbolic tracing to work around in DTensor expand mode.
|
||
|
# Otherwise it hits shape mismatch error, as we use local inputs to
|
||
|
# trace local graph and use DTensor to expand operators, where
|
||
|
# DTensor's shape is the global shape.
|
||
|
tracing_mode = "fake" if is_data_parallel_mode else "symbolic"
|
||
|
|
||
|
if is_data_parallel_mode:
|
||
|
fake_mode = FakeTensorMode()
|
||
|
data_parallel_mode = cast(DataParallel, parallel_mode)
|
||
|
|
||
|
def _get_full_batch_arg(arg: torch.Tensor) -> torch.Tensor:
|
||
|
# since compilation happens in the first iteration and we
|
||
|
# receives mini-batch input, convert them to full batch
|
||
|
# fake tensor input first for data parallel sharding
|
||
|
# propagations
|
||
|
fake_arg = fake_mode.from_tensor(arg)
|
||
|
arg_dims = [1] * arg.ndim
|
||
|
# expand the tensor to full batch size on its batch dim
|
||
|
arg_dims[data_parallel_mode.input_batch_dim] *= dist.get_world_size()
|
||
|
return fake_arg.repeat(arg_dims)
|
||
|
|
||
|
args = pytree.tree_map_only(
|
||
|
torch.Tensor,
|
||
|
_get_full_batch_arg,
|
||
|
args,
|
||
|
)
|
||
|
kwargs = pytree.tree_map_only(
|
||
|
torch.Tensor,
|
||
|
_get_full_batch_arg,
|
||
|
kwargs,
|
||
|
)
|
||
|
|
||
|
with _enable_compile(), torch.autograd.detect_anomaly(check_nan=False):
|
||
|
# FIXME(@mrshenli): functionalization does not work for our use
|
||
|
# case yet. Use explicit decompositions for foreach ops.
|
||
|
# Remove this when the following issue is addressed.
|
||
|
# Issue: https://github.com/pytorch/pytorch/issues/97852
|
||
|
gm = make_fx(
|
||
|
partial(stateless_func, func),
|
||
|
tracing_mode=tracing_mode,
|
||
|
decomposition_table=SPMD_DECOMP_TABLE,
|
||
|
_allow_non_fake_inputs=False,
|
||
|
)(params, buffers, named_states, args, kwargs)
|
||
|
|
||
|
params_and_buffers: Dict[str, Union[torch.Tensor, nn.Parameter]] = {
|
||
|
**params,
|
||
|
**buffers,
|
||
|
}
|
||
|
|
||
|
# 4. parallel mode to expand a single device graph to a distributed graph
|
||
|
gm = parallel_mode.partition(
|
||
|
gm,
|
||
|
mod,
|
||
|
opt,
|
||
|
params_and_buffers,
|
||
|
named_states,
|
||
|
args,
|
||
|
kwargs,
|
||
|
)
|
||
|
|
||
|
# 5. Move the responsibility of flattening the input arguments from the
|
||
|
# graph module to the caller. This serves two purposes:
|
||
|
# - Transformations that add/remove state need to manipulate a state
|
||
|
# container that maintains the state tensors in the same order as they
|
||
|
# appear in graph placeholders.
|
||
|
# - Reduced runtime cost. The state container is only flattened once upfront.
|
||
|
flat_state = pytree.tree_leaves([params_and_buffers, named_states])
|
||
|
gm = _to_caller_flattened_graph_module(gm)
|
||
|
|
||
|
# 6. dedup comm operators.
|
||
|
# The duplication could come from DTensor args and kwargs redistribution.
|
||
|
# Suppose one operator produces a Partial gradient tensor and model
|
||
|
# parameters are replicated. In this case, every optimizer operation using
|
||
|
# that Partial gradient tensor would trigger an allreduce. This is becuase
|
||
|
# DTensor only has local information on individual tensor/operator, which is
|
||
|
# not sufficient to detect duplications in the graph. This situation can
|
||
|
# also happen when inserting FSDP allgather if a parameter is used multiple
|
||
|
# times in the forward method.
|
||
|
# TODO(@mrshenli): @yifuwang has a suggestion of conducting expansion and
|
||
|
# dedup at tracer-level to avoid multiple graph passes.
|
||
|
gm = _dedup_collectives(gm)
|
||
|
|
||
|
# 7. Replace previously inserted dummy ones with real graphs.
|
||
|
if module_override:
|
||
|
for override in module_override:
|
||
|
gm = override.transform(gm, flat_state)
|
||
|
|
||
|
return _CompiledResult(gm, mod, opt, flat_state)
|
||
|
|
||
|
|
||
|
# Note that the Python convention of __dict__ requires the key to be str.
|
||
|
# TODO: ensure the key is unique.
|
||
|
COMPILED_OBJECT_KEY = "_compiled_obj"
|
||
|
|
||
|
|
||
|
def compile(
|
||
|
module_override: Optional[List[Override]] = None,
|
||
|
gm_transformation: Optional[Callable[[fx.GraphModule], fx.GraphModule]] = None,
|
||
|
parallel_mode: Optional[ParallelMode] = None,
|
||
|
):
|
||
|
r"""Compile and optimize a callable, which can be a train step within a training loop.
|
||
|
|
||
|
This method will extract :class:`nn.Module` and :class:`torch.optim.Optimizer`
|
||
|
instances from the input arguments and trace operations applied to their
|
||
|
parameters and states.
|
||
|
|
||
|
Args:
|
||
|
module_override (Optional[List[Override]]): a list of Override instances
|
||
|
that will be applied to the module in order. The :class:`Override`
|
||
|
objects provide :class:`nn.Module` replacements during tracing and a
|
||
|
graph transformation function after tracing. (Default: ``None``)
|
||
|
gm_transformation (Optional[Callable[fx.GraphModule, fx.GraphModule]]):
|
||
|
a callback that will be called after the original callable is
|
||
|
compiled and distributed (usually after the first iteration) to
|
||
|
transform the compiled GraphModule into a new optimized one.
|
||
|
parallel_mode (Optional[ParallelMode]): a :class:`ParallelMode` object
|
||
|
that specifies how to parallelize the callable. Each ParallelMode
|
||
|
would have its own strategy to partition the model and the captured
|
||
|
graph (Default: ``None``)
|
||
|
|
||
|
"""
|
||
|
|
||
|
def inner(func: Callable):
|
||
|
@wraps(func)
|
||
|
def wrapper(*args, **kwargs):
|
||
|
last_train_step = kwargs.pop("last_train_step", False) if kwargs else False
|
||
|
first_iter = False
|
||
|
# Put the COMPILED_OBJECT_KEY in ``wrapper`` instead of ``func`` as
|
||
|
# ``wrapper`` is the one that users will get.
|
||
|
compiled_obj = wrapper.__dict__.get(COMPILED_OBJECT_KEY, None)
|
||
|
if compiled_obj is None:
|
||
|
first_iter = True
|
||
|
global dtensor_expand_mode
|
||
|
mode: ParallelMode = (
|
||
|
dtensor_expand_mode if parallel_mode is None else parallel_mode
|
||
|
)
|
||
|
|
||
|
compiled_obj = _compile(func, module_override, mode, *args, **kwargs)
|
||
|
wrapper.__dict__[COMPILED_OBJECT_KEY] = compiled_obj
|
||
|
|
||
|
flat_inps = compiled_obj.flat_state + pytree.arg_tree_leaves(
|
||
|
*args, **kwargs
|
||
|
)
|
||
|
|
||
|
with torch.no_grad():
|
||
|
# N.B.: we don't need autograd as backward has already been
|
||
|
# captured in the graph.
|
||
|
if first_iter and gm_transformation:
|
||
|
# TODO: SPMD should provid a default and configurable
|
||
|
# transformation.
|
||
|
compiled_obj.gm = gm_transformation(compiled_obj.gm)
|
||
|
if not last_train_step:
|
||
|
output = compiled_obj.gm(*flat_inps)[0]
|
||
|
else:
|
||
|
# This is the last train step. Call IterGraphModule.forward()
|
||
|
# with the `last_iter` argument and catch the exception in
|
||
|
# case the compiled_obj is not wrapped with IterGraphModule.
|
||
|
try:
|
||
|
output = compiled_obj.gm(*flat_inps, last_iter=last_train_step)[
|
||
|
0
|
||
|
]
|
||
|
except TypeError as e:
|
||
|
if "last_iter" not in str(e):
|
||
|
raise e
|
||
|
output = compiled_obj.gm(*flat_inps)[0]
|
||
|
|
||
|
return output
|
||
|
|
||
|
return wrapper
|
||
|
|
||
|
return inner
|