120 lines
4.6 KiB
Python
120 lines
4.6 KiB
Python
|
import torch
|
||
|
from torch import nn
|
||
|
from torch.nn import functional as F
|
||
|
|
||
|
|
||
|
class Linear(nn.Module):
|
||
|
"""Linear layer with a specific initialization.
|
||
|
|
||
|
Args:
|
||
|
in_features (int): number of channels in the input tensor.
|
||
|
out_features (int): number of channels in the output tensor.
|
||
|
bias (bool, optional): enable/disable bias in the layer. Defaults to True.
|
||
|
init_gain (str, optional): method to compute the gain in the weight initializtion based on the nonlinear activation used afterwards. Defaults to 'linear'.
|
||
|
"""
|
||
|
|
||
|
def __init__(self, in_features, out_features, bias=True, init_gain="linear"):
|
||
|
super().__init__()
|
||
|
self.linear_layer = torch.nn.Linear(in_features, out_features, bias=bias)
|
||
|
self._init_w(init_gain)
|
||
|
|
||
|
def _init_w(self, init_gain):
|
||
|
torch.nn.init.xavier_uniform_(self.linear_layer.weight, gain=torch.nn.init.calculate_gain(init_gain))
|
||
|
|
||
|
def forward(self, x):
|
||
|
return self.linear_layer(x)
|
||
|
|
||
|
|
||
|
class LinearBN(nn.Module):
|
||
|
"""Linear layer with Batch Normalization.
|
||
|
|
||
|
x -> linear -> BN -> o
|
||
|
|
||
|
Args:
|
||
|
in_features (int): number of channels in the input tensor.
|
||
|
out_features (int ): number of channels in the output tensor.
|
||
|
bias (bool, optional): enable/disable bias in the linear layer. Defaults to True.
|
||
|
init_gain (str, optional): method to set the gain for weight initialization. Defaults to 'linear'.
|
||
|
"""
|
||
|
|
||
|
def __init__(self, in_features, out_features, bias=True, init_gain="linear"):
|
||
|
super().__init__()
|
||
|
self.linear_layer = torch.nn.Linear(in_features, out_features, bias=bias)
|
||
|
self.batch_normalization = nn.BatchNorm1d(out_features, momentum=0.1, eps=1e-5)
|
||
|
self._init_w(init_gain)
|
||
|
|
||
|
def _init_w(self, init_gain):
|
||
|
torch.nn.init.xavier_uniform_(self.linear_layer.weight, gain=torch.nn.init.calculate_gain(init_gain))
|
||
|
|
||
|
def forward(self, x):
|
||
|
"""
|
||
|
Shapes:
|
||
|
x: [T, B, C] or [B, C]
|
||
|
"""
|
||
|
out = self.linear_layer(x)
|
||
|
if len(out.shape) == 3:
|
||
|
out = out.permute(1, 2, 0)
|
||
|
out = self.batch_normalization(out)
|
||
|
if len(out.shape) == 3:
|
||
|
out = out.permute(2, 0, 1)
|
||
|
return out
|
||
|
|
||
|
|
||
|
class Prenet(nn.Module):
|
||
|
"""Tacotron specific Prenet with an optional Batch Normalization.
|
||
|
|
||
|
Note:
|
||
|
Prenet with BN improves the model performance significantly especially
|
||
|
if it is enabled after learning a diagonal attention alignment with the original
|
||
|
prenet. However, if the target dataset is high quality then it also works from
|
||
|
the start. It is also suggested to disable dropout if BN is in use.
|
||
|
|
||
|
prenet_type == "original"
|
||
|
x -> [linear -> ReLU -> Dropout]xN -> o
|
||
|
|
||
|
prenet_type == "bn"
|
||
|
x -> [linear -> BN -> ReLU -> Dropout]xN -> o
|
||
|
|
||
|
Args:
|
||
|
in_features (int): number of channels in the input tensor and the inner layers.
|
||
|
prenet_type (str, optional): prenet type "original" or "bn". Defaults to "original".
|
||
|
prenet_dropout (bool, optional): dropout rate. Defaults to True.
|
||
|
dropout_at_inference (bool, optional): use dropout at inference. It leads to a better quality for some models.
|
||
|
out_features (list, optional): List of output channels for each prenet block.
|
||
|
It also defines number of the prenet blocks based on the length of argument list.
|
||
|
Defaults to [256, 256].
|
||
|
bias (bool, optional): enable/disable bias in prenet linear layers. Defaults to True.
|
||
|
"""
|
||
|
|
||
|
# pylint: disable=dangerous-default-value
|
||
|
def __init__(
|
||
|
self,
|
||
|
in_features,
|
||
|
prenet_type="original",
|
||
|
prenet_dropout=True,
|
||
|
dropout_at_inference=False,
|
||
|
out_features=[256, 256],
|
||
|
bias=True,
|
||
|
):
|
||
|
super().__init__()
|
||
|
self.prenet_type = prenet_type
|
||
|
self.prenet_dropout = prenet_dropout
|
||
|
self.dropout_at_inference = dropout_at_inference
|
||
|
in_features = [in_features] + out_features[:-1]
|
||
|
if prenet_type == "bn":
|
||
|
self.linear_layers = nn.ModuleList(
|
||
|
[LinearBN(in_size, out_size, bias=bias) for (in_size, out_size) in zip(in_features, out_features)]
|
||
|
)
|
||
|
elif prenet_type == "original":
|
||
|
self.linear_layers = nn.ModuleList(
|
||
|
[Linear(in_size, out_size, bias=bias) for (in_size, out_size) in zip(in_features, out_features)]
|
||
|
)
|
||
|
|
||
|
def forward(self, x):
|
||
|
for linear in self.linear_layers:
|
||
|
if self.prenet_dropout:
|
||
|
x = F.dropout(F.relu(linear(x)), p=0.5, training=self.training or self.dropout_at_inference)
|
||
|
else:
|
||
|
x = F.relu(linear(x))
|
||
|
return x
|