Why Gemfury? Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Debian packages RPM packages NuGet packages

Repository URL to install this package:

Details    
Size: Mime:
# coding=utf-8
# Copyright 2018 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" Auto Model class."""


import warnings
from collections import OrderedDict

from ...utils import logging
from .auto_factory import _BaseAutoModelClass, _LazyAutoMapping, auto_class_update
from .configuration_auto import CONFIG_MAPPING_NAMES


logger = logging.get_logger(__name__)


TF_MODEL_MAPPING_NAMES = OrderedDict(
    [
        # Base model mapping
        ("albert", "TFAlbertModel"),
        ("bart", "TFBartModel"),
        ("bert", "TFBertModel"),
        ("blenderbot", "TFBlenderbotModel"),
        ("blenderbot-small", "TFBlenderbotSmallModel"),
        ("blip", "TFBlipModel"),
        ("camembert", "TFCamembertModel"),
        ("clip", "TFCLIPModel"),
        ("convbert", "TFConvBertModel"),
        ("convnext", "TFConvNextModel"),
        ("ctrl", "TFCTRLModel"),
        ("cvt", "TFCvtModel"),
        ("data2vec-vision", "TFData2VecVisionModel"),
        ("deberta", "TFDebertaModel"),
        ("deberta-v2", "TFDebertaV2Model"),
        ("deit", "TFDeiTModel"),
        ("distilbert", "TFDistilBertModel"),
        ("dpr", "TFDPRQuestionEncoder"),
        ("efficientformer", "TFEfficientFormerModel"),
        ("electra", "TFElectraModel"),
        ("esm", "TFEsmModel"),
        ("flaubert", "TFFlaubertModel"),
        ("funnel", ("TFFunnelModel", "TFFunnelBaseModel")),
        ("gpt-sw3", "TFGPT2Model"),
        ("gpt2", "TFGPT2Model"),
        ("gptj", "TFGPTJModel"),
        ("groupvit", "TFGroupViTModel"),
        ("hubert", "TFHubertModel"),
        ("layoutlm", "TFLayoutLMModel"),
        ("layoutlmv3", "TFLayoutLMv3Model"),
        ("led", "TFLEDModel"),
        ("longformer", "TFLongformerModel"),
        ("lxmert", "TFLxmertModel"),
        ("marian", "TFMarianModel"),
        ("mbart", "TFMBartModel"),
        ("mobilebert", "TFMobileBertModel"),
        ("mobilevit", "TFMobileViTModel"),
        ("mpnet", "TFMPNetModel"),
        ("mt5", "TFMT5Model"),
        ("openai-gpt", "TFOpenAIGPTModel"),
        ("opt", "TFOPTModel"),
        ("pegasus", "TFPegasusModel"),
        ("regnet", "TFRegNetModel"),
        ("rembert", "TFRemBertModel"),
        ("resnet", "TFResNetModel"),
        ("roberta", "TFRobertaModel"),
        ("roberta-prelayernorm", "TFRobertaPreLayerNormModel"),
        ("roformer", "TFRoFormerModel"),
        ("sam", "TFSamModel"),
        ("segformer", "TFSegformerModel"),
        ("speech_to_text", "TFSpeech2TextModel"),
        ("swin", "TFSwinModel"),
        ("t5", "TFT5Model"),
        ("tapas", "TFTapasModel"),
        ("transfo-xl", "TFTransfoXLModel"),
        ("vision-text-dual-encoder", "TFVisionTextDualEncoderModel"),
        ("vit", "TFViTModel"),
        ("vit_mae", "TFViTMAEModel"),
        ("wav2vec2", "TFWav2Vec2Model"),
        ("whisper", "TFWhisperModel"),
        ("xglm", "TFXGLMModel"),
        ("xlm", "TFXLMModel"),
        ("xlm-roberta", "TFXLMRobertaModel"),
        ("xlnet", "TFXLNetModel"),
    ]
)

