ai-content-maker/.venv/Lib/site-packages/TTS/tts/layers/tacotron/common_layers.py

120 lines
4.6 KiB
Python
Raw Normal View History

2024-05-03 04:18:51 +03:00
import torch
from torch import nn
from torch.nn import functional as F
class Linear(nn.Module):
"""Linear layer with a specific initialization.
Args:
in_features (int): number of channels in the input tensor.
out_features (int): number of channels in the output tensor.
bias (bool, optional): enable/disable bias in the layer. Defaults to True.
init_gain (str, optional): method to compute the gain in the weight initializtion based on the nonlinear activation used afterwards. Defaults to 'linear'.
"""
def __init__(self, in_features, out_features, bias=True, init_gain="linear"):
super().__init__()
self.linear_layer = torch.nn.Linear(in_features, out_features, bias=bias)
self._init_w(init_gain)
def _init_w(self, init_gain):
torch.nn.init.xavier_uniform_(self.linear_layer.weight, gain=torch.nn.init.calculate_gain(init_gain))
def forward(self, x):
return self.linear_layer(x)
class LinearBN(nn.Module):
"""Linear layer with Batch Normalization.
x -> linear -> BN -> o
Args:
in_features (int): number of channels in the input tensor.
out_features (int ): number of channels in the output tensor.
bias (bool, optional): enable/disable bias in the linear layer. Defaults to True.
init_gain (str, optional): method to set the gain for weight initialization. Defaults to 'linear'.
"""
def __init__(self, in_features, out_features, bias=True, init_gain="linear"):
super().__init__()
self.linear_layer = torch.nn.Linear(in_features, out_features, bias=bias)
self.batch_normalization = nn.BatchNorm1d(out_features, momentum=0.1, eps=1e-5)
self._init_w(init_gain)
def _init_w(self, init_gain):
torch.nn.init.xavier_uniform_(self.linear_layer.weight, gain=torch.nn.init.calculate_gain(init_gain))
def forward(self, x):
"""
Shapes:
x: [T, B, C] or [B, C]
"""
out = self.linear_layer(x)
if len(out.shape) == 3:
out = out.permute(1, 2, 0)
out = self.batch_normalization(out)
if len(out.shape) == 3:
out = out.permute(2, 0, 1)
return out
class Prenet(nn.Module):
"""Tacotron specific Prenet with an optional Batch Normalization.
Note:
Prenet with BN improves the model performance significantly especially
if it is enabled after learning a diagonal attention alignment with the original
prenet. However, if the target dataset is high quality then it also works from
the start. It is also suggested to disable dropout if BN is in use.
prenet_type == "original"
x -> [linear -> ReLU -> Dropout]xN -> o
prenet_type == "bn"
x -> [linear -> BN -> ReLU -> Dropout]xN -> o
Args:
in_features (int): number of channels in the input tensor and the inner layers.
prenet_type (str, optional): prenet type "original" or "bn". Defaults to "original".
prenet_dropout (bool, optional): dropout rate. Defaults to True.
dropout_at_inference (bool, optional): use dropout at inference. It leads to a better quality for some models.
out_features (list, optional): List of output channels for each prenet block.
It also defines number of the prenet blocks based on the length of argument list.
Defaults to [256, 256].
bias (bool, optional): enable/disable bias in prenet linear layers. Defaults to True.
"""
# pylint: disable=dangerous-default-value
def __init__(
self,
in_features,
prenet_type="original",
prenet_dropout=True,
dropout_at_inference=False,
out_features=[256, 256],
bias=True,
):
super().__init__()
self.prenet_type = prenet_type
self.prenet_dropout = prenet_dropout
self.dropout_at_inference = dropout_at_inference
in_features = [in_features] + out_features[:-1]
if prenet_type == "bn":
self.linear_layers = nn.ModuleList(
[LinearBN(in_size, out_size, bias=bias) for (in_size, out_size) in zip(in_features, out_features)]
)
elif prenet_type == "original":
self.linear_layers = nn.ModuleList(
[Linear(in_size, out_size, bias=bias) for (in_size, out_size) in zip(in_features, out_features)]
)
def forward(self, x):
for linear in self.linear_layers:
if self.prenet_dropout:
x = F.dropout(F.relu(linear(x)), p=0.5, training=self.training or self.dropout_at_inference)
else:
x = F.relu(linear(x))
return x