174 lines
7.7 KiB
Python
174 lines
7.7 KiB
Python
# coding=utf-8
|
|
# Copyright 2023 The HuggingFace Inc. team.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
"""
|
|
Processor class for InstructBLIP. Largely copy of Blip2Processor with addition of a tokenizer for the Q-Former.
|
|
"""
|
|
|
|
import os
|
|
from typing import List, Optional, Union
|
|
|
|
from ...image_processing_utils import BatchFeature
|
|
from ...image_utils import ImageInput
|
|
from ...processing_utils import ProcessorMixin
|
|
from ...tokenization_utils_base import PaddingStrategy, PreTokenizedInput, TextInput, TruncationStrategy
|
|
from ...utils import TensorType
|
|
from ..auto import AutoTokenizer
|
|
|
|
|
|
class InstructBlipProcessor(ProcessorMixin):
|
|
r"""
|
|
Constructs an InstructBLIP processor which wraps a BLIP image processor and a LLaMa/T5 tokenizer into a single
|
|
processor.
|
|
|
|
[`InstructBlipProcessor`] offers all the functionalities of [`BlipImageProcessor`] and [`AutoTokenizer`]. See the
|
|
docstring of [`~BlipProcessor.__call__`] and [`~BlipProcessor.decode`] for more information.
|
|
|
|
Args:
|
|
image_processor (`BlipImageProcessor`):
|
|
An instance of [`BlipImageProcessor`]. The image processor is a required input.
|
|
tokenizer (`AutoTokenizer`):
|
|
An instance of ['PreTrainedTokenizer`]. The tokenizer is a required input.
|
|
qformer_tokenizer (`AutoTokenizer`):
|
|
An instance of ['PreTrainedTokenizer`]. The Q-Former tokenizer is a required input.
|
|
"""
|
|
|
|
attributes = ["image_processor", "tokenizer"]
|
|
image_processor_class = "BlipImageProcessor"
|
|
tokenizer_class = "AutoTokenizer"
|
|
|
|
def __init__(self, image_processor, tokenizer, qformer_tokenizer):
|
|
super().__init__(image_processor, tokenizer)
|
|
|
|
# add QFormer tokenizer
|
|
self.qformer_tokenizer = qformer_tokenizer
|
|
|
|
def __call__(
|
|
self,
|
|
images: ImageInput = None,
|
|
text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] = None,
|
|
add_special_tokens: bool = True,
|
|
padding: Union[bool, str, PaddingStrategy] = False,
|
|
truncation: Union[bool, str, TruncationStrategy] = None,
|
|
max_length: Optional[int] = None,
|
|
stride: int = 0,
|
|
pad_to_multiple_of: Optional[int] = None,
|
|
return_attention_mask: Optional[bool] = None,
|
|
return_overflowing_tokens: bool = False,
|
|
return_special_tokens_mask: bool = False,
|
|
return_offsets_mapping: bool = False,
|
|
return_token_type_ids: bool = False,
|
|
return_length: bool = False,
|
|
verbose: bool = True,
|
|
return_tensors: Optional[Union[str, TensorType]] = None,
|
|
**kwargs,
|
|
) -> BatchFeature:
|
|
"""
|
|
This method uses [`BlipImageProcessor.__call__`] method to prepare image(s) for the model, and
|
|
[`BertTokenizerFast.__call__`] to prepare text for the model.
|
|
|
|
Please refer to the docstring of the above two methods for more information.
|
|
"""
|
|
if images is None and text is None:
|
|
raise ValueError("You have to specify at least images or text.")
|
|
|
|
encoding = BatchFeature()
|
|
|
|
if text is not None:
|
|
text_encoding = self.tokenizer(
|
|
text=text,
|
|
add_special_tokens=add_special_tokens,
|
|
padding=padding,
|
|
truncation=truncation,
|
|
max_length=max_length,
|
|
stride=stride,
|
|
pad_to_multiple_of=pad_to_multiple_of,
|
|
return_attention_mask=return_attention_mask,
|
|
return_overflowing_tokens=return_overflowing_tokens,
|
|
return_special_tokens_mask=return_special_tokens_mask,
|
|
return_offsets_mapping=return_offsets_mapping,
|
|
return_token_type_ids=return_token_type_ids,
|
|
return_length=return_length,
|
|
verbose=verbose,
|
|
return_tensors=return_tensors,
|
|
**kwargs,
|
|
)
|
|
encoding.update(text_encoding)
|
|
qformer_text_encoding = self.qformer_tokenizer(
|
|
text=text,
|
|
add_special_tokens=add_special_tokens,
|
|
padding=padding,
|
|
truncation=truncation,
|
|
max_length=max_length,
|
|
stride=stride,
|
|
pad_to_multiple_of=pad_to_multiple_of,
|
|
return_attention_mask=return_attention_mask,
|
|
return_overflowing_tokens=return_overflowing_tokens,
|
|
return_special_tokens_mask=return_special_tokens_mask,
|
|
return_offsets_mapping=return_offsets_mapping,
|
|
return_token_type_ids=return_token_type_ids,
|
|
return_length=return_length,
|
|
verbose=verbose,
|
|
return_tensors=return_tensors,
|
|
**kwargs,
|
|
)
|
|
encoding["qformer_input_ids"] = qformer_text_encoding.pop("input_ids")
|
|
encoding["qformer_attention_mask"] = qformer_text_encoding.pop("attention_mask")
|
|
|
|
if images is not None:
|
|
image_encoding = self.image_processor(images, return_tensors=return_tensors)
|
|
encoding.update(image_encoding)
|
|
|
|
return encoding
|
|
|
|
# Copied from transformers.models.blip.processing_blip.BlipProcessor.batch_decode with BertTokenizerFast->PreTrainedTokenizer
|
|
def batch_decode(self, *args, **kwargs):
|
|
"""
|
|
This method forwards all its arguments to PreTrainedTokenizer's [`~PreTrainedTokenizer.batch_decode`]. Please
|
|
refer to the docstring of this method for more information.
|
|
"""
|
|
return self.tokenizer.batch_decode(*args, **kwargs)
|
|
|
|
# Copied from transformers.models.blip.processing_blip.BlipProcessor.decode with BertTokenizerFast->PreTrainedTokenizer
|
|
def decode(self, *args, **kwargs):
|
|
"""
|
|
This method forwards all its arguments to PreTrainedTokenizer's [`~PreTrainedTokenizer.decode`]. Please refer to
|
|
the docstring of this method for more information.
|
|
"""
|
|
return self.tokenizer.decode(*args, **kwargs)
|
|
|
|
@property
|
|
# Copied from transformers.models.blip.processing_blip.BlipProcessor.model_input_names
|
|
def model_input_names(self):
|
|
tokenizer_input_names = self.tokenizer.model_input_names
|
|
image_processor_input_names = self.image_processor.model_input_names
|
|
return list(dict.fromkeys(tokenizer_input_names + image_processor_input_names))
|
|
|
|
# overwrite to save the Q-Former tokenizer in a separate folder
|
|
def save_pretrained(self, save_directory, **kwargs):
|
|
if os.path.isfile(save_directory):
|
|
raise ValueError(f"Provided path ({save_directory}) should be a directory, not a file")
|
|
os.makedirs(save_directory, exist_ok=True)
|
|
qformer_tokenizer_path = os.path.join(save_directory, "qformer_tokenizer")
|
|
self.qformer_tokenizer.save_pretrained(qformer_tokenizer_path)
|
|
return super().save_pretrained(save_directory, **kwargs)
|
|
|
|
# overwrite to load the Q-Former tokenizer from a separate folder
|
|
@classmethod
|
|
def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
|
|
qformer_tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path, subfolder="qformer_tokenizer")
|
|
args = cls._get_arguments_from_pretrained(pretrained_model_name_or_path, **kwargs)
|
|
args.append(qformer_tokenizer)
|
|
return cls(*args)
|