TF_MODEL_FOR_PRETRAINING_MAPPING_NAMES = OrderedDict(
    [
        # Model for pre-training mapping
        ("albert", "TFAlbertForPreTraining"),
        ("bart", "TFBartForConditionalGeneration"),
        ("bert", "TFBertForPreTraining"),
        ("camembert", "TFCamembertForMaskedLM"),
        ("ctrl", "TFCTRLLMHeadModel"),
        ("distilbert", "TFDistilBertForMaskedLM"),
        ("electra", "TFElectraForPreTraining"),
        ("flaubert", "TFFlaubertWithLMHeadModel"),
        ("funnel", "TFFunnelForPreTraining"),
        ("gpt-sw3", "TFGPT2LMHeadModel"),
        ("gpt2", "TFGPT2LMHeadModel"),
        ("layoutlm", "TFLayoutLMForMaskedLM"),
        ("lxmert", "TFLxmertForPreTraining"),
        ("mobilebert", "TFMobileBertForPreTraining"),
        ("mpnet", "TFMPNetForMaskedLM"),
        ("openai-gpt", "TFOpenAIGPTLMHeadModel"),
        ("roberta", "TFRobertaForMaskedLM"),
        ("roberta-prelayernorm", "TFRobertaPreLayerNormForMaskedLM"),
        ("t5", "TFT5ForConditionalGeneration"),
        ("tapas", "TFTapasForMaskedLM"),
        ("transfo-xl", "TFTransfoXLLMHeadModel"),
        ("vit_mae", "TFViTMAEForPreTraining"),
        ("xlm", "TFXLMWithLMHeadModel"),
        ("xlm-roberta", "TFXLMRobertaForMaskedLM"),
        ("xlnet", "TFXLNetLMHeadModel"),
    ]
)

TF_MODEL_WITH_LM_HEAD_MAPPING_NAMES = OrderedDict(
    [
        # Model with LM heads mapping
        ("albert", "TFAlbertForMaskedLM"),
        ("bart", "TFBartForConditionalGeneration"),
        ("bert", "TFBertForMaskedLM"),
        ("camembert", "TFCamembertForMaskedLM"),
        ("convbert", "TFConvBertForMaskedLM"),
        ("ctrl", "TFCTRLLMHeadModel"),
        ("distilbert", "TFDistilBertForMaskedLM"),
        ("electra", "TFElectraForMaskedLM"),
        ("esm", "TFEsmForMaskedLM"),
        ("flaubert", "TFFlaubertWithLMHeadModel"),
        ("funnel", "TFFunnelForMaskedLM"),
        ("gpt-sw3", "TFGPT2LMHeadModel"),
        ("gpt2", "TFGPT2LMHeadModel"),
        ("gptj", "TFGPTJForCausalLM"),
        ("layoutlm", "TFLayoutLMForMaskedLM"),
        ("led", "TFLEDForConditionalGeneration"),
        ("longformer", "TFLongformerForMaskedLM"),
        ("marian", "TFMarianMTModel"),
        ("mobilebert", "TFMobileBertForMaskedLM"),
        ("mpnet", "TFMPNetForMaskedLM"),
        ("openai-gpt", "TFOpenAIGPTLMHeadModel"),
        ("rembert", "TFRemBertForMaskedLM"),
        ("roberta", "TFRobertaForMaskedLM"),
        ("roberta-prelayernorm", "TFRobertaPreLayerNormForMaskedLM"),
        ("roformer", "TFRoFormerForMaskedLM"),
        ("speech_to_text", "TFSpeech2TextForConditionalGeneration"),
        ("t5", "TFT5ForConditionalGeneration"),
        ("tapas", "TFTapasForMaskedLM"),
        ("transfo-xl", "TFTransfoXLLMHeadModel"),
        ("whisper", "TFWhisperForConditionalGeneration"),
        ("xlm", "TFXLMWithLMHeadModel"),
        ("xlm-roberta", "TFXLMRobertaForMaskedLM"),
        ("xlnet", "TFXLNetLMHeadModel"),
    ]
)

