1850 lines
77 KiB
Python
1850 lines
77 KiB
Python
|
# coding=utf-8
|
||
|
# Copyright 2021 The Fairseq Authors, Microsoft Research, and The HuggingFace Inc. team. All rights reserved.
|
||
|
#
|
||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||
|
# you may not use this file except in compliance with the License.
|
||
|
# You may obtain a copy of the License at
|
||
|
#
|
||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||
|
#
|
||
|
# Unless required by applicable law or agreed to in writing, software
|
||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
|
# See the License for the specific language governing permissions and
|
||
|
# limitations under the License.
|
||
|
""" PyTorch WavLM model."""
|
||
|
|
||
|
import math
|
||
|
import warnings
|
||
|
from typing import Optional, Tuple, Union
|
||
|
|
||
|
import numpy as np
|
||
|
import torch
|
||
|
import torch.nn.functional as F
|
||
|
import torch.utils.checkpoint
|
||
|
from torch import nn
|
||
|
from torch.nn import CrossEntropyLoss
|
||
|
|
||
|
from ...activations import ACT2FN
|
||
|
from ...integrations.deepspeed import is_deepspeed_zero3_enabled
|
||
|
from ...modeling_outputs import (
|
||
|
BaseModelOutput,
|
||
|
CausalLMOutput,
|
||
|
SequenceClassifierOutput,
|
||
|
TokenClassifierOutput,
|
||
|
Wav2Vec2BaseModelOutput,
|
||
|
XVectorOutput,
|
||
|
)
|
||
|
from ...modeling_utils import PreTrainedModel
|
||
|
from ...utils import (
|
||
|
add_code_sample_docstrings,
|
||
|
add_start_docstrings,
|
||
|
add_start_docstrings_to_model_forward,
|
||
|
is_peft_available,
|
||
|
logging,
|
||
|
)
|
||
|
from .configuration_wavlm import WavLMConfig
|
||
|
|
||
|
|
||
|
logger = logging.get_logger(__name__)
|
||
|
|
||
|
|
||
|
_HIDDEN_STATES_START_POSITION = 2
|
||
|
|
||
|
# General docstring
|
||
|
_CONFIG_FOR_DOC = "WavLMConfig"
|
||
|
|
||
|
# Base docstring
|
||
|
_CHECKPOINT_FOR_DOC = "patrickvonplaten/wavlm-libri-clean-100h-base-plus"
|
||
|
_EXPECTED_OUTPUT_SHAPE = [1, 292, 768]
|
||
|
|
||
|
# CTC docstring
|
||
|
_CTC_EXPECTED_OUTPUT = "'mister quilter is the aposle of the middle classes and we are glad to welcome his gospel'"
|
||
|
_CTC_EXPECTED_LOSS = 12.51
|
||
|
|
||
|
# Frame class docstring
|
||
|
_FRAME_CLASS_CHECKPOINT = "microsoft/wavlm-base-plus-sd"
|
||
|
_FRAME_EXPECTED_OUTPUT = [0, 0]
|
||
|
|
||
|
# Speaker Verification docstring
|
||
|
_XVECTOR_CHECKPOINT = "microsoft/wavlm-base-plus-sv"
|
||
|
_XVECTOR_EXPECTED_OUTPUT = 0.97
|
||
|
|
||
|
|
||
|
from ..deprecated._archive_maps import WAVLM_PRETRAINED_MODEL_ARCHIVE_LIST # noqa: F401, E402
|
||
|
|
||
|
|
||
|
# Copied from transformers.models.wav2vec2.modeling_wav2vec2._compute_mask_indices
|
||
|
def _compute_mask_indices(
|
||
|
shape: Tuple[int, int],
|
||
|
mask_prob: float,
|
||
|
mask_length: int,
|
||
|
attention_mask: Optional[torch.LongTensor] = None,
|
||
|
min_masks: int = 0,
|
||
|
) -> np.ndarray:
|
||
|
"""
|
||
|
Computes random mask spans for a given shape. Used to implement [SpecAugment: A Simple Data Augmentation Method for
|
||
|
ASR](https://arxiv.org/abs/1904.08779). Note that this method is not optimized to run on TPU and should be run on
|
||
|
CPU as part of the preprocessing during training.
|
||
|
|
||
|
Args:
|
||
|
shape: The shape for which to compute masks. This should be of a tuple of size 2 where
|
||
|
the first element is the batch size and the second element is the length of the axis to span.
|
||
|
mask_prob: The percentage of the whole axis (between 0 and 1) which will be masked. The number of
|
||
|
independently generated mask spans of length `mask_length` is computed by
|
||
|
`mask_prob*shape[1]/mask_length`. Note that due to overlaps, `mask_prob` is an upper bound and the
|
||
|
actual percentage will be smaller.
|
||
|
mask_length: size of the mask
|
||
|
min_masks: minimum number of masked spans
|
||
|
attention_mask: A (right-padded) attention mask which independently shortens the feature axis of
|
||
|
each batch dimension.
|
||
|
"""
|
||
|
batch_size, sequence_length = shape
|
||
|
|
||
|
if mask_length < 1:
|
||
|
raise ValueError("`mask_length` has to be bigger than 0.")
|
||
|
|
||
|
if mask_length > sequence_length:
|
||
|
raise ValueError(
|
||
|
f"`mask_length` has to be smaller than `sequence_length`, but got `mask_length`: {mask_length}"
|
||
|
f" and `sequence_length`: {sequence_length}`"
|
||
|
)
|
||
|
|
||
|
# epsilon is used for probabilistic rounding
|
||
|
epsilon = np.random.rand(1).item()
|
||
|
|
||
|
def compute_num_masked_span(input_length):
|
||
|
"""Given input length, compute how many spans should be masked"""
|
||
|
num_masked_span = int(mask_prob * input_length / mask_length + epsilon)
|
||
|
num_masked_span = max(num_masked_span, min_masks)
|
||
|
|
||
|
# make sure num masked span <= sequence_length
|
||
|
if num_masked_span * mask_length > sequence_length:
|
||
|
num_masked_span = sequence_length // mask_length
|
||
|
|
||
|
# make sure num_masked span is also <= input_length - (mask_length - 1)
|
||
|
if input_length - (mask_length - 1) < num_masked_span:
|
||
|
num_masked_span = max(input_length - (mask_length - 1), 0)
|
||
|
|
||
|
return num_masked_span
|
||
|
|
||
|
# compute number of masked spans in batch
|
||
|
input_lengths = (
|
||
|
attention_mask.sum(-1).detach().tolist()
|
||
|
if attention_mask is not None
|
||
|
else [sequence_length for _ in range(batch_size)]
|
||
|
)
|
||
|
|
||
|
# SpecAugment mask to fill
|
||
|
spec_aug_mask = np.zeros((batch_size, sequence_length), dtype=bool)
|
||
|
spec_aug_mask_idxs = []
|
||
|
|
||
|
max_num_masked_span = compute_num_masked_span(sequence_length)
|
||
|
|
||
|
if max_num_masked_span == 0:
|
||
|
return spec_aug_mask
|
||
|
|
||
|
for input_length in input_lengths:
|
||
|
# compute num of masked spans for this input
|
||
|
num_masked_span = compute_num_masked_span(input_length)
|
||
|
|
||
|
# get random indices to mask
|
||
|
spec_aug_mask_idx = np.random.choice(
|
||
|
np.arange(input_length - (mask_length - 1)), num_masked_span, replace=False
|
||
|
)
|
||
|
|
||
|
# pick first sampled index that will serve as a dummy index to pad vector
|
||
|
# to ensure same dimension for all batches due to probabilistic rounding
|
||
|
# Picking first sample just pads those vectors twice.
|
||
|
if len(spec_aug_mask_idx) == 0:
|
||
|
# this case can only happen if `input_length` is strictly smaller then
|
||
|
# `sequence_length` in which case the last token has to be a padding
|
||
|
# token which we can use as a dummy mask id
|
||
|
dummy_mask_idx = sequence_length - 1
|
||
|
else:
|
||
|
dummy_mask_idx = spec_aug_mask_idx[0]
|
||
|
|
||
|
spec_aug_mask_idx = np.concatenate(
|
||
|
[spec_aug_mask_idx, np.ones(max_num_masked_span - num_masked_span, dtype=np.int32) * dummy_mask_idx]
|
||
|
)
|
||
|
spec_aug_mask_idxs.append(spec_aug_mask_idx)
|
||
|
|
||
|
spec_aug_mask_idxs = np.array(spec_aug_mask_idxs)
|
||
|
|
||
|
# expand masked indices to masked spans
|
||
|
spec_aug_mask_idxs = np.broadcast_to(
|
||
|
spec_aug_mask_idxs[:, :, None], (batch_size, max_num_masked_span, mask_length)
|
||
|
)
|
||
|
spec_aug_mask_idxs = spec_aug_mask_idxs.reshape(batch_size, max_num_masked_span * mask_length)
|
||
|
|
||
|
# add offset to the starting indexes so that indexes now create a span
|
||
|
offsets = np.arange(mask_length)[None, None, :]
|
||
|
offsets = np.broadcast_to(offsets, (batch_size, max_num_masked_span, mask_length)).reshape(
|
||
|
batch_size, max_num_masked_span * mask_length
|
||
|
)
|
||
|
spec_aug_mask_idxs = spec_aug_mask_idxs + offsets
|
||
|
|
||
|
# ensure that we cannot have indices larger than sequence_length
|
||
|
if spec_aug_mask_idxs.max() > sequence_length - 1:
|
||
|
spec_aug_mask_idxs[spec_aug_mask_idxs > sequence_length - 1] = sequence_length - 1
|
||
|
|
||
|
# scatter indices to mask
|
||
|
np.put_along_axis(spec_aug_mask, spec_aug_mask_idxs, 1, -1)
|
||
|
|
||
|
return spec_aug_mask
|
||
|
|
||
|
|
||
|
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2NoLayerNormConvLayer with Wav2Vec2->WavLM
|
||
|
class WavLMNoLayerNormConvLayer(nn.Module):
|
||
|
def __init__(self, config, layer_id=0):
|
||
|
super().__init__()
|
||
|
self.in_conv_dim = config.conv_dim[layer_id - 1] if layer_id > 0 else 1
|
||
|
self.out_conv_dim = config.conv_dim[layer_id]
|
||
|
|
||
|
self.conv = nn.Conv1d(
|
||
|
self.in_conv_dim,
|
||
|
self.out_conv_dim,
|
||
|
kernel_size=config.conv_kernel[layer_id],
|
||
|
stride=config.conv_stride[layer_id],
|
||
|
bias=config.conv_bias,
|
||
|
)
|
||
|
self.activation = ACT2FN[config.feat_extract_activation]
|
||
|
|
||
|
def forward(self, hidden_states):
|
||
|
hidden_states = self.conv(hidden_states)
|
||
|
hidden_states = self.activation(hidden_states)
|
||
|
return hidden_states
|
||
|
|
||
|
|
||
|
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2LayerNormConvLayer with Wav2Vec2->WavLM
|
||
|
class WavLMLayerNormConvLayer(nn.Module):
|
||
|
def __init__(self, config, layer_id=0):
|
||
|
super().__init__()
|
||
|
self.in_conv_dim = config.conv_dim[layer_id - 1] if layer_id > 0 else 1
|
||
|
self.out_conv_dim = config.conv_dim[layer_id]
|
||
|
|
||
|
self.conv = nn.Conv1d(
|
||
|
self.in_conv_dim,
|
||
|
self.out_conv_dim,
|
||
|
kernel_size=config.conv_kernel[layer_id],
|
||
|
stride=config.conv_stride[layer_id],
|
||
|
bias=config.conv_bias,
|
||
|
)
|
||
|
self.layer_norm = nn.LayerNorm(self.out_conv_dim, elementwise_affine=True)
|
||
|
self.activation = ACT2FN[config.feat_extract_activation]
|
||
|
|
||
|
def forward(self, hidden_states):
|
||
|
hidden_states = self.conv(hidden_states)
|
||
|
|
||
|
hidden_states = hidden_states.transpose(-2, -1)
|
||
|
hidden_states = self.layer_norm(hidden_states)
|
||
|
hidden_states = hidden_states.transpose(-2, -1)
|
||
|
|
||
|
hidden_states = self.activation(hidden_states)
|
||
|
return hidden_states
|
||
|
|
||
|
|
||
|
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2GroupNormConvLayer with Wav2Vec2->WavLM
|
||
|
class WavLMGroupNormConvLayer(nn.Module):
|
||
|
def __init__(self, config, layer_id=0):
|
||
|
super().__init__()
|
||
|
self.in_conv_dim = config.conv_dim[layer_id - 1] if layer_id > 0 else 1
|
||
|
self.out_conv_dim = config.conv_dim[layer_id]
|
||
|
|
||
|
self.conv = nn.Conv1d(
|
||
|
self.in_conv_dim,
|
||
|
self.out_conv_dim,
|
||
|
kernel_size=config.conv_kernel[layer_id],
|
||
|
stride=config.conv_stride[layer_id],
|
||
|
bias=config.conv_bias,
|
||
|
)
|
||
|
self.activation = ACT2FN[config.feat_extract_activation]
|
||
|
|
||
|
self.layer_norm = nn.GroupNorm(num_groups=self.out_conv_dim, num_channels=self.out_conv_dim, affine=True)
|
||
|
|
||
|
def forward(self, hidden_states):
|
||
|
hidden_states = self.conv(hidden_states)
|
||
|
hidden_states = self.layer_norm(hidden_states)
|
||
|
hidden_states = self.activation(hidden_states)
|
||
|
return hidden_states
|
||
|
|
||
|
|
||
|
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2PositionalConvEmbedding with Wav2Vec2->WavLM
|
||
|
class WavLMPositionalConvEmbedding(nn.Module):
|
||
|
def __init__(self, config):
|
||
|
super().__init__()
|
||
|
self.conv = nn.Conv1d(
|
||
|
config.hidden_size,
|
||
|
config.hidden_size,
|
||
|
kernel_size=config.num_conv_pos_embeddings,
|
||
|
padding=config.num_conv_pos_embeddings // 2,
|
||
|
groups=config.num_conv_pos_embedding_groups,
|
||
|
)
|
||
|
|
||
|
weight_norm = nn.utils.weight_norm
|
||
|
if hasattr(nn.utils.parametrizations, "weight_norm"):
|
||
|
weight_norm = nn.utils.parametrizations.weight_norm
|
||
|
|
||
|
if is_deepspeed_zero3_enabled():
|
||
|
import deepspeed
|
||
|
|
||
|
with deepspeed.zero.GatheredParameters(self.conv.weight, modifier_rank=0):
|
||
|
self.conv = weight_norm(self.conv, name="weight", dim=2)
|
||
|
deepspeed.zero.register_external_parameter(self, self.conv.weight_v)
|
||
|
deepspeed.zero.register_external_parameter(self, self.conv.weight_g)
|
||
|
else:
|
||
|
self.conv = weight_norm(self.conv, name="weight", dim=2)
|
||
|
|
||
|
self.padding = WavLMSamePadLayer(config.num_conv_pos_embeddings)
|
||
|
self.activation = ACT2FN[config.feat_extract_activation]
|
||
|
|
||
|
def forward(self, hidden_states):
|
||
|
hidden_states = hidden_states.transpose(1, 2)
|
||
|
|
||
|
hidden_states = self.conv(hidden_states)
|
||
|
hidden_states = self.padding(hidden_states)
|
||
|
hidden_states = self.activation(hidden_states)
|
||
|
|
||
|
hidden_states = hidden_states.transpose(1, 2)
|
||
|
return hidden_states
|
||
|
|
||
|
|
||
|
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2SamePadLayer with Wav2Vec2->WavLM
|
||
|
class WavLMSamePadLayer(nn.Module):
|
||
|
def __init__(self, num_conv_pos_embeddings):
|
||
|
super().__init__()
|
||
|
self.num_pad_remove = 1 if num_conv_pos_embeddings % 2 == 0 else 0
|
||
|
|
||
|
def forward(self, hidden_states):
|
||
|
if self.num_pad_remove > 0:
|
||
|
hidden_states = hidden_states[:, :, : -self.num_pad_remove]
|
||
|
return hidden_states
|
||
|
|
||
|
|
||
|
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2FeatureEncoder with Wav2Vec2->WavLM
|
||
|
class WavLMFeatureEncoder(nn.Module):
|
||
|
"""Construct the features from raw audio waveform"""
|
||
|
|
||
|
def __init__(self, config):
|
||
|
super().__init__()
|
||
|
|
||
|
if config.feat_extract_norm == "group":
|
||
|
conv_layers = [WavLMGroupNormConvLayer(config, layer_id=0)] + [
|
||
|
WavLMNoLayerNormConvLayer(config, layer_id=i + 1) for i in range(config.num_feat_extract_layers - 1)
|
||
|
]
|
||
|
elif config.feat_extract_norm == "layer":
|
||
|
conv_layers = [WavLMLayerNormConvLayer(config, layer_id=i) for i in range(config.num_feat_extract_layers)]
|
||
|
else:
|
||
|
raise ValueError(
|
||
|
f"`config.feat_extract_norm` is {config.feat_extract_norm}, but has to be one of ['group', 'layer']"
|
||
|
)
|
||
|
self.conv_layers = nn.ModuleList(conv_layers)
|
||
|
self.gradient_checkpointing = False
|
||
|
self._requires_grad = True
|
||
|
|
||
|
def _freeze_parameters(self):
|
||
|
for param in self.parameters():
|
||
|
param.requires_grad = False
|
||
|
self._requires_grad = False
|
||
|
|
||
|
def forward(self, input_values):
|
||
|
hidden_states = input_values[:, None]
|
||
|
|
||
|
# make sure hidden_states require grad for gradient_checkpointing
|
||
|
if self._requires_grad and self.training:
|
||
|
hidden_states.requires_grad = True
|
||
|
|
||
|
for conv_layer in self.conv_layers:
|
||
|
if self._requires_grad and self.gradient_checkpointing and self.training:
|
||
|
hidden_states = self._gradient_checkpointing_func(
|
||
|
conv_layer.__call__,
|
||
|
hidden_states,
|
||
|
)
|
||
|
else:
|
||
|
hidden_states = conv_layer(hidden_states)
|
||
|
|
||
|
return hidden_states
|
||
|
|
||
|
|
||
|
class WavLMFeatureExtractor(WavLMFeatureEncoder):
|
||
|
def __init__(self, config):
|
||
|
super().__init__(config)
|
||
|
warnings.warn(
|
||
|
f"The class `{self.__class__.__name__}` has been depreciated "
|
||
|
"and will be removed in Transformers v5. "
|
||
|
f"Use `{self.__class__.__bases__[0].__name__}` instead.",
|
||
|
FutureWarning,
|
||
|
)
|
||
|
|
||
|
|
||
|
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2FeatureProjection with Wav2Vec2->WavLM
|
||
|
class WavLMFeatureProjection(nn.Module):
|
||
|
def __init__(self, config):
|
||
|
super().__init__()
|
||
|
self.layer_norm = nn.LayerNorm(config.conv_dim[-1], eps=config.layer_norm_eps)
|
||
|
self.projection = nn.Linear(config.conv_dim[-1], config.hidden_size)
|
||
|
self.dropout = nn.Dropout(config.feat_proj_dropout)
|
||
|
|
||
|
def forward(self, hidden_states):
|
||
|
# non-projected hidden states are needed for quantization
|
||
|
norm_hidden_states = self.layer_norm(hidden_states)
|
||
|
hidden_states = self.projection(norm_hidden_states)
|
||
|
hidden_states = self.dropout(hidden_states)
|
||
|
return hidden_states, norm_hidden_states
|
||
|
|
||
|
|
||
|
class WavLMAttention(nn.Module):
|
||
|
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
||
|
|
||
|
def __init__(
|
||
|
self,
|
||
|
embed_dim: int,
|
||
|
num_heads: int,
|
||
|
dropout: float = 0.0,
|
||
|
num_buckets: int = 320,
|
||
|
max_distance: int = 800,
|
||
|
has_relative_position_bias: bool = True,
|
||
|
):
|
||
|
super().__init__()
|
||
|
self.embed_dim = embed_dim
|
||
|
self.num_heads = num_heads
|
||
|
self.dropout = dropout
|
||
|
self.head_dim = embed_dim // num_heads
|
||
|
|
||
|
if (self.head_dim * num_heads) != self.embed_dim:
|
||
|
raise ValueError(
|
||
|
f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}"
|
||
|
f" and `num_heads`: {num_heads})."
|
||
|
)
|
||
|
self.scaling = self.head_dim**-0.5
|
||
|
|
||
|
self.k_proj = nn.Linear(embed_dim, embed_dim)
|
||
|
self.v_proj = nn.Linear(embed_dim, embed_dim)
|
||
|
self.q_proj = nn.Linear(embed_dim, embed_dim)
|
||
|
self.out_proj = nn.Linear(embed_dim, embed_dim)
|
||
|
|
||
|
self.num_buckets = num_buckets
|
||
|
self.max_distance = max_distance
|
||
|
|
||
|
self.gru_rel_pos_const = nn.Parameter(torch.ones(1, self.num_heads, 1, 1))
|
||
|
self.gru_rel_pos_linear = nn.Linear(self.head_dim, 8)
|
||
|
|
||
|
if has_relative_position_bias:
|
||
|
self.rel_attn_embed = nn.Embedding(self.num_buckets, self.num_heads)
|
||
|
|
||
|
def forward(
|
||
|
self,
|
||
|
hidden_states: torch.Tensor,
|
||
|
attention_mask: Optional[torch.Tensor] = None,
|
||
|
position_bias: Optional[torch.Tensor] = None,
|
||
|
output_attentions: bool = False,
|
||
|
index=0,
|
||
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||
|
"""Attention layer with relative attention"""
|
||
|
bsz, tgt_len, _ = hidden_states.size()
|
||
|
|
||
|
# first pass of attention layer creates position bias
|
||
|
if position_bias is None:
|
||
|
position_bias = self.compute_bias(tgt_len, tgt_len)
|
||
|
position_bias = (
|
||
|
position_bias.unsqueeze(0).repeat(bsz, 1, 1, 1).view(bsz * self.num_heads, tgt_len, tgt_len)
|
||
|
)
|
||
|
|
||
|
# Compute relative position bias:
|
||
|
# 1) get reshape hidden_states
|
||
|
gated_hidden_states = hidden_states.view(hidden_states.shape[:-1] + (self.num_heads, -1))
|
||
|
gated_hidden_states = gated_hidden_states.permute(0, 2, 1, 3)
|
||
|
|
||
|
# 2) project hidden states
|
||
|
relative_position_proj = self.gru_rel_pos_linear(gated_hidden_states)
|
||
|
relative_position_proj = relative_position_proj.view(gated_hidden_states.shape[:-1] + (2, 4)).sum(-1)
|
||
|
|
||
|
# 3) compute gate for position bias from projected hidden states
|
||
|
gate_a, gate_b = torch.sigmoid(relative_position_proj).chunk(2, dim=-1)
|
||
|
gate_output = gate_a * (gate_b * self.gru_rel_pos_const - 1.0) + 2.0
|
||
|
|
||
|
# 4) apply gate to position bias to compute gated position_bias
|
||
|
gated_position_bias = gate_output.view(bsz * self.num_heads, -1, 1) * position_bias
|
||
|
gated_position_bias = gated_position_bias.view((-1, tgt_len, tgt_len))
|
||
|
|
||
|
attn_output, attn_weights = self.torch_multi_head_self_attention(
|
||
|
hidden_states, attention_mask, gated_position_bias, output_attentions
|
||
|
)
|
||
|
|
||
|
return attn_output, attn_weights, position_bias
|
||
|
|
||
|
def torch_multi_head_self_attention(
|
||
|
self,
|
||
|
hidden_states: torch.FloatTensor,
|
||
|
attention_mask: Union[torch.LongTensor, torch.BoolTensor],
|
||
|
gated_position_bias: torch.FloatTensor,
|
||
|
output_attentions: bool,
|
||
|
) -> (torch.FloatTensor, torch.FloatTensor):
|
||
|
"""simple wrapper around torch's multi_head_attention_forward function"""
|
||
|
# self-attention assumes q = k = v
|
||
|
query = key = value = hidden_states.transpose(0, 1)
|
||
|
key_padding_mask = attention_mask.ne(1) if attention_mask is not None else None
|
||
|
|
||
|
# disable bias and add_zero_attn
|
||
|
bias_k = bias_v = None
|
||
|
add_zero_attn = False
|
||
|
|
||
|
# PyTorch 1.3.0 has F.multi_head_attention_forward defined
|
||
|
# so no problem with backwards compatibility
|
||
|
attn_output, attn_weights = F.multi_head_attention_forward(
|
||
|
query,
|
||
|
key,
|
||
|
value,
|
||
|
self.embed_dim,
|
||
|
self.num_heads,
|
||
|
torch.empty([0]),
|
||
|
torch.cat((self.q_proj.bias, self.k_proj.bias, self.v_proj.bias)),
|
||
|
bias_k,
|
||
|
bias_v,
|
||
|
add_zero_attn,
|
||
|
self.dropout,
|
||
|
self.out_proj.weight,
|
||
|
self.out_proj.bias,
|
||
|
self.training,
|
||
|
key_padding_mask,
|
||
|
output_attentions,
|
||
|
gated_position_bias,
|
||
|
use_separate_proj_weight=True,
|
||
|
q_proj_weight=self.q_proj.weight,
|
||
|
k_proj_weight=self.k_proj.weight,
|
||
|
v_proj_weight=self.v_proj.weight,
|
||
|
)
|
||
|
|
||
|
# [Seq_Len, Batch Size, ...] -> [Batch Size, Seq_Len, ...]
|
||
|
attn_output = attn_output.transpose(0, 1)
|
||
|
|
||
|
if attn_weights is not None:
|
||
|
# IMPORTANT: Attention weights are averaged weights
|
||
|
# here which should not be the case. This is an open issue
|
||
|
# on PyTorch: https://github.com/pytorch/pytorch/issues/32590
|
||
|
attn_weights = attn_weights[:, None].broadcast_to(
|
||
|
attn_weights.shape[:1] + (self.num_heads,) + attn_weights.shape[1:]
|
||
|
)
|
||
|
|
||
|
return attn_output, attn_weights
|
||
|
|
||
|
def compute_bias(self, query_length: int, key_length: int) -> torch.FloatTensor:
|
||
|
context_position = torch.arange(query_length, dtype=torch.long)[:, None]
|
||
|
memory_position = torch.arange(key_length, dtype=torch.long)[None, :]
|
||
|
relative_position = memory_position - context_position
|
||
|
relative_position_bucket = self._relative_positions_bucket(relative_position)
|
||
|
relative_position_bucket = relative_position_bucket.to(self.rel_attn_embed.weight.device)
|
||
|
values = self.rel_attn_embed(relative_position_bucket)
|
||
|
values = values.permute([2, 0, 1])
|
||
|
return values
|
||
|
|
||
|
def _relative_positions_bucket(self, relative_positions: torch.FloatTensor) -> torch.FloatTensor:
|
||
|
num_buckets = self.num_buckets // 2
|
||
|
|
||
|
relative_buckets = (relative_positions > 0).to(torch.long) * num_buckets
|
||
|
relative_positions = torch.abs(relative_positions)
|
||
|
|
||
|
max_exact = num_buckets // 2
|
||
|
is_small = relative_positions < max_exact
|
||
|
|
||
|
relative_positions_if_large = torch.log(relative_positions.float() / max_exact)
|
||
|
relative_positions_if_large = relative_positions_if_large / math.log(self.max_distance / max_exact)
|
||
|
relative_positions_if_large = relative_positions_if_large * (num_buckets - max_exact)
|
||
|
relative_position_if_large = (max_exact + relative_positions_if_large).to(torch.long)
|
||
|
relative_position_if_large = torch.min(
|
||
|
relative_position_if_large, torch.full_like(relative_position_if_large, num_buckets - 1)
|
||
|
)
|
||
|
|
||
|
relative_buckets += torch.where(is_small, relative_positions, relative_position_if_large)
|
||
|
return relative_buckets
|
||
|
|
||
|
|
||
|
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2FeedForward with Wav2Vec2->WavLM
|
||
|
class WavLMFeedForward(nn.Module):
|
||
|
def __init__(self, config):
|
||
|
super().__init__()
|
||
|
self.intermediate_dropout = nn.Dropout(config.activation_dropout)
|
||
|
|
||
|
self.intermediate_dense = nn.Linear(config.hidden_size, config.intermediate_size)
|
||
|
if isinstance(config.hidden_act, str):
|
||
|
self.intermediate_act_fn = ACT2FN[config.hidden_act]
|
||
|
else:
|
||
|
self.intermediate_act_fn = config.hidden_act
|
||
|
|
||
|
self.output_dense = nn.Linear(config.intermediate_size, config.hidden_size)
|
||
|
self.output_dropout = nn.Dropout(config.hidden_dropout)
|
||
|
|
||
|
def forward(self, hidden_states):
|
||
|
hidden_states = self.intermediate_dense(hidden_states)
|
||
|
hidden_states = self.intermediate_act_fn(hidden_states)
|
||
|
hidden_states = self.intermediate_dropout(hidden_states)
|
||
|
|
||
|
hidden_states = self.output_dense(hidden_states)
|
||
|
hidden_states = self.output_dropout(hidden_states)
|
||
|
return hidden_states
|
||
|
|
||
|
|
||
|
class WavLMEncoderLayer(nn.Module):
|
||
|
def __init__(self, config: WavLMConfig, has_relative_position_bias: bool = True):
|
||
|
super().__init__()
|
||
|
self.attention = WavLMAttention(
|
||
|
embed_dim=config.hidden_size,
|
||
|
num_heads=config.num_attention_heads,
|
||
|
dropout=config.attention_dropout,
|
||
|
num_buckets=config.num_buckets,
|
||
|
max_distance=config.max_bucket_distance,
|
||
|
has_relative_position_bias=has_relative_position_bias,
|
||
|
)
|
||
|
self.dropout = nn.Dropout(config.hidden_dropout)
|
||
|
self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
||
|
self.feed_forward = WavLMFeedForward(config)
|
||
|
self.final_layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
||
|
|
||
|
def forward(self, hidden_states, attention_mask=None, position_bias=None, output_attentions=False, index=0):
|
||
|
attn_residual = hidden_states
|
||
|
hidden_states, attn_weights, position_bias = self.attention(
|
||
|
hidden_states,
|
||
|
attention_mask=attention_mask,
|
||
|
position_bias=position_bias,
|
||
|
output_attentions=output_attentions,
|
||
|
index=index,
|
||
|
)
|
||
|
hidden_states = self.dropout(hidden_states)
|
||
|
hidden_states = attn_residual + hidden_states
|
||
|
|
||
|
hidden_states = self.layer_norm(hidden_states)
|
||
|
|
||
|
hidden_states = hidden_states + self.feed_forward(hidden_states)
|
||
|
hidden_states = self.final_layer_norm(hidden_states)
|
||
|
|
||
|
outputs = (hidden_states, position_bias)
|
||
|
|
||
|
if output_attentions:
|
||
|
outputs += (attn_weights,)
|
||
|
|
||
|
return outputs
|
||
|
|
||
|
|
||
|
class WavLMEncoderLayerStableLayerNorm(nn.Module):
|
||
|
def __init__(self, config: WavLMConfig, has_relative_position_bias: bool = True):
|
||
|
super().__init__()
|
||
|
self.attention = WavLMAttention(
|
||
|
embed_dim=config.hidden_size,
|
||
|
num_heads=config.num_attention_heads,
|
||
|
dropout=config.attention_dropout,
|
||
|
num_buckets=config.num_buckets,
|
||
|
max_distance=config.max_bucket_distance,
|
||
|
has_relative_position_bias=has_relative_position_bias,
|
||
|
)
|
||
|
self.dropout = nn.Dropout(config.hidden_dropout)
|
||
|
self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
||
|
self.feed_forward = WavLMFeedForward(config)
|
||
|
self.final_layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
||
|
|
||
|
def forward(self, hidden_states, attention_mask=None, position_bias=None, output_attentions=False):
|
||
|
attn_residual = hidden_states
|
||
|
hidden_states = self.layer_norm(hidden_states)
|
||
|
hidden_states, attn_weights, position_bias = self.attention(
|
||
|
hidden_states,
|
||
|
attention_mask=attention_mask,
|
||
|
position_bias=position_bias,
|
||
|
output_attentions=output_attentions,
|
||
|
)
|
||
|
hidden_states = self.dropout(hidden_states)
|
||
|
hidden_states = attn_residual + hidden_states
|
||
|
hidden_states = hidden_states + self.feed_forward(self.final_layer_norm(hidden_states))
|
||
|
|
||
|
outputs = (hidden_states, position_bias)
|
||
|
|
||
|
if output_attentions:
|
||
|
outputs += (attn_weights,)
|
||
|
|
||
|
return outputs
|
||
|
|
||
|
|
||
|
class WavLMEncoder(nn.Module):
|
||
|
def __init__(self, config):
|
||
|
super().__init__()
|
||
|
self.config = config
|
||
|
self.pos_conv_embed = WavLMPositionalConvEmbedding(config)
|
||
|
self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
||
|
self.dropout = nn.Dropout(config.hidden_dropout)
|
||
|
self.layers = nn.ModuleList(
|
||
|
[WavLMEncoderLayer(config, has_relative_position_bias=(i == 0)) for i in range(config.num_hidden_layers)]
|
||
|
)
|
||
|
self.gradient_checkpointing = False
|
||
|
|
||
|
def forward(
|
||
|
self,
|
||
|
hidden_states,
|
||
|
attention_mask=None,
|
||
|
output_attentions=False,
|
||
|
output_hidden_states=False,
|
||
|
return_dict=True,
|
||
|
):
|
||
|
all_hidden_states = () if output_hidden_states else None
|
||
|
all_self_attentions = () if output_attentions else None
|
||
|
|
||
|
if attention_mask is not None:
|
||
|
# make sure padded tokens output 0
|
||
|
hidden_states[~attention_mask] = 0.0
|
||
|
|
||
|
position_embeddings = self.pos_conv_embed(hidden_states)
|
||
|
hidden_states = hidden_states + position_embeddings
|
||
|
hidden_states = self.layer_norm(hidden_states)
|
||
|
hidden_states = self.dropout(hidden_states)
|
||
|
|
||
|
deepspeed_zero3_is_enabled = is_deepspeed_zero3_enabled()
|
||
|
position_bias = None
|
||
|
|
||
|
for i, layer in enumerate(self.layers):
|
||
|
if output_hidden_states:
|
||
|
all_hidden_states = all_hidden_states + (hidden_states,)
|
||
|
|
||
|
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
|
||
|
dropout_probability = torch.rand([])
|
||
|
|
||
|
skip_the_layer = self.training and i > 0 and (dropout_probability < self.config.layerdrop)
|
||
|
if not skip_the_layer or deepspeed_zero3_is_enabled:
|
||
|
# under deepspeed zero3 all gpus must run in sync
|
||
|
if self.gradient_checkpointing and self.training:
|
||
|
layer_outputs = self._gradient_checkpointing_func(
|
||
|
layer.__call__,
|
||
|
hidden_states,
|
||
|
attention_mask,
|
||
|
position_bias,
|
||
|
output_attentions,
|
||
|
)
|
||
|
else:
|
||
|
layer_outputs = layer(
|
||
|
hidden_states,
|
||
|
attention_mask=attention_mask,
|
||
|
position_bias=position_bias,
|
||
|
output_attentions=output_attentions,
|
||
|
index=i,
|
||
|
)
|
||
|
|
||
|
hidden_states, position_bias = layer_outputs[:2]
|
||
|
|
||
|
if skip_the_layer:
|
||
|
layer_outputs = (None, None)
|
||
|
|
||
|
if output_attentions:
|
||
|
all_self_attentions = all_self_attentions + (layer_outputs[2],)
|
||
|
|
||
|
if output_hidden_states:
|
||
|
all_hidden_states = all_hidden_states + (hidden_states,)
|
||
|
|
||
|
if not return_dict:
|
||
|
return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None)
|
||
|
return BaseModelOutput(
|
||
|
last_hidden_state=hidden_states,
|
||
|
hidden_states=all_hidden_states,
|
||
|
attentions=all_self_attentions,
|
||
|
)
|
||
|
|
||
|
|
||
|
class WavLMEncoderStableLayerNorm(nn.Module):
|
||
|
def __init__(self, config):
|
||
|
super().__init__()
|
||
|
self.config = config
|
||
|
self.pos_conv_embed = WavLMPositionalConvEmbedding(config)
|
||
|
self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
||
|
self.dropout = nn.Dropout(config.hidden_dropout)
|
||
|
self.layers = nn.ModuleList(
|
||
|
[
|
||
|
WavLMEncoderLayerStableLayerNorm(config, has_relative_position_bias=(i == 0))
|
||
|
for i in range(config.num_hidden_layers)
|
||
|
]
|
||
|
)
|
||
|
self.gradient_checkpointing = False
|
||
|
|
||
|
def forward(
|
||
|
self,
|
||
|
hidden_states,
|
||
|
attention_mask=None,
|
||
|
output_attentions=False,
|
||
|
output_hidden_states=False,
|
||
|
return_dict=True,
|
||
|
):
|
||
|
all_hidden_states = () if output_hidden_states else None
|
||
|
all_self_attentions = () if output_attentions else None
|
||
|
|
||
|
if attention_mask is not None:
|
||
|
# make sure padded tokens are not attended to
|
||
|
hidden_states[~attention_mask] = 0
|
||
|
|
||
|
position_embeddings = self.pos_conv_embed(hidden_states)
|
||
|
hidden_states = hidden_states + position_embeddings
|
||
|
hidden_states = self.dropout(hidden_states)
|
||
|
|
||
|
deepspeed_zero3_is_enabled = is_deepspeed_zero3_enabled()
|
||
|
position_bias = None
|
||
|
|
||
|
for i, layer in enumerate(self.layers):
|
||
|
if output_hidden_states:
|
||
|
all_hidden_states = all_hidden_states + (hidden_states,)
|
||
|
|
||
|
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
|
||
|
dropout_probability = torch.rand([])
|
||
|
|
||
|
skip_the_layer = self.training and i > 0 and (dropout_probability < self.config.layerdrop)
|
||
|
if not skip_the_layer or deepspeed_zero3_is_enabled:
|
||
|
# under deepspeed zero3 all gpus must run in sync
|
||
|
# XXX: could optimize this like synced_gpus in generate_utils but not sure if it's worth the code complication
|
||
|
if self.gradient_checkpointing and self.training:
|
||
|
layer_outputs = self._gradient_checkpointing_func(
|
||
|
layer.__call__,
|
||
|
hidden_states,
|
||
|
attention_mask,
|
||
|
position_bias,
|
||
|
output_attentions,
|
||
|
)
|
||
|
else:
|
||
|
layer_outputs = layer(
|
||
|
hidden_states,
|
||
|
attention_mask=attention_mask,
|
||
|
output_attentions=output_attentions,
|
||
|
position_bias=position_bias,
|
||
|
)
|
||
|
hidden_states, position_bias = layer_outputs[:2]
|
||
|
|
||
|
if skip_the_layer:
|
||
|
layer_outputs = (None, None)
|
||
|
|
||
|
if output_attentions:
|
||
|
all_self_attentions = all_self_attentions + (layer_outputs[2],)
|
||
|
|
||
|
hidden_states = self.layer_norm(hidden_states)
|
||
|
|
||
|
if output_hidden_states:
|
||
|
all_hidden_states = all_hidden_states + (hidden_states,)
|
||
|
|
||
|
if not return_dict:
|
||
|
return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None)
|
||
|
return BaseModelOutput(
|
||
|
last_hidden_state=hidden_states, hidden_states=all_hidden_states, attentions=all_self_attentions
|
||
|
)
|
||
|
|
||
|
|
||
|
class WavLMGumbelVectorQuantizer(nn.Module):
|
||
|
"""
|
||
|
Vector quantization using gumbel softmax. See [CATEGORICAL REPARAMETERIZATION WITH
|
||
|
GUMBEL-SOFTMAX](https://arxiv.org/pdf/1611.01144.pdf) for more information.
|
||
|
"""
|
||
|
|
||
|
def __init__(self, config):
|
||
|
super().__init__()
|
||
|
self.num_groups = config.num_codevector_groups
|
||
|
self.num_vars = config.num_codevectors_per_group
|
||
|
|
||
|
if config.codevector_dim % self.num_groups != 0:
|
||
|
raise ValueError(
|
||
|
f"`config.codevector_dim {config.codevector_dim} must be divisible"
|
||
|
f" by `config.num_codevector_groups` {self.num_groups} "
|
||
|
"for concatenation."
|
||
|
)
|
||
|
|
||
|
# storage for codebook variables (codewords)
|
||
|
self.codevectors = nn.Parameter(
|
||
|
torch.FloatTensor(1, self.num_groups * self.num_vars, config.codevector_dim // self.num_groups)
|
||
|
)
|
||
|
self.weight_proj = nn.Linear(config.conv_dim[-1], self.num_groups * self.num_vars)
|
||
|
|
||
|
# can be decayed for training
|
||
|
self.temperature = 2
|
||
|
|
||
|
@staticmethod
|
||
|
def _compute_perplexity(probs):
|
||
|
marginal_probs = probs.mean(dim=0)
|
||
|
perplexity = torch.exp(-torch.sum(marginal_probs * torch.log(marginal_probs + 1e-7), dim=-1)).sum()
|
||
|
return perplexity
|
||
|
|
||
|
def forward(self, hidden_states):
|
||
|
batch_size, sequence_length, hidden_size = hidden_states.shape
|
||
|
|
||
|
# project to codevector dim
|
||
|
hidden_states = self.weight_proj(hidden_states)
|
||
|
hidden_states = hidden_states.view(batch_size * sequence_length * self.num_groups, -1)
|
||
|
|
||
|
if self.training:
|
||
|
# sample code vector probs via gumbel in differentiateable way
|
||
|
codevector_probs = nn.functional.gumbel_softmax(hidden_states.float(), tau=self.temperature, hard=True)
|
||
|
codevector_probs = codevector_probs.type_as(hidden_states)
|
||
|
|
||
|
# compute perplexity
|
||
|
codevector_soft_dist = torch.softmax(
|
||
|
hidden_states.view(batch_size * sequence_length, self.num_groups, -1).float(), dim=-1
|
||
|
)
|
||
|
perplexity = self._compute_perplexity(codevector_soft_dist)
|
||
|
else:
|
||
|
# take argmax in non-differentiable way
|
||
|
# comptute hard codevector distribution (one hot)
|
||
|
codevector_idx = hidden_states.argmax(dim=-1)
|
||
|
codevector_probs = hidden_states.new_zeros(*hidden_states.shape).scatter_(
|
||
|
-1, codevector_idx.view(-1, 1), 1.0
|
||
|
)
|
||
|
codevector_probs = codevector_probs.view(batch_size * sequence_length, self.num_groups, -1)
|
||
|
|
||
|
perplexity = self._compute_perplexity(codevector_probs)
|
||
|
|
||
|
codevector_probs = codevector_probs.view(batch_size * sequence_length, -1)
|
||
|
# use probs to retrieve codevectors
|
||
|
codevectors_per_group = codevector_probs.unsqueeze(-1) * self.codevectors
|
||
|
codevectors = codevectors_per_group.view(batch_size * sequence_length, self.num_groups, self.num_vars, -1)
|
||
|
codevectors = codevectors.sum(-2).view(batch_size, sequence_length, -1)
|
||
|
|
||
|
return codevectors, perplexity
|
||
|
|
||
|
|
||
|
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2Adapter with Wav2Vec2->WavLM
|
||
|
class WavLMAdapter(nn.Module):
|
||
|
def __init__(self, config):
|
||
|
super().__init__()
|
||
|
|
||
|
# feature dim might need to be down-projected
|
||
|
if config.output_hidden_size != config.hidden_size:
|
||
|
self.proj = nn.Linear(config.hidden_size, config.output_hidden_size)
|
||
|
self.proj_layer_norm = nn.LayerNorm(config.output_hidden_size)
|
||
|
else:
|
||
|
self.proj = self.proj_layer_norm = None
|
||
|
|
||
|
self.layers = nn.ModuleList(WavLMAdapterLayer(config) for _ in range(config.num_adapter_layers))
|
||
|
self.layerdrop = config.layerdrop
|
||
|
|
||
|
def forward(self, hidden_states):
|
||
|
# down project hidden_states if necessary
|
||
|
if self.proj is not None and self.proj_layer_norm is not None:
|
||
|
hidden_states = self.proj(hidden_states)
|
||
|
hidden_states = self.proj_layer_norm(hidden_states)
|
||
|
|
||
|
hidden_states = hidden_states.transpose(1, 2)
|
||
|
|
||
|
for layer in self.layers:
|
||
|
layerdrop_prob = np.random.random()
|
||
|
if not self.training or (layerdrop_prob > self.layerdrop):
|
||
|
hidden_states = layer(hidden_states)
|
||
|
|
||
|
hidden_states = hidden_states.transpose(1, 2)
|
||
|
return hidden_states
|
||
|
|
||
|
|
||
|
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2AdapterLayer with Wav2Vec2->WavLM
|
||
|
class WavLMAdapterLayer(nn.Module):
|
||
|
def __init__(self, config):
|
||
|
super().__init__()
|
||
|
self.conv = nn.Conv1d(
|
||
|
config.output_hidden_size,
|
||
|
2 * config.output_hidden_size,
|
||
|
config.adapter_kernel_size,
|
||
|
stride=config.adapter_stride,
|
||
|
padding=1,
|
||
|
)
|
||
|
|
||
|
def forward(self, hidden_states):
|
||
|
hidden_states = self.conv(hidden_states)
|
||
|
hidden_states = nn.functional.glu(hidden_states, dim=1)
|
||
|
|
||
|
return hidden_states
|
||
|
|
||
|
|
||
|
class WavLMPreTrainedModel(PreTrainedModel):
|
||
|
"""
|
||
|
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
|
||
|
models.
|
||
|
"""
|
||
|
|
||
|
config_class = WavLMConfig
|
||
|
base_model_prefix = "wavlm"
|
||
|
main_input_name = "input_values"
|
||
|
supports_gradient_checkpointing = True
|
||
|
|
||
|
def _init_weights(self, module):
|
||
|
"""Initialize the weights"""
|
||
|
# gumbel softmax requires special init
|
||
|
if isinstance(module, WavLMGumbelVectorQuantizer):
|
||
|
module.weight_proj.weight.data.normal_(mean=0.0, std=1)
|
||
|
module.weight_proj.bias.data.zero_()
|
||
|
nn.init.uniform_(module.codevectors)
|
||
|
elif isinstance(module, WavLMPositionalConvEmbedding):
|
||
|
nn.init.normal_(
|
||
|
module.conv.weight,
|
||
|
mean=0,
|
||
|
std=2 * math.sqrt(1 / (module.conv.kernel_size[0] * module.conv.in_channels)),
|
||
|
)
|
||
|
nn.init.constant_(module.conv.bias, 0)
|
||
|
elif isinstance(module, WavLMFeatureProjection):
|
||
|
k = math.sqrt(1 / module.projection.in_features)
|
||
|
nn.init.uniform_(module.projection.weight, a=-k, b=k)
|
||
|
nn.init.uniform_(module.projection.bias, a=-k, b=k)
|
||
|
elif isinstance(module, nn.Linear):
|
||
|
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
||
|
|
||
|
if module.bias is not None:
|
||
|
module.bias.data.zero_()
|
||
|
elif isinstance(module, (nn.LayerNorm, nn.GroupNorm)):
|
||
|
module.bias.data.zero_()
|
||
|
module.weight.data.fill_(1.0)
|
||
|
elif isinstance(module, nn.Conv1d):
|
||
|
nn.init.kaiming_normal_(module.weight)
|
||
|
|
||
|
if module.bias is not None:
|
||
|
k = math.sqrt(module.groups / (module.in_channels * module.kernel_size[0]))
|
||
|
nn.init.uniform_(module.bias, a=-k, b=k)
|
||
|
|
||
|
def _get_feat_extract_output_lengths(
|
||
|
self, input_lengths: Union[torch.LongTensor, int], add_adapter: Optional[bool] = None
|
||
|
):
|
||
|
"""
|
||
|
Computes the output length of the convolutional layers
|
||
|
"""
|
||
|
|
||
|
add_adapter = self.config.add_adapter if add_adapter is None else add_adapter
|
||
|
|
||
|
def _conv_out_length(input_length, kernel_size, stride):
|
||
|
# 1D convolutional layer output length formula taken
|
||
|
# from https://pytorch.org/docs/stable/generated/torch.nn.Conv1d.html
|
||
|
return torch.div(input_length - kernel_size, stride, rounding_mode="floor") + 1
|
||
|
|
||
|
for kernel_size, stride in zip(self.config.conv_kernel, self.config.conv_stride):
|
||
|
input_lengths = _conv_out_length(input_lengths, kernel_size, stride)
|
||
|
|
||
|
if add_adapter:
|
||
|
for _ in range(self.config.num_adapter_layers):
|
||
|
input_lengths = _conv_out_length(input_lengths, 1, self.config.adapter_stride)
|
||
|
|
||
|
return input_lengths
|
||
|
|
||
|
def _get_feature_vector_attention_mask(
|
||
|
self, feature_vector_length: int, attention_mask: torch.LongTensor, add_adapter=None
|
||
|
):
|
||
|
# Effectively attention_mask.sum(-1), but not inplace to be able to run
|
||
|
# on inference mode.
|
||
|
non_padded_lengths = attention_mask.cumsum(dim=-1)[:, -1]
|
||
|
|
||
|
output_lengths = self._get_feat_extract_output_lengths(non_padded_lengths, add_adapter=add_adapter)
|
||
|
output_lengths = output_lengths.to(torch.long)
|
||
|
|
||
|
batch_size = attention_mask.shape[0]
|
||
|
|
||
|
attention_mask = torch.zeros(
|
||
|
(batch_size, feature_vector_length), dtype=attention_mask.dtype, device=attention_mask.device
|
||
|
)
|
||
|
# these two operations makes sure that all values before the output lengths idxs are attended to
|
||
|
attention_mask[(torch.arange(attention_mask.shape[0], device=attention_mask.device), output_lengths - 1)] = 1
|
||
|
attention_mask = attention_mask.flip([-1]).cumsum(-1).flip([-1]).bool()
|
||
|
return attention_mask
|
||
|
|
||
|
|
||
|
WAVLM_START_DOCSTRING = r"""
|
||
|
WavLM was proposed in [WavLM: Unified Speech Representation Learning with Labeled and Unlabeled
|
||
|
Data](https://arxiv.org/abs/2110.13900) by Sanyuan Chen, Chengyi Wang, Zhengyang Chen, Yu Wu, Shujie Liu, Zhuo
|
||
|
Chen, Jinyu Li, Naoyuki Kanda, Takuya Yoshioka, Xiong Xiao, Jian Wu, Long Zhou, Shuo Ren, Yanmin Qian, Yao Qian,
|
||
|
Jian Wu, Michael Zeng, Xiangzhan Yu, Furu Wei.
|
||
|
|
||
|
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
|
||
|
library implements for all its model (such as downloading or saving etc.).
|
||
|
|
||
|
This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) sub-class. Use
|
||
|
it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and
|
||
|
behavior.
|
||
|
|
||
|
Parameters:
|
||
|
config ([`WavLMConfig`]): Model configuration class with all the parameters of the model.
|
||
|
Initializing with a config file does not load the weights associated with the model, only the
|
||
|
configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
|
||
|
"""
|
||
|
|
||
|
|
||
|
WAVLM_INPUTS_DOCSTRING = r"""
|
||
|
Args:
|
||
|
input_values (`torch.FloatTensor` of shape `(batch_size, sequence_length)`):
|
||
|
Float values of input raw speech waveform. Values can be obtained by loading a `.flac` or `.wav` audio file
|
||
|
into an array of type `List[float]` or a `numpy.ndarray`, *e.g.* via the soundfile library (`pip install
|
||
|
soundfile`). To prepare the array into `input_values`, the [`AutoProcessor`] should be used for padding and
|
||
|
conversion into a tensor of type `torch.FloatTensor`. See [`Wav2Vec2Processor.__call__`] for details.
|
||
|
attention_mask (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||
|
Mask to avoid performing convolution and attention on padding token indices. Mask values selected in `[0,
|
||
|
1]`:
|
||
|
|
||
|
- 1 for tokens that are **not masked**,
|
||
|
- 0 for tokens that are **masked**.
|
||
|
|
||
|
[What are attention masks?](../glossary#attention-mask)
|
||
|
|
||
|
<Tip warning={true}>
|
||
|
|
||
|
`attention_mask` should only be passed if the corresponding processor has `config.return_attention_mask ==
|
||
|
True`. For all models whose processor has `config.return_attention_mask == False`, `attention_mask` should
|
||
|
**not** be passed to avoid degraded performance when doing batched inference. For such models
|
||
|
`input_values` should simply be padded with 0 and passed without `attention_mask`. Be aware that these
|
||
|
models also yield slightly different results depending on whether `input_values` is padded or not.
|
||
|
|
||
|
</Tip>
|
||
|
|
||
|
output_attentions (`bool`, *optional*):
|
||
|
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
|
||
|
tensors for more detail.
|
||
|
output_hidden_states (`bool`, *optional*):
|
||
|
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
|
||
|
more detail.
|
||
|
return_dict (`bool`, *optional*):
|
||
|
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
||
|
"""
|
||
|
|
||
|
|
||
|
@add_start_docstrings(
|
||
|
"The bare WavLM Model transformer outputting raw hidden-states without any specific head on top.",
|
||
|
WAVLM_START_DOCSTRING,
|
||
|
)
|
||
|
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2Model with Wav2Vec2->WavLM, wav2vec2->wavlm, WAV_2_VEC_2->WAVLM, WavLMBaseModelOutput->Wav2Vec2BaseModelOutput
|
||
|
class WavLMModel(WavLMPreTrainedModel):
|
||
|
def __init__(self, config: WavLMConfig):
|
||
|
super().__init__(config)
|
||
|
self.config = config
|
||
|
self.feature_extractor = WavLMFeatureEncoder(config)
|
||
|
self.feature_projection = WavLMFeatureProjection(config)
|
||
|
|
||
|
# model only needs masking vector if mask prob is > 0.0
|
||
|
if config.mask_time_prob > 0.0 or config.mask_feature_prob > 0.0:
|
||
|
self.masked_spec_embed = nn.Parameter(torch.FloatTensor(config.hidden_size).uniform_())
|
||
|
|
||
|
if config.do_stable_layer_norm:
|
||
|
self.encoder = WavLMEncoderStableLayerNorm(config)
|
||
|
else:
|
||
|
self.encoder = WavLMEncoder(config)
|
||
|
|
||
|
self.adapter = WavLMAdapter(config) if config.add_adapter else None
|
||
|
|
||
|
# Initialize weights and apply final processing
|
||
|
self.post_init()
|
||
|
|
||
|
def freeze_feature_extractor(self):
|
||
|
"""
|
||
|
Calling this function will disable the gradient computation for the feature encoder so that its parameters will
|
||
|
not be updated during training.
|
||
|
"""
|
||
|
warnings.warn(
|
||
|
"The method `freeze_feature_extractor` is deprecated and will be removed in Transformers v5. "
|
||
|
"Please use the equivalent `freeze_feature_encoder` method instead.",
|
||
|
FutureWarning,
|
||
|
)
|
||
|
self.freeze_feature_encoder()
|
||
|
|
||
|
def freeze_feature_encoder(self):
|
||
|
"""
|
||
|
Calling this function will disable the gradient computation for the feature encoder so that its parameter will
|
||
|
not be updated during training.
|
||
|
"""
|
||
|
self.feature_extractor._freeze_parameters()
|
||
|
|
||
|
def _mask_hidden_states(
|
||
|
self,
|
||
|
hidden_states: torch.FloatTensor,
|
||
|
mask_time_indices: Optional[torch.FloatTensor] = None,
|
||
|
attention_mask: Optional[torch.LongTensor] = None,
|
||
|
):
|
||
|
"""
|
||
|
Masks extracted features along time axis and/or along feature axis according to
|
||
|
[SpecAugment](https://arxiv.org/abs/1904.08779).
|
||
|
"""
|
||
|
|
||
|
# `config.apply_spec_augment` can set masking to False
|
||
|
if not getattr(self.config, "apply_spec_augment", True):
|
||
|
return hidden_states
|
||
|
|
||
|
# generate indices & apply SpecAugment along time axis
|
||
|
batch_size, sequence_length, hidden_size = hidden_states.size()
|
||
|
|
||
|
if mask_time_indices is not None:
|
||
|
# apply SpecAugment along time axis with given mask_time_indices
|
||
|
hidden_states[mask_time_indices] = self.masked_spec_embed.to(hidden_states.dtype)
|
||
|
elif self.config.mask_time_prob > 0 and self.training:
|
||
|
mask_time_indices = _compute_mask_indices(
|
||
|
(batch_size, sequence_length),
|
||
|
mask_prob=self.config.mask_time_prob,
|
||
|
mask_length=self.config.mask_time_length,
|
||
|
attention_mask=attention_mask,
|
||
|
min_masks=self.config.mask_time_min_masks,
|
||
|
)
|
||
|
mask_time_indices = torch.tensor(mask_time_indices, device=hidden_states.device, dtype=torch.bool)
|
||
|
hidden_states[mask_time_indices] = self.masked_spec_embed.to(hidden_states.dtype)
|
||
|
|
||
|
if self.config.mask_feature_prob > 0 and self.training:
|
||
|
# generate indices & apply SpecAugment along feature axis
|
||
|
mask_feature_indices = _compute_mask_indices(
|
||
|
(batch_size, hidden_size),
|
||
|
mask_prob=self.config.mask_feature_prob,
|
||
|
mask_length=self.config.mask_feature_length,
|
||
|
min_masks=self.config.mask_feature_min_masks,
|
||
|
)
|
||
|
mask_feature_indices = torch.tensor(mask_feature_indices, device=hidden_states.device, dtype=torch.bool)
|
||
|
mask_feature_indices = mask_feature_indices[:, None].expand(-1, sequence_length, -1)
|
||
|
hidden_states[mask_feature_indices] = 0
|
||
|
|
||
|
return hidden_states
|
||
|
|
||
|
@add_start_docstrings_to_model_forward(WAVLM_INPUTS_DOCSTRING)
|
||
|
@add_code_sample_docstrings(
|
||
|
checkpoint=_CHECKPOINT_FOR_DOC,
|
||
|
output_type=Wav2Vec2BaseModelOutput,
|
||
|
config_class=_CONFIG_FOR_DOC,
|
||
|
modality="audio",
|
||
|
expected_output=_EXPECTED_OUTPUT_SHAPE,
|
||
|
)
|
||
|
def forward(
|
||
|
self,
|
||
|
input_values: Optional[torch.Tensor],
|
||
|
attention_mask: Optional[torch.Tensor] = None,
|
||
|
mask_time_indices: Optional[torch.FloatTensor] = None,
|
||
|
output_attentions: Optional[bool] = None,
|
||
|
output_hidden_states: Optional[bool] = None,
|
||
|
return_dict: Optional[bool] = None,
|
||
|
) -> Union[Tuple, Wav2Vec2BaseModelOutput]:
|
||
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||
|
output_hidden_states = (
|
||
|
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||
|
)
|
||
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||
|
|
||
|
extract_features = self.feature_extractor(input_values)
|
||
|
extract_features = extract_features.transpose(1, 2)
|
||
|
|
||
|
if attention_mask is not None:
|
||
|
# compute reduced attention_mask corresponding to feature vectors
|
||
|
attention_mask = self._get_feature_vector_attention_mask(
|
||
|
extract_features.shape[1], attention_mask, add_adapter=False
|
||
|
)
|
||
|
|
||
|
hidden_states, extract_features = self.feature_projection(extract_features)
|
||
|
hidden_states = self._mask_hidden_states(
|
||
|
hidden_states, mask_time_indices=mask_time_indices, attention_mask=attention_mask
|
||
|
)
|
||
|
|
||
|
encoder_outputs = self.encoder(
|
||
|
hidden_states,
|
||
|
attention_mask=attention_mask,
|
||
|
output_attentions=output_attentions,
|
||
|
output_hidden_states=output_hidden_states,
|
||
|
return_dict=return_dict,
|
||
|
)
|
||
|
|
||
|
hidden_states = encoder_outputs[0]
|
||
|
|
||
|
if self.adapter is not None:
|
||
|
hidden_states = self.adapter(hidden_states)
|
||
|
|
||
|
if not return_dict:
|
||
|
return (hidden_states, extract_features) + encoder_outputs[1:]
|
||
|
|
||
|
return Wav2Vec2BaseModelOutput(
|
||
|
last_hidden_state=hidden_states,
|
||
|
extract_features=extract_features,
|
||
|
hidden_states=encoder_outputs.hidden_states,
|
||
|
attentions=encoder_outputs.attentions,
|
||
|
)
|
||
|
|
||
|
|
||
|
@add_start_docstrings(
|
||
|
"""WavLM Model with a `language modeling` head on top for Connectionist Temporal Classification (CTC).""",
|
||
|
WAVLM_START_DOCSTRING,
|
||
|
)
|
||
|
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForCTC with Wav2Vec2->WavLM, wav2vec2->wavlm, WAV_2_VEC_2->WAVLM
|
||
|
class WavLMForCTC(WavLMPreTrainedModel):
|
||
|
def __init__(self, config, target_lang: Optional[str] = None):
|
||
|
super().__init__(config)
|
||
|
|
||
|
self.wavlm = WavLMModel(config)
|
||
|
self.dropout = nn.Dropout(config.final_dropout)
|
||
|
|
||
|
self.target_lang = target_lang
|
||
|
|
||
|
if config.vocab_size is None:
|
||
|
raise ValueError(
|
||
|
f"You are trying to instantiate {self.__class__} with a configuration that "
|
||
|
"does not define the vocabulary size of the language model head. Please "
|
||
|
"instantiate the model as follows: `WavLMForCTC.from_pretrained(..., vocab_size=vocab_size)`. "
|
||
|
"or define `vocab_size` of your model's configuration."
|
||
|
)
|
||
|
output_hidden_size = (
|
||
|
config.output_hidden_size if hasattr(config, "add_adapter") and config.add_adapter else config.hidden_size
|
||
|
)
|
||
|
self.lm_head = nn.Linear(output_hidden_size, config.vocab_size)
|
||
|
|
||
|
# Initialize weights and apply final processing
|
||
|
self.post_init()
|
||
|
|
||
|
def tie_weights(self):
|
||
|
"""
|
||
|
This method overwrites [`~PreTrainedModel.tie_weights`] so that adapter weights can be correctly loaded when
|
||
|
passing `target_lang=...` to `from_pretrained(...)`.
|
||
|
|
||
|
This method is **not** supposed to be called by the user and is prone to be changed in the future.
|
||
|
"""
|
||
|
|
||
|
# Note that `tie_weights` is usually used to tie input and output embedding weights. The method is re-purposed to
|
||
|
# correctly load adapter layers for WavLM so that we do not have to introduce a new API to
|
||
|
# [`PreTrainedModel`]. While slightly hacky, WavLM never has to tie input and output embeddings, so that it is
|
||
|
# ok to repurpose this function here.
|
||
|
target_lang = self.target_lang
|
||
|
|
||
|
if target_lang is not None and getattr(self.config, "adapter_attn_dim", None) is None:
|
||
|
raise ValueError(f"Cannot pass `target_lang`: {target_lang} if `config.adapter_attn_dim` is not defined.")
|
||
|
elif target_lang is None and getattr(self.config, "adapter_attn_dim", None) is not None:
|
||
|
logger.info("By default `target_lang` is set to 'eng'.")
|
||
|
elif target_lang is not None:
|
||
|
self.load_adapter(target_lang, force_load=True)
|
||
|
|
||
|
def freeze_feature_extractor(self):
|
||
|
"""
|
||
|
Calling this function will disable the gradient computation for the feature encoder so that its parameter will
|
||
|
not be updated during training.
|
||
|
"""
|
||
|
warnings.warn(
|
||
|
"The method `freeze_feature_extractor` is deprecated and will be removed in Transformers v5. "
|
||
|
"Please use the equivalent `freeze_feature_encoder` method instead.",
|
||
|
FutureWarning,
|
||
|
)
|
||
|
self.freeze_feature_encoder()
|
||
|
|
||
|
def freeze_feature_encoder(self):
|
||
|
"""
|
||
|
Calling this function will disable the gradient computation for the feature encoder so that its parameter will
|
||
|
not be updated during training.
|
||
|
"""
|
||
|
self.wavlm.feature_extractor._freeze_parameters()
|
||
|
|
||
|
def freeze_base_model(self):
|
||
|
"""
|
||
|
Calling this function will disable the gradient computation for the base model so that its parameters will not
|
||
|
be updated during training. Only the classification head will be updated.
|
||
|
"""
|
||
|
for param in self.wavlm.parameters():
|
||
|
param.requires_grad = False
|
||
|
|
||
|
@add_start_docstrings_to_model_forward(WAVLM_INPUTS_DOCSTRING)
|
||
|
@add_code_sample_docstrings(
|
||
|
checkpoint=_CHECKPOINT_FOR_DOC,
|
||
|
output_type=CausalLMOutput,
|
||
|
config_class=_CONFIG_FOR_DOC,
|
||
|
expected_output=_CTC_EXPECTED_OUTPUT,
|
||
|
expected_loss=_CTC_EXPECTED_LOSS,
|
||
|
)
|
||
|
def forward(
|
||
|
self,
|
||
|
input_values: Optional[torch.Tensor],
|
||
|
attention_mask: Optional[torch.Tensor] = None,
|
||
|
output_attentions: Optional[bool] = None,
|
||
|
output_hidden_states: Optional[bool] = None,
|
||
|
return_dict: Optional[bool] = None,
|
||
|
labels: Optional[torch.Tensor] = None,
|
||
|
) -> Union[Tuple, CausalLMOutput]:
|
||
|
r"""
|
||
|
labels (`torch.LongTensor` of shape `(batch_size, target_length)`, *optional*):
|
||
|
Labels for connectionist temporal classification. Note that `target_length` has to be smaller or equal to
|
||
|
the sequence length of the output logits. Indices are selected in `[-100, 0, ..., config.vocab_size - 1]`.
|
||
|
All labels set to `-100` are ignored (masked), the loss is only computed for labels in `[0, ...,
|
||
|
config.vocab_size - 1]`.
|
||
|
"""
|
||
|
|
||
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||
|
|
||
|
outputs = self.wavlm(
|
||
|
input_values,
|
||
|
attention_mask=attention_mask,
|
||
|
output_attentions=output_attentions,
|
||
|
output_hidden_states=output_hidden_states,
|
||
|
return_dict=return_dict,
|
||
|
)
|
||
|
|
||
|
hidden_states = outputs[0]
|
||
|
hidden_states = self.dropout(hidden_states)
|
||
|
|
||
|
logits = self.lm_head(hidden_states)
|
||
|
|
||
|
loss = None
|
||
|
if labels is not None:
|
||
|
if labels.max() >= self.config.vocab_size:
|
||
|
raise ValueError(f"Label values must be <= vocab_size: {self.config.vocab_size}")
|
||
|
|
||
|
# retrieve loss input_lengths from attention_mask
|
||
|
attention_mask = (
|
||
|
attention_mask if attention_mask is not None else torch.ones_like(input_values, dtype=torch.long)
|
||
|
)
|
||
|
input_lengths = self._get_feat_extract_output_lengths(attention_mask.sum(-1)).to(torch.long)
|
||
|
|
||
|
# assuming that padded tokens are filled with -100
|
||
|
# when not being attended to
|
||
|
labels_mask = labels >= 0
|
||
|
target_lengths = labels_mask.sum(-1)
|
||
|
flattened_targets = labels.masked_select(labels_mask)
|
||
|
|
||
|
# ctc_loss doesn't support fp16
|
||
|
log_probs = nn.functional.log_softmax(logits, dim=-1, dtype=torch.float32).transpose(0, 1)
|
||
|
|
||
|
with torch.backends.cudnn.flags(enabled=False):
|
||
|
loss = nn.functional.ctc_loss(
|
||
|
log_probs,
|
||
|
flattened_targets,
|
||
|
input_lengths,
|
||
|
target_lengths,
|
||
|
blank=self.config.pad_token_id,
|
||
|
reduction=self.config.ctc_loss_reduction,
|
||
|
zero_infinity=self.config.ctc_zero_infinity,
|
||
|
)
|
||
|
|
||
|
if not return_dict:
|
||
|
output = (logits,) + outputs[_HIDDEN_STATES_START_POSITION:]
|
||
|
return ((loss,) + output) if loss is not None else output
|
||
|
|
||
|
return CausalLMOutput(
|
||
|
loss=loss, logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions
|
||
|
)
|
||
|
|
||
|
|
||
|
@add_start_docstrings(
|
||
|
"""
|
||
|
WavLM Model with a sequence classification head on top (a linear layer over the pooled output) for tasks like
|
||
|
SUPERB Keyword Spotting.
|
||
|
""",
|
||
|
WAVLM_START_DOCSTRING,
|
||
|
)
|
||
|
class WavLMForSequenceClassification(WavLMPreTrainedModel):
|
||
|
def __init__(self, config):
|
||
|
super().__init__(config)
|
||
|
|
||
|
if hasattr(config, "add_adapter") and config.add_adapter:
|
||
|
raise ValueError(
|
||
|
"Sequence classification does not support the use of WavLM adapters (config.add_adapter=True)"
|
||
|
)
|
||
|
self.wavlm = WavLMModel(config)
|
||
|
num_layers = config.num_hidden_layers + 1 # transformer layers + input embeddings
|
||
|
if config.use_weighted_layer_sum:
|
||
|
self.layer_weights = nn.Parameter(torch.ones(num_layers) / num_layers)
|
||
|
self.projector = nn.Linear(config.hidden_size, config.classifier_proj_size)
|
||
|
self.classifier = nn.Linear(config.classifier_proj_size, config.num_labels)
|
||
|
|
||
|
# Initialize weights and apply final processing
|
||
|
self.post_init()
|
||
|
|
||
|
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForSequenceClassification.freeze_feature_extractor
|
||
|
def freeze_feature_extractor(self):
|
||
|
"""
|
||
|
Calling this function will disable the gradient computation for the feature encoder so that its parameters will
|
||
|
not be updated during training.
|
||
|
"""
|
||
|
warnings.warn(
|
||
|
"The method `freeze_feature_extractor` is deprecated and will be removed in Transformers v5. "
|
||
|
"Please use the equivalent `freeze_feature_encoder` method instead.",
|
||
|
FutureWarning,
|
||
|
)
|
||
|
self.freeze_feature_encoder()
|
||
|
|
||
|
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForSequenceClassification.freeze_feature_encoder with wav2vec2->wavlm
|
||
|
def freeze_feature_encoder(self):
|
||
|
"""
|
||
|
Calling this function will disable the gradient computation for the feature encoder so that its parameter will
|
||
|
not be updated during training.
|
||
|
"""
|
||
|
self.wavlm.feature_extractor._freeze_parameters()
|
||
|
|
||
|
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForSequenceClassification.freeze_base_model with wav2vec2->wavlm
|
||
|
def freeze_base_model(self):
|
||
|
"""
|
||
|
Calling this function will disable the gradient computation for the base model so that its parameters will not
|
||
|
be updated during training. Only the classification head will be updated.
|
||
|
"""
|
||
|
for param in self.wavlm.parameters():
|
||
|
param.requires_grad = False
|
||
|
|
||
|
@add_start_docstrings_to_model_forward(WAVLM_INPUTS_DOCSTRING)
|
||
|
@add_code_sample_docstrings(
|
||
|
checkpoint=_CHECKPOINT_FOR_DOC,
|
||
|
output_type=SequenceClassifierOutput,
|
||
|
config_class=_CONFIG_FOR_DOC,
|
||
|
modality="audio",
|
||
|
)
|
||
|
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForSequenceClassification.forward with Wav2Vec2->WavLM, wav2vec2->wavlm
|
||
|
def forward(
|
||
|
self,
|
||
|
input_values: Optional[torch.Tensor],
|
||
|
attention_mask: Optional[torch.Tensor] = None,
|
||
|
output_attentions: Optional[bool] = None,
|
||
|
output_hidden_states: Optional[bool] = None,
|
||
|
return_dict: Optional[bool] = None,
|
||
|
labels: Optional[torch.Tensor] = None,
|
||
|
) -> Union[Tuple, SequenceClassifierOutput]:
|
||
|
r"""
|
||
|
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
||
|
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
|
||
|
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
|
||
|
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
|
||
|
"""
|
||
|
|
||
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||
|
output_hidden_states = True if self.config.use_weighted_layer_sum else output_hidden_states
|
||
|
|
||
|
outputs = self.wavlm(
|
||
|
input_values,
|
||
|
attention_mask=attention_mask,
|
||
|
output_attentions=output_attentions,
|
||
|
output_hidden_states=output_hidden_states,
|
||
|
return_dict=return_dict,
|
||
|
)
|
||
|
|
||
|
if self.config.use_weighted_layer_sum:
|
||
|
hidden_states = outputs[_HIDDEN_STATES_START_POSITION]
|
||
|
hidden_states = torch.stack(hidden_states, dim=1)
|
||
|
norm_weights = nn.functional.softmax(self.layer_weights, dim=-1)
|
||
|
hidden_states = (hidden_states * norm_weights.view(-1, 1, 1)).sum(dim=1)
|
||
|
else:
|
||
|
hidden_states = outputs[0]
|
||
|
|
||
|
hidden_states = self.projector(hidden_states)
|
||
|
if attention_mask is None:
|
||
|
pooled_output = hidden_states.mean(dim=1)
|
||
|
else:
|
||
|
padding_mask = self._get_feature_vector_attention_mask(hidden_states.shape[1], attention_mask)
|
||
|
hidden_states[~padding_mask] = 0.0
|
||
|
pooled_output = hidden_states.sum(dim=1) / padding_mask.sum(dim=1).view(-1, 1)
|
||
|
|
||
|
logits = self.classifier(pooled_output)
|
||
|
|
||
|
loss = None
|
||
|
if labels is not None:
|
||
|
loss_fct = CrossEntropyLoss()
|
||
|
loss = loss_fct(logits.view(-1, self.config.num_labels), labels.view(-1))
|
||
|
|
||
|
if not return_dict:
|
||
|
output = (logits,) + outputs[_HIDDEN_STATES_START_POSITION:]
|
||
|
return ((loss,) + output) if loss is not None else output
|
||
|
|
||
|
return SequenceClassifierOutput(
|
||
|
loss=loss,
|
||
|
logits=logits,
|
||
|
hidden_states=outputs.hidden_states,
|
||
|
attentions=outputs.attentions,
|
||
|
)
|
||
|
|
||
|
|
||
|
@add_start_docstrings(
|
||
|
"""
|
||
|
WavLM Model with a frame classification head on top for tasks like Speaker Diarization.
|
||
|
""",
|
||
|
WAVLM_START_DOCSTRING,
|
||
|
)
|
||
|
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForAudioFrameClassification with Wav2Vec2->WavLM, wav2vec2->wavlm, WAV_2_VEC_2->WAVLM
|
||
|
class WavLMForAudioFrameClassification(WavLMPreTrainedModel):
|
||
|
def __init__(self, config):
|
||
|
super().__init__(config)
|
||
|
|
||
|
if hasattr(config, "add_adapter") and config.add_adapter:
|
||
|
raise ValueError(
|
||
|
"Audio frame classification does not support the use of WavLM adapters (config.add_adapter=True)"
|
||
|
)
|
||
|
self.wavlm = WavLMModel(config)
|
||
|
num_layers = config.num_hidden_layers + 1 # transformer layers + input embeddings
|
||
|
if config.use_weighted_layer_sum:
|
||
|
self.layer_weights = nn.Parameter(torch.ones(num_layers) / num_layers)
|
||
|
self.classifier = nn.Linear(config.hidden_size, config.num_labels)
|
||
|
self.num_labels = config.num_labels
|
||
|
|
||
|
self.init_weights()
|
||
|
|
||
|
def freeze_feature_extractor(self):
|
||
|
"""
|
||
|
Calling this function will disable the gradient computation for the feature encoder so that its parameter will
|
||
|
not be updated during training.
|
||
|
"""
|
||
|
warnings.warn(
|
||
|
"The method `freeze_feature_extractor` is deprecated and will be removed in Transformers v5. "
|
||
|
"Please use the equivalent `freeze_feature_encoder` method instead.",
|
||
|
FutureWarning,
|
||
|
)
|
||
|
self.freeze_feature_encoder()
|
||
|
|
||
|
def freeze_feature_encoder(self):
|
||
|
"""
|
||
|
Calling this function will disable the gradient computation for the feature encoder so that its parameter will
|
||
|
not be updated during training.
|
||
|
"""
|
||
|
self.wavlm.feature_extractor._freeze_parameters()
|
||
|
|
||
|
def freeze_base_model(self):
|
||
|
"""
|
||
|
Calling this function will disable the gradient computation for the base model so that its parameters will not
|
||
|
be updated during training. Only the classification head will be updated.
|
||
|
"""
|
||
|
for param in self.wavlm.parameters():
|
||
|
param.requires_grad = False
|
||
|
|
||
|
@add_start_docstrings_to_model_forward(WAVLM_INPUTS_DOCSTRING)
|
||
|
@add_code_sample_docstrings(
|
||
|
checkpoint=_FRAME_CLASS_CHECKPOINT,
|
||
|
output_type=TokenClassifierOutput,
|
||
|
config_class=_CONFIG_FOR_DOC,
|
||
|
modality="audio",
|
||
|
expected_output=_FRAME_EXPECTED_OUTPUT,
|
||
|
)
|
||
|
def forward(
|
||
|
self,
|
||
|
input_values: Optional[torch.Tensor],
|
||
|
attention_mask: Optional[torch.Tensor] = None,
|
||
|
labels: Optional[torch.Tensor] = None,
|
||
|
output_attentions: Optional[bool] = None,
|
||
|
output_hidden_states: Optional[bool] = None,
|
||
|
return_dict: Optional[bool] = None,
|
||
|
) -> Union[Tuple, TokenClassifierOutput]:
|
||
|
r"""
|
||
|
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
||
|
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
|
||
|
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
|
||
|
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
|
||
|
"""
|
||
|
|
||
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||
|
output_hidden_states = True if self.config.use_weighted_layer_sum else output_hidden_states
|
||
|
|
||
|
outputs = self.wavlm(
|
||
|
input_values,
|
||
|
attention_mask=attention_mask,
|
||
|
output_attentions=output_attentions,
|
||
|
output_hidden_states=output_hidden_states,
|
||
|
return_dict=return_dict,
|
||
|
)
|
||
|
|
||
|
if self.config.use_weighted_layer_sum:
|
||
|
hidden_states = outputs[_HIDDEN_STATES_START_POSITION]
|
||
|
hidden_states = torch.stack(hidden_states, dim=1)
|
||
|
norm_weights = nn.functional.softmax(self.layer_weights, dim=-1)
|
||
|
hidden_states = (hidden_states * norm_weights.view(-1, 1, 1)).sum(dim=1)
|
||
|
else:
|
||
|
hidden_states = outputs[0]
|
||
|
|
||
|
logits = self.classifier(hidden_states)
|
||
|
|
||
|
loss = None
|
||
|
if labels is not None:
|
||
|
loss_fct = CrossEntropyLoss()
|
||
|
loss = loss_fct(logits.view(-1, self.num_labels), torch.argmax(labels.view(-1, self.num_labels), axis=1))
|
||
|
|
||
|
if not return_dict:
|
||
|
output = (logits,) + outputs[_HIDDEN_STATES_START_POSITION:]
|
||
|
return output
|
||
|
|
||
|
return TokenClassifierOutput(
|
||
|
loss=loss,
|
||
|
logits=logits,
|
||
|
hidden_states=outputs.hidden_states,
|
||
|
attentions=outputs.attentions,
|
||
|
)
|
||
|
|
||
|
|
||
|
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.AMSoftmaxLoss
|
||
|
class AMSoftmaxLoss(nn.Module):
|
||
|
def __init__(self, input_dim, num_labels, scale=30.0, margin=0.4):
|
||
|
super(AMSoftmaxLoss, self).__init__()
|
||
|
self.scale = scale
|
||
|
self.margin = margin
|
||
|
self.num_labels = num_labels
|
||
|
self.weight = nn.Parameter(torch.randn(input_dim, num_labels), requires_grad=True)
|
||
|
self.loss = nn.CrossEntropyLoss()
|
||
|
|
||
|
def forward(self, hidden_states, labels):
|
||
|
labels = labels.flatten()
|
||
|
weight = nn.functional.normalize(self.weight, dim=0)
|
||
|
hidden_states = nn.functional.normalize(hidden_states, dim=1)
|
||
|
cos_theta = torch.mm(hidden_states, weight)
|
||
|
psi = cos_theta - self.margin
|
||
|
|
||
|
onehot = nn.functional.one_hot(labels, self.num_labels)
|
||
|
logits = self.scale * torch.where(onehot.bool(), psi, cos_theta)
|
||
|
loss = self.loss(logits, labels)
|
||
|
|
||
|
return loss
|
||
|
|
||
|
|
||
|
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.TDNNLayer
|
||
|
class TDNNLayer(nn.Module):
|
||
|
def __init__(self, config, layer_id=0):
|
||
|
super().__init__()
|
||
|
self.in_conv_dim = config.tdnn_dim[layer_id - 1] if layer_id > 0 else config.tdnn_dim[layer_id]
|
||
|
self.out_conv_dim = config.tdnn_dim[layer_id]
|
||
|
self.kernel_size = config.tdnn_kernel[layer_id]
|
||
|
self.dilation = config.tdnn_dilation[layer_id]
|
||
|
|
||
|
self.kernel = nn.Linear(self.in_conv_dim * self.kernel_size, self.out_conv_dim)
|
||
|
self.activation = nn.ReLU()
|
||
|
|
||
|
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||
|
if is_peft_available():
|
||
|
from peft.tuners.lora import LoraLayer
|
||
|
|
||
|
if isinstance(self.kernel, LoraLayer):
|
||
|
warnings.warn(
|
||
|
"Detected LoRA on TDNNLayer. LoRA weights won't be applied due to optimization. "
|
||
|
"You should exclude TDNNLayer from LoRA's target modules.",
|
||
|
)
|
||
|
|
||
|
# for backward compatibility, we keep nn.Linear but call F.conv1d for speed up
|
||
|
hidden_states = hidden_states.transpose(1, 2)
|
||
|
weight = self.kernel.weight.view(self.out_conv_dim, self.kernel_size, self.in_conv_dim).transpose(1, 2)
|
||
|
hidden_states = nn.functional.conv1d(hidden_states, weight, self.kernel.bias, dilation=self.dilation)
|
||
|
hidden_states = hidden_states.transpose(1, 2)
|
||
|
|
||
|
hidden_states = self.activation(hidden_states)
|
||
|
return hidden_states
|
||
|
|
||
|
|
||
|
@add_start_docstrings(
|
||
|
"""
|
||
|
WavLM Model with an XVector feature extraction head on top for tasks like Speaker Verification.
|
||
|
""",
|
||
|
WAVLM_START_DOCSTRING,
|
||
|
)
|
||
|
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForXVector with Wav2Vec2->WavLM, wav2vec2->wavlm, WAV_2_VEC_2->WAVLM
|
||
|
class WavLMForXVector(WavLMPreTrainedModel):
|
||
|
def __init__(self, config):
|
||
|
super().__init__(config)
|
||
|
|
||
|
self.wavlm = WavLMModel(config)
|
||
|
num_layers = config.num_hidden_layers + 1 # transformer layers + input embeddings
|
||
|
if config.use_weighted_layer_sum:
|
||
|
self.layer_weights = nn.Parameter(torch.ones(num_layers) / num_layers)
|
||
|
self.projector = nn.Linear(config.hidden_size, config.tdnn_dim[0])
|
||
|
|
||
|
tdnn_layers = [TDNNLayer(config, i) for i in range(len(config.tdnn_dim))]
|
||
|
self.tdnn = nn.ModuleList(tdnn_layers)
|
||
|
|
||
|
self.feature_extractor = nn.Linear(config.tdnn_dim[-1] * 2, config.xvector_output_dim)
|
||
|
self.classifier = nn.Linear(config.xvector_output_dim, config.xvector_output_dim)
|
||
|
|
||
|
self.objective = AMSoftmaxLoss(config.xvector_output_dim, config.num_labels)
|
||
|
|
||
|
self.init_weights()
|
||
|
|
||
|
def freeze_feature_extractor(self):
|
||
|
"""
|
||
|
Calling this function will disable the gradient computation for the feature encoder so that its parameter will
|
||
|
not be updated during training.
|
||
|
"""
|
||
|
warnings.warn(
|
||
|
"The method `freeze_feature_extractor` is deprecated and will be removed in Transformers v5. "
|
||
|
"Please use the equivalent `freeze_feature_encoder` method instead.",
|
||
|
FutureWarning,
|
||
|
)
|
||
|
self.freeze_feature_encoder()
|
||
|
|
||
|
def freeze_feature_encoder(self):
|
||
|
"""
|
||
|
Calling this function will disable the gradient computation for the feature encoder so that its parameter will
|
||
|
not be updated during training.
|
||
|
"""
|
||
|
self.wavlm.feature_extractor._freeze_parameters()
|
||
|
|
||
|
def freeze_base_model(self):
|
||
|
"""
|
||
|
Calling this function will disable the gradient computation for the base model so that its parameters will not
|
||
|
be updated during training. Only the classification head will be updated.
|
||
|
"""
|
||
|
for param in self.wavlm.parameters():
|
||
|
param.requires_grad = False
|
||
|
|
||
|
def _get_tdnn_output_lengths(self, input_lengths: Union[torch.LongTensor, int]):
|
||
|
"""
|
||
|
Computes the output length of the TDNN layers
|
||
|
"""
|
||
|
|
||
|
def _conv_out_length(input_length, kernel_size, stride):
|
||
|
# 1D convolutional layer output length formula taken
|
||
|
# from https://pytorch.org/docs/stable/generated/torch.nn.Conv1d.html
|
||
|
return (input_length - kernel_size) // stride + 1
|
||
|
|
||
|
for kernel_size in self.config.tdnn_kernel:
|
||
|
input_lengths = _conv_out_length(input_lengths, kernel_size, 1)
|
||
|
|
||
|
return input_lengths
|
||
|
|
||
|
@add_start_docstrings_to_model_forward(WAVLM_INPUTS_DOCSTRING)
|
||
|
@add_code_sample_docstrings(
|
||
|
checkpoint=_XVECTOR_CHECKPOINT,
|
||
|
output_type=XVectorOutput,
|
||
|
config_class=_CONFIG_FOR_DOC,
|
||
|
modality="audio",
|
||
|
expected_output=_XVECTOR_EXPECTED_OUTPUT,
|
||
|
)
|
||
|
def forward(
|
||
|
self,
|
||
|
input_values: Optional[torch.Tensor],
|
||
|
attention_mask: Optional[torch.Tensor] = None,
|
||
|
output_attentions: Optional[bool] = None,
|
||
|
output_hidden_states: Optional[bool] = None,
|
||
|
return_dict: Optional[bool] = None,
|
||
|
labels: Optional[torch.Tensor] = None,
|
||
|
) -> Union[Tuple, XVectorOutput]:
|
||
|
r"""
|
||
|
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
||
|
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
|
||
|
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
|
||
|
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
|
||
|
"""
|
||
|
|
||
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||
|
output_hidden_states = True if self.config.use_weighted_layer_sum else output_hidden_states
|
||
|
|
||
|
outputs = self.wavlm(
|
||
|
input_values,
|
||
|
attention_mask=attention_mask,
|
||
|
output_attentions=output_attentions,
|
||
|
output_hidden_states=output_hidden_states,
|
||
|
return_dict=return_dict,
|
||
|
)
|
||
|
|
||
|
if self.config.use_weighted_layer_sum:
|
||
|
hidden_states = outputs[_HIDDEN_STATES_START_POSITION]
|
||
|
hidden_states = torch.stack(hidden_states, dim=1)
|
||
|
norm_weights = nn.functional.softmax(self.layer_weights, dim=-1)
|
||
|
hidden_states = (hidden_states * norm_weights.view(-1, 1, 1)).sum(dim=1)
|
||
|
else:
|
||
|
hidden_states = outputs[0]
|
||
|
|
||
|
hidden_states = self.projector(hidden_states)
|
||
|
|
||
|
for tdnn_layer in self.tdnn:
|
||
|
hidden_states = tdnn_layer(hidden_states)
|
||
|
|
||
|
# Statistic Pooling
|
||
|
if attention_mask is None:
|
||
|
mean_features = hidden_states.mean(dim=1)
|
||
|
std_features = hidden_states.std(dim=1)
|
||
|
else:
|
||
|
feat_extract_output_lengths = self._get_feat_extract_output_lengths(attention_mask.sum(dim=1))
|
||
|
tdnn_output_lengths = self._get_tdnn_output_lengths(feat_extract_output_lengths)
|
||
|
mean_features = []
|
||
|
std_features = []
|
||
|
for i, length in enumerate(tdnn_output_lengths):
|
||
|
mean_features.append(hidden_states[i, :length].mean(dim=0))
|
||
|
std_features.append(hidden_states[i, :length].std(dim=0))
|
||
|
mean_features = torch.stack(mean_features)
|
||
|
std_features = torch.stack(std_features)
|
||
|
statistic_pooling = torch.cat([mean_features, std_features], dim=-1)
|
||
|
|
||
|
output_embeddings = self.feature_extractor(statistic_pooling)
|
||
|
logits = self.classifier(output_embeddings)
|
||
|
|
||
|
loss = None
|
||
|
if labels is not None:
|
||
|
loss = self.objective(logits, labels)
|
||
|
|
||
|
if not return_dict:
|
||
|
output = (logits, output_embeddings) + outputs[_HIDDEN_STATES_START_POSITION:]
|
||
|
return ((loss,) + output) if loss is not None else output
|
||
|
|
||
|
return XVectorOutput(
|
||
|
loss=loss,
|
||
|
logits=logits,
|
||
|
embeddings=output_embeddings,
|
||
|
hidden_states=outputs.hidden_states,
|
||
|
attentions=outputs.attentions,
|
||
|
)
|