125 lines
5.1 KiB
Python
125 lines
5.1 KiB
Python
# Copyright 2023 The HuggingFace Team. All rights reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
import importlib
|
|
import os
|
|
from typing import Dict, Optional, Union
|
|
|
|
from packaging import version
|
|
|
|
from .hub import cached_file
|
|
from .import_utils import is_peft_available
|
|
|
|
|
|
ADAPTER_CONFIG_NAME = "adapter_config.json"
|
|
ADAPTER_WEIGHTS_NAME = "adapter_model.bin"
|
|
ADAPTER_SAFE_WEIGHTS_NAME = "adapter_model.safetensors"
|
|
|
|
|
|
def find_adapter_config_file(
|
|
model_id: str,
|
|
cache_dir: Optional[Union[str, os.PathLike]] = None,
|
|
force_download: bool = False,
|
|
resume_download: bool = False,
|
|
proxies: Optional[Dict[str, str]] = None,
|
|
token: Optional[Union[bool, str]] = None,
|
|
revision: Optional[str] = None,
|
|
local_files_only: bool = False,
|
|
subfolder: str = "",
|
|
_commit_hash: Optional[str] = None,
|
|
) -> Optional[str]:
|
|
r"""
|
|
Simply checks if the model stored on the Hub or locally is an adapter model or not, return the path of the adapter
|
|
config file if it is, None otherwise.
|
|
|
|
Args:
|
|
model_id (`str`):
|
|
The identifier of the model to look for, can be either a local path or an id to the repository on the Hub.
|
|
cache_dir (`str` or `os.PathLike`, *optional*):
|
|
Path to a directory in which a downloaded pretrained model configuration should be cached if the standard
|
|
cache should not be used.
|
|
force_download (`bool`, *optional*, defaults to `False`):
|
|
Whether or not to force to (re-)download the configuration files and override the cached versions if they
|
|
exist.
|
|
resume_download (`bool`, *optional*, defaults to `False`):
|
|
Whether or not to delete incompletely received file. Attempts to resume the download if such a file exists.
|
|
proxies (`Dict[str, str]`, *optional*):
|
|
A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
|
|
'http://hostname': 'foo.bar:4012'}.` The proxies are used on each request.
|
|
token (`str` or *bool*, *optional*):
|
|
The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated
|
|
when running `huggingface-cli login` (stored in `~/.huggingface`).
|
|
revision (`str`, *optional*, defaults to `"main"`):
|
|
The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
|
|
git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any
|
|
identifier allowed by git.
|
|
|
|
<Tip>
|
|
|
|
To test a pull request you made on the Hub, you can pass `revision="refs/pr/<pr_number>".
|
|
|
|
</Tip>
|
|
|
|
local_files_only (`bool`, *optional*, defaults to `False`):
|
|
If `True`, will only try to load the tokenizer configuration from local files.
|
|
subfolder (`str`, *optional*, defaults to `""`):
|
|
In case the relevant files are located inside a subfolder of the model repo on huggingface.co, you can
|
|
specify the folder name here.
|
|
"""
|
|
adapter_cached_filename = None
|
|
if model_id is None:
|
|
return None
|
|
elif os.path.isdir(model_id):
|
|
list_remote_files = os.listdir(model_id)
|
|
if ADAPTER_CONFIG_NAME in list_remote_files:
|
|
adapter_cached_filename = os.path.join(model_id, ADAPTER_CONFIG_NAME)
|
|
else:
|
|
adapter_cached_filename = cached_file(
|
|
model_id,
|
|
ADAPTER_CONFIG_NAME,
|
|
cache_dir=cache_dir,
|
|
force_download=force_download,
|
|
resume_download=resume_download,
|
|
proxies=proxies,
|
|
token=token,
|
|
revision=revision,
|
|
local_files_only=local_files_only,
|
|
subfolder=subfolder,
|
|
_commit_hash=_commit_hash,
|
|
_raise_exceptions_for_gated_repo=False,
|
|
_raise_exceptions_for_missing_entries=False,
|
|
_raise_exceptions_for_connection_errors=False,
|
|
)
|
|
|
|
return adapter_cached_filename
|
|
|
|
|
|
def check_peft_version(min_version: str) -> None:
|
|
r"""
|
|
Checks if the version of PEFT is compatible.
|
|
|
|
Args:
|
|
version (`str`):
|
|
The version of PEFT to check against.
|
|
"""
|
|
if not is_peft_available():
|
|
raise ValueError("PEFT is not installed. Please install it with `pip install peft`")
|
|
|
|
is_peft_version_compatible = version.parse(importlib.metadata.version("peft")) >= version.parse(min_version)
|
|
|
|
if not is_peft_version_compatible:
|
|
raise ValueError(
|
|
f"The version of PEFT you are using is not compatible, please use a version that is greater"
|
|
f" than {min_version}"
|
|
)
|