TF_MODEL_FOR_CAUSAL_LM_MAPPING_NAMES = OrderedDict(
    [
        # Model for Causal LM mapping
        ("bert", "TFBertLMHeadModel"),
        ("camembert", "TFCamembertForCausalLM"),
        ("ctrl", "TFCTRLLMHeadModel"),
        ("gpt-sw3", "TFGPT2LMHeadModel"),
        ("gpt2", "TFGPT2LMHeadModel"),
        ("gptj", "TFGPTJForCausalLM"),
        ("openai-gpt", "TFOpenAIGPTLMHeadModel"),
        ("opt", "TFOPTForCausalLM"),
        ("rembert", "TFRemBertForCausalLM"),
        ("roberta", "TFRobertaForCausalLM"),
        ("roberta-prelayernorm", "TFRobertaPreLayerNormForCausalLM"),
        ("roformer", "TFRoFormerForCausalLM"),
        ("transfo-xl", "TFTransfoXLLMHeadModel"),
        ("xglm", "TFXGLMForCausalLM"),
        ("xlm", "TFXLMWithLMHeadModel"),
        ("xlm-roberta", "TFXLMRobertaForCausalLM"),
        ("xlnet", "TFXLNetLMHeadModel"),
    ]
)

TF_MODEL_FOR_MASKED_IMAGE_MODELING_MAPPING_NAMES = OrderedDict(
    [
        ("deit", "TFDeiTForMaskedImageModeling"),
        ("swin", "TFSwinForMaskedImageModeling"),
    ]
)

TF_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
    [
        # Model for Image-classsification
        ("convnext", "TFConvNextForImageClassification"),
        ("cvt", "TFCvtForImageClassification"),
        ("data2vec-vision", "TFData2VecVisionForImageClassification"),
        ("deit", ("TFDeiTForImageClassification", "TFDeiTForImageClassificationWithTeacher")),
        (
            "efficientformer",
            ("TFEfficientFormerForImageClassification", "TFEfficientFormerForImageClassificationWithTeacher"),
        ),
        ("mobilevit", "TFMobileViTForImageClassification"),
        ("regnet", "TFRegNetForImageClassification"),
        ("resnet", "TFResNetForImageClassification"),
        ("segformer", "TFSegformerForImageClassification"),
        ("swin", "TFSwinForImageClassification"),
        ("vit", "TFViTForImageClassification"),
    ]
)


TF_MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
    [
        # Model for Zero Shot Image Classification mapping
        ("blip", "TFBlipModel"),
        ("clip", "TFCLIPModel"),
    ]
)


TF_MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING_NAMES = OrderedDict(
    [
        # Model for Semantic Segmentation mapping
        ("data2vec-vision", "TFData2VecVisionForSemanticSegmentation"),
        ("mobilevit", "TFMobileViTForSemanticSegmentation"),
        ("segformer", "TFSegformerForSemanticSegmentation"),
    ]
)

TF_MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES = OrderedDict(
    [
        ("blip", "TFBlipForConditionalGeneration"),
        ("vision-encoder-decoder", "TFVisionEncoderDecoderModel"),
    ]
)

TF_MODEL_FOR_MASKED_LM_MAPPING_NAMES = OrderedDict(
    [
        # Model for Masked LM mapping
        ("albert", "TFAlbertForMaskedLM"),
        ("bert", "TFBertForMaskedLM"),
        ("camembert", "TFCamembertForMaskedLM"),
        ("convbert", "TFConvBertForMaskedLM"),
        ("deberta", "TFDebertaForMaskedLM"),
        ("deberta-v2", "TFDebertaV2ForMaskedLM"),
        ("distilbert", "TFDistilBertForMaskedLM"),
        ("electra", "TFElectraForMaskedLM"),
        ("esm", "TFEsmForMaskedLM"),
        ("flaubert", "TFFlaubertWithLMHeadModel"),
        ("funnel", "TFFunnelForMaskedLM"),
        ("layoutlm", "TFLayoutLMForMaskedLM"),
        ("longformer", "TFLongformerForMaskedLM"),
        ("mobilebert", "TFMobileBertForMaskedLM"),
        ("mpnet", "TFMPNetForMaskedLM"),
        ("rembert", "TFRemBertForMaskedLM"),
        ("roberta", "TFRobertaForMaskedLM"),
        ("roberta-prelayernorm", "TFRobertaPreLayerNormForMaskedLM"),
        ("roformer", "TFRoFormerForMaskedLM"),
        ("tapas", "TFTapasForMaskedLM"),
        ("xlm", "TFXLMWithLMHeadModel"),
        ("xlm-roberta", "TFXLMRobertaForMaskedLM"),
    ]
)

TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES = OrderedDict(
    [
        # Model for Seq2Seq Causal LM mapping
        ("bart", "TFBartForConditionalGeneration"),
        ("blenderbot", "TFBlenderbotForConditionalGeneration"),
        ("blenderbot-small", "TFBlenderbotSmallForConditionalGeneration"),
        ("encoder-decoder", "TFEncoderDecoderModel"),
        ("led", "TFLEDForConditionalGeneration"),
        ("marian", "TFMarianMTModel"),
        ("mbart", "TFMBartForConditionalGeneration"),
        ("mt5", "TFMT5ForConditionalGeneration"),
        ("pegasus", "TFPegasusForConditionalGeneration"),
        ("t5", "TFT5ForConditionalGeneration"),
    ]
)

TF_MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES = OrderedDict(
    [
        ("speech_to_text", "TFSpeech2TextForConditionalGeneration"),
        ("whisper", "TFWhisperForConditionalGeneration"),
    ]
)

TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
    [
        # Model for Sequence Classification mapping
        ("albert", "TFAlbertForSequenceClassification"),
        ("bart", "TFBartForSequenceClassification"),
        ("bert", "TFBertForSequenceClassification"),
        ("camembert", "TFCamembertForSequenceClassification"),
        ("convbert", "TFConvBertForSequenceClassification"),
        ("ctrl", "TFCTRLForSequenceClassification"),
        ("deberta", "TFDebertaForSequenceClassification"),
        ("deberta-v2", "TFDebertaV2ForSequenceClassification"),
        ("distilbert", "TFDistilBertForSequenceClassification"),
        ("electra", "TFElectraForSequenceClassification"),
        ("esm", "TFEsmForSequenceClassification"),
        ("flaubert", "TFFlaubertForSequenceClassification"),
        ("funnel", "TFFunnelForSequenceClassification"),
        ("gpt-sw3", "TFGPT2ForSequenceClassification"),
        ("gpt2", "TFGPT2ForSequenceClassification"),
        ("gptj", "TFGPTJForSequenceClassification"),
        ("layoutlm", "TFLayoutLMForSequenceClassification"),
        ("layoutlmv3", "TFLayoutLMv3ForSequenceClassification"),
        ("longformer", "TFLongformerForSequenceClassification"),
        ("mobilebert", "TFMobileBertForSequenceClassification"),
        ("mpnet", "TFMPNetForSequenceClassification"),
        ("openai-gpt", "TFOpenAIGPTForSequenceClassification"),
        ("rembert", "TFRemBertForSequenceClassification"),
        ("roberta", "TFRobertaForSequenceClassification"),
        ("roberta-prelayernorm", "TFRobertaPreLayerNormForSequenceClassification"),
        ("roformer", "TFRoFormerForSequenceClassification"),
        ("tapas", "TFTapasForSequenceClassification"),
        ("transfo-xl", "TFTransfoXLForSequenceClassification"),
        ("xlm", "TFXLMForSequenceClassification"),
        ("xlm-roberta", "TFXLMRobertaForSequenceClassification"),
        ("xlnet", "TFXLNetForSequenceClassification"),
    ]
)

