217 lines
8.8 KiB
Python
217 lines
8.8 KiB
Python
|
from typing import Callable, Dict, List, Union
|
||
|
|
||
|
from TTS.tts.utils.text import cleaners
|
||
|
from TTS.tts.utils.text.characters import Graphemes, IPAPhonemes
|
||
|
from TTS.tts.utils.text.phonemizers import DEF_LANG_TO_PHONEMIZER, get_phonemizer_by_name
|
||
|
from TTS.tts.utils.text.phonemizers.multi_phonemizer import MultiPhonemizer
|
||
|
from TTS.utils.generic_utils import get_import_path, import_class
|
||
|
|
||
|
|
||
|
class TTSTokenizer:
|
||
|
"""🐸TTS tokenizer to convert input characters to token IDs and back.
|
||
|
|
||
|
Token IDs for OOV chars are discarded but those are stored in `self.not_found_characters` for later.
|
||
|
|
||
|
Args:
|
||
|
use_phonemes (bool):
|
||
|
Whether to use phonemes instead of characters. Defaults to False.
|
||
|
|
||
|
characters (Characters):
|
||
|
A Characters object to use for character-to-ID and ID-to-character mappings.
|
||
|
|
||
|
text_cleaner (callable):
|
||
|
A function to pre-process the text before tokenization and phonemization. Defaults to None.
|
||
|
|
||
|
phonemizer (Phonemizer):
|
||
|
A phonemizer object or a dict that maps language codes to phonemizer objects. Defaults to None.
|
||
|
|
||
|
Example:
|
||
|
|
||
|
>>> from TTS.tts.utils.text.tokenizer import TTSTokenizer
|
||
|
>>> tokenizer = TTSTokenizer(use_phonemes=False, characters=Graphemes())
|
||
|
>>> text = "Hello world!"
|
||
|
>>> ids = tokenizer.text_to_ids(text)
|
||
|
>>> text_hat = tokenizer.ids_to_text(ids)
|
||
|
>>> assert text == text_hat
|
||
|
"""
|
||
|
|
||
|
def __init__(
|
||
|
self,
|
||
|
use_phonemes=False,
|
||
|
text_cleaner: Callable = None,
|
||
|
characters: "BaseCharacters" = None,
|
||
|
phonemizer: Union["Phonemizer", Dict] = None,
|
||
|
add_blank: bool = False,
|
||
|
use_eos_bos=False,
|
||
|
):
|
||
|
self.text_cleaner = text_cleaner
|
||
|
self.use_phonemes = use_phonemes
|
||
|
self.add_blank = add_blank
|
||
|
self.use_eos_bos = use_eos_bos
|
||
|
self.characters = characters
|
||
|
self.not_found_characters = []
|
||
|
self.phonemizer = phonemizer
|
||
|
|
||
|
@property
|
||
|
def characters(self):
|
||
|
return self._characters
|
||
|
|
||
|
@characters.setter
|
||
|
def characters(self, new_characters):
|
||
|
self._characters = new_characters
|
||
|
self.pad_id = self.characters.char_to_id(self.characters.pad) if self.characters.pad else None
|
||
|
self.blank_id = self.characters.char_to_id(self.characters.blank) if self.characters.blank else None
|
||
|
|
||
|
def encode(self, text: str) -> List[int]:
|
||
|
"""Encodes a string of text as a sequence of IDs."""
|
||
|
token_ids = []
|
||
|
for char in text:
|
||
|
try:
|
||
|
idx = self.characters.char_to_id(char)
|
||
|
token_ids.append(idx)
|
||
|
except KeyError:
|
||
|
# discard but store not found characters
|
||
|
if char not in self.not_found_characters:
|
||
|
self.not_found_characters.append(char)
|
||
|
print(text)
|
||
|
print(f" [!] Character {repr(char)} not found in the vocabulary. Discarding it.")
|
||
|
return token_ids
|
||
|
|
||
|
def decode(self, token_ids: List[int]) -> str:
|
||
|
"""Decodes a sequence of IDs to a string of text."""
|
||
|
text = ""
|
||
|
for token_id in token_ids:
|
||
|
text += self.characters.id_to_char(token_id)
|
||
|
return text
|
||
|
|
||
|
def text_to_ids(self, text: str, language: str = None) -> List[int]: # pylint: disable=unused-argument
|
||
|
"""Converts a string of text to a sequence of token IDs.
|
||
|
|
||
|
Args:
|
||
|
text(str):
|
||
|
The text to convert to token IDs.
|
||
|
|
||
|
language(str):
|
||
|
The language code of the text. Defaults to None.
|
||
|
|
||
|
TODO:
|
||
|
- Add support for language-specific processing.
|
||
|
|
||
|
1. Text normalizatin
|
||
|
2. Phonemization (if use_phonemes is True)
|
||
|
3. Add blank char between characters
|
||
|
4. Add BOS and EOS characters
|
||
|
5. Text to token IDs
|
||
|
"""
|
||
|
# TODO: text cleaner should pick the right routine based on the language
|
||
|
if self.text_cleaner is not None:
|
||
|
text = self.text_cleaner(text)
|
||
|
if self.use_phonemes:
|
||
|
text = self.phonemizer.phonemize(text, separator="", language=language)
|
||
|
text = self.encode(text)
|
||
|
if self.add_blank:
|
||
|
text = self.intersperse_blank_char(text, True)
|
||
|
if self.use_eos_bos:
|
||
|
text = self.pad_with_bos_eos(text)
|
||
|
return text
|
||
|
|
||
|
def ids_to_text(self, id_sequence: List[int]) -> str:
|
||
|
"""Converts a sequence of token IDs to a string of text."""
|
||
|
return self.decode(id_sequence)
|
||
|
|
||
|
def pad_with_bos_eos(self, char_sequence: List[str]):
|
||
|
"""Pads a sequence with the special BOS and EOS characters."""
|
||
|
return [self.characters.bos_id] + list(char_sequence) + [self.characters.eos_id]
|
||
|
|
||
|
def intersperse_blank_char(self, char_sequence: List[str], use_blank_char: bool = False):
|
||
|
"""Intersperses the blank character between characters in a sequence.
|
||
|
|
||
|
Use the ```blank``` character if defined else use the ```pad``` character.
|
||
|
"""
|
||
|
char_to_use = self.characters.blank_id if use_blank_char else self.characters.pad
|
||
|
result = [char_to_use] * (len(char_sequence) * 2 + 1)
|
||
|
result[1::2] = char_sequence
|
||
|
return result
|
||
|
|
||
|
def print_logs(self, level: int = 0):
|
||
|
indent = "\t" * level
|
||
|
print(f"{indent}| > add_blank: {self.add_blank}")
|
||
|
print(f"{indent}| > use_eos_bos: {self.use_eos_bos}")
|
||
|
print(f"{indent}| > use_phonemes: {self.use_phonemes}")
|
||
|
if self.use_phonemes:
|
||
|
print(f"{indent}| > phonemizer:")
|
||
|
self.phonemizer.print_logs(level + 1)
|
||
|
if len(self.not_found_characters) > 0:
|
||
|
print(f"{indent}| > {len(self.not_found_characters)} not found characters:")
|
||
|
for char in self.not_found_characters:
|
||
|
print(f"{indent}| > {char}")
|
||
|
|
||
|
@staticmethod
|
||
|
def init_from_config(config: "Coqpit", characters: "BaseCharacters" = None):
|
||
|
"""Init Tokenizer object from config
|
||
|
|
||
|
Args:
|
||
|
config (Coqpit): Coqpit model config.
|
||
|
characters (BaseCharacters): Defines the model character set. If not set, use the default options based on
|
||
|
the config values. Defaults to None.
|
||
|
"""
|
||
|
# init cleaners
|
||
|
text_cleaner = None
|
||
|
if isinstance(config.text_cleaner, (str, list)):
|
||
|
text_cleaner = getattr(cleaners, config.text_cleaner)
|
||
|
|
||
|
# init characters
|
||
|
if characters is None:
|
||
|
# set characters based on defined characters class
|
||
|
if config.characters and config.characters.characters_class:
|
||
|
CharactersClass = import_class(config.characters.characters_class)
|
||
|
characters, new_config = CharactersClass.init_from_config(config)
|
||
|
# set characters based on config
|
||
|
else:
|
||
|
if config.use_phonemes:
|
||
|
# init phoneme set
|
||
|
characters, new_config = IPAPhonemes().init_from_config(config)
|
||
|
else:
|
||
|
# init character set
|
||
|
characters, new_config = Graphemes().init_from_config(config)
|
||
|
|
||
|
else:
|
||
|
characters, new_config = characters.init_from_config(config)
|
||
|
|
||
|
# set characters class
|
||
|
new_config.characters.characters_class = get_import_path(characters)
|
||
|
|
||
|
# init phonemizer
|
||
|
phonemizer = None
|
||
|
if config.use_phonemes:
|
||
|
if "phonemizer" in config and config.phonemizer == "multi_phonemizer":
|
||
|
lang_to_phonemizer_name = {}
|
||
|
for dataset in config.datasets:
|
||
|
if dataset.language != "":
|
||
|
lang_to_phonemizer_name[dataset.language] = dataset.phonemizer
|
||
|
else:
|
||
|
raise ValueError("Multi phonemizer requires language to be set for each dataset.")
|
||
|
phonemizer = MultiPhonemizer(lang_to_phonemizer_name)
|
||
|
else:
|
||
|
phonemizer_kwargs = {"language": config.phoneme_language}
|
||
|
if "phonemizer" in config and config.phonemizer:
|
||
|
phonemizer = get_phonemizer_by_name(config.phonemizer, **phonemizer_kwargs)
|
||
|
else:
|
||
|
try:
|
||
|
phonemizer = get_phonemizer_by_name(
|
||
|
DEF_LANG_TO_PHONEMIZER[config.phoneme_language], **phonemizer_kwargs
|
||
|
)
|
||
|
new_config.phonemizer = phonemizer.name()
|
||
|
except KeyError as e:
|
||
|
raise ValueError(
|
||
|
f"""No phonemizer found for language {config.phoneme_language}.
|
||
|
You may need to install a third party library for this language."""
|
||
|
) from e
|
||
|
|
||
|
return (
|
||
|
TTSTokenizer(
|
||
|
config.use_phonemes, text_cleaner, characters, phonemizer, config.add_blank, config.enable_eos_bos_chars
|
||
|
),
|
||
|
new_config,
|
||
|
)
|