import weakref from typing import Any, Dict, Iterable, List, Optional, Set, Tuple import torch import torch.nn as nn from torch.distributed._composable_state import _State from torch.nn.parallel import DistributedDataParallel from .contract import _get_registry, contract _ROOT_MODULE_PREFIX = "" class _ReplicateState(_State): def __init__(self) -> None: super().__init__() self.module: nn.Module = nn.ParameterList() self.has_initialized: bool = False self._param_list: nn.ParameterList = nn.ParameterList() # TODO(@fegin): this variable is originally create for testing, we # should remove this if possible. self._param_names: List[str] = [] def _collect_params( self, module: nn.Module, ignored_modules: Set[nn.Module], ignored_params: Set[nn.Parameter], prefix: str = _ROOT_MODULE_PREFIX, ) -> None: # skip if managed by fully_sharded API if _is_fully_sharded(module): return # if a module is ignored, all descendants of the module are ignored. if module in ignored_modules: return recurse_prefix = ( f"{prefix}." if prefix != _ROOT_MODULE_PREFIX else _ROOT_MODULE_PREFIX ) for n, p in module.named_parameters(recurse=False): if p not in ignored_params: self._param_list.append(p) self._param_names.append(f"{recurse_prefix}{n}") for name, child_module in module.named_children(): self._collect_params( child_module, ignored_modules, ignored_params, prefix=f"{recurse_prefix}{name}", ) def init( self, module: nn.Module, ignored_modules: Set[nn.Module], **kwargs, ) -> None: if _is_fully_sharded(module): raise RuntimeError( "Cannot apply `replicate()` on a Module already managed by `fully_shard`" ) if self.has_initialized: return self.has_initialized = True self.module = module ignored_params = {p for m in ignored_modules for p in m.parameters()} self._collect_params(module, ignored_modules, ignored_params) module.register_forward_pre_hook(self.forward_pre_hook, with_kwargs=True) module.register_forward_hook(self.forward_post_hook) # type: ignore[arg-type] if "device_id" in kwargs: # replicate() supports a small usability enhancement where # user can pass in device_id as a Union[int, torch.device] even for # CPU devices so users don't have to change code for CPU/GPU runs. # We derive the right device_ids to feed into DDP to support this. if kwargs["device_id"] is not None: device_id = kwargs["device_id"] # Convert to device_ids that DDP expects. if isinstance(device_id, torch.device) and device_id.type == "cpu": # CPU modules receive device_ids None kwargs["device_ids"] = None else: # GPU modules expect device_ids=[cuda_device] kwargs["device_ids"] = [device_id] else: kwargs["device_ids"] = None kwargs.pop("device_id") self._ddp = DistributedDataParallel(self._param_list, **kwargs) # Weakref to the DDP instance is currently only used for testing. replicate.state(self.module)._ddp_weakref = weakref.ref(self._ddp) def forward_pre_hook( self, module: nn.Module, args: Tuple[Any, ...], kwargs: Dict[str, Any] ) -> Any: return self._ddp._pre_forward(*args, **kwargs) def forward_post_hook( self, module: nn.Module, input: Tuple[torch.Tensor], output: torch.Tensor, ) -> torch.Tensor: return self._ddp._post_forward(output) @contract(state_cls=_ReplicateState) def replicate( module: nn.Module, ignored_modules: Optional[Iterable[torch.nn.Module]] = None, **kwargs, ) -> nn.Module: r"""Replicates a module Args: module (torch.nn.Module): module to replicate Example:: >>> # xdoctest: +REQUIRES(module:torch._C._distributed_c10d) >>> module = nn.Linear(3, 3) >>> replicate(module) """ torch._C._log_api_usage_once("torch.distributed.replicate") # TODO(fegin): using kwargs is not a good idea if we would like to make # replicate a formal API to replace DDP. if "device_id" in kwargs: if not isinstance(kwargs["device_id"], (int, torch.device)): raise RuntimeError( "Expected device_id to be int or torch.device, " f"but got {type(kwargs['device_id'])}" ) if ignored_modules is None: ignored_modules = {} else: ignored_modules = set(ignored_modules) replicate.state(module).init(module, ignored_modules, **kwargs) return module def _is_fully_sharded(module: nn.Module) -> bool: r"""Check if module is marked with fully_shard.""" registry = _get_registry(module) if registry is None: return False return "fully_shard" in registry