TF_MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES = OrderedDict(
    [
        # Model for Question Answering mapping
        ("albert", "TFAlbertForQuestionAnswering"),
        ("bert", "TFBertForQuestionAnswering"),
        ("camembert", "TFCamembertForQuestionAnswering"),
        ("convbert", "TFConvBertForQuestionAnswering"),
        ("deberta", "TFDebertaForQuestionAnswering"),
        ("deberta-v2", "TFDebertaV2ForQuestionAnswering"),
        ("distilbert", "TFDistilBertForQuestionAnswering"),
        ("electra", "TFElectraForQuestionAnswering"),
        ("flaubert", "TFFlaubertForQuestionAnsweringSimple"),
        ("funnel", "TFFunnelForQuestionAnswering"),
        ("gptj", "TFGPTJForQuestionAnswering"),
        ("layoutlmv3", "TFLayoutLMv3ForQuestionAnswering"),
        ("longformer", "TFLongformerForQuestionAnswering"),
        ("mobilebert", "TFMobileBertForQuestionAnswering"),
        ("mpnet", "TFMPNetForQuestionAnswering"),
        ("rembert", "TFRemBertForQuestionAnswering"),
        ("roberta", "TFRobertaForQuestionAnswering"),
        ("roberta-prelayernorm", "TFRobertaPreLayerNormForQuestionAnswering"),
        ("roformer", "TFRoFormerForQuestionAnswering"),
        ("xlm", "TFXLMForQuestionAnsweringSimple"),
        ("xlm-roberta", "TFXLMRobertaForQuestionAnswering"),
        ("xlnet", "TFXLNetForQuestionAnsweringSimple"),
    ]
)
TF_MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING_NAMES = OrderedDict([("wav2vec2", "TFWav2Vec2ForSequenceClassification")])

TF_MODEL_FOR_DOCUMENT_QUESTION_ANSWERING_MAPPING_NAMES = OrderedDict(
    [
        ("layoutlm", "TFLayoutLMForQuestionAnswering"),
    ]
)


TF_MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING_NAMES = OrderedDict(
    [
        # Model for Table Question Answering mapping
        ("tapas", "TFTapasForQuestionAnswering"),
    ]
)

TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
    [
        # Model for Token Classification mapping
        ("albert", "TFAlbertForTokenClassification"),
        ("bert", "TFBertForTokenClassification"),
        ("camembert", "TFCamembertForTokenClassification"),
        ("convbert", "TFConvBertForTokenClassification"),
        ("deberta", "TFDebertaForTokenClassification"),
        ("deberta-v2", "TFDebertaV2ForTokenClassification"),
        ("distilbert", "TFDistilBertForTokenClassification"),
        ("electra", "TFElectraForTokenClassification"),
        ("esm", "TFEsmForTokenClassification"),
        ("flaubert", "TFFlaubertForTokenClassification"),
        ("funnel", "TFFunnelForTokenClassification"),
        ("layoutlm", "TFLayoutLMForTokenClassification"),
        ("layoutlmv3", "TFLayoutLMv3ForTokenClassification"),
        ("longformer", "TFLongformerForTokenClassification"),
        ("mobilebert", "TFMobileBertForTokenClassification"),
        ("mpnet", "TFMPNetForTokenClassification"),
        ("rembert", "TFRemBertForTokenClassification"),
        ("roberta", "TFRobertaForTokenClassification"),
        ("roberta-prelayernorm", "TFRobertaPreLayerNormForTokenClassification"),
        ("roformer", "TFRoFormerForTokenClassification"),
        ("xlm", "TFXLMForTokenClassification"),
        ("xlm-roberta", "TFXLMRobertaForTokenClassification"),
        ("xlnet", "TFXLNetForTokenClassification"),
    ]
)

TF_MODEL_FOR_MULTIPLE_CHOICE_MAPPING_NAMES = OrderedDict(
    [
        # Model for Multiple Choice mapping
        ("albert", "TFAlbertForMultipleChoice"),
        ("bert", "TFBertForMultipleChoice"),
        ("camembert", "TFCamembertForMultipleChoice"),
        ("convbert", "TFConvBertForMultipleChoice"),
        ("distilbert", "TFDistilBertForMultipleChoice"),
        ("electra", "TFElectraForMultipleChoice"),
        ("flaubert", "TFFlaubertForMultipleChoice"),
        ("funnel", "TFFunnelForMultipleChoice"),
        ("longformer", "TFLongformerForMultipleChoice"),
        ("mobilebert", "TFMobileBertForMultipleChoice"),
        ("mpnet", "TFMPNetForMultipleChoice"),
        ("rembert", "TFRemBertForMultipleChoice"),
        ("roberta", "TFRobertaForMultipleChoice"),
        ("roberta-prelayernorm", "TFRobertaPreLayerNormForMultipleChoice"),
        ("roformer", "TFRoFormerForMultipleChoice"),
        ("xlm", "TFXLMForMultipleChoice"),
        ("xlm-roberta", "TFXLMRobertaForMultipleChoice"),
        ("xlnet", "TFXLNetForMultipleChoice"),
    ]
)

