from typing import List, Tuple from torch.distributed.checkpoint.metadata import ChunkStorageMetadata __all__: List[str] = [] def _check_shard_metadata_pair_overlap( shard1: ChunkStorageMetadata, shard2: ChunkStorageMetadata ): """Check if two shards overlap.""" # For each dim of each shard, check if one shard resides on the other # end of second shard with respect to that dim. As an example for a 2D # shard, we would check if one shard is above or on the left of the # other shard. ndims = len(shard1.offsets) for i in range(ndims): if shard1.offsets[i] >= shard2.offsets[i] + shard2.sizes[i]: return False if shard2.offsets[i] >= shard1.offsets[i] + shard1.sizes[i]: return False return True def _shards_get_overlap_region_wrt_saved_tensor( saved_shard: ChunkStorageMetadata, current_shard: ChunkStorageMetadata ) -> List[Tuple[int, int, int, int]]: """ Return the overlapping region between saved_shard and current_shard. There returned list has the same number of elements as the tensor's dimension. For each element, we produce a tuple with the following contents: (dimension, `saved_shard` offset, `current_shard` offset, length) Offsets are relative to each shard. """ narrows = [] for dim, ( saved_shard_offset, current_shard_offset, saved_shard_size, current_shard_size, ) in enumerate( zip( saved_shard.offsets, current_shard.offsets, saved_shard.sizes, current_shard.sizes, ) ): min_range_end = min( saved_shard_offset + saved_shard_size, current_shard_offset + current_shard_size, ) length = min_range_end - max(current_shard_offset, saved_shard_offset) if saved_shard_offset > current_shard_offset: offset_for_saved_tensor = 0 offset_for_current_tensor = saved_shard_offset - current_shard_offset else: offset_for_saved_tensor = current_shard_offset - saved_shard_offset offset_for_current_tensor = 0 narrows.append( (dim, offset_for_saved_tensor, offset_for_current_tensor, length) ) return narrows