from dataclasses import dataclass from typing import List, Union, Optional from functools import reduce from torch.distributed.remote_device import _remote_device @dataclass class ShardMetadata: """ Represents a shard of the overall Tensor including its offsets, lengths and device placement. Args: shard_offsets(List[int]): Offsets in the original tensor indicating the start offsets for this shard. Should have the same rank as the original tensor. shard_sizes(List[int]): Integers indicating the size of each dimension for this shard. Should have the same rank as the original tensor. placement(:class:`torch.distributed._remote_device`): Specifies the placement of this shard. """ __slots__ = ['shard_offsets', 'shard_sizes', 'placement'] shard_offsets: List[int] shard_sizes: List[int] placement: Optional[_remote_device] def __init__( self, shard_offsets: List[int], shard_sizes: List[int], placement: Optional[Union[str, _remote_device]] = None ): self.shard_offsets = shard_offsets self.shard_sizes = shard_sizes if isinstance(placement, str): self.placement = _remote_device(placement) else: self.placement = placement if len(self.shard_offsets) != len(self.shard_sizes): raise ValueError( f'shard_offsets and shard_sizes should have ' f'the same number of elements, found {len(self.shard_offsets)} ' f'and {self.shard_sizes} respectively') for i in range(len(self.shard_offsets)): if self.shard_offsets[i] < 0: raise ValueError('shard_offsets should be >=0') if self.shard_sizes[i] < 0: raise ValueError('shard_sizes should be >= 0') def __hash__(self): def _hash_reduce(a, b): return (a << 8) + hash(b) res = reduce(_hash_reduce, self.shard_offsets, 37) res = reduce(_hash_reduce, self.shard_sizes, res) res = _hash_reduce(res, self.placement) return res