124 lines
4.0 KiB
Python
124 lines
4.0 KiB
Python
|
import torch
|
||
|
from torch import nn
|
||
|
|
||
|
|
||
|
class LayerNorm(nn.Module):
|
||
|
def __init__(self, channels, eps=1e-4):
|
||
|
"""Layer norm for the 2nd dimension of the input.
|
||
|
Args:
|
||
|
channels (int): number of channels (2nd dimension) of the input.
|
||
|
eps (float): to prevent 0 division
|
||
|
|
||
|
Shapes:
|
||
|
- input: (B, C, T)
|
||
|
- output: (B, C, T)
|
||
|
"""
|
||
|
super().__init__()
|
||
|
self.channels = channels
|
||
|
self.eps = eps
|
||
|
|
||
|
self.gamma = nn.Parameter(torch.ones(1, channels, 1) * 0.1)
|
||
|
self.beta = nn.Parameter(torch.zeros(1, channels, 1))
|
||
|
|
||
|
def forward(self, x):
|
||
|
mean = torch.mean(x, 1, keepdim=True)
|
||
|
variance = torch.mean((x - mean) ** 2, 1, keepdim=True)
|
||
|
x = (x - mean) * torch.rsqrt(variance + self.eps)
|
||
|
x = x * self.gamma + self.beta
|
||
|
return x
|
||
|
|
||
|
|
||
|
class LayerNorm2(nn.Module):
|
||
|
"""Layer norm for the 2nd dimension of the input using torch primitive.
|
||
|
Args:
|
||
|
channels (int): number of channels (2nd dimension) of the input.
|
||
|
eps (float): to prevent 0 division
|
||
|
|
||
|
Shapes:
|
||
|
- input: (B, C, T)
|
||
|
- output: (B, C, T)
|
||
|
"""
|
||
|
|
||
|
def __init__(self, channels, eps=1e-5):
|
||
|
super().__init__()
|
||
|
self.channels = channels
|
||
|
self.eps = eps
|
||
|
|
||
|
self.gamma = nn.Parameter(torch.ones(channels))
|
||
|
self.beta = nn.Parameter(torch.zeros(channels))
|
||
|
|
||
|
def forward(self, x):
|
||
|
x = x.transpose(1, -1)
|
||
|
x = torch.nn.functional.layer_norm(x, (self.channels,), self.gamma, self.beta, self.eps)
|
||
|
return x.transpose(1, -1)
|
||
|
|
||
|
|
||
|
class TemporalBatchNorm1d(nn.BatchNorm1d):
|
||
|
"""Normalize each channel separately over time and batch."""
|
||
|
|
||
|
def __init__(self, channels, affine=True, track_running_stats=True, momentum=0.1):
|
||
|
super().__init__(channels, affine=affine, track_running_stats=track_running_stats, momentum=momentum)
|
||
|
|
||
|
def forward(self, x):
|
||
|
return super().forward(x.transpose(2, 1)).transpose(2, 1)
|
||
|
|
||
|
|
||
|
class ActNorm(nn.Module):
|
||
|
"""Activation Normalization bijector as an alternative to Batch Norm. It computes
|
||
|
mean and std from a sample data in advance and it uses these values
|
||
|
for normalization at training.
|
||
|
|
||
|
Args:
|
||
|
channels (int): input channels.
|
||
|
ddi (False): data depended initialization flag.
|
||
|
|
||
|
Shapes:
|
||
|
- inputs: (B, C, T)
|
||
|
- outputs: (B, C, T)
|
||
|
"""
|
||
|
|
||
|
def __init__(self, channels, ddi=False, **kwargs): # pylint: disable=unused-argument
|
||
|
super().__init__()
|
||
|
self.channels = channels
|
||
|
self.initialized = not ddi
|
||
|
|
||
|
self.logs = nn.Parameter(torch.zeros(1, channels, 1))
|
||
|
self.bias = nn.Parameter(torch.zeros(1, channels, 1))
|
||
|
|
||
|
def forward(self, x, x_mask=None, reverse=False, **kwargs): # pylint: disable=unused-argument
|
||
|
if x_mask is None:
|
||
|
x_mask = torch.ones(x.size(0), 1, x.size(2)).to(device=x.device, dtype=x.dtype)
|
||
|
x_len = torch.sum(x_mask, [1, 2])
|
||
|
if not self.initialized:
|
||
|
self.initialize(x, x_mask)
|
||
|
self.initialized = True
|
||
|
|
||
|
if reverse:
|
||
|
z = (x - self.bias) * torch.exp(-self.logs) * x_mask
|
||
|
logdet = None
|
||
|
else:
|
||
|
z = (self.bias + torch.exp(self.logs) * x) * x_mask
|
||
|
logdet = torch.sum(self.logs) * x_len # [b]
|
||
|
|
||
|
return z, logdet
|
||
|
|
||
|
def store_inverse(self):
|
||
|
pass
|
||
|
|
||
|
def set_ddi(self, ddi):
|
||
|
self.initialized = not ddi
|
||
|
|
||
|
def initialize(self, x, x_mask):
|
||
|
with torch.no_grad():
|
||
|
denom = torch.sum(x_mask, [0, 2])
|
||
|
m = torch.sum(x * x_mask, [0, 2]) / denom
|
||
|
m_sq = torch.sum(x * x * x_mask, [0, 2]) / denom
|
||
|
v = m_sq - (m**2)
|
||
|
logs = 0.5 * torch.log(torch.clamp_min(v, 1e-6))
|
||
|
|
||
|
bias_init = (-m * torch.exp(-logs)).view(*self.bias.shape).to(dtype=self.bias.dtype)
|
||
|
logs_init = (-logs).view(*self.logs.shape).to(dtype=self.logs.dtype)
|
||
|
|
||
|
self.bias.data.copy_(bias_init)
|
||
|
self.logs.data.copy_(logs_init)
|