1289 lines
60 KiB
Python
1289 lines
60 KiB
Python
# coding=utf-8
|
|
# Copyright 2021 The Google Flax Team Authors and The HuggingFace Inc. team.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
|
|
import gc
|
|
import json
|
|
import os
|
|
import re
|
|
import warnings
|
|
from functools import partial
|
|
from pickle import UnpicklingError
|
|
from typing import Any, Dict, Optional, Set, Tuple, Union
|
|
|
|
import flax.linen as nn
|
|
import jax
|
|
import jax.numpy as jnp
|
|
import msgpack.exceptions
|
|
from flax.core.frozen_dict import FrozenDict, unfreeze
|
|
from flax.serialization import from_bytes, to_bytes
|
|
from flax.traverse_util import flatten_dict, unflatten_dict
|
|
from jax.random import PRNGKey
|
|
|
|
from .configuration_utils import PretrainedConfig
|
|
from .dynamic_module_utils import custom_object_save
|
|
from .generation import FlaxGenerationMixin, GenerationConfig
|
|
from .modeling_flax_pytorch_utils import load_pytorch_checkpoint_in_flax_state_dict
|
|
from .utils import (
|
|
FLAX_WEIGHTS_INDEX_NAME,
|
|
FLAX_WEIGHTS_NAME,
|
|
SAFE_WEIGHTS_INDEX_NAME,
|
|
SAFE_WEIGHTS_NAME,
|
|
WEIGHTS_INDEX_NAME,
|
|
WEIGHTS_NAME,
|
|
PushToHubMixin,
|
|
add_code_sample_docstrings,
|
|
add_start_docstrings_to_model_forward,
|
|
cached_file,
|
|
copy_func,
|
|
download_url,
|
|
has_file,
|
|
is_offline_mode,
|
|
is_remote_url,
|
|
logging,
|
|
replace_return_docstrings,
|
|
)
|
|
from .utils.hub import convert_file_size_to_int, get_checkpoint_shard_files
|
|
from .utils.import_utils import is_safetensors_available
|
|
|
|
|
|
if is_safetensors_available():
|
|
from safetensors import safe_open
|
|
from safetensors.flax import load_file as safe_load_file
|
|
from safetensors.flax import save_file as safe_save_file
|
|
|
|
logger = logging.get_logger(__name__)
|
|
|
|
|
|
def quick_gelu(x):
|
|
return x * jax.nn.sigmoid(1.702 * x)
|
|
|
|
|
|
ACT2FN = {
|
|
"gelu": partial(nn.gelu, approximate=False),
|
|
"relu": nn.relu,
|
|
"silu": nn.swish,
|
|
"swish": nn.swish,
|
|
"gelu_new": partial(nn.gelu, approximate=True),
|
|
"quick_gelu": quick_gelu,
|
|
"gelu_pytorch_tanh": partial(nn.gelu, approximate=True),
|
|
}
|
|
|
|
|
|
def dtype_byte_size(dtype):
|
|
"""
|
|
Returns the size (in bytes) occupied by one parameter of type `dtype`. Example:
|
|
```py
|
|
>>> dtype_byte_size(np.float32)
|
|
4
|
|
```
|
|
"""
|
|
if dtype == bool:
|
|
return 1 / 8
|
|
bit_search = re.search(r"[^\d](\d+)$", dtype.name)
|
|
if bit_search is None:
|
|
raise ValueError(f"`dtype` is not a valid dtype: {dtype}.")
|
|
bit_size = int(bit_search.groups()[0])
|
|
return bit_size // 8
|
|
|
|
|
|
def flax_shard_checkpoint(params, max_shard_size="10GB"):
|
|
"""
|
|
Splits a model state dictionary in sub-checkpoints so that the final size of each sub-checkpoint does not exceed a
|
|
given size. The sub-checkpoints are determined by iterating through the `state_dict` in the order of its keys, so
|
|
there is no optimization made to make each sub-checkpoint as close as possible to the maximum size passed. For
|
|
example, if the limit is 10GB and we have weights of sizes [6GB, 6GB, 2GB, 6GB, 2GB, 2GB] they will get sharded as
|
|
[6GB], [6+2GB], [6+2+2GB] and not [6+2+2GB], [6+2GB], [6GB].
|
|
|
|
<Tip warning={true}>
|
|
|
|
If one of the model's weight is bigger that `max_shard_size`, it will end up in its own sub-checkpoint which will
|
|
have a size greater than `max_shard_size`.
|
|
|
|
</Tip>
|
|
|
|
Args:
|
|
params (`Union[Dict, FrozenDict]`): A `PyTree` of model parameters.
|
|
max_shard_size (`int` or `str`, *optional*, defaults to `"10GB"`):
|
|
The maximum size of each sub-checkpoint. If expressed as a string, needs to be digits followed by a unit
|
|
(like `"5MB"`).
|
|
"""
|
|
max_shard_size = convert_file_size_to_int(max_shard_size)
|
|
|
|
sharded_state_dicts = []
|
|
current_block = {}
|
|
current_block_size = 0
|
|
total_size = 0
|
|
|
|
# flatten the weights to chunk
|
|
weights = flatten_dict(params, sep="/")
|
|
for item in weights:
|
|
weight_size = weights[item].size * dtype_byte_size(weights[item].dtype)
|
|
|
|
# If this weight is going to tip up over the maximal size, we split.
|
|
if current_block_size + weight_size > max_shard_size:
|
|
sharded_state_dicts.append(current_block)
|
|
current_block = {}
|
|
current_block_size = 0
|
|
|
|
current_block[item] = weights[item]
|
|
current_block_size += weight_size
|
|
total_size += weight_size
|
|
|
|
# Add the last block
|
|
sharded_state_dicts.append(current_block)
|
|
|
|
# If we only have one shard, we return it
|
|
if len(sharded_state_dicts) == 1:
|
|
return {FLAX_WEIGHTS_NAME: sharded_state_dicts[0]}, None
|
|
|
|
# Otherwise, let's build the index
|
|
weight_map = {}
|
|
shards = {}
|
|
for idx, shard in enumerate(sharded_state_dicts):
|
|
shard_file = FLAX_WEIGHTS_NAME.replace(".msgpack", f"-{idx+1:05d}-of-{len(sharded_state_dicts):05d}.msgpack")
|
|
shards[shard_file] = shard
|
|
for weight_name in shard.keys():
|
|
weight_map[weight_name] = shard_file
|
|
|
|
# Add the metadata
|
|
metadata = {"total_size": total_size}
|
|
index = {"metadata": metadata, "weight_map": weight_map}
|
|
return shards, index
|
|
|
|
|
|
class FlaxPreTrainedModel(PushToHubMixin, FlaxGenerationMixin):
|
|
r"""
|
|
Base class for all models.
|
|
|
|
[`FlaxPreTrainedModel`] takes care of storing the configuration of the models and handles methods for loading,
|
|
downloading and saving models.
|
|
|
|
Class attributes (overridden by derived classes):
|
|
|
|
- **config_class** ([`PretrainedConfig`]) -- A subclass of [`PretrainedConfig`] to use as configuration class
|
|
for this model architecture.
|
|
- **base_model_prefix** (`str`) -- A string indicating the attribute associated to the base model in derived
|
|
classes of the same architecture adding modules on top of the base model.
|
|
- **main_input_name** (`str`) -- The name of the principal input to the model (often `input_ids` for NLP
|
|
models, `pixel_values` for vision models and `input_values` for speech models).
|
|
"""
|
|
|
|
config_class = None
|
|
base_model_prefix = ""
|
|
main_input_name = "input_ids"
|
|
_auto_class = None
|
|
_missing_keys = set()
|
|
|
|
def __init__(
|
|
self,
|
|
config: PretrainedConfig,
|
|
module: nn.Module,
|
|
input_shape: Tuple = (1, 1),
|
|
seed: int = 0,
|
|
dtype: jnp.dtype = jnp.float32,
|
|
_do_init: bool = True,
|
|
):
|
|
if config is None:
|
|
raise ValueError("config cannot be None")
|
|
|
|
if module is None:
|
|
raise ValueError("module cannot be None")
|
|
|
|
# Those are private to be exposed as typed property on derived classes.
|
|
self._config = config
|
|
self._module = module
|
|
|
|
# Those are public as their type is generic to every derived classes.
|
|
self.key = PRNGKey(seed)
|
|
self.dtype = dtype
|
|
self.input_shape = input_shape
|
|
self.generation_config = GenerationConfig.from_model_config(config) if self.can_generate() else None
|
|
|
|
# To check if the model was initialized automatically.
|
|
self._is_initialized = _do_init
|
|
|
|
if _do_init:
|
|
# randomly initialized parameters
|
|
random_params = self.init_weights(self.key, input_shape)
|
|
params_shape_tree = jax.eval_shape(lambda params: params, random_params)
|
|
else:
|
|
init_fn = partial(self.init_weights, input_shape=input_shape)
|
|
params_shape_tree = jax.eval_shape(init_fn, self.key)
|
|
|
|
logger.info(
|
|
"Model weights are not initialized as `_do_init` is set to `False`. "
|
|
f"Make sure to call `{self.__class__.__name__}.init_weights` manually to initialize the weights."
|
|
)
|
|
|
|
# get the shape of the parameters
|
|
self._params_shape_tree = params_shape_tree
|
|
|
|
# save required_params as set
|
|
self._required_params = set(flatten_dict(unfreeze(params_shape_tree)).keys())
|
|
|
|
# initialize the parameters
|
|
if _do_init:
|
|
self.params = random_params
|
|
|
|
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> Dict:
|
|
raise NotImplementedError(f"init method has to be implemented for {self}")
|
|
|
|
def enable_gradient_checkpointing(self):
|
|
raise NotImplementedError(f"gradient checkpointing method has to be implemented for {self}")
|
|
|
|
@classmethod
|
|
def _from_config(cls, config, **kwargs):
|
|
"""
|
|
All context managers that the model should be initialized under go here.
|
|
"""
|
|
return cls(config, **kwargs)
|
|
|
|
@property
|
|
def framework(self) -> str:
|
|
"""
|
|
:str: Identifies that this is a Flax model.
|
|
"""
|
|
return "flax"
|
|
|
|
@property
|
|
def config(self) -> PretrainedConfig:
|
|
return self._config
|
|
|
|
@property
|
|
def module(self) -> nn.Module:
|
|
return self._module
|
|
|
|
@property
|
|
def params(self) -> Union[Dict, FrozenDict]:
|
|
if not self._is_initialized:
|
|
raise ValueError(
|
|
"`params` cannot be accessed from model when the model is created with `_do_init=False`. "
|
|
"You must call `init_weights` manually and store the params outside of the model and "
|
|
"pass it explicitly where needed."
|
|
)
|
|
return self._params
|
|
|
|
@property
|
|
def required_params(self) -> Set:
|
|
return self._required_params
|
|
|
|
@property
|
|
def params_shape_tree(self) -> Dict:
|
|
return self._params_shape_tree
|
|
|
|
@params.setter
|
|
def params(self, params: Union[Dict, FrozenDict]):
|
|
# don't set params if the model is not initialized
|
|
if not self._is_initialized:
|
|
raise ValueError(
|
|
"`params` cannot be set from model when the model is created with `_do_init=False`. "
|
|
"You store the params outside of the model."
|
|
)
|
|
|
|
if isinstance(params, FrozenDict):
|
|
params = unfreeze(params)
|
|
param_keys = set(flatten_dict(params).keys())
|
|
if len(self.required_params - param_keys) > 0:
|
|
raise ValueError(
|
|
"Some parameters are missing. Make sure that `params` include the following "
|
|
f"parameters {self.required_params - param_keys}"
|
|
)
|
|
self._params = params
|
|
|
|
def _cast_floating_to(self, params: Union[Dict, FrozenDict], dtype: jnp.dtype, mask: Any = None) -> Any:
|
|
"""
|
|
Helper method to cast floating-point values of given parameter `PyTree` to given `dtype`.
|
|
"""
|
|
|
|
# taken from https://github.com/deepmind/jmp/blob/3a8318abc3292be38582794dbf7b094e6583b192/jmp/_src/policy.py#L27
|
|
def conditional_cast(param):
|
|
if isinstance(param, jnp.ndarray) and jnp.issubdtype(param.dtype, jnp.floating):
|
|
param = param.astype(dtype)
|
|
return param
|
|
|
|
if mask is None:
|
|
return jax.tree_util.tree_map(conditional_cast, params)
|
|
|
|
flat_params = flatten_dict(params)
|
|
flat_mask, _ = jax.tree_util.tree_flatten(mask)
|
|
|
|
for masked, key in zip(flat_mask, sorted(flat_params.keys())):
|
|
if masked:
|
|
flat_params[key] = conditional_cast(flat_params[key])
|
|
|
|
return unflatten_dict(flat_params)
|
|
|
|
def to_bf16(self, params: Union[Dict, FrozenDict], mask: Any = None):
|
|
r"""
|
|
Cast the floating-point `params` to `jax.numpy.bfloat16`. This returns a new `params` tree and does not cast
|
|
the `params` in place.
|
|
|
|
This method can be used on TPU to explicitly convert the model parameters to bfloat16 precision to do full
|
|
half-precision training or to save weights in bfloat16 for inference in order to save memory and improve speed.
|
|
|
|
Arguments:
|
|
params (`Union[Dict, FrozenDict]`):
|
|
A `PyTree` of model parameters.
|
|
mask (`Union[Dict, FrozenDict]`):
|
|
A `PyTree` with same structure as the `params` tree. The leaves should be booleans, `True` for params
|
|
you want to cast, and should be `False` for those you want to skip.
|
|
|
|
Examples:
|
|
|
|
```python
|
|
>>> from transformers import FlaxBertModel
|
|
|
|
>>> # load model
|
|
>>> model = FlaxBertModel.from_pretrained("google-bert/bert-base-cased")
|
|
>>> # By default, the model parameters will be in fp32 precision, to cast these to bfloat16 precision
|
|
>>> model.params = model.to_bf16(model.params)
|
|
>>> # If you want don't want to cast certain parameters (for example layer norm bias and scale)
|
|
>>> # then pass the mask as follows
|
|
>>> from flax import traverse_util
|
|
|
|
>>> model = FlaxBertModel.from_pretrained("google-bert/bert-base-cased")
|
|
>>> flat_params = traverse_util.flatten_dict(model.params)
|
|
>>> mask = {
|
|
... path: (path[-2] != ("LayerNorm", "bias") and path[-2:] != ("LayerNorm", "scale"))
|
|
... for path in flat_params
|
|
... }
|
|
>>> mask = traverse_util.unflatten_dict(mask)
|
|
>>> model.params = model.to_bf16(model.params, mask)
|
|
```"""
|
|
return self._cast_floating_to(params, jnp.bfloat16, mask)
|
|
|
|
def to_fp32(self, params: Union[Dict, FrozenDict], mask: Any = None):
|
|
r"""
|
|
Cast the floating-point `parmas` to `jax.numpy.float32`. This method can be used to explicitly convert the
|
|
model parameters to fp32 precision. This returns a new `params` tree and does not cast the `params` in place.
|
|
|
|
Arguments:
|
|
params (`Union[Dict, FrozenDict]`):
|
|
A `PyTree` of model parameters.
|
|
mask (`Union[Dict, FrozenDict]`):
|
|
A `PyTree` with same structure as the `params` tree. The leaves should be booleans, `True` for params
|
|
you want to cast, and should be `False` for those you want to skip
|
|
|
|
Examples:
|
|
|
|
```python
|
|
>>> from transformers import FlaxBertModel
|
|
|
|
>>> # Download model and configuration from huggingface.co
|
|
>>> model = FlaxBertModel.from_pretrained("google-bert/bert-base-cased")
|
|
>>> # By default, the model params will be in fp32, to illustrate the use of this method,
|
|
>>> # we'll first cast to fp16 and back to fp32
|
|
>>> model.params = model.to_f16(model.params)
|
|
>>> # now cast back to fp32
|
|
>>> model.params = model.to_fp32(model.params)
|
|
```"""
|
|
return self._cast_floating_to(params, jnp.float32, mask)
|
|
|
|
def to_fp16(self, params: Union[Dict, FrozenDict], mask: Any = None):
|
|
r"""
|
|
Cast the floating-point `parmas` to `jax.numpy.float16`. This returns a new `params` tree and does not cast the
|
|
`params` in place.
|
|
|
|
This method can be used on GPU to explicitly convert the model parameters to float16 precision to do full
|
|
half-precision training or to save weights in float16 for inference in order to save memory and improve speed.
|
|
|
|
Arguments:
|
|
params (`Union[Dict, FrozenDict]`):
|
|
A `PyTree` of model parameters.
|
|
mask (`Union[Dict, FrozenDict]`):
|
|
A `PyTree` with same structure as the `params` tree. The leaves should be booleans, `True` for params
|
|
you want to cast, and should be `False` for those you want to skip
|
|
|
|
Examples:
|
|
|
|
```python
|
|
>>> from transformers import FlaxBertModel
|
|
|
|
>>> # load model
|
|
>>> model = FlaxBertModel.from_pretrained("google-bert/bert-base-cased")
|
|
>>> # By default, the model params will be in fp32, to cast these to float16
|
|
>>> model.params = model.to_fp16(model.params)
|
|
>>> # If you want don't want to cast certain parameters (for example layer norm bias and scale)
|
|
>>> # then pass the mask as follows
|
|
>>> from flax import traverse_util
|
|
|
|
>>> model = FlaxBertModel.from_pretrained("google-bert/bert-base-cased")
|
|
>>> flat_params = traverse_util.flatten_dict(model.params)
|
|
>>> mask = {
|
|
... path: (path[-2] != ("LayerNorm", "bias") and path[-2:] != ("LayerNorm", "scale"))
|
|
... for path in flat_params
|
|
... }
|
|
>>> mask = traverse_util.unflatten_dict(mask)
|
|
>>> model.params = model.to_fp16(model.params, mask)
|
|
```"""
|
|
return self._cast_floating_to(params, jnp.float16, mask)
|
|
|
|
@classmethod
|
|
def load_flax_weights(cls, resolved_archive_file):
|
|
try:
|
|
if resolved_archive_file.endswith(".safetensors"):
|
|
state = safe_load_file(resolved_archive_file)
|
|
state = unflatten_dict(state, sep=".")
|
|
else:
|
|
with open(resolved_archive_file, "rb") as state_f:
|
|
state = from_bytes(cls, state_f.read())
|
|
except (UnpicklingError, msgpack.exceptions.ExtraData) as e:
|
|
try:
|
|
with open(resolved_archive_file) as f:
|
|
if f.read().startswith("version"):
|
|
raise OSError(
|
|
"You seem to have cloned a repository without having git-lfs installed. Please"
|
|
" install git-lfs and run `git lfs install` followed by `git lfs pull` in the"
|
|
" folder you cloned."
|
|
)
|
|
else:
|
|
raise ValueError from e
|
|
except (UnicodeDecodeError, ValueError):
|
|
raise EnvironmentError(f"Unable to convert {resolved_archive_file} to Flax deserializable object. ")
|
|
|
|
return state
|
|
|
|
@classmethod
|
|
def load_flax_sharded_weights(cls, shard_files):
|
|
"""
|
|
This is the same as [`flax.serialization.from_bytes`]
|
|
(https:lax.readthedocs.io/en/latest/_modules/flax/serialization.html#from_bytes) but for a sharded checkpoint.
|
|
|
|
This load is performed efficiently: each checkpoint shard is loaded one by one in RAM and deleted after being
|
|
loaded in the model.
|
|
|
|
Args:
|
|
shard_files (`List[str]`:
|
|
The list of shard files to load.
|
|
|
|
Returns:
|
|
`Dict`: A nested dictionary of the model parameters, in the expected format for flax models : `{'model':
|
|
{'params': {'...'}}}`.
|
|
"""
|
|
|
|
# Load the index
|
|
state_sharded_dict = {}
|
|
|
|
for shard_file in shard_files:
|
|
# load using msgpack utils
|
|
try:
|
|
with open(shard_file, "rb") as state_f:
|
|
state = from_bytes(cls, state_f.read())
|
|
except (UnpicklingError, msgpack.exceptions.ExtraData) as e:
|
|
with open(shard_file) as f:
|
|
if f.read().startswith("version"):
|
|
raise OSError(
|
|
"You seem to have cloned a repository without having git-lfs installed. Please"
|
|
" install git-lfs and run `git lfs install` followed by `git lfs pull` in the"
|
|
" folder you cloned."
|
|
)
|
|
else:
|
|
raise ValueError from e
|
|
except (UnicodeDecodeError, ValueError):
|
|
raise EnvironmentError(f"Unable to convert {shard_file} to Flax deserializable object. ")
|
|
|
|
state = flatten_dict(state, sep="/")
|
|
state_sharded_dict.update(state)
|
|
del state
|
|
gc.collect()
|
|
|
|
# the state dict is unflattened to the match the format of model.params
|
|
return unflatten_dict(state_sharded_dict, sep="/")
|
|
|
|
@classmethod
|
|
def can_generate(cls) -> bool:
|
|
"""
|
|
Returns whether this model can generate sequences with `.generate()`. Returns:
|
|
`bool`: Whether this model can generate sequences with `.generate()`.
|
|
"""
|
|
# Detects whether `prepare_inputs_for_generation` has been overwritten, which is a requirement for generation.
|
|
# Alternativelly, the model can also have a custom `generate` function.
|
|
if "GenerationMixin" in str(cls.prepare_inputs_for_generation) and "GenerationMixin" in str(cls.generate):
|
|
return False
|
|
return True
|
|
|
|
@classmethod
|
|
def from_pretrained(
|
|
cls,
|
|
pretrained_model_name_or_path: Union[str, os.PathLike],
|
|
dtype: jnp.dtype = jnp.float32,
|
|
*model_args,
|
|
config: Optional[Union[PretrainedConfig, str, os.PathLike]] = None,
|
|
cache_dir: Optional[Union[str, os.PathLike]] = None,
|
|
ignore_mismatched_sizes: bool = False,
|
|
force_download: bool = False,
|
|
local_files_only: bool = False,
|
|
token: Optional[Union[str, bool]] = None,
|
|
revision: str = "main",
|
|
**kwargs,
|
|
):
|
|
r"""
|
|
Instantiate a pretrained flax model from a pre-trained model configuration.
|
|
|
|
The warning *Weights from XXX not initialized from pretrained model* means that the weights of XXX do not come
|
|
pretrained with the rest of the model. It is up to you to train those weights with a downstream fine-tuning
|
|
task.
|
|
|
|
The warning *Weights from XXX not used in YYY* means that the layer XXX is not used by YYY, therefore those
|
|
weights are discarded.
|
|
|
|
Parameters:
|
|
pretrained_model_name_or_path (`str` or `os.PathLike`):
|
|
Can be either:
|
|
|
|
- A string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co.
|
|
- A path to a *directory* containing model weights saved using
|
|
[`~FlaxPreTrainedModel.save_pretrained`], e.g., `./my_model_directory/`.
|
|
- A path or url to a *pt index checkpoint file* (e.g, `./tf_model/model.ckpt.index`). In this case,
|
|
`from_pt` should be set to `True`.
|
|
dtype (`jax.numpy.dtype`, *optional*, defaults to `jax.numpy.float32`):
|
|
The data type of the computation. Can be one of `jax.numpy.float32`, `jax.numpy.float16` (on GPUs) and
|
|
`jax.numpy.bfloat16` (on TPUs).
|
|
|
|
This can be used to enable mixed-precision training or half-precision inference on GPUs or TPUs. If
|
|
specified all the computation will be performed with the given `dtype`.
|
|
|
|
**Note that this only specifies the dtype of the computation and does not influence the dtype of model
|
|
parameters.**
|
|
|
|
If you wish to change the dtype of the model parameters, see [`~FlaxPreTrainedModel.to_fp16`] and
|
|
[`~FlaxPreTrainedModel.to_bf16`].
|
|
model_args (sequence of positional arguments, *optional*):
|
|
All remaining positional arguments will be passed to the underlying model's `__init__` method.
|
|
config (`Union[PretrainedConfig, str, os.PathLike]`, *optional*):
|
|
Can be either:
|
|
|
|
- an instance of a class derived from [`PretrainedConfig`],
|
|
- a string or path valid as input to [`~PretrainedConfig.from_pretrained`].
|
|
|
|
Configuration for the model to use instead of an automatically loaded configuration. Configuration can
|
|
be automatically loaded when:
|
|
|
|
- The model is a model provided by the library (loaded with the *model id* string of a pretrained
|
|
model).
|
|
- The model was saved using [`~PreTrainedModel.save_pretrained`] and is reloaded by supplying the
|
|
save directory.
|
|
- The model is loaded by supplying a local directory as `pretrained_model_name_or_path` and a
|
|
configuration JSON file named *config.json* is found in the directory.
|
|
cache_dir (`Union[str, os.PathLike]`, *optional*):
|
|
Path to a directory in which a downloaded pretrained model configuration should be cached if the
|
|
standard cache should not be used.
|
|
from_pt (`bool`, *optional*, defaults to `False`):
|
|
Load the model weights from a PyTorch checkpoint save file (see docstring of
|
|
`pretrained_model_name_or_path` argument).
|
|
ignore_mismatched_sizes (`bool`, *optional*, defaults to `False`):
|
|
Whether or not to raise an error if some of the weights from the checkpoint do not have the same size
|
|
as the weights of the model (if for instance, you are instantiating a model with 10 labels from a
|
|
checkpoint with 3 labels).
|
|
force_download (`bool`, *optional*, defaults to `False`):
|
|
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
|
|
cached versions if they exist.
|
|
resume_download (`bool`, *optional*, defaults to `False`):
|
|
Whether or not to delete incompletely received files. Will attempt to resume the download if such a
|
|
file exists.
|
|
proxies (`Dict[str, str]`, *optional*):
|
|
A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
|
|
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
|
|
local_files_only(`bool`, *optional*, defaults to `False`):
|
|
Whether or not to only look at local files (i.e., do not try to download the model).
|
|
token (`str` or `bool`, *optional*):
|
|
The token to use as HTTP bearer authorization for remote files. If `True`, or not specified, will use
|
|
the token generated when running `huggingface-cli login` (stored in `~/.huggingface`).
|
|
revision (`str`, *optional*, defaults to `"main"`):
|
|
The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
|
|
git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any
|
|
identifier allowed by git.
|
|
|
|
|
|
<Tip>
|
|
|
|
To test a pull request you made on the Hub, you can pass `revision="refs/pr/<pr_number>".
|
|
|
|
</Tip>
|
|
|
|
subfolder (`str`, *optional*, defaults to `""`):
|
|
In case the relevant files are located inside a subfolder of the model repo on huggingface.co, you can
|
|
specify the folder name here.
|
|
kwargs (remaining dictionary of keyword arguments, *optional*):
|
|
Can be used to update the configuration object (after it being loaded) and initiate the model (e.g.,
|
|
`output_attentions=True`). Behaves differently depending on whether a `config` is provided or
|
|
automatically loaded:
|
|
|
|
- If a configuration is provided with `config`, `**kwargs` will be directly passed to the
|
|
underlying model's `__init__` method (we assume all relevant updates to the configuration have
|
|
already been done)
|
|
- If a configuration is not provided, `kwargs` will be first passed to the configuration class
|
|
initialization function ([`~PretrainedConfig.from_pretrained`]). Each key of `kwargs` that
|
|
corresponds to a configuration attribute will be used to override said attribute with the
|
|
supplied `kwargs` value. Remaining keys that do not correspond to any configuration attribute
|
|
will be passed to the underlying model's `__init__` function.
|
|
|
|
Examples:
|
|
|
|
```python
|
|
>>> from transformers import BertConfig, FlaxBertModel
|
|
|
|
>>> # Download model and configuration from huggingface.co and cache.
|
|
>>> model = FlaxBertModel.from_pretrained("google-bert/bert-base-cased")
|
|
>>> # Model was saved using *save_pretrained('./test/saved_model/')* (for example purposes, not runnable).
|
|
>>> model = FlaxBertModel.from_pretrained("./test/saved_model/")
|
|
>>> # Loading from a PyTorch checkpoint file instead of a PyTorch model (slower, for example purposes, not runnable).
|
|
>>> config = BertConfig.from_json_file("./pt_model/config.json")
|
|
>>> model = FlaxBertModel.from_pretrained("./pt_model/pytorch_model.bin", from_pt=True, config=config)
|
|
```"""
|
|
from_pt = kwargs.pop("from_pt", False)
|
|
resume_download = kwargs.pop("resume_download", False)
|
|
proxies = kwargs.pop("proxies", None)
|
|
use_auth_token = kwargs.pop("use_auth_token", None)
|
|
trust_remote_code = kwargs.pop("trust_remote_code", None)
|
|
from_pipeline = kwargs.pop("_from_pipeline", None)
|
|
from_auto_class = kwargs.pop("_from_auto", False)
|
|
_do_init = kwargs.pop("_do_init", True)
|
|
subfolder = kwargs.pop("subfolder", "")
|
|
commit_hash = kwargs.pop("_commit_hash", None)
|
|
|
|
# Not relevant for Flax Models
|
|
_ = kwargs.pop("adapter_kwargs", None)
|
|
|
|
if use_auth_token is not None:
|
|
warnings.warn(
|
|
"The `use_auth_token` argument is deprecated and will be removed in v5 of Transformers. Please use `token` instead.",
|
|
FutureWarning,
|
|
)
|
|
if token is not None:
|
|
raise ValueError(
|
|
"`token` and `use_auth_token` are both specified. Please set only the argument `token`."
|
|
)
|
|
token = use_auth_token
|
|
|
|
if trust_remote_code is True:
|
|
logger.warning(
|
|
"The argument `trust_remote_code` is to be used with Auto classes. It has no effect here and is"
|
|
" ignored."
|
|
)
|
|
|
|
user_agent = {"file_type": "model", "framework": "flax", "from_auto_class": from_auto_class}
|
|
if from_pipeline is not None:
|
|
user_agent["using_pipeline"] = from_pipeline
|
|
|
|
if is_offline_mode() and not local_files_only:
|
|
logger.info("Offline mode: forcing local_files_only=True")
|
|
local_files_only = True
|
|
|
|
# Load config if we don't provide a configuration
|
|
if not isinstance(config, PretrainedConfig):
|
|
config_path = config if config is not None else pretrained_model_name_or_path
|
|
config, model_kwargs = cls.config_class.from_pretrained(
|
|
config_path,
|
|
cache_dir=cache_dir,
|
|
return_unused_kwargs=True,
|
|
force_download=force_download,
|
|
resume_download=resume_download,
|
|
proxies=proxies,
|
|
local_files_only=local_files_only,
|
|
token=token,
|
|
revision=revision,
|
|
subfolder=subfolder,
|
|
_from_auto=from_auto_class,
|
|
_from_pipeline=from_pipeline,
|
|
_commit_hash=commit_hash,
|
|
**kwargs,
|
|
)
|
|
else:
|
|
model_kwargs = kwargs.copy()
|
|
|
|
if commit_hash is None:
|
|
commit_hash = getattr(config, "_commit_hash", None)
|
|
|
|
# Add the dtype to model_kwargs
|
|
model_kwargs["dtype"] = dtype
|
|
|
|
# This variable will flag if we're loading a sharded checkpoint. In this case the archive file is just the
|
|
# index of the files.
|
|
is_sharded = False
|
|
|
|
# Load model
|
|
if pretrained_model_name_or_path is not None:
|
|
pretrained_model_name_or_path = str(pretrained_model_name_or_path)
|
|
is_local = os.path.isdir(pretrained_model_name_or_path)
|
|
if os.path.isdir(pretrained_model_name_or_path):
|
|
if os.path.isfile(os.path.join(pretrained_model_name_or_path, subfolder, FLAX_WEIGHTS_NAME)):
|
|
# Load from a Flax checkpoint
|
|
archive_file = os.path.join(pretrained_model_name_or_path, subfolder, FLAX_WEIGHTS_NAME)
|
|
elif os.path.isfile(os.path.join(pretrained_model_name_or_path, subfolder, FLAX_WEIGHTS_INDEX_NAME)):
|
|
# Load from a sharded Flax checkpoint
|
|
archive_file = os.path.join(pretrained_model_name_or_path, subfolder, FLAX_WEIGHTS_INDEX_NAME)
|
|
is_sharded = True
|
|
elif is_safetensors_available() and os.path.isfile(
|
|
os.path.join(pretrained_model_name_or_path, SAFE_WEIGHTS_NAME)
|
|
):
|
|
# Load from a safetensors checkpoint
|
|
archive_file = os.path.join(pretrained_model_name_or_path, SAFE_WEIGHTS_NAME)
|
|
elif from_pt and os.path.isfile(os.path.join(pretrained_model_name_or_path, subfolder, WEIGHTS_NAME)):
|
|
# Load from a PyTorch checkpoint
|
|
archive_file = os.path.join(pretrained_model_name_or_path, subfolder, WEIGHTS_NAME)
|
|
elif from_pt and os.path.isfile(
|
|
os.path.join(pretrained_model_name_or_path, subfolder, WEIGHTS_INDEX_NAME)
|
|
):
|
|
# Load from a sharded pytorch checkpoint
|
|
archive_file = os.path.join(pretrained_model_name_or_path, subfolder, WEIGHTS_INDEX_NAME)
|
|
is_sharded = True
|
|
# At this stage we don't have a weight file so we will raise an error.
|
|
elif is_safetensors_available() and os.path.isfile(
|
|
os.path.join(pretrained_model_name_or_path, SAFE_WEIGHTS_INDEX_NAME)
|
|
):
|
|
# Load from a sharded safetensors checkpoint
|
|
archive_file = os.path.join(pretrained_model_name_or_path, SAFE_WEIGHTS_INDEX_NAME)
|
|
is_sharded = True
|
|
raise NotImplementedError("Support for sharded checkpoints using safetensors is coming soon!")
|
|
elif os.path.isfile(os.path.join(pretrained_model_name_or_path, subfolder, WEIGHTS_NAME)):
|
|
raise EnvironmentError(
|
|
f"Error no file named {FLAX_WEIGHTS_NAME} found in directory {pretrained_model_name_or_path} "
|
|
"but there is a file for PyTorch weights. Use `from_pt=True` to load this model from those "
|
|
"weights."
|
|
)
|
|
else:
|
|
raise EnvironmentError(
|
|
f"Error no file named {FLAX_WEIGHTS_NAME} or {WEIGHTS_NAME} found in directory "
|
|
f"{pretrained_model_name_or_path}."
|
|
)
|
|
elif os.path.isfile(os.path.join(subfolder, pretrained_model_name_or_path)):
|
|
archive_file = pretrained_model_name_or_path
|
|
is_local = True
|
|
elif is_remote_url(pretrained_model_name_or_path):
|
|
filename = pretrained_model_name_or_path
|
|
resolved_archive_file = download_url(pretrained_model_name_or_path)
|
|
else:
|
|
if from_pt:
|
|
filename = WEIGHTS_NAME
|
|
else:
|
|
filename = FLAX_WEIGHTS_NAME
|
|
|
|
try:
|
|
# Load from URL or cache if already cached
|
|
cached_file_kwargs = {
|
|
"cache_dir": cache_dir,
|
|
"force_download": force_download,
|
|
"proxies": proxies,
|
|
"resume_download": resume_download,
|
|
"local_files_only": local_files_only,
|
|
"token": token,
|
|
"user_agent": user_agent,
|
|
"revision": revision,
|
|
"subfolder": subfolder,
|
|
"_raise_exceptions_for_gated_repo": False,
|
|
"_raise_exceptions_for_missing_entries": False,
|
|
"_commit_hash": commit_hash,
|
|
}
|
|
resolved_archive_file = cached_file(pretrained_model_name_or_path, filename, **cached_file_kwargs)
|
|
|
|
# Maybe the checkpoint is sharded, we try to grab the index name in this case.
|
|
if resolved_archive_file is None and filename == FLAX_WEIGHTS_NAME:
|
|
resolved_archive_file = cached_file(
|
|
pretrained_model_name_or_path, FLAX_WEIGHTS_INDEX_NAME, **cached_file_kwargs
|
|
)
|
|
if resolved_archive_file is not None:
|
|
is_sharded = True
|
|
|
|
# Maybe the checkpoint is pytorch sharded, we try to grab the pytorch index name in this case.
|
|
if resolved_archive_file is None and from_pt:
|
|
resolved_archive_file = cached_file(
|
|
pretrained_model_name_or_path, WEIGHTS_INDEX_NAME, **cached_file_kwargs
|
|
)
|
|
if resolved_archive_file is not None:
|
|
is_sharded = True
|
|
|
|
# If we still haven't found anything, look for `safetensors`.
|
|
if resolved_archive_file is None:
|
|
# No support for sharded safetensors yet, so we'll raise an error if that's all we find.
|
|
filename = SAFE_WEIGHTS_NAME
|
|
resolved_archive_file = cached_file(
|
|
pretrained_model_name_or_path, SAFE_WEIGHTS_NAME, **cached_file_kwargs
|
|
)
|
|
|
|
# Since we set _raise_exceptions_for_missing_entries=False, we don't get an exception but a None
|
|
# result when internet is up, the repo and revision exist, but the file does not.
|
|
if resolved_archive_file is None:
|
|
# Otherwise, maybe there is a TF or Torch model file. We try those to give a helpful error
|
|
# message.
|
|
has_file_kwargs = {
|
|
"revision": revision,
|
|
"proxies": proxies,
|
|
"token": token,
|
|
}
|
|
if has_file(pretrained_model_name_or_path, SAFE_WEIGHTS_INDEX_NAME, **has_file_kwargs):
|
|
is_sharded = True
|
|
raise NotImplementedError(
|
|
"Support for sharded checkpoints using safetensors is coming soon!"
|
|
)
|
|
elif has_file(pretrained_model_name_or_path, WEIGHTS_NAME, **has_file_kwargs):
|
|
raise EnvironmentError(
|
|
f"{pretrained_model_name_or_path} does not appear to have a file named"
|
|
f" {FLAX_WEIGHTS_NAME} but there is a file for PyTorch weights. Use `from_pt=True` to"
|
|
" load this model from those weights."
|
|
)
|
|
elif has_file(pretrained_model_name_or_path, WEIGHTS_INDEX_NAME, **has_file_kwargs):
|
|
raise EnvironmentError(
|
|
f"{pretrained_model_name_or_path} does not appear to have a file named"
|
|
f" {FLAX_WEIGHTS_INDEX_NAME} but there is a sharded file for PyTorch weights. Use"
|
|
" `from_pt=True` to load this model from those weights."
|
|
)
|
|
else:
|
|
raise EnvironmentError(
|
|
f"{pretrained_model_name_or_path} does not appear to have a file named"
|
|
f" {FLAX_WEIGHTS_NAME} or {WEIGHTS_NAME}."
|
|
)
|
|
except EnvironmentError:
|
|
# Raise any environment error raise by `cached_file`. It will have a helpful error message adapted
|
|
# to the original exception.
|
|
raise
|
|
except Exception:
|
|
# For any other exception, we throw a generic error.
|
|
raise EnvironmentError(
|
|
f"Can't load the model for '{pretrained_model_name_or_path}'. If you were trying to load it"
|
|
" from 'https://huggingface.co/models', make sure you don't have a local directory with the"
|
|
f" same name. Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a"
|
|
f" directory containing a file named {FLAX_WEIGHTS_NAME} or {WEIGHTS_NAME}."
|
|
)
|
|
|
|
if is_local:
|
|
logger.info(f"loading weights file {archive_file}")
|
|
resolved_archive_file = archive_file
|
|
filename = resolved_archive_file.split(os.path.sep)[-1]
|
|
else:
|
|
logger.info(f"loading weights file {filename} from cache at {resolved_archive_file}")
|
|
else:
|
|
resolved_archive_file = None
|
|
|
|
# We'll need to download and cache each checkpoint shard if the checkpoint is sharded.
|
|
if is_sharded:
|
|
# resolved_archive_file becomes a list of files that point to the different checkpoint shards in this case.
|
|
resolved_archive_file, _ = get_checkpoint_shard_files(
|
|
pretrained_model_name_or_path,
|
|
resolved_archive_file,
|
|
cache_dir=cache_dir,
|
|
force_download=force_download,
|
|
proxies=proxies,
|
|
resume_download=resume_download,
|
|
local_files_only=local_files_only,
|
|
token=token,
|
|
user_agent=user_agent,
|
|
revision=revision,
|
|
subfolder=subfolder,
|
|
_commit_hash=commit_hash,
|
|
)
|
|
|
|
safetensors_from_pt = False
|
|
if filename == SAFE_WEIGHTS_NAME:
|
|
with safe_open(resolved_archive_file, framework="flax") as f:
|
|
safetensors_metadata = f.metadata()
|
|
if safetensors_metadata is None or safetensors_metadata.get("format") not in ["pt", "tf", "flax"]:
|
|
raise OSError(
|
|
f"The safetensors archive passed at {resolved_archive_file} does not contain the valid metadata."
|
|
" Make sure you save your model with the `save_pretrained` method."
|
|
)
|
|
safetensors_from_pt = safetensors_metadata.get("format") == "pt"
|
|
|
|
# init random models
|
|
model = cls(config, *model_args, _do_init=_do_init, **model_kwargs)
|
|
|
|
if from_pt or safetensors_from_pt:
|
|
state = load_pytorch_checkpoint_in_flax_state_dict(model, resolved_archive_file, is_sharded)
|
|
else:
|
|
if is_sharded:
|
|
state = cls.load_flax_sharded_weights(resolved_archive_file)
|
|
else:
|
|
state = cls.load_flax_weights(resolved_archive_file)
|
|
# make sure all arrays are stored as jnp.arrays
|
|
# NOTE: This is to prevent a bug this will be fixed in Flax >= v0.3.4:
|
|
# https://github.com/google/flax/issues/1261
|
|
if _do_init:
|
|
state = jax.tree_util.tree_map(jnp.array, state)
|
|
else:
|
|
# keep the params on CPU if we don't want to initialize
|
|
state = jax.tree_util.tree_map(lambda x: jax.device_put(x, jax.local_devices(backend="cpu")[0]), state)
|
|
|
|
if "batch_stats" in state: # if flax model contains batch norm layers
|
|
# if model is base model only use model_prefix key
|
|
if (
|
|
cls.base_model_prefix not in dict(model.params_shape_tree["params"])
|
|
and cls.base_model_prefix in state["params"]
|
|
):
|
|
state["params"] = state["params"][cls.base_model_prefix]
|
|
state["batch_stats"] = state["batch_stats"][cls.base_model_prefix]
|
|
|
|
# if model is head model and we are loading weights from base model
|
|
# we initialize new params dict with base_model_prefix
|
|
if (
|
|
cls.base_model_prefix in dict(model.params_shape_tree["params"])
|
|
and cls.base_model_prefix not in state["params"]
|
|
):
|
|
state = {
|
|
"params": {cls.base_model_prefix: state["params"]},
|
|
"batch_stats": {cls.base_model_prefix: state["batch_stats"]},
|
|
}
|
|
|
|
else:
|
|
# if model is base model only use model_prefix key
|
|
if cls.base_model_prefix not in dict(model.params_shape_tree) and cls.base_model_prefix in state:
|
|
state = state[cls.base_model_prefix]
|
|
|
|
# if model is head model and we are loading weights from base model
|
|
# we initialize new params dict with base_model_prefix
|
|
if cls.base_model_prefix in dict(model.params_shape_tree) and cls.base_model_prefix not in state:
|
|
state = {cls.base_model_prefix: state}
|
|
|
|
# flatten dicts
|
|
state = flatten_dict(state)
|
|
|
|
random_state = flatten_dict(unfreeze(model.params if _do_init else model.params_shape_tree))
|
|
|
|
missing_keys = model.required_params - set(state.keys())
|
|
unexpected_keys = set(state.keys()) - model.required_params
|
|
|
|
# Disabling warning when porting pytorch weights to flax, flax does not uses num_batches_tracked
|
|
for unexpected_key in unexpected_keys.copy():
|
|
if "num_batches_tracked" in unexpected_key[-1]:
|
|
unexpected_keys.remove(unexpected_key)
|
|
|
|
if missing_keys and not _do_init:
|
|
logger.warning(
|
|
f"The checkpoint {pretrained_model_name_or_path} is missing required keys: {missing_keys}. "
|
|
"Make sure to call model.init_weights to initialize the missing weights."
|
|
)
|
|
cls._missing_keys = missing_keys
|
|
|
|
# Mistmatched keys contains tuples key/shape1/shape2 of weights in the checkpoint that have a shape not
|
|
# matching the weights in the model.
|
|
mismatched_keys = []
|
|
for key in state.keys():
|
|
if key in random_state and state[key].shape != random_state[key].shape:
|
|
if ignore_mismatched_sizes:
|
|
mismatched_keys.append((key, state[key].shape, random_state[key].shape))
|
|
state[key] = random_state[key]
|
|
else:
|
|
raise ValueError(
|
|
f"Trying to load the pretrained weight for {key} failed: checkpoint has shape "
|
|
f"{state[key].shape} which is incompatible with the model shape {random_state[key].shape}. "
|
|
"Using `ignore_mismatched_sizes=True` if you really want to load this checkpoint inside this "
|
|
"model."
|
|
)
|
|
|
|
# add missing keys as random parameters if we are initializing
|
|
if missing_keys and _do_init:
|
|
for missing_key in missing_keys:
|
|
state[missing_key] = random_state[missing_key]
|
|
|
|
# remove unexpected keys to not be saved again
|
|
for unexpected_key in unexpected_keys:
|
|
del state[unexpected_key]
|
|
|
|
if len(unexpected_keys) > 0:
|
|
logger.warning(
|
|
f"Some weights of the model checkpoint at {pretrained_model_name_or_path} were not used when"
|
|
f" initializing {model.__class__.__name__}: {unexpected_keys}\n- This IS expected if you are"
|
|
f" initializing {model.__class__.__name__} from the checkpoint of a model trained on another task or"
|
|
" with another architecture (e.g. initializing a BertForSequenceClassification model from a"
|
|
" BertForPreTraining model).\n- This IS NOT expected if you are initializing"
|
|
f" {model.__class__.__name__} from the checkpoint of a model that you expect to be exactly identical"
|
|
" (initializing a BertForSequenceClassification model from a BertForSequenceClassification model)."
|
|
)
|
|
else:
|
|
logger.info(f"All model checkpoint weights were used when initializing {model.__class__.__name__}.\n")
|
|
|
|
if len(missing_keys) > 0:
|
|
logger.warning(
|
|
f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at"
|
|
f" {pretrained_model_name_or_path} and are newly initialized: {missing_keys}\nYou should probably"
|
|
" TRAIN this model on a down-stream task to be able to use it for predictions and inference."
|
|
)
|
|
elif len(mismatched_keys) == 0:
|
|
logger.info(
|
|
f"All the weights of {model.__class__.__name__} were initialized from the model checkpoint at"
|
|
f" {pretrained_model_name_or_path}.\nIf your task is similar to the task the model of the checkpoint"
|
|
f" was trained on, you can already use {model.__class__.__name__} for predictions without further"
|
|
" training."
|
|
)
|
|
if len(mismatched_keys) > 0:
|
|
mismatched_warning = "\n".join(
|
|
[
|
|
f"- {key}: found shape {shape1} in the checkpoint and {shape2} in the model instantiated"
|
|
for key, shape1, shape2 in mismatched_keys
|
|
]
|
|
)
|
|
logger.warning(
|
|
f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at"
|
|
f" {pretrained_model_name_or_path} and are newly initialized because the shapes did not"
|
|
f" match:\n{mismatched_warning}\nYou should probably TRAIN this model on a down-stream task to be able"
|
|
" to use it for predictions and inference."
|
|
)
|
|
|
|
# dictionary of key: dtypes for the model params
|
|
param_dtypes = jax.tree_util.tree_map(lambda x: x.dtype, state)
|
|
# extract keys of parameters not in jnp.float32
|
|
fp16_params = [k for k in param_dtypes if param_dtypes[k] == jnp.float16]
|
|
bf16_params = [k for k in param_dtypes if param_dtypes[k] == jnp.bfloat16]
|
|
|
|
# raise a warning if any of the parameters are not in jnp.float32
|
|
if len(fp16_params) > 0:
|
|
logger.warning(
|
|
f"Some of the weights of {model.__class__.__name__} were initialized in float16 precision from "
|
|
f"the model checkpoint at {pretrained_model_name_or_path}:\n{fp16_params}\n"
|
|
"You should probably UPCAST the model weights to float32 if this was not intended. "
|
|
"See [`~FlaxPreTrainedModel.to_fp32`] for further information on how to do this."
|
|
)
|
|
|
|
if len(bf16_params) > 0:
|
|
logger.warning(
|
|
f"Some of the weights of {model.__class__.__name__} were initialized in bfloat16 precision from "
|
|
f"the model checkpoint at {pretrained_model_name_or_path}:\n{bf16_params}\n"
|
|
"You should probably UPCAST the model weights to float32 if this was not intended. "
|
|
"See [`~FlaxPreTrainedModel.to_fp32`] for further information on how to do this."
|
|
)
|
|
|
|
# If it is a model with generation capabilities, attempt to load the generation config
|
|
if model.can_generate():
|
|
try:
|
|
model.generation_config = GenerationConfig.from_pretrained(
|
|
pretrained_model_name_or_path,
|
|
cache_dir=cache_dir,
|
|
force_download=force_download,
|
|
resume_download=resume_download,
|
|
proxies=proxies,
|
|
local_files_only=local_files_only,
|
|
token=token,
|
|
revision=revision,
|
|
subfolder=subfolder,
|
|
_from_auto=from_auto_class,
|
|
_from_pipeline=from_pipeline,
|
|
**kwargs,
|
|
)
|
|
except OSError:
|
|
logger.info(
|
|
"Generation config file not found, using a generation config created from the model config."
|
|
)
|
|
pass
|
|
|
|
if _do_init:
|
|
# set correct parameters
|
|
model.params = unflatten_dict(state)
|
|
return model
|
|
else:
|
|
return model, unflatten_dict(state)
|
|
|
|
def save_pretrained(
|
|
self,
|
|
save_directory: Union[str, os.PathLike],
|
|
params=None,
|
|
push_to_hub=False,
|
|
max_shard_size="10GB",
|
|
token: Optional[Union[str, bool]] = None,
|
|
safe_serialization: bool = False,
|
|
**kwargs,
|
|
):
|
|
"""
|
|
Save a model and its configuration file to a directory, so that it can be re-loaded using the
|
|
`[`~FlaxPreTrainedModel.from_pretrained`]` class method
|
|
|
|
Arguments:
|
|
save_directory (`str` or `os.PathLike`):
|
|
Directory to which to save. Will be created if it doesn't exist.
|
|
push_to_hub (`bool`, *optional*, defaults to `False`):
|
|
Whether or not to push your model to the Hugging Face model hub after saving it. You can specify the
|
|
repository you want to push to with `repo_id` (will default to the name of `save_directory` in your
|
|
namespace).
|
|
max_shard_size (`int` or `str`, *optional*, defaults to `"10GB"`):
|
|
The maximum size for a checkpoint before being sharded. Checkpoints shard will then be each of size
|
|
lower than this size. If expressed as a string, needs to be digits followed by a unit (like `"5MB"`).
|
|
|
|
<Tip warning={true}>
|
|
|
|
If a single weight of the model is bigger than `max_shard_size`, it will be in its own checkpoint shard
|
|
which will be bigger than `max_shard_size`.
|
|
|
|
</Tip>
|
|
|
|
token (`str` or `bool`, *optional*):
|
|
The token to use as HTTP bearer authorization for remote files. If `True`, or not specified, will use
|
|
the token generated when running `huggingface-cli login` (stored in `~/.huggingface`).
|
|
kwargs (`Dict[str, Any]`, *optional*):
|
|
Additional key word arguments passed along to the [`~utils.PushToHubMixin.push_to_hub`] method.
|
|
safe_serialization (`bool`, *optional*, defaults to `False`):
|
|
Whether to save the model using `safetensors` or through msgpack.
|
|
"""
|
|
use_auth_token = kwargs.pop("use_auth_token", None)
|
|
|
|
if use_auth_token is not None:
|
|
warnings.warn(
|
|
"The `use_auth_token` argument is deprecated and will be removed in v5 of Transformers. Please use `token` instead.",
|
|
FutureWarning,
|
|
)
|
|
if token is not None:
|
|
raise ValueError(
|
|
"`token` and `use_auth_token` are both specified. Please set only the argument `token`."
|
|
)
|
|
token = use_auth_token
|
|
|
|
if token is not None:
|
|
kwargs["token"] = token
|
|
|
|
if os.path.isfile(save_directory):
|
|
logger.error(f"Provided path ({save_directory}) should be a directory, not a file")
|
|
return
|
|
|
|
os.makedirs(save_directory, exist_ok=True)
|
|
|
|
if push_to_hub:
|
|
commit_message = kwargs.pop("commit_message", None)
|
|
repo_id = kwargs.pop("repo_id", save_directory.split(os.path.sep)[-1])
|
|
repo_id = self._create_repo(repo_id, **kwargs)
|
|
files_timestamps = self._get_files_timestamps(save_directory)
|
|
|
|
# get abs dir
|
|
save_directory = os.path.abspath(save_directory)
|
|
# save config as well
|
|
self.config.architectures = [self.__class__.__name__[4:]]
|
|
|
|
# If we have a custom model, we copy the file defining it in the folder and set the attributes so it can be
|
|
# loaded from the Hub.
|
|
if self._auto_class is not None:
|
|
custom_object_save(self, save_directory, config=self.config)
|
|
|
|
self.config.save_pretrained(save_directory)
|
|
if self.can_generate():
|
|
self.generation_config.save_pretrained(save_directory)
|
|
|
|
# save model
|
|
weights_name = SAFE_WEIGHTS_NAME if safe_serialization else FLAX_WEIGHTS_NAME
|
|
output_model_file = os.path.join(save_directory, weights_name)
|
|
|
|
shards, index = flax_shard_checkpoint(params if params is not None else self.params, max_shard_size)
|
|
# Clean the folder from a previous save
|
|
for filename in os.listdir(save_directory):
|
|
full_filename = os.path.join(save_directory, filename)
|
|
weights_no_suffix = weights_name.replace(".bin", "").replace(".safetensors", "")
|
|
if (
|
|
filename.startswith(weights_no_suffix)
|
|
and os.path.isfile(full_filename)
|
|
and filename not in shards.keys()
|
|
):
|
|
os.remove(full_filename)
|
|
|
|
if index is None:
|
|
if safe_serialization:
|
|
params = params if params is not None else self.params
|
|
flat_dict = flatten_dict(params, sep=".")
|
|
safe_save_file(flat_dict, output_model_file, metadata={"format": "flax"})
|
|
else:
|
|
with open(output_model_file, "wb") as f:
|
|
params = params if params is not None else self.params
|
|
model_bytes = to_bytes(params)
|
|
f.write(model_bytes)
|
|
|
|
else:
|
|
save_index_file = os.path.join(save_directory, FLAX_WEIGHTS_INDEX_NAME)
|
|
# Save the index as well
|
|
with open(save_index_file, "w", encoding="utf-8") as f:
|
|
content = json.dumps(index, indent=2, sort_keys=True) + "\n"
|
|
f.write(content)
|
|
logger.info(
|
|
f"The model is bigger than the maximum size per checkpoint ({max_shard_size}) and is going to be "
|
|
f"split in {len(shards)} checkpoint shards. You can find where each parameters has been saved in the "
|
|
f"index located at {save_index_file}."
|
|
)
|
|
for shard_file, shard in shards.items():
|
|
# the shard item are unflattened, to save them we need to flatten them again
|
|
with open(os.path.join(save_directory, shard_file), mode="wb") as f:
|
|
params = unflatten_dict(shard, sep="/")
|
|
shard_bytes = to_bytes(params)
|
|
f.write(shard_bytes)
|
|
|
|
logger.info(f"Model weights saved in {output_model_file}")
|
|
|
|
if push_to_hub:
|
|
self._upload_modified_files(
|
|
save_directory,
|
|
repo_id,
|
|
files_timestamps,
|
|
commit_message=commit_message,
|
|
token=token,
|
|
)
|
|
|
|
@classmethod
|
|
def register_for_auto_class(cls, auto_class="FlaxAutoModel"):
|
|
"""
|
|
Register this class with a given auto class. This should only be used for custom models as the ones in the
|
|
library are already mapped with an auto class.
|
|
|
|
<Tip warning={true}>
|
|
|
|
This API is experimental and may have some slight breaking changes in the next releases.
|
|
|
|
</Tip>
|
|
|
|
Args:
|
|
auto_class (`str` or `type`, *optional*, defaults to `"FlaxAutoModel"`):
|
|
The auto class to register this new model with.
|
|
"""
|
|
if not isinstance(auto_class, str):
|
|
auto_class = auto_class.__name__
|
|
|
|
import transformers.models.auto as auto_module
|
|
|
|
if not hasattr(auto_module, auto_class):
|
|
raise ValueError(f"{auto_class} is not a valid auto class.")
|
|
|
|
cls._auto_class = auto_class
|
|
|
|
|
|
# To update the docstring, we need to copy the method, otherwise we change the original docstring.
|
|
FlaxPreTrainedModel.push_to_hub = copy_func(FlaxPreTrainedModel.push_to_hub)
|
|
if FlaxPreTrainedModel.push_to_hub.__doc__ is not None:
|
|
FlaxPreTrainedModel.push_to_hub.__doc__ = FlaxPreTrainedModel.push_to_hub.__doc__.format(
|
|
object="model", object_class="FlaxAutoModel", object_files="model checkpoint"
|
|
)
|
|
|
|
|
|
def overwrite_call_docstring(model_class, docstring):
|
|
# copy __call__ function to be sure docstring is changed only for this function
|
|
model_class.__call__ = copy_func(model_class.__call__)
|
|
# delete existing docstring
|
|
model_class.__call__.__doc__ = None
|
|
# set correct docstring
|
|
model_class.__call__ = add_start_docstrings_to_model_forward(docstring)(model_class.__call__)
|
|
|
|
|
|
def append_call_sample_docstring(
|
|
model_class, checkpoint, output_type, config_class, mask=None, revision=None, real_checkpoint=None
|
|
):
|
|
model_class.__call__ = copy_func(model_class.__call__)
|
|
model_class.__call__ = add_code_sample_docstrings(
|
|
checkpoint=checkpoint,
|
|
output_type=output_type,
|
|
config_class=config_class,
|
|
model_cls=model_class.__name__,
|
|
revision=revision,
|
|
real_checkpoint=real_checkpoint,
|
|
)(model_class.__call__)
|
|
|
|
|
|
def append_replace_return_docstrings(model_class, output_type, config_class):
|
|
model_class.__call__ = copy_func(model_class.__call__)
|
|
model_class.__call__ = replace_return_docstrings(
|
|
output_type=output_type,
|
|
config_class=config_class,
|
|
)(model_class.__call__)
|