TF_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING_NAMES = OrderedDict(
    [
        ("bert", "TFBertForNextSentencePrediction"),
        ("mobilebert", "TFMobileBertForNextSentencePrediction"),
    ]
)
TF_MODEL_FOR_MASK_GENERATION_MAPPING_NAMES = OrderedDict(
    [
        ("sam", "TFSamModel"),
    ]
)
TF_MODEL_FOR_TEXT_ENCODING_MAPPING_NAMES = OrderedDict(
    [
        ("albert", "TFAlbertModel"),
        ("bert", "TFBertModel"),
        ("convbert", "TFConvBertModel"),
        ("deberta", "TFDebertaModel"),
        ("deberta-v2", "TFDebertaV2Model"),
        ("distilbert", "TFDistilBertModel"),
        ("electra", "TFElectraModel"),
        ("flaubert", "TFFlaubertModel"),
        ("longformer", "TFLongformerModel"),
        ("mobilebert", "TFMobileBertModel"),
        ("mt5", "TFMT5EncoderModel"),
        ("rembert", "TFRemBertModel"),
        ("roberta", "TFRobertaModel"),
        ("roberta-prelayernorm", "TFRobertaPreLayerNormModel"),
        ("roformer", "TFRoFormerModel"),
        ("t5", "TFT5EncoderModel"),
        ("xlm", "TFXLMModel"),
        ("xlm-roberta", "TFXLMRobertaModel"),
    ]
)

TF_MODEL_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, TF_MODEL_MAPPING_NAMES)
TF_MODEL_FOR_PRETRAINING_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, TF_MODEL_FOR_PRETRAINING_MAPPING_NAMES)
TF_MODEL_WITH_LM_HEAD_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, TF_MODEL_WITH_LM_HEAD_MAPPING_NAMES)
TF_MODEL_FOR_CAUSAL_LM_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, TF_MODEL_FOR_CAUSAL_LM_MAPPING_NAMES)
TF_MODEL_FOR_MASKED_IMAGE_MODELING_MAPPING = _LazyAutoMapping(
    CONFIG_MAPPING_NAMES, TF_MODEL_FOR_MASKED_IMAGE_MODELING_MAPPING_NAMES
)
TF_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING = _LazyAutoMapping(
    CONFIG_MAPPING_NAMES, TF_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES
)
TF_MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING = _LazyAutoMapping(
    CONFIG_MAPPING_NAMES, TF_MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING_NAMES
)
TF_MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING = _LazyAutoMapping(
    CONFIG_MAPPING_NAMES, TF_MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING_NAMES
)
TF_MODEL_FOR_VISION_2_SEQ_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, TF_MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES)
TF_MODEL_FOR_MASKED_LM_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, TF_MODEL_FOR_MASKED_LM_MAPPING_NAMES)
TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING = _LazyAutoMapping(
    CONFIG_MAPPING_NAMES, TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES
)
TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING = _LazyAutoMapping(
    CONFIG_MAPPING_NAMES, TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES
)
TF_MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING = _LazyAutoMapping(
    CONFIG_MAPPING_NAMES, TF_MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES
)
TF_MODEL_FOR_QUESTION_ANSWERING_MAPPING = _LazyAutoMapping(
    CONFIG_MAPPING_NAMES, TF_MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES
)
TF_MODEL_FOR_DOCUMENT_QUESTION_ANSWERING_MAPPING = _LazyAutoMapping(
    CONFIG_MAPPING_NAMES, TF_MODEL_FOR_DOCUMENT_QUESTION_ANSWERING_MAPPING_NAMES
)
TF_MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING = _LazyAutoMapping(
    CONFIG_MAPPING_NAMES, TF_MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING_NAMES
)
TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING = _LazyAutoMapping(
    CONFIG_MAPPING_NAMES, TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES
)
TF_MODEL_FOR_MULTIPLE_CHOICE_MAPPING = _LazyAutoMapping(
    CONFIG_MAPPING_NAMES, TF_MODEL_FOR_MULTIPLE_CHOICE_MAPPING_NAMES
)
TF_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING = _LazyAutoMapping(
    CONFIG_MAPPING_NAMES, TF_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING_NAMES
)
TF_MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING = _LazyAutoMapping(
    CONFIG_MAPPING_NAMES, TF_MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING_NAMES
)

