259 lines
7.5 KiB
Python
259 lines
7.5 KiB
Python
#!/usr/bin/env python
|
|
# coding=utf-8
|
|
|
|
# Copyright 2021 The HuggingFace Inc. team. All rights reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
from huggingface_hub import get_full_repo_name # for backward compatibility
|
|
from huggingface_hub.constants import HF_HUB_DISABLE_TELEMETRY as DISABLE_TELEMETRY # for backward compatibility
|
|
from packaging import version
|
|
|
|
from .. import __version__
|
|
from .backbone_utils import BackboneConfigMixin, BackboneMixin
|
|
from .constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_STANDARD_MEAN, IMAGENET_STANDARD_STD
|
|
from .doc import (
|
|
add_code_sample_docstrings,
|
|
add_end_docstrings,
|
|
add_start_docstrings,
|
|
add_start_docstrings_to_model_forward,
|
|
copy_func,
|
|
replace_return_docstrings,
|
|
)
|
|
from .generic import (
|
|
ContextManagers,
|
|
ExplicitEnum,
|
|
ModelOutput,
|
|
PaddingStrategy,
|
|
TensorType,
|
|
add_model_info_to_auto_map,
|
|
cached_property,
|
|
can_return_loss,
|
|
expand_dims,
|
|
find_labels,
|
|
flatten_dict,
|
|
infer_framework,
|
|
is_jax_tensor,
|
|
is_numpy_array,
|
|
is_tensor,
|
|
is_tf_symbolic_tensor,
|
|
is_tf_tensor,
|
|
is_torch_device,
|
|
is_torch_dtype,
|
|
is_torch_tensor,
|
|
reshape,
|
|
squeeze,
|
|
strtobool,
|
|
tensor_size,
|
|
to_numpy,
|
|
to_py_obj,
|
|
transpose,
|
|
working_or_temp_dir,
|
|
)
|
|
from .hub import (
|
|
CLOUDFRONT_DISTRIB_PREFIX,
|
|
HF_MODULES_CACHE,
|
|
HUGGINGFACE_CO_PREFIX,
|
|
HUGGINGFACE_CO_RESOLVE_ENDPOINT,
|
|
PYTORCH_PRETRAINED_BERT_CACHE,
|
|
PYTORCH_TRANSFORMERS_CACHE,
|
|
S3_BUCKET_PREFIX,
|
|
TRANSFORMERS_CACHE,
|
|
TRANSFORMERS_DYNAMIC_MODULE_NAME,
|
|
EntryNotFoundError,
|
|
PushInProgress,
|
|
PushToHubMixin,
|
|
RepositoryNotFoundError,
|
|
RevisionNotFoundError,
|
|
cached_file,
|
|
default_cache_path,
|
|
define_sagemaker_information,
|
|
download_url,
|
|
extract_commit_hash,
|
|
get_cached_models,
|
|
get_file_from_repo,
|
|
has_file,
|
|
http_user_agent,
|
|
is_offline_mode,
|
|
is_remote_url,
|
|
move_cache,
|
|
send_example_telemetry,
|
|
try_to_load_from_cache,
|
|
)
|
|
from .import_utils import (
|
|
ACCELERATE_MIN_VERSION,
|
|
ENV_VARS_TRUE_AND_AUTO_VALUES,
|
|
ENV_VARS_TRUE_VALUES,
|
|
TORCH_FX_REQUIRED_VERSION,
|
|
USE_JAX,
|
|
USE_TF,
|
|
USE_TORCH,
|
|
XLA_FSDPV2_MIN_VERSION,
|
|
DummyObject,
|
|
OptionalDependencyNotAvailable,
|
|
_LazyModule,
|
|
ccl_version,
|
|
direct_transformers_import,
|
|
get_torch_version,
|
|
is_accelerate_available,
|
|
is_apex_available,
|
|
is_aqlm_available,
|
|
is_auto_awq_available,
|
|
is_auto_gptq_available,
|
|
is_av_available,
|
|
is_bitsandbytes_available,
|
|
is_bs4_available,
|
|
is_coloredlogs_available,
|
|
is_cv2_available,
|
|
is_cython_available,
|
|
is_datasets_available,
|
|
is_decord_available,
|
|
is_detectron2_available,
|
|
is_essentia_available,
|
|
is_faiss_available,
|
|
is_flash_attn_2_available,
|
|
is_flash_attn_greater_or_equal_2_10,
|
|
is_flax_available,
|
|
is_fsdp_available,
|
|
is_ftfy_available,
|
|
is_g2p_en_available,
|
|
is_galore_torch_available,
|
|
is_in_notebook,
|
|
is_ipex_available,
|
|
is_jieba_available,
|
|
is_jinja_available,
|
|
is_jumanpp_available,
|
|
is_kenlm_available,
|
|
is_keras_nlp_available,
|
|
is_levenshtein_available,
|
|
is_librosa_available,
|
|
is_mlx_available,
|
|
is_natten_available,
|
|
is_ninja_available,
|
|
is_nltk_available,
|
|
is_onnx_available,
|
|
is_openai_available,
|
|
is_optimum_available,
|
|
is_pandas_available,
|
|
is_peft_available,
|
|
is_phonemizer_available,
|
|
is_pretty_midi_available,
|
|
is_protobuf_available,
|
|
is_psutil_available,
|
|
is_py3nvml_available,
|
|
is_pyctcdecode_available,
|
|
is_pytesseract_available,
|
|
is_pytest_available,
|
|
is_pytorch_quantization_available,
|
|
is_quanto_available,
|
|
is_rjieba_available,
|
|
is_sacremoses_available,
|
|
is_safetensors_available,
|
|
is_sagemaker_dp_enabled,
|
|
is_sagemaker_mp_enabled,
|
|
is_scipy_available,
|
|
is_sentencepiece_available,
|
|
is_seqio_available,
|
|
is_sklearn_available,
|
|
is_soundfile_availble,
|
|
is_spacy_available,
|
|
is_speech_available,
|
|
is_sudachi_available,
|
|
is_sudachi_projection_available,
|
|
is_tensorflow_probability_available,
|
|
is_tensorflow_text_available,
|
|
is_tf2onnx_available,
|
|
is_tf_available,
|
|
is_timm_available,
|
|
is_tokenizers_available,
|
|
is_torch_available,
|
|
is_torch_bf16_available,
|
|
is_torch_bf16_available_on_device,
|
|
is_torch_bf16_cpu_available,
|
|
is_torch_bf16_gpu_available,
|
|
is_torch_compile_available,
|
|
is_torch_cuda_available,
|
|
is_torch_fp16_available_on_device,
|
|
is_torch_fx_available,
|
|
is_torch_fx_proxy,
|
|
is_torch_mlu_available,
|
|
is_torch_mps_available,
|
|
is_torch_neuroncore_available,
|
|
is_torch_npu_available,
|
|
is_torch_sdpa_available,
|
|
is_torch_tensorrt_fx_available,
|
|
is_torch_tf32_available,
|
|
is_torch_tpu_available,
|
|
is_torch_xla_available,
|
|
is_torch_xpu_available,
|
|
is_torchaudio_available,
|
|
is_torchdistx_available,
|
|
is_torchdynamo_available,
|
|
is_torchdynamo_compiling,
|
|
is_torchvision_available,
|
|
is_training_run_on_sagemaker,
|
|
is_vision_available,
|
|
requires_backends,
|
|
torch_only_method,
|
|
)
|
|
from .peft_utils import (
|
|
ADAPTER_CONFIG_NAME,
|
|
ADAPTER_SAFE_WEIGHTS_NAME,
|
|
ADAPTER_WEIGHTS_NAME,
|
|
check_peft_version,
|
|
find_adapter_config_file,
|
|
)
|
|
|
|
|
|
WEIGHTS_NAME = "pytorch_model.bin"
|
|
WEIGHTS_INDEX_NAME = "pytorch_model.bin.index.json"
|
|
TF2_WEIGHTS_NAME = "tf_model.h5"
|
|
TF2_WEIGHTS_INDEX_NAME = "tf_model.h5.index.json"
|
|
TF_WEIGHTS_NAME = "model.ckpt"
|
|
FLAX_WEIGHTS_NAME = "flax_model.msgpack"
|
|
FLAX_WEIGHTS_INDEX_NAME = "flax_model.msgpack.index.json"
|
|
SAFE_WEIGHTS_NAME = "model.safetensors"
|
|
SAFE_WEIGHTS_INDEX_NAME = "model.safetensors.index.json"
|
|
CONFIG_NAME = "config.json"
|
|
FEATURE_EXTRACTOR_NAME = "preprocessor_config.json"
|
|
IMAGE_PROCESSOR_NAME = FEATURE_EXTRACTOR_NAME
|
|
PROCESSOR_NAME = "processor_config.json"
|
|
GENERATION_CONFIG_NAME = "generation_config.json"
|
|
MODEL_CARD_NAME = "modelcard.json"
|
|
|
|
SENTENCEPIECE_UNDERLINE = "▁"
|
|
SPIECE_UNDERLINE = SENTENCEPIECE_UNDERLINE # Kept for backward compatibility
|
|
|
|
MULTIPLE_CHOICE_DUMMY_INPUTS = [
|
|
[[0, 1, 0, 1], [1, 0, 0, 1]]
|
|
] * 2 # Needs to have 0s and 1s only since XLM uses it for langs too.
|
|
DUMMY_INPUTS = [[7, 6, 0, 0, 1], [1, 2, 3, 0, 0], [0, 0, 0, 4, 5]]
|
|
DUMMY_MASK = [[1, 1, 1, 1, 1], [1, 1, 1, 0, 0], [0, 0, 0, 1, 1]]
|
|
|
|
|
|
def check_min_version(min_version):
|
|
if version.parse(__version__) < version.parse(min_version):
|
|
if "dev" in min_version:
|
|
error_message = (
|
|
"This example requires a source install from HuggingFace Transformers (see "
|
|
"`https://huggingface.co/docs/transformers/installation#install-from-source`),"
|
|
)
|
|
else:
|
|
error_message = f"This example requires a minimum version of {min_version},"
|
|
error_message += f" but the version found is {__version__}.\n"
|
|
raise ImportError(
|
|
error_message
|
|
+ "Check out https://github.com/huggingface/transformers/tree/main/examples#important-note for the examples corresponding to other "
|
|
"versions of HuggingFace Transformers."
|
|
)
|