ai-content-maker/.venv/Lib/site-packages/huggingface_hub/serialization/_base.py

170 lines
7.0 KiB
Python

# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Contains helpers to split tensors into shards."""
from dataclasses import dataclass, field
from typing import Any, Callable, Dict, List, Optional, TypeVar
from .. import logging
TensorT = TypeVar("TensorT")
TensorSizeFn_T = Callable[[TensorT], int]
StorageIDFn_T = Callable[[TensorT], Optional[Any]]
MAX_SHARD_SIZE = 5_000_000_000 # 5GB
FILENAME_PATTERN = "model{suffix}.safetensors"
logger = logging.get_logger(__file__)
@dataclass
class StateDictSplit:
is_sharded: bool = field(init=False)
metadata: Dict[str, Any]
filename_to_tensors: Dict[str, List[str]]
tensor_to_filename: Dict[str, str]
def __post_init__(self):
self.is_sharded = len(self.filename_to_tensors) > 1
def split_state_dict_into_shards_factory(
state_dict: Dict[str, TensorT],
*,
get_tensor_size: TensorSizeFn_T,
get_storage_id: StorageIDFn_T = lambda tensor: None,
filename_pattern: str = FILENAME_PATTERN,
max_shard_size: int = MAX_SHARD_SIZE,
) -> StateDictSplit:
"""
Split a model state dictionary in shards so that each shard is smaller than a given size.
The shards are determined by iterating through the `state_dict` in the order of its keys. There is no optimization
made to make each shard as close as possible to the maximum size passed. For example, if the limit is 10GB and we
have tensors of sizes [6GB, 6GB, 2GB, 6GB, 2GB, 2GB] they will get sharded as [6GB], [6+2GB], [6+2+2GB] and not
[6+2+2GB], [6+2GB], [6GB].
<Tip warning={true}>
If one of the model's tensor is bigger than `max_shard_size`, it will end up in its own shard which will have a
size greater than `max_shard_size`.
</Tip>
Args:
state_dict (`Dict[str, Tensor]`):
The state dictionary to save.
get_tensor_size (`Callable[[Tensor], int]`):
A function that returns the size of a tensor in bytes.
get_storage_id (`Callable[[Tensor], Optional[Any]]`, *optional*):
A function that returns a unique identifier to a tensor storage. Multiple different tensors can share the
same underlying storage. This identifier is guaranteed to be unique and constant for this tensor's storage
during its lifetime. Two tensor storages with non-overlapping lifetimes may have the same id.
filename_pattern (`str`, *optional*):
The pattern to generate the files names in which the model will be saved. Pattern must be a string that
can be formatted with `filename_pattern.format(suffix=...)` and must contain the keyword `suffix`
Defaults to `"model{suffix}.safetensors"`.
max_shard_size (`int` or `str`, *optional*):
The maximum size of each shard, in bytes. Defaults to 5GB.
Returns:
[`StateDictSplit`]: A `StateDictSplit` object containing the shards and the index to retrieve them.
"""
storage_id_to_tensors: Dict[Any, List[str]] = {}
shard_list: List[Dict[str, TensorT]] = []
current_shard: Dict[str, TensorT] = {}
current_shard_size = 0
total_size = 0
for key, tensor in state_dict.items():
# when bnb serialization is used the weights in the state dict can be strings
# check: https://github.com/huggingface/transformers/pull/24416 for more details
if isinstance(tensor, str):
logger.info("Skipping tensor %s as it is a string (bnb serialization)", key)
continue
# If a `tensor` shares the same underlying storage as another tensor, we put `tensor` in the same `block`
storage_id = get_storage_id(tensor)
if storage_id is not None:
if storage_id in storage_id_to_tensors:
# We skip this tensor for now and will reassign to correct shard later
storage_id_to_tensors[storage_id].append(key)
continue
else:
# This is the first tensor with this storage_id, we create a new entry
# in the storage_id_to_tensors dict => we will assign the shard id later
storage_id_to_tensors[storage_id] = [key]
# Compute tensor size
tensor_size = get_tensor_size(tensor)
# If this tensor is bigger than the maximal size, we put it in its own shard
if tensor_size > max_shard_size:
total_size += tensor_size
shard_list.append({key: tensor})
continue
# If this tensor is going to tip up over the maximal size, we split.
# Current shard already has some tensors, we add it to the list of shards and create a new one.
if current_shard_size + tensor_size > max_shard_size:
shard_list.append(current_shard)
current_shard = {}
current_shard_size = 0
# Add the tensor to the current shard
current_shard[key] = tensor
current_shard_size += tensor_size
total_size += tensor_size
# Add the last shard
if len(current_shard) > 0:
shard_list.append(current_shard)
nb_shards = len(shard_list)
# Loop over the tensors that share the same storage and assign them together
for storage_id, keys in storage_id_to_tensors.items():
# Let's try to find the shard where the first tensor of this storage is and put all tensors in the same shard
for shard in shard_list:
if keys[0] in shard:
for key in keys:
shard[key] = state_dict[key]
break
# If we only have one shard, we return it => no need to build the index
if nb_shards == 1:
filename = filename_pattern.format(suffix="")
return StateDictSplit(
metadata={"total_size": total_size},
filename_to_tensors={filename: list(state_dict.keys())},
tensor_to_filename={key: filename for key in state_dict.keys()},
)
# Now that each tensor is assigned to a shard, let's assign a filename to each shard
tensor_name_to_filename = {}
filename_to_tensors = {}
for idx, shard in enumerate(shard_list):
filename = filename_pattern.format(suffix=f"-{idx+1:05d}-of-{nb_shards:05d}")
for key in shard:
tensor_name_to_filename[key] = filename
filename_to_tensors[filename] = list(shard.keys())
# Build the index and return
return StateDictSplit(
metadata={"total_size": total_size},
filename_to_tensors=filename_to_tensors,
tensor_to_filename=tensor_name_to_filename,
)