82 lines
2.6 KiB
Python
82 lines
2.6 KiB
Python
import torch
|
|
from torch import nn
|
|
|
|
from TTS.tts.layers.glow_tts.decoder import Decoder as GlowDecoder
|
|
from TTS.tts.utils.helpers import sequence_mask
|
|
|
|
|
|
class Decoder(nn.Module):
|
|
"""Uses glow decoder with some modifications.
|
|
::
|
|
|
|
Squeeze -> ActNorm -> InvertibleConv1x1 -> AffineCoupling -> Unsqueeze
|
|
|
|
Args:
|
|
in_channels (int): channels of input tensor.
|
|
hidden_channels (int): hidden decoder channels.
|
|
kernel_size (int): Coupling block kernel size. (Wavenet filter kernel size.)
|
|
dilation_rate (int): rate to increase dilation by each layer in a decoder block.
|
|
num_flow_blocks (int): number of decoder blocks.
|
|
num_coupling_layers (int): number coupling layers. (number of wavenet layers.)
|
|
dropout_p (float): wavenet dropout rate.
|
|
sigmoid_scale (bool): enable/disable sigmoid scaling in coupling layer.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
in_channels,
|
|
hidden_channels,
|
|
kernel_size,
|
|
dilation_rate,
|
|
num_flow_blocks,
|
|
num_coupling_layers,
|
|
dropout_p=0.0,
|
|
num_splits=4,
|
|
num_squeeze=2,
|
|
sigmoid_scale=False,
|
|
c_in_channels=0,
|
|
):
|
|
super().__init__()
|
|
|
|
self.glow_decoder = GlowDecoder(
|
|
in_channels,
|
|
hidden_channels,
|
|
kernel_size,
|
|
dilation_rate,
|
|
num_flow_blocks,
|
|
num_coupling_layers,
|
|
dropout_p,
|
|
num_splits,
|
|
num_squeeze,
|
|
sigmoid_scale,
|
|
c_in_channels,
|
|
)
|
|
self.n_sqz = num_squeeze
|
|
|
|
def forward(self, x, x_len, g=None, reverse=False):
|
|
"""
|
|
Input shapes:
|
|
- x: :math:`[B, C, T]`
|
|
- x_len :math:`[B]`
|
|
- g: :math:`[B, C]`
|
|
|
|
Output shapes:
|
|
- x: :math:`[B, C, T]`
|
|
- x_len :math:`[B]`
|
|
- logget_tot :math:`[B]`
|
|
"""
|
|
x, x_len, x_max_len = self.preprocess(x, x_len, x_len.max())
|
|
x_mask = torch.unsqueeze(sequence_mask(x_len, x_max_len), 1).to(x.dtype)
|
|
x, logdet_tot = self.glow_decoder(x, x_mask, g, reverse)
|
|
return x, x_len, logdet_tot
|
|
|
|
def preprocess(self, y, y_lengths, y_max_length):
|
|
if y_max_length is not None:
|
|
y_max_length = torch.div(y_max_length, self.n_sqz, rounding_mode="floor") * self.n_sqz
|
|
y = y[:, :, :y_max_length]
|
|
y_lengths = torch.div(y_lengths, self.n_sqz, rounding_mode="floor") * self.n_sqz
|
|
return y, y_lengths, y_max_length
|
|
|
|
def store_inverse(self):
|
|
self.glow_decoder.store_inverse()
|