504 lines
18 KiB
Python
504 lines
18 KiB
Python
|
# coding: utf-8
|
||
|
# adapted from https://github.com/r9y9/tacotron_pytorch
|
||
|
|
||
|
import torch
|
||
|
from torch import nn
|
||
|
|
||
|
from .attentions import init_attn
|
||
|
from .common_layers import Prenet
|
||
|
|
||
|
|
||
|
class BatchNormConv1d(nn.Module):
|
||
|
r"""A wrapper for Conv1d with BatchNorm. It sets the activation
|
||
|
function between Conv and BatchNorm layers. BatchNorm layer
|
||
|
is initialized with the TF default values for momentum and eps.
|
||
|
|
||
|
Args:
|
||
|
in_channels: size of each input sample
|
||
|
out_channels: size of each output samples
|
||
|
kernel_size: kernel size of conv filters
|
||
|
stride: stride of conv filters
|
||
|
padding: padding of conv filters
|
||
|
activation: activation function set b/w Conv1d and BatchNorm
|
||
|
|
||
|
Shapes:
|
||
|
- input: (B, D)
|
||
|
- output: (B, D)
|
||
|
"""
|
||
|
|
||
|
def __init__(self, in_channels, out_channels, kernel_size, stride, padding, activation=None):
|
||
|
super().__init__()
|
||
|
self.padding = padding
|
||
|
self.padder = nn.ConstantPad1d(padding, 0)
|
||
|
self.conv1d = nn.Conv1d(
|
||
|
in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=0, bias=False
|
||
|
)
|
||
|
# Following tensorflow's default parameters
|
||
|
self.bn = nn.BatchNorm1d(out_channels, momentum=0.99, eps=1e-3)
|
||
|
self.activation = activation
|
||
|
# self.init_layers()
|
||
|
|
||
|
def init_layers(self):
|
||
|
if isinstance(self.activation, torch.nn.ReLU):
|
||
|
w_gain = "relu"
|
||
|
elif isinstance(self.activation, torch.nn.Tanh):
|
||
|
w_gain = "tanh"
|
||
|
elif self.activation is None:
|
||
|
w_gain = "linear"
|
||
|
else:
|
||
|
raise RuntimeError("Unknown activation function")
|
||
|
torch.nn.init.xavier_uniform_(self.conv1d.weight, gain=torch.nn.init.calculate_gain(w_gain))
|
||
|
|
||
|
def forward(self, x):
|
||
|
x = self.padder(x)
|
||
|
x = self.conv1d(x)
|
||
|
x = self.bn(x)
|
||
|
if self.activation is not None:
|
||
|
x = self.activation(x)
|
||
|
return x
|
||
|
|
||
|
|
||
|
class Highway(nn.Module):
|
||
|
r"""Highway layers as explained in https://arxiv.org/abs/1505.00387
|
||
|
|
||
|
Args:
|
||
|
in_features (int): size of each input sample
|
||
|
out_feature (int): size of each output sample
|
||
|
|
||
|
Shapes:
|
||
|
- input: (B, *, H_in)
|
||
|
- output: (B, *, H_out)
|
||
|
"""
|
||
|
|
||
|
# TODO: Try GLU layer
|
||
|
def __init__(self, in_features, out_feature):
|
||
|
super().__init__()
|
||
|
self.H = nn.Linear(in_features, out_feature)
|
||
|
self.H.bias.data.zero_()
|
||
|
self.T = nn.Linear(in_features, out_feature)
|
||
|
self.T.bias.data.fill_(-1)
|
||
|
self.relu = nn.ReLU()
|
||
|
self.sigmoid = nn.Sigmoid()
|
||
|
# self.init_layers()
|
||
|
|
||
|
def init_layers(self):
|
||
|
torch.nn.init.xavier_uniform_(self.H.weight, gain=torch.nn.init.calculate_gain("relu"))
|
||
|
torch.nn.init.xavier_uniform_(self.T.weight, gain=torch.nn.init.calculate_gain("sigmoid"))
|
||
|
|
||
|
def forward(self, inputs):
|
||
|
H = self.relu(self.H(inputs))
|
||
|
T = self.sigmoid(self.T(inputs))
|
||
|
return H * T + inputs * (1.0 - T)
|
||
|
|
||
|
|
||
|
class CBHG(nn.Module):
|
||
|
"""CBHG module: a recurrent neural network composed of:
|
||
|
- 1-d convolution banks
|
||
|
- Highway networks + residual connections
|
||
|
- Bidirectional gated recurrent units
|
||
|
|
||
|
Args:
|
||
|
in_features (int): sample size
|
||
|
K (int): max filter size in conv bank
|
||
|
projections (list): conv channel sizes for conv projections
|
||
|
num_highways (int): number of highways layers
|
||
|
|
||
|
Shapes:
|
||
|
- input: (B, C, T_in)
|
||
|
- output: (B, T_in, C*2)
|
||
|
"""
|
||
|
|
||
|
# pylint: disable=dangerous-default-value
|
||
|
def __init__(
|
||
|
self,
|
||
|
in_features,
|
||
|
K=16,
|
||
|
conv_bank_features=128,
|
||
|
conv_projections=[128, 128],
|
||
|
highway_features=128,
|
||
|
gru_features=128,
|
||
|
num_highways=4,
|
||
|
):
|
||
|
super().__init__()
|
||
|
self.in_features = in_features
|
||
|
self.conv_bank_features = conv_bank_features
|
||
|
self.highway_features = highway_features
|
||
|
self.gru_features = gru_features
|
||
|
self.conv_projections = conv_projections
|
||
|
self.relu = nn.ReLU()
|
||
|
# list of conv1d bank with filter size k=1...K
|
||
|
# TODO: try dilational layers instead
|
||
|
self.conv1d_banks = nn.ModuleList(
|
||
|
[
|
||
|
BatchNormConv1d(
|
||
|
in_features,
|
||
|
conv_bank_features,
|
||
|
kernel_size=k,
|
||
|
stride=1,
|
||
|
padding=[(k - 1) // 2, k // 2],
|
||
|
activation=self.relu,
|
||
|
)
|
||
|
for k in range(1, K + 1)
|
||
|
]
|
||
|
)
|
||
|
# max pooling of conv bank, with padding
|
||
|
# TODO: try average pooling OR larger kernel size
|
||
|
out_features = [K * conv_bank_features] + conv_projections[:-1]
|
||
|
activations = [self.relu] * (len(conv_projections) - 1)
|
||
|
activations += [None]
|
||
|
# setup conv1d projection layers
|
||
|
layer_set = []
|
||
|
for in_size, out_size, ac in zip(out_features, conv_projections, activations):
|
||
|
layer = BatchNormConv1d(in_size, out_size, kernel_size=3, stride=1, padding=[1, 1], activation=ac)
|
||
|
layer_set.append(layer)
|
||
|
self.conv1d_projections = nn.ModuleList(layer_set)
|
||
|
# setup Highway layers
|
||
|
if self.highway_features != conv_projections[-1]:
|
||
|
self.pre_highway = nn.Linear(conv_projections[-1], highway_features, bias=False)
|
||
|
self.highways = nn.ModuleList([Highway(highway_features, highway_features) for _ in range(num_highways)])
|
||
|
# bi-directional GPU layer
|
||
|
self.gru = nn.GRU(gru_features, gru_features, 1, batch_first=True, bidirectional=True)
|
||
|
|
||
|
def forward(self, inputs):
|
||
|
# (B, in_features, T_in)
|
||
|
x = inputs
|
||
|
# (B, hid_features*K, T_in)
|
||
|
# Concat conv1d bank outputs
|
||
|
outs = []
|
||
|
for conv1d in self.conv1d_banks:
|
||
|
out = conv1d(x)
|
||
|
outs.append(out)
|
||
|
x = torch.cat(outs, dim=1)
|
||
|
assert x.size(1) == self.conv_bank_features * len(self.conv1d_banks)
|
||
|
for conv1d in self.conv1d_projections:
|
||
|
x = conv1d(x)
|
||
|
x += inputs
|
||
|
x = x.transpose(1, 2)
|
||
|
if self.highway_features != self.conv_projections[-1]:
|
||
|
x = self.pre_highway(x)
|
||
|
# Residual connection
|
||
|
# TODO: try residual scaling as in Deep Voice 3
|
||
|
# TODO: try plain residual layers
|
||
|
for highway in self.highways:
|
||
|
x = highway(x)
|
||
|
# (B, T_in, hid_features*2)
|
||
|
# TODO: replace GRU with convolution as in Deep Voice 3
|
||
|
self.gru.flatten_parameters()
|
||
|
outputs, _ = self.gru(x)
|
||
|
return outputs
|
||
|
|
||
|
|
||
|
class EncoderCBHG(nn.Module):
|
||
|
r"""CBHG module with Encoder specific arguments"""
|
||
|
|
||
|
def __init__(self):
|
||
|
super().__init__()
|
||
|
self.cbhg = CBHG(
|
||
|
128,
|
||
|
K=16,
|
||
|
conv_bank_features=128,
|
||
|
conv_projections=[128, 128],
|
||
|
highway_features=128,
|
||
|
gru_features=128,
|
||
|
num_highways=4,
|
||
|
)
|
||
|
|
||
|
def forward(self, x):
|
||
|
return self.cbhg(x)
|
||
|
|
||
|
|
||
|
class Encoder(nn.Module):
|
||
|
r"""Stack Prenet and CBHG module for encoder
|
||
|
Args:
|
||
|
inputs (FloatTensor): embedding features
|
||
|
|
||
|
Shapes:
|
||
|
- inputs: (B, T, D_in)
|
||
|
- outputs: (B, T, 128 * 2)
|
||
|
"""
|
||
|
|
||
|
def __init__(self, in_features):
|
||
|
super().__init__()
|
||
|
self.prenet = Prenet(in_features, out_features=[256, 128])
|
||
|
self.cbhg = EncoderCBHG()
|
||
|
|
||
|
def forward(self, inputs):
|
||
|
# B x T x prenet_dim
|
||
|
outputs = self.prenet(inputs)
|
||
|
outputs = self.cbhg(outputs.transpose(1, 2))
|
||
|
return outputs
|
||
|
|
||
|
|
||
|
class PostCBHG(nn.Module):
|
||
|
def __init__(self, mel_dim):
|
||
|
super().__init__()
|
||
|
self.cbhg = CBHG(
|
||
|
mel_dim,
|
||
|
K=8,
|
||
|
conv_bank_features=128,
|
||
|
conv_projections=[256, mel_dim],
|
||
|
highway_features=128,
|
||
|
gru_features=128,
|
||
|
num_highways=4,
|
||
|
)
|
||
|
|
||
|
def forward(self, x):
|
||
|
return self.cbhg(x)
|
||
|
|
||
|
|
||
|
class Decoder(nn.Module):
|
||
|
"""Tacotron decoder.
|
||
|
|
||
|
Args:
|
||
|
in_channels (int): number of input channels.
|
||
|
frame_channels (int): number of feature frame channels.
|
||
|
r (int): number of outputs per time step (reduction rate).
|
||
|
memory_size (int): size of the past window. if <= 0 memory_size = r
|
||
|
attn_type (string): type of attention used in decoder.
|
||
|
attn_windowing (bool): if true, define an attention window centered to maximum
|
||
|
attention response. It provides more robust attention alignment especially
|
||
|
at interence time.
|
||
|
attn_norm (string): attention normalization function. 'sigmoid' or 'softmax'.
|
||
|
prenet_type (string): 'original' or 'bn'.
|
||
|
prenet_dropout (float): prenet dropout rate.
|
||
|
forward_attn (bool): if true, use forward attention method. https://arxiv.org/abs/1807.06736
|
||
|
trans_agent (bool): if true, use transition agent. https://arxiv.org/abs/1807.06736
|
||
|
forward_attn_mask (bool): if true, mask attention values smaller than a threshold.
|
||
|
location_attn (bool): if true, use location sensitive attention.
|
||
|
attn_K (int): number of attention heads for GravesAttention.
|
||
|
separate_stopnet (bool): if true, detach stopnet input to prevent gradient flow.
|
||
|
d_vector_dim (int): size of speaker embedding vector, for multi-speaker training.
|
||
|
max_decoder_steps (int): Maximum number of steps allowed for the decoder. Defaults to 500.
|
||
|
"""
|
||
|
|
||
|
# Pylint gets confused by PyTorch conventions here
|
||
|
# pylint: disable=attribute-defined-outside-init
|
||
|
|
||
|
def __init__(
|
||
|
self,
|
||
|
in_channels,
|
||
|
frame_channels,
|
||
|
r,
|
||
|
memory_size,
|
||
|
attn_type,
|
||
|
attn_windowing,
|
||
|
attn_norm,
|
||
|
prenet_type,
|
||
|
prenet_dropout,
|
||
|
forward_attn,
|
||
|
trans_agent,
|
||
|
forward_attn_mask,
|
||
|
location_attn,
|
||
|
attn_K,
|
||
|
separate_stopnet,
|
||
|
max_decoder_steps,
|
||
|
):
|
||
|
super().__init__()
|
||
|
self.r_init = r
|
||
|
self.r = r
|
||
|
self.in_channels = in_channels
|
||
|
self.max_decoder_steps = max_decoder_steps
|
||
|
self.use_memory_queue = memory_size > 0
|
||
|
self.memory_size = memory_size if memory_size > 0 else r
|
||
|
self.frame_channels = frame_channels
|
||
|
self.separate_stopnet = separate_stopnet
|
||
|
self.query_dim = 256
|
||
|
# memory -> |Prenet| -> processed_memory
|
||
|
prenet_dim = frame_channels * self.memory_size if self.use_memory_queue else frame_channels
|
||
|
self.prenet = Prenet(prenet_dim, prenet_type, prenet_dropout, out_features=[256, 128])
|
||
|
# processed_inputs, processed_memory -> |Attention| -> Attention, attention, RNN_State
|
||
|
# attention_rnn generates queries for the attention mechanism
|
||
|
self.attention_rnn = nn.GRUCell(in_channels + 128, self.query_dim)
|
||
|
self.attention = init_attn(
|
||
|
attn_type=attn_type,
|
||
|
query_dim=self.query_dim,
|
||
|
embedding_dim=in_channels,
|
||
|
attention_dim=128,
|
||
|
location_attention=location_attn,
|
||
|
attention_location_n_filters=32,
|
||
|
attention_location_kernel_size=31,
|
||
|
windowing=attn_windowing,
|
||
|
norm=attn_norm,
|
||
|
forward_attn=forward_attn,
|
||
|
trans_agent=trans_agent,
|
||
|
forward_attn_mask=forward_attn_mask,
|
||
|
attn_K=attn_K,
|
||
|
)
|
||
|
# (processed_memory | attention context) -> |Linear| -> decoder_RNN_input
|
||
|
self.project_to_decoder_in = nn.Linear(256 + in_channels, 256)
|
||
|
# decoder_RNN_input -> |RNN| -> RNN_state
|
||
|
self.decoder_rnns = nn.ModuleList([nn.GRUCell(256, 256) for _ in range(2)])
|
||
|
# RNN_state -> |Linear| -> mel_spec
|
||
|
self.proj_to_mel = nn.Linear(256, frame_channels * self.r_init)
|
||
|
# learn init values instead of zero init.
|
||
|
self.stopnet = StopNet(256 + frame_channels * self.r_init)
|
||
|
|
||
|
def set_r(self, new_r):
|
||
|
self.r = new_r
|
||
|
|
||
|
def _reshape_memory(self, memory):
|
||
|
"""
|
||
|
Reshape the spectrograms for given 'r'
|
||
|
"""
|
||
|
# Grouping multiple frames if necessary
|
||
|
if memory.size(-1) == self.frame_channels:
|
||
|
memory = memory.view(memory.shape[0], memory.size(1) // self.r, -1)
|
||
|
# Time first (T_decoder, B, frame_channels)
|
||
|
memory = memory.transpose(0, 1)
|
||
|
return memory
|
||
|
|
||
|
def _init_states(self, inputs):
|
||
|
"""
|
||
|
Initialization of decoder states
|
||
|
"""
|
||
|
B = inputs.size(0)
|
||
|
# go frame as zeros matrix
|
||
|
if self.use_memory_queue:
|
||
|
self.memory_input = torch.zeros(1, device=inputs.device).repeat(B, self.frame_channels * self.memory_size)
|
||
|
else:
|
||
|
self.memory_input = torch.zeros(1, device=inputs.device).repeat(B, self.frame_channels)
|
||
|
# decoder states
|
||
|
self.attention_rnn_hidden = torch.zeros(1, device=inputs.device).repeat(B, 256)
|
||
|
self.decoder_rnn_hiddens = [
|
||
|
torch.zeros(1, device=inputs.device).repeat(B, 256) for idx in range(len(self.decoder_rnns))
|
||
|
]
|
||
|
self.context_vec = inputs.data.new(B, self.in_channels).zero_()
|
||
|
# cache attention inputs
|
||
|
self.processed_inputs = self.attention.preprocess_inputs(inputs)
|
||
|
|
||
|
def _parse_outputs(self, outputs, attentions, stop_tokens):
|
||
|
# Back to batch first
|
||
|
attentions = torch.stack(attentions).transpose(0, 1)
|
||
|
stop_tokens = torch.stack(stop_tokens).transpose(0, 1)
|
||
|
outputs = torch.stack(outputs).transpose(0, 1).contiguous()
|
||
|
outputs = outputs.view(outputs.size(0), -1, self.frame_channels)
|
||
|
outputs = outputs.transpose(1, 2)
|
||
|
return outputs, attentions, stop_tokens
|
||
|
|
||
|
def decode(self, inputs, mask=None):
|
||
|
# Prenet
|
||
|
processed_memory = self.prenet(self.memory_input)
|
||
|
# Attention RNN
|
||
|
self.attention_rnn_hidden = self.attention_rnn(
|
||
|
torch.cat((processed_memory, self.context_vec), -1), self.attention_rnn_hidden
|
||
|
)
|
||
|
self.context_vec = self.attention(self.attention_rnn_hidden, inputs, self.processed_inputs, mask)
|
||
|
# Concat RNN output and attention context vector
|
||
|
decoder_input = self.project_to_decoder_in(torch.cat((self.attention_rnn_hidden, self.context_vec), -1))
|
||
|
|
||
|
# Pass through the decoder RNNs
|
||
|
for idx, decoder_rnn in enumerate(self.decoder_rnns):
|
||
|
self.decoder_rnn_hiddens[idx] = decoder_rnn(decoder_input, self.decoder_rnn_hiddens[idx])
|
||
|
# Residual connection
|
||
|
decoder_input = self.decoder_rnn_hiddens[idx] + decoder_input
|
||
|
decoder_output = decoder_input
|
||
|
|
||
|
# predict mel vectors from decoder vectors
|
||
|
output = self.proj_to_mel(decoder_output)
|
||
|
# output = torch.sigmoid(output)
|
||
|
# predict stop token
|
||
|
stopnet_input = torch.cat([decoder_output, output], -1)
|
||
|
if self.separate_stopnet:
|
||
|
stop_token = self.stopnet(stopnet_input.detach())
|
||
|
else:
|
||
|
stop_token = self.stopnet(stopnet_input)
|
||
|
output = output[:, : self.r * self.frame_channels]
|
||
|
return output, stop_token, self.attention.attention_weights
|
||
|
|
||
|
def _update_memory_input(self, new_memory):
|
||
|
if self.use_memory_queue:
|
||
|
if self.memory_size > self.r:
|
||
|
# memory queue size is larger than number of frames per decoder iter
|
||
|
self.memory_input = torch.cat(
|
||
|
[new_memory, self.memory_input[:, : (self.memory_size - self.r) * self.frame_channels].clone()],
|
||
|
dim=-1,
|
||
|
)
|
||
|
else:
|
||
|
# memory queue size smaller than number of frames per decoder iter
|
||
|
self.memory_input = new_memory[:, : self.memory_size * self.frame_channels]
|
||
|
else:
|
||
|
# use only the last frame prediction
|
||
|
# assert new_memory.shape[-1] == self.r * self.frame_channels
|
||
|
self.memory_input = new_memory[:, self.frame_channels * (self.r - 1) :]
|
||
|
|
||
|
def forward(self, inputs, memory, mask):
|
||
|
"""
|
||
|
Args:
|
||
|
inputs: Encoder outputs.
|
||
|
memory: Decoder memory (autoregression. If None (at eval-time),
|
||
|
decoder outputs are used as decoder inputs. If None, it uses the last
|
||
|
output as the input.
|
||
|
mask: Attention mask for sequence padding.
|
||
|
|
||
|
Shapes:
|
||
|
- inputs: (B, T, D_out_enc)
|
||
|
- memory: (B, T_mel, D_mel)
|
||
|
"""
|
||
|
# Run greedy decoding if memory is None
|
||
|
memory = self._reshape_memory(memory)
|
||
|
outputs = []
|
||
|
attentions = []
|
||
|
stop_tokens = []
|
||
|
t = 0
|
||
|
self._init_states(inputs)
|
||
|
self.attention.init_states(inputs)
|
||
|
while len(outputs) < memory.size(0):
|
||
|
if t > 0:
|
||
|
new_memory = memory[t - 1]
|
||
|
self._update_memory_input(new_memory)
|
||
|
|
||
|
output, stop_token, attention = self.decode(inputs, mask)
|
||
|
outputs += [output]
|
||
|
attentions += [attention]
|
||
|
stop_tokens += [stop_token.squeeze(1)]
|
||
|
t += 1
|
||
|
return self._parse_outputs(outputs, attentions, stop_tokens)
|
||
|
|
||
|
def inference(self, inputs):
|
||
|
"""
|
||
|
Args:
|
||
|
inputs: encoder outputs.
|
||
|
Shapes:
|
||
|
- inputs: batch x time x encoder_out_dim
|
||
|
"""
|
||
|
outputs = []
|
||
|
attentions = []
|
||
|
stop_tokens = []
|
||
|
t = 0
|
||
|
self._init_states(inputs)
|
||
|
self.attention.init_states(inputs)
|
||
|
while True:
|
||
|
if t > 0:
|
||
|
new_memory = outputs[-1]
|
||
|
self._update_memory_input(new_memory)
|
||
|
output, stop_token, attention = self.decode(inputs, None)
|
||
|
stop_token = torch.sigmoid(stop_token.data)
|
||
|
outputs += [output]
|
||
|
attentions += [attention]
|
||
|
stop_tokens += [stop_token]
|
||
|
t += 1
|
||
|
if t > inputs.shape[1] / 4 and (stop_token > 0.6 or attention[:, -1].item() > 0.6):
|
||
|
break
|
||
|
if t > self.max_decoder_steps:
|
||
|
print(" | > Decoder stopped with 'max_decoder_steps")
|
||
|
break
|
||
|
return self._parse_outputs(outputs, attentions, stop_tokens)
|
||
|
|
||
|
|
||
|
class StopNet(nn.Module):
|
||
|
r"""Stopnet signalling decoder to stop inference.
|
||
|
Args:
|
||
|
in_features (int): feature dimension of input.
|
||
|
"""
|
||
|
|
||
|
def __init__(self, in_features):
|
||
|
super().__init__()
|
||
|
self.dropout = nn.Dropout(0.1)
|
||
|
self.linear = nn.Linear(in_features, 1)
|
||
|
torch.nn.init.xavier_uniform_(self.linear.weight, gain=torch.nn.init.calculate_gain("linear"))
|
||
|
|
||
|
def forward(self, inputs):
|
||
|
outputs = self.dropout(inputs)
|
||
|
outputs = self.linear(outputs)
|
||
|
return outputs
|