320 lines
9.2 KiB
Python
320 lines
9.2 KiB
Python
# Adapted from https://github.com/lucidrains/naturalspeech2-pytorch/blob/659bec7f7543e7747e809e950cc2f84242fbeec7/naturalspeech2_pytorch/naturalspeech2_pytorch.py#L532
|
|
|
|
from collections import namedtuple
|
|
from functools import wraps
|
|
|
|
import torch
|
|
import torch.nn.functional as F
|
|
from einops import rearrange, repeat
|
|
from einops.layers.torch import Rearrange
|
|
from packaging import version
|
|
from torch import einsum, nn
|
|
|
|
|
|
def exists(val):
|
|
return val is not None
|
|
|
|
|
|
def once(fn):
|
|
called = False
|
|
|
|
@wraps(fn)
|
|
def inner(x):
|
|
nonlocal called
|
|
if called:
|
|
return
|
|
called = True
|
|
return fn(x)
|
|
|
|
return inner
|
|
|
|
|
|
print_once = once(print)
|
|
|
|
# main class
|
|
|
|
|
|
class Attend(nn.Module):
|
|
def __init__(self, dropout=0.0, causal=False, use_flash=False):
|
|
super().__init__()
|
|
self.dropout = dropout
|
|
self.attn_dropout = nn.Dropout(dropout)
|
|
|
|
self.causal = causal
|
|
self.register_buffer("mask", None, persistent=False)
|
|
|
|
self.use_flash = use_flash
|
|
assert not (
|
|
use_flash and version.parse(torch.__version__) < version.parse("2.0.0")
|
|
), "in order to use flash attention, you must be using pytorch 2.0 or above"
|
|
|
|
# determine efficient attention configs for cuda and cpu
|
|
self.config = namedtuple("EfficientAttentionConfig", ["enable_flash", "enable_math", "enable_mem_efficient"])
|
|
self.cpu_config = self.config(True, True, True)
|
|
self.cuda_config = None
|
|
|
|
if not torch.cuda.is_available() or not use_flash:
|
|
return
|
|
|
|
device_properties = torch.cuda.get_device_properties(torch.device("cuda"))
|
|
|
|
if device_properties.major == 8 and device_properties.minor == 0:
|
|
print_once("A100 GPU detected, using flash attention if input tensor is on cuda")
|
|
self.cuda_config = self.config(True, False, False)
|
|
else:
|
|
print_once("Non-A100 GPU detected, using math or mem efficient attention if input tensor is on cuda")
|
|
self.cuda_config = self.config(False, True, True)
|
|
|
|
def get_mask(self, n, device):
|
|
if exists(self.mask) and self.mask.shape[-1] >= n:
|
|
return self.mask[:n, :n]
|
|
|
|
mask = torch.ones((n, n), device=device, dtype=torch.bool).triu(1)
|
|
self.register_buffer("mask", mask, persistent=False)
|
|
return mask
|
|
|
|
def flash_attn(self, q, k, v, mask=None):
|
|
_, heads, q_len, _, k_len, is_cuda = *q.shape, k.shape[-2], q.is_cuda
|
|
|
|
# Recommended for multi-query single-key-value attention by Tri Dao
|
|
# kv shape torch.Size([1, 512, 64]) -> torch.Size([1, 8, 512, 64])
|
|
|
|
if k.ndim == 3:
|
|
k = rearrange(k, "b ... -> b 1 ...").expand_as(q)
|
|
|
|
if v.ndim == 3:
|
|
v = rearrange(v, "b ... -> b 1 ...").expand_as(q)
|
|
|
|
# Check if mask exists and expand to compatible shape
|
|
# The mask is B L, so it would have to be expanded to B H N L
|
|
|
|
if exists(mask):
|
|
mask = rearrange(mask, "b j -> b 1 1 j")
|
|
mask = mask.expand(-1, heads, q_len, -1)
|
|
|
|
# Check if there is a compatible device for flash attention
|
|
|
|
config = self.cuda_config if is_cuda else self.cpu_config
|
|
|
|
# pytorch 2.0 flash attn: q, k, v, mask, dropout, causal, softmax_scale
|
|
|
|
with torch.backends.cuda.sdp_kernel(**config._asdict()):
|
|
out = F.scaled_dot_product_attention(
|
|
q, k, v, attn_mask=mask, dropout_p=self.dropout if self.training else 0.0, is_causal=self.causal
|
|
)
|
|
|
|
return out
|
|
|
|
def forward(self, q, k, v, mask=None):
|
|
"""
|
|
einstein notation
|
|
b - batch
|
|
h - heads
|
|
n, i, j - sequence length (base sequence length, source, target)
|
|
d - feature dimension
|
|
"""
|
|
|
|
n, device = q.shape[-2], q.device
|
|
|
|
scale = q.shape[-1] ** -0.5
|
|
|
|
if self.use_flash:
|
|
return self.flash_attn(q, k, v, mask=mask)
|
|
|
|
kv_einsum_eq = "b j d" if k.ndim == 3 else "b h j d"
|
|
|
|
# similarity
|
|
|
|
sim = einsum(f"b h i d, {kv_einsum_eq} -> b h i j", q, k) * scale
|
|
|
|
# key padding mask
|
|
|
|
if exists(mask):
|
|
mask = rearrange(mask, "b j -> b 1 1 j")
|
|
sim = sim.masked_fill(~mask, -torch.finfo(sim.dtype).max)
|
|
|
|
# causal mask
|
|
|
|
if self.causal:
|
|
causal_mask = self.get_mask(n, device)
|
|
sim = sim.masked_fill(causal_mask, -torch.finfo(sim.dtype).max)
|
|
|
|
# attention
|
|
|
|
attn = sim.softmax(dim=-1)
|
|
attn = self.attn_dropout(attn)
|
|
|
|
# aggregate values
|
|
|
|
out = einsum(f"b h i j, {kv_einsum_eq} -> b h i d", attn, v)
|
|
|
|
return out
|
|
|
|
|
|
def Sequential(*mods):
|
|
return nn.Sequential(*filter(exists, mods))
|
|
|
|
|
|
def exists(x):
|
|
return x is not None
|
|
|
|
|
|
def default(val, d):
|
|
if exists(val):
|
|
return val
|
|
return d() if callable(d) else d
|
|
|
|
|
|
class RMSNorm(nn.Module):
|
|
def __init__(self, dim, scale=True, dim_cond=None):
|
|
super().__init__()
|
|
self.cond = exists(dim_cond)
|
|
self.to_gamma_beta = nn.Linear(dim_cond, dim * 2) if self.cond else None
|
|
|
|
self.scale = dim**0.5
|
|
self.gamma = nn.Parameter(torch.ones(dim)) if scale else None
|
|
|
|
def forward(self, x, cond=None):
|
|
gamma = default(self.gamma, 1)
|
|
out = F.normalize(x, dim=-1) * self.scale * gamma
|
|
|
|
if not self.cond:
|
|
return out
|
|
|
|
assert exists(cond)
|
|
gamma, beta = self.to_gamma_beta(cond).chunk(2, dim=-1)
|
|
gamma, beta = map(lambda t: rearrange(t, "b d -> b 1 d"), (gamma, beta))
|
|
return out * gamma + beta
|
|
|
|
|
|
class CausalConv1d(nn.Conv1d):
|
|
def __init__(self, *args, **kwargs):
|
|
super().__init__(*args, **kwargs)
|
|
(kernel_size,) = self.kernel_size
|
|
(dilation,) = self.dilation
|
|
(stride,) = self.stride
|
|
|
|
assert stride == 1
|
|
self.causal_padding = dilation * (kernel_size - 1)
|
|
|
|
def forward(self, x):
|
|
causal_padded_x = F.pad(x, (self.causal_padding, 0), value=0.0)
|
|
return super().forward(causal_padded_x)
|
|
|
|
|
|
class GEGLU(nn.Module):
|
|
def forward(self, x):
|
|
x, gate = x.chunk(2, dim=-1)
|
|
return F.gelu(gate) * x
|
|
|
|
|
|
def FeedForward(dim, mult=4, causal_conv=False):
|
|
dim_inner = int(dim * mult * 2 / 3)
|
|
|
|
conv = None
|
|
if causal_conv:
|
|
conv = nn.Sequential(
|
|
Rearrange("b n d -> b d n"),
|
|
CausalConv1d(dim_inner, dim_inner, 3),
|
|
Rearrange("b d n -> b n d"),
|
|
)
|
|
|
|
return Sequential(nn.Linear(dim, dim_inner * 2), GEGLU(), conv, nn.Linear(dim_inner, dim))
|
|
|
|
|
|
class PerceiverResampler(nn.Module):
|
|
def __init__(
|
|
self,
|
|
*,
|
|
dim,
|
|
depth=2,
|
|
dim_context=None,
|
|
num_latents=32,
|
|
dim_head=64,
|
|
heads=8,
|
|
ff_mult=4,
|
|
use_flash_attn=False,
|
|
):
|
|
super().__init__()
|
|
dim_context = default(dim_context, dim)
|
|
|
|
self.proj_context = nn.Linear(dim_context, dim) if dim_context != dim else nn.Identity()
|
|
|
|
self.latents = nn.Parameter(torch.randn(num_latents, dim))
|
|
nn.init.normal_(self.latents, std=0.02)
|
|
|
|
self.layers = nn.ModuleList([])
|
|
for _ in range(depth):
|
|
self.layers.append(
|
|
nn.ModuleList(
|
|
[
|
|
Attention(
|
|
dim=dim,
|
|
dim_head=dim_head,
|
|
heads=heads,
|
|
use_flash=use_flash_attn,
|
|
cross_attn_include_queries=True,
|
|
),
|
|
FeedForward(dim=dim, mult=ff_mult),
|
|
]
|
|
)
|
|
)
|
|
|
|
self.norm = RMSNorm(dim)
|
|
|
|
def forward(self, x, mask=None):
|
|
batch = x.shape[0]
|
|
|
|
x = self.proj_context(x)
|
|
|
|
latents = repeat(self.latents, "n d -> b n d", b=batch)
|
|
|
|
for attn, ff in self.layers:
|
|
latents = attn(latents, x, mask=mask) + latents
|
|
latents = ff(latents) + latents
|
|
|
|
return self.norm(latents)
|
|
|
|
|
|
class Attention(nn.Module):
|
|
def __init__(
|
|
self,
|
|
dim,
|
|
*,
|
|
dim_context=None,
|
|
causal=False,
|
|
dim_head=64,
|
|
heads=8,
|
|
dropout=0.0,
|
|
use_flash=False,
|
|
cross_attn_include_queries=False,
|
|
):
|
|
super().__init__()
|
|
self.scale = dim_head**-0.5
|
|
self.heads = heads
|
|
self.cross_attn_include_queries = cross_attn_include_queries
|
|
|
|
dim_inner = dim_head * heads
|
|
dim_context = default(dim_context, dim)
|
|
|
|
self.attend = Attend(causal=causal, dropout=dropout, use_flash=use_flash)
|
|
self.to_q = nn.Linear(dim, dim_inner, bias=False)
|
|
self.to_kv = nn.Linear(dim_context, dim_inner * 2, bias=False)
|
|
self.to_out = nn.Linear(dim_inner, dim, bias=False)
|
|
|
|
def forward(self, x, context=None, mask=None):
|
|
h, has_context = self.heads, exists(context)
|
|
|
|
context = default(context, x)
|
|
|
|
if has_context and self.cross_attn_include_queries:
|
|
context = torch.cat((x, context), dim=-2)
|
|
|
|
q, k, v = (self.to_q(x), *self.to_kv(context).chunk(2, dim=-1))
|
|
q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=h), (q, k, v))
|
|
|
|
out = self.attend(q, k, v, mask=mask)
|
|
|
|
out = rearrange(out, "b h n d -> b n (h d)")
|
|
return self.to_out(out)
|