2000 lines
80 KiB
Python
2000 lines
80 KiB
Python
import math
|
|
import os
|
|
from dataclasses import dataclass, field, replace
|
|
from itertools import chain
|
|
from typing import Dict, List, Tuple, Union
|
|
|
|
import numpy as np
|
|
import torch
|
|
import torch.distributed as dist
|
|
import torchaudio
|
|
from coqpit import Coqpit
|
|
from librosa.filters import mel as librosa_mel_fn
|
|
from torch import nn
|
|
from torch.cuda.amp.autocast_mode import autocast
|
|
from torch.nn import functional as F
|
|
from torch.utils.data import DataLoader
|
|
from torch.utils.data.sampler import WeightedRandomSampler
|
|
from trainer.torch import DistributedSampler, DistributedSamplerWrapper
|
|
from trainer.trainer_utils import get_optimizer, get_scheduler
|
|
|
|
from TTS.tts.configs.shared_configs import CharactersConfig
|
|
from TTS.tts.datasets.dataset import TTSDataset, _parse_sample
|
|
from TTS.tts.layers.glow_tts.duration_predictor import DurationPredictor
|
|
from TTS.tts.layers.vits.discriminator import VitsDiscriminator
|
|
from TTS.tts.layers.vits.networks import PosteriorEncoder, ResidualCouplingBlocks, TextEncoder
|
|
from TTS.tts.layers.vits.stochastic_duration_predictor import StochasticDurationPredictor
|
|
from TTS.tts.models.base_tts import BaseTTS
|
|
from TTS.tts.utils.fairseq import rehash_fairseq_vits_checkpoint
|
|
from TTS.tts.utils.helpers import generate_path, maximum_path, rand_segments, segment, sequence_mask
|
|
from TTS.tts.utils.languages import LanguageManager
|
|
from TTS.tts.utils.speakers import SpeakerManager
|
|
from TTS.tts.utils.synthesis import synthesis
|
|
from TTS.tts.utils.text.characters import BaseCharacters, BaseVocabulary, _characters, _pad, _phonemes, _punctuations
|
|
from TTS.tts.utils.text.tokenizer import TTSTokenizer
|
|
from TTS.tts.utils.visual import plot_alignment
|
|
from TTS.utils.io import load_fsspec
|
|
from TTS.utils.samplers import BucketBatchSampler
|
|
from TTS.vocoder.models.hifigan_generator import HifiganGenerator
|
|
from TTS.vocoder.utils.generic_utils import plot_results
|
|
|
|
##############################
|
|
# IO / Feature extraction
|
|
##############################
|
|
|
|
# pylint: disable=global-statement
|
|
hann_window = {}
|
|
mel_basis = {}
|
|
|
|
|
|
@torch.no_grad()
|
|
def weights_reset(m: nn.Module):
|
|
# check if the current module has reset_parameters and if it is reset the weight
|
|
reset_parameters = getattr(m, "reset_parameters", None)
|
|
if callable(reset_parameters):
|
|
m.reset_parameters()
|
|
|
|
|
|
def get_module_weights_sum(mdl: nn.Module):
|
|
dict_sums = {}
|
|
for name, w in mdl.named_parameters():
|
|
if "weight" in name:
|
|
value = w.data.sum().item()
|
|
dict_sums[name] = value
|
|
return dict_sums
|
|
|
|
|
|
def load_audio(file_path):
|
|
"""Load the audio file normalized in [-1, 1]
|
|
|
|
Return Shapes:
|
|
- x: :math:`[1, T]`
|
|
"""
|
|
x, sr = torchaudio.load(file_path)
|
|
assert (x > 1).sum() + (x < -1).sum() == 0
|
|
return x, sr
|
|
|
|
|
|
def _amp_to_db(x, C=1, clip_val=1e-5):
|
|
return torch.log(torch.clamp(x, min=clip_val) * C)
|
|
|
|
|
|
def _db_to_amp(x, C=1):
|
|
return torch.exp(x) / C
|
|
|
|
|
|
def amp_to_db(magnitudes):
|
|
output = _amp_to_db(magnitudes)
|
|
return output
|
|
|
|
|
|
def db_to_amp(magnitudes):
|
|
output = _db_to_amp(magnitudes)
|
|
return output
|
|
|
|
|
|
def wav_to_spec(y, n_fft, hop_length, win_length, center=False):
|
|
"""
|
|
Args Shapes:
|
|
- y : :math:`[B, 1, T]`
|
|
|
|
Return Shapes:
|
|
- spec : :math:`[B,C,T]`
|
|
"""
|
|
y = y.squeeze(1)
|
|
|
|
if torch.min(y) < -1.0:
|
|
print("min value is ", torch.min(y))
|
|
if torch.max(y) > 1.0:
|
|
print("max value is ", torch.max(y))
|
|
|
|
global hann_window
|
|
dtype_device = str(y.dtype) + "_" + str(y.device)
|
|
wnsize_dtype_device = str(win_length) + "_" + dtype_device
|
|
if wnsize_dtype_device not in hann_window:
|
|
hann_window[wnsize_dtype_device] = torch.hann_window(win_length).to(dtype=y.dtype, device=y.device)
|
|
|
|
y = torch.nn.functional.pad(
|
|
y.unsqueeze(1),
|
|
(int((n_fft - hop_length) / 2), int((n_fft - hop_length) / 2)),
|
|
mode="reflect",
|
|
)
|
|
y = y.squeeze(1)
|
|
|
|
spec = torch.stft(
|
|
y,
|
|
n_fft,
|
|
hop_length=hop_length,
|
|
win_length=win_length,
|
|
window=hann_window[wnsize_dtype_device],
|
|
center=center,
|
|
pad_mode="reflect",
|
|
normalized=False,
|
|
onesided=True,
|
|
return_complex=False,
|
|
)
|
|
|
|
spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6)
|
|
return spec
|
|
|
|
|
|
def spec_to_mel(spec, n_fft, num_mels, sample_rate, fmin, fmax):
|
|
"""
|
|
Args Shapes:
|
|
- spec : :math:`[B,C,T]`
|
|
|
|
Return Shapes:
|
|
- mel : :math:`[B,C,T]`
|
|
"""
|
|
global mel_basis
|
|
dtype_device = str(spec.dtype) + "_" + str(spec.device)
|
|
fmax_dtype_device = str(fmax) + "_" + dtype_device
|
|
if fmax_dtype_device not in mel_basis:
|
|
mel = librosa_mel_fn(sr=sample_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax)
|
|
mel_basis[fmax_dtype_device] = torch.from_numpy(mel).to(dtype=spec.dtype, device=spec.device)
|
|
mel = torch.matmul(mel_basis[fmax_dtype_device], spec)
|
|
mel = amp_to_db(mel)
|
|
return mel
|
|
|
|
|
|
def wav_to_mel(y, n_fft, num_mels, sample_rate, hop_length, win_length, fmin, fmax, center=False):
|
|
"""
|
|
Args Shapes:
|
|
- y : :math:`[B, 1, T]`
|
|
|
|
Return Shapes:
|
|
- spec : :math:`[B,C,T]`
|
|
"""
|
|
y = y.squeeze(1)
|
|
|
|
if torch.min(y) < -1.0:
|
|
print("min value is ", torch.min(y))
|
|
if torch.max(y) > 1.0:
|
|
print("max value is ", torch.max(y))
|
|
|
|
global mel_basis, hann_window
|
|
dtype_device = str(y.dtype) + "_" + str(y.device)
|
|
fmax_dtype_device = str(fmax) + "_" + dtype_device
|
|
wnsize_dtype_device = str(win_length) + "_" + dtype_device
|
|
if fmax_dtype_device not in mel_basis:
|
|
mel = librosa_mel_fn(sr=sample_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax)
|
|
mel_basis[fmax_dtype_device] = torch.from_numpy(mel).to(dtype=y.dtype, device=y.device)
|
|
if wnsize_dtype_device not in hann_window:
|
|
hann_window[wnsize_dtype_device] = torch.hann_window(win_length).to(dtype=y.dtype, device=y.device)
|
|
|
|
y = torch.nn.functional.pad(
|
|
y.unsqueeze(1),
|
|
(int((n_fft - hop_length) / 2), int((n_fft - hop_length) / 2)),
|
|
mode="reflect",
|
|
)
|
|
y = y.squeeze(1)
|
|
|
|
spec = torch.stft(
|
|
y,
|
|
n_fft,
|
|
hop_length=hop_length,
|
|
win_length=win_length,
|
|
window=hann_window[wnsize_dtype_device],
|
|
center=center,
|
|
pad_mode="reflect",
|
|
normalized=False,
|
|
onesided=True,
|
|
return_complex=False,
|
|
)
|
|
|
|
spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6)
|
|
spec = torch.matmul(mel_basis[fmax_dtype_device], spec)
|
|
spec = amp_to_db(spec)
|
|
return spec
|
|
|
|
|
|
#############################
|
|
# CONFIGS
|
|
#############################
|
|
|
|
|
|
@dataclass
|
|
class VitsAudioConfig(Coqpit):
|
|
fft_size: int = 1024
|
|
sample_rate: int = 22050
|
|
win_length: int = 1024
|
|
hop_length: int = 256
|
|
num_mels: int = 80
|
|
mel_fmin: int = 0
|
|
mel_fmax: int = None
|
|
|
|
|
|
##############################
|
|
# DATASET
|
|
##############################
|
|
|
|
|
|
def get_attribute_balancer_weights(items: list, attr_name: str, multi_dict: dict = None):
|
|
"""Create inverse frequency weights for balancing the dataset.
|
|
Use `multi_dict` to scale relative weights."""
|
|
attr_names_samples = np.array([item[attr_name] for item in items])
|
|
unique_attr_names = np.unique(attr_names_samples).tolist()
|
|
attr_idx = [unique_attr_names.index(l) for l in attr_names_samples]
|
|
attr_count = np.array([len(np.where(attr_names_samples == l)[0]) for l in unique_attr_names])
|
|
weight_attr = 1.0 / attr_count
|
|
dataset_samples_weight = np.array([weight_attr[l] for l in attr_idx])
|
|
dataset_samples_weight = dataset_samples_weight / np.linalg.norm(dataset_samples_weight)
|
|
if multi_dict is not None:
|
|
# check if all keys are in the multi_dict
|
|
for k in multi_dict:
|
|
assert k in unique_attr_names, f"{k} not in {unique_attr_names}"
|
|
# scale weights
|
|
multiplier_samples = np.array([multi_dict.get(item[attr_name], 1.0) for item in items])
|
|
dataset_samples_weight *= multiplier_samples
|
|
return (
|
|
torch.from_numpy(dataset_samples_weight).float(),
|
|
unique_attr_names,
|
|
np.unique(dataset_samples_weight).tolist(),
|
|
)
|
|
|
|
|
|
class VitsDataset(TTSDataset):
|
|
def __init__(self, model_args, *args, **kwargs):
|
|
super().__init__(*args, **kwargs)
|
|
self.pad_id = self.tokenizer.characters.pad_id
|
|
self.model_args = model_args
|
|
|
|
def __getitem__(self, idx):
|
|
item = self.samples[idx]
|
|
raw_text = item["text"]
|
|
|
|
wav, _ = load_audio(item["audio_file"])
|
|
if self.model_args.encoder_sample_rate is not None:
|
|
if wav.size(1) % self.model_args.encoder_sample_rate != 0:
|
|
wav = wav[:, : -int(wav.size(1) % self.model_args.encoder_sample_rate)]
|
|
|
|
wav_filename = os.path.basename(item["audio_file"])
|
|
|
|
token_ids = self.get_token_ids(idx, item["text"])
|
|
|
|
# after phonemization the text length may change
|
|
# this is a shameful 🤭 hack to prevent longer phonemes
|
|
# TODO: find a better fix
|
|
if len(token_ids) > self.max_text_len or wav.shape[1] < self.min_audio_len:
|
|
self.rescue_item_idx += 1
|
|
return self.__getitem__(self.rescue_item_idx)
|
|
|
|
return {
|
|
"raw_text": raw_text,
|
|
"token_ids": token_ids,
|
|
"token_len": len(token_ids),
|
|
"wav": wav,
|
|
"wav_file": wav_filename,
|
|
"speaker_name": item["speaker_name"],
|
|
"language_name": item["language"],
|
|
"audio_unique_name": item["audio_unique_name"],
|
|
}
|
|
|
|
@property
|
|
def lengths(self):
|
|
lens = []
|
|
for item in self.samples:
|
|
_, wav_file, *_ = _parse_sample(item)
|
|
audio_len = os.path.getsize(wav_file) / 16 * 8 # assuming 16bit audio
|
|
lens.append(audio_len)
|
|
return lens
|
|
|
|
def collate_fn(self, batch):
|
|
"""
|
|
Return Shapes:
|
|
- tokens: :math:`[B, T]`
|
|
- token_lens :math:`[B]`
|
|
- token_rel_lens :math:`[B]`
|
|
- waveform: :math:`[B, 1, T]`
|
|
- waveform_lens: :math:`[B]`
|
|
- waveform_rel_lens: :math:`[B]`
|
|
- speaker_names: :math:`[B]`
|
|
- language_names: :math:`[B]`
|
|
- audiofile_paths: :math:`[B]`
|
|
- raw_texts: :math:`[B]`
|
|
- audio_unique_names: :math:`[B]`
|
|
"""
|
|
# convert list of dicts to dict of lists
|
|
B = len(batch)
|
|
batch = {k: [dic[k] for dic in batch] for k in batch[0]}
|
|
|
|
_, ids_sorted_decreasing = torch.sort(
|
|
torch.LongTensor([x.size(1) for x in batch["wav"]]), dim=0, descending=True
|
|
)
|
|
|
|
max_text_len = max([len(x) for x in batch["token_ids"]])
|
|
token_lens = torch.LongTensor(batch["token_len"])
|
|
token_rel_lens = token_lens / token_lens.max()
|
|
|
|
wav_lens = [w.shape[1] for w in batch["wav"]]
|
|
wav_lens = torch.LongTensor(wav_lens)
|
|
wav_lens_max = torch.max(wav_lens)
|
|
wav_rel_lens = wav_lens / wav_lens_max
|
|
|
|
token_padded = torch.LongTensor(B, max_text_len)
|
|
wav_padded = torch.FloatTensor(B, 1, wav_lens_max)
|
|
token_padded = token_padded.zero_() + self.pad_id
|
|
wav_padded = wav_padded.zero_() + self.pad_id
|
|
for i in range(len(ids_sorted_decreasing)):
|
|
token_ids = batch["token_ids"][i]
|
|
token_padded[i, : batch["token_len"][i]] = torch.LongTensor(token_ids)
|
|
|
|
wav = batch["wav"][i]
|
|
wav_padded[i, :, : wav.size(1)] = torch.FloatTensor(wav)
|
|
|
|
return {
|
|
"tokens": token_padded,
|
|
"token_lens": token_lens,
|
|
"token_rel_lens": token_rel_lens,
|
|
"waveform": wav_padded, # (B x T)
|
|
"waveform_lens": wav_lens, # (B)
|
|
"waveform_rel_lens": wav_rel_lens,
|
|
"speaker_names": batch["speaker_name"],
|
|
"language_names": batch["language_name"],
|
|
"audio_files": batch["wav_file"],
|
|
"raw_text": batch["raw_text"],
|
|
"audio_unique_names": batch["audio_unique_name"],
|
|
}
|
|
|
|
|
|
##############################
|
|
# MODEL DEFINITION
|
|
##############################
|
|
|
|
|
|
@dataclass
|
|
class VitsArgs(Coqpit):
|
|
"""VITS model arguments.
|
|
|
|
Args:
|
|
|
|
num_chars (int):
|
|
Number of characters in the vocabulary. Defaults to 100.
|
|
|
|
out_channels (int):
|
|
Number of output channels of the decoder. Defaults to 513.
|
|
|
|
spec_segment_size (int):
|
|
Decoder input segment size. Defaults to 32 `(32 * hoplength = waveform length)`.
|
|
|
|
hidden_channels (int):
|
|
Number of hidden channels of the model. Defaults to 192.
|
|
|
|
hidden_channels_ffn_text_encoder (int):
|
|
Number of hidden channels of the feed-forward layers of the text encoder transformer. Defaults to 256.
|
|
|
|
num_heads_text_encoder (int):
|
|
Number of attention heads of the text encoder transformer. Defaults to 2.
|
|
|
|
num_layers_text_encoder (int):
|
|
Number of transformer layers in the text encoder. Defaults to 6.
|
|
|
|
kernel_size_text_encoder (int):
|
|
Kernel size of the text encoder transformer FFN layers. Defaults to 3.
|
|
|
|
dropout_p_text_encoder (float):
|
|
Dropout rate of the text encoder. Defaults to 0.1.
|
|
|
|
dropout_p_duration_predictor (float):
|
|
Dropout rate of the duration predictor. Defaults to 0.1.
|
|
|
|
kernel_size_posterior_encoder (int):
|
|
Kernel size of the posterior encoder's WaveNet layers. Defaults to 5.
|
|
|
|
dilatation_posterior_encoder (int):
|
|
Dilation rate of the posterior encoder's WaveNet layers. Defaults to 1.
|
|
|
|
num_layers_posterior_encoder (int):
|
|
Number of posterior encoder's WaveNet layers. Defaults to 16.
|
|
|
|
kernel_size_flow (int):
|
|
Kernel size of the Residual Coupling layers of the flow network. Defaults to 5.
|
|
|
|
dilatation_flow (int):
|
|
Dilation rate of the Residual Coupling WaveNet layers of the flow network. Defaults to 1.
|
|
|
|
num_layers_flow (int):
|
|
Number of Residual Coupling WaveNet layers of the flow network. Defaults to 6.
|
|
|
|
resblock_type_decoder (str):
|
|
Type of the residual block in the decoder network. Defaults to "1".
|
|
|
|
resblock_kernel_sizes_decoder (List[int]):
|
|
Kernel sizes of the residual blocks in the decoder network. Defaults to `[3, 7, 11]`.
|
|
|
|
resblock_dilation_sizes_decoder (List[List[int]]):
|
|
Dilation sizes of the residual blocks in the decoder network. Defaults to `[[1, 3, 5], [1, 3, 5], [1, 3, 5]]`.
|
|
|
|
upsample_rates_decoder (List[int]):
|
|
Upsampling rates for each concecutive upsampling layer in the decoder network. The multiply of these
|
|
values must be equal to the kop length used for computing spectrograms. Defaults to `[8, 8, 2, 2]`.
|
|
|
|
upsample_initial_channel_decoder (int):
|
|
Number of hidden channels of the first upsampling convolution layer of the decoder network. Defaults to 512.
|
|
|
|
upsample_kernel_sizes_decoder (List[int]):
|
|
Kernel sizes for each upsampling layer of the decoder network. Defaults to `[16, 16, 4, 4]`.
|
|
|
|
periods_multi_period_discriminator (List[int]):
|
|
Periods values for Vits Multi-Period Discriminator. Defaults to `[2, 3, 5, 7, 11]`.
|
|
|
|
use_sdp (bool):
|
|
Use Stochastic Duration Predictor. Defaults to True.
|
|
|
|
noise_scale (float):
|
|
Noise scale used for the sample noise tensor in training. Defaults to 1.0.
|
|
|
|
inference_noise_scale (float):
|
|
Noise scale used for the sample noise tensor in inference. Defaults to 0.667.
|
|
|
|
length_scale (float):
|
|
Scale factor for the predicted duration values. Smaller values result faster speech. Defaults to 1.
|
|
|
|
noise_scale_dp (float):
|
|
Noise scale used by the Stochastic Duration Predictor sample noise in training. Defaults to 1.0.
|
|
|
|
inference_noise_scale_dp (float):
|
|
Noise scale for the Stochastic Duration Predictor in inference. Defaults to 0.8.
|
|
|
|
max_inference_len (int):
|
|
Maximum inference length to limit the memory use. Defaults to None.
|
|
|
|
init_discriminator (bool):
|
|
Initialize the disciminator network if set True. Set False for inference. Defaults to True.
|
|
|
|
use_spectral_norm_disriminator (bool):
|
|
Use spectral normalization over weight norm in the discriminator. Defaults to False.
|
|
|
|
use_speaker_embedding (bool):
|
|
Enable/Disable speaker embedding for multi-speaker models. Defaults to False.
|
|
|
|
num_speakers (int):
|
|
Number of speakers for the speaker embedding layer. Defaults to 0.
|
|
|
|
speakers_file (str):
|
|
Path to the speaker mapping file for the Speaker Manager. Defaults to None.
|
|
|
|
speaker_embedding_channels (int):
|
|
Number of speaker embedding channels. Defaults to 256.
|
|
|
|
use_d_vector_file (bool):
|
|
Enable/Disable the use of d-vectors for multi-speaker training. Defaults to False.
|
|
|
|
d_vector_file (List[str]):
|
|
List of paths to the files including pre-computed speaker embeddings. Defaults to None.
|
|
|
|
d_vector_dim (int):
|
|
Number of d-vector channels. Defaults to 0.
|
|
|
|
detach_dp_input (bool):
|
|
Detach duration predictor's input from the network for stopping the gradients. Defaults to True.
|
|
|
|
use_language_embedding (bool):
|
|
Enable/Disable language embedding for multilingual models. Defaults to False.
|
|
|
|
embedded_language_dim (int):
|
|
Number of language embedding channels. Defaults to 4.
|
|
|
|
num_languages (int):
|
|
Number of languages for the language embedding layer. Defaults to 0.
|
|
|
|
language_ids_file (str):
|
|
Path to the language mapping file for the Language Manager. Defaults to None.
|
|
|
|
use_speaker_encoder_as_loss (bool):
|
|
Enable/Disable Speaker Consistency Loss (SCL). Defaults to False.
|
|
|
|
speaker_encoder_config_path (str):
|
|
Path to the file speaker encoder config file, to use for SCL. Defaults to "".
|
|
|
|
speaker_encoder_model_path (str):
|
|
Path to the file speaker encoder checkpoint file, to use for SCL. Defaults to "".
|
|
|
|
condition_dp_on_speaker (bool):
|
|
Condition the duration predictor on the speaker embedding. Defaults to True.
|
|
|
|
freeze_encoder (bool):
|
|
Freeze the encoder weigths during training. Defaults to False.
|
|
|
|
freeze_DP (bool):
|
|
Freeze the duration predictor weigths during training. Defaults to False.
|
|
|
|
freeze_PE (bool):
|
|
Freeze the posterior encoder weigths during training. Defaults to False.
|
|
|
|
freeze_flow_encoder (bool):
|
|
Freeze the flow encoder weigths during training. Defaults to False.
|
|
|
|
freeze_waveform_decoder (bool):
|
|
Freeze the waveform decoder weigths during training. Defaults to False.
|
|
|
|
encoder_sample_rate (int):
|
|
If not None this sample rate will be used for training the Posterior Encoder,
|
|
flow, text_encoder and duration predictor. The decoder part (vocoder) will be
|
|
trained with the `config.audio.sample_rate`. Defaults to None.
|
|
|
|
interpolate_z (bool):
|
|
If `encoder_sample_rate` not None and this parameter True the nearest interpolation
|
|
will be used to upsampling the latent variable z with the sampling rate `encoder_sample_rate`
|
|
to the `config.audio.sample_rate`. If it is False you will need to add extra
|
|
`upsample_rates_decoder` to match the shape. Defaults to True.
|
|
|
|
"""
|
|
|
|
num_chars: int = 100
|
|
out_channels: int = 513
|
|
spec_segment_size: int = 32
|
|
hidden_channels: int = 192
|
|
hidden_channels_ffn_text_encoder: int = 768
|
|
num_heads_text_encoder: int = 2
|
|
num_layers_text_encoder: int = 6
|
|
kernel_size_text_encoder: int = 3
|
|
dropout_p_text_encoder: float = 0.1
|
|
dropout_p_duration_predictor: float = 0.5
|
|
kernel_size_posterior_encoder: int = 5
|
|
dilation_rate_posterior_encoder: int = 1
|
|
num_layers_posterior_encoder: int = 16
|
|
kernel_size_flow: int = 5
|
|
dilation_rate_flow: int = 1
|
|
num_layers_flow: int = 4
|
|
resblock_type_decoder: str = "1"
|
|
resblock_kernel_sizes_decoder: List[int] = field(default_factory=lambda: [3, 7, 11])
|
|
resblock_dilation_sizes_decoder: List[List[int]] = field(default_factory=lambda: [[1, 3, 5], [1, 3, 5], [1, 3, 5]])
|
|
upsample_rates_decoder: List[int] = field(default_factory=lambda: [8, 8, 2, 2])
|
|
upsample_initial_channel_decoder: int = 512
|
|
upsample_kernel_sizes_decoder: List[int] = field(default_factory=lambda: [16, 16, 4, 4])
|
|
periods_multi_period_discriminator: List[int] = field(default_factory=lambda: [2, 3, 5, 7, 11])
|
|
use_sdp: bool = True
|
|
noise_scale: float = 1.0
|
|
inference_noise_scale: float = 0.667
|
|
length_scale: float = 1
|
|
noise_scale_dp: float = 1.0
|
|
inference_noise_scale_dp: float = 1.0
|
|
max_inference_len: int = None
|
|
init_discriminator: bool = True
|
|
use_spectral_norm_disriminator: bool = False
|
|
use_speaker_embedding: bool = False
|
|
num_speakers: int = 0
|
|
speakers_file: str = None
|
|
d_vector_file: List[str] = None
|
|
speaker_embedding_channels: int = 256
|
|
use_d_vector_file: bool = False
|
|
d_vector_dim: int = 0
|
|
detach_dp_input: bool = True
|
|
use_language_embedding: bool = False
|
|
embedded_language_dim: int = 4
|
|
num_languages: int = 0
|
|
language_ids_file: str = None
|
|
use_speaker_encoder_as_loss: bool = False
|
|
speaker_encoder_config_path: str = ""
|
|
speaker_encoder_model_path: str = ""
|
|
condition_dp_on_speaker: bool = True
|
|
freeze_encoder: bool = False
|
|
freeze_DP: bool = False
|
|
freeze_PE: bool = False
|
|
freeze_flow_decoder: bool = False
|
|
freeze_waveform_decoder: bool = False
|
|
encoder_sample_rate: int = None
|
|
interpolate_z: bool = True
|
|
reinit_DP: bool = False
|
|
reinit_text_encoder: bool = False
|
|
|
|
|
|
class Vits(BaseTTS):
|
|
"""VITS TTS model
|
|
|
|
Paper::
|
|
https://arxiv.org/pdf/2106.06103.pdf
|
|
|
|
Paper Abstract::
|
|
Several recent end-to-end text-to-speech (TTS) models enabling single-stage training and parallel
|
|
sampling have been proposed, but their sample quality does not match that of two-stage TTS systems.
|
|
In this work, we present a parallel endto-end TTS method that generates more natural sounding audio than
|
|
current two-stage models. Our method adopts variational inference augmented with normalizing flows and
|
|
an adversarial training process, which improves the expressive power of generative modeling. We also propose a
|
|
stochastic duration predictor to synthesize speech with diverse rhythms from input text. With the
|
|
uncertainty modeling over latent variables and the stochastic duration predictor, our method expresses the
|
|
natural one-to-many relationship in which a text input can be spoken in multiple ways
|
|
with different pitches and rhythms. A subjective human evaluation (mean opinion score, or MOS)
|
|
on the LJ Speech, a single speaker dataset, shows that our method outperforms the best publicly
|
|
available TTS systems and achieves a MOS comparable to ground truth.
|
|
|
|
Check :class:`TTS.tts.configs.vits_config.VitsConfig` for class arguments.
|
|
|
|
Examples:
|
|
>>> from TTS.tts.configs.vits_config import VitsConfig
|
|
>>> from TTS.tts.models.vits import Vits
|
|
>>> config = VitsConfig()
|
|
>>> model = Vits(config)
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
config: Coqpit,
|
|
ap: "AudioProcessor" = None,
|
|
tokenizer: "TTSTokenizer" = None,
|
|
speaker_manager: SpeakerManager = None,
|
|
language_manager: LanguageManager = None,
|
|
):
|
|
super().__init__(config, ap, tokenizer, speaker_manager, language_manager)
|
|
|
|
self.init_multispeaker(config)
|
|
self.init_multilingual(config)
|
|
self.init_upsampling()
|
|
|
|
self.length_scale = self.args.length_scale
|
|
self.noise_scale = self.args.noise_scale
|
|
self.inference_noise_scale = self.args.inference_noise_scale
|
|
self.inference_noise_scale_dp = self.args.inference_noise_scale_dp
|
|
self.noise_scale_dp = self.args.noise_scale_dp
|
|
self.max_inference_len = self.args.max_inference_len
|
|
self.spec_segment_size = self.args.spec_segment_size
|
|
|
|
self.text_encoder = TextEncoder(
|
|
self.args.num_chars,
|
|
self.args.hidden_channels,
|
|
self.args.hidden_channels,
|
|
self.args.hidden_channels_ffn_text_encoder,
|
|
self.args.num_heads_text_encoder,
|
|
self.args.num_layers_text_encoder,
|
|
self.args.kernel_size_text_encoder,
|
|
self.args.dropout_p_text_encoder,
|
|
language_emb_dim=self.embedded_language_dim,
|
|
)
|
|
|
|
self.posterior_encoder = PosteriorEncoder(
|
|
self.args.out_channels,
|
|
self.args.hidden_channels,
|
|
self.args.hidden_channels,
|
|
kernel_size=self.args.kernel_size_posterior_encoder,
|
|
dilation_rate=self.args.dilation_rate_posterior_encoder,
|
|
num_layers=self.args.num_layers_posterior_encoder,
|
|
cond_channels=self.embedded_speaker_dim,
|
|
)
|
|
|
|
self.flow = ResidualCouplingBlocks(
|
|
self.args.hidden_channels,
|
|
self.args.hidden_channels,
|
|
kernel_size=self.args.kernel_size_flow,
|
|
dilation_rate=self.args.dilation_rate_flow,
|
|
num_layers=self.args.num_layers_flow,
|
|
cond_channels=self.embedded_speaker_dim,
|
|
)
|
|
|
|
if self.args.use_sdp:
|
|
self.duration_predictor = StochasticDurationPredictor(
|
|
self.args.hidden_channels,
|
|
192,
|
|
3,
|
|
self.args.dropout_p_duration_predictor,
|
|
4,
|
|
cond_channels=self.embedded_speaker_dim if self.args.condition_dp_on_speaker else 0,
|
|
language_emb_dim=self.embedded_language_dim,
|
|
)
|
|
else:
|
|
self.duration_predictor = DurationPredictor(
|
|
self.args.hidden_channels,
|
|
256,
|
|
3,
|
|
self.args.dropout_p_duration_predictor,
|
|
cond_channels=self.embedded_speaker_dim,
|
|
language_emb_dim=self.embedded_language_dim,
|
|
)
|
|
|
|
self.waveform_decoder = HifiganGenerator(
|
|
self.args.hidden_channels,
|
|
1,
|
|
self.args.resblock_type_decoder,
|
|
self.args.resblock_dilation_sizes_decoder,
|
|
self.args.resblock_kernel_sizes_decoder,
|
|
self.args.upsample_kernel_sizes_decoder,
|
|
self.args.upsample_initial_channel_decoder,
|
|
self.args.upsample_rates_decoder,
|
|
inference_padding=0,
|
|
cond_channels=self.embedded_speaker_dim,
|
|
conv_pre_weight_norm=False,
|
|
conv_post_weight_norm=False,
|
|
conv_post_bias=False,
|
|
)
|
|
|
|
if self.args.init_discriminator:
|
|
self.disc = VitsDiscriminator(
|
|
periods=self.args.periods_multi_period_discriminator,
|
|
use_spectral_norm=self.args.use_spectral_norm_disriminator,
|
|
)
|
|
|
|
@property
|
|
def device(self):
|
|
return next(self.parameters()).device
|
|
|
|
def init_multispeaker(self, config: Coqpit):
|
|
"""Initialize multi-speaker modules of a model. A model can be trained either with a speaker embedding layer
|
|
or with external `d_vectors` computed from a speaker encoder model.
|
|
|
|
You must provide a `speaker_manager` at initialization to set up the multi-speaker modules.
|
|
|
|
Args:
|
|
config (Coqpit): Model configuration.
|
|
data (List, optional): Dataset items to infer number of speakers. Defaults to None.
|
|
"""
|
|
self.embedded_speaker_dim = 0
|
|
self.num_speakers = self.args.num_speakers
|
|
self.audio_transform = None
|
|
|
|
if self.speaker_manager:
|
|
self.num_speakers = self.speaker_manager.num_speakers
|
|
|
|
if self.args.use_speaker_embedding:
|
|
self._init_speaker_embedding()
|
|
|
|
if self.args.use_d_vector_file:
|
|
self._init_d_vector()
|
|
|
|
# TODO: make this a function
|
|
if self.args.use_speaker_encoder_as_loss:
|
|
if self.speaker_manager.encoder is None and (
|
|
not self.args.speaker_encoder_model_path or not self.args.speaker_encoder_config_path
|
|
):
|
|
raise RuntimeError(
|
|
" [!] To use the speaker consistency loss (SCL) you need to specify speaker_encoder_model_path and speaker_encoder_config_path !!"
|
|
)
|
|
|
|
self.speaker_manager.encoder.eval()
|
|
print(" > External Speaker Encoder Loaded !!")
|
|
|
|
if (
|
|
hasattr(self.speaker_manager.encoder, "audio_config")
|
|
and self.config.audio.sample_rate != self.speaker_manager.encoder.audio_config["sample_rate"]
|
|
):
|
|
self.audio_transform = torchaudio.transforms.Resample(
|
|
orig_freq=self.config.audio.sample_rate,
|
|
new_freq=self.speaker_manager.encoder.audio_config["sample_rate"],
|
|
)
|
|
|
|
def _init_speaker_embedding(self):
|
|
# pylint: disable=attribute-defined-outside-init
|
|
if self.num_speakers > 0:
|
|
print(" > initialization of speaker-embedding layers.")
|
|
self.embedded_speaker_dim = self.args.speaker_embedding_channels
|
|
self.emb_g = nn.Embedding(self.num_speakers, self.embedded_speaker_dim)
|
|
|
|
def _init_d_vector(self):
|
|
# pylint: disable=attribute-defined-outside-init
|
|
if hasattr(self, "emb_g"):
|
|
raise ValueError("[!] Speaker embedding layer already initialized before d_vector settings.")
|
|
self.embedded_speaker_dim = self.args.d_vector_dim
|
|
|
|
def init_multilingual(self, config: Coqpit):
|
|
"""Initialize multilingual modules of a model.
|
|
|
|
Args:
|
|
config (Coqpit): Model configuration.
|
|
"""
|
|
if self.args.language_ids_file is not None:
|
|
self.language_manager = LanguageManager(language_ids_file_path=config.language_ids_file)
|
|
|
|
if self.args.use_language_embedding and self.language_manager:
|
|
print(" > initialization of language-embedding layers.")
|
|
self.num_languages = self.language_manager.num_languages
|
|
self.embedded_language_dim = self.args.embedded_language_dim
|
|
self.emb_l = nn.Embedding(self.num_languages, self.embedded_language_dim)
|
|
torch.nn.init.xavier_uniform_(self.emb_l.weight)
|
|
else:
|
|
self.embedded_language_dim = 0
|
|
|
|
def init_upsampling(self):
|
|
"""
|
|
Initialize upsampling modules of a model.
|
|
"""
|
|
if self.args.encoder_sample_rate:
|
|
self.interpolate_factor = self.config.audio["sample_rate"] / self.args.encoder_sample_rate
|
|
self.audio_resampler = torchaudio.transforms.Resample(
|
|
orig_freq=self.config.audio["sample_rate"], new_freq=self.args.encoder_sample_rate
|
|
) # pylint: disable=W0201
|
|
|
|
def on_epoch_start(self, trainer): # pylint: disable=W0613
|
|
"""Freeze layers at the beginning of an epoch"""
|
|
self._freeze_layers()
|
|
# set the device of speaker encoder
|
|
if self.args.use_speaker_encoder_as_loss:
|
|
self.speaker_manager.encoder = self.speaker_manager.encoder.to(self.device)
|
|
|
|
def on_init_end(self, trainer): # pylint: disable=W0613
|
|
"""Reinit layes if needed"""
|
|
if self.args.reinit_DP:
|
|
before_dict = get_module_weights_sum(self.duration_predictor)
|
|
# Applies weights_reset recursively to every submodule of the duration predictor
|
|
self.duration_predictor.apply(fn=weights_reset)
|
|
after_dict = get_module_weights_sum(self.duration_predictor)
|
|
for key, value in after_dict.items():
|
|
if value == before_dict[key]:
|
|
raise RuntimeError(" [!] The weights of Duration Predictor was not reinit check it !")
|
|
print(" > Duration Predictor was reinit.")
|
|
|
|
if self.args.reinit_text_encoder:
|
|
before_dict = get_module_weights_sum(self.text_encoder)
|
|
# Applies weights_reset recursively to every submodule of the duration predictor
|
|
self.text_encoder.apply(fn=weights_reset)
|
|
after_dict = get_module_weights_sum(self.text_encoder)
|
|
for key, value in after_dict.items():
|
|
if value == before_dict[key]:
|
|
raise RuntimeError(" [!] The weights of Text Encoder was not reinit check it !")
|
|
print(" > Text Encoder was reinit.")
|
|
|
|
def get_aux_input(self, aux_input: Dict):
|
|
sid, g, lid, _ = self._set_cond_input(aux_input)
|
|
return {"speaker_ids": sid, "style_wav": None, "d_vectors": g, "language_ids": lid}
|
|
|
|
def _freeze_layers(self):
|
|
if self.args.freeze_encoder:
|
|
for param in self.text_encoder.parameters():
|
|
param.requires_grad = False
|
|
|
|
if hasattr(self, "emb_l"):
|
|
for param in self.emb_l.parameters():
|
|
param.requires_grad = False
|
|
|
|
if self.args.freeze_PE:
|
|
for param in self.posterior_encoder.parameters():
|
|
param.requires_grad = False
|
|
|
|
if self.args.freeze_DP:
|
|
for param in self.duration_predictor.parameters():
|
|
param.requires_grad = False
|
|
|
|
if self.args.freeze_flow_decoder:
|
|
for param in self.flow.parameters():
|
|
param.requires_grad = False
|
|
|
|
if self.args.freeze_waveform_decoder:
|
|
for param in self.waveform_decoder.parameters():
|
|
param.requires_grad = False
|
|
|
|
@staticmethod
|
|
def _set_cond_input(aux_input: Dict):
|
|
"""Set the speaker conditioning input based on the multi-speaker mode."""
|
|
sid, g, lid, durations = None, None, None, None
|
|
if "speaker_ids" in aux_input and aux_input["speaker_ids"] is not None:
|
|
sid = aux_input["speaker_ids"]
|
|
if sid.ndim == 0:
|
|
sid = sid.unsqueeze_(0)
|
|
if "d_vectors" in aux_input and aux_input["d_vectors"] is not None:
|
|
g = F.normalize(aux_input["d_vectors"]).unsqueeze(-1)
|
|
if g.ndim == 2:
|
|
g = g.unsqueeze_(0)
|
|
|
|
if "language_ids" in aux_input and aux_input["language_ids"] is not None:
|
|
lid = aux_input["language_ids"]
|
|
if lid.ndim == 0:
|
|
lid = lid.unsqueeze_(0)
|
|
|
|
if "durations" in aux_input and aux_input["durations"] is not None:
|
|
durations = aux_input["durations"]
|
|
|
|
return sid, g, lid, durations
|
|
|
|
def _set_speaker_input(self, aux_input: Dict):
|
|
d_vectors = aux_input.get("d_vectors", None)
|
|
speaker_ids = aux_input.get("speaker_ids", None)
|
|
|
|
if d_vectors is not None and speaker_ids is not None:
|
|
raise ValueError("[!] Cannot use d-vectors and speaker-ids together.")
|
|
|
|
if speaker_ids is not None and not hasattr(self, "emb_g"):
|
|
raise ValueError("[!] Cannot use speaker-ids without enabling speaker embedding.")
|
|
|
|
g = speaker_ids if speaker_ids is not None else d_vectors
|
|
return g
|
|
|
|
def forward_mas(self, outputs, z_p, m_p, logs_p, x, x_mask, y_mask, g, lang_emb):
|
|
# find the alignment path
|
|
attn_mask = torch.unsqueeze(x_mask, -1) * torch.unsqueeze(y_mask, 2)
|
|
with torch.no_grad():
|
|
o_scale = torch.exp(-2 * logs_p)
|
|
logp1 = torch.sum(-0.5 * math.log(2 * math.pi) - logs_p, [1]).unsqueeze(-1) # [b, t, 1]
|
|
logp2 = torch.einsum("klm, kln -> kmn", [o_scale, -0.5 * (z_p**2)])
|
|
logp3 = torch.einsum("klm, kln -> kmn", [m_p * o_scale, z_p])
|
|
logp4 = torch.sum(-0.5 * (m_p**2) * o_scale, [1]).unsqueeze(-1) # [b, t, 1]
|
|
logp = logp2 + logp3 + logp1 + logp4
|
|
attn = maximum_path(logp, attn_mask.squeeze(1)).unsqueeze(1).detach() # [b, 1, t, t']
|
|
|
|
# duration predictor
|
|
attn_durations = attn.sum(3)
|
|
if self.args.use_sdp:
|
|
loss_duration = self.duration_predictor(
|
|
x.detach() if self.args.detach_dp_input else x,
|
|
x_mask,
|
|
attn_durations,
|
|
g=g.detach() if self.args.detach_dp_input and g is not None else g,
|
|
lang_emb=lang_emb.detach() if self.args.detach_dp_input and lang_emb is not None else lang_emb,
|
|
)
|
|
loss_duration = loss_duration / torch.sum(x_mask)
|
|
else:
|
|
attn_log_durations = torch.log(attn_durations + 1e-6) * x_mask
|
|
log_durations = self.duration_predictor(
|
|
x.detach() if self.args.detach_dp_input else x,
|
|
x_mask,
|
|
g=g.detach() if self.args.detach_dp_input and g is not None else g,
|
|
lang_emb=lang_emb.detach() if self.args.detach_dp_input and lang_emb is not None else lang_emb,
|
|
)
|
|
loss_duration = torch.sum((log_durations - attn_log_durations) ** 2, [1, 2]) / torch.sum(x_mask)
|
|
outputs["loss_duration"] = loss_duration
|
|
return outputs, attn
|
|
|
|
def upsampling_z(self, z, slice_ids=None, y_lengths=None, y_mask=None):
|
|
spec_segment_size = self.spec_segment_size
|
|
if self.args.encoder_sample_rate:
|
|
# recompute the slices and spec_segment_size if needed
|
|
slice_ids = slice_ids * int(self.interpolate_factor) if slice_ids is not None else slice_ids
|
|
spec_segment_size = spec_segment_size * int(self.interpolate_factor)
|
|
# interpolate z if needed
|
|
if self.args.interpolate_z:
|
|
z = torch.nn.functional.interpolate(z, scale_factor=[self.interpolate_factor], mode="linear").squeeze(0)
|
|
# recompute the mask if needed
|
|
if y_lengths is not None and y_mask is not None:
|
|
y_mask = (
|
|
sequence_mask(y_lengths * self.interpolate_factor, None).to(y_mask.dtype).unsqueeze(1)
|
|
) # [B, 1, T_dec_resampled]
|
|
|
|
return z, spec_segment_size, slice_ids, y_mask
|
|
|
|
def forward( # pylint: disable=dangerous-default-value
|
|
self,
|
|
x: torch.tensor,
|
|
x_lengths: torch.tensor,
|
|
y: torch.tensor,
|
|
y_lengths: torch.tensor,
|
|
waveform: torch.tensor,
|
|
aux_input={"d_vectors": None, "speaker_ids": None, "language_ids": None},
|
|
) -> Dict:
|
|
"""Forward pass of the model.
|
|
|
|
Args:
|
|
x (torch.tensor): Batch of input character sequence IDs.
|
|
x_lengths (torch.tensor): Batch of input character sequence lengths.
|
|
y (torch.tensor): Batch of input spectrograms.
|
|
y_lengths (torch.tensor): Batch of input spectrogram lengths.
|
|
waveform (torch.tensor): Batch of ground truth waveforms per sample.
|
|
aux_input (dict, optional): Auxiliary inputs for multi-speaker and multi-lingual training.
|
|
Defaults to {"d_vectors": None, "speaker_ids": None, "language_ids": None}.
|
|
|
|
Returns:
|
|
Dict: model outputs keyed by the output name.
|
|
|
|
Shapes:
|
|
- x: :math:`[B, T_seq]`
|
|
- x_lengths: :math:`[B]`
|
|
- y: :math:`[B, C, T_spec]`
|
|
- y_lengths: :math:`[B]`
|
|
- waveform: :math:`[B, 1, T_wav]`
|
|
- d_vectors: :math:`[B, C, 1]`
|
|
- speaker_ids: :math:`[B]`
|
|
- language_ids: :math:`[B]`
|
|
|
|
Return Shapes:
|
|
- model_outputs: :math:`[B, 1, T_wav]`
|
|
- alignments: :math:`[B, T_seq, T_dec]`
|
|
- z: :math:`[B, C, T_dec]`
|
|
- z_p: :math:`[B, C, T_dec]`
|
|
- m_p: :math:`[B, C, T_dec]`
|
|
- logs_p: :math:`[B, C, T_dec]`
|
|
- m_q: :math:`[B, C, T_dec]`
|
|
- logs_q: :math:`[B, C, T_dec]`
|
|
- waveform_seg: :math:`[B, 1, spec_seg_size * hop_length]`
|
|
- gt_spk_emb: :math:`[B, 1, speaker_encoder.proj_dim]`
|
|
- syn_spk_emb: :math:`[B, 1, speaker_encoder.proj_dim]`
|
|
"""
|
|
outputs = {}
|
|
sid, g, lid, _ = self._set_cond_input(aux_input)
|
|
# speaker embedding
|
|
if self.args.use_speaker_embedding and sid is not None:
|
|
g = self.emb_g(sid).unsqueeze(-1) # [b, h, 1]
|
|
|
|
# language embedding
|
|
lang_emb = None
|
|
if self.args.use_language_embedding and lid is not None:
|
|
lang_emb = self.emb_l(lid).unsqueeze(-1)
|
|
|
|
x, m_p, logs_p, x_mask = self.text_encoder(x, x_lengths, lang_emb=lang_emb)
|
|
|
|
# posterior encoder
|
|
z, m_q, logs_q, y_mask = self.posterior_encoder(y, y_lengths, g=g)
|
|
|
|
# flow layers
|
|
z_p = self.flow(z, y_mask, g=g)
|
|
|
|
# duration predictor
|
|
outputs, attn = self.forward_mas(outputs, z_p, m_p, logs_p, x, x_mask, y_mask, g=g, lang_emb=lang_emb)
|
|
|
|
# expand prior
|
|
m_p = torch.einsum("klmn, kjm -> kjn", [attn, m_p])
|
|
logs_p = torch.einsum("klmn, kjm -> kjn", [attn, logs_p])
|
|
|
|
# select a random feature segment for the waveform decoder
|
|
z_slice, slice_ids = rand_segments(z, y_lengths, self.spec_segment_size, let_short_samples=True, pad_short=True)
|
|
|
|
# interpolate z if needed
|
|
z_slice, spec_segment_size, slice_ids, _ = self.upsampling_z(z_slice, slice_ids=slice_ids)
|
|
|
|
o = self.waveform_decoder(z_slice, g=g)
|
|
|
|
wav_seg = segment(
|
|
waveform,
|
|
slice_ids * self.config.audio.hop_length,
|
|
spec_segment_size * self.config.audio.hop_length,
|
|
pad_short=True,
|
|
)
|
|
|
|
if self.args.use_speaker_encoder_as_loss and self.speaker_manager.encoder is not None:
|
|
# concate generated and GT waveforms
|
|
wavs_batch = torch.cat((wav_seg, o), dim=0)
|
|
|
|
# resample audio to speaker encoder sample_rate
|
|
# pylint: disable=W0105
|
|
if self.audio_transform is not None:
|
|
wavs_batch = self.audio_transform(wavs_batch)
|
|
|
|
pred_embs = self.speaker_manager.encoder.forward(wavs_batch, l2_norm=True)
|
|
|
|
# split generated and GT speaker embeddings
|
|
gt_spk_emb, syn_spk_emb = torch.chunk(pred_embs, 2, dim=0)
|
|
else:
|
|
gt_spk_emb, syn_spk_emb = None, None
|
|
|
|
outputs.update(
|
|
{
|
|
"model_outputs": o,
|
|
"alignments": attn.squeeze(1),
|
|
"m_p": m_p,
|
|
"logs_p": logs_p,
|
|
"z": z,
|
|
"z_p": z_p,
|
|
"m_q": m_q,
|
|
"logs_q": logs_q,
|
|
"waveform_seg": wav_seg,
|
|
"gt_spk_emb": gt_spk_emb,
|
|
"syn_spk_emb": syn_spk_emb,
|
|
"slice_ids": slice_ids,
|
|
}
|
|
)
|
|
return outputs
|
|
|
|
@staticmethod
|
|
def _set_x_lengths(x, aux_input):
|
|
if "x_lengths" in aux_input and aux_input["x_lengths"] is not None:
|
|
return aux_input["x_lengths"]
|
|
return torch.tensor(x.shape[1:2]).to(x.device)
|
|
|
|
@torch.no_grad()
|
|
def inference(
|
|
self,
|
|
x,
|
|
aux_input={"x_lengths": None, "d_vectors": None, "speaker_ids": None, "language_ids": None, "durations": None},
|
|
): # pylint: disable=dangerous-default-value
|
|
"""
|
|
Note:
|
|
To run in batch mode, provide `x_lengths` else model assumes that the batch size is 1.
|
|
|
|
Shapes:
|
|
- x: :math:`[B, T_seq]`
|
|
- x_lengths: :math:`[B]`
|
|
- d_vectors: :math:`[B, C]`
|
|
- speaker_ids: :math:`[B]`
|
|
|
|
Return Shapes:
|
|
- model_outputs: :math:`[B, 1, T_wav]`
|
|
- alignments: :math:`[B, T_seq, T_dec]`
|
|
- z: :math:`[B, C, T_dec]`
|
|
- z_p: :math:`[B, C, T_dec]`
|
|
- m_p: :math:`[B, C, T_dec]`
|
|
- logs_p: :math:`[B, C, T_dec]`
|
|
"""
|
|
sid, g, lid, durations = self._set_cond_input(aux_input)
|
|
x_lengths = self._set_x_lengths(x, aux_input)
|
|
|
|
# speaker embedding
|
|
if self.args.use_speaker_embedding and sid is not None:
|
|
g = self.emb_g(sid).unsqueeze(-1)
|
|
|
|
# language embedding
|
|
lang_emb = None
|
|
if self.args.use_language_embedding and lid is not None:
|
|
lang_emb = self.emb_l(lid).unsqueeze(-1)
|
|
|
|
x, m_p, logs_p, x_mask = self.text_encoder(x, x_lengths, lang_emb=lang_emb)
|
|
|
|
if durations is None:
|
|
if self.args.use_sdp:
|
|
logw = self.duration_predictor(
|
|
x,
|
|
x_mask,
|
|
g=g if self.args.condition_dp_on_speaker else None,
|
|
reverse=True,
|
|
noise_scale=self.inference_noise_scale_dp,
|
|
lang_emb=lang_emb,
|
|
)
|
|
else:
|
|
logw = self.duration_predictor(
|
|
x, x_mask, g=g if self.args.condition_dp_on_speaker else None, lang_emb=lang_emb
|
|
)
|
|
w = torch.exp(logw) * x_mask * self.length_scale
|
|
else:
|
|
assert durations.shape[-1] == x.shape[-1]
|
|
w = durations.unsqueeze(0)
|
|
|
|
w_ceil = torch.ceil(w)
|
|
y_lengths = torch.clamp_min(torch.sum(w_ceil, [1, 2]), 1).long()
|
|
y_mask = sequence_mask(y_lengths, None).to(x_mask.dtype).unsqueeze(1) # [B, 1, T_dec]
|
|
|
|
attn_mask = x_mask * y_mask.transpose(1, 2) # [B, 1, T_enc] * [B, T_dec, 1]
|
|
attn = generate_path(w_ceil.squeeze(1), attn_mask.squeeze(1).transpose(1, 2))
|
|
|
|
m_p = torch.matmul(attn.transpose(1, 2), m_p.transpose(1, 2)).transpose(1, 2)
|
|
logs_p = torch.matmul(attn.transpose(1, 2), logs_p.transpose(1, 2)).transpose(1, 2)
|
|
|
|
z_p = m_p + torch.randn_like(m_p) * torch.exp(logs_p) * self.inference_noise_scale
|
|
z = self.flow(z_p, y_mask, g=g, reverse=True)
|
|
|
|
# upsampling if needed
|
|
z, _, _, y_mask = self.upsampling_z(z, y_lengths=y_lengths, y_mask=y_mask)
|
|
|
|
o = self.waveform_decoder((z * y_mask)[:, :, : self.max_inference_len], g=g)
|
|
|
|
outputs = {
|
|
"model_outputs": o,
|
|
"alignments": attn.squeeze(1),
|
|
"durations": w_ceil,
|
|
"z": z,
|
|
"z_p": z_p,
|
|
"m_p": m_p,
|
|
"logs_p": logs_p,
|
|
"y_mask": y_mask,
|
|
}
|
|
return outputs
|
|
|
|
@torch.no_grad()
|
|
def inference_voice_conversion(
|
|
self, reference_wav, speaker_id=None, d_vector=None, reference_speaker_id=None, reference_d_vector=None
|
|
):
|
|
"""Inference for voice conversion
|
|
|
|
Args:
|
|
reference_wav (Tensor): Reference wavform. Tensor of shape [B, T]
|
|
speaker_id (Tensor): speaker_id of the target speaker. Tensor of shape [B]
|
|
d_vector (Tensor): d_vector embedding of target speaker. Tensor of shape `[B, C]`
|
|
reference_speaker_id (Tensor): speaker_id of the reference_wav speaker. Tensor of shape [B]
|
|
reference_d_vector (Tensor): d_vector embedding of the reference_wav speaker. Tensor of shape `[B, C]`
|
|
"""
|
|
# compute spectrograms
|
|
y = wav_to_spec(
|
|
reference_wav,
|
|
self.config.audio.fft_size,
|
|
self.config.audio.hop_length,
|
|
self.config.audio.win_length,
|
|
center=False,
|
|
)
|
|
y_lengths = torch.tensor([y.size(-1)]).to(y.device)
|
|
speaker_cond_src = reference_speaker_id if reference_speaker_id is not None else reference_d_vector
|
|
speaker_cond_tgt = speaker_id if speaker_id is not None else d_vector
|
|
wav, _, _ = self.voice_conversion(y, y_lengths, speaker_cond_src, speaker_cond_tgt)
|
|
return wav
|
|
|
|
def voice_conversion(self, y, y_lengths, speaker_cond_src, speaker_cond_tgt):
|
|
"""Forward pass for voice conversion
|
|
|
|
TODO: create an end-point for voice conversion
|
|
|
|
Args:
|
|
y (Tensor): Reference spectrograms. Tensor of shape [B, T, C]
|
|
y_lengths (Tensor): Length of each reference spectrogram. Tensor of shape [B]
|
|
speaker_cond_src (Tensor): Reference speaker ID. Tensor of shape [B,]
|
|
speaker_cond_tgt (Tensor): Target speaker ID. Tensor of shape [B,]
|
|
"""
|
|
assert self.num_speakers > 0, "num_speakers have to be larger than 0."
|
|
# speaker embedding
|
|
if self.args.use_speaker_embedding and not self.args.use_d_vector_file:
|
|
g_src = self.emb_g(torch.from_numpy((np.array(speaker_cond_src))).unsqueeze(0)).unsqueeze(-1)
|
|
g_tgt = self.emb_g(torch.from_numpy((np.array(speaker_cond_tgt))).unsqueeze(0)).unsqueeze(-1)
|
|
elif not self.args.use_speaker_embedding and self.args.use_d_vector_file:
|
|
g_src = F.normalize(speaker_cond_src).unsqueeze(-1)
|
|
g_tgt = F.normalize(speaker_cond_tgt).unsqueeze(-1)
|
|
else:
|
|
raise RuntimeError(" [!] Voice conversion is only supported on multi-speaker models.")
|
|
|
|
z, _, _, y_mask = self.posterior_encoder(y, y_lengths, g=g_src)
|
|
z_p = self.flow(z, y_mask, g=g_src)
|
|
z_hat = self.flow(z_p, y_mask, g=g_tgt, reverse=True)
|
|
o_hat = self.waveform_decoder(z_hat * y_mask, g=g_tgt)
|
|
return o_hat, y_mask, (z, z_p, z_hat)
|
|
|
|
def train_step(self, batch: dict, criterion: nn.Module, optimizer_idx: int) -> Tuple[Dict, Dict]:
|
|
"""Perform a single training step. Run the model forward pass and compute losses.
|
|
|
|
Args:
|
|
batch (Dict): Input tensors.
|
|
criterion (nn.Module): Loss layer designed for the model.
|
|
optimizer_idx (int): Index of optimizer to use. 0 for the generator and 1 for the discriminator networks.
|
|
|
|
Returns:
|
|
Tuple[Dict, Dict]: Model ouputs and computed losses.
|
|
"""
|
|
|
|
spec_lens = batch["spec_lens"]
|
|
|
|
if optimizer_idx == 0:
|
|
tokens = batch["tokens"]
|
|
token_lenghts = batch["token_lens"]
|
|
spec = batch["spec"]
|
|
|
|
d_vectors = batch["d_vectors"]
|
|
speaker_ids = batch["speaker_ids"]
|
|
language_ids = batch["language_ids"]
|
|
waveform = batch["waveform"]
|
|
|
|
# generator pass
|
|
outputs = self.forward(
|
|
tokens,
|
|
token_lenghts,
|
|
spec,
|
|
spec_lens,
|
|
waveform,
|
|
aux_input={"d_vectors": d_vectors, "speaker_ids": speaker_ids, "language_ids": language_ids},
|
|
)
|
|
|
|
# cache tensors for the generator pass
|
|
self.model_outputs_cache = outputs # pylint: disable=attribute-defined-outside-init
|
|
|
|
# compute scores and features
|
|
scores_disc_fake, _, scores_disc_real, _ = self.disc(
|
|
outputs["model_outputs"].detach(), outputs["waveform_seg"]
|
|
)
|
|
|
|
# compute loss
|
|
with autocast(enabled=False): # use float32 for the criterion
|
|
loss_dict = criterion[optimizer_idx](
|
|
scores_disc_real,
|
|
scores_disc_fake,
|
|
)
|
|
return outputs, loss_dict
|
|
|
|
if optimizer_idx == 1:
|
|
mel = batch["mel"]
|
|
|
|
# compute melspec segment
|
|
with autocast(enabled=False):
|
|
if self.args.encoder_sample_rate:
|
|
spec_segment_size = self.spec_segment_size * int(self.interpolate_factor)
|
|
else:
|
|
spec_segment_size = self.spec_segment_size
|
|
|
|
mel_slice = segment(
|
|
mel.float(), self.model_outputs_cache["slice_ids"], spec_segment_size, pad_short=True
|
|
)
|
|
mel_slice_hat = wav_to_mel(
|
|
y=self.model_outputs_cache["model_outputs"].float(),
|
|
n_fft=self.config.audio.fft_size,
|
|
sample_rate=self.config.audio.sample_rate,
|
|
num_mels=self.config.audio.num_mels,
|
|
hop_length=self.config.audio.hop_length,
|
|
win_length=self.config.audio.win_length,
|
|
fmin=self.config.audio.mel_fmin,
|
|
fmax=self.config.audio.mel_fmax,
|
|
center=False,
|
|
)
|
|
|
|
# compute discriminator scores and features
|
|
scores_disc_fake, feats_disc_fake, _, feats_disc_real = self.disc(
|
|
self.model_outputs_cache["model_outputs"], self.model_outputs_cache["waveform_seg"]
|
|
)
|
|
|
|
# compute losses
|
|
with autocast(enabled=False): # use float32 for the criterion
|
|
loss_dict = criterion[optimizer_idx](
|
|
mel_slice_hat=mel_slice.float(),
|
|
mel_slice=mel_slice_hat.float(),
|
|
z_p=self.model_outputs_cache["z_p"].float(),
|
|
logs_q=self.model_outputs_cache["logs_q"].float(),
|
|
m_p=self.model_outputs_cache["m_p"].float(),
|
|
logs_p=self.model_outputs_cache["logs_p"].float(),
|
|
z_len=spec_lens,
|
|
scores_disc_fake=scores_disc_fake,
|
|
feats_disc_fake=feats_disc_fake,
|
|
feats_disc_real=feats_disc_real,
|
|
loss_duration=self.model_outputs_cache["loss_duration"],
|
|
use_speaker_encoder_as_loss=self.args.use_speaker_encoder_as_loss,
|
|
gt_spk_emb=self.model_outputs_cache["gt_spk_emb"],
|
|
syn_spk_emb=self.model_outputs_cache["syn_spk_emb"],
|
|
)
|
|
|
|
return self.model_outputs_cache, loss_dict
|
|
|
|
raise ValueError(" [!] Unexpected `optimizer_idx`.")
|
|
|
|
def _log(self, ap, batch, outputs, name_prefix="train"): # pylint: disable=unused-argument,no-self-use
|
|
y_hat = outputs[1]["model_outputs"]
|
|
y = outputs[1]["waveform_seg"]
|
|
figures = plot_results(y_hat, y, ap, name_prefix)
|
|
sample_voice = y_hat[0].squeeze(0).detach().cpu().numpy()
|
|
audios = {f"{name_prefix}/audio": sample_voice}
|
|
|
|
alignments = outputs[1]["alignments"]
|
|
align_img = alignments[0].data.cpu().numpy().T
|
|
|
|
figures.update(
|
|
{
|
|
"alignment": plot_alignment(align_img, output_fig=False),
|
|
}
|
|
)
|
|
return figures, audios
|
|
|
|
def train_log(
|
|
self, batch: dict, outputs: dict, logger: "Logger", assets: dict, steps: int
|
|
): # pylint: disable=no-self-use
|
|
"""Create visualizations and waveform examples.
|
|
|
|
For example, here you can plot spectrograms and generate sample sample waveforms from these spectrograms to
|
|
be projected onto Tensorboard.
|
|
|
|
Args:
|
|
ap (AudioProcessor): audio processor used at training.
|
|
batch (Dict): Model inputs used at the previous training step.
|
|
outputs (Dict): Model outputs generated at the previoud training step.
|
|
|
|
Returns:
|
|
Tuple[Dict, np.ndarray]: training plots and output waveform.
|
|
"""
|
|
figures, audios = self._log(self.ap, batch, outputs, "train")
|
|
logger.train_figures(steps, figures)
|
|
logger.train_audios(steps, audios, self.ap.sample_rate)
|
|
|
|
@torch.no_grad()
|
|
def eval_step(self, batch: dict, criterion: nn.Module, optimizer_idx: int):
|
|
return self.train_step(batch, criterion, optimizer_idx)
|
|
|
|
def eval_log(self, batch: dict, outputs: dict, logger: "Logger", assets: dict, steps: int) -> None:
|
|
figures, audios = self._log(self.ap, batch, outputs, "eval")
|
|
logger.eval_figures(steps, figures)
|
|
logger.eval_audios(steps, audios, self.ap.sample_rate)
|
|
|
|
def get_aux_input_from_test_sentences(self, sentence_info):
|
|
if hasattr(self.config, "model_args"):
|
|
config = self.config.model_args
|
|
else:
|
|
config = self.config
|
|
|
|
# extract speaker and language info
|
|
text, speaker_name, style_wav, language_name = None, None, None, None
|
|
|
|
if isinstance(sentence_info, list):
|
|
if len(sentence_info) == 1:
|
|
text = sentence_info[0]
|
|
elif len(sentence_info) == 2:
|
|
text, speaker_name = sentence_info
|
|
elif len(sentence_info) == 3:
|
|
text, speaker_name, style_wav = sentence_info
|
|
elif len(sentence_info) == 4:
|
|
text, speaker_name, style_wav, language_name = sentence_info
|
|
else:
|
|
text = sentence_info
|
|
|
|
# get speaker id/d_vector
|
|
speaker_id, d_vector, language_id = None, None, None
|
|
if hasattr(self, "speaker_manager"):
|
|
if config.use_d_vector_file:
|
|
if speaker_name is None:
|
|
d_vector = self.speaker_manager.get_random_embedding()
|
|
else:
|
|
d_vector = self.speaker_manager.get_mean_embedding(speaker_name, num_samples=None, randomize=False)
|
|
elif config.use_speaker_embedding:
|
|
if speaker_name is None:
|
|
speaker_id = self.speaker_manager.get_random_id()
|
|
else:
|
|
speaker_id = self.speaker_manager.name_to_id[speaker_name]
|
|
|
|
# get language id
|
|
if hasattr(self, "language_manager") and config.use_language_embedding and language_name is not None:
|
|
language_id = self.language_manager.name_to_id[language_name]
|
|
|
|
return {
|
|
"text": text,
|
|
"speaker_id": speaker_id,
|
|
"style_wav": style_wav,
|
|
"d_vector": d_vector,
|
|
"language_id": language_id,
|
|
"language_name": language_name,
|
|
}
|
|
|
|
@torch.no_grad()
|
|
def test_run(self, assets) -> Tuple[Dict, Dict]:
|
|
"""Generic test run for `tts` models used by `Trainer`.
|
|
|
|
You can override this for a different behaviour.
|
|
|
|
Returns:
|
|
Tuple[Dict, Dict]: Test figures and audios to be projected to Tensorboard.
|
|
"""
|
|
print(" | > Synthesizing test sentences.")
|
|
test_audios = {}
|
|
test_figures = {}
|
|
test_sentences = self.config.test_sentences
|
|
for idx, s_info in enumerate(test_sentences):
|
|
aux_inputs = self.get_aux_input_from_test_sentences(s_info)
|
|
wav, alignment, _, _ = synthesis(
|
|
self,
|
|
aux_inputs["text"],
|
|
self.config,
|
|
"cuda" in str(next(self.parameters()).device),
|
|
speaker_id=aux_inputs["speaker_id"],
|
|
d_vector=aux_inputs["d_vector"],
|
|
style_wav=aux_inputs["style_wav"],
|
|
language_id=aux_inputs["language_id"],
|
|
use_griffin_lim=True,
|
|
do_trim_silence=False,
|
|
).values()
|
|
test_audios["{}-audio".format(idx)] = wav
|
|
test_figures["{}-alignment".format(idx)] = plot_alignment(alignment.T, output_fig=False)
|
|
return {"figures": test_figures, "audios": test_audios}
|
|
|
|
def test_log(
|
|
self, outputs: dict, logger: "Logger", assets: dict, steps: int # pylint: disable=unused-argument
|
|
) -> None:
|
|
logger.test_audios(steps, outputs["audios"], self.ap.sample_rate)
|
|
logger.test_figures(steps, outputs["figures"])
|
|
|
|
def format_batch(self, batch: Dict) -> Dict:
|
|
"""Compute speaker, langugage IDs and d_vector for the batch if necessary."""
|
|
speaker_ids = None
|
|
language_ids = None
|
|
d_vectors = None
|
|
|
|
# get numerical speaker ids from speaker names
|
|
if self.speaker_manager is not None and self.speaker_manager.name_to_id and self.args.use_speaker_embedding:
|
|
speaker_ids = [self.speaker_manager.name_to_id[sn] for sn in batch["speaker_names"]]
|
|
|
|
if speaker_ids is not None:
|
|
speaker_ids = torch.LongTensor(speaker_ids)
|
|
|
|
# get d_vectors from audio file names
|
|
if self.speaker_manager is not None and self.speaker_manager.embeddings and self.args.use_d_vector_file:
|
|
d_vector_mapping = self.speaker_manager.embeddings
|
|
d_vectors = [d_vector_mapping[w]["embedding"] for w in batch["audio_unique_names"]]
|
|
d_vectors = torch.FloatTensor(d_vectors)
|
|
|
|
# get language ids from language names
|
|
if self.language_manager is not None and self.language_manager.name_to_id and self.args.use_language_embedding:
|
|
language_ids = [self.language_manager.name_to_id[ln] for ln in batch["language_names"]]
|
|
|
|
if language_ids is not None:
|
|
language_ids = torch.LongTensor(language_ids)
|
|
|
|
batch["language_ids"] = language_ids
|
|
batch["d_vectors"] = d_vectors
|
|
batch["speaker_ids"] = speaker_ids
|
|
return batch
|
|
|
|
def format_batch_on_device(self, batch):
|
|
"""Compute spectrograms on the device."""
|
|
ac = self.config.audio
|
|
|
|
if self.args.encoder_sample_rate:
|
|
wav = self.audio_resampler(batch["waveform"])
|
|
else:
|
|
wav = batch["waveform"]
|
|
|
|
# compute spectrograms
|
|
batch["spec"] = wav_to_spec(wav, ac.fft_size, ac.hop_length, ac.win_length, center=False)
|
|
|
|
if self.args.encoder_sample_rate:
|
|
# recompute spec with high sampling rate to the loss
|
|
spec_mel = wav_to_spec(batch["waveform"], ac.fft_size, ac.hop_length, ac.win_length, center=False)
|
|
# remove extra stft frames if needed
|
|
if spec_mel.size(2) > int(batch["spec"].size(2) * self.interpolate_factor):
|
|
spec_mel = spec_mel[:, :, : int(batch["spec"].size(2) * self.interpolate_factor)]
|
|
else:
|
|
batch["spec"] = batch["spec"][:, :, : int(spec_mel.size(2) / self.interpolate_factor)]
|
|
else:
|
|
spec_mel = batch["spec"]
|
|
|
|
batch["mel"] = spec_to_mel(
|
|
spec=spec_mel,
|
|
n_fft=ac.fft_size,
|
|
num_mels=ac.num_mels,
|
|
sample_rate=ac.sample_rate,
|
|
fmin=ac.mel_fmin,
|
|
fmax=ac.mel_fmax,
|
|
)
|
|
|
|
if self.args.encoder_sample_rate:
|
|
assert batch["spec"].shape[2] == int(
|
|
batch["mel"].shape[2] / self.interpolate_factor
|
|
), f"{batch['spec'].shape[2]}, {batch['mel'].shape[2]}"
|
|
else:
|
|
assert batch["spec"].shape[2] == batch["mel"].shape[2], f"{batch['spec'].shape[2]}, {batch['mel'].shape[2]}"
|
|
|
|
# compute spectrogram frame lengths
|
|
batch["spec_lens"] = (batch["spec"].shape[2] * batch["waveform_rel_lens"]).int()
|
|
batch["mel_lens"] = (batch["mel"].shape[2] * batch["waveform_rel_lens"]).int()
|
|
|
|
if self.args.encoder_sample_rate:
|
|
assert (batch["spec_lens"] - (batch["mel_lens"] / self.interpolate_factor).int()).sum() == 0
|
|
else:
|
|
assert (batch["spec_lens"] - batch["mel_lens"]).sum() == 0
|
|
|
|
# zero the padding frames
|
|
batch["spec"] = batch["spec"] * sequence_mask(batch["spec_lens"]).unsqueeze(1)
|
|
batch["mel"] = batch["mel"] * sequence_mask(batch["mel_lens"]).unsqueeze(1)
|
|
return batch
|
|
|
|
def get_sampler(self, config: Coqpit, dataset: TTSDataset, num_gpus=1, is_eval=False):
|
|
weights = None
|
|
data_items = dataset.samples
|
|
if getattr(config, "use_weighted_sampler", False):
|
|
for attr_name, alpha in config.weighted_sampler_attrs.items():
|
|
print(f" > Using weighted sampler for attribute '{attr_name}' with alpha '{alpha}'")
|
|
multi_dict = config.weighted_sampler_multipliers.get(attr_name, None)
|
|
print(multi_dict)
|
|
weights, attr_names, attr_weights = get_attribute_balancer_weights(
|
|
attr_name=attr_name, items=data_items, multi_dict=multi_dict
|
|
)
|
|
weights = weights * alpha
|
|
print(f" > Attribute weights for '{attr_names}' \n | > {attr_weights}")
|
|
|
|
# input_audio_lenghts = [os.path.getsize(x["audio_file"]) for x in data_items]
|
|
|
|
if weights is not None:
|
|
w_sampler = WeightedRandomSampler(weights, len(weights))
|
|
batch_sampler = BucketBatchSampler(
|
|
w_sampler,
|
|
data=data_items,
|
|
batch_size=config.eval_batch_size if is_eval else config.batch_size,
|
|
sort_key=lambda x: os.path.getsize(x["audio_file"]),
|
|
drop_last=True,
|
|
)
|
|
else:
|
|
batch_sampler = None
|
|
# sampler for DDP
|
|
if batch_sampler is None:
|
|
batch_sampler = DistributedSampler(dataset) if num_gpus > 1 else None
|
|
else: # If a sampler is already defined use this sampler and DDP sampler together
|
|
batch_sampler = (
|
|
DistributedSamplerWrapper(batch_sampler) if num_gpus > 1 else batch_sampler
|
|
) # TODO: check batch_sampler with multi-gpu
|
|
return batch_sampler
|
|
|
|
def get_data_loader(
|
|
self,
|
|
config: Coqpit,
|
|
assets: Dict,
|
|
is_eval: bool,
|
|
samples: Union[List[Dict], List[List]],
|
|
verbose: bool,
|
|
num_gpus: int,
|
|
rank: int = None,
|
|
) -> "DataLoader":
|
|
if is_eval and not config.run_eval:
|
|
loader = None
|
|
else:
|
|
# init dataloader
|
|
dataset = VitsDataset(
|
|
model_args=self.args,
|
|
samples=samples,
|
|
batch_group_size=0 if is_eval else config.batch_group_size * config.batch_size,
|
|
min_text_len=config.min_text_len,
|
|
max_text_len=config.max_text_len,
|
|
min_audio_len=config.min_audio_len,
|
|
max_audio_len=config.max_audio_len,
|
|
phoneme_cache_path=config.phoneme_cache_path,
|
|
precompute_num_workers=config.precompute_num_workers,
|
|
verbose=verbose,
|
|
tokenizer=self.tokenizer,
|
|
start_by_longest=config.start_by_longest,
|
|
)
|
|
|
|
# wait all the DDP process to be ready
|
|
if num_gpus > 1:
|
|
dist.barrier()
|
|
|
|
# sort input sequences from short to long
|
|
dataset.preprocess_samples()
|
|
|
|
# get samplers
|
|
sampler = self.get_sampler(config, dataset, num_gpus)
|
|
if sampler is None:
|
|
loader = DataLoader(
|
|
dataset,
|
|
batch_size=config.eval_batch_size if is_eval else config.batch_size,
|
|
shuffle=False, # shuffle is done in the dataset.
|
|
collate_fn=dataset.collate_fn,
|
|
drop_last=False, # setting this False might cause issues in AMP training.
|
|
num_workers=config.num_eval_loader_workers if is_eval else config.num_loader_workers,
|
|
pin_memory=False,
|
|
)
|
|
else:
|
|
if num_gpus > 1:
|
|
loader = DataLoader(
|
|
dataset,
|
|
sampler=sampler,
|
|
batch_size=config.eval_batch_size if is_eval else config.batch_size,
|
|
collate_fn=dataset.collate_fn,
|
|
num_workers=config.num_eval_loader_workers if is_eval else config.num_loader_workers,
|
|
pin_memory=False,
|
|
)
|
|
else:
|
|
loader = DataLoader(
|
|
dataset,
|
|
batch_sampler=sampler,
|
|
collate_fn=dataset.collate_fn,
|
|
num_workers=config.num_eval_loader_workers if is_eval else config.num_loader_workers,
|
|
pin_memory=False,
|
|
)
|
|
return loader
|
|
|
|
def get_optimizer(self) -> List:
|
|
"""Initiate and return the GAN optimizers based on the config parameters.
|
|
It returnes 2 optimizers in a list. First one is for the generator and the second one is for the discriminator.
|
|
Returns:
|
|
List: optimizers.
|
|
"""
|
|
# select generator parameters
|
|
optimizer0 = get_optimizer(self.config.optimizer, self.config.optimizer_params, self.config.lr_disc, self.disc)
|
|
|
|
gen_parameters = chain(params for k, params in self.named_parameters() if not k.startswith("disc."))
|
|
optimizer1 = get_optimizer(
|
|
self.config.optimizer, self.config.optimizer_params, self.config.lr_gen, parameters=gen_parameters
|
|
)
|
|
return [optimizer0, optimizer1]
|
|
|
|
def get_lr(self) -> List:
|
|
"""Set the initial learning rates for each optimizer.
|
|
|
|
Returns:
|
|
List: learning rates for each optimizer.
|
|
"""
|
|
return [self.config.lr_disc, self.config.lr_gen]
|
|
|
|
def get_scheduler(self, optimizer) -> List:
|
|
"""Set the schedulers for each optimizer.
|
|
|
|
Args:
|
|
optimizer (List[`torch.optim.Optimizer`]): List of optimizers.
|
|
|
|
Returns:
|
|
List: Schedulers, one for each optimizer.
|
|
"""
|
|
scheduler_D = get_scheduler(self.config.lr_scheduler_disc, self.config.lr_scheduler_disc_params, optimizer[0])
|
|
scheduler_G = get_scheduler(self.config.lr_scheduler_gen, self.config.lr_scheduler_gen_params, optimizer[1])
|
|
return [scheduler_D, scheduler_G]
|
|
|
|
def get_criterion(self):
|
|
"""Get criterions for each optimizer. The index in the output list matches the optimizer idx used in
|
|
`train_step()`"""
|
|
from TTS.tts.layers.losses import ( # pylint: disable=import-outside-toplevel
|
|
VitsDiscriminatorLoss,
|
|
VitsGeneratorLoss,
|
|
)
|
|
|
|
return [VitsDiscriminatorLoss(self.config), VitsGeneratorLoss(self.config)]
|
|
|
|
def load_checkpoint(
|
|
self, config, checkpoint_path, eval=False, strict=True, cache=False
|
|
): # pylint: disable=unused-argument, redefined-builtin
|
|
"""Load the model checkpoint and setup for training or inference"""
|
|
state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"), cache=cache)
|
|
# compat band-aid for the pre-trained models to not use the encoder baked into the model
|
|
# TODO: consider baking the speaker encoder into the model and call it from there.
|
|
# as it is probably easier for model distribution.
|
|
state["model"] = {k: v for k, v in state["model"].items() if "speaker_encoder" not in k}
|
|
|
|
if self.args.encoder_sample_rate is not None and eval:
|
|
# audio resampler is not used in inference time
|
|
self.audio_resampler = None
|
|
|
|
# handle fine-tuning from a checkpoint with additional speakers
|
|
if hasattr(self, "emb_g") and state["model"]["emb_g.weight"].shape != self.emb_g.weight.shape:
|
|
num_new_speakers = self.emb_g.weight.shape[0] - state["model"]["emb_g.weight"].shape[0]
|
|
print(f" > Loading checkpoint with {num_new_speakers} additional speakers.")
|
|
emb_g = state["model"]["emb_g.weight"]
|
|
new_row = torch.randn(num_new_speakers, emb_g.shape[1])
|
|
emb_g = torch.cat([emb_g, new_row], axis=0)
|
|
state["model"]["emb_g.weight"] = emb_g
|
|
# load the model weights
|
|
self.load_state_dict(state["model"], strict=strict)
|
|
|
|
if eval:
|
|
self.eval()
|
|
assert not self.training
|
|
|
|
def load_fairseq_checkpoint(
|
|
self, config, checkpoint_dir, eval=False, strict=True
|
|
): # pylint: disable=unused-argument, redefined-builtin
|
|
"""Load VITS checkpoints released by fairseq here: https://github.com/facebookresearch/fairseq/tree/main/examples/mms
|
|
Performs some changes for compatibility.
|
|
|
|
Args:
|
|
config (Coqpit): 🐸TTS model config.
|
|
checkpoint_dir (str): Path to the checkpoint directory.
|
|
eval (bool, optional): Set to True for evaluation. Defaults to False.
|
|
"""
|
|
import json
|
|
|
|
from TTS.tts.utils.text.cleaners import basic_cleaners
|
|
|
|
self.disc = None
|
|
# set paths
|
|
config_file = os.path.join(checkpoint_dir, "config.json")
|
|
checkpoint_file = os.path.join(checkpoint_dir, "G_100000.pth")
|
|
vocab_file = os.path.join(checkpoint_dir, "vocab.txt")
|
|
# set config params
|
|
with open(config_file, "r", encoding="utf-8") as file:
|
|
# Load the JSON data as a dictionary
|
|
config_org = json.load(file)
|
|
self.config.audio.sample_rate = config_org["data"]["sampling_rate"]
|
|
# self.config.add_blank = config['add_blank']
|
|
# set tokenizer
|
|
vocab = FairseqVocab(vocab_file)
|
|
self.text_encoder.emb = nn.Embedding(vocab.num_chars, config.model_args.hidden_channels)
|
|
self.tokenizer = TTSTokenizer(
|
|
use_phonemes=False,
|
|
text_cleaner=basic_cleaners,
|
|
characters=vocab,
|
|
phonemizer=None,
|
|
add_blank=config_org["data"]["add_blank"],
|
|
use_eos_bos=False,
|
|
)
|
|
# load fairseq checkpoint
|
|
new_chk = rehash_fairseq_vits_checkpoint(checkpoint_file)
|
|
self.load_state_dict(new_chk, strict=strict)
|
|
if eval:
|
|
self.eval()
|
|
assert not self.training
|
|
|
|
@staticmethod
|
|
def init_from_config(config: "VitsConfig", samples: Union[List[List], List[Dict]] = None, verbose=True):
|
|
"""Initiate model from config
|
|
|
|
Args:
|
|
config (VitsConfig): Model config.
|
|
samples (Union[List[List], List[Dict]]): Training samples to parse speaker ids for training.
|
|
Defaults to None.
|
|
"""
|
|
from TTS.utils.audio import AudioProcessor
|
|
|
|
upsample_rate = torch.prod(torch.as_tensor(config.model_args.upsample_rates_decoder)).item()
|
|
|
|
if not config.model_args.encoder_sample_rate:
|
|
assert (
|
|
upsample_rate == config.audio.hop_length
|
|
), f" [!] Product of upsample rates must be equal to the hop length - {upsample_rate} vs {config.audio.hop_length}"
|
|
else:
|
|
encoder_to_vocoder_upsampling_factor = config.audio.sample_rate / config.model_args.encoder_sample_rate
|
|
effective_hop_length = config.audio.hop_length * encoder_to_vocoder_upsampling_factor
|
|
assert (
|
|
upsample_rate == effective_hop_length
|
|
), f" [!] Product of upsample rates must be equal to the hop length - {upsample_rate} vs {effective_hop_length}"
|
|
|
|
ap = AudioProcessor.init_from_config(config, verbose=verbose)
|
|
tokenizer, new_config = TTSTokenizer.init_from_config(config)
|
|
speaker_manager = SpeakerManager.init_from_config(config, samples)
|
|
language_manager = LanguageManager.init_from_config(config)
|
|
|
|
if config.model_args.speaker_encoder_model_path:
|
|
speaker_manager.init_encoder(
|
|
config.model_args.speaker_encoder_model_path, config.model_args.speaker_encoder_config_path
|
|
)
|
|
return Vits(new_config, ap, tokenizer, speaker_manager, language_manager)
|
|
|
|
def export_onnx(self, output_path: str = "coqui_vits.onnx", verbose: bool = True):
|
|
"""Export model to ONNX format for inference
|
|
|
|
Args:
|
|
output_path (str): Path to save the exported model.
|
|
verbose (bool): Print verbose information. Defaults to True.
|
|
"""
|
|
|
|
# rollback values
|
|
_forward = self.forward
|
|
disc = None
|
|
if hasattr(self, "disc"):
|
|
disc = self.disc
|
|
training = self.training
|
|
|
|
# set export mode
|
|
self.disc = None
|
|
self.eval()
|
|
|
|
def onnx_inference(text, text_lengths, scales, sid=None, langid=None):
|
|
noise_scale = scales[0]
|
|
length_scale = scales[1]
|
|
noise_scale_dp = scales[2]
|
|
self.noise_scale = noise_scale
|
|
self.length_scale = length_scale
|
|
self.noise_scale_dp = noise_scale_dp
|
|
return self.inference(
|
|
text,
|
|
aux_input={
|
|
"x_lengths": text_lengths,
|
|
"d_vectors": None,
|
|
"speaker_ids": sid,
|
|
"language_ids": langid,
|
|
"durations": None,
|
|
},
|
|
)["model_outputs"]
|
|
|
|
self.forward = onnx_inference
|
|
|
|
# set dummy inputs
|
|
dummy_input_length = 100
|
|
sequences = torch.randint(low=0, high=2, size=(1, dummy_input_length), dtype=torch.long)
|
|
sequence_lengths = torch.LongTensor([sequences.size(1)])
|
|
scales = torch.FloatTensor([self.inference_noise_scale, self.length_scale, self.inference_noise_scale_dp])
|
|
dummy_input = (sequences, sequence_lengths, scales)
|
|
input_names = ["input", "input_lengths", "scales"]
|
|
|
|
if self.num_speakers > 0:
|
|
speaker_id = torch.LongTensor([0])
|
|
dummy_input += (speaker_id,)
|
|
input_names.append("sid")
|
|
|
|
if hasattr(self, "num_languages") and self.num_languages > 0 and self.embedded_language_dim > 0:
|
|
language_id = torch.LongTensor([0])
|
|
dummy_input += (language_id,)
|
|
input_names.append("langid")
|
|
|
|
# export to ONNX
|
|
torch.onnx.export(
|
|
model=self,
|
|
args=dummy_input,
|
|
opset_version=15,
|
|
f=output_path,
|
|
verbose=verbose,
|
|
input_names=input_names,
|
|
output_names=["output"],
|
|
dynamic_axes={
|
|
"input": {0: "batch_size", 1: "phonemes"},
|
|
"input_lengths": {0: "batch_size"},
|
|
"output": {0: "batch_size", 1: "time1", 2: "time2"},
|
|
},
|
|
)
|
|
|
|
# rollback
|
|
self.forward = _forward
|
|
if training:
|
|
self.train()
|
|
if not disc is None:
|
|
self.disc = disc
|
|
|
|
def load_onnx(self, model_path: str, cuda=False):
|
|
import onnxruntime as ort
|
|
|
|
providers = [
|
|
"CPUExecutionProvider"
|
|
if cuda is False
|
|
else ("CUDAExecutionProvider", {"cudnn_conv_algo_search": "DEFAULT"})
|
|
]
|
|
sess_options = ort.SessionOptions()
|
|
self.onnx_sess = ort.InferenceSession(
|
|
model_path,
|
|
sess_options=sess_options,
|
|
providers=providers,
|
|
)
|
|
|
|
def inference_onnx(self, x, x_lengths=None, speaker_id=None, language_id=None):
|
|
"""ONNX inference"""
|
|
|
|
if isinstance(x, torch.Tensor):
|
|
x = x.cpu().numpy()
|
|
|
|
if x_lengths is None:
|
|
x_lengths = np.array([x.shape[1]], dtype=np.int64)
|
|
|
|
if isinstance(x_lengths, torch.Tensor):
|
|
x_lengths = x_lengths.cpu().numpy()
|
|
scales = np.array(
|
|
[self.inference_noise_scale, self.length_scale, self.inference_noise_scale_dp],
|
|
dtype=np.float32,
|
|
)
|
|
input_params = {"input": x, "input_lengths": x_lengths, "scales": scales}
|
|
if not speaker_id is None:
|
|
input_params["sid"] = torch.tensor([speaker_id]).cpu().numpy()
|
|
if not language_id is None:
|
|
input_params["langid"] = torch.tensor([language_id]).cpu().numpy()
|
|
|
|
audio = self.onnx_sess.run(
|
|
["output"],
|
|
input_params,
|
|
)
|
|
return audio[0][0]
|
|
|
|
|
|
##################################
|
|
# VITS CHARACTERS
|
|
##################################
|
|
|
|
|
|
class VitsCharacters(BaseCharacters):
|
|
"""Characters class for VITs model for compatibility with pre-trained models"""
|
|
|
|
def __init__(
|
|
self,
|
|
graphemes: str = _characters,
|
|
punctuations: str = _punctuations,
|
|
pad: str = _pad,
|
|
ipa_characters: str = _phonemes,
|
|
) -> None:
|
|
if ipa_characters is not None:
|
|
graphemes += ipa_characters
|
|
super().__init__(graphemes, punctuations, pad, None, None, "<BLNK>", is_unique=False, is_sorted=True)
|
|
|
|
def _create_vocab(self):
|
|
self._vocab = [self._pad] + list(self._punctuations) + list(self._characters) + [self._blank]
|
|
self._char_to_id = {char: idx for idx, char in enumerate(self.vocab)}
|
|
# pylint: disable=unnecessary-comprehension
|
|
self._id_to_char = {idx: char for idx, char in enumerate(self.vocab)}
|
|
|
|
@staticmethod
|
|
def init_from_config(config: Coqpit):
|
|
if config.characters is not None:
|
|
_pad = config.characters["pad"]
|
|
_punctuations = config.characters["punctuations"]
|
|
_letters = config.characters["characters"]
|
|
_letters_ipa = config.characters["phonemes"]
|
|
return (
|
|
VitsCharacters(graphemes=_letters, ipa_characters=_letters_ipa, punctuations=_punctuations, pad=_pad),
|
|
config,
|
|
)
|
|
characters = VitsCharacters()
|
|
new_config = replace(config, characters=characters.to_config())
|
|
return characters, new_config
|
|
|
|
def to_config(self) -> "CharactersConfig":
|
|
return CharactersConfig(
|
|
characters=self._characters,
|
|
punctuations=self._punctuations,
|
|
pad=self._pad,
|
|
eos=None,
|
|
bos=None,
|
|
blank=self._blank,
|
|
is_unique=False,
|
|
is_sorted=True,
|
|
)
|
|
|
|
|
|
class FairseqVocab(BaseVocabulary):
|
|
def __init__(self, vocab: str):
|
|
super(FairseqVocab).__init__()
|
|
self.vocab = vocab
|
|
|
|
@property
|
|
def vocab(self):
|
|
"""Return the vocabulary dictionary."""
|
|
return self._vocab
|
|
|
|
@vocab.setter
|
|
def vocab(self, vocab_file):
|
|
with open(vocab_file, encoding="utf-8") as f:
|
|
self._vocab = [x.replace("\n", "") for x in f.readlines()]
|
|
self.blank = self._vocab[0]
|
|
self.pad = " "
|
|
self._char_to_id = {s: i for i, s in enumerate(self._vocab)} # pylint: disable=unnecessary-comprehension
|
|
self._id_to_char = {i: s for i, s in enumerate(self._vocab)} # pylint: disable=unnecessary-comprehension
|