TF_MODEL_FOR_MASK_GENERATION_MAPPING = _LazyAutoMapping(
    CONFIG_MAPPING_NAMES, TF_MODEL_FOR_MASK_GENERATION_MAPPING_NAMES
)

TF_MODEL_FOR_TEXT_ENCODING_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, TF_MODEL_FOR_TEXT_ENCODING_MAPPING_NAMES)


class TFAutoModelForMaskGeneration(_BaseAutoModelClass):
    _model_mapping = TF_MODEL_FOR_MASK_GENERATION_MAPPING


class TFAutoModelForTextEncoding(_BaseAutoModelClass):
    _model_mapping = TF_MODEL_FOR_TEXT_ENCODING_MAPPING


class TFAutoModel(_BaseAutoModelClass):
    _model_mapping = TF_MODEL_MAPPING


TFAutoModel = auto_class_update(TFAutoModel)


class TFAutoModelForAudioClassification(_BaseAutoModelClass):
    _model_mapping = TF_MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING


TFAutoModelForAudioClassification = auto_class_update(
    TFAutoModelForAudioClassification, head_doc="audio classification"
)


class TFAutoModelForPreTraining(_BaseAutoModelClass):
    _model_mapping = TF_MODEL_FOR_PRETRAINING_MAPPING


TFAutoModelForPreTraining = auto_class_update(TFAutoModelForPreTraining, head_doc="pretraining")


# Private on purpose, the public class will add the deprecation warnings.
class _TFAutoModelWithLMHead(_BaseAutoModelClass):
    _model_mapping = TF_MODEL_WITH_LM_HEAD_MAPPING


_TFAutoModelWithLMHead = auto_class_update(_TFAutoModelWithLMHead, head_doc="language modeling")


class TFAutoModelForCausalLM(_BaseAutoModelClass):
    _model_mapping = TF_MODEL_FOR_CAUSAL_LM_MAPPING


TFAutoModelForCausalLM = auto_class_update(TFAutoModelForCausalLM, head_doc="causal language modeling")


class TFAutoModelForMaskedImageModeling(_BaseAutoModelClass):
    _model_mapping = TF_MODEL_FOR_MASKED_IMAGE_MODELING_MAPPING


TFAutoModelForMaskedImageModeling = auto_class_update(
    TFAutoModelForMaskedImageModeling, head_doc="masked image modeling"
)


class TFAutoModelForImageClassification(_BaseAutoModelClass):
    _model_mapping = TF_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING


TFAutoModelForImageClassification = auto_class_update(
    TFAutoModelForImageClassification, head_doc="image classification"
)


class TFAutoModelForZeroShotImageClassification(_BaseAutoModelClass):
    _model_mapping = TF_MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING


TFAutoModelForZeroShotImageClassification = auto_class_update(
    TFAutoModelForZeroShotImageClassification, head_doc="zero-shot image classification"
)


class TFAutoModelForSemanticSegmentation(_BaseAutoModelClass):
    _model_mapping = TF_MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING


TFAutoModelForSemanticSegmentation = auto_class_update(
    TFAutoModelForSemanticSegmentation, head_doc="semantic segmentation"
)


class TFAutoModelForVision2Seq(_BaseAutoModelClass):
    _model_mapping = TF_MODEL_FOR_VISION_2_SEQ_MAPPING


