496 lines
17 KiB
Python
496 lines
17 KiB
Python
import os
|
|
import sys
|
|
from collections import defaultdict
|
|
from typing import Any, Dict, List, Optional, Set, Tuple, Union
|
|
|
|
import torch
|
|
|
|
from safetensors import deserialize, safe_open, serialize, serialize_file
|
|
|
|
|
|
def storage_ptr(tensor: torch.Tensor) -> int:
|
|
try:
|
|
return tensor.untyped_storage().data_ptr()
|
|
except Exception:
|
|
# Fallback for torch==1.10
|
|
try:
|
|
return tensor.storage().data_ptr()
|
|
except NotImplementedError:
|
|
# Fallback for meta storage
|
|
return 0
|
|
|
|
|
|
def _end_ptr(tensor: torch.Tensor) -> int:
|
|
if tensor.nelement():
|
|
stop = tensor.view(-1)[-1].data_ptr() + _SIZE[tensor.dtype]
|
|
else:
|
|
stop = tensor.data_ptr()
|
|
return stop
|
|
|
|
|
|
def storage_size(tensor: torch.Tensor) -> int:
|
|
try:
|
|
return tensor.untyped_storage().nbytes()
|
|
except AttributeError:
|
|
# Fallback for torch==1.10
|
|
try:
|
|
return tensor.storage().size() * _SIZE[tensor.dtype]
|
|
except NotImplementedError:
|
|
# Fallback for meta storage
|
|
# On torch >=2.0 this is the tensor size
|
|
return tensor.nelement() * _SIZE[tensor.dtype]
|
|
|
|
|
|
def _filter_shared_not_shared(tensors: List[Set[str]], state_dict: Dict[str, torch.Tensor]) -> List[Set[str]]:
|
|
filtered_tensors = []
|
|
for shared in tensors:
|
|
if len(shared) < 2:
|
|
filtered_tensors.append(shared)
|
|
continue
|
|
|
|
areas = []
|
|
for name in shared:
|
|
tensor = state_dict[name]
|
|
areas.append((tensor.data_ptr(), _end_ptr(tensor), name))
|
|
areas.sort()
|
|
|
|
_, last_stop, last_name = areas[0]
|
|
filtered_tensors.append({last_name})
|
|
for start, stop, name in areas[1:]:
|
|
if start >= last_stop:
|
|
filtered_tensors.append({name})
|
|
else:
|
|
filtered_tensors[-1].add(name)
|
|
last_stop = stop
|
|
|
|
return filtered_tensors
|
|
|
|
|
|
def _find_shared_tensors(state_dict: Dict[str, torch.Tensor]) -> List[Set[str]]:
|
|
tensors = defaultdict(set)
|
|
for k, v in state_dict.items():
|
|
if v.device != torch.device("meta") and storage_ptr(v) != 0 and storage_size(v) != 0:
|
|
# Need to add device as key because of multiple GPU.
|
|
tensors[(v.device, storage_ptr(v), storage_size(v))].add(k)
|
|
tensors = list(sorted(tensors.values()))
|
|
tensors = _filter_shared_not_shared(tensors, state_dict)
|
|
return tensors
|
|
|
|
|
|
def _is_complete(tensor: torch.Tensor) -> bool:
|
|
return tensor.data_ptr() == storage_ptr(tensor) and tensor.nelement() * _SIZE[tensor.dtype] == storage_size(tensor)
|
|
|
|
|
|
def _remove_duplicate_names(
|
|
state_dict: Dict[str, torch.Tensor],
|
|
*,
|
|
preferred_names: Optional[List[str]] = None,
|
|
discard_names: Optional[List[str]] = None,
|
|
) -> Dict[str, List[str]]:
|
|
if preferred_names is None:
|
|
preferred_names = []
|
|
preferred_names = set(preferred_names)
|
|
if discard_names is None:
|
|
discard_names = []
|
|
discard_names = set(discard_names)
|
|
|
|
shareds = _find_shared_tensors(state_dict)
|
|
to_remove = defaultdict(list)
|
|
for shared in shareds:
|
|
complete_names = set([name for name in shared if _is_complete(state_dict[name])])
|
|
if not complete_names:
|
|
raise RuntimeError(
|
|
"Error while trying to find names to remove to save state dict, but found no suitable name to keep"
|
|
f" for saving amongst: {shared}. None is covering the entire storage.Refusing to save/load the model"
|
|
" since you could be storing much more memory than needed. Please refer to"
|
|
" https://huggingface.co/docs/safetensors/torch_shared_tensors for more information. Or open an"
|
|
" issue."
|
|
)
|
|
|
|
keep_name = sorted(list(complete_names))[0]
|
|
|
|
# Mechanism to preferentially select keys to keep
|
|
# coming from the on-disk file to allow
|
|
# loading models saved with a different choice
|
|
# of keep_name
|
|
preferred = complete_names.difference(discard_names)
|
|
if preferred:
|
|
keep_name = sorted(list(preferred))[0]
|
|
|
|
if preferred_names:
|
|
preferred = preferred_names.intersection(complete_names)
|
|
if preferred:
|
|
keep_name = sorted(list(preferred))[0]
|
|
for name in sorted(shared):
|
|
if name != keep_name:
|
|
to_remove[keep_name].append(name)
|
|
return to_remove
|
|
|
|
|
|
def save_model(
|
|
model: torch.nn.Module, filename: str, metadata: Optional[Dict[str, str]] = None, force_contiguous: bool = True
|
|
):
|
|
"""
|
|
Saves a given torch model to specified filename.
|
|
This method exists specifically to avoid tensor sharing issues which are
|
|
not allowed in `safetensors`. [More information on tensor sharing](../torch_shared_tensors)
|
|
|
|
Args:
|
|
model (`torch.nn.Module`):
|
|
The model to save on disk.
|
|
filename (`str`):
|
|
The filename location to save the file
|
|
metadata (`Dict[str, str]`, *optional*):
|
|
Extra information to save along with the file.
|
|
Some metadata will be added for each dropped tensors.
|
|
This information will not be enough to recover the entire
|
|
shared structure but might help understanding things
|
|
force_contiguous (`boolean`, *optional*, defaults to True):
|
|
Forcing the state_dict to be saved as contiguous tensors.
|
|
This has no effect on the correctness of the model, but it
|
|
could potentially change performance if the layout of the tensor
|
|
was chosen specifically for that reason.
|
|
"""
|
|
state_dict = model.state_dict()
|
|
to_removes = _remove_duplicate_names(state_dict)
|
|
|
|
for kept_name, to_remove_group in to_removes.items():
|
|
for to_remove in to_remove_group:
|
|
if metadata is None:
|
|
metadata = {}
|
|
|
|
if to_remove not in metadata:
|
|
# Do not override user data
|
|
metadata[to_remove] = kept_name
|
|
del state_dict[to_remove]
|
|
if force_contiguous:
|
|
state_dict = {k: v.contiguous() for k, v in state_dict.items()}
|
|
try:
|
|
save_file(state_dict, filename, metadata=metadata)
|
|
except ValueError as e:
|
|
msg = str(e)
|
|
msg += " Or use save_model(..., force_contiguous=True), read the docs for potential caveats."
|
|
raise ValueError(msg)
|
|
|
|
|
|
def load_model(model: torch.nn.Module, filename: Union[str, os.PathLike], strict: bool = True, device: Union[str, int] = "cpu") -> Tuple[List[str], List[str]]:
|
|
"""
|
|
Loads a given filename onto a torch model.
|
|
This method exists specifically to avoid tensor sharing issues which are
|
|
not allowed in `safetensors`. [More information on tensor sharing](../torch_shared_tensors)
|
|
|
|
Args:
|
|
model (`torch.nn.Module`):
|
|
The model to load onto.
|
|
filename (`str`, or `os.PathLike`):
|
|
The filename location to load the file from.
|
|
strict (`bool`, *optional*, defaults to True):
|
|
Whether to fail if you're missing keys or having unexpected ones.
|
|
When false, the function simply returns missing and unexpected names.
|
|
device (`Union[str, int]`, *optional*, defaults to `cpu`):
|
|
The device where the tensors need to be located after load.
|
|
available options are all regular torch device locations.
|
|
|
|
Returns:
|
|
`(missing, unexpected): (List[str], List[str])`
|
|
`missing` are names in the model which were not modified during loading
|
|
`unexpected` are names that are on the file, but weren't used during
|
|
the load.
|
|
"""
|
|
state_dict = load_file(filename, device=device)
|
|
model_state_dict = model.state_dict()
|
|
to_removes = _remove_duplicate_names(model_state_dict, preferred_names=state_dict.keys())
|
|
missing, unexpected = model.load_state_dict(state_dict, strict=False)
|
|
missing = set(missing)
|
|
for to_remove_group in to_removes.values():
|
|
for to_remove in to_remove_group:
|
|
if to_remove not in missing:
|
|
unexpected.append(to_remove)
|
|
else:
|
|
missing.remove(to_remove)
|
|
if strict and (missing or unexpected):
|
|
missing_keys = ", ".join([f'"{k}"' for k in sorted(missing)])
|
|
unexpected_keys = ", ".join([f'"{k}"' for k in sorted(unexpected)])
|
|
error = f"Error(s) in loading state_dict for {model.__class__.__name__}:"
|
|
if missing:
|
|
error += f"\n Missing key(s) in state_dict: {missing_keys}"
|
|
if unexpected:
|
|
error += f"\n Unexpected key(s) in state_dict: {unexpected_keys}"
|
|
raise RuntimeError(error)
|
|
return missing, unexpected
|
|
|
|
|
|
def save(tensors: Dict[str, torch.Tensor], metadata: Optional[Dict[str, str]] = None) -> bytes:
|
|
"""
|
|
Saves a dictionary of tensors into raw bytes in safetensors format.
|
|
|
|
Args:
|
|
tensors (`Dict[str, torch.Tensor]`):
|
|
The incoming tensors. Tensors need to be contiguous and dense.
|
|
metadata (`Dict[str, str]`, *optional*, defaults to `None`):
|
|
Optional text only metadata you might want to save in your header.
|
|
For instance it can be useful to specify more about the underlying
|
|
tensors. This is purely informative and does not affect tensor loading.
|
|
|
|
Returns:
|
|
`bytes`: The raw bytes representing the format
|
|
|
|
Example:
|
|
|
|
```python
|
|
from safetensors.torch import save
|
|
import torch
|
|
|
|
tensors = {"embedding": torch.zeros((512, 1024)), "attention": torch.zeros((256, 256))}
|
|
byte_data = save(tensors)
|
|
```
|
|
"""
|
|
serialized = serialize(_flatten(tensors), metadata=metadata)
|
|
result = bytes(serialized)
|
|
return result
|
|
|
|
|
|
def save_file(
|
|
tensors: Dict[str, torch.Tensor],
|
|
filename: Union[str, os.PathLike],
|
|
metadata: Optional[Dict[str, str]] = None,
|
|
):
|
|
"""
|
|
Saves a dictionary of tensors into raw bytes in safetensors format.
|
|
|
|
Args:
|
|
tensors (`Dict[str, torch.Tensor]`):
|
|
The incoming tensors. Tensors need to be contiguous and dense.
|
|
filename (`str`, or `os.PathLike`)):
|
|
The filename we're saving into.
|
|
metadata (`Dict[str, str]`, *optional*, defaults to `None`):
|
|
Optional text only metadata you might want to save in your header.
|
|
For instance it can be useful to specify more about the underlying
|
|
tensors. This is purely informative and does not affect tensor loading.
|
|
|
|
Returns:
|
|
`None`
|
|
|
|
Example:
|
|
|
|
```python
|
|
from safetensors.torch import save_file
|
|
import torch
|
|
|
|
tensors = {"embedding": torch.zeros((512, 1024)), "attention": torch.zeros((256, 256))}
|
|
save_file(tensors, "model.safetensors")
|
|
```
|
|
"""
|
|
serialize_file(_flatten(tensors), filename, metadata=metadata)
|
|
|
|
|
|
def load_file(filename: Union[str, os.PathLike], device: Union[str, int] = "cpu") -> Dict[str, torch.Tensor]:
|
|
"""
|
|
Loads a safetensors file into torch format.
|
|
|
|
Args:
|
|
filename (`str`, or `os.PathLike`):
|
|
The name of the file which contains the tensors
|
|
device (`Union[str, int]`, *optional*, defaults to `cpu`):
|
|
The device where the tensors need to be located after load.
|
|
available options are all regular torch device locations.
|
|
|
|
Returns:
|
|
`Dict[str, torch.Tensor]`: dictionary that contains name as key, value as `torch.Tensor`
|
|
|
|
Example:
|
|
|
|
```python
|
|
from safetensors.torch import load_file
|
|
|
|
file_path = "./my_folder/bert.safetensors"
|
|
loaded = load_file(file_path)
|
|
```
|
|
"""
|
|
result = {}
|
|
with safe_open(filename, framework="pt", device=device) as f:
|
|
for k in f.keys():
|
|
result[k] = f.get_tensor(k)
|
|
return result
|
|
|
|
|
|
def load(data: bytes) -> Dict[str, torch.Tensor]:
|
|
"""
|
|
Loads a safetensors file into torch format from pure bytes.
|
|
|
|
Args:
|
|
data (`bytes`):
|
|
The content of a safetensors file
|
|
|
|
Returns:
|
|
`Dict[str, torch.Tensor]`: dictionary that contains name as key, value as `torch.Tensor` on cpu
|
|
|
|
Example:
|
|
|
|
```python
|
|
from safetensors.torch import load
|
|
|
|
file_path = "./my_folder/bert.safetensors"
|
|
with open(file_path, "rb") as f:
|
|
data = f.read()
|
|
|
|
loaded = load(data)
|
|
```
|
|
"""
|
|
flat = deserialize(data)
|
|
return _view2torch(flat)
|
|
|
|
# torch.float8 formats require 2.1; we do not support these dtypes on earlier versions
|
|
_float8_e4m3fn = getattr(torch, "float8_e4m3fn", None)
|
|
_float8_e5m2 = getattr(torch, "float8_e5m2", None)
|
|
|
|
_SIZE = {
|
|
torch.int64: 8,
|
|
torch.float32: 4,
|
|
torch.int32: 4,
|
|
torch.bfloat16: 2,
|
|
torch.float16: 2,
|
|
torch.int16: 2,
|
|
torch.uint8: 1,
|
|
torch.int8: 1,
|
|
torch.bool: 1,
|
|
torch.float64: 8,
|
|
_float8_e4m3fn: 1,
|
|
_float8_e5m2: 1,
|
|
}
|
|
|
|
_TYPES = {
|
|
"F64": torch.float64,
|
|
"F32": torch.float32,
|
|
"F16": torch.float16,
|
|
"BF16": torch.bfloat16,
|
|
"I64": torch.int64,
|
|
# "U64": torch.uint64,
|
|
"I32": torch.int32,
|
|
# "U32": torch.uint32,
|
|
"I16": torch.int16,
|
|
# "U16": torch.uint16,
|
|
"I8": torch.int8,
|
|
"U8": torch.uint8,
|
|
"BOOL": torch.bool,
|
|
"F8_E4M3": _float8_e4m3fn,
|
|
"F8_E5M2": _float8_e5m2,
|
|
}
|
|
|
|
|
|
def _getdtype(dtype_str: str) -> torch.dtype:
|
|
return _TYPES[dtype_str]
|
|
|
|
|
|
def _view2torch(safeview) -> Dict[str, torch.Tensor]:
|
|
result = {}
|
|
for k, v in safeview:
|
|
dtype = _getdtype(v["dtype"])
|
|
arr = torch.frombuffer(v["data"], dtype=dtype).reshape(v["shape"])
|
|
if sys.byteorder == "big":
|
|
arr = torch.from_numpy(arr.numpy().byteswap(inplace=False))
|
|
result[k] = arr
|
|
|
|
return result
|
|
|
|
|
|
def _tobytes(tensor: torch.Tensor, name: str) -> bytes:
|
|
if tensor.layout != torch.strided:
|
|
raise ValueError(
|
|
f"You are trying to save a sparse tensor: `{name}` which this library does not support."
|
|
" You can make it a dense tensor before saving with `.to_dense()` but be aware this might"
|
|
" make a much larger file than needed."
|
|
)
|
|
|
|
if not tensor.is_contiguous():
|
|
raise ValueError(
|
|
f"You are trying to save a non contiguous tensor: `{name}` which is not allowed. It either means you"
|
|
" are trying to save tensors which are reference of each other in which case it's recommended to save"
|
|
" only the full tensors, and reslice at load time, or simply call `.contiguous()` on your tensor to"
|
|
" pack it before saving."
|
|
)
|
|
if tensor.device.type != "cpu":
|
|
# Moving tensor to cpu before saving
|
|
tensor = tensor.to("cpu")
|
|
|
|
import ctypes
|
|
|
|
import numpy as np
|
|
|
|
# When shape is empty (scalar), np.prod returns a float
|
|
# we need a int for the following calculations
|
|
length = int(np.prod(tensor.shape).item())
|
|
bytes_per_item = _SIZE[tensor.dtype]
|
|
|
|
total_bytes = length * bytes_per_item
|
|
|
|
ptr = tensor.data_ptr()
|
|
if ptr == 0:
|
|
return b""
|
|
newptr = ctypes.cast(ptr, ctypes.POINTER(ctypes.c_ubyte))
|
|
data = np.ctypeslib.as_array(newptr, (total_bytes,)) # no internal copy
|
|
if sys.byteorder == "big":
|
|
NPDTYPES = {
|
|
torch.int64: np.int64,
|
|
torch.float32: np.float32,
|
|
torch.int32: np.int32,
|
|
# XXX: This is ok because both have the same width
|
|
torch.bfloat16: np.float16,
|
|
torch.float16: np.float16,
|
|
torch.int16: np.int16,
|
|
torch.uint8: np.uint8,
|
|
torch.int8: np.int8,
|
|
torch.bool: bool,
|
|
torch.float64: np.float64,
|
|
# XXX: This is ok because both have the same width and byteswap is a no-op anyway
|
|
_float8_e4m3fn: np.uint8,
|
|
_float8_e5m2: np.uint8,
|
|
}
|
|
npdtype = NPDTYPES[tensor.dtype]
|
|
# Not in place as that would potentially modify a live running model
|
|
data = data.view(npdtype).byteswap(inplace=False)
|
|
return data.tobytes()
|
|
|
|
|
|
def _flatten(tensors: Dict[str, torch.Tensor]) -> Dict[str, Dict[str, Any]]:
|
|
if not isinstance(tensors, dict):
|
|
raise ValueError(f"Expected a dict of [str, torch.Tensor] but received {type(tensors)}")
|
|
|
|
invalid_tensors = []
|
|
for k, v in tensors.items():
|
|
if not isinstance(v, torch.Tensor):
|
|
raise ValueError(f"Key `{k}` is invalid, expected torch.Tensor but received {type(v)}")
|
|
|
|
if v.layout != torch.strided:
|
|
invalid_tensors.append(k)
|
|
if invalid_tensors:
|
|
raise ValueError(
|
|
f"You are trying to save a sparse tensors: `{invalid_tensors}` which this library does not support."
|
|
" You can make it a dense tensor before saving with `.to_dense()` but be aware this might"
|
|
" make a much larger file than needed."
|
|
)
|
|
|
|
shared_pointers = _find_shared_tensors(tensors)
|
|
failing = []
|
|
for names in shared_pointers:
|
|
if len(names) > 1:
|
|
failing.append(names)
|
|
|
|
if failing:
|
|
raise RuntimeError(
|
|
f"""
|
|
Some tensors share memory, this will lead to duplicate memory on disk and potential differences when loading them again: {failing}.
|
|
A potential way to correctly save your model is to use `save_model`.
|
|
More information at https://huggingface.co/docs/safetensors/torch_shared_tensors
|
|
"""
|
|
)
|
|
|
|
return {
|
|
k: {
|
|
"dtype": str(v.dtype).split(".")[-1],
|
|
"shape": v.shape,
|
|
"data": _tobytes(v, k),
|
|
}
|
|
for k, v in tensors.items()
|
|
}
|