1706 lines
67 KiB
Python
1706 lines
67 KiB
Python
# coding=utf-8
|
|
# Copyright 2018 The HuggingFace Inc. team.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
""" Auto Model class."""
|
|
|
|
import warnings
|
|
from collections import OrderedDict
|
|
|
|
from ...utils import logging
|
|
from .auto_factory import (
|
|
_BaseAutoBackboneClass,
|
|
_BaseAutoModelClass,
|
|
_LazyAutoMapping,
|
|
auto_class_update,
|
|
)
|
|
from .configuration_auto import CONFIG_MAPPING_NAMES
|
|
|
|
|
|
logger = logging.get_logger(__name__)
|
|
|
|
MODEL_MAPPING_NAMES = OrderedDict(
|
|
[
|
|
# Base model mapping
|
|
("albert", "AlbertModel"),
|
|
("align", "AlignModel"),
|
|
("altclip", "AltCLIPModel"),
|
|
("audio-spectrogram-transformer", "ASTModel"),
|
|
("autoformer", "AutoformerModel"),
|
|
("bark", "BarkModel"),
|
|
("bart", "BartModel"),
|
|
("beit", "BeitModel"),
|
|
("bert", "BertModel"),
|
|
("bert-generation", "BertGenerationEncoder"),
|
|
("big_bird", "BigBirdModel"),
|
|
("bigbird_pegasus", "BigBirdPegasusModel"),
|
|
("biogpt", "BioGptModel"),
|
|
("bit", "BitModel"),
|
|
("blenderbot", "BlenderbotModel"),
|
|
("blenderbot-small", "BlenderbotSmallModel"),
|
|
("blip", "BlipModel"),
|
|
("blip-2", "Blip2Model"),
|
|
("bloom", "BloomModel"),
|
|
("bridgetower", "BridgeTowerModel"),
|
|
("bros", "BrosModel"),
|
|
("camembert", "CamembertModel"),
|
|
("canine", "CanineModel"),
|
|
("chinese_clip", "ChineseCLIPModel"),
|
|
("chinese_clip_vision_model", "ChineseCLIPVisionModel"),
|
|
("clap", "ClapModel"),
|
|
("clip", "CLIPModel"),
|
|
("clip_vision_model", "CLIPVisionModel"),
|
|
("clipseg", "CLIPSegModel"),
|
|
("clvp", "ClvpModelForConditionalGeneration"),
|
|
("code_llama", "LlamaModel"),
|
|
("codegen", "CodeGenModel"),
|
|
("cohere", "CohereModel"),
|
|
("conditional_detr", "ConditionalDetrModel"),
|
|
("convbert", "ConvBertModel"),
|
|
("convnext", "ConvNextModel"),
|
|
("convnextv2", "ConvNextV2Model"),
|
|
("cpmant", "CpmAntModel"),
|
|
("ctrl", "CTRLModel"),
|
|
("cvt", "CvtModel"),
|
|
("data2vec-audio", "Data2VecAudioModel"),
|
|
("data2vec-text", "Data2VecTextModel"),
|
|
("data2vec-vision", "Data2VecVisionModel"),
|
|
("dbrx", "DbrxModel"),
|
|
("deberta", "DebertaModel"),
|
|
("deberta-v2", "DebertaV2Model"),
|
|
("decision_transformer", "DecisionTransformerModel"),
|
|
("deformable_detr", "DeformableDetrModel"),
|
|
("deit", "DeiTModel"),
|
|
("deta", "DetaModel"),
|
|
("detr", "DetrModel"),
|
|
("dinat", "DinatModel"),
|
|
("dinov2", "Dinov2Model"),
|
|
("distilbert", "DistilBertModel"),
|
|
("donut-swin", "DonutSwinModel"),
|
|
("dpr", "DPRQuestionEncoder"),
|
|
("dpt", "DPTModel"),
|
|
("efficientformer", "EfficientFormerModel"),
|
|
("efficientnet", "EfficientNetModel"),
|
|
("electra", "ElectraModel"),
|
|
("encodec", "EncodecModel"),
|
|
("ernie", "ErnieModel"),
|
|
("ernie_m", "ErnieMModel"),
|
|
("esm", "EsmModel"),
|
|
("falcon", "FalconModel"),
|
|
("fastspeech2_conformer", "FastSpeech2ConformerModel"),
|
|
("flaubert", "FlaubertModel"),
|
|
("flava", "FlavaModel"),
|
|
("fnet", "FNetModel"),
|
|
("focalnet", "FocalNetModel"),
|
|
("fsmt", "FSMTModel"),
|
|
("funnel", ("FunnelModel", "FunnelBaseModel")),
|
|
("gemma", "GemmaModel"),
|
|
("git", "GitModel"),
|
|
("glpn", "GLPNModel"),
|
|
("gpt-sw3", "GPT2Model"),
|
|
("gpt2", "GPT2Model"),
|
|
("gpt_bigcode", "GPTBigCodeModel"),
|
|
("gpt_neo", "GPTNeoModel"),
|
|
("gpt_neox", "GPTNeoXModel"),
|
|
("gpt_neox_japanese", "GPTNeoXJapaneseModel"),
|
|
("gptj", "GPTJModel"),
|
|
("gptsan-japanese", "GPTSanJapaneseForConditionalGeneration"),
|
|
("graphormer", "GraphormerModel"),
|
|
("grounding-dino", "GroundingDinoModel"),
|
|
("groupvit", "GroupViTModel"),
|
|
("hubert", "HubertModel"),
|
|
("ibert", "IBertModel"),
|
|
("idefics", "IdeficsModel"),
|
|
("idefics2", "Idefics2Model"),
|
|
("imagegpt", "ImageGPTModel"),
|
|
("informer", "InformerModel"),
|
|
("jamba", "JambaModel"),
|
|
("jukebox", "JukeboxModel"),
|
|
("kosmos-2", "Kosmos2Model"),
|
|
("layoutlm", "LayoutLMModel"),
|
|
("layoutlmv2", "LayoutLMv2Model"),
|
|
("layoutlmv3", "LayoutLMv3Model"),
|
|
("led", "LEDModel"),
|
|
("levit", "LevitModel"),
|
|
("lilt", "LiltModel"),
|
|
("llama", "LlamaModel"),
|
|
("longformer", "LongformerModel"),
|
|
("longt5", "LongT5Model"),
|
|
("luke", "LukeModel"),
|
|
("lxmert", "LxmertModel"),
|
|
("m2m_100", "M2M100Model"),
|
|
("mamba", "MambaModel"),
|
|
("marian", "MarianModel"),
|
|
("markuplm", "MarkupLMModel"),
|
|
("mask2former", "Mask2FormerModel"),
|
|
("maskformer", "MaskFormerModel"),
|
|
("maskformer-swin", "MaskFormerSwinModel"),
|
|
("mbart", "MBartModel"),
|
|
("mctct", "MCTCTModel"),
|
|
("mega", "MegaModel"),
|
|
("megatron-bert", "MegatronBertModel"),
|
|
("mgp-str", "MgpstrForSceneTextRecognition"),
|
|
("mistral", "MistralModel"),
|
|
("mixtral", "MixtralModel"),
|
|
("mobilebert", "MobileBertModel"),
|
|
("mobilenet_v1", "MobileNetV1Model"),
|
|
("mobilenet_v2", "MobileNetV2Model"),
|
|
("mobilevit", "MobileViTModel"),
|
|
("mobilevitv2", "MobileViTV2Model"),
|
|
("mpnet", "MPNetModel"),
|
|
("mpt", "MptModel"),
|
|
("mra", "MraModel"),
|
|
("mt5", "MT5Model"),
|
|
("mvp", "MvpModel"),
|
|
("nat", "NatModel"),
|
|
("nezha", "NezhaModel"),
|
|
("nllb-moe", "NllbMoeModel"),
|
|
("nystromformer", "NystromformerModel"),
|
|
("olmo", "OlmoModel"),
|
|
("oneformer", "OneFormerModel"),
|
|
("open-llama", "OpenLlamaModel"),
|
|
("openai-gpt", "OpenAIGPTModel"),
|
|
("opt", "OPTModel"),
|
|
("owlv2", "Owlv2Model"),
|
|
("owlvit", "OwlViTModel"),
|
|
("patchtsmixer", "PatchTSMixerModel"),
|
|
("patchtst", "PatchTSTModel"),
|
|
("pegasus", "PegasusModel"),
|
|
("pegasus_x", "PegasusXModel"),
|
|
("perceiver", "PerceiverModel"),
|
|
("persimmon", "PersimmonModel"),
|
|
("phi", "PhiModel"),
|
|
("plbart", "PLBartModel"),
|
|
("poolformer", "PoolFormerModel"),
|
|
("prophetnet", "ProphetNetModel"),
|
|
("pvt", "PvtModel"),
|
|
("pvt_v2", "PvtV2Model"),
|
|
("qdqbert", "QDQBertModel"),
|
|
("qwen2", "Qwen2Model"),
|
|
("qwen2_moe", "Qwen2MoeModel"),
|
|
("recurrent_gemma", "RecurrentGemmaModel"),
|
|
("reformer", "ReformerModel"),
|
|
("regnet", "RegNetModel"),
|
|
("rembert", "RemBertModel"),
|
|
("resnet", "ResNetModel"),
|
|
("retribert", "RetriBertModel"),
|
|
("roberta", "RobertaModel"),
|
|
("roberta-prelayernorm", "RobertaPreLayerNormModel"),
|
|
("roc_bert", "RoCBertModel"),
|
|
("roformer", "RoFormerModel"),
|
|
("rwkv", "RwkvModel"),
|
|
("sam", "SamModel"),
|
|
("seamless_m4t", "SeamlessM4TModel"),
|
|
("seamless_m4t_v2", "SeamlessM4Tv2Model"),
|
|
("segformer", "SegformerModel"),
|
|
("seggpt", "SegGptModel"),
|
|
("sew", "SEWModel"),
|
|
("sew-d", "SEWDModel"),
|
|
("siglip", "SiglipModel"),
|
|
("siglip_vision_model", "SiglipVisionModel"),
|
|
("speech_to_text", "Speech2TextModel"),
|
|
("speecht5", "SpeechT5Model"),
|
|
("splinter", "SplinterModel"),
|
|
("squeezebert", "SqueezeBertModel"),
|
|
("stablelm", "StableLmModel"),
|
|
("starcoder2", "Starcoder2Model"),
|
|
("swiftformer", "SwiftFormerModel"),
|
|
("swin", "SwinModel"),
|
|
("swin2sr", "Swin2SRModel"),
|
|
("swinv2", "Swinv2Model"),
|
|
("switch_transformers", "SwitchTransformersModel"),
|
|
("t5", "T5Model"),
|
|
("table-transformer", "TableTransformerModel"),
|
|
("tapas", "TapasModel"),
|
|
("time_series_transformer", "TimeSeriesTransformerModel"),
|
|
("timesformer", "TimesformerModel"),
|
|
("timm_backbone", "TimmBackbone"),
|
|
("trajectory_transformer", "TrajectoryTransformerModel"),
|
|
("transfo-xl", "TransfoXLModel"),
|
|
("tvlt", "TvltModel"),
|
|
("tvp", "TvpModel"),
|
|
("udop", "UdopModel"),
|
|
("umt5", "UMT5Model"),
|
|
("unispeech", "UniSpeechModel"),
|
|
("unispeech-sat", "UniSpeechSatModel"),
|
|
("univnet", "UnivNetModel"),
|
|
("van", "VanModel"),
|
|
("videomae", "VideoMAEModel"),
|
|
("vilt", "ViltModel"),
|
|
("vision-text-dual-encoder", "VisionTextDualEncoderModel"),
|
|
("visual_bert", "VisualBertModel"),
|
|
("vit", "ViTModel"),
|
|
("vit_hybrid", "ViTHybridModel"),
|
|
("vit_mae", "ViTMAEModel"),
|
|
("vit_msn", "ViTMSNModel"),
|
|
("vitdet", "VitDetModel"),
|
|
("vits", "VitsModel"),
|
|
("vivit", "VivitModel"),
|
|
("wav2vec2", "Wav2Vec2Model"),
|
|
("wav2vec2-bert", "Wav2Vec2BertModel"),
|
|
("wav2vec2-conformer", "Wav2Vec2ConformerModel"),
|
|
("wavlm", "WavLMModel"),
|
|
("whisper", "WhisperModel"),
|
|
("xclip", "XCLIPModel"),
|
|
("xglm", "XGLMModel"),
|
|
("xlm", "XLMModel"),
|
|
("xlm-prophetnet", "XLMProphetNetModel"),
|
|
("xlm-roberta", "XLMRobertaModel"),
|
|
("xlm-roberta-xl", "XLMRobertaXLModel"),
|
|
("xlnet", "XLNetModel"),
|
|
("xmod", "XmodModel"),
|
|
("yolos", "YolosModel"),
|
|
("yoso", "YosoModel"),
|
|
]
|
|
)
|
|
|
|
MODEL_FOR_PRETRAINING_MAPPING_NAMES = OrderedDict(
|
|
[
|
|
# Model for pre-training mapping
|
|
("albert", "AlbertForPreTraining"),
|
|
("bart", "BartForConditionalGeneration"),
|
|
("bert", "BertForPreTraining"),
|
|
("big_bird", "BigBirdForPreTraining"),
|
|
("bloom", "BloomForCausalLM"),
|
|
("camembert", "CamembertForMaskedLM"),
|
|
("ctrl", "CTRLLMHeadModel"),
|
|
("data2vec-text", "Data2VecTextForMaskedLM"),
|
|
("deberta", "DebertaForMaskedLM"),
|
|
("deberta-v2", "DebertaV2ForMaskedLM"),
|
|
("distilbert", "DistilBertForMaskedLM"),
|
|
("electra", "ElectraForPreTraining"),
|
|
("ernie", "ErnieForPreTraining"),
|
|
("flaubert", "FlaubertWithLMHeadModel"),
|
|
("flava", "FlavaForPreTraining"),
|
|
("fnet", "FNetForPreTraining"),
|
|
("fsmt", "FSMTForConditionalGeneration"),
|
|
("funnel", "FunnelForPreTraining"),
|
|
("gpt-sw3", "GPT2LMHeadModel"),
|
|
("gpt2", "GPT2LMHeadModel"),
|
|
("gpt_bigcode", "GPTBigCodeForCausalLM"),
|
|
("gptsan-japanese", "GPTSanJapaneseForConditionalGeneration"),
|
|
("ibert", "IBertForMaskedLM"),
|
|
("idefics", "IdeficsForVisionText2Text"),
|
|
("idefics2", "Idefics2ForConditionalGeneration"),
|
|
("layoutlm", "LayoutLMForMaskedLM"),
|
|
("llava", "LlavaForConditionalGeneration"),
|
|
("llava_next", "LlavaNextForConditionalGeneration"),
|
|
("longformer", "LongformerForMaskedLM"),
|
|
("luke", "LukeForMaskedLM"),
|
|
("lxmert", "LxmertForPreTraining"),
|
|
("mamba", "MambaForCausalLM"),
|
|
("mega", "MegaForMaskedLM"),
|
|
("megatron-bert", "MegatronBertForPreTraining"),
|
|
("mobilebert", "MobileBertForPreTraining"),
|
|
("mpnet", "MPNetForMaskedLM"),
|
|
("mpt", "MptForCausalLM"),
|
|
("mra", "MraForMaskedLM"),
|
|
("mvp", "MvpForConditionalGeneration"),
|
|
("nezha", "NezhaForPreTraining"),
|
|
("nllb-moe", "NllbMoeForConditionalGeneration"),
|
|
("openai-gpt", "OpenAIGPTLMHeadModel"),
|
|
("retribert", "RetriBertModel"),
|
|
("roberta", "RobertaForMaskedLM"),
|
|
("roberta-prelayernorm", "RobertaPreLayerNormForMaskedLM"),
|
|
("roc_bert", "RoCBertForPreTraining"),
|
|
("rwkv", "RwkvForCausalLM"),
|
|
("splinter", "SplinterForPreTraining"),
|
|
("squeezebert", "SqueezeBertForMaskedLM"),
|
|
("switch_transformers", "SwitchTransformersForConditionalGeneration"),
|
|
("t5", "T5ForConditionalGeneration"),
|
|
("tapas", "TapasForMaskedLM"),
|
|
("transfo-xl", "TransfoXLLMHeadModel"),
|
|
("tvlt", "TvltForPreTraining"),
|
|
("unispeech", "UniSpeechForPreTraining"),
|
|
("unispeech-sat", "UniSpeechSatForPreTraining"),
|
|
("videomae", "VideoMAEForPreTraining"),
|
|
("vipllava", "VipLlavaForConditionalGeneration"),
|
|
("visual_bert", "VisualBertForPreTraining"),
|
|
("vit_mae", "ViTMAEForPreTraining"),
|
|
("wav2vec2", "Wav2Vec2ForPreTraining"),
|
|
("wav2vec2-conformer", "Wav2Vec2ConformerForPreTraining"),
|
|
("xlm", "XLMWithLMHeadModel"),
|
|
("xlm-roberta", "XLMRobertaForMaskedLM"),
|
|
("xlm-roberta-xl", "XLMRobertaXLForMaskedLM"),
|
|
("xlnet", "XLNetLMHeadModel"),
|
|
("xmod", "XmodForMaskedLM"),
|
|
]
|
|
)
|
|
|
|
MODEL_WITH_LM_HEAD_MAPPING_NAMES = OrderedDict(
|
|
[
|
|
# Model with LM heads mapping
|
|
("albert", "AlbertForMaskedLM"),
|
|
("bart", "BartForConditionalGeneration"),
|
|
("bert", "BertForMaskedLM"),
|
|
("big_bird", "BigBirdForMaskedLM"),
|
|
("bigbird_pegasus", "BigBirdPegasusForConditionalGeneration"),
|
|
("blenderbot-small", "BlenderbotSmallForConditionalGeneration"),
|
|
("bloom", "BloomForCausalLM"),
|
|
("camembert", "CamembertForMaskedLM"),
|
|
("codegen", "CodeGenForCausalLM"),
|
|
("convbert", "ConvBertForMaskedLM"),
|
|
("cpmant", "CpmAntForCausalLM"),
|
|
("ctrl", "CTRLLMHeadModel"),
|
|
("data2vec-text", "Data2VecTextForMaskedLM"),
|
|
("deberta", "DebertaForMaskedLM"),
|
|
("deberta-v2", "DebertaV2ForMaskedLM"),
|
|
("distilbert", "DistilBertForMaskedLM"),
|
|
("electra", "ElectraForMaskedLM"),
|
|
("encoder-decoder", "EncoderDecoderModel"),
|
|
("ernie", "ErnieForMaskedLM"),
|
|
("esm", "EsmForMaskedLM"),
|
|
("flaubert", "FlaubertWithLMHeadModel"),
|
|
("fnet", "FNetForMaskedLM"),
|
|
("fsmt", "FSMTForConditionalGeneration"),
|
|
("funnel", "FunnelForMaskedLM"),
|
|
("git", "GitForCausalLM"),
|
|
("gpt-sw3", "GPT2LMHeadModel"),
|
|
("gpt2", "GPT2LMHeadModel"),
|
|
("gpt_bigcode", "GPTBigCodeForCausalLM"),
|
|
("gpt_neo", "GPTNeoForCausalLM"),
|
|
("gpt_neox", "GPTNeoXForCausalLM"),
|
|
("gpt_neox_japanese", "GPTNeoXJapaneseForCausalLM"),
|
|
("gptj", "GPTJForCausalLM"),
|
|
("gptsan-japanese", "GPTSanJapaneseForConditionalGeneration"),
|
|
("ibert", "IBertForMaskedLM"),
|
|
("layoutlm", "LayoutLMForMaskedLM"),
|
|
("led", "LEDForConditionalGeneration"),
|
|
("longformer", "LongformerForMaskedLM"),
|
|
("longt5", "LongT5ForConditionalGeneration"),
|
|
("luke", "LukeForMaskedLM"),
|
|
("m2m_100", "M2M100ForConditionalGeneration"),
|
|
("mamba", "MambaForCausalLM"),
|
|
("marian", "MarianMTModel"),
|
|
("mega", "MegaForMaskedLM"),
|
|
("megatron-bert", "MegatronBertForCausalLM"),
|
|
("mobilebert", "MobileBertForMaskedLM"),
|
|
("mpnet", "MPNetForMaskedLM"),
|
|
("mpt", "MptForCausalLM"),
|
|
("mra", "MraForMaskedLM"),
|
|
("mvp", "MvpForConditionalGeneration"),
|
|
("nezha", "NezhaForMaskedLM"),
|
|
("nllb-moe", "NllbMoeForConditionalGeneration"),
|
|
("nystromformer", "NystromformerForMaskedLM"),
|
|
("openai-gpt", "OpenAIGPTLMHeadModel"),
|
|
("pegasus_x", "PegasusXForConditionalGeneration"),
|
|
("plbart", "PLBartForConditionalGeneration"),
|
|
("pop2piano", "Pop2PianoForConditionalGeneration"),
|
|
("qdqbert", "QDQBertForMaskedLM"),
|
|
("reformer", "ReformerModelWithLMHead"),
|
|
("rembert", "RemBertForMaskedLM"),
|
|
("roberta", "RobertaForMaskedLM"),
|
|
("roberta-prelayernorm", "RobertaPreLayerNormForMaskedLM"),
|
|
("roc_bert", "RoCBertForMaskedLM"),
|
|
("roformer", "RoFormerForMaskedLM"),
|
|
("rwkv", "RwkvForCausalLM"),
|
|
("speech_to_text", "Speech2TextForConditionalGeneration"),
|
|
("squeezebert", "SqueezeBertForMaskedLM"),
|
|
("switch_transformers", "SwitchTransformersForConditionalGeneration"),
|
|
("t5", "T5ForConditionalGeneration"),
|
|
("tapas", "TapasForMaskedLM"),
|
|
("transfo-xl", "TransfoXLLMHeadModel"),
|
|
("wav2vec2", "Wav2Vec2ForMaskedLM"),
|
|
("whisper", "WhisperForConditionalGeneration"),
|
|
("xlm", "XLMWithLMHeadModel"),
|
|
("xlm-roberta", "XLMRobertaForMaskedLM"),
|
|
("xlm-roberta-xl", "XLMRobertaXLForMaskedLM"),
|
|
("xlnet", "XLNetLMHeadModel"),
|
|
("xmod", "XmodForMaskedLM"),
|
|
("yoso", "YosoForMaskedLM"),
|
|
]
|
|
)
|
|
|
|
MODEL_FOR_CAUSAL_LM_MAPPING_NAMES = OrderedDict(
|
|
[
|
|
# Model for Causal LM mapping
|
|
("bart", "BartForCausalLM"),
|
|
("bert", "BertLMHeadModel"),
|
|
("bert-generation", "BertGenerationDecoder"),
|
|
("big_bird", "BigBirdForCausalLM"),
|
|
("bigbird_pegasus", "BigBirdPegasusForCausalLM"),
|
|
("biogpt", "BioGptForCausalLM"),
|
|
("blenderbot", "BlenderbotForCausalLM"),
|
|
("blenderbot-small", "BlenderbotSmallForCausalLM"),
|
|
("bloom", "BloomForCausalLM"),
|
|
("camembert", "CamembertForCausalLM"),
|
|
("code_llama", "LlamaForCausalLM"),
|
|
("codegen", "CodeGenForCausalLM"),
|
|
("cohere", "CohereForCausalLM"),
|
|
("cpmant", "CpmAntForCausalLM"),
|
|
("ctrl", "CTRLLMHeadModel"),
|
|
("data2vec-text", "Data2VecTextForCausalLM"),
|
|
("dbrx", "DbrxForCausalLM"),
|
|
("electra", "ElectraForCausalLM"),
|
|
("ernie", "ErnieForCausalLM"),
|
|
("falcon", "FalconForCausalLM"),
|
|
("fuyu", "FuyuForCausalLM"),
|
|
("gemma", "GemmaForCausalLM"),
|
|
("git", "GitForCausalLM"),
|
|
("gpt-sw3", "GPT2LMHeadModel"),
|
|
("gpt2", "GPT2LMHeadModel"),
|
|
("gpt_bigcode", "GPTBigCodeForCausalLM"),
|
|
("gpt_neo", "GPTNeoForCausalLM"),
|
|
("gpt_neox", "GPTNeoXForCausalLM"),
|
|
("gpt_neox_japanese", "GPTNeoXJapaneseForCausalLM"),
|
|
("gptj", "GPTJForCausalLM"),
|
|
("jamba", "JambaForCausalLM"),
|
|
("llama", "LlamaForCausalLM"),
|
|
("mamba", "MambaForCausalLM"),
|
|
("marian", "MarianForCausalLM"),
|
|
("mbart", "MBartForCausalLM"),
|
|
("mega", "MegaForCausalLM"),
|
|
("megatron-bert", "MegatronBertForCausalLM"),
|
|
("mistral", "MistralForCausalLM"),
|
|
("mixtral", "MixtralForCausalLM"),
|
|
("mpt", "MptForCausalLM"),
|
|
("musicgen", "MusicgenForCausalLM"),
|
|
("musicgen_melody", "MusicgenMelodyForCausalLM"),
|
|
("mvp", "MvpForCausalLM"),
|
|
("olmo", "OlmoForCausalLM"),
|
|
("open-llama", "OpenLlamaForCausalLM"),
|
|
("openai-gpt", "OpenAIGPTLMHeadModel"),
|
|
("opt", "OPTForCausalLM"),
|
|
("pegasus", "PegasusForCausalLM"),
|
|
("persimmon", "PersimmonForCausalLM"),
|
|
("phi", "PhiForCausalLM"),
|
|
("plbart", "PLBartForCausalLM"),
|
|
("prophetnet", "ProphetNetForCausalLM"),
|
|
("qdqbert", "QDQBertLMHeadModel"),
|
|
("qwen2", "Qwen2ForCausalLM"),
|
|
("qwen2_moe", "Qwen2MoeForCausalLM"),
|
|
("recurrent_gemma", "RecurrentGemmaForCausalLM"),
|
|
("reformer", "ReformerModelWithLMHead"),
|
|
("rembert", "RemBertForCausalLM"),
|
|
("roberta", "RobertaForCausalLM"),
|
|
("roberta-prelayernorm", "RobertaPreLayerNormForCausalLM"),
|
|
("roc_bert", "RoCBertForCausalLM"),
|
|
("roformer", "RoFormerForCausalLM"),
|
|
("rwkv", "RwkvForCausalLM"),
|
|
("speech_to_text_2", "Speech2Text2ForCausalLM"),
|
|
("stablelm", "StableLmForCausalLM"),
|
|
("starcoder2", "Starcoder2ForCausalLM"),
|
|
("transfo-xl", "TransfoXLLMHeadModel"),
|
|
("trocr", "TrOCRForCausalLM"),
|
|
("whisper", "WhisperForCausalLM"),
|
|
("xglm", "XGLMForCausalLM"),
|
|
("xlm", "XLMWithLMHeadModel"),
|
|
("xlm-prophetnet", "XLMProphetNetForCausalLM"),
|
|
("xlm-roberta", "XLMRobertaForCausalLM"),
|
|
("xlm-roberta-xl", "XLMRobertaXLForCausalLM"),
|
|
("xlnet", "XLNetLMHeadModel"),
|
|
("xmod", "XmodForCausalLM"),
|
|
]
|
|
)
|
|
|
|
MODEL_FOR_IMAGE_MAPPING_NAMES = OrderedDict(
|
|
[
|
|
# Model for Image mapping
|
|
("beit", "BeitModel"),
|
|
("bit", "BitModel"),
|
|
("conditional_detr", "ConditionalDetrModel"),
|
|
("convnext", "ConvNextModel"),
|
|
("convnextv2", "ConvNextV2Model"),
|
|
("data2vec-vision", "Data2VecVisionModel"),
|
|
("deformable_detr", "DeformableDetrModel"),
|
|
("deit", "DeiTModel"),
|
|
("deta", "DetaModel"),
|
|
("detr", "DetrModel"),
|
|
("dinat", "DinatModel"),
|
|
("dinov2", "Dinov2Model"),
|
|
("dpt", "DPTModel"),
|
|
("efficientformer", "EfficientFormerModel"),
|
|
("efficientnet", "EfficientNetModel"),
|
|
("focalnet", "FocalNetModel"),
|
|
("glpn", "GLPNModel"),
|
|
("imagegpt", "ImageGPTModel"),
|
|
("levit", "LevitModel"),
|
|
("mobilenet_v1", "MobileNetV1Model"),
|
|
("mobilenet_v2", "MobileNetV2Model"),
|
|
("mobilevit", "MobileViTModel"),
|
|
("mobilevitv2", "MobileViTV2Model"),
|
|
("nat", "NatModel"),
|
|
("poolformer", "PoolFormerModel"),
|
|
("pvt", "PvtModel"),
|
|
("regnet", "RegNetModel"),
|
|
("resnet", "ResNetModel"),
|
|
("segformer", "SegformerModel"),
|
|
("siglip_vision_model", "SiglipVisionModel"),
|
|
("swiftformer", "SwiftFormerModel"),
|
|
("swin", "SwinModel"),
|
|
("swin2sr", "Swin2SRModel"),
|
|
("swinv2", "Swinv2Model"),
|
|
("table-transformer", "TableTransformerModel"),
|
|
("timesformer", "TimesformerModel"),
|
|
("timm_backbone", "TimmBackbone"),
|
|
("van", "VanModel"),
|
|
("videomae", "VideoMAEModel"),
|
|
("vit", "ViTModel"),
|
|
("vit_hybrid", "ViTHybridModel"),
|
|
("vit_mae", "ViTMAEModel"),
|
|
("vit_msn", "ViTMSNModel"),
|
|
("vitdet", "VitDetModel"),
|
|
("vivit", "VivitModel"),
|
|
("yolos", "YolosModel"),
|
|
]
|
|
)
|
|
|
|
MODEL_FOR_MASKED_IMAGE_MODELING_MAPPING_NAMES = OrderedDict(
|
|
[
|
|
("deit", "DeiTForMaskedImageModeling"),
|
|
("focalnet", "FocalNetForMaskedImageModeling"),
|
|
("swin", "SwinForMaskedImageModeling"),
|
|
("swinv2", "Swinv2ForMaskedImageModeling"),
|
|
("vit", "ViTForMaskedImageModeling"),
|
|
]
|
|
)
|
|
|
|
|
|
MODEL_FOR_CAUSAL_IMAGE_MODELING_MAPPING_NAMES = OrderedDict(
|
|
# Model for Causal Image Modeling mapping
|
|
[
|
|
("imagegpt", "ImageGPTForCausalImageModeling"),
|
|
]
|
|
)
|
|
|
|
MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
|
|
[
|
|
# Model for Image Classification mapping
|
|
("beit", "BeitForImageClassification"),
|
|
("bit", "BitForImageClassification"),
|
|
("clip", "CLIPForImageClassification"),
|
|
("convnext", "ConvNextForImageClassification"),
|
|
("convnextv2", "ConvNextV2ForImageClassification"),
|
|
("cvt", "CvtForImageClassification"),
|
|
("data2vec-vision", "Data2VecVisionForImageClassification"),
|
|
(
|
|
"deit",
|
|
("DeiTForImageClassification", "DeiTForImageClassificationWithTeacher"),
|
|
),
|
|
("dinat", "DinatForImageClassification"),
|
|
("dinov2", "Dinov2ForImageClassification"),
|
|
(
|
|
"efficientformer",
|
|
(
|
|
"EfficientFormerForImageClassification",
|
|
"EfficientFormerForImageClassificationWithTeacher",
|
|
),
|
|
),
|
|
("efficientnet", "EfficientNetForImageClassification"),
|
|
("focalnet", "FocalNetForImageClassification"),
|
|
("imagegpt", "ImageGPTForImageClassification"),
|
|
(
|
|
"levit",
|
|
("LevitForImageClassification", "LevitForImageClassificationWithTeacher"),
|
|
),
|
|
("mobilenet_v1", "MobileNetV1ForImageClassification"),
|
|
("mobilenet_v2", "MobileNetV2ForImageClassification"),
|
|
("mobilevit", "MobileViTForImageClassification"),
|
|
("mobilevitv2", "MobileViTV2ForImageClassification"),
|
|
("nat", "NatForImageClassification"),
|
|
(
|
|
"perceiver",
|
|
(
|
|
"PerceiverForImageClassificationLearned",
|
|
"PerceiverForImageClassificationFourier",
|
|
"PerceiverForImageClassificationConvProcessing",
|
|
),
|
|
),
|
|
("poolformer", "PoolFormerForImageClassification"),
|
|
("pvt", "PvtForImageClassification"),
|
|
("pvt_v2", "PvtV2ForImageClassification"),
|
|
("regnet", "RegNetForImageClassification"),
|
|
("resnet", "ResNetForImageClassification"),
|
|
("segformer", "SegformerForImageClassification"),
|
|
("siglip", "SiglipForImageClassification"),
|
|
("swiftformer", "SwiftFormerForImageClassification"),
|
|
("swin", "SwinForImageClassification"),
|
|
("swinv2", "Swinv2ForImageClassification"),
|
|
("van", "VanForImageClassification"),
|
|
("vit", "ViTForImageClassification"),
|
|
("vit_hybrid", "ViTHybridForImageClassification"),
|
|
("vit_msn", "ViTMSNForImageClassification"),
|
|
]
|
|
)
|
|
|
|
MODEL_FOR_IMAGE_SEGMENTATION_MAPPING_NAMES = OrderedDict(
|
|
[
|
|
# Do not add new models here, this class will be deprecated in the future.
|
|
# Model for Image Segmentation mapping
|
|
("detr", "DetrForSegmentation"),
|
|
]
|
|
)
|
|
|
|
MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING_NAMES = OrderedDict(
|
|
[
|
|
# Model for Semantic Segmentation mapping
|
|
("beit", "BeitForSemanticSegmentation"),
|
|
("data2vec-vision", "Data2VecVisionForSemanticSegmentation"),
|
|
("dpt", "DPTForSemanticSegmentation"),
|
|
("mobilenet_v2", "MobileNetV2ForSemanticSegmentation"),
|
|
("mobilevit", "MobileViTForSemanticSegmentation"),
|
|
("mobilevitv2", "MobileViTV2ForSemanticSegmentation"),
|
|
("segformer", "SegformerForSemanticSegmentation"),
|
|
("upernet", "UperNetForSemanticSegmentation"),
|
|
]
|
|
)
|
|
|
|
MODEL_FOR_INSTANCE_SEGMENTATION_MAPPING_NAMES = OrderedDict(
|
|
[
|
|
# Model for Instance Segmentation mapping
|
|
# MaskFormerForInstanceSegmentation can be removed from this mapping in v5
|
|
("maskformer", "MaskFormerForInstanceSegmentation"),
|
|
]
|
|
)
|
|
|
|
MODEL_FOR_UNIVERSAL_SEGMENTATION_MAPPING_NAMES = OrderedDict(
|
|
[
|
|
# Model for Universal Segmentation mapping
|
|
("detr", "DetrForSegmentation"),
|
|
("mask2former", "Mask2FormerForUniversalSegmentation"),
|
|
("maskformer", "MaskFormerForInstanceSegmentation"),
|
|
("oneformer", "OneFormerForUniversalSegmentation"),
|
|
]
|
|
)
|
|
|
|
MODEL_FOR_VIDEO_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
|
|
[
|
|
("timesformer", "TimesformerForVideoClassification"),
|
|
("videomae", "VideoMAEForVideoClassification"),
|
|
("vivit", "VivitForVideoClassification"),
|
|
]
|
|
)
|
|
|
|
MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES = OrderedDict(
|
|
[
|
|
("blip", "BlipForConditionalGeneration"),
|
|
("blip-2", "Blip2ForConditionalGeneration"),
|
|
("git", "GitForCausalLM"),
|
|
("idefics2", "Idefics2ForConditionalGeneration"),
|
|
("instructblip", "InstructBlipForConditionalGeneration"),
|
|
("kosmos-2", "Kosmos2ForConditionalGeneration"),
|
|
("llava", "LlavaForConditionalGeneration"),
|
|
("llava_next", "LlavaNextForConditionalGeneration"),
|
|
("pix2struct", "Pix2StructForConditionalGeneration"),
|
|
("vipllava", "VipLlavaForConditionalGeneration"),
|
|
("vision-encoder-decoder", "VisionEncoderDecoderModel"),
|
|
]
|
|
)
|
|
|
|
MODEL_FOR_MASKED_LM_MAPPING_NAMES = OrderedDict(
|
|
[
|
|
# Model for Masked LM mapping
|
|
("albert", "AlbertForMaskedLM"),
|
|
("bart", "BartForConditionalGeneration"),
|
|
("bert", "BertForMaskedLM"),
|
|
("big_bird", "BigBirdForMaskedLM"),
|
|
("camembert", "CamembertForMaskedLM"),
|
|
("convbert", "ConvBertForMaskedLM"),
|
|
("data2vec-text", "Data2VecTextForMaskedLM"),
|
|
("deberta", "DebertaForMaskedLM"),
|
|
("deberta-v2", "DebertaV2ForMaskedLM"),
|
|
("distilbert", "DistilBertForMaskedLM"),
|
|
("electra", "ElectraForMaskedLM"),
|
|
("ernie", "ErnieForMaskedLM"),
|
|
("esm", "EsmForMaskedLM"),
|
|
("flaubert", "FlaubertWithLMHeadModel"),
|
|
("fnet", "FNetForMaskedLM"),
|
|
("funnel", "FunnelForMaskedLM"),
|
|
("ibert", "IBertForMaskedLM"),
|
|
("layoutlm", "LayoutLMForMaskedLM"),
|
|
("longformer", "LongformerForMaskedLM"),
|
|
("luke", "LukeForMaskedLM"),
|
|
("mbart", "MBartForConditionalGeneration"),
|
|
("mega", "MegaForMaskedLM"),
|
|
("megatron-bert", "MegatronBertForMaskedLM"),
|
|
("mobilebert", "MobileBertForMaskedLM"),
|
|
("mpnet", "MPNetForMaskedLM"),
|
|
("mra", "MraForMaskedLM"),
|
|
("mvp", "MvpForConditionalGeneration"),
|
|
("nezha", "NezhaForMaskedLM"),
|
|
("nystromformer", "NystromformerForMaskedLM"),
|
|
("perceiver", "PerceiverForMaskedLM"),
|
|
("qdqbert", "QDQBertForMaskedLM"),
|
|
("reformer", "ReformerForMaskedLM"),
|
|
("rembert", "RemBertForMaskedLM"),
|
|
("roberta", "RobertaForMaskedLM"),
|
|
("roberta-prelayernorm", "RobertaPreLayerNormForMaskedLM"),
|
|
("roc_bert", "RoCBertForMaskedLM"),
|
|
("roformer", "RoFormerForMaskedLM"),
|
|
("squeezebert", "SqueezeBertForMaskedLM"),
|
|
("tapas", "TapasForMaskedLM"),
|
|
("wav2vec2", "Wav2Vec2ForMaskedLM"),
|
|
("xlm", "XLMWithLMHeadModel"),
|
|
("xlm-roberta", "XLMRobertaForMaskedLM"),
|
|
("xlm-roberta-xl", "XLMRobertaXLForMaskedLM"),
|
|
("xmod", "XmodForMaskedLM"),
|
|
("yoso", "YosoForMaskedLM"),
|
|
]
|
|
)
|
|
|
|
MODEL_FOR_OBJECT_DETECTION_MAPPING_NAMES = OrderedDict(
|
|
[
|
|
# Model for Object Detection mapping
|
|
("conditional_detr", "ConditionalDetrForObjectDetection"),
|
|
("deformable_detr", "DeformableDetrForObjectDetection"),
|
|
("deta", "DetaForObjectDetection"),
|
|
("detr", "DetrForObjectDetection"),
|
|
("table-transformer", "TableTransformerForObjectDetection"),
|
|
("yolos", "YolosForObjectDetection"),
|
|
]
|
|
)
|
|
|
|
MODEL_FOR_ZERO_SHOT_OBJECT_DETECTION_MAPPING_NAMES = OrderedDict(
|
|
[
|
|
# Model for Zero Shot Object Detection mapping
|
|
("grounding-dino", "GroundingDinoForObjectDetection"),
|
|
("owlv2", "Owlv2ForObjectDetection"),
|
|
("owlvit", "OwlViTForObjectDetection"),
|
|
]
|
|
)
|
|
|
|
MODEL_FOR_DEPTH_ESTIMATION_MAPPING_NAMES = OrderedDict(
|
|
[
|
|
# Model for depth estimation mapping
|
|
("depth_anything", "DepthAnythingForDepthEstimation"),
|
|
("dpt", "DPTForDepthEstimation"),
|
|
("glpn", "GLPNForDepthEstimation"),
|
|
]
|
|
)
|
|
MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES = OrderedDict(
|
|
[
|
|
# Model for Seq2Seq Causal LM mapping
|
|
("bart", "BartForConditionalGeneration"),
|
|
("bigbird_pegasus", "BigBirdPegasusForConditionalGeneration"),
|
|
("blenderbot", "BlenderbotForConditionalGeneration"),
|
|
("blenderbot-small", "BlenderbotSmallForConditionalGeneration"),
|
|
("encoder-decoder", "EncoderDecoderModel"),
|
|
("fsmt", "FSMTForConditionalGeneration"),
|
|
("gptsan-japanese", "GPTSanJapaneseForConditionalGeneration"),
|
|
("led", "LEDForConditionalGeneration"),
|
|
("longt5", "LongT5ForConditionalGeneration"),
|
|
("m2m_100", "M2M100ForConditionalGeneration"),
|
|
("marian", "MarianMTModel"),
|
|
("mbart", "MBartForConditionalGeneration"),
|
|
("mt5", "MT5ForConditionalGeneration"),
|
|
("mvp", "MvpForConditionalGeneration"),
|
|
("nllb-moe", "NllbMoeForConditionalGeneration"),
|
|
("pegasus", "PegasusForConditionalGeneration"),
|
|
("pegasus_x", "PegasusXForConditionalGeneration"),
|
|
("plbart", "PLBartForConditionalGeneration"),
|
|
("prophetnet", "ProphetNetForConditionalGeneration"),
|
|
("seamless_m4t", "SeamlessM4TForTextToText"),
|
|
("seamless_m4t_v2", "SeamlessM4Tv2ForTextToText"),
|
|
("switch_transformers", "SwitchTransformersForConditionalGeneration"),
|
|
("t5", "T5ForConditionalGeneration"),
|
|
("umt5", "UMT5ForConditionalGeneration"),
|
|
("xlm-prophetnet", "XLMProphetNetForConditionalGeneration"),
|
|
]
|
|
)
|
|
|
|
MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES = OrderedDict(
|
|
[
|
|
("pop2piano", "Pop2PianoForConditionalGeneration"),
|
|
("seamless_m4t", "SeamlessM4TForSpeechToText"),
|
|
("seamless_m4t_v2", "SeamlessM4Tv2ForSpeechToText"),
|
|
("speech-encoder-decoder", "SpeechEncoderDecoderModel"),
|
|
("speech_to_text", "Speech2TextForConditionalGeneration"),
|
|
("speecht5", "SpeechT5ForSpeechToText"),
|
|
("whisper", "WhisperForConditionalGeneration"),
|
|
]
|
|
)
|
|
|
|
MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
|
|
[
|
|
# Model for Sequence Classification mapping
|
|
("albert", "AlbertForSequenceClassification"),
|
|
("bart", "BartForSequenceClassification"),
|
|
("bert", "BertForSequenceClassification"),
|
|
("big_bird", "BigBirdForSequenceClassification"),
|
|
("bigbird_pegasus", "BigBirdPegasusForSequenceClassification"),
|
|
("biogpt", "BioGptForSequenceClassification"),
|
|
("bloom", "BloomForSequenceClassification"),
|
|
("camembert", "CamembertForSequenceClassification"),
|
|
("canine", "CanineForSequenceClassification"),
|
|
("code_llama", "LlamaForSequenceClassification"),
|
|
("convbert", "ConvBertForSequenceClassification"),
|
|
("ctrl", "CTRLForSequenceClassification"),
|
|
("data2vec-text", "Data2VecTextForSequenceClassification"),
|
|
("deberta", "DebertaForSequenceClassification"),
|
|
("deberta-v2", "DebertaV2ForSequenceClassification"),
|
|
("distilbert", "DistilBertForSequenceClassification"),
|
|
("electra", "ElectraForSequenceClassification"),
|
|
("ernie", "ErnieForSequenceClassification"),
|
|
("ernie_m", "ErnieMForSequenceClassification"),
|
|
("esm", "EsmForSequenceClassification"),
|
|
("falcon", "FalconForSequenceClassification"),
|
|
("flaubert", "FlaubertForSequenceClassification"),
|
|
("fnet", "FNetForSequenceClassification"),
|
|
("funnel", "FunnelForSequenceClassification"),
|
|
("gemma", "GemmaForSequenceClassification"),
|
|
("gpt-sw3", "GPT2ForSequenceClassification"),
|
|
("gpt2", "GPT2ForSequenceClassification"),
|
|
("gpt_bigcode", "GPTBigCodeForSequenceClassification"),
|
|
("gpt_neo", "GPTNeoForSequenceClassification"),
|
|
("gpt_neox", "GPTNeoXForSequenceClassification"),
|
|
("gptj", "GPTJForSequenceClassification"),
|
|
("ibert", "IBertForSequenceClassification"),
|
|
("jamba", "JambaForSequenceClassification"),
|
|
("layoutlm", "LayoutLMForSequenceClassification"),
|
|
("layoutlmv2", "LayoutLMv2ForSequenceClassification"),
|
|
("layoutlmv3", "LayoutLMv3ForSequenceClassification"),
|
|
("led", "LEDForSequenceClassification"),
|
|
("lilt", "LiltForSequenceClassification"),
|
|
("llama", "LlamaForSequenceClassification"),
|
|
("longformer", "LongformerForSequenceClassification"),
|
|
("luke", "LukeForSequenceClassification"),
|
|
("markuplm", "MarkupLMForSequenceClassification"),
|
|
("mbart", "MBartForSequenceClassification"),
|
|
("mega", "MegaForSequenceClassification"),
|
|
("megatron-bert", "MegatronBertForSequenceClassification"),
|
|
("mistral", "MistralForSequenceClassification"),
|
|
("mixtral", "MixtralForSequenceClassification"),
|
|
("mobilebert", "MobileBertForSequenceClassification"),
|
|
("mpnet", "MPNetForSequenceClassification"),
|
|
("mpt", "MptForSequenceClassification"),
|
|
("mra", "MraForSequenceClassification"),
|
|
("mt5", "MT5ForSequenceClassification"),
|
|
("mvp", "MvpForSequenceClassification"),
|
|
("nezha", "NezhaForSequenceClassification"),
|
|
("nystromformer", "NystromformerForSequenceClassification"),
|
|
("open-llama", "OpenLlamaForSequenceClassification"),
|
|
("openai-gpt", "OpenAIGPTForSequenceClassification"),
|
|
("opt", "OPTForSequenceClassification"),
|
|
("perceiver", "PerceiverForSequenceClassification"),
|
|
("persimmon", "PersimmonForSequenceClassification"),
|
|
("phi", "PhiForSequenceClassification"),
|
|
("plbart", "PLBartForSequenceClassification"),
|
|
("qdqbert", "QDQBertForSequenceClassification"),
|
|
("qwen2", "Qwen2ForSequenceClassification"),
|
|
("qwen2_moe", "Qwen2MoeForSequenceClassification"),
|
|
("reformer", "ReformerForSequenceClassification"),
|
|
("rembert", "RemBertForSequenceClassification"),
|
|
("roberta", "RobertaForSequenceClassification"),
|
|
("roberta-prelayernorm", "RobertaPreLayerNormForSequenceClassification"),
|
|
("roc_bert", "RoCBertForSequenceClassification"),
|
|
("roformer", "RoFormerForSequenceClassification"),
|
|
("squeezebert", "SqueezeBertForSequenceClassification"),
|
|
("stablelm", "StableLmForSequenceClassification"),
|
|
("starcoder2", "Starcoder2ForSequenceClassification"),
|
|
("t5", "T5ForSequenceClassification"),
|
|
("tapas", "TapasForSequenceClassification"),
|
|
("transfo-xl", "TransfoXLForSequenceClassification"),
|
|
("umt5", "UMT5ForSequenceClassification"),
|
|
("xlm", "XLMForSequenceClassification"),
|
|
("xlm-roberta", "XLMRobertaForSequenceClassification"),
|
|
("xlm-roberta-xl", "XLMRobertaXLForSequenceClassification"),
|
|
("xlnet", "XLNetForSequenceClassification"),
|
|
("xmod", "XmodForSequenceClassification"),
|
|
("yoso", "YosoForSequenceClassification"),
|
|
]
|
|
)
|
|
|
|
MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES = OrderedDict(
|
|
[
|
|
# Model for Question Answering mapping
|
|
("albert", "AlbertForQuestionAnswering"),
|
|
("bart", "BartForQuestionAnswering"),
|
|
("bert", "BertForQuestionAnswering"),
|
|
("big_bird", "BigBirdForQuestionAnswering"),
|
|
("bigbird_pegasus", "BigBirdPegasusForQuestionAnswering"),
|
|
("bloom", "BloomForQuestionAnswering"),
|
|
("camembert", "CamembertForQuestionAnswering"),
|
|
("canine", "CanineForQuestionAnswering"),
|
|
("convbert", "ConvBertForQuestionAnswering"),
|
|
("data2vec-text", "Data2VecTextForQuestionAnswering"),
|
|
("deberta", "DebertaForQuestionAnswering"),
|
|
("deberta-v2", "DebertaV2ForQuestionAnswering"),
|
|
("distilbert", "DistilBertForQuestionAnswering"),
|
|
("electra", "ElectraForQuestionAnswering"),
|
|
("ernie", "ErnieForQuestionAnswering"),
|
|
("ernie_m", "ErnieMForQuestionAnswering"),
|
|
("falcon", "FalconForQuestionAnswering"),
|
|
("flaubert", "FlaubertForQuestionAnsweringSimple"),
|
|
("fnet", "FNetForQuestionAnswering"),
|
|
("funnel", "FunnelForQuestionAnswering"),
|
|
("gpt2", "GPT2ForQuestionAnswering"),
|
|
("gpt_neo", "GPTNeoForQuestionAnswering"),
|
|
("gpt_neox", "GPTNeoXForQuestionAnswering"),
|
|
("gptj", "GPTJForQuestionAnswering"),
|
|
("ibert", "IBertForQuestionAnswering"),
|
|
("layoutlmv2", "LayoutLMv2ForQuestionAnswering"),
|
|
("layoutlmv3", "LayoutLMv3ForQuestionAnswering"),
|
|
("led", "LEDForQuestionAnswering"),
|
|
("lilt", "LiltForQuestionAnswering"),
|
|
("llama", "LlamaForQuestionAnswering"),
|
|
("longformer", "LongformerForQuestionAnswering"),
|
|
("luke", "LukeForQuestionAnswering"),
|
|
("lxmert", "LxmertForQuestionAnswering"),
|
|
("markuplm", "MarkupLMForQuestionAnswering"),
|
|
("mbart", "MBartForQuestionAnswering"),
|
|
("mega", "MegaForQuestionAnswering"),
|
|
("megatron-bert", "MegatronBertForQuestionAnswering"),
|
|
("mobilebert", "MobileBertForQuestionAnswering"),
|
|
("mpnet", "MPNetForQuestionAnswering"),
|
|
("mpt", "MptForQuestionAnswering"),
|
|
("mra", "MraForQuestionAnswering"),
|
|
("mt5", "MT5ForQuestionAnswering"),
|
|
("mvp", "MvpForQuestionAnswering"),
|
|
("nezha", "NezhaForQuestionAnswering"),
|
|
("nystromformer", "NystromformerForQuestionAnswering"),
|
|
("opt", "OPTForQuestionAnswering"),
|
|
("qdqbert", "QDQBertForQuestionAnswering"),
|
|
("reformer", "ReformerForQuestionAnswering"),
|
|
("rembert", "RemBertForQuestionAnswering"),
|
|
("roberta", "RobertaForQuestionAnswering"),
|
|
("roberta-prelayernorm", "RobertaPreLayerNormForQuestionAnswering"),
|
|
("roc_bert", "RoCBertForQuestionAnswering"),
|
|
("roformer", "RoFormerForQuestionAnswering"),
|
|
("splinter", "SplinterForQuestionAnswering"),
|
|
("squeezebert", "SqueezeBertForQuestionAnswering"),
|
|
("t5", "T5ForQuestionAnswering"),
|
|
("umt5", "UMT5ForQuestionAnswering"),
|
|
("xlm", "XLMForQuestionAnsweringSimple"),
|
|
("xlm-roberta", "XLMRobertaForQuestionAnswering"),
|
|
("xlm-roberta-xl", "XLMRobertaXLForQuestionAnswering"),
|
|
("xlnet", "XLNetForQuestionAnsweringSimple"),
|
|
("xmod", "XmodForQuestionAnswering"),
|
|
("yoso", "YosoForQuestionAnswering"),
|
|
]
|
|
)
|
|
|
|
MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING_NAMES = OrderedDict(
|
|
[
|
|
# Model for Table Question Answering mapping
|
|
("tapas", "TapasForQuestionAnswering"),
|
|
]
|
|
)
|
|
|
|
MODEL_FOR_VISUAL_QUESTION_ANSWERING_MAPPING_NAMES = OrderedDict(
|
|
[
|
|
("blip", "BlipForQuestionAnswering"),
|
|
("blip-2", "Blip2ForConditionalGeneration"),
|
|
("vilt", "ViltForQuestionAnswering"),
|
|
]
|
|
)
|
|
|
|
MODEL_FOR_DOCUMENT_QUESTION_ANSWERING_MAPPING_NAMES = OrderedDict(
|
|
[
|
|
("layoutlm", "LayoutLMForQuestionAnswering"),
|
|
("layoutlmv2", "LayoutLMv2ForQuestionAnswering"),
|
|
("layoutlmv3", "LayoutLMv3ForQuestionAnswering"),
|
|
]
|
|
)
|
|
|
|
MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
|
|
[
|
|
# Model for Token Classification mapping
|
|
("albert", "AlbertForTokenClassification"),
|
|
("bert", "BertForTokenClassification"),
|
|
("big_bird", "BigBirdForTokenClassification"),
|
|
("biogpt", "BioGptForTokenClassification"),
|
|
("bloom", "BloomForTokenClassification"),
|
|
("bros", "BrosForTokenClassification"),
|
|
("camembert", "CamembertForTokenClassification"),
|
|
("canine", "CanineForTokenClassification"),
|
|
("convbert", "ConvBertForTokenClassification"),
|
|
("data2vec-text", "Data2VecTextForTokenClassification"),
|
|
("deberta", "DebertaForTokenClassification"),
|
|
("deberta-v2", "DebertaV2ForTokenClassification"),
|
|
("distilbert", "DistilBertForTokenClassification"),
|
|
("electra", "ElectraForTokenClassification"),
|
|
("ernie", "ErnieForTokenClassification"),
|
|
("ernie_m", "ErnieMForTokenClassification"),
|
|
("esm", "EsmForTokenClassification"),
|
|
("falcon", "FalconForTokenClassification"),
|
|
("flaubert", "FlaubertForTokenClassification"),
|
|
("fnet", "FNetForTokenClassification"),
|
|
("funnel", "FunnelForTokenClassification"),
|
|
("gpt-sw3", "GPT2ForTokenClassification"),
|
|
("gpt2", "GPT2ForTokenClassification"),
|
|
("gpt_bigcode", "GPTBigCodeForTokenClassification"),
|
|
("gpt_neo", "GPTNeoForTokenClassification"),
|
|
("gpt_neox", "GPTNeoXForTokenClassification"),
|
|
("ibert", "IBertForTokenClassification"),
|
|
("layoutlm", "LayoutLMForTokenClassification"),
|
|
("layoutlmv2", "LayoutLMv2ForTokenClassification"),
|
|
("layoutlmv3", "LayoutLMv3ForTokenClassification"),
|
|
("lilt", "LiltForTokenClassification"),
|
|
("longformer", "LongformerForTokenClassification"),
|
|
("luke", "LukeForTokenClassification"),
|
|
("markuplm", "MarkupLMForTokenClassification"),
|
|
("mega", "MegaForTokenClassification"),
|
|
("megatron-bert", "MegatronBertForTokenClassification"),
|
|
("mobilebert", "MobileBertForTokenClassification"),
|
|
("mpnet", "MPNetForTokenClassification"),
|
|
("mpt", "MptForTokenClassification"),
|
|
("mra", "MraForTokenClassification"),
|
|
("mt5", "MT5ForTokenClassification"),
|
|
("nezha", "NezhaForTokenClassification"),
|
|
("nystromformer", "NystromformerForTokenClassification"),
|
|
("phi", "PhiForTokenClassification"),
|
|
("qdqbert", "QDQBertForTokenClassification"),
|
|
("rembert", "RemBertForTokenClassification"),
|
|
("roberta", "RobertaForTokenClassification"),
|
|
("roberta-prelayernorm", "RobertaPreLayerNormForTokenClassification"),
|
|
("roc_bert", "RoCBertForTokenClassification"),
|
|
("roformer", "RoFormerForTokenClassification"),
|
|
("squeezebert", "SqueezeBertForTokenClassification"),
|
|
("t5", "T5ForTokenClassification"),
|
|
("umt5", "UMT5ForTokenClassification"),
|
|
("xlm", "XLMForTokenClassification"),
|
|
("xlm-roberta", "XLMRobertaForTokenClassification"),
|
|
("xlm-roberta-xl", "XLMRobertaXLForTokenClassification"),
|
|
("xlnet", "XLNetForTokenClassification"),
|
|
("xmod", "XmodForTokenClassification"),
|
|
("yoso", "YosoForTokenClassification"),
|
|
]
|
|
)
|
|
|
|
MODEL_FOR_MULTIPLE_CHOICE_MAPPING_NAMES = OrderedDict(
|
|
[
|
|
# Model for Multiple Choice mapping
|
|
("albert", "AlbertForMultipleChoice"),
|
|
("bert", "BertForMultipleChoice"),
|
|
("big_bird", "BigBirdForMultipleChoice"),
|
|
("camembert", "CamembertForMultipleChoice"),
|
|
("canine", "CanineForMultipleChoice"),
|
|
("convbert", "ConvBertForMultipleChoice"),
|
|
("data2vec-text", "Data2VecTextForMultipleChoice"),
|
|
("deberta-v2", "DebertaV2ForMultipleChoice"),
|
|
("distilbert", "DistilBertForMultipleChoice"),
|
|
("electra", "ElectraForMultipleChoice"),
|
|
("ernie", "ErnieForMultipleChoice"),
|
|
("ernie_m", "ErnieMForMultipleChoice"),
|
|
("flaubert", "FlaubertForMultipleChoice"),
|
|
("fnet", "FNetForMultipleChoice"),
|
|
("funnel", "FunnelForMultipleChoice"),
|
|
("ibert", "IBertForMultipleChoice"),
|
|
("longformer", "LongformerForMultipleChoice"),
|
|
("luke", "LukeForMultipleChoice"),
|
|
("mega", "MegaForMultipleChoice"),
|
|
("megatron-bert", "MegatronBertForMultipleChoice"),
|
|
("mobilebert", "MobileBertForMultipleChoice"),
|
|
("mpnet", "MPNetForMultipleChoice"),
|
|
("mra", "MraForMultipleChoice"),
|
|
("nezha", "NezhaForMultipleChoice"),
|
|
("nystromformer", "NystromformerForMultipleChoice"),
|
|
("qdqbert", "QDQBertForMultipleChoice"),
|
|
("rembert", "RemBertForMultipleChoice"),
|
|
("roberta", "RobertaForMultipleChoice"),
|
|
("roberta-prelayernorm", "RobertaPreLayerNormForMultipleChoice"),
|
|
("roc_bert", "RoCBertForMultipleChoice"),
|
|
("roformer", "RoFormerForMultipleChoice"),
|
|
("squeezebert", "SqueezeBertForMultipleChoice"),
|
|
("xlm", "XLMForMultipleChoice"),
|
|
("xlm-roberta", "XLMRobertaForMultipleChoice"),
|
|
("xlm-roberta-xl", "XLMRobertaXLForMultipleChoice"),
|
|
("xlnet", "XLNetForMultipleChoice"),
|
|
("xmod", "XmodForMultipleChoice"),
|
|
("yoso", "YosoForMultipleChoice"),
|
|
]
|
|
)
|
|
|
|
MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING_NAMES = OrderedDict(
|
|
[
|
|
("bert", "BertForNextSentencePrediction"),
|
|
("ernie", "ErnieForNextSentencePrediction"),
|
|
("fnet", "FNetForNextSentencePrediction"),
|
|
("megatron-bert", "MegatronBertForNextSentencePrediction"),
|
|
("mobilebert", "MobileBertForNextSentencePrediction"),
|
|
("nezha", "NezhaForNextSentencePrediction"),
|
|
("qdqbert", "QDQBertForNextSentencePrediction"),
|
|
]
|
|
)
|
|
|
|
MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
|
|
[
|
|
# Model for Audio Classification mapping
|
|
("audio-spectrogram-transformer", "ASTForAudioClassification"),
|
|
("data2vec-audio", "Data2VecAudioForSequenceClassification"),
|
|
("hubert", "HubertForSequenceClassification"),
|
|
("sew", "SEWForSequenceClassification"),
|
|
("sew-d", "SEWDForSequenceClassification"),
|
|
("unispeech", "UniSpeechForSequenceClassification"),
|
|
("unispeech-sat", "UniSpeechSatForSequenceClassification"),
|
|
("wav2vec2", "Wav2Vec2ForSequenceClassification"),
|
|
("wav2vec2-bert", "Wav2Vec2BertForSequenceClassification"),
|
|
("wav2vec2-conformer", "Wav2Vec2ConformerForSequenceClassification"),
|
|
("wavlm", "WavLMForSequenceClassification"),
|
|
("whisper", "WhisperForAudioClassification"),
|
|
]
|
|
)
|
|
|
|
MODEL_FOR_CTC_MAPPING_NAMES = OrderedDict(
|
|
[
|
|
# Model for Connectionist temporal classification (CTC) mapping
|
|
("data2vec-audio", "Data2VecAudioForCTC"),
|
|
("hubert", "HubertForCTC"),
|
|
("mctct", "MCTCTForCTC"),
|
|
("sew", "SEWForCTC"),
|
|
("sew-d", "SEWDForCTC"),
|
|
("unispeech", "UniSpeechForCTC"),
|
|
("unispeech-sat", "UniSpeechSatForCTC"),
|
|
("wav2vec2", "Wav2Vec2ForCTC"),
|
|
("wav2vec2-bert", "Wav2Vec2BertForCTC"),
|
|
("wav2vec2-conformer", "Wav2Vec2ConformerForCTC"),
|
|
("wavlm", "WavLMForCTC"),
|
|
]
|
|
)
|
|
|
|
MODEL_FOR_AUDIO_FRAME_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
|
|
[
|
|
# Model for Audio Classification mapping
|
|
("data2vec-audio", "Data2VecAudioForAudioFrameClassification"),
|
|
("unispeech-sat", "UniSpeechSatForAudioFrameClassification"),
|
|
("wav2vec2", "Wav2Vec2ForAudioFrameClassification"),
|
|
("wav2vec2-bert", "Wav2Vec2BertForAudioFrameClassification"),
|
|
("wav2vec2-conformer", "Wav2Vec2ConformerForAudioFrameClassification"),
|
|
("wavlm", "WavLMForAudioFrameClassification"),
|
|
]
|
|
)
|
|
|
|
MODEL_FOR_AUDIO_XVECTOR_MAPPING_NAMES = OrderedDict(
|
|
[
|
|
# Model for Audio Classification mapping
|
|
("data2vec-audio", "Data2VecAudioForXVector"),
|
|
("unispeech-sat", "UniSpeechSatForXVector"),
|
|
("wav2vec2", "Wav2Vec2ForXVector"),
|
|
("wav2vec2-bert", "Wav2Vec2BertForXVector"),
|
|
("wav2vec2-conformer", "Wav2Vec2ConformerForXVector"),
|
|
("wavlm", "WavLMForXVector"),
|
|
]
|
|
)
|
|
|
|
MODEL_FOR_TEXT_TO_SPECTROGRAM_MAPPING_NAMES = OrderedDict(
|
|
[
|
|
# Model for Text-To-Spectrogram mapping
|
|
("fastspeech2_conformer", "FastSpeech2ConformerModel"),
|
|
("speecht5", "SpeechT5ForTextToSpeech"),
|
|
]
|
|
)
|
|
|
|
MODEL_FOR_TEXT_TO_WAVEFORM_MAPPING_NAMES = OrderedDict(
|
|
[
|
|
# Model for Text-To-Waveform mapping
|
|
("bark", "BarkModel"),
|
|
("fastspeech2_conformer", "FastSpeech2ConformerWithHifiGan"),
|
|
("musicgen", "MusicgenForConditionalGeneration"),
|
|
("musicgen_melody", "MusicgenMelodyForConditionalGeneration"),
|
|
("seamless_m4t", "SeamlessM4TForTextToSpeech"),
|
|
("seamless_m4t_v2", "SeamlessM4Tv2ForTextToSpeech"),
|
|
("vits", "VitsModel"),
|
|
]
|
|
)
|
|
|
|
MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
|
|
[
|
|
# Model for Zero Shot Image Classification mapping
|
|
("align", "AlignModel"),
|
|
("altclip", "AltCLIPModel"),
|
|
("blip", "BlipModel"),
|
|
("chinese_clip", "ChineseCLIPModel"),
|
|
("clip", "CLIPModel"),
|
|
("clipseg", "CLIPSegModel"),
|
|
("siglip", "SiglipModel"),
|
|
]
|
|
)
|
|
|
|
MODEL_FOR_BACKBONE_MAPPING_NAMES = OrderedDict(
|
|
[
|
|
# Backbone mapping
|
|
("beit", "BeitBackbone"),
|
|
("bit", "BitBackbone"),
|
|
("convnext", "ConvNextBackbone"),
|
|
("convnextv2", "ConvNextV2Backbone"),
|
|
("dinat", "DinatBackbone"),
|
|
("dinov2", "Dinov2Backbone"),
|
|
("focalnet", "FocalNetBackbone"),
|
|
("maskformer-swin", "MaskFormerSwinBackbone"),
|
|
("nat", "NatBackbone"),
|
|
("pvt_v2", "PvtV2Backbone"),
|
|
("resnet", "ResNetBackbone"),
|
|
("swin", "SwinBackbone"),
|
|
("swinv2", "Swinv2Backbone"),
|
|
("timm_backbone", "TimmBackbone"),
|
|
("vitdet", "VitDetBackbone"),
|
|
]
|
|
)
|
|
|
|
MODEL_FOR_MASK_GENERATION_MAPPING_NAMES = OrderedDict(
|
|
[
|
|
("sam", "SamModel"),
|
|
]
|
|
)
|
|
|
|
|
|
MODEL_FOR_KEYPOINT_DETECTION_MAPPING_NAMES = OrderedDict(
|
|
[
|
|
("superpoint", "SuperPointForKeypointDetection"),
|
|
]
|
|
)
|
|
|
|
|
|
MODEL_FOR_TEXT_ENCODING_MAPPING_NAMES = OrderedDict(
|
|
[
|
|
("albert", "AlbertModel"),
|
|
("bert", "BertModel"),
|
|
("big_bird", "BigBirdModel"),
|
|
("data2vec-text", "Data2VecTextModel"),
|
|
("deberta", "DebertaModel"),
|
|
("deberta-v2", "DebertaV2Model"),
|
|
("distilbert", "DistilBertModel"),
|
|
("electra", "ElectraModel"),
|
|
("flaubert", "FlaubertModel"),
|
|
("ibert", "IBertModel"),
|
|
("longformer", "LongformerModel"),
|
|
("mobilebert", "MobileBertModel"),
|
|
("mt5", "MT5EncoderModel"),
|
|
("nystromformer", "NystromformerModel"),
|
|
("reformer", "ReformerModel"),
|
|
("rembert", "RemBertModel"),
|
|
("roberta", "RobertaModel"),
|
|
("roberta-prelayernorm", "RobertaPreLayerNormModel"),
|
|
("roc_bert", "RoCBertModel"),
|
|
("roformer", "RoFormerModel"),
|
|
("squeezebert", "SqueezeBertModel"),
|
|
("t5", "T5EncoderModel"),
|
|
("umt5", "UMT5EncoderModel"),
|
|
("xlm", "XLMModel"),
|
|
("xlm-roberta", "XLMRobertaModel"),
|
|
("xlm-roberta-xl", "XLMRobertaXLModel"),
|
|
]
|
|
)
|
|
|
|
MODEL_FOR_TIME_SERIES_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
|
|
[
|
|
("patchtsmixer", "PatchTSMixerForTimeSeriesClassification"),
|
|
("patchtst", "PatchTSTForClassification"),
|
|
]
|
|
)
|
|
|
|
MODEL_FOR_TIME_SERIES_REGRESSION_MAPPING_NAMES = OrderedDict(
|
|
[
|
|
("patchtsmixer", "PatchTSMixerForRegression"),
|
|
("patchtst", "PatchTSTForRegression"),
|
|
]
|
|
)
|
|
|
|
MODEL_FOR_IMAGE_TO_IMAGE_MAPPING_NAMES = OrderedDict(
|
|
[
|
|
("swin2sr", "Swin2SRForImageSuperResolution"),
|
|
]
|
|
)
|
|
|
|
MODEL_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_MAPPING_NAMES)
|
|
MODEL_FOR_PRETRAINING_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_PRETRAINING_MAPPING_NAMES)
|
|
MODEL_WITH_LM_HEAD_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_WITH_LM_HEAD_MAPPING_NAMES)
|
|
MODEL_FOR_CAUSAL_LM_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_CAUSAL_LM_MAPPING_NAMES)
|
|
MODEL_FOR_CAUSAL_IMAGE_MODELING_MAPPING = _LazyAutoMapping(
|
|
CONFIG_MAPPING_NAMES, MODEL_FOR_CAUSAL_IMAGE_MODELING_MAPPING_NAMES
|
|
)
|
|
MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING = _LazyAutoMapping(
|
|
CONFIG_MAPPING_NAMES, MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES
|
|
)
|
|
MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING = _LazyAutoMapping(
|
|
CONFIG_MAPPING_NAMES, MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING_NAMES
|
|
)
|
|
MODEL_FOR_IMAGE_SEGMENTATION_MAPPING = _LazyAutoMapping(
|
|
CONFIG_MAPPING_NAMES, MODEL_FOR_IMAGE_SEGMENTATION_MAPPING_NAMES
|
|
)
|
|
MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING = _LazyAutoMapping(
|
|
CONFIG_MAPPING_NAMES, MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING_NAMES
|
|
)
|
|
MODEL_FOR_INSTANCE_SEGMENTATION_MAPPING = _LazyAutoMapping(
|
|
CONFIG_MAPPING_NAMES, MODEL_FOR_INSTANCE_SEGMENTATION_MAPPING_NAMES
|
|
)
|
|
MODEL_FOR_UNIVERSAL_SEGMENTATION_MAPPING = _LazyAutoMapping(
|
|
CONFIG_MAPPING_NAMES, MODEL_FOR_UNIVERSAL_SEGMENTATION_MAPPING_NAMES
|
|
)
|
|
MODEL_FOR_VIDEO_CLASSIFICATION_MAPPING = _LazyAutoMapping(
|
|
CONFIG_MAPPING_NAMES, MODEL_FOR_VIDEO_CLASSIFICATION_MAPPING_NAMES
|
|
)
|
|
MODEL_FOR_VISION_2_SEQ_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES)
|
|
MODEL_FOR_VISUAL_QUESTION_ANSWERING_MAPPING = _LazyAutoMapping(
|
|
CONFIG_MAPPING_NAMES, MODEL_FOR_VISUAL_QUESTION_ANSWERING_MAPPING_NAMES
|
|
)
|
|
MODEL_FOR_DOCUMENT_QUESTION_ANSWERING_MAPPING = _LazyAutoMapping(
|
|
CONFIG_MAPPING_NAMES, MODEL_FOR_DOCUMENT_QUESTION_ANSWERING_MAPPING_NAMES
|
|
)
|
|
MODEL_FOR_MASKED_LM_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_MASKED_LM_MAPPING_NAMES)
|
|
MODEL_FOR_IMAGE_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_IMAGE_MAPPING_NAMES)
|
|
MODEL_FOR_MASKED_IMAGE_MODELING_MAPPING = _LazyAutoMapping(
|
|
CONFIG_MAPPING_NAMES, MODEL_FOR_MASKED_IMAGE_MODELING_MAPPING_NAMES
|
|
)
|
|
MODEL_FOR_OBJECT_DETECTION_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_OBJECT_DETECTION_MAPPING_NAMES)
|
|
MODEL_FOR_ZERO_SHOT_OBJECT_DETECTION_MAPPING = _LazyAutoMapping(
|
|
CONFIG_MAPPING_NAMES, MODEL_FOR_ZERO_SHOT_OBJECT_DETECTION_MAPPING_NAMES
|
|
)
|
|
MODEL_FOR_DEPTH_ESTIMATION_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_DEPTH_ESTIMATION_MAPPING_NAMES)
|
|
MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING = _LazyAutoMapping(
|
|
CONFIG_MAPPING_NAMES, MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES
|
|
)
|
|
MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING = _LazyAutoMapping(
|
|
CONFIG_MAPPING_NAMES, MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES
|
|
)
|
|
MODEL_FOR_QUESTION_ANSWERING_MAPPING = _LazyAutoMapping(
|
|
CONFIG_MAPPING_NAMES, MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES
|
|
)
|
|
MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING = _LazyAutoMapping(
|
|
CONFIG_MAPPING_NAMES, MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING_NAMES
|
|
)
|
|
MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING = _LazyAutoMapping(
|
|
CONFIG_MAPPING_NAMES, MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES
|
|
)
|
|
MODEL_FOR_MULTIPLE_CHOICE_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_MULTIPLE_CHOICE_MAPPING_NAMES)
|
|
MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING = _LazyAutoMapping(
|
|
CONFIG_MAPPING_NAMES, MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING_NAMES
|
|
)
|
|
MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING = _LazyAutoMapping(
|
|
CONFIG_MAPPING_NAMES, MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING_NAMES
|
|
)
|
|
MODEL_FOR_CTC_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_CTC_MAPPING_NAMES)
|
|
MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES)
|
|
MODEL_FOR_AUDIO_FRAME_CLASSIFICATION_MAPPING = _LazyAutoMapping(
|
|
CONFIG_MAPPING_NAMES, MODEL_FOR_AUDIO_FRAME_CLASSIFICATION_MAPPING_NAMES
|
|
)
|
|
MODEL_FOR_AUDIO_XVECTOR_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_AUDIO_XVECTOR_MAPPING_NAMES)
|
|
|
|
MODEL_FOR_TEXT_TO_SPECTROGRAM_MAPPING = _LazyAutoMapping(
|
|
CONFIG_MAPPING_NAMES, MODEL_FOR_TEXT_TO_SPECTROGRAM_MAPPING_NAMES
|
|
)
|
|
|
|
MODEL_FOR_TEXT_TO_WAVEFORM_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_TEXT_TO_WAVEFORM_MAPPING_NAMES)
|
|
|
|
MODEL_FOR_BACKBONE_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_BACKBONE_MAPPING_NAMES)
|
|
|
|
MODEL_FOR_MASK_GENERATION_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_MASK_GENERATION_MAPPING_NAMES)
|
|
|
|
MODEL_FOR_KEYPOINT_DETECTION_MAPPING = _LazyAutoMapping(
|
|
CONFIG_MAPPING_NAMES, MODEL_FOR_KEYPOINT_DETECTION_MAPPING_NAMES
|
|
)
|
|
|
|
MODEL_FOR_TEXT_ENCODING_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_TEXT_ENCODING_MAPPING_NAMES)
|
|
|
|
MODEL_FOR_TIME_SERIES_CLASSIFICATION_MAPPING = _LazyAutoMapping(
|
|
CONFIG_MAPPING_NAMES, MODEL_FOR_TIME_SERIES_CLASSIFICATION_MAPPING_NAMES
|
|
)
|
|
|
|
MODEL_FOR_TIME_SERIES_REGRESSION_MAPPING = _LazyAutoMapping(
|
|
CONFIG_MAPPING_NAMES, MODEL_FOR_TIME_SERIES_REGRESSION_MAPPING_NAMES
|
|
)
|
|
|
|
MODEL_FOR_IMAGE_TO_IMAGE_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_IMAGE_TO_IMAGE_MAPPING_NAMES)
|
|
|
|
|
|
class AutoModelForMaskGeneration(_BaseAutoModelClass):
|
|
_model_mapping = MODEL_FOR_MASK_GENERATION_MAPPING
|
|
|
|
|
|
class AutoModelForKeypointDetection(_BaseAutoModelClass):
|
|
_model_mapping = MODEL_FOR_KEYPOINT_DETECTION_MAPPING
|
|
|
|
|
|
class AutoModelForTextEncoding(_BaseAutoModelClass):
|
|
_model_mapping = MODEL_FOR_TEXT_ENCODING_MAPPING
|
|
|
|
|
|
class AutoModelForImageToImage(_BaseAutoModelClass):
|
|
_model_mapping = MODEL_FOR_IMAGE_TO_IMAGE_MAPPING
|
|
|
|
|
|
class AutoModel(_BaseAutoModelClass):
|
|
_model_mapping = MODEL_MAPPING
|
|
|
|
|
|
AutoModel = auto_class_update(AutoModel)
|
|
|
|
|
|
class AutoModelForPreTraining(_BaseAutoModelClass):
|
|
_model_mapping = MODEL_FOR_PRETRAINING_MAPPING
|
|
|
|
|
|
AutoModelForPreTraining = auto_class_update(AutoModelForPreTraining, head_doc="pretraining")
|
|
|
|
|
|
# Private on purpose, the public class will add the deprecation warnings.
|
|
class _AutoModelWithLMHead(_BaseAutoModelClass):
|
|
_model_mapping = MODEL_WITH_LM_HEAD_MAPPING
|
|
|
|
|
|
_AutoModelWithLMHead = auto_class_update(_AutoModelWithLMHead, head_doc="language modeling")
|
|
|
|
|
|
class AutoModelForCausalLM(_BaseAutoModelClass):
|
|
_model_mapping = MODEL_FOR_CAUSAL_LM_MAPPING
|
|
|
|
|
|
AutoModelForCausalLM = auto_class_update(AutoModelForCausalLM, head_doc="causal language modeling")
|
|
|
|
|
|
class AutoModelForMaskedLM(_BaseAutoModelClass):
|
|
_model_mapping = MODEL_FOR_MASKED_LM_MAPPING
|
|
|
|
|
|
AutoModelForMaskedLM = auto_class_update(AutoModelForMaskedLM, head_doc="masked language modeling")
|
|
|
|
|
|
class AutoModelForSeq2SeqLM(_BaseAutoModelClass):
|
|
_model_mapping = MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING
|
|
|
|
|
|
AutoModelForSeq2SeqLM = auto_class_update(
|
|
AutoModelForSeq2SeqLM,
|
|
head_doc="sequence-to-sequence language modeling",
|
|
checkpoint_for_example="google-t5/t5-base",
|
|
)
|
|
|
|
|
|
class AutoModelForSequenceClassification(_BaseAutoModelClass):
|
|
_model_mapping = MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING
|
|
|
|
|
|
AutoModelForSequenceClassification = auto_class_update(
|
|
AutoModelForSequenceClassification, head_doc="sequence classification"
|
|
)
|
|
|
|
|
|
class AutoModelForQuestionAnswering(_BaseAutoModelClass):
|
|
_model_mapping = MODEL_FOR_QUESTION_ANSWERING_MAPPING
|
|
|
|
|
|
AutoModelForQuestionAnswering = auto_class_update(AutoModelForQuestionAnswering, head_doc="question answering")
|
|
|
|
|
|
class AutoModelForTableQuestionAnswering(_BaseAutoModelClass):
|
|
_model_mapping = MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING
|
|
|
|
|
|
AutoModelForTableQuestionAnswering = auto_class_update(
|
|
AutoModelForTableQuestionAnswering,
|
|
head_doc="table question answering",
|
|
checkpoint_for_example="google/tapas-base-finetuned-wtq",
|
|
)
|
|
|
|
|
|
class AutoModelForVisualQuestionAnswering(_BaseAutoModelClass):
|
|
_model_mapping = MODEL_FOR_VISUAL_QUESTION_ANSWERING_MAPPING
|
|
|
|
|
|
AutoModelForVisualQuestionAnswering = auto_class_update(
|
|
AutoModelForVisualQuestionAnswering,
|
|
head_doc="visual question answering",
|
|
checkpoint_for_example="dandelin/vilt-b32-finetuned-vqa",
|
|
)
|
|
|
|
|
|
class AutoModelForDocumentQuestionAnswering(_BaseAutoModelClass):
|
|
_model_mapping = MODEL_FOR_DOCUMENT_QUESTION_ANSWERING_MAPPING
|
|
|
|
|
|
AutoModelForDocumentQuestionAnswering = auto_class_update(
|
|
AutoModelForDocumentQuestionAnswering,
|
|
head_doc="document question answering",
|
|
checkpoint_for_example='impira/layoutlm-document-qa", revision="52e01b3',
|
|
)
|
|
|
|
|
|
class AutoModelForTokenClassification(_BaseAutoModelClass):
|
|
_model_mapping = MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING
|
|
|
|
|
|
AutoModelForTokenClassification = auto_class_update(AutoModelForTokenClassification, head_doc="token classification")
|
|
|
|
|
|
class AutoModelForMultipleChoice(_BaseAutoModelClass):
|
|
_model_mapping = MODEL_FOR_MULTIPLE_CHOICE_MAPPING
|
|
|
|
|
|
AutoModelForMultipleChoice = auto_class_update(AutoModelForMultipleChoice, head_doc="multiple choice")
|
|
|
|
|
|
class AutoModelForNextSentencePrediction(_BaseAutoModelClass):
|
|
_model_mapping = MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING
|
|
|
|
|
|
AutoModelForNextSentencePrediction = auto_class_update(
|
|
AutoModelForNextSentencePrediction, head_doc="next sentence prediction"
|
|
)
|
|
|
|
|
|
class AutoModelForImageClassification(_BaseAutoModelClass):
|
|
_model_mapping = MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING
|
|
|
|
|
|
AutoModelForImageClassification = auto_class_update(AutoModelForImageClassification, head_doc="image classification")
|
|
|
|
|
|
class AutoModelForZeroShotImageClassification(_BaseAutoModelClass):
|
|
_model_mapping = MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING
|
|
|
|
|
|
AutoModelForZeroShotImageClassification = auto_class_update(
|
|
AutoModelForZeroShotImageClassification, head_doc="zero-shot image classification"
|
|
)
|
|
|
|
|
|
class AutoModelForImageSegmentation(_BaseAutoModelClass):
|
|
_model_mapping = MODEL_FOR_IMAGE_SEGMENTATION_MAPPING
|
|
|
|
|
|
AutoModelForImageSegmentation = auto_class_update(AutoModelForImageSegmentation, head_doc="image segmentation")
|
|
|
|
|
|
class AutoModelForSemanticSegmentation(_BaseAutoModelClass):
|
|
_model_mapping = MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING
|
|
|
|
|
|
AutoModelForSemanticSegmentation = auto_class_update(
|
|
AutoModelForSemanticSegmentation, head_doc="semantic segmentation"
|
|
)
|
|
|
|
|
|
class AutoModelForUniversalSegmentation(_BaseAutoModelClass):
|
|
_model_mapping = MODEL_FOR_UNIVERSAL_SEGMENTATION_MAPPING
|
|
|
|
|
|
AutoModelForUniversalSegmentation = auto_class_update(
|
|
AutoModelForUniversalSegmentation, head_doc="universal image segmentation"
|
|
)
|
|
|
|
|
|
class AutoModelForInstanceSegmentation(_BaseAutoModelClass):
|
|
_model_mapping = MODEL_FOR_INSTANCE_SEGMENTATION_MAPPING
|
|
|
|
|
|
AutoModelForInstanceSegmentation = auto_class_update(
|
|
AutoModelForInstanceSegmentation, head_doc="instance segmentation"
|
|
)
|
|
|
|
|
|
class AutoModelForObjectDetection(_BaseAutoModelClass):
|
|
_model_mapping = MODEL_FOR_OBJECT_DETECTION_MAPPING
|
|
|
|
|
|
AutoModelForObjectDetection = auto_class_update(AutoModelForObjectDetection, head_doc="object detection")
|
|
|
|
|
|
class AutoModelForZeroShotObjectDetection(_BaseAutoModelClass):
|
|
_model_mapping = MODEL_FOR_ZERO_SHOT_OBJECT_DETECTION_MAPPING
|
|
|
|
|
|
AutoModelForZeroShotObjectDetection = auto_class_update(
|
|
AutoModelForZeroShotObjectDetection, head_doc="zero-shot object detection"
|
|
)
|
|
|
|
|
|
class AutoModelForDepthEstimation(_BaseAutoModelClass):
|
|
_model_mapping = MODEL_FOR_DEPTH_ESTIMATION_MAPPING
|
|
|
|
|
|
AutoModelForDepthEstimation = auto_class_update(AutoModelForDepthEstimation, head_doc="depth estimation")
|
|
|
|
|
|
class AutoModelForVideoClassification(_BaseAutoModelClass):
|
|
_model_mapping = MODEL_FOR_VIDEO_CLASSIFICATION_MAPPING
|
|
|
|
|
|
AutoModelForVideoClassification = auto_class_update(AutoModelForVideoClassification, head_doc="video classification")
|
|
|
|
|
|
class AutoModelForVision2Seq(_BaseAutoModelClass):
|
|
_model_mapping = MODEL_FOR_VISION_2_SEQ_MAPPING
|
|
|
|
|
|
AutoModelForVision2Seq = auto_class_update(AutoModelForVision2Seq, head_doc="vision-to-text modeling")
|
|
|
|
|
|
class AutoModelForAudioClassification(_BaseAutoModelClass):
|
|
_model_mapping = MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING
|
|
|
|
|
|
AutoModelForAudioClassification = auto_class_update(AutoModelForAudioClassification, head_doc="audio classification")
|
|
|
|
|
|
class AutoModelForCTC(_BaseAutoModelClass):
|
|
_model_mapping = MODEL_FOR_CTC_MAPPING
|
|
|
|
|
|
AutoModelForCTC = auto_class_update(AutoModelForCTC, head_doc="connectionist temporal classification")
|
|
|
|
|
|
class AutoModelForSpeechSeq2Seq(_BaseAutoModelClass):
|
|
_model_mapping = MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING
|
|
|
|
|
|
AutoModelForSpeechSeq2Seq = auto_class_update(
|
|
AutoModelForSpeechSeq2Seq, head_doc="sequence-to-sequence speech-to-text modeling"
|
|
)
|
|
|
|
|
|
class AutoModelForAudioFrameClassification(_BaseAutoModelClass):
|
|
_model_mapping = MODEL_FOR_AUDIO_FRAME_CLASSIFICATION_MAPPING
|
|
|
|
|
|
AutoModelForAudioFrameClassification = auto_class_update(
|
|
AutoModelForAudioFrameClassification, head_doc="audio frame (token) classification"
|
|
)
|
|
|
|
|
|
class AutoModelForAudioXVector(_BaseAutoModelClass):
|
|
_model_mapping = MODEL_FOR_AUDIO_XVECTOR_MAPPING
|
|
|
|
|
|
class AutoModelForTextToSpectrogram(_BaseAutoModelClass):
|
|
_model_mapping = MODEL_FOR_TEXT_TO_SPECTROGRAM_MAPPING
|
|
|
|
|
|
class AutoModelForTextToWaveform(_BaseAutoModelClass):
|
|
_model_mapping = MODEL_FOR_TEXT_TO_WAVEFORM_MAPPING
|
|
|
|
|
|
class AutoBackbone(_BaseAutoBackboneClass):
|
|
_model_mapping = MODEL_FOR_BACKBONE_MAPPING
|
|
|
|
|
|
AutoModelForAudioXVector = auto_class_update(AutoModelForAudioXVector, head_doc="audio retrieval via x-vector")
|
|
|
|
|
|
class AutoModelForMaskedImageModeling(_BaseAutoModelClass):
|
|
_model_mapping = MODEL_FOR_MASKED_IMAGE_MODELING_MAPPING
|
|
|
|
|
|
AutoModelForMaskedImageModeling = auto_class_update(AutoModelForMaskedImageModeling, head_doc="masked image modeling")
|
|
|
|
|
|
class AutoModelWithLMHead(_AutoModelWithLMHead):
|
|
@classmethod
|
|
def from_config(cls, config):
|
|
warnings.warn(
|
|
"The class `AutoModelWithLMHead` is deprecated and will be removed in a future version. Please use "
|
|
"`AutoModelForCausalLM` for causal language models, `AutoModelForMaskedLM` for masked language models and "
|
|
"`AutoModelForSeq2SeqLM` for encoder-decoder models.",
|
|
FutureWarning,
|
|
)
|
|
return super().from_config(config)
|
|
|
|
@classmethod
|
|
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
|
|
warnings.warn(
|
|
"The class `AutoModelWithLMHead` is deprecated and will be removed in a future version. Please use "
|
|
"`AutoModelForCausalLM` for causal language models, `AutoModelForMaskedLM` for masked language models and "
|
|
"`AutoModelForSeq2SeqLM` for encoder-decoder models.",
|
|
FutureWarning,
|
|
)
|
|
return super().from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
|