# coding=utf-8 # Copyright 2023 The Kakao Enterprise Authors, the MMS-TTS Authors and the HuggingFace Inc. team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Tokenization class for VITS.""" import json import os import re from typing import Any, Dict, List, Optional, Tuple, Union from ...tokenization_utils import PreTrainedTokenizer from ...utils import is_phonemizer_available, logging if is_phonemizer_available(): import phonemizer logger = logging.get_logger(__name__) VOCAB_FILES_NAMES = {"vocab_file": "vocab.json"} def has_non_roman_characters(input_string): # Find any character outside the ASCII range non_roman_pattern = re.compile(r"[^\x00-\x7F]") # Search the input string for non-Roman characters match = non_roman_pattern.search(input_string) has_non_roman = match is not None return has_non_roman class VitsTokenizer(PreTrainedTokenizer): """ Construct a VITS tokenizer. Also supports MMS-TTS. This tokenizer inherits from [`PreTrainedTokenizer`] which contains most of the main methods. Users should refer to this superclass for more information regarding those methods. Args: vocab_file (`str`): Path to the vocabulary file. language (`str`, *optional*): Language identifier. add_blank (`bool`, *optional*, defaults to `True`): Whether to insert token id 0 in between the other tokens. normalize (`bool`, *optional*, defaults to `True`): Whether to normalize the input text by removing all casing and punctuation. phonemize (`bool`, *optional*, defaults to `True`): Whether to convert the input text into phonemes. is_uroman (`bool`, *optional*, defaults to `False`): Whether the `uroman` Romanizer needs to be applied to the input text prior to tokenizing. """ vocab_files_names = VOCAB_FILES_NAMES model_input_names = ["input_ids", "attention_mask"] def __init__( self, vocab_file, pad_token="", unk_token="", language=None, add_blank=True, normalize=True, phonemize=True, is_uroman=False, **kwargs, ) -> None: with open(vocab_file, encoding="utf-8") as vocab_handle: self.encoder = json.load(vocab_handle) self.decoder = {v: k for k, v in self.encoder.items()} self.language = language self.add_blank = add_blank self.normalize = normalize self.phonemize = phonemize self.is_uroman = is_uroman super().__init__( pad_token=pad_token, unk_token=unk_token, language=language, add_blank=add_blank, normalize=normalize, phonemize=phonemize, is_uroman=is_uroman, **kwargs, ) @property def vocab_size(self): return len(self.encoder) def get_vocab(self): vocab = {self.convert_ids_to_tokens(i): i for i in range(self.vocab_size)} vocab.update(self.added_tokens_encoder) return vocab def normalize_text(self, input_string): """Lowercase the input string, respecting any special token ids that may be part or entirely upper-cased.""" all_vocabulary = list(self.encoder.keys()) + list(self.added_tokens_encoder.keys()) filtered_text = "" i = 0 while i < len(input_string): found_match = False for word in all_vocabulary: if input_string[i : i + len(word)] == word: filtered_text += word i += len(word) found_match = True break if not found_match: filtered_text += input_string[i].lower() i += 1 return filtered_text def _preprocess_char(self, text): """Special treatment of characters in certain languages""" if self.language == "ron": text = text.replace("ț", "ţ") return text def prepare_for_tokenization( self, text: str, is_split_into_words: bool = False, normalize: Optional[bool] = None, **kwargs ) -> Tuple[str, Dict[str, Any]]: """ Performs any necessary transformations before tokenization. This method should pop the arguments from kwargs and return the remaining `kwargs` as well. We test the `kwargs` at the end of the encoding process to be sure all the arguments have been used. Args: text (`str`): The text to prepare. is_split_into_words (`bool`, *optional*, defaults to `False`): Whether or not the input is already pre-tokenized (e.g., split into words). If set to `True`, the tokenizer assumes the input is already split into words (for instance, by splitting it on whitespace) which it will tokenize. normalize (`bool`, *optional*, defaults to `None`): Whether or not to apply punctuation and casing normalization to the text inputs. Typically, VITS is trained on lower-cased and un-punctuated text. Hence, normalization is used to ensure that the input text consists only of lower-case characters. kwargs (`Dict[str, Any]`, *optional*): Keyword arguments to use for the tokenization. Returns: `Tuple[str, Dict[str, Any]]`: The prepared text and the unused kwargs. """ normalize = normalize if normalize is not None else self.normalize if normalize: # normalise for casing text = self.normalize_text(text) filtered_text = self._preprocess_char(text) if has_non_roman_characters(filtered_text) and self.is_uroman: logger.warning( "Text to the tokenizer contains non-Roman characters. Ensure the `uroman` Romanizer is " "applied to the text prior to passing it to the tokenizer. See " "`https://github.com/isi-nlp/uroman` for details." ) if self.phonemize: if not is_phonemizer_available(): raise ImportError("Please install the `phonemizer` Python package to use this tokenizer.") filtered_text = phonemizer.phonemize( filtered_text, language="en-us", backend="espeak", strip=True, preserve_punctuation=True, with_stress=True, ) filtered_text = re.sub(r"\s+", " ", filtered_text) elif normalize: # strip any chars outside of the vocab (punctuation) filtered_text = "".join(list(filter(lambda char: char in self.encoder, filtered_text))).strip() return filtered_text, kwargs def _tokenize(self, text: str) -> List[str]: """Tokenize a string by inserting the `` token at the boundary between adjacent characters.""" tokens = list(text) if self.add_blank: interspersed = [self._convert_id_to_token(0)] * (len(tokens) * 2 + 1) interspersed[1::2] = tokens tokens = interspersed return tokens def convert_tokens_to_string(self, tokens: List[str]) -> str: if self.add_blank and len(tokens) > 1: tokens = tokens[1::2] return "".join(tokens) def _convert_token_to_id(self, token): """Converts a token (str) in an id using the vocab.""" return self.encoder.get(token, self.encoder.get(self.unk_token)) def _convert_id_to_token(self, index): """Converts an index (integer) in a token (str) using the vocab.""" return self.decoder.get(index) def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Union[Tuple[str], None]: if not os.path.isdir(save_directory): logger.error(f"Vocabulary path ({save_directory}) should be a directory") return vocab_file = os.path.join( save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"] ) with open(vocab_file, "w", encoding="utf-8") as f: f.write(json.dumps(self.encoder, indent=2, sort_keys=True, ensure_ascii=False) + "\n") return (vocab_file,)