167 lines
6.0 KiB
Python
167 lines
6.0 KiB
Python
|
import torch
|
||
|
import torch.nn.functional as F
|
||
|
from torch import nn
|
||
|
from torch.nn.utils.parametrizations import weight_norm
|
||
|
from torch.nn.utils.parametrize import remove_parametrizations
|
||
|
|
||
|
|
||
|
class Conv1d(nn.Conv1d):
|
||
|
def __init__(self, *args, **kwargs):
|
||
|
super().__init__(*args, **kwargs)
|
||
|
nn.init.orthogonal_(self.weight)
|
||
|
nn.init.zeros_(self.bias)
|
||
|
|
||
|
|
||
|
class PositionalEncoding(nn.Module):
|
||
|
"""Positional encoding with noise level conditioning"""
|
||
|
|
||
|
def __init__(self, n_channels, max_len=10000):
|
||
|
super().__init__()
|
||
|
self.n_channels = n_channels
|
||
|
self.max_len = max_len
|
||
|
self.C = 5000
|
||
|
self.pe = torch.zeros(0, 0)
|
||
|
|
||
|
def forward(self, x, noise_level):
|
||
|
if x.shape[2] > self.pe.shape[1]:
|
||
|
self.init_pe_matrix(x.shape[1], x.shape[2], x)
|
||
|
return x + noise_level[..., None, None] + self.pe[:, : x.size(2)].repeat(x.shape[0], 1, 1) / self.C
|
||
|
|
||
|
def init_pe_matrix(self, n_channels, max_len, x):
|
||
|
pe = torch.zeros(max_len, n_channels)
|
||
|
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
|
||
|
div_term = torch.pow(10000, torch.arange(0, n_channels, 2).float() / n_channels)
|
||
|
|
||
|
pe[:, 0::2] = torch.sin(position / div_term)
|
||
|
pe[:, 1::2] = torch.cos(position / div_term)
|
||
|
self.pe = pe.transpose(0, 1).to(x)
|
||
|
|
||
|
|
||
|
class FiLM(nn.Module):
|
||
|
def __init__(self, input_size, output_size):
|
||
|
super().__init__()
|
||
|
self.encoding = PositionalEncoding(input_size)
|
||
|
self.input_conv = nn.Conv1d(input_size, input_size, 3, padding=1)
|
||
|
self.output_conv = nn.Conv1d(input_size, output_size * 2, 3, padding=1)
|
||
|
|
||
|
nn.init.xavier_uniform_(self.input_conv.weight)
|
||
|
nn.init.xavier_uniform_(self.output_conv.weight)
|
||
|
nn.init.zeros_(self.input_conv.bias)
|
||
|
nn.init.zeros_(self.output_conv.bias)
|
||
|
|
||
|
def forward(self, x, noise_scale):
|
||
|
o = self.input_conv(x)
|
||
|
o = F.leaky_relu(o, 0.2)
|
||
|
o = self.encoding(o, noise_scale)
|
||
|
shift, scale = torch.chunk(self.output_conv(o), 2, dim=1)
|
||
|
return shift, scale
|
||
|
|
||
|
def remove_weight_norm(self):
|
||
|
remove_parametrizations(self.input_conv, "weight")
|
||
|
remove_parametrizations(self.output_conv, "weight")
|
||
|
|
||
|
def apply_weight_norm(self):
|
||
|
self.input_conv = weight_norm(self.input_conv)
|
||
|
self.output_conv = weight_norm(self.output_conv)
|
||
|
|
||
|
|
||
|
@torch.jit.script
|
||
|
def shif_and_scale(x, scale, shift):
|
||
|
o = shift + scale * x
|
||
|
return o
|
||
|
|
||
|
|
||
|
class UBlock(nn.Module):
|
||
|
def __init__(self, input_size, hidden_size, factor, dilation):
|
||
|
super().__init__()
|
||
|
assert isinstance(dilation, (list, tuple))
|
||
|
assert len(dilation) == 4
|
||
|
|
||
|
self.factor = factor
|
||
|
self.res_block = Conv1d(input_size, hidden_size, 1)
|
||
|
self.main_block = nn.ModuleList(
|
||
|
[
|
||
|
Conv1d(input_size, hidden_size, 3, dilation=dilation[0], padding=dilation[0]),
|
||
|
Conv1d(hidden_size, hidden_size, 3, dilation=dilation[1], padding=dilation[1]),
|
||
|
]
|
||
|
)
|
||
|
self.out_block = nn.ModuleList(
|
||
|
[
|
||
|
Conv1d(hidden_size, hidden_size, 3, dilation=dilation[2], padding=dilation[2]),
|
||
|
Conv1d(hidden_size, hidden_size, 3, dilation=dilation[3], padding=dilation[3]),
|
||
|
]
|
||
|
)
|
||
|
|
||
|
def forward(self, x, shift, scale):
|
||
|
x_inter = F.interpolate(x, size=x.shape[-1] * self.factor)
|
||
|
res = self.res_block(x_inter)
|
||
|
o = F.leaky_relu(x_inter, 0.2)
|
||
|
o = F.interpolate(o, size=x.shape[-1] * self.factor)
|
||
|
o = self.main_block[0](o)
|
||
|
o = shif_and_scale(o, scale, shift)
|
||
|
o = F.leaky_relu(o, 0.2)
|
||
|
o = self.main_block[1](o)
|
||
|
res2 = res + o
|
||
|
o = shif_and_scale(res2, scale, shift)
|
||
|
o = F.leaky_relu(o, 0.2)
|
||
|
o = self.out_block[0](o)
|
||
|
o = shif_and_scale(o, scale, shift)
|
||
|
o = F.leaky_relu(o, 0.2)
|
||
|
o = self.out_block[1](o)
|
||
|
o = o + res2
|
||
|
return o
|
||
|
|
||
|
def remove_weight_norm(self):
|
||
|
remove_parametrizations(self.res_block, "weight")
|
||
|
for _, layer in enumerate(self.main_block):
|
||
|
if len(layer.state_dict()) != 0:
|
||
|
remove_parametrizations(layer, "weight")
|
||
|
for _, layer in enumerate(self.out_block):
|
||
|
if len(layer.state_dict()) != 0:
|
||
|
remove_parametrizations(layer, "weight")
|
||
|
|
||
|
def apply_weight_norm(self):
|
||
|
self.res_block = weight_norm(self.res_block)
|
||
|
for idx, layer in enumerate(self.main_block):
|
||
|
if len(layer.state_dict()) != 0:
|
||
|
self.main_block[idx] = weight_norm(layer)
|
||
|
for idx, layer in enumerate(self.out_block):
|
||
|
if len(layer.state_dict()) != 0:
|
||
|
self.out_block[idx] = weight_norm(layer)
|
||
|
|
||
|
|
||
|
class DBlock(nn.Module):
|
||
|
def __init__(self, input_size, hidden_size, factor):
|
||
|
super().__init__()
|
||
|
self.factor = factor
|
||
|
self.res_block = Conv1d(input_size, hidden_size, 1)
|
||
|
self.main_block = nn.ModuleList(
|
||
|
[
|
||
|
Conv1d(input_size, hidden_size, 3, dilation=1, padding=1),
|
||
|
Conv1d(hidden_size, hidden_size, 3, dilation=2, padding=2),
|
||
|
Conv1d(hidden_size, hidden_size, 3, dilation=4, padding=4),
|
||
|
]
|
||
|
)
|
||
|
|
||
|
def forward(self, x):
|
||
|
size = x.shape[-1] // self.factor
|
||
|
res = self.res_block(x)
|
||
|
res = F.interpolate(res, size=size)
|
||
|
o = F.interpolate(x, size=size)
|
||
|
for layer in self.main_block:
|
||
|
o = F.leaky_relu(o, 0.2)
|
||
|
o = layer(o)
|
||
|
return o + res
|
||
|
|
||
|
def remove_weight_norm(self):
|
||
|
remove_parametrizations(self.res_block, "weight")
|
||
|
for _, layer in enumerate(self.main_block):
|
||
|
if len(layer.state_dict()) != 0:
|
||
|
remove_parametrizations(layer, "weight")
|
||
|
|
||
|
def apply_weight_norm(self):
|
||
|
self.res_block = weight_norm(self.res_block)
|
||
|
for idx, layer in enumerate(self.main_block):
|
||
|
if len(layer.state_dict()) != 0:
|
||
|
self.main_block[idx] = weight_norm(layer)
|