# Copyright (c) Meta Platforms, Inc. and affiliates import dataclasses from typing import cast, Dict, List, Optional, Sequence, Tuple, Union import torch import torch.distributed as dist from torch._utils import _get_device_module from torch.distributed._shard.sharded_tensor.api import ShardedTensor from torch.distributed._shard.sharded_tensor.metadata import ( TensorProperties as ShardTensorProperties, ) from torch.distributed._shard.sharded_tensor.shard import Shard from torch.distributed._shard.sharding_spec.chunk_sharding_spec import ChunkShardingSpec from torch.distributed._tensor import DTensor from torch.distributed.checkpoint._nested_dict import unflatten_state_dict from torch.distributed.checkpoint.default_planner import DefaultLoadPlanner from torch.distributed.checkpoint.metadata import ( BytesStorageMetadata, ChunkStorageMetadata, Metadata, MetadataIndex, STATE_DICT_TYPE, TensorProperties, TensorStorageMetadata, ) from torch.distributed.checkpoint.planner import LoadPlan, LoadPlanner from torch.distributed.checkpoint.planner_helpers import ( _create_read_items, create_read_items_for_chunk_list, ) from torch.distributed.checkpoint.state_dict_loader import load_state_dict from torch.distributed.checkpoint.storage import StorageReader from torch.distributed.checkpoint.utils import ( _element_wise_add, _element_wise_sub, _normalize_device_info, ) from torch.distributed.distributed_c10d import _get_default_group from torch.distributed.fsdp._shard_utils import _create_chunk_sharded_tensor from torch.distributed.remote_device import _remote_device STATE_DICT_2D_LAYOUT = Dict[str, Tuple[Optional[Sequence[int]], Sequence[int]]] # TODO: Update docstrings for optimizer.py __all__ = [ "load_sharded_optimizer_state_dict", ] def _gen_rank_device(global_rank: int, device_type: str = "cuda") -> str: if device_type == "cpu": return "cpu" device_module = _get_device_module(device_type) if device_module.is_available(): return _normalize_device_info( device_type, global_rank % device_module.device_count() ) return "cpu" def _create_colwise_spec( pg: Optional[dist.ProcessGroup] = None, ) -> ChunkShardingSpec: pg_device_type = dist.distributed_c10d._get_pg_default_device(pg).type if pg is None: placements = [ f"rank:{idx}/{_gen_rank_device(idx, pg_device_type)}" for idx in range(dist.get_world_size()) ] else: placements = [ f"rank:{idx}/{_gen_rank_device(dist.get_global_rank(pg, idx), pg_device_type)}" for idx in range(pg.size()) ] return ChunkShardingSpec( dim=0, placements=cast(List[Union[_remote_device, str]], placements), ) def _is_nested_tensor(val: torch.Tensor) -> bool: if type(val) is ShardedTensor: if len(val.local_shards()) == 0: return False if type(val.local_shards()[0].tensor) is ShardedTensor: return True if type(val.local_shards()[0].tensor) is DTensor: raise ValueError("Cannot handle DTensor nested insided ShardedTensor") elif type(val) is DTensor and ( type(val._local_tensor) is DTensor or type(val._local_tensor) is ShardedTensor ): raise ValueError("Cannot handle nested DTensor") return False def _alloc_tensor( props: TensorProperties, size: Sequence[int], device_type: str = "cuda" ) -> torch.Tensor: return torch.empty( size=size, dtype=props.dtype, layout=props.layout, requires_grad=props.requires_grad, pin_memory=props.pin_memory, device=cast(torch.device, _get_device_module(device_type).current_device()), ) def _get_state_dict_2d_layout( state_dict: STATE_DICT_TYPE, ) -> Tuple[STATE_DICT_2D_LAYOUT, Optional[dist.ProcessGroup]]: """ Load the right TP slice of the optimizer state. This is not easy since the per-tensor slicing can't be inferred from checkpoint metadata. We take advantage of the model state_dict producing a sliced ST to figure out what we need to load. This is pretty fragile and it might be easier for FSDP to compute this info for us. Returns a dictionary where keys are the same of the state_dict and the value is a tuple of (offset, size) for the current rank TP slice. N.B. The state_dict *MUST* come from FSDP.sharded_state_dict. """ specs: STATE_DICT_2D_LAYOUT = {} dp_pg: Optional[dist.ProcessGroup] = None for key, value in state_dict.items(): specs[key] = (None, value.size()) if _is_nested_tensor(value): assert ( len(value.local_shards()) == 1 ), "Cannot handle ST with multiple shards" assert isinstance( value, ShardedTensor ), "Can only handle nested ShardedTensor" shard = value.local_shards()[0] specs[key] = ( shard.metadata.shard_offsets, shard.metadata.shard_sizes, ) dp_pg = shard.tensor._process_group # type: ignore[attr-defined] return ( specs, dp_pg, ) class _ReaderWithOffset(DefaultLoadPlanner): translation: Dict[MetadataIndex, MetadataIndex] state_dict: STATE_DICT_TYPE metadata: Metadata def __init__(self, fqn_to_offset: Dict[str, Sequence[int]]) -> None: super().__init__() self.fqn_to_offset = fqn_to_offset self.metadata = Metadata({}) self.state_dict = {} self.translation = {} def create_local_plan(self) -> LoadPlan: requests = [] self.translation = {} for fqn, obj in self.state_dict.items(): md = self.metadata.state_dict_metadata[fqn] if not isinstance(obj, ShardedTensor): requests += _create_read_items(fqn, md, obj) continue if fqn not in self.fqn_to_offset: requests += _create_read_items(fqn, md, obj) continue offset = self.fqn_to_offset[fqn] assert len(obj.local_shards()) == 1 original_shard = obj.local_shards()[0] local_chunks = [ ChunkStorageMetadata( offsets=torch.Size( _element_wise_add(original_shard.metadata.shard_offsets, offset) ), sizes=torch.Size(original_shard.metadata.shard_sizes), ) ] reqs = create_read_items_for_chunk_list( fqn, cast(TensorStorageMetadata, md), local_chunks ) # TODO: The ReadItems will have a displaced MetadataIndex, fix it. # TODO: we should change _create_sharded_read_items to have more ergonomic API for ri in reqs: assert ri.dest_index.offset is not None original_offset = _element_wise_sub(ri.dest_index.offset, offset) original_index = dataclasses.replace( ri.dest_index, offset=torch.Size(original_offset) ) self.translation[ri.dest_index] = original_index requests += reqs return LoadPlan(requests) def lookup_tensor(self, index: MetadataIndex) -> torch.Tensor: return super().lookup_tensor(self.translation.get(index, index)) def load_sharded_optimizer_state_dict( model_state_dict: STATE_DICT_TYPE, optimizer_key: str, storage_reader: StorageReader, planner: Optional[LoadPlanner] = None, ) -> STATE_DICT_TYPE: """ Load a state_dict in conjunction with FSDP sharded optimizer state. This is the current recommended way to checkpoint FSDP. >>> # xdoctest: +SKIP >>> import torch.distributed.checkpoint as dist_cp >>> # Save >>> model: torch.nn.Model >>> optim_params = model.parameters() >>> optim = torch.optim.SGD(optim_params, lr=0.01) >>> # Save >>> with FSDP.state_dict_type(model, StateDictType.SHARDED_STATE_DICT): >>> state_dict = { >>> "optimizer": FSDP.optim_state_dict(model, optim), >>> "model": model.state_dict() >>> } >>> dist_cp.save_state_dict( >>> state_dict=optim_state, >>> storage_writer=dist_cp.FileSystemWriter("checkpoint"), >>> planner=dist_cp.DefaultSavePlanner(), >>> ) >>> >>> # Load >>> with FSDP.state_dict_type(model_tp, StateDictType.SHARDED_STATE_DICT): >>> model_state_dict = model_tp.state_dict() >>> checkpoint = { >>> "model": model_state_dict >>> } >>> dist_cp.load_state_dict( >>> state_dict=checkpoint, >>> storage_reader=dist_cp.FileSystemReader(checkpoint_file), >>> planner=dist_cp.DefaultLoadPlanner(), >>> ) >>> model.load_state_dict(checkpoint["model_state"]) >>> >>> optim_state = dist_cp.load_sharded_optimizer_state_dict( >>> model_state_dict, >>> optimizer_key="optimizer", >>> storage_reader=dist_cp.FileSystemReader("checkpoint"), >>> ) >>> >>> flattened_osd = FSDP.optim_state_dict_to_load( >>> model, optim, optim_state["optimizer"] >>> ) >>> >>> optim.load_state_dict(flattened_osd) """ metadata = storage_reader.read_metadata() layout_specs, dp_pg = _get_state_dict_2d_layout(model_state_dict) dp_pg_device_type = dist.distributed_c10d._get_pg_default_device(dp_pg).type device_module = _get_device_module(dp_pg_device_type) if dp_pg is None: placements = [] for i in range(dist.get_world_size()): device_info = _normalize_device_info( dp_pg_device_type, i % device_module.device_count() ) placements.append(f"rank:{i}/{device_info}") sharding_spec = ChunkShardingSpec(dim=0, placements=placements) # type: ignore[arg-type] else: sharding_spec = _create_colwise_spec(dp_pg) # Create a state_dict for optimizer state state_dict: STATE_DICT_TYPE = {} fqn_to_offset: Dict[str, Sequence[int]] = {} for key, value in metadata.state_dict_metadata.items(): key_path = metadata.planner_data[key] if key_path[0] != optimizer_key: continue if isinstance(value, BytesStorageMetadata): state_dict[key] = "" continue # value: TensorStorageMetadata if value.size.numel() == 1: state_dict[key] = _alloc_tensor( value.properties, value.size, dp_pg_device_type ) elif dp_pg is None: state_dict[key] = _create_chunk_sharded_tensor( _alloc_tensor(value.properties, value.size, dp_pg_device_type), rank=dist.get_rank(), world_size=dist.get_world_size(), num_devices_per_node=device_module.device_count(), pg=_get_default_group(), ) else: spec_key = key_path[2] alloc_size = layout_specs.get(spec_key, (None, value.size))[1] properties = ShardTensorProperties( dtype=value.properties.dtype, layout=value.properties.layout, requires_grad=value.properties.requires_grad, memory_format=value.properties.memory_format, pin_memory=value.properties.pin_memory, ) st_md = sharding_spec.build_metadata(torch.Size(alloc_size), properties) local_shards = [] current_rank = dist.get_rank(dp_pg) for shard_md in st_md.shards_metadata: if cast(_remote_device, shard_md.placement).rank() != current_rank: continue local_shards.append( Shard( tensor=_alloc_tensor( value.properties, shard_md.shard_sizes, dp_pg_device_type ), metadata=shard_md, ) ) st = ShardedTensor._init_from_local_shards_and_global_metadata( local_shards, st_md, process_group=dp_pg ) if spec_key in layout_specs and layout_specs[spec_key][0] is not None: fqn_to_offset[key] = cast(Sequence[int], layout_specs[spec_key][0]) state_dict[key] = st # Whether we unflatten before or after doesn't matter load_state_dict( state_dict=state_dict, storage_reader=storage_reader, # FIXME the type of planner is wrong in load_state_dict planner=_ReaderWithOffset(fqn_to_offset) if dp_pg is not None else planner, ) state_dict = unflatten_state_dict(state_dict, metadata.planner_data) return state_dict