291 lines
12 KiB
Python
291 lines
12 KiB
Python
from contextlib import contextmanager
|
|
from typing import Optional
|
|
import torch
|
|
import torch.distributed as dist
|
|
import torch.nn as nn
|
|
from torch.distributed import distributed_c10d
|
|
from torch.distributed._shard.sharded_tensor import (
|
|
ShardedTensor,
|
|
)
|
|
from .sharding_spec import (
|
|
ShardingSpec,
|
|
ChunkShardingSpec
|
|
)
|
|
from .sharding_plan import (
|
|
ShardingPlan
|
|
)
|
|
from .sharder import Sharder
|
|
|
|
def _shard_tensor(
|
|
tensor: torch.Tensor, sharding_spec: ShardingSpec, src_rank=0, process_group=None
|
|
) -> ShardedTensor:
|
|
"""
|
|
Given a :class:`torch.Tensor`, it shards that tensor according to the provided
|
|
``sharding_spec``. ``src_rank`` denotes the source rank which would be
|
|
used as the ground truth of the data which would be scattered as shards
|
|
across the rest of the ranks.
|
|
|
|
Args:
|
|
tensor (:class:`torch.Tensor`): Tensor needs to be sharded.
|
|
sharding_spec (:class:`torch.distributed._shard.sharding_spec.ShardingSpec`): The specification
|
|
describing how to shard the Tensor.
|
|
|
|
Keyword args:
|
|
src_rank (int, optional): The source rank which is used as the ground truth of
|
|
the data for the parameter that would be sharded and scattered
|
|
across the rest of the ranks.
|
|
Default: 0.
|
|
process_group (ProcessGroup, optional): The process group to work on. If None,
|
|
the default process group will be used.
|
|
|
|
Returns:
|
|
A :class:`ShardedTensor` sharded from the given tensor.
|
|
|
|
.. warning::
|
|
Only :class:`torch.distributed._shard.sharding_spec.ChunkShardingSpec` is
|
|
currently supported as the ``sharding_spec``.
|
|
"""
|
|
if not tensor.is_contiguous():
|
|
raise ValueError('input tensor is not a contiguous Tensor')
|
|
|
|
pg = process_group if process_group is not None else distributed_c10d._get_default_group()
|
|
world_size = dist.get_world_size(pg)
|
|
current_rank = dist.get_rank(pg)
|
|
|
|
# Validate src_rank and sharding_spec are same across all ranks.
|
|
gathered_list = [None] * world_size
|
|
dist.all_gather_object(gathered_list, (src_rank, sharding_spec), group=pg)
|
|
|
|
for idx, entry in enumerate(gathered_list):
|
|
if src_rank != entry[0]: # type: ignore[index]
|
|
raise ValueError(
|
|
f'src_rank={src_rank} on rank: {current_rank} does not ' # type: ignore[index]
|
|
f'match with src_rank={entry[0]} on rank: {idx}')
|
|
if sharding_spec != entry[1]: # type: ignore[index]
|
|
raise ValueError(
|
|
f'sharding_spec={sharding_spec} on rank: {current_rank} does not ' # type: ignore[index]
|
|
f'match with sharding_spec={entry[1]} on rank: {idx}')
|
|
|
|
st = sharding_spec.shard(tensor, src_rank=src_rank, process_group=process_group)
|
|
|
|
return st
|
|
|
|
def shard_parameter(
|
|
module: torch.nn.Module,
|
|
param_name: str,
|
|
sharding_spec: ShardingSpec,
|
|
src_rank=0,
|
|
process_group=None):
|
|
"""
|
|
Given a :class:`torch.nn.Module`, a ``param_name`` for a parameter in that
|
|
module, it shards that parameter according to the provided
|
|
``sharding_spec``. ``src_rank`` denotes the source rank which would be
|
|
used as the ground truth of the data which would be scattered as shards
|
|
across the rest of the ranks.
|
|
|
|
This method replaces ``module.param_name`` with a
|
|
:class:`torch.distributed._sharded_tensor.ShardedTensor`
|
|
|
|
Args:
|
|
module (:class:`torch.nn.Module`): Module whose parameter needs to be sharded.
|
|
param_name (str): Name of the parameter of ``module`` that needs to be sharded.
|
|
sharding_spec (:class:`torch.distributed._shard.sharding_spec.ShardingSpec`): The specification
|
|
describing how to shard the Tensor.
|
|
|
|
Keyword args:
|
|
src_rank (int, optional): The source rank which is used as the ground truth of
|
|
the data for the parameter that would be sharded and scattered
|
|
across the rest of the ranks.
|
|
Default: 0.
|
|
process_group (ProcessGroup, optional): The process group to work on. If None,
|
|
the default process group will be used.
|
|
|
|
.. warning::
|
|
Only :class:`torch.distributed._shard.sharding_spec.ChunkShardingSpec` is
|
|
currently supported as the ``sharding_spec``.
|
|
"""
|
|
# Perform some validation first.
|
|
if not hasattr(module, param_name):
|
|
raise AttributeError(f'{module._get_name()} has no attribute `{param_name}`')
|
|
|
|
tensor = getattr(module, param_name)
|
|
if not isinstance(tensor, torch.Tensor):
|
|
raise ValueError(f'Expected {type(module).__name__}.{param_name} to be a Tensor, but found {type(tensor).__name__}')
|
|
|
|
if not tensor.is_contiguous():
|
|
raise ValueError(f'param: {param_name} is not a contiguous Tensor')
|
|
|
|
st = _shard_tensor(tensor, sharding_spec, src_rank, process_group)
|
|
|
|
# Replace param with ShardedTensor.
|
|
module.register_parameter(param_name, nn.Parameter(st))
|
|
|
|
# Tracks the current process group in the load context manager.
|
|
_CURRENT_PROCESS_GROUP: Optional[dist.ProcessGroup] = None
|
|
|
|
@contextmanager
|
|
def load_with_process_group(process_group):
|
|
"""
|
|
Context manager to set the process group with which to load a ShardedTensor.
|
|
"""
|
|
global _CURRENT_PROCESS_GROUP
|
|
if _CURRENT_PROCESS_GROUP is not None:
|
|
raise RuntimeError(
|
|
'ProcessGroup already set by previous "load_with_process_group" '
|
|
'context manager')
|
|
_CURRENT_PROCESS_GROUP = process_group
|
|
try:
|
|
yield process_group
|
|
finally:
|
|
_CURRENT_PROCESS_GROUP = None
|
|
|
|
def _get_current_process_group():
|
|
"""
|
|
Retrieves the current process group set by ``load_with_process_group``.
|
|
If not set, it just returns the default group.
|
|
"""
|
|
global _CURRENT_PROCESS_GROUP
|
|
if _CURRENT_PROCESS_GROUP is None:
|
|
return distributed_c10d._get_default_group()
|
|
else:
|
|
return _CURRENT_PROCESS_GROUP
|
|
|
|
def _reshard_output(
|
|
module: torch.nn.Module,
|
|
resharding_spec: ShardingSpec) -> torch.nn.Module:
|
|
"""
|
|
Hook a module with output resharding in the forward pass according
|
|
to the given ``resharding_spec``.
|
|
|
|
Args:
|
|
module (:class:`torch.nn.Module`): Module whose output needs to be resharded.
|
|
resharding_spec (:class:`torch.distributed._shard.sharding_spec.ShardingSpec`):
|
|
The specification describing how the output of the module will be resharded.
|
|
|
|
Returns:
|
|
A :class:`torch.nn.Module` object with reshard API hooked.
|
|
"""
|
|
def hook_func(_module, _input, output):
|
|
if isinstance(output, ShardedTensor):
|
|
return output.reshard(resharding_spec)
|
|
return output
|
|
module.register_forward_hook(hook_func)
|
|
return module
|
|
|
|
def _collect_local_shard(module: torch.nn.Module) -> torch.nn.Module:
|
|
"""
|
|
Hook a module with local shards collection in the forward pass.
|
|
|
|
This API is typically used to convert a sharded representation back to data parallel
|
|
representation. In particular, it returns the local tensor for this Shard. If the
|
|
size along the sharding dimension for the local tensor is 1, this dimension is removed
|
|
from the final result. For example a [4, 16] ShardedTensor across 4 ranks is typically
|
|
a local Tensor of size [16] across each rank and not [1, 16] across each rank.
|
|
|
|
Args:
|
|
module (:class:`torch.nn.Module`): Module whose output is ShardedTensor and the
|
|
local tensor value needs to be returned.
|
|
|
|
Returns:
|
|
A :class:`torch.nn.Module` object with collection API hooked.
|
|
"""
|
|
|
|
def hook_func(_module, _input, output):
|
|
if isinstance(output, ShardedTensor):
|
|
local_tensor = output.local_tensor()
|
|
# Squeeze the # of dimensions manually, only applicable to ChunkShardingSpec
|
|
sharding_spec = output._sharding_spec
|
|
if isinstance(sharding_spec, ChunkShardingSpec) \
|
|
and local_tensor.size(sharding_spec.dim) == 1: # type: ignore[attr-defined, arg-type]
|
|
local_tensor = local_tensor.squeeze(
|
|
output._sharding_spec.dim # type: ignore[attr-defined]
|
|
)
|
|
return local_tensor
|
|
module.register_forward_hook(hook_func)
|
|
return module
|
|
|
|
def shard_module(
|
|
module: nn.Module,
|
|
plan: ShardingPlan,
|
|
src_rank=0,
|
|
process_group=None
|
|
):
|
|
"""
|
|
Shards a given module according to the provided sharding `plan`. This method
|
|
first shards all the parameters according to the given sharding `plan`. Then if
|
|
`output_plan` and `return_local_tensor` are specified in the sharding `plan`, it
|
|
will tag the output of modules according `output_plan`, convert the module's
|
|
output back to data parallel according to `return_local_tensor`.
|
|
|
|
Needs to be called on all ranks in an SPMD fashion.
|
|
|
|
Args:
|
|
module (:class:`torch.nn.Module`): The module to apply sharding to
|
|
plan (:class:`torch.distributed._shard.sharding_plan.ShardingPlan`):
|
|
The ShardingPlan which specified param name to ShardingSpec to apply to
|
|
each parameter.
|
|
|
|
Keyword args:
|
|
src_rank (int, optional): The source rank which is used as the ground truth of
|
|
the data for the module that would be sharded and scattered across the rest
|
|
of the ranks.
|
|
Default: 0.
|
|
process_group (ProcessGroup, optional): The process group to work on. If None,
|
|
the default process group will be used.
|
|
"""
|
|
# record Sharder paths for sanity check on the plan to ensure items in the plan
|
|
# does not conflict with the submodule tree that the Sharder is working with
|
|
sharder_paths = []
|
|
for name, spec in plan.plan.items():
|
|
if isinstance(spec, Sharder):
|
|
sharder_paths.append(name)
|
|
|
|
# shard the parameter according to the ShardingPlan
|
|
for name, spec in plan.plan.items():
|
|
if isinstance(spec, ShardingSpec):
|
|
# if found a sharding spec, try to shard the parameter
|
|
module_path, _, param_name = name.rpartition(".")
|
|
|
|
for sharder_path in sharder_paths:
|
|
if module_path.startswith(sharder_path):
|
|
raise RuntimeError(f"ShardingPlan is in-valid, trying to shard a parameter: {name},"
|
|
f" but there's already a Sharder entry for module {sharder_path},"
|
|
f" parameter sharding should not conflict with the submodule tree"
|
|
f" that a Sharder is working with!")
|
|
|
|
mod = module.get_submodule(module_path)
|
|
shard_parameter(
|
|
mod,
|
|
param_name,
|
|
spec,
|
|
src_rank=src_rank,
|
|
process_group=process_group
|
|
)
|
|
elif isinstance(spec, Sharder):
|
|
parent_mod_path, _, mod_name = name.rpartition(".")
|
|
if name == "":
|
|
raise KeyError("Module path must not be empty for custom sharder!")
|
|
mod = module.get_submodule(name)
|
|
parent_mod = module.get_submodule(parent_mod_path)
|
|
sharded_mod = spec.shard(mod)
|
|
# swap this submodule with the sharded module
|
|
parent_mod.mod_name = sharded_mod
|
|
else:
|
|
raise TypeError(f"Only `ShardingSpec` and `Sharder` are supported to shard '{name}'")
|
|
|
|
# reshard output if there's an entry in `reshard_output` for this module
|
|
if plan.output_plan is not None:
|
|
for module_path, output_spec in plan.output_plan.items():
|
|
if isinstance(output_spec, ShardingSpec):
|
|
mod = module.get_submodule(module_path)
|
|
_reshard_output(mod, output_spec)
|
|
else:
|
|
raise TypeError(f"Only `ShardingSpec` is supported as output_plan for '{module_path}'")
|
|
# convert the output back to data parallel for the modules appears in
|
|
# `return_local_tensor` of the plan, we will call `_collect_local_shard`
|
|
# to collect the local tensor for output of modules
|
|
if plan.return_local_tensor is not None:
|
|
for module_path in plan.return_local_tensor:
|
|
mod = module.get_submodule(module_path)
|
|
_collect_local_shard(mod)
|