ai-content-maker/.venv/Lib/site-packages/TTS/tts/layers/delightful_tts/encoders.py

262 lines
9.0 KiB
Python

from typing import List, Tuple, Union
import torch
import torch.nn as nn # pylint: disable=consider-using-from-import
import torch.nn.functional as F
from TTS.tts.layers.delightful_tts.conformer import ConformerMultiHeadedSelfAttention
from TTS.tts.layers.delightful_tts.conv_layers import CoordConv1d
from TTS.tts.layers.delightful_tts.networks import STL
def get_mask_from_lengths(lengths: torch.Tensor) -> torch.Tensor:
batch_size = lengths.shape[0]
max_len = torch.max(lengths).item()
ids = torch.arange(0, max_len, device=lengths.device).unsqueeze(0).expand(batch_size, -1)
mask = ids >= lengths.unsqueeze(1).expand(-1, max_len)
return mask
def stride_lens(lens: torch.Tensor, stride: int = 2) -> torch.Tensor:
return torch.ceil(lens / stride).int()
class ReferenceEncoder(nn.Module):
"""
Referance encoder for utterance and phoneme prosody encoders. Reference encoder
made up of convolution and RNN layers.
Args:
num_mels (int): Number of mel frames to produce.
ref_enc_filters (list[int]): List of channel sizes for encoder layers.
ref_enc_size (int): Size of the kernel for the conv layers.
ref_enc_strides (List[int]): List of strides to use for conv layers.
ref_enc_gru_size (int): Number of hidden features for the gated recurrent unit.
Inputs: inputs, mask
- **inputs** (batch, dim, time): Tensor containing mel vector
- **lengths** (batch): Tensor containing the mel lengths.
Returns:
- **outputs** (batch, time, dim): Tensor produced by Reference Encoder.
"""
def __init__(
self,
num_mels: int,
ref_enc_filters: List[Union[int, int, int, int, int, int]],
ref_enc_size: int,
ref_enc_strides: List[Union[int, int, int, int, int]],
ref_enc_gru_size: int,
):
super().__init__()
n_mel_channels = num_mels
self.n_mel_channels = n_mel_channels
K = len(ref_enc_filters)
filters = [self.n_mel_channels] + ref_enc_filters
strides = [1] + ref_enc_strides
# Use CoordConv at the first layer to better preserve positional information: https://arxiv.org/pdf/1811.02122.pdf
convs = [
CoordConv1d(
in_channels=filters[0],
out_channels=filters[0 + 1],
kernel_size=ref_enc_size,
stride=strides[0],
padding=ref_enc_size // 2,
with_r=True,
)
]
convs2 = [
nn.Conv1d(
in_channels=filters[i],
out_channels=filters[i + 1],
kernel_size=ref_enc_size,
stride=strides[i],
padding=ref_enc_size // 2,
)
for i in range(1, K)
]
convs.extend(convs2)
self.convs = nn.ModuleList(convs)
self.norms = nn.ModuleList([nn.InstanceNorm1d(num_features=ref_enc_filters[i], affine=True) for i in range(K)])
self.gru = nn.GRU(
input_size=ref_enc_filters[-1],
hidden_size=ref_enc_gru_size,
batch_first=True,
)
def forward(self, x: torch.Tensor, mel_lens: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
inputs --- [N, n_mels, timesteps]
outputs --- [N, E//2]
"""
mel_masks = get_mask_from_lengths(mel_lens).unsqueeze(1)
x = x.masked_fill(mel_masks, 0)
for conv, norm in zip(self.convs, self.norms):
x = conv(x)
x = F.leaky_relu(x, 0.3) # [N, 128, Ty//2^K, n_mels//2^K]
x = norm(x)
for _ in range(2):
mel_lens = stride_lens(mel_lens)
mel_masks = get_mask_from_lengths(mel_lens)
x = x.masked_fill(mel_masks.unsqueeze(1), 0)
x = x.permute((0, 2, 1))
x = torch.nn.utils.rnn.pack_padded_sequence(x, mel_lens.cpu().int(), batch_first=True, enforce_sorted=False)
self.gru.flatten_parameters()
x, memory = self.gru(x) # memory --- [N, Ty, E//2], out --- [1, N, E//2]
x, _ = torch.nn.utils.rnn.pad_packed_sequence(x, batch_first=True)
return x, memory, mel_masks
def calculate_channels( # pylint: disable=no-self-use
self, L: int, kernel_size: int, stride: int, pad: int, n_convs: int
) -> int:
for _ in range(n_convs):
L = (L - kernel_size + 2 * pad) // stride + 1
return L
class UtteranceLevelProsodyEncoder(nn.Module):
def __init__(
self,
num_mels: int,
ref_enc_filters: List[Union[int, int, int, int, int, int]],
ref_enc_size: int,
ref_enc_strides: List[Union[int, int, int, int, int]],
ref_enc_gru_size: int,
dropout: float,
n_hidden: int,
bottleneck_size_u: int,
token_num: int,
):
"""
Encoder to extract prosody from utterance. it is made up of a reference encoder
with a couple of linear layers and style token layer with dropout.
Args:
num_mels (int): Number of mel frames to produce.
ref_enc_filters (list[int]): List of channel sizes for ref encoder layers.
ref_enc_size (int): Size of the kernel for the ref encoder conv layers.
ref_enc_strides (List[int]): List of strides to use for teh ref encoder conv layers.
ref_enc_gru_size (int): Number of hidden features for the gated recurrent unit.
dropout (float): Probability of dropout.
n_hidden (int): Size of hidden layers.
bottleneck_size_u (int): Size of the bottle neck layer.
Inputs: inputs, mask
- **inputs** (batch, dim, time): Tensor containing mel vector
- **lengths** (batch): Tensor containing the mel lengths.
Returns:
- **outputs** (batch, 1, dim): Tensor produced by Utterance Level Prosody Encoder.
"""
super().__init__()
self.E = n_hidden
self.d_q = self.d_k = n_hidden
bottleneck_size = bottleneck_size_u
self.encoder = ReferenceEncoder(
ref_enc_filters=ref_enc_filters,
ref_enc_gru_size=ref_enc_gru_size,
ref_enc_size=ref_enc_size,
ref_enc_strides=ref_enc_strides,
num_mels=num_mels,
)
self.encoder_prj = nn.Linear(ref_enc_gru_size, self.E // 2)
self.stl = STL(n_hidden=n_hidden, token_num=token_num)
self.encoder_bottleneck = nn.Linear(self.E, bottleneck_size)
self.dropout = nn.Dropout(dropout)
def forward(self, mels: torch.Tensor, mel_lens: torch.Tensor) -> torch.Tensor:
"""
Shapes:
mels: :math: `[B, C, T]`
mel_lens: :math: `[B]`
out --- [N, seq_len, E]
"""
_, embedded_prosody, _ = self.encoder(mels, mel_lens)
# Bottleneck
embedded_prosody = self.encoder_prj(embedded_prosody)
# Style Token
out = self.encoder_bottleneck(self.stl(embedded_prosody))
out = self.dropout(out)
out = out.view((-1, 1, out.shape[3]))
return out
class PhonemeLevelProsodyEncoder(nn.Module):
def __init__(
self,
num_mels: int,
ref_enc_filters: List[Union[int, int, int, int, int, int]],
ref_enc_size: int,
ref_enc_strides: List[Union[int, int, int, int, int]],
ref_enc_gru_size: int,
dropout: float,
n_hidden: int,
n_heads: int,
bottleneck_size_p: int,
):
super().__init__()
self.E = n_hidden
self.d_q = self.d_k = n_hidden
bottleneck_size = bottleneck_size_p
self.encoder = ReferenceEncoder(
ref_enc_filters=ref_enc_filters,
ref_enc_gru_size=ref_enc_gru_size,
ref_enc_size=ref_enc_size,
ref_enc_strides=ref_enc_strides,
num_mels=num_mels,
)
self.encoder_prj = nn.Linear(ref_enc_gru_size, n_hidden)
self.attention = ConformerMultiHeadedSelfAttention(
d_model=n_hidden,
num_heads=n_heads,
dropout_p=dropout,
)
self.encoder_bottleneck = nn.Linear(n_hidden, bottleneck_size)
def forward(
self,
x: torch.Tensor,
src_mask: torch.Tensor,
mels: torch.Tensor,
mel_lens: torch.Tensor,
encoding: torch.Tensor,
) -> torch.Tensor:
"""
x --- [N, seq_len, encoder_embedding_dim]
mels --- [N, Ty/r, n_mels*r], r=1
out --- [N, seq_len, bottleneck_size]
attn --- [N, seq_len, ref_len], Ty/r = ref_len
"""
embedded_prosody, _, mel_masks = self.encoder(mels, mel_lens)
# Bottleneck
embedded_prosody = self.encoder_prj(embedded_prosody)
attn_mask = mel_masks.view((mel_masks.shape[0], 1, 1, -1))
x, _ = self.attention(
query=x,
key=embedded_prosody,
value=embedded_prosody,
mask=attn_mask,
encoding=encoding,
)
x = self.encoder_bottleneck(x)
x = x.masked_fill(src_mask.unsqueeze(-1), 0.0)
return x