103 lines
3.6 KiB
Python
103 lines
3.6 KiB
Python
|
import torch
|
||
|
from torch.nn import functional as F
|
||
|
|
||
|
|
||
|
class Stretch2d(torch.nn.Module):
|
||
|
def __init__(self, x_scale, y_scale, mode="nearest"):
|
||
|
super().__init__()
|
||
|
self.x_scale = x_scale
|
||
|
self.y_scale = y_scale
|
||
|
self.mode = mode
|
||
|
|
||
|
def forward(self, x):
|
||
|
"""
|
||
|
x (Tensor): Input tensor (B, C, F, T).
|
||
|
Tensor: Interpolated tensor (B, C, F * y_scale, T * x_scale),
|
||
|
"""
|
||
|
return F.interpolate(x, scale_factor=(self.y_scale, self.x_scale), mode=self.mode)
|
||
|
|
||
|
|
||
|
class UpsampleNetwork(torch.nn.Module):
|
||
|
# pylint: disable=dangerous-default-value
|
||
|
def __init__(
|
||
|
self,
|
||
|
upsample_factors,
|
||
|
nonlinear_activation=None,
|
||
|
nonlinear_activation_params={},
|
||
|
interpolate_mode="nearest",
|
||
|
freq_axis_kernel_size=1,
|
||
|
use_causal_conv=False,
|
||
|
):
|
||
|
super().__init__()
|
||
|
self.use_causal_conv = use_causal_conv
|
||
|
self.up_layers = torch.nn.ModuleList()
|
||
|
for scale in upsample_factors:
|
||
|
# interpolation layer
|
||
|
stretch = Stretch2d(scale, 1, interpolate_mode)
|
||
|
self.up_layers += [stretch]
|
||
|
|
||
|
# conv layer
|
||
|
assert (freq_axis_kernel_size - 1) % 2 == 0, "Not support even number freq axis kernel size."
|
||
|
freq_axis_padding = (freq_axis_kernel_size - 1) // 2
|
||
|
kernel_size = (freq_axis_kernel_size, scale * 2 + 1)
|
||
|
if use_causal_conv:
|
||
|
padding = (freq_axis_padding, scale * 2)
|
||
|
else:
|
||
|
padding = (freq_axis_padding, scale)
|
||
|
conv = torch.nn.Conv2d(1, 1, kernel_size=kernel_size, padding=padding, bias=False)
|
||
|
self.up_layers += [conv]
|
||
|
|
||
|
# nonlinear
|
||
|
if nonlinear_activation is not None:
|
||
|
nonlinear = getattr(torch.nn, nonlinear_activation)(**nonlinear_activation_params)
|
||
|
self.up_layers += [nonlinear]
|
||
|
|
||
|
def forward(self, c):
|
||
|
"""
|
||
|
c : (B, C, T_in).
|
||
|
Tensor: (B, C, T_upsample)
|
||
|
"""
|
||
|
c = c.unsqueeze(1) # (B, 1, C, T)
|
||
|
for f in self.up_layers:
|
||
|
c = f(c)
|
||
|
return c.squeeze(1) # (B, C, T')
|
||
|
|
||
|
|
||
|
class ConvUpsample(torch.nn.Module):
|
||
|
# pylint: disable=dangerous-default-value
|
||
|
def __init__(
|
||
|
self,
|
||
|
upsample_factors,
|
||
|
nonlinear_activation=None,
|
||
|
nonlinear_activation_params={},
|
||
|
interpolate_mode="nearest",
|
||
|
freq_axis_kernel_size=1,
|
||
|
aux_channels=80,
|
||
|
aux_context_window=0,
|
||
|
use_causal_conv=False,
|
||
|
):
|
||
|
super().__init__()
|
||
|
self.aux_context_window = aux_context_window
|
||
|
self.use_causal_conv = use_causal_conv and aux_context_window > 0
|
||
|
# To capture wide-context information in conditional features
|
||
|
kernel_size = aux_context_window + 1 if use_causal_conv else 2 * aux_context_window + 1
|
||
|
# NOTE(kan-bayashi): Here do not use padding because the input is already padded
|
||
|
self.conv_in = torch.nn.Conv1d(aux_channels, aux_channels, kernel_size=kernel_size, bias=False)
|
||
|
self.upsample = UpsampleNetwork(
|
||
|
upsample_factors=upsample_factors,
|
||
|
nonlinear_activation=nonlinear_activation,
|
||
|
nonlinear_activation_params=nonlinear_activation_params,
|
||
|
interpolate_mode=interpolate_mode,
|
||
|
freq_axis_kernel_size=freq_axis_kernel_size,
|
||
|
use_causal_conv=use_causal_conv,
|
||
|
)
|
||
|
|
||
|
def forward(self, c):
|
||
|
"""
|
||
|
c : (B, C, T_in).
|
||
|
Tensor: (B, C, T_upsampled),
|
||
|
"""
|
||
|
c_ = self.conv_in(c)
|
||
|
c = c_[:, :, : -self.aux_context_window] if self.use_causal_conv else c_
|
||
|
return self.upsample(c)
|