TFAutoModelForVision2Seq = auto_class_update(TFAutoModelForVision2Seq, head_doc="vision-to-text modeling")


class TFAutoModelForMaskedLM(_BaseAutoModelClass):
    _model_mapping = TF_MODEL_FOR_MASKED_LM_MAPPING


TFAutoModelForMaskedLM = auto_class_update(TFAutoModelForMaskedLM, head_doc="masked language modeling")


class TFAutoModelForSeq2SeqLM(_BaseAutoModelClass):
    _model_mapping = TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING


TFAutoModelForSeq2SeqLM = auto_class_update(
    TFAutoModelForSeq2SeqLM, head_doc="sequence-to-sequence language modeling", checkpoint_for_example="t5-base"
)


class TFAutoModelForSequenceClassification(_BaseAutoModelClass):
    _model_mapping = TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING


TFAutoModelForSequenceClassification = auto_class_update(
    TFAutoModelForSequenceClassification, head_doc="sequence classification"
)


class TFAutoModelForQuestionAnswering(_BaseAutoModelClass):
    _model_mapping = TF_MODEL_FOR_QUESTION_ANSWERING_MAPPING


TFAutoModelForQuestionAnswering = auto_class_update(TFAutoModelForQuestionAnswering, head_doc="question answering")


class TFAutoModelForDocumentQuestionAnswering(_BaseAutoModelClass):
    _model_mapping = TF_MODEL_FOR_DOCUMENT_QUESTION_ANSWERING_MAPPING


TFAutoModelForDocumentQuestionAnswering = auto_class_update(
    TFAutoModelForDocumentQuestionAnswering,
    head_doc="document question answering",
    checkpoint_for_example='impira/layoutlm-document-qa", revision="52e01b3',
)


class TFAutoModelForTableQuestionAnswering(_BaseAutoModelClass):
    _model_mapping = TF_MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING


TFAutoModelForTableQuestionAnswering = auto_class_update(
    TFAutoModelForTableQuestionAnswering,
    head_doc="table question answering",
    checkpoint_for_example="google/tapas-base-finetuned-wtq",
)


class TFAutoModelForTokenClassification(_BaseAutoModelClass):
    _model_mapping = TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING


TFAutoModelForTokenClassification = auto_class_update(
    TFAutoModelForTokenClassification, head_doc="token classification"
)


class TFAutoModelForMultipleChoice(_BaseAutoModelClass):
    _model_mapping = TF_MODEL_FOR_MULTIPLE_CHOICE_MAPPING


TFAutoModelForMultipleChoice = auto_class_update(TFAutoModelForMultipleChoice, head_doc="multiple choice")


class TFAutoModelForNextSentencePrediction(_BaseAutoModelClass):
    _model_mapping = TF_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING


TFAutoModelForNextSentencePrediction = auto_class_update(
    TFAutoModelForNextSentencePrediction, head_doc="next sentence prediction"
)


class TFAutoModelForSpeechSeq2Seq(_BaseAutoModelClass):
    _model_mapping = TF_MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING


TFAutoModelForSpeechSeq2Seq = auto_class_update(
    TFAutoModelForSpeechSeq2Seq, head_doc="sequence-to-sequence speech-to-text modeling"
)


class TFAutoModelWithLMHead(_TFAutoModelWithLMHead):
    @classmethod
    def from_config(cls, config):
        warnings.warn(
            "The class `TFAutoModelWithLMHead` is deprecated and will be removed in a future version. Please use"
            " `TFAutoModelForCausalLM` for causal language models, `TFAutoModelForMaskedLM` for masked language models"
            " and `TFAutoModelForSeq2SeqLM` for encoder-decoder models.",
            FutureWarning,
        )
        return super().from_config(config)

    @classmethod
    def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
        warnings.warn(
            "The class `TFAutoModelWithLMHead` is deprecated and will be removed in a future version. Please use"
            " `TFAutoModelForCausalLM` for causal language models, `TFAutoModelForMaskedLM` for masked language models"
            " and `TFAutoModelForSeq2SeqLM` for encoder-decoder models.",
            FutureWarning,
        )
        return super().from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)