326 lines
11 KiB
Python
326 lines
11 KiB
Python
|
from typing import Any, cast, List
|
||
|
|
||
|
import torch
|
||
|
import torch.distributed as dist
|
||
|
from torch._utils import _get_device_module
|
||
|
|
||
|
from torch.distributed._shard.metadata import ShardMetadata
|
||
|
from torch.distributed._shard.sharded_tensor import ShardedTensor
|
||
|
from torch.distributed._tensor import DTensor
|
||
|
from torch.distributed._tensor._utils import compute_local_shape_and_global_offset
|
||
|
|
||
|
from torch.utils._pytree import tree_map_only
|
||
|
|
||
|
from .metadata import (
|
||
|
BytesStorageMetadata,
|
||
|
ChunkStorageMetadata,
|
||
|
MetadataIndex,
|
||
|
STATE_DICT_TYPE,
|
||
|
STORAGE_TYPES,
|
||
|
TensorProperties,
|
||
|
TensorStorageMetadata,
|
||
|
)
|
||
|
from .planner import (
|
||
|
LoadItemType,
|
||
|
ReadItem,
|
||
|
SavePlan,
|
||
|
TensorWriteData,
|
||
|
WriteItem,
|
||
|
WriteItemType,
|
||
|
)
|
||
|
from .resharding import (
|
||
|
_check_shard_metadata_pair_overlap,
|
||
|
_shards_get_overlap_region_wrt_saved_tensor,
|
||
|
)
|
||
|
|
||
|
__all__: List[str] = ["create_read_items_for_chunk_list"]
|
||
|
|
||
|
|
||
|
def _create_chunk_from_tensor(tensor: torch.Tensor) -> ChunkStorageMetadata:
|
||
|
return ChunkStorageMetadata(
|
||
|
offsets=torch.Size([0] * len(tensor.size())), sizes=tensor.size()
|
||
|
)
|
||
|
|
||
|
|
||
|
def _chunk_for_shard(shard_md: ShardMetadata) -> ChunkStorageMetadata:
|
||
|
return ChunkStorageMetadata(
|
||
|
offsets=torch.Size(shard_md.shard_offsets),
|
||
|
sizes=torch.Size(shard_md.shard_sizes),
|
||
|
)
|
||
|
|
||
|
|
||
|
def _sharded_tensor_metadata(
|
||
|
sharded_tensor: ShardedTensor, shard_md: ShardMetadata
|
||
|
) -> TensorWriteData:
|
||
|
shard_properties = sharded_tensor.metadata().tensor_properties
|
||
|
|
||
|
properties = TensorProperties(
|
||
|
dtype=shard_properties.dtype,
|
||
|
layout=shard_properties.layout,
|
||
|
requires_grad=shard_properties.requires_grad,
|
||
|
memory_format=shard_properties.memory_format,
|
||
|
pin_memory=shard_properties.pin_memory,
|
||
|
)
|
||
|
|
||
|
return TensorWriteData(
|
||
|
chunk=_chunk_for_shard(shard_md),
|
||
|
properties=properties,
|
||
|
size=sharded_tensor.metadata().size,
|
||
|
)
|
||
|
|
||
|
|
||
|
def _create_write_items_for_dtensor(fqn: str, tensor: DTensor) -> WriteItem:
|
||
|
sizes, offsets = compute_local_shape_and_global_offset(
|
||
|
tensor.shape, tensor.device_mesh, tensor.placements
|
||
|
)
|
||
|
sizes, offsets = torch.Size(sizes), torch.Size(offsets)
|
||
|
|
||
|
return WriteItem(
|
||
|
index=MetadataIndex(fqn, offsets),
|
||
|
type=WriteItemType.SHARD,
|
||
|
tensor_data=TensorWriteData(
|
||
|
chunk=ChunkStorageMetadata(
|
||
|
offsets=offsets,
|
||
|
sizes=sizes,
|
||
|
),
|
||
|
properties=TensorProperties.create_from_tensor(tensor.to_local()),
|
||
|
size=tensor.size(),
|
||
|
),
|
||
|
)
|
||
|
|
||
|
|
||
|
def _create_write_item_for_shard(
|
||
|
fqn: str, sharded_tensor: ShardedTensor, shard_md: ShardMetadata
|
||
|
) -> WriteItem:
|
||
|
offsets = torch.Size(shard_md.shard_offsets)
|
||
|
return WriteItem(
|
||
|
index=MetadataIndex(fqn, offsets),
|
||
|
type=WriteItemType.SHARD,
|
||
|
tensor_data=_sharded_tensor_metadata(sharded_tensor, shard_md),
|
||
|
)
|
||
|
|
||
|
|
||
|
def _create_write_item_for_tensor(fqn: str, tensor: torch.Tensor) -> WriteItem:
|
||
|
offsets = torch.Size([0] * len(tensor.size()))
|
||
|
return WriteItem(
|
||
|
index=MetadataIndex(fqn, offsets),
|
||
|
type=WriteItemType.TENSOR,
|
||
|
tensor_data=TensorWriteData(
|
||
|
chunk=ChunkStorageMetadata(offsets=offsets, sizes=tensor.size()),
|
||
|
properties=TensorProperties.create_from_tensor(tensor),
|
||
|
size=tensor.size(),
|
||
|
),
|
||
|
)
|
||
|
|
||
|
|
||
|
def _create_write_item_for_bytesio(fqn: str, bytes: Any):
|
||
|
return WriteItem(
|
||
|
index=MetadataIndex(fqn),
|
||
|
type=WriteItemType.BYTE_IO,
|
||
|
)
|
||
|
|
||
|
|
||
|
def _create_read_item_for_byteio(
|
||
|
dest_index, dest_offset, storage_index, storage_offset, length
|
||
|
):
|
||
|
return ReadItem(
|
||
|
type=LoadItemType.BYTE_IO,
|
||
|
dest_index=dest_index,
|
||
|
dest_offsets=torch.Size((dest_offset,)),
|
||
|
storage_index=storage_index,
|
||
|
storage_offsets=torch.Size((storage_offset,)),
|
||
|
lengths=torch.Size((length,)),
|
||
|
)
|
||
|
|
||
|
|
||
|
def _create_read_item_for_tensor(
|
||
|
dest_index, dest_offsets, storage_index, storage_offsets, lengths
|
||
|
):
|
||
|
return ReadItem(
|
||
|
type=LoadItemType.TENSOR,
|
||
|
dest_index=dest_index,
|
||
|
dest_offsets=torch.Size(dest_offsets),
|
||
|
storage_index=storage_index,
|
||
|
storage_offsets=torch.Size(storage_offsets),
|
||
|
lengths=torch.Size(lengths),
|
||
|
)
|
||
|
|
||
|
|
||
|
def create_read_items_for_chunk_list(
|
||
|
fqn: str,
|
||
|
checkpoint_md: TensorStorageMetadata,
|
||
|
local_chunks: List[ChunkStorageMetadata],
|
||
|
) -> List[ReadItem]:
|
||
|
"""
|
||
|
Create a list of ``ReadItem`` based on the checkpoint and local chunks.
|
||
|
|
||
|
This applies the resharding algorithm and computes the reads needed
|
||
|
to satisfy ``local_chunks`` with a checkpoint described by ``checkpoint_md``.
|
||
|
|
||
|
Args:
|
||
|
fqn (str) : The state_dict FQN to pass to ``ReadItem``.
|
||
|
checkpoint_md (TensorStorageMetadata): metadata for a given tensor
|
||
|
from a checkpoint.
|
||
|
local_chunks (List[ChunkStorageMetadata]): Local chunks that needs to be
|
||
|
loaded.
|
||
|
|
||
|
Returns:
|
||
|
A list of ``ReadItem`` that will satisfy all input chunks.
|
||
|
"""
|
||
|
read_items = []
|
||
|
# this is a naive quadratic algo that can be optimized later
|
||
|
for idx, shard in enumerate(local_chunks):
|
||
|
for storage_idx, storage_md in enumerate(checkpoint_md.chunks):
|
||
|
if not _check_shard_metadata_pair_overlap(shard, storage_md):
|
||
|
continue
|
||
|
|
||
|
storage_offsets = []
|
||
|
dest_offsets = []
|
||
|
lengths = []
|
||
|
for (
|
||
|
dim,
|
||
|
offset_for_saved_tensor,
|
||
|
offset_for_current_tensor,
|
||
|
length,
|
||
|
) in _shards_get_overlap_region_wrt_saved_tensor(
|
||
|
saved_shard=storage_md, current_shard=shard
|
||
|
):
|
||
|
storage_offsets.append(offset_for_saved_tensor)
|
||
|
dest_offsets.append(offset_for_current_tensor)
|
||
|
lengths.append(length)
|
||
|
|
||
|
read_items.append(
|
||
|
_create_read_item_for_tensor(
|
||
|
dest_index=MetadataIndex(fqn, shard.offsets, idx),
|
||
|
dest_offsets=dest_offsets,
|
||
|
storage_index=MetadataIndex(fqn, storage_md.offsets, storage_idx),
|
||
|
storage_offsets=storage_offsets,
|
||
|
lengths=lengths,
|
||
|
)
|
||
|
)
|
||
|
return read_items
|
||
|
|
||
|
|
||
|
def _create_default_metadata_only_plan(state_dict: STATE_DICT_TYPE) -> SavePlan:
|
||
|
requests = []
|
||
|
for fqn, obj in state_dict.items():
|
||
|
if isinstance(obj, DTensor):
|
||
|
requests.append(_create_write_items_for_dtensor(fqn, obj))
|
||
|
elif isinstance(obj, ShardedTensor):
|
||
|
for shard_md in obj.metadata().shards_metadata:
|
||
|
requests.append(_create_write_item_for_shard(fqn, obj, shard_md))
|
||
|
elif isinstance(obj, torch.Tensor):
|
||
|
requests.append(_create_write_item_for_tensor(fqn, obj))
|
||
|
else:
|
||
|
requests.append(_create_write_item_for_bytesio(fqn, obj))
|
||
|
return SavePlan(requests)
|
||
|
|
||
|
|
||
|
def _create_write_items(fqn: str, object: Any) -> List[WriteItem]:
|
||
|
if isinstance(object, DTensor):
|
||
|
return [_create_write_items_for_dtensor(fqn, object)]
|
||
|
elif isinstance(object, ShardedTensor):
|
||
|
return [
|
||
|
_create_write_item_for_shard(fqn, object, shard.metadata)
|
||
|
for shard in object.local_shards()
|
||
|
]
|
||
|
elif isinstance(object, torch.Tensor):
|
||
|
return [_create_write_item_for_tensor(fqn, object)]
|
||
|
else:
|
||
|
return [_create_write_item_for_bytesio(fqn, object)]
|
||
|
|
||
|
|
||
|
def _create_chunk_from_dtensor(tensor: DTensor) -> ChunkStorageMetadata:
|
||
|
sizes, offsets = compute_local_shape_and_global_offset(
|
||
|
tensor.shape, tensor.device_mesh, tensor.placements
|
||
|
)
|
||
|
sizes, offsets = torch.Size(sizes), torch.Size(offsets)
|
||
|
return ChunkStorageMetadata(
|
||
|
offsets=offsets,
|
||
|
sizes=sizes,
|
||
|
)
|
||
|
|
||
|
|
||
|
def _create_chunk_list(tensor: torch.Tensor) -> List[ChunkStorageMetadata]:
|
||
|
if isinstance(tensor, DTensor):
|
||
|
local_chunks = [_create_chunk_from_dtensor(tensor)]
|
||
|
elif isinstance(tensor, ShardedTensor):
|
||
|
local_chunks = [
|
||
|
_chunk_for_shard(shard.metadata) for shard in tensor.local_shards()
|
||
|
]
|
||
|
elif isinstance(tensor, torch.Tensor):
|
||
|
local_chunks = [_create_chunk_from_tensor(tensor)]
|
||
|
else:
|
||
|
raise ValueError(
|
||
|
"Unsupported Type, expecting one of [Tensor, DTensor, ShardedTensor] "
|
||
|
f",but got {type(tensor)}"
|
||
|
)
|
||
|
|
||
|
return local_chunks
|
||
|
|
||
|
|
||
|
def _create_read_items(fqn: str, md: STORAGE_TYPES, obj: Any) -> List[ReadItem]:
|
||
|
if not isinstance(md, BytesStorageMetadata):
|
||
|
try:
|
||
|
local_chunks = _create_chunk_list(obj)
|
||
|
except ValueError as ex:
|
||
|
raise ValueError(
|
||
|
f"Invalid checkpoint metadata for {fqn}, "
|
||
|
+ f"expected BytesStorageMetadata but found {type(md)}",
|
||
|
) from ex
|
||
|
|
||
|
return create_read_items_for_chunk_list(fqn, md, local_chunks)
|
||
|
else:
|
||
|
return [
|
||
|
_create_read_item_for_byteio(
|
||
|
dest_index=MetadataIndex(fqn),
|
||
|
dest_offset=0,
|
||
|
storage_index=MetadataIndex(fqn),
|
||
|
storage_offset=0,
|
||
|
length=0,
|
||
|
)
|
||
|
]
|
||
|
|
||
|
|
||
|
def _init_state_dict(state_dict: STATE_DICT_TYPE) -> None:
|
||
|
state_dict_assigned_storage = tree_map_only(
|
||
|
torch.Tensor, lambda v: _init_meta_tensor(v), state_dict
|
||
|
)
|
||
|
# The inplace version of tree_map_only, tree_map_only_ doesn't seem to work.
|
||
|
# So we need to temporariy update the each element in the state dict with meta tensor.
|
||
|
for k in state_dict.keys():
|
||
|
state_dict[k] = state_dict_assigned_storage[k]
|
||
|
|
||
|
|
||
|
def _init_meta_tensor(value: Any) -> Any:
|
||
|
"""
|
||
|
Initializes tensor, moves it to device for torch.Tensor/DTensor on meta device.
|
||
|
"""
|
||
|
|
||
|
device = getattr(value, "device", None)
|
||
|
# DCP does the initialization if it's meta tensor/DTensor.
|
||
|
if device == torch.device("meta"):
|
||
|
device_type = dist.distributed_c10d._get_pg_default_device().type
|
||
|
device = cast(torch.device, _get_device_module(device_type).current_device())
|
||
|
if isinstance(value, DTensor):
|
||
|
new_local_tensor = torch.empty_like(value.to_local(), device=device)
|
||
|
# We need to pass shape and stride explicitly, since DTensor might be
|
||
|
# sharded unevenly.
|
||
|
dtensor = DTensor.from_local(
|
||
|
new_local_tensor,
|
||
|
device_mesh=value.device_mesh,
|
||
|
placements=value.placements,
|
||
|
shape=value.size(),
|
||
|
stride=value.stride(),
|
||
|
)
|
||
|
return dtensor
|
||
|
elif isinstance(value, torch.Tensor):
|
||
|
tensor = torch.empty_like(value, device=device)
|
||
|
return tensor
|
||
|
else:
|
||
|
raise RuntimeError(
|
||
|
f"Found unsupported type {type(value)} for meta device loading."
|
||
|
)
|
||
|
else:
|
||
|
return value
|