324 lines
12 KiB
Python
324 lines
12 KiB
Python
|
from typing import List, Tuple
|
||
|
|
||
|
import torch
|
||
|
import torch.nn.functional as F
|
||
|
from torch import nn
|
||
|
from tqdm.auto import tqdm
|
||
|
|
||
|
from TTS.tts.layers.tacotron.common_layers import Linear
|
||
|
from TTS.tts.layers.tacotron.tacotron2 import ConvBNBlock
|
||
|
|
||
|
|
||
|
class Encoder(nn.Module):
|
||
|
r"""Neural HMM Encoder
|
||
|
|
||
|
Same as Tacotron 2 encoder but increases the input length by states per phone
|
||
|
|
||
|
Args:
|
||
|
num_chars (int): Number of characters in the input.
|
||
|
state_per_phone (int): Number of states per phone.
|
||
|
in_out_channels (int): number of input and output channels.
|
||
|
n_convolutions (int): number of convolutional layers.
|
||
|
"""
|
||
|
|
||
|
def __init__(self, num_chars, state_per_phone, in_out_channels=512, n_convolutions=3):
|
||
|
super().__init__()
|
||
|
|
||
|
self.state_per_phone = state_per_phone
|
||
|
self.in_out_channels = in_out_channels
|
||
|
|
||
|
self.emb = nn.Embedding(num_chars, in_out_channels)
|
||
|
self.convolutions = nn.ModuleList()
|
||
|
for _ in range(n_convolutions):
|
||
|
self.convolutions.append(ConvBNBlock(in_out_channels, in_out_channels, 5, "relu"))
|
||
|
self.lstm = nn.LSTM(
|
||
|
in_out_channels,
|
||
|
int(in_out_channels / 2) * state_per_phone,
|
||
|
num_layers=1,
|
||
|
batch_first=True,
|
||
|
bias=True,
|
||
|
bidirectional=True,
|
||
|
)
|
||
|
self.rnn_state = None
|
||
|
|
||
|
def forward(self, x: torch.FloatTensor, x_len: torch.LongTensor) -> Tuple[torch.FloatTensor, torch.LongTensor]:
|
||
|
"""Forward pass to the encoder.
|
||
|
|
||
|
Args:
|
||
|
x (torch.FloatTensor): input text indices.
|
||
|
- shape: :math:`(b, T_{in})`
|
||
|
x_len (torch.LongTensor): input text lengths.
|
||
|
- shape: :math:`(b,)`
|
||
|
|
||
|
Returns:
|
||
|
Tuple[torch.FloatTensor, torch.LongTensor]: encoder outputs and output lengths.
|
||
|
-shape: :math:`((b, T_{in} * states_per_phone, in_out_channels), (b,))`
|
||
|
"""
|
||
|
b, T = x.shape
|
||
|
o = self.emb(x).transpose(1, 2)
|
||
|
for layer in self.convolutions:
|
||
|
o = layer(o)
|
||
|
o = o.transpose(1, 2)
|
||
|
o = nn.utils.rnn.pack_padded_sequence(o, x_len.cpu(), batch_first=True)
|
||
|
self.lstm.flatten_parameters()
|
||
|
o, _ = self.lstm(o)
|
||
|
o, _ = nn.utils.rnn.pad_packed_sequence(o, batch_first=True)
|
||
|
o = o.reshape(b, T * self.state_per_phone, self.in_out_channels)
|
||
|
x_len = x_len * self.state_per_phone
|
||
|
return o, x_len
|
||
|
|
||
|
def inference(self, x, x_len):
|
||
|
"""Inference to the encoder.
|
||
|
|
||
|
Args:
|
||
|
x (torch.FloatTensor): input text indices.
|
||
|
- shape: :math:`(b, T_{in})`
|
||
|
x_len (torch.LongTensor): input text lengths.
|
||
|
- shape: :math:`(b,)`
|
||
|
|
||
|
Returns:
|
||
|
Tuple[torch.FloatTensor, torch.LongTensor]: encoder outputs and output lengths.
|
||
|
-shape: :math:`((b, T_{in} * states_per_phone, in_out_channels), (b,))`
|
||
|
"""
|
||
|
b, T = x.shape
|
||
|
o = self.emb(x).transpose(1, 2)
|
||
|
for layer in self.convolutions:
|
||
|
o = layer(o)
|
||
|
o = o.transpose(1, 2)
|
||
|
# self.lstm.flatten_parameters()
|
||
|
o, _ = self.lstm(o)
|
||
|
o = o.reshape(b, T * self.state_per_phone, self.in_out_channels)
|
||
|
x_len = x_len * self.state_per_phone
|
||
|
return o, x_len
|
||
|
|
||
|
|
||
|
class ParameterModel(nn.Module):
|
||
|
r"""Main neural network of the outputnet
|
||
|
|
||
|
Note: Do not put dropout layers here, the model will not converge.
|
||
|
|
||
|
Args:
|
||
|
outputnet_size (List[int]): the architecture of the parameter model
|
||
|
input_size (int): size of input for the first layer
|
||
|
output_size (int): size of output i.e size of the feature dim
|
||
|
frame_channels (int): feature dim to set the flat start bias
|
||
|
flat_start_params (dict): flat start parameters to set the bias
|
||
|
"""
|
||
|
|
||
|
def __init__(
|
||
|
self,
|
||
|
outputnet_size: List[int],
|
||
|
input_size: int,
|
||
|
output_size: int,
|
||
|
frame_channels: int,
|
||
|
flat_start_params: dict,
|
||
|
):
|
||
|
super().__init__()
|
||
|
self.frame_channels = frame_channels
|
||
|
|
||
|
self.layers = nn.ModuleList(
|
||
|
[Linear(inp, out) for inp, out in zip([input_size] + outputnet_size[:-1], outputnet_size)]
|
||
|
)
|
||
|
self.last_layer = nn.Linear(outputnet_size[-1], output_size)
|
||
|
self.flat_start_output_layer(
|
||
|
flat_start_params["mean"], flat_start_params["std"], flat_start_params["transition_p"]
|
||
|
)
|
||
|
|
||
|
def flat_start_output_layer(self, mean, std, transition_p):
|
||
|
self.last_layer.weight.data.zero_()
|
||
|
self.last_layer.bias.data[0 : self.frame_channels] = mean
|
||
|
self.last_layer.bias.data[self.frame_channels : 2 * self.frame_channels] = OverflowUtils.inverse_softplus(std)
|
||
|
self.last_layer.bias.data[2 * self.frame_channels :] = OverflowUtils.inverse_sigmod(transition_p)
|
||
|
|
||
|
def forward(self, x):
|
||
|
for layer in self.layers:
|
||
|
x = F.relu(layer(x))
|
||
|
x = self.last_layer(x)
|
||
|
return x
|
||
|
|
||
|
|
||
|
class Outputnet(nn.Module):
|
||
|
r"""
|
||
|
This network takes current state and previous observed values as input
|
||
|
and returns its parameters, mean, standard deviation and probability
|
||
|
of transition to the next state
|
||
|
"""
|
||
|
|
||
|
def __init__(
|
||
|
self,
|
||
|
encoder_dim: int,
|
||
|
memory_rnn_dim: int,
|
||
|
frame_channels: int,
|
||
|
outputnet_size: List[int],
|
||
|
flat_start_params: dict,
|
||
|
std_floor: float = 1e-2,
|
||
|
):
|
||
|
super().__init__()
|
||
|
|
||
|
self.frame_channels = frame_channels
|
||
|
self.flat_start_params = flat_start_params
|
||
|
self.std_floor = std_floor
|
||
|
|
||
|
input_size = memory_rnn_dim + encoder_dim
|
||
|
output_size = 2 * frame_channels + 1
|
||
|
|
||
|
self.parametermodel = ParameterModel(
|
||
|
outputnet_size=outputnet_size,
|
||
|
input_size=input_size,
|
||
|
output_size=output_size,
|
||
|
flat_start_params=flat_start_params,
|
||
|
frame_channels=frame_channels,
|
||
|
)
|
||
|
|
||
|
def forward(self, ar_mels, inputs):
|
||
|
r"""Inputs observation and returns the means, stds and transition probability for the current state
|
||
|
|
||
|
Args:
|
||
|
ar_mel_inputs (torch.FloatTensor): shape (batch, prenet_dim)
|
||
|
states (torch.FloatTensor): (batch, hidden_states, hidden_state_dim)
|
||
|
|
||
|
Returns:
|
||
|
means: means for the emission observation for each feature
|
||
|
- shape: (B, hidden_states, feature_size)
|
||
|
stds: standard deviations for the emission observation for each feature
|
||
|
- shape: (batch, hidden_states, feature_size)
|
||
|
transition_vectors: transition vector for the current hidden state
|
||
|
- shape: (batch, hidden_states)
|
||
|
"""
|
||
|
batch_size, prenet_dim = ar_mels.shape[0], ar_mels.shape[1]
|
||
|
N = inputs.shape[1]
|
||
|
|
||
|
ar_mels = ar_mels.unsqueeze(1).expand(batch_size, N, prenet_dim)
|
||
|
ar_mels = torch.cat((ar_mels, inputs), dim=2)
|
||
|
ar_mels = self.parametermodel(ar_mels)
|
||
|
|
||
|
mean, std, transition_vector = (
|
||
|
ar_mels[:, :, 0 : self.frame_channels],
|
||
|
ar_mels[:, :, self.frame_channels : 2 * self.frame_channels],
|
||
|
ar_mels[:, :, 2 * self.frame_channels :].squeeze(2),
|
||
|
)
|
||
|
std = F.softplus(std)
|
||
|
std = self._floor_std(std)
|
||
|
return mean, std, transition_vector
|
||
|
|
||
|
def _floor_std(self, std):
|
||
|
r"""
|
||
|
It clamps the standard deviation to not to go below some level
|
||
|
This removes the problem when the model tries to cheat for higher likelihoods by converting
|
||
|
one of the gaussians to a point mass.
|
||
|
|
||
|
Args:
|
||
|
std (float Tensor): tensor containing the standard deviation to be
|
||
|
"""
|
||
|
original_tensor = std.clone().detach()
|
||
|
std = torch.clamp(std, min=self.std_floor)
|
||
|
if torch.any(original_tensor != std):
|
||
|
print(
|
||
|
"[*] Standard deviation was floored! The model is preventing overfitting, nothing serious to worry about"
|
||
|
)
|
||
|
return std
|
||
|
|
||
|
|
||
|
class OverflowUtils:
|
||
|
@staticmethod
|
||
|
def get_data_parameters_for_flat_start(
|
||
|
data_loader: torch.utils.data.DataLoader, out_channels: int, states_per_phone: int
|
||
|
):
|
||
|
"""Generates data parameters for flat starting the HMM.
|
||
|
|
||
|
Args:
|
||
|
data_loader (torch.utils.data.Dataloader): _description_
|
||
|
out_channels (int): mel spectrogram channels
|
||
|
states_per_phone (_type_): HMM states per phone
|
||
|
"""
|
||
|
|
||
|
# State related information for transition_p
|
||
|
total_state_len = 0
|
||
|
total_mel_len = 0
|
||
|
|
||
|
# Useful for data mean an std
|
||
|
total_mel_sum = 0
|
||
|
total_mel_sq_sum = 0
|
||
|
|
||
|
for batch in tqdm(data_loader, leave=False):
|
||
|
text_lengths = batch["token_id_lengths"]
|
||
|
mels = batch["mel"]
|
||
|
mel_lengths = batch["mel_lengths"]
|
||
|
|
||
|
total_state_len += torch.sum(text_lengths)
|
||
|
total_mel_len += torch.sum(mel_lengths)
|
||
|
total_mel_sum += torch.sum(mels)
|
||
|
total_mel_sq_sum += torch.sum(torch.pow(mels, 2))
|
||
|
|
||
|
data_mean = total_mel_sum / (total_mel_len * out_channels)
|
||
|
data_std = torch.sqrt((total_mel_sq_sum / (total_mel_len * out_channels)) - torch.pow(data_mean, 2))
|
||
|
average_num_states = total_state_len / len(data_loader.dataset)
|
||
|
average_mel_len = total_mel_len / len(data_loader.dataset)
|
||
|
average_duration_each_state = average_mel_len / average_num_states
|
||
|
init_transition_prob = 1 / average_duration_each_state
|
||
|
|
||
|
return data_mean, data_std, (init_transition_prob * states_per_phone)
|
||
|
|
||
|
@staticmethod
|
||
|
@torch.no_grad()
|
||
|
def update_flat_start_transition(model, transition_p):
|
||
|
model.neural_hmm.output_net.parametermodel.flat_start_output_layer(0.0, 1.0, transition_p)
|
||
|
|
||
|
@staticmethod
|
||
|
def log_clamped(x, eps=1e-04):
|
||
|
"""
|
||
|
Avoids the log(0) problem
|
||
|
|
||
|
Args:
|
||
|
x (torch.tensor): input tensor
|
||
|
eps (float, optional): lower bound. Defaults to 1e-04.
|
||
|
|
||
|
Returns:
|
||
|
torch.tensor: :math:`log(x)`
|
||
|
"""
|
||
|
clamped_x = torch.clamp(x, min=eps)
|
||
|
return torch.log(clamped_x)
|
||
|
|
||
|
@staticmethod
|
||
|
def inverse_sigmod(x):
|
||
|
r"""
|
||
|
Inverse of the sigmoid function
|
||
|
"""
|
||
|
if not torch.is_tensor(x):
|
||
|
x = torch.tensor(x)
|
||
|
return OverflowUtils.log_clamped(x / (1.0 - x))
|
||
|
|
||
|
@staticmethod
|
||
|
def inverse_softplus(x):
|
||
|
r"""
|
||
|
Inverse of the softplus function
|
||
|
"""
|
||
|
if not torch.is_tensor(x):
|
||
|
x = torch.tensor(x)
|
||
|
return OverflowUtils.log_clamped(torch.exp(x) - 1.0)
|
||
|
|
||
|
@staticmethod
|
||
|
def logsumexp(x, dim):
|
||
|
r"""
|
||
|
Differentiable LogSumExp: Does not creates nan gradients
|
||
|
when all the inputs are -inf yeilds 0 gradients.
|
||
|
Args:
|
||
|
x : torch.Tensor - The input tensor
|
||
|
dim: int - The dimension on which the log sum exp has to be applied
|
||
|
"""
|
||
|
|
||
|
m, _ = x.max(dim=dim)
|
||
|
mask = m == -float("inf")
|
||
|
s = (x - m.masked_fill_(mask, 0).unsqueeze(dim=dim)).exp().sum(dim=dim)
|
||
|
return s.masked_fill_(mask, 1).log() + m.masked_fill_(mask, -float("inf"))
|
||
|
|
||
|
@staticmethod
|
||
|
def double_pad(list_of_different_shape_tensors):
|
||
|
r"""
|
||
|
Pads the list of tensors in 2 dimensions
|
||
|
"""
|
||
|
second_dim_lens = [len(a) for a in [i[0] for i in list_of_different_shape_tensors]]
|
||
|
second_dim_max = max(second_dim_lens)
|
||
|
padded_x = [F.pad(x, (0, second_dim_max - len(x[0]))) for x in list_of_different_shape_tensors]
|
||
|
return nn.utils.rnn.pad_sequence(padded_x, batch_first=True)
|