262 lines
9.0 KiB
Python
262 lines
9.0 KiB
Python
from typing import List, Tuple, Union
|
|
|
|
import torch
|
|
import torch.nn as nn # pylint: disable=consider-using-from-import
|
|
import torch.nn.functional as F
|
|
|
|
from TTS.tts.layers.delightful_tts.conformer import ConformerMultiHeadedSelfAttention
|
|
from TTS.tts.layers.delightful_tts.conv_layers import CoordConv1d
|
|
from TTS.tts.layers.delightful_tts.networks import STL
|
|
|
|
|
|
def get_mask_from_lengths(lengths: torch.Tensor) -> torch.Tensor:
|
|
batch_size = lengths.shape[0]
|
|
max_len = torch.max(lengths).item()
|
|
ids = torch.arange(0, max_len, device=lengths.device).unsqueeze(0).expand(batch_size, -1)
|
|
mask = ids >= lengths.unsqueeze(1).expand(-1, max_len)
|
|
return mask
|
|
|
|
|
|
def stride_lens(lens: torch.Tensor, stride: int = 2) -> torch.Tensor:
|
|
return torch.ceil(lens / stride).int()
|
|
|
|
|
|
class ReferenceEncoder(nn.Module):
|
|
"""
|
|
Referance encoder for utterance and phoneme prosody encoders. Reference encoder
|
|
made up of convolution and RNN layers.
|
|
|
|
Args:
|
|
num_mels (int): Number of mel frames to produce.
|
|
ref_enc_filters (list[int]): List of channel sizes for encoder layers.
|
|
ref_enc_size (int): Size of the kernel for the conv layers.
|
|
ref_enc_strides (List[int]): List of strides to use for conv layers.
|
|
ref_enc_gru_size (int): Number of hidden features for the gated recurrent unit.
|
|
|
|
Inputs: inputs, mask
|
|
- **inputs** (batch, dim, time): Tensor containing mel vector
|
|
- **lengths** (batch): Tensor containing the mel lengths.
|
|
Returns:
|
|
- **outputs** (batch, time, dim): Tensor produced by Reference Encoder.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
num_mels: int,
|
|
ref_enc_filters: List[Union[int, int, int, int, int, int]],
|
|
ref_enc_size: int,
|
|
ref_enc_strides: List[Union[int, int, int, int, int]],
|
|
ref_enc_gru_size: int,
|
|
):
|
|
super().__init__()
|
|
|
|
n_mel_channels = num_mels
|
|
self.n_mel_channels = n_mel_channels
|
|
K = len(ref_enc_filters)
|
|
filters = [self.n_mel_channels] + ref_enc_filters
|
|
strides = [1] + ref_enc_strides
|
|
# Use CoordConv at the first layer to better preserve positional information: https://arxiv.org/pdf/1811.02122.pdf
|
|
convs = [
|
|
CoordConv1d(
|
|
in_channels=filters[0],
|
|
out_channels=filters[0 + 1],
|
|
kernel_size=ref_enc_size,
|
|
stride=strides[0],
|
|
padding=ref_enc_size // 2,
|
|
with_r=True,
|
|
)
|
|
]
|
|
convs2 = [
|
|
nn.Conv1d(
|
|
in_channels=filters[i],
|
|
out_channels=filters[i + 1],
|
|
kernel_size=ref_enc_size,
|
|
stride=strides[i],
|
|
padding=ref_enc_size // 2,
|
|
)
|
|
for i in range(1, K)
|
|
]
|
|
convs.extend(convs2)
|
|
self.convs = nn.ModuleList(convs)
|
|
|
|
self.norms = nn.ModuleList([nn.InstanceNorm1d(num_features=ref_enc_filters[i], affine=True) for i in range(K)])
|
|
|
|
self.gru = nn.GRU(
|
|
input_size=ref_enc_filters[-1],
|
|
hidden_size=ref_enc_gru_size,
|
|
batch_first=True,
|
|
)
|
|
|
|
def forward(self, x: torch.Tensor, mel_lens: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
|
"""
|
|
inputs --- [N, n_mels, timesteps]
|
|
outputs --- [N, E//2]
|
|
"""
|
|
|
|
mel_masks = get_mask_from_lengths(mel_lens).unsqueeze(1)
|
|
x = x.masked_fill(mel_masks, 0)
|
|
for conv, norm in zip(self.convs, self.norms):
|
|
x = conv(x)
|
|
x = F.leaky_relu(x, 0.3) # [N, 128, Ty//2^K, n_mels//2^K]
|
|
x = norm(x)
|
|
|
|
for _ in range(2):
|
|
mel_lens = stride_lens(mel_lens)
|
|
|
|
mel_masks = get_mask_from_lengths(mel_lens)
|
|
|
|
x = x.masked_fill(mel_masks.unsqueeze(1), 0)
|
|
x = x.permute((0, 2, 1))
|
|
x = torch.nn.utils.rnn.pack_padded_sequence(x, mel_lens.cpu().int(), batch_first=True, enforce_sorted=False)
|
|
|
|
self.gru.flatten_parameters()
|
|
x, memory = self.gru(x) # memory --- [N, Ty, E//2], out --- [1, N, E//2]
|
|
x, _ = torch.nn.utils.rnn.pad_packed_sequence(x, batch_first=True)
|
|
|
|
return x, memory, mel_masks
|
|
|
|
def calculate_channels( # pylint: disable=no-self-use
|
|
self, L: int, kernel_size: int, stride: int, pad: int, n_convs: int
|
|
) -> int:
|
|
for _ in range(n_convs):
|
|
L = (L - kernel_size + 2 * pad) // stride + 1
|
|
return L
|
|
|
|
|
|
class UtteranceLevelProsodyEncoder(nn.Module):
|
|
def __init__(
|
|
self,
|
|
num_mels: int,
|
|
ref_enc_filters: List[Union[int, int, int, int, int, int]],
|
|
ref_enc_size: int,
|
|
ref_enc_strides: List[Union[int, int, int, int, int]],
|
|
ref_enc_gru_size: int,
|
|
dropout: float,
|
|
n_hidden: int,
|
|
bottleneck_size_u: int,
|
|
token_num: int,
|
|
):
|
|
"""
|
|
Encoder to extract prosody from utterance. it is made up of a reference encoder
|
|
with a couple of linear layers and style token layer with dropout.
|
|
|
|
Args:
|
|
num_mels (int): Number of mel frames to produce.
|
|
ref_enc_filters (list[int]): List of channel sizes for ref encoder layers.
|
|
ref_enc_size (int): Size of the kernel for the ref encoder conv layers.
|
|
ref_enc_strides (List[int]): List of strides to use for teh ref encoder conv layers.
|
|
ref_enc_gru_size (int): Number of hidden features for the gated recurrent unit.
|
|
dropout (float): Probability of dropout.
|
|
n_hidden (int): Size of hidden layers.
|
|
bottleneck_size_u (int): Size of the bottle neck layer.
|
|
|
|
Inputs: inputs, mask
|
|
- **inputs** (batch, dim, time): Tensor containing mel vector
|
|
- **lengths** (batch): Tensor containing the mel lengths.
|
|
Returns:
|
|
- **outputs** (batch, 1, dim): Tensor produced by Utterance Level Prosody Encoder.
|
|
"""
|
|
super().__init__()
|
|
|
|
self.E = n_hidden
|
|
self.d_q = self.d_k = n_hidden
|
|
bottleneck_size = bottleneck_size_u
|
|
|
|
self.encoder = ReferenceEncoder(
|
|
ref_enc_filters=ref_enc_filters,
|
|
ref_enc_gru_size=ref_enc_gru_size,
|
|
ref_enc_size=ref_enc_size,
|
|
ref_enc_strides=ref_enc_strides,
|
|
num_mels=num_mels,
|
|
)
|
|
self.encoder_prj = nn.Linear(ref_enc_gru_size, self.E // 2)
|
|
self.stl = STL(n_hidden=n_hidden, token_num=token_num)
|
|
self.encoder_bottleneck = nn.Linear(self.E, bottleneck_size)
|
|
self.dropout = nn.Dropout(dropout)
|
|
|
|
def forward(self, mels: torch.Tensor, mel_lens: torch.Tensor) -> torch.Tensor:
|
|
"""
|
|
Shapes:
|
|
mels: :math: `[B, C, T]`
|
|
mel_lens: :math: `[B]`
|
|
|
|
out --- [N, seq_len, E]
|
|
"""
|
|
_, embedded_prosody, _ = self.encoder(mels, mel_lens)
|
|
|
|
# Bottleneck
|
|
embedded_prosody = self.encoder_prj(embedded_prosody)
|
|
|
|
# Style Token
|
|
out = self.encoder_bottleneck(self.stl(embedded_prosody))
|
|
out = self.dropout(out)
|
|
|
|
out = out.view((-1, 1, out.shape[3]))
|
|
return out
|
|
|
|
|
|
class PhonemeLevelProsodyEncoder(nn.Module):
|
|
def __init__(
|
|
self,
|
|
num_mels: int,
|
|
ref_enc_filters: List[Union[int, int, int, int, int, int]],
|
|
ref_enc_size: int,
|
|
ref_enc_strides: List[Union[int, int, int, int, int]],
|
|
ref_enc_gru_size: int,
|
|
dropout: float,
|
|
n_hidden: int,
|
|
n_heads: int,
|
|
bottleneck_size_p: int,
|
|
):
|
|
super().__init__()
|
|
|
|
self.E = n_hidden
|
|
self.d_q = self.d_k = n_hidden
|
|
bottleneck_size = bottleneck_size_p
|
|
|
|
self.encoder = ReferenceEncoder(
|
|
ref_enc_filters=ref_enc_filters,
|
|
ref_enc_gru_size=ref_enc_gru_size,
|
|
ref_enc_size=ref_enc_size,
|
|
ref_enc_strides=ref_enc_strides,
|
|
num_mels=num_mels,
|
|
)
|
|
self.encoder_prj = nn.Linear(ref_enc_gru_size, n_hidden)
|
|
self.attention = ConformerMultiHeadedSelfAttention(
|
|
d_model=n_hidden,
|
|
num_heads=n_heads,
|
|
dropout_p=dropout,
|
|
)
|
|
self.encoder_bottleneck = nn.Linear(n_hidden, bottleneck_size)
|
|
|
|
def forward(
|
|
self,
|
|
x: torch.Tensor,
|
|
src_mask: torch.Tensor,
|
|
mels: torch.Tensor,
|
|
mel_lens: torch.Tensor,
|
|
encoding: torch.Tensor,
|
|
) -> torch.Tensor:
|
|
"""
|
|
x --- [N, seq_len, encoder_embedding_dim]
|
|
mels --- [N, Ty/r, n_mels*r], r=1
|
|
out --- [N, seq_len, bottleneck_size]
|
|
attn --- [N, seq_len, ref_len], Ty/r = ref_len
|
|
"""
|
|
embedded_prosody, _, mel_masks = self.encoder(mels, mel_lens)
|
|
|
|
# Bottleneck
|
|
embedded_prosody = self.encoder_prj(embedded_prosody)
|
|
|
|
attn_mask = mel_masks.view((mel_masks.shape[0], 1, 1, -1))
|
|
x, _ = self.attention(
|
|
query=x,
|
|
key=embedded_prosody,
|
|
value=embedded_prosody,
|
|
mask=attn_mask,
|
|
encoding=encoding,
|
|
)
|
|
x = self.encoder_bottleneck(x)
|
|
x = x.masked_fill(src_mask.unsqueeze(-1), 0.0)
|
|
return x
|