# coding=utf-8 # Copyright 2021 The HuggingFace Inc. team. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """ Speech processor class for Wav2Vec2 """ import warnings from contextlib import contextmanager from ...processing_utils import ProcessorMixin from .feature_extraction_wav2vec2 import Wav2Vec2FeatureExtractor from .tokenization_wav2vec2 import Wav2Vec2CTCTokenizer class Wav2Vec2Processor(ProcessorMixin): r""" Constructs a Wav2Vec2 processor which wraps a Wav2Vec2 feature extractor and a Wav2Vec2 CTC tokenizer into a single processor. [`Wav2Vec2Processor`] offers all the functionalities of [`Wav2Vec2FeatureExtractor`] and [`PreTrainedTokenizer`]. See the docstring of [`~Wav2Vec2Processor.__call__`] and [`~Wav2Vec2Processor.decode`] for more information. Args: feature_extractor (`Wav2Vec2FeatureExtractor`): An instance of [`Wav2Vec2FeatureExtractor`]. The feature extractor is a required input. tokenizer ([`PreTrainedTokenizer`]): An instance of [`PreTrainedTokenizer`]. The tokenizer is a required input. """ feature_extractor_class = "Wav2Vec2FeatureExtractor" tokenizer_class = "AutoTokenizer" def __init__(self, feature_extractor, tokenizer): super().__init__(feature_extractor, tokenizer) self.current_processor = self.feature_extractor self._in_target_context_manager = False @classmethod def from_pretrained(cls, pretrained_model_name_or_path, **kwargs): try: return super().from_pretrained(pretrained_model_name_or_path, **kwargs) except OSError: warnings.warn( f"Loading a tokenizer inside {cls.__name__} from a config that does not" " include a `tokenizer_class` attribute is deprecated and will be " "removed in v5. Please add `'tokenizer_class': 'Wav2Vec2CTCTokenizer'`" " attribute to either your `config.json` or `tokenizer_config.json` " "file to suppress this warning: ", FutureWarning, ) feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(pretrained_model_name_or_path, **kwargs) tokenizer = Wav2Vec2CTCTokenizer.from_pretrained(pretrained_model_name_or_path, **kwargs) return cls(feature_extractor=feature_extractor, tokenizer=tokenizer) def __call__(self, *args, **kwargs): """ When used in normal mode, this method forwards all its arguments to Wav2Vec2FeatureExtractor's [`~Wav2Vec2FeatureExtractor.__call__`] and returns its output. If used in the context [`~Wav2Vec2Processor.as_target_processor`] this method forwards all its arguments to PreTrainedTokenizer's [`~PreTrainedTokenizer.__call__`]. Please refer to the docstring of the above two methods for more information. """ # For backward compatibility if self._in_target_context_manager: return self.current_processor(*args, **kwargs) if "raw_speech" in kwargs: warnings.warn("Using `raw_speech` as a keyword argument is deprecated. Use `audio` instead.") audio = kwargs.pop("raw_speech") else: audio = kwargs.pop("audio", None) sampling_rate = kwargs.pop("sampling_rate", None) text = kwargs.pop("text", None) if len(args) > 0: audio = args[0] args = args[1:] if audio is None and text is None: raise ValueError("You need to specify either an `audio` or `text` input to process.") if audio is not None: inputs = self.feature_extractor(audio, *args, sampling_rate=sampling_rate, **kwargs) if text is not None: encodings = self.tokenizer(text, **kwargs) if text is None: return inputs elif audio is None: return encodings else: inputs["labels"] = encodings["input_ids"] return inputs def pad(self, *args, **kwargs): """ When used in normal mode, this method forwards all its arguments to Wav2Vec2FeatureExtractor's [`~Wav2Vec2FeatureExtractor.pad`] and returns its output. If used in the context [`~Wav2Vec2Processor.as_target_processor`] this method forwards all its arguments to PreTrainedTokenizer's [`~PreTrainedTokenizer.pad`]. Please refer to the docstring of the above two methods for more information. """ # For backward compatibility if self._in_target_context_manager: return self.current_processor.pad(*args, **kwargs) input_features = kwargs.pop("input_features", None) labels = kwargs.pop("labels", None) if len(args) > 0: input_features = args[0] args = args[1:] if input_features is not None: input_features = self.feature_extractor.pad(input_features, *args, **kwargs) if labels is not None: labels = self.tokenizer.pad(labels, **kwargs) if labels is None: return input_features elif input_features is None: return labels else: input_features["labels"] = labels["input_ids"] return input_features def batch_decode(self, *args, **kwargs): """ This method forwards all its arguments to PreTrainedTokenizer's [`~PreTrainedTokenizer.batch_decode`]. Please refer to the docstring of this method for more information. """ return self.tokenizer.batch_decode(*args, **kwargs) def decode(self, *args, **kwargs): """ This method forwards all its arguments to PreTrainedTokenizer's [`~PreTrainedTokenizer.decode`]. Please refer to the docstring of this method for more information. """ return self.tokenizer.decode(*args, **kwargs) @contextmanager def as_target_processor(self): """ Temporarily sets the tokenizer for processing the input. Useful for encoding the labels when fine-tuning Wav2Vec2. """ warnings.warn( "`as_target_processor` is deprecated and will be removed in v5 of Transformers. You can process your " "labels by using the argument `text` of the regular `__call__` method (either in the same call as " "your audio inputs, or in a separate call." ) self._in_target_context_manager = True self.current_processor = self.tokenizer yield self.current_processor = self.feature_extractor self._in_target_context_manager = False