120 lines
4.5 KiB
Python
120 lines
4.5 KiB
Python
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||
|
# All rights reserved.
|
||
|
#
|
||
|
# This source code is licensed under the license found in the
|
||
|
# LICENSE file in the root directory of this source tree.
|
||
|
|
||
|
"""A streamable transformer."""
|
||
|
|
||
|
import typing as tp
|
||
|
|
||
|
import torch
|
||
|
import torch.nn as nn
|
||
|
import torch.nn.functional as F
|
||
|
|
||
|
|
||
|
def create_sin_embedding(positions: torch.Tensor, dim: int, max_period: float = 10000):
|
||
|
"""Create time embedding for the given positions, target dimension `dim`.
|
||
|
"""
|
||
|
# We aim for BTC format
|
||
|
assert dim % 2 == 0
|
||
|
half_dim = dim // 2
|
||
|
adim = torch.arange(half_dim, device=positions.device).view(1, 1, -1)
|
||
|
phase = positions / (max_period ** (adim / (half_dim - 1)))
|
||
|
return torch.cat([
|
||
|
torch.cos(phase),
|
||
|
torch.sin(phase),
|
||
|
], dim=-1)
|
||
|
|
||
|
|
||
|
class StreamingTransformerEncoderLayer(nn.TransformerEncoderLayer):
|
||
|
def forward(self, x: torch.Tensor, x_past: torch.Tensor, past_context: int): # type: ignore
|
||
|
if self.norm_first:
|
||
|
sa_input = self.norm1(x)
|
||
|
x = x + self._sa_block(sa_input, x_past, past_context)
|
||
|
x = x + self._ff_block(self.norm2(x))
|
||
|
else:
|
||
|
sa_input = x
|
||
|
x = self.norm1(x + self._sa_block(sa_input, x_past, past_context))
|
||
|
x = self.norm2(x + self._ff_block(x))
|
||
|
|
||
|
return x, sa_input
|
||
|
|
||
|
# self-attention block
|
||
|
def _sa_block(self, x: torch.Tensor, x_past: torch.Tensor, past_context: int): # type: ignore
|
||
|
_, T, _ = x.shape
|
||
|
_, H, _ = x_past.shape
|
||
|
|
||
|
queries = x
|
||
|
keys = torch.cat([x_past, x], dim=1)
|
||
|
values = keys
|
||
|
|
||
|
queries_pos = torch.arange(H, T + H, device=x.device).view(-1, 1)
|
||
|
keys_pos = torch.arange(T + H, device=x.device).view(1, -1)
|
||
|
delta = queries_pos - keys_pos
|
||
|
valid_access = (delta >= 0) & (delta <= past_context)
|
||
|
x = self.self_attn(queries, keys, values,
|
||
|
attn_mask=~valid_access,
|
||
|
need_weights=False)[0]
|
||
|
return self.dropout1(x)
|
||
|
|
||
|
|
||
|
class StreamingTransformerEncoder(nn.Module):
|
||
|
"""TransformerEncoder with streaming support.
|
||
|
|
||
|
Args:
|
||
|
dim (int): dimension of the data.
|
||
|
hidden_scale (int): intermediate dimension of FF module is this times the dimension.
|
||
|
num_heads (int): number of heads.
|
||
|
num_layers (int): number of layers.
|
||
|
max_period (float): maxium period of cosines in the positional embedding.
|
||
|
past_context (int or None): receptive field for the causal mask, infinite if None.
|
||
|
gelu (bool): if true uses GeLUs, otherwise use ReLUs.
|
||
|
norm_in (bool): normalize the input.
|
||
|
dropout (float): dropout probability.
|
||
|
**kwargs: See `nn.TransformerEncoderLayer`.
|
||
|
"""
|
||
|
def __init__(self, dim, hidden_scale: float = 4., num_heads: int = 8, num_layers: int = 5,
|
||
|
max_period: float = 10000, past_context: int = 1000, gelu: bool = True,
|
||
|
norm_in: bool = True, dropout: float = 0., **kwargs):
|
||
|
super().__init__()
|
||
|
assert dim % num_heads == 0
|
||
|
hidden_dim = int(dim * hidden_scale)
|
||
|
|
||
|
self.max_period = max_period
|
||
|
self.past_context = past_context
|
||
|
activation: tp.Any = F.gelu if gelu else F.relu
|
||
|
|
||
|
self.norm_in: nn.Module
|
||
|
if norm_in:
|
||
|
self.norm_in = nn.LayerNorm(dim)
|
||
|
else:
|
||
|
self.norm_in = nn.Identity()
|
||
|
|
||
|
self.layers = nn.ModuleList()
|
||
|
for idx in range(num_layers):
|
||
|
self.layers.append(
|
||
|
StreamingTransformerEncoderLayer(
|
||
|
dim, num_heads, hidden_dim,
|
||
|
activation=activation, batch_first=True, dropout=dropout, **kwargs))
|
||
|
|
||
|
def forward(self, x: torch.Tensor,
|
||
|
states: tp.Optional[tp.List[torch.Tensor]] = None,
|
||
|
offset: tp.Union[int, torch.Tensor] = 0):
|
||
|
B, T, C = x.shape
|
||
|
if states is None:
|
||
|
states = [torch.zeros_like(x[:, :1]) for _ in range(1 + len(self.layers))]
|
||
|
|
||
|
positions = torch.arange(T, device=x.device).view(1, -1, 1) + offset
|
||
|
pos_emb = create_sin_embedding(positions, C, max_period=self.max_period)
|
||
|
|
||
|
new_state: tp.List[torch.Tensor] = []
|
||
|
x = self.norm_in(x)
|
||
|
x = x + pos_emb
|
||
|
|
||
|
for layer_state, layer in zip(states, self.layers):
|
||
|
x, new_layer_state = layer(x, layer_state, self.past_context)
|
||
|
new_layer_state = torch.cat([layer_state, new_layer_state], dim=1)
|
||
|
new_state.append(new_layer_state[:, -self.past_context:, :])
|
||
|
return x, new_state, offset + T
|