215 lines
11 KiB
Python
215 lines
11 KiB
Python
|
"""
|
||
|
The MIT License (MIT)
|
||
|
|
||
|
Copyright (c) Microsoft Corporation
|
||
|
|
||
|
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||
|
of this software and associated documentation files (the "Software"), to deal
|
||
|
in the Software without restriction, including without limitation the rights
|
||
|
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||
|
copies of the Software, and to permit persons to whom the Software is
|
||
|
furnished to do so, subject to the following conditions:
|
||
|
|
||
|
The above copyright notice and this permission notice shall be included in all
|
||
|
copies or substantial portions of the Software.
|
||
|
|
||
|
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||
|
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||
|
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||
|
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||
|
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||
|
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||
|
SOFTWARE.
|
||
|
"""
|
||
|
|
||
|
import math
|
||
|
from typing import Optional, Tuple
|
||
|
|
||
|
import torch
|
||
|
from torch import nn, Tensor
|
||
|
|
||
|
|
||
|
class WavLMSelfAttention(nn.Module):
|
||
|
"""Multi-headed self-attention for WavLM model :cite:`chen2022wavlm`.
|
||
|
Wraps around ``torch.nn.MultiheadAttention``, creating relaive position embeddings and passing them to multi-headed
|
||
|
attention as a mask.
|
||
|
Source: https://github.com/microsoft/unilm/blob/2d8302f09c99bca2b82e6e868d81d4281cceebc8/wavlm/modules.py#L303-L763
|
||
|
|
||
|
Args:
|
||
|
embed_dim (int): Total dimension of the model.
|
||
|
num_heads (int): The number of heads.
|
||
|
dropout (float, optional): Dropout probability on attn_output_weights. (Default: to ``0.0``)
|
||
|
bias (bool, optional): If ``True``, add bias to input / output projection layers. (Default: ``True``)
|
||
|
has_relative_attention_bias (bool, optional): If ``True``, apply relative position embedding.
|
||
|
Necessary in the first encoder layer, but not in the subsequent ones. (Default: ``False``)
|
||
|
num_buckets (int, optional): Number of buckets for relative position embedding. (Default: ``32``)
|
||
|
max_distance (int, optional): Naximum distance for relative position embedding. (Default: ``128``)
|
||
|
gru_rel_pos (bool, optional): If ``True``, apply gated relative position embedding. (Default: ``False``)
|
||
|
"""
|
||
|
|
||
|
def __init__(
|
||
|
self,
|
||
|
embed_dim: int,
|
||
|
num_heads: int,
|
||
|
dropout: float = 0.0,
|
||
|
bias: bool = True,
|
||
|
has_relative_attention_bias: bool = False,
|
||
|
num_buckets: int = 32,
|
||
|
max_distance: int = 128,
|
||
|
gru_rel_pos: bool = True,
|
||
|
):
|
||
|
super().__init__()
|
||
|
self.embed_dim = embed_dim
|
||
|
self.num_heads = num_heads
|
||
|
self.has_relative_attention_bias = has_relative_attention_bias
|
||
|
self.num_buckets = num_buckets
|
||
|
self.max_distance = max_distance
|
||
|
|
||
|
if has_relative_attention_bias:
|
||
|
self.rel_attn_embed = nn.Embedding(num_buckets, num_heads)
|
||
|
else:
|
||
|
self.rel_attn_embed = None
|
||
|
|
||
|
self.head_dim = embed_dim // num_heads
|
||
|
assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads"
|
||
|
|
||
|
self.dropout = dropout
|
||
|
self.attention = nn.MultiheadAttention(embed_dim, num_heads, dropout=dropout, bias=bias, batch_first=True)
|
||
|
|
||
|
self.gru_rel_pos = gru_rel_pos
|
||
|
if self.gru_rel_pos:
|
||
|
self.gru_rel_pos_linear = nn.Linear(self.head_dim, 8)
|
||
|
self.gru_rel_pos_const = nn.Parameter(torch.ones(1, num_heads, 1, 1))
|
||
|
self.has_position_bias = True
|
||
|
|
||
|
def compute_bias(self, query_length: int, key_length: int) -> Tensor:
|
||
|
"""Compute relative position embeddings for WavLM model.
|
||
|
Args:
|
||
|
query_length (int): Query position can take values between 0 and ``query_length - 1``.
|
||
|
key_length (int): Key position can take values between 0 and ``key_length - 1``.
|
||
|
Returns:
|
||
|
Tensor of shape `(num_heads, query_length, key_length)`, relative positions embeddings
|
||
|
"""
|
||
|
context_position = torch.arange(query_length, dtype=torch.long)[:, None]
|
||
|
memory_position = torch.arange(key_length, dtype=torch.long)[None, :]
|
||
|
relative_position = memory_position - context_position # Shape (query_length, key_length)
|
||
|
relative_position_bucket = self._relative_positions_bucket(relative_position, bidirectional=True)
|
||
|
relative_position_bucket = relative_position_bucket.to(self.rel_attn_embed.weight.device)
|
||
|
values = self.rel_attn_embed(relative_position_bucket) # Shape (query_length, key_length, num_heads)
|
||
|
values = values.permute([2, 0, 1])
|
||
|
return values
|
||
|
|
||
|
def _relative_positions_bucket(self, relative_positions: Tensor, bidirectional: bool = True):
|
||
|
"""Compute relative position buckets for WavLM model. Computation similar to formula (5) in WavLM
|
||
|
paper :cite:`chen2022wavlm`.
|
||
|
Args:
|
||
|
relative_positions (Tensor): Relative offsets between query and key positions,
|
||
|
of shape ``(query_length, key_length)``.
|
||
|
bidirectional (bool): If ``True``, values will be filled both above and below the diagonal in the resulting
|
||
|
matrix. If ``False``, the elements above the diagonal (i.e. with negative relative offsets) will be set
|
||
|
to zero. (Default ``True``)
|
||
|
Returns:
|
||
|
Tensor of shape ``(query_length, key_length)`` filled bucketed values of with relative positions.
|
||
|
"""
|
||
|
num_buckets = self.num_buckets
|
||
|
max_distance = self.max_distance
|
||
|
# Shape (query_length, key_length)
|
||
|
relative_buckets = torch.zeros_like(relative_positions, dtype=torch.long)
|
||
|
|
||
|
if bidirectional:
|
||
|
num_buckets = num_buckets // 2
|
||
|
relative_buckets += (relative_positions > 0).to(torch.long) * num_buckets
|
||
|
relative_positions = torch.abs(relative_positions)
|
||
|
else:
|
||
|
relative_positions = -torch.min(relative_positions, torch.zeros_like(relative_positions))
|
||
|
|
||
|
max_exact = num_buckets // 2
|
||
|
is_small = relative_positions < max_exact
|
||
|
|
||
|
relative_postion_if_large = max_exact + (
|
||
|
torch.log(relative_positions.float() / max_exact)
|
||
|
/ math.log(max_distance / max_exact)
|
||
|
* (num_buckets - max_exact)
|
||
|
).to(torch.long)
|
||
|
relative_postion_if_large = torch.min(
|
||
|
relative_postion_if_large, torch.full_like(relative_postion_if_large, num_buckets - 1)
|
||
|
)
|
||
|
|
||
|
relative_buckets += torch.where(is_small, relative_positions, relative_postion_if_large)
|
||
|
return relative_buckets
|
||
|
|
||
|
def forward(
|
||
|
self,
|
||
|
query: Tensor,
|
||
|
key_padding_mask: Optional[Tensor] = None,
|
||
|
attention_mask: Optional[Tensor] = None,
|
||
|
position_bias: Optional[Tensor] = None,
|
||
|
) -> Tuple[Tensor, Optional[Tensor]]:
|
||
|
"""
|
||
|
Args:
|
||
|
query (Tensor): Input of shape ``(batch_size, src_len, embed_dim)``.
|
||
|
key_padding_mask (Tensor or None, optional): Mask to exclude keys that are pads, of shape
|
||
|
`(batch, src_len)`, where padding elements are indicated by 1s. (Default: ``None``)
|
||
|
attn_mask: Needs to be ``None``. The argument exists for compatibility with
|
||
|
``EncoderLayer``. (Default: ``None``)
|
||
|
position_bias (Tensor or None, optional): Position bias of shape
|
||
|
``(batch_size * num_heads, src_len, src_len)``. When used inside WavLM model encoder, will be
|
||
|
generated in the first layer and then passed from each encoder layer to the next one.
|
||
|
(Default: ``None``)
|
||
|
Returns:
|
||
|
attn_output (Tensor): Attention output of shape ``(batch_size, src_len, embed_dim)``.
|
||
|
position_bias (Tensor or None): Position bias of shape ``(batch_size * num_heads, src_len, src_len)``.
|
||
|
"""
|
||
|
bsz, seq_len, embed_dim = query.size()
|
||
|
assert embed_dim == self.embed_dim
|
||
|
assert attention_mask is None
|
||
|
|
||
|
if self.rel_attn_embed is not None and position_bias is None:
|
||
|
position_bias = self.compute_bias(seq_len, seq_len)
|
||
|
position_bias = position_bias.unsqueeze(0).repeat(bsz, 1, 1, 1)
|
||
|
|
||
|
attn_mask_rel_pos: Optional[Tensor] = None
|
||
|
if position_bias is not None:
|
||
|
attn_mask_rel_pos = position_bias
|
||
|
if self.gru_rel_pos: # Apply gating on relative position bias
|
||
|
query_layer = query.view(bsz, seq_len, self.num_heads, -1)
|
||
|
query_layer = query_layer.permute(0, 2, 1, 3)
|
||
|
|
||
|
gate_a, gate_b = torch.sigmoid(
|
||
|
self.gru_rel_pos_linear(query_layer).view(bsz, self.num_heads, seq_len, 2, 4).sum(-1, keepdim=False)
|
||
|
).chunk(2, dim=-1)
|
||
|
gate_a_1 = gate_a * (gate_b * self.gru_rel_pos_const - 1.0) + 2.0
|
||
|
attn_mask_rel_pos = gate_a_1.view(bsz, self.num_heads, -1, 1) * position_bias
|
||
|
|
||
|
attn_mask_rel_pos = attn_mask_rel_pos.view((bsz, self.num_heads, seq_len, seq_len))
|
||
|
|
||
|
if attn_mask_rel_pos is not None and key_padding_mask is not None:
|
||
|
key_padding_mask = key_padding_mask.view(bsz, 1, 1, seq_len).expand(-1, self.num_heads, -1, -1)
|
||
|
key_padding_mask = torch.nn.functional._canonical_mask(
|
||
|
mask=key_padding_mask,
|
||
|
mask_name="key_padding_mask",
|
||
|
other_type=torch.nn.functional._none_or_dtype(attn_mask_rel_pos),
|
||
|
other_name="",
|
||
|
target_type=query.dtype,
|
||
|
)
|
||
|
if attn_mask_rel_pos is not None and key_padding_mask is not None:
|
||
|
attn_mask_rel_pos = attn_mask_rel_pos + key_padding_mask
|
||
|
query_projected = torch.nn.functional.linear(query, self.attention.in_proj_weight, self.attention.in_proj_bias)
|
||
|
query, key, value = query_projected.chunk(3, -1)
|
||
|
shape = (bsz, seq_len, self.num_heads, self.head_dim)
|
||
|
query = query.view(shape).transpose(2, 1) # (batch, num_heads, seq_len, head_dim)
|
||
|
key = key.view(shape).transpose(2, 1) # (batch, num_heads, seq_len, head_dim)
|
||
|
value = value.view(shape).transpose(2, 1) # (batch, num_heads, seq_len, head_dim)
|
||
|
dropout = self.dropout if self.training else 0.0
|
||
|
attn_output = torch.nn.functional.scaled_dot_product_attention(
|
||
|
query,
|
||
|
key,
|
||
|
value,
|
||
|
attn_mask=attn_mask_rel_pos,
|
||
|
dropout_p=dropout,
|
||
|
is_causal=False,
|
||
|
)
|
||
|
attn_output = attn_output.transpose(1, 2).reshape(bsz, -1, self.num_heads * self.head_dim)
|
||
|
attn_output = self.attention.out_proj(attn_output)
|
||
|
return attn_output, position_bias
|