180 lines
6.7 KiB
Python
180 lines
6.7 KiB
Python
|
import math
|
||
|
|
||
|
import torch
|
||
|
from torch import nn
|
||
|
|
||
|
from TTS.tts.layers.generic.gated_conv import GatedConvBlock
|
||
|
from TTS.tts.layers.generic.res_conv_bn import ResidualConv1dBNBlock
|
||
|
from TTS.tts.layers.generic.time_depth_sep_conv import TimeDepthSeparableConvBlock
|
||
|
from TTS.tts.layers.glow_tts.duration_predictor import DurationPredictor
|
||
|
from TTS.tts.layers.glow_tts.glow import ResidualConv1dLayerNormBlock
|
||
|
from TTS.tts.layers.glow_tts.transformer import RelativePositionTransformer
|
||
|
from TTS.tts.utils.helpers import sequence_mask
|
||
|
|
||
|
|
||
|
class Encoder(nn.Module):
|
||
|
"""Glow-TTS encoder module.
|
||
|
|
||
|
::
|
||
|
|
||
|
embedding -> <prenet> -> encoder_module -> <postnet> --> proj_mean
|
||
|
|
|
||
|
|-> proj_var
|
||
|
|
|
||
|
|-> concat -> duration_predictor
|
||
|
↑
|
||
|
speaker_embed
|
||
|
|
||
|
Args:
|
||
|
num_chars (int): number of characters.
|
||
|
out_channels (int): number of output channels.
|
||
|
hidden_channels (int): encoder's embedding size.
|
||
|
hidden_channels_ffn (int): transformer's feed-forward channels.
|
||
|
kernel_size (int): kernel size for conv layers and duration predictor.
|
||
|
dropout_p (float): dropout rate for any dropout layer.
|
||
|
mean_only (bool): if True, output only mean values and use constant std.
|
||
|
use_prenet (bool): if True, use pre-convolutional layers before transformer layers.
|
||
|
c_in_channels (int): number of channels in conditional input.
|
||
|
|
||
|
Shapes:
|
||
|
- input: (B, T, C)
|
||
|
|
||
|
::
|
||
|
|
||
|
suggested encoder params...
|
||
|
|
||
|
for encoder_type == 'rel_pos_transformer'
|
||
|
encoder_params={
|
||
|
'kernel_size':3,
|
||
|
'dropout_p': 0.1,
|
||
|
'num_layers': 6,
|
||
|
'num_heads': 2,
|
||
|
'hidden_channels_ffn': 768, # 4 times the hidden_channels
|
||
|
'input_length': None
|
||
|
}
|
||
|
|
||
|
for encoder_type == 'gated_conv'
|
||
|
encoder_params={
|
||
|
'kernel_size':5,
|
||
|
'dropout_p': 0.1,
|
||
|
'num_layers': 9,
|
||
|
}
|
||
|
|
||
|
for encoder_type == 'residual_conv_bn'
|
||
|
encoder_params={
|
||
|
"kernel_size": 4,
|
||
|
"dilations": [1, 2, 4, 1, 2, 4, 1, 2, 4, 1, 2, 4, 1],
|
||
|
"num_conv_blocks": 2,
|
||
|
"num_res_blocks": 13
|
||
|
}
|
||
|
|
||
|
for encoder_type == 'time_depth_separable'
|
||
|
encoder_params={
|
||
|
"kernel_size": 5,
|
||
|
'num_layers': 9,
|
||
|
}
|
||
|
"""
|
||
|
|
||
|
def __init__(
|
||
|
self,
|
||
|
num_chars,
|
||
|
out_channels,
|
||
|
hidden_channels,
|
||
|
hidden_channels_dp,
|
||
|
encoder_type,
|
||
|
encoder_params,
|
||
|
dropout_p_dp=0.1,
|
||
|
mean_only=False,
|
||
|
use_prenet=True,
|
||
|
c_in_channels=0,
|
||
|
):
|
||
|
super().__init__()
|
||
|
# class arguments
|
||
|
self.num_chars = num_chars
|
||
|
self.out_channels = out_channels
|
||
|
self.hidden_channels = hidden_channels
|
||
|
self.hidden_channels_dp = hidden_channels_dp
|
||
|
self.dropout_p_dp = dropout_p_dp
|
||
|
self.mean_only = mean_only
|
||
|
self.use_prenet = use_prenet
|
||
|
self.c_in_channels = c_in_channels
|
||
|
self.encoder_type = encoder_type
|
||
|
# embedding layer
|
||
|
self.emb = nn.Embedding(num_chars, hidden_channels)
|
||
|
nn.init.normal_(self.emb.weight, 0.0, hidden_channels**-0.5)
|
||
|
# init encoder module
|
||
|
if encoder_type.lower() == "rel_pos_transformer":
|
||
|
if use_prenet:
|
||
|
self.prenet = ResidualConv1dLayerNormBlock(
|
||
|
hidden_channels, hidden_channels, hidden_channels, kernel_size=5, num_layers=3, dropout_p=0.5
|
||
|
)
|
||
|
self.encoder = RelativePositionTransformer(
|
||
|
hidden_channels, hidden_channels, hidden_channels, **encoder_params
|
||
|
)
|
||
|
elif encoder_type.lower() == "gated_conv":
|
||
|
self.encoder = GatedConvBlock(hidden_channels, **encoder_params)
|
||
|
elif encoder_type.lower() == "residual_conv_bn":
|
||
|
if use_prenet:
|
||
|
self.prenet = nn.Sequential(nn.Conv1d(hidden_channels, hidden_channels, 1), nn.ReLU())
|
||
|
self.encoder = ResidualConv1dBNBlock(hidden_channels, hidden_channels, hidden_channels, **encoder_params)
|
||
|
self.postnet = nn.Sequential(
|
||
|
nn.Conv1d(self.hidden_channels, self.hidden_channels, 1), nn.BatchNorm1d(self.hidden_channels)
|
||
|
)
|
||
|
elif encoder_type.lower() == "time_depth_separable":
|
||
|
if use_prenet:
|
||
|
self.prenet = ResidualConv1dLayerNormBlock(
|
||
|
hidden_channels, hidden_channels, hidden_channels, kernel_size=5, num_layers=3, dropout_p=0.5
|
||
|
)
|
||
|
self.encoder = TimeDepthSeparableConvBlock(
|
||
|
hidden_channels, hidden_channels, hidden_channels, **encoder_params
|
||
|
)
|
||
|
else:
|
||
|
raise ValueError(" [!] Unkown encoder type.")
|
||
|
|
||
|
# final projection layers
|
||
|
self.proj_m = nn.Conv1d(hidden_channels, out_channels, 1)
|
||
|
if not mean_only:
|
||
|
self.proj_s = nn.Conv1d(hidden_channels, out_channels, 1)
|
||
|
# duration predictor
|
||
|
self.duration_predictor = DurationPredictor(
|
||
|
hidden_channels + c_in_channels, hidden_channels_dp, 3, dropout_p_dp
|
||
|
)
|
||
|
|
||
|
def forward(self, x, x_lengths, g=None):
|
||
|
"""
|
||
|
Shapes:
|
||
|
- x: :math:`[B, C, T]`
|
||
|
- x_lengths: :math:`[B]`
|
||
|
- g (optional): :math:`[B, 1, T]`
|
||
|
"""
|
||
|
# embedding layer
|
||
|
# [B ,T, D]
|
||
|
x = self.emb(x) * math.sqrt(self.hidden_channels)
|
||
|
# [B, D, T]
|
||
|
x = torch.transpose(x, 1, -1)
|
||
|
# compute input sequence mask
|
||
|
x_mask = torch.unsqueeze(sequence_mask(x_lengths, x.size(2)), 1).to(x.dtype)
|
||
|
# prenet
|
||
|
if hasattr(self, "prenet") and self.use_prenet:
|
||
|
x = self.prenet(x, x_mask)
|
||
|
# encoder
|
||
|
x = self.encoder(x, x_mask)
|
||
|
# postnet
|
||
|
if hasattr(self, "postnet"):
|
||
|
x = self.postnet(x) * x_mask
|
||
|
# set duration predictor input
|
||
|
if g is not None:
|
||
|
g_exp = g.expand(-1, -1, x.size(-1))
|
||
|
x_dp = torch.cat([x.detach(), g_exp], 1)
|
||
|
else:
|
||
|
x_dp = x.detach()
|
||
|
# final projection layer
|
||
|
x_m = self.proj_m(x) * x_mask
|
||
|
if not self.mean_only:
|
||
|
x_logs = self.proj_s(x) * x_mask
|
||
|
else:
|
||
|
x_logs = torch.zeros_like(x_m)
|
||
|
# duration predictor
|
||
|
logw = self.duration_predictor(x_dp, x_mask)
|
||
|
return x_m, x_logs, logw, x_mask
|