1387 lines
57 KiB
Python
1387 lines
57 KiB
Python
|
# coding=utf-8
|
||
|
# Copyright 2020 The Facebook AI Research Team Authors and The HuggingFace Inc. team.
|
||
|
#
|
||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||
|
# you may not use this file except in compliance with the License.
|
||
|
# You may obtain a copy of the License at
|
||
|
#
|
||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||
|
#
|
||
|
# Unless required by applicable law or agreed to in writing, software
|
||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
|
# See the License for the specific language governing permissions and
|
||
|
# limitations under the License.
|
||
|
#
|
||
|
# Original implementation: https://github.com/pytorch/fairseq/tree/master/examples/wmt19
|
||
|
# Authors:
|
||
|
# - @alexeib Alexei Baevski
|
||
|
# - @edunov Sergey Edunov
|
||
|
# - @michaelauli Michael Auli
|
||
|
# - @myleott Myle Ott
|
||
|
# - @nng555 Nathan Ng
|
||
|
# - David Grangier
|
||
|
# - Kyra Yee
|
||
|
#
|
||
|
# Paper: Facebook FAIR's WMT19 News Translation Task Submission https://arxiv.org/abs/1907.06616
|
||
|
#
|
||
|
"""PyTorch Fairseq model, ported from https://github.com/pytorch/fairseq/tree/master/examples/wmt19"""
|
||
|
|
||
|
import math
|
||
|
from typing import Any, Dict, List, Optional, Tuple, Union
|
||
|
|
||
|
import torch
|
||
|
from torch import Tensor, nn
|
||
|
from torch.nn import CrossEntropyLoss, LayerNorm
|
||
|
|
||
|
from ...activations import ACT2FN
|
||
|
from ...integrations.deepspeed import is_deepspeed_zero3_enabled
|
||
|
from ...modeling_outputs import (
|
||
|
BaseModelOutput,
|
||
|
BaseModelOutputWithPastAndCrossAttentions,
|
||
|
Seq2SeqLMOutput,
|
||
|
Seq2SeqModelOutput,
|
||
|
)
|
||
|
from ...modeling_utils import PreTrainedModel
|
||
|
from ...utils import (
|
||
|
add_code_sample_docstrings,
|
||
|
add_end_docstrings,
|
||
|
add_start_docstrings,
|
||
|
add_start_docstrings_to_model_forward,
|
||
|
logging,
|
||
|
replace_return_docstrings,
|
||
|
)
|
||
|
from .configuration_fsmt import FSMTConfig
|
||
|
|
||
|
|
||
|
logger = logging.get_logger(__name__)
|
||
|
|
||
|
_CHECKPOINT_FOR_DOC = "facebook/wmt19-ru-en"
|
||
|
_CONFIG_FOR_DOC = "FSMTConfig"
|
||
|
|
||
|
# See all FSMT models at https://huggingface.co/models?filter=fsmt
|
||
|
|
||
|
# Porting notes:
|
||
|
# this one is modeled after BartModel*
|
||
|
#
|
||
|
# Currently only translation (fairseq also has weights for LM)
|
||
|
#
|
||
|
# fairseq provides weights for ru-en, en-ru and de-en, en-de pairs. All have been ported.
|
||
|
# - ru-en, en-ru use asymmetric vocab
|
||
|
# - de-en, en-de use a merged single vocab (but the code works as if they are separate)
|
||
|
#
|
||
|
# Differences with Bart:
|
||
|
# - not using bos token
|
||
|
# - 2 separate vocabs (src and target)
|
||
|
# - embed weights aren't tied
|
||
|
# - uses a model Ensemble (but that part isn't ported/implemented yet) - so we
|
||
|
# aren't getting as good of a BLEU score
|
||
|
# - uses a projection layer at the end of the decoder
|
||
|
# - doesn't use final_logits_bias
|
||
|
# - beam search: stops as soon as num_beams == len(hypos) (whereas transformers
|
||
|
# is not satisfied there and will continue searching until the next cycles
|
||
|
# aren't promising something better), comparing BLEU scores - the transformers
|
||
|
# algorithm is slightly superior, therefore using the latter. But if you want
|
||
|
# to match fairseq outputs, you need to pass ``early_stopping=True`` to ``generate()``.
|
||
|
#
|
||
|
# SinusoidalPositionalEmbedding is slightly different from Bart's - generates
|
||
|
# different embeddings. This implementation is copied verbatim from fairseq with
|
||
|
# some small changes to make it work here.
|
||
|
#
|
||
|
# Other changes:
|
||
|
# - doesn't support use_cache as Bart's version does
|
||
|
#
|
||
|
#
|
||
|
# FSMTConfig changes with BartConfig
|
||
|
#
|
||
|
# Differences with BART:
|
||
|
# - src/tgt vocabs aren't shared
|
||
|
# - token embeddings aren't shared
|
||
|
# - needs a language pair
|
||
|
# - scale_embedding are True
|
||
|
#
|
||
|
# some unused args were removed too
|
||
|
#
|
||
|
#
|
||
|
# TODO:
|
||
|
# - port model ensemble (fs uses 4 model checkpoints)
|
||
|
# - solve beam search discrepancies
|
||
|
# docstyle-ignore
|
||
|
|
||
|
"""
|
||
|
|
||
|
Here is how to compare BLEU scores against fairseq implementation:
|
||
|
|
||
|
# en-ru
|
||
|
|
||
|
export PAIR=en-ru
|
||
|
export DATA_DIR=data/$PAIR
|
||
|
export SAVE_DIR=data/$PAIR
|
||
|
export BS=8
|
||
|
export NUM_BEAMS=50
|
||
|
mkdir -p $DATA_DIR
|
||
|
sacrebleu -t wmt19 -l $PAIR --echo src > $DATA_DIR/val.source
|
||
|
sacrebleu -t wmt19 -l $PAIR --echo ref > $DATA_DIR/val.target
|
||
|
echo $PAIR
|
||
|
PYTHONPATH="src:examples/seq2seq" python examples/seq2seq/run_eval.py facebook/wmt19-$PAIR $DATA_DIR/val.source $SAVE_DIR/test_translations.txt --reference_path $DATA_DIR/val.target --score_path $SAVE_DIR/test_bleu.json --bs $BS --task translation --num_beams $NUM_BEAMS
|
||
|
|
||
|
# (fairseq BLEU: 36.4 http://matrix.statmt.org/matrix/output/1914?score_id=37605)
|
||
|
|
||
|
|
||
|
# ru-en
|
||
|
|
||
|
export PAIR=ru-en
|
||
|
export DATA_DIR=data/$PAIR
|
||
|
export SAVE_DIR=data/$PAIR
|
||
|
export BS=8
|
||
|
export NUM_BEAMS=50
|
||
|
mkdir -p $DATA_DIR
|
||
|
sacrebleu -t wmt19 -l $PAIR --echo src > $DATA_DIR/val.source
|
||
|
sacrebleu -t wmt19 -l $PAIR --echo ref > $DATA_DIR/val.target
|
||
|
PYTHONPATH="src:examples/seq2seq" python examples/seq2seq/run_eval.py facebook/wmt19-$PAIR $DATA_DIR/val.source $SAVE_DIR/test_translations.txt --reference_path $DATA_DIR/val.target --score_path $SAVE_DIR/test_bleu.json --bs $BS --task translation --num_beams $NUM_BEAMS
|
||
|
|
||
|
|
||
|
# (fairseq BLEU: 41.3 http://matrix.statmt.org/matrix/output/1907?run_id=6937)
|
||
|
|
||
|
|
||
|
# de-en
|
||
|
|
||
|
export PAIR=de-en
|
||
|
export DATA_DIR=data/$PAIR
|
||
|
export SAVE_DIR=data/$PAIR
|
||
|
export BS=8
|
||
|
export NUM_BEAMS=50
|
||
|
mkdir -p $DATA_DIR
|
||
|
sacrebleu -t wmt19 -l $PAIR --echo src > $DATA_DIR/val.source
|
||
|
sacrebleu -t wmt19 -l $PAIR --echo ref > $DATA_DIR/val.target
|
||
|
echo $PAIR
|
||
|
PYTHONPATH="src:examples/seq2seq" python examples/seq2seq/run_eval.py facebook/wmt19-$PAIR $DATA_DIR/val.source $SAVE_DIR/test_translations.txt --reference_path $DATA_DIR/val.target --score_path $SAVE_DIR/test_bleu.json --bs $BS --task translation --num_beams $NUM_BEAMS
|
||
|
|
||
|
# (fairseq BLEU: 42.3 http://matrix.statmt.org/matrix/output/1902?run_id=6750)
|
||
|
|
||
|
|
||
|
|
||
|
# en-de
|
||
|
|
||
|
export PAIR=en-de
|
||
|
export DATA_DIR=data/$PAIR
|
||
|
export SAVE_DIR=data/$PAIR
|
||
|
export BS=8
|
||
|
mkdir -p $DATA_DIR
|
||
|
sacrebleu -t wmt19 -l $PAIR --echo src > $DATA_DIR/val.source
|
||
|
sacrebleu -t wmt19 -l $PAIR --echo ref > $DATA_DIR/val.target
|
||
|
echo $PAIR
|
||
|
PYTHONPATH="src:examples/seq2seq" python examples/seq2seq/run_eval.py facebook/wmt19-$PAIR $DATA_DIR/val.source $SAVE_DIR/test_translations.txt --reference_path $DATA_DIR/val.target --score_path $SAVE_DIR/test_bleu.json --bs $BS --task translation --num_beams $NUM_BEAMS
|
||
|
|
||
|
# (fairseq BLEU: 43.1 http://matrix.statmt.org/matrix/output/1909?run_id=6862)
|
||
|
|
||
|
"""
|
||
|
|
||
|
|
||
|
FSMT_START_DOCSTRING = r"""
|
||
|
|
||
|
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
|
||
|
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
|
||
|
etc.)
|
||
|
|
||
|
This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
|
||
|
Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
|
||
|
and behavior.
|
||
|
|
||
|
Parameters:
|
||
|
config ([`FSMTConfig`]): Model configuration class with all the parameters of the model.
|
||
|
Initializing with a config file does not load the weights associated with the model, only the
|
||
|
configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
|
||
|
|
||
|
"""
|
||
|
FSMT_GENERATION_EXAMPLE = r"""
|
||
|
Translation example::
|
||
|
|
||
|
```python
|
||
|
>>> from transformers import AutoTokenizer, FSMTForConditionalGeneration
|
||
|
|
||
|
>>> mname = "facebook/wmt19-ru-en"
|
||
|
>>> model = FSMTForConditionalGeneration.from_pretrained(mname)
|
||
|
>>> tokenizer = AutoTokenizer.from_pretrained(mname)
|
||
|
|
||
|
>>> src_text = "Машинное обучение - это здорово, не так ли?"
|
||
|
>>> input_ids = tokenizer(src_text, return_tensors="pt").input_ids
|
||
|
>>> outputs = model.generate(input_ids, num_beams=5, num_return_sequences=3)
|
||
|
>>> tokenizer.decode(outputs[0], skip_special_tokens=True)
|
||
|
"Machine learning is great, isn't it?"
|
||
|
```
|
||
|
|
||
|
"""
|
||
|
|
||
|
FSMT_INPUTS_DOCSTRING = r"""
|
||
|
Args:
|
||
|
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
|
||
|
Indices of input sequence tokens in the vocabulary.
|
||
|
|
||
|
Indices can be obtained using [`FSTMTokenizer`]. See [`PreTrainedTokenizer.encode`] and
|
||
|
[`PreTrainedTokenizer.__call__`] for details.
|
||
|
|
||
|
[What are input IDs?](../glossary#input-ids)
|
||
|
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||
|
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
|
||
|
|
||
|
- 1 for tokens that are **not masked**,
|
||
|
- 0 for tokens that are **masked**.
|
||
|
|
||
|
[What are attention masks?](../glossary#attention-mask)
|
||
|
decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
|
||
|
Indices of decoder input sequence tokens in the vocabulary.
|
||
|
|
||
|
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
|
||
|
[`PreTrainedTokenizer.__call__`] for details.
|
||
|
|
||
|
[What are decoder input IDs?](../glossary#decoder-input-ids)
|
||
|
|
||
|
FSMT uses the `eos_token_id` as the starting token for `decoder_input_ids` generation. If `past_key_values`
|
||
|
is used, optionally only the last `decoder_input_ids` have to be input (see `past_key_values`).
|
||
|
decoder_attention_mask (`torch.BoolTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
|
||
|
Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also
|
||
|
be used by default.
|
||
|
head_mask (`torch.Tensor` of shape `(encoder_layers, encoder_attention_heads)`, *optional*):
|
||
|
Mask to nullify selected heads of the attention modules in the encoder. Mask values selected in `[0, 1]`:
|
||
|
|
||
|
- 1 indicates the head is **not masked**,
|
||
|
- 0 indicates the head is **masked**.
|
||
|
|
||
|
decoder_head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*):
|
||
|
Mask to nullify selected heads of the attention modules in the decoder. Mask values selected in `[0, 1]`:
|
||
|
|
||
|
- 1 indicates the head is **not masked**,
|
||
|
- 0 indicates the head is **masked**.
|
||
|
|
||
|
cross_attn_head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*):
|
||
|
Mask to nullify selected heads of the cross-attention modules in the decoder. Mask values selected in `[0,
|
||
|
1]`:
|
||
|
|
||
|
- 1 indicates the head is **not masked**,
|
||
|
- 0 indicates the head is **masked**.
|
||
|
|
||
|
encoder_outputs (`Tuple(torch.FloatTensor)`, *optional*):
|
||
|
Tuple consists of (`last_hidden_state`, *optional*: `hidden_states`, *optional*: `attentions`)
|
||
|
`last_hidden_state` of shape `(batch_size, sequence_length, hidden_size)` is a sequence of hidden-states at
|
||
|
the output of the last layer of the encoder. Used in the cross-attention of the decoder.
|
||
|
past_key_values (`Tuple(torch.FloatTensor)` of length `config.n_layers` with each tuple having 4 tensors of shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
|
||
|
Contains precomputed key and value hidden-states of the attention blocks. Can be used to speed up decoding.
|
||
|
If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
|
||
|
don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
|
||
|
`decoder_input_ids` of shape `(batch_size, sequence_length)`.
|
||
|
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
|
||
|
Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
|
||
|
is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
|
||
|
model's internal embedding lookup matrix.
|
||
|
decoder_inputs_embeds (`torch.FloatTensor` of shape `(batch_size, target_sequence_length, hidden_size)`, *optional*):
|
||
|
Optionally, instead of passing `decoder_input_ids` you can choose to directly pass an embedded
|
||
|
representation. If `past_key_values` is used, optionally only the last `decoder_inputs_embeds` have to be
|
||
|
input (see `past_key_values`). This is useful if you want more control over how to convert
|
||
|
`decoder_input_ids` indices into associated vectors than the model's internal embedding lookup matrix.
|
||
|
|
||
|
If `decoder_input_ids` and `decoder_inputs_embeds` are both unset, `decoder_inputs_embeds` takes the value
|
||
|
of `inputs_embeds`.
|
||
|
use_cache (`bool`, *optional*, defaults to `True`):
|
||
|
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
|
||
|
`past_key_values`).
|
||
|
output_attentions (`bool`, *optional*):
|
||
|
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
|
||
|
tensors for more detail.
|
||
|
output_hidden_states (`bool`, *optional*):
|
||
|
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
|
||
|
more detail.
|
||
|
return_dict (`bool`, *optional*):
|
||
|
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
||
|
"""
|
||
|
|
||
|
|
||
|
def invert_mask(attention_mask):
|
||
|
"""Turns 1->0, 0->1, False->True, True-> False"""
|
||
|
assert attention_mask.dim() == 2
|
||
|
return attention_mask.eq(0)
|
||
|
|
||
|
|
||
|
def triu_onnx(x, diagonal=0):
|
||
|
l = x.shape[0]
|
||
|
arange = torch.arange(l, device=x.device)
|
||
|
mask = arange.expand(l, l)
|
||
|
arange = arange.unsqueeze(-1)
|
||
|
if diagonal:
|
||
|
arange = arange + diagonal
|
||
|
mask = mask >= arange
|
||
|
return x.masked_fill(mask == 0, 0)
|
||
|
|
||
|
|
||
|
def _prepare_fsmt_decoder_inputs(
|
||
|
config,
|
||
|
input_ids,
|
||
|
decoder_input_ids=None,
|
||
|
decoder_padding_mask=None,
|
||
|
causal_mask_dtype=torch.float32,
|
||
|
):
|
||
|
"""
|
||
|
Prepare masks that ignore padding tokens in the decoder and a causal mask for the decoder if none are provided.
|
||
|
This mimics the default behavior in fairseq. To override it pass in masks. Note: this is not called during
|
||
|
generation
|
||
|
"""
|
||
|
pad_token_id = config.pad_token_id
|
||
|
if decoder_input_ids is None:
|
||
|
decoder_input_ids = shift_tokens_right(input_ids, pad_token_id)
|
||
|
bsz, tgt_len = decoder_input_ids.size()
|
||
|
if decoder_padding_mask is None:
|
||
|
decoder_padding_mask = make_padding_mask(decoder_input_ids, pad_token_id)
|
||
|
else:
|
||
|
decoder_padding_mask = invert_mask(decoder_padding_mask)
|
||
|
causal_mask = triu_onnx(fill_with_neg_inf(torch.zeros(tgt_len, tgt_len, dtype=causal_mask_dtype)), 1).to(
|
||
|
device=decoder_input_ids.device
|
||
|
)
|
||
|
return decoder_input_ids, decoder_padding_mask, causal_mask
|
||
|
|
||
|
|
||
|
class PretrainedFSMTModel(PreTrainedModel):
|
||
|
config_class = FSMTConfig
|
||
|
base_model_prefix = "model"
|
||
|
|
||
|
def _init_weights(self, module):
|
||
|
std = self.config.init_std
|
||
|
if isinstance(module, nn.Linear):
|
||
|
module.weight.data.normal_(mean=0.0, std=std)
|
||
|
if module.bias is not None:
|
||
|
module.bias.data.zero_()
|
||
|
elif isinstance(module, SinusoidalPositionalEmbedding):
|
||
|
pass
|
||
|
elif isinstance(module, nn.Embedding):
|
||
|
module.weight.data.normal_(mean=0.0, std=std)
|
||
|
if module.padding_idx is not None:
|
||
|
module.weight.data[module.padding_idx].zero_()
|
||
|
|
||
|
@property
|
||
|
def dummy_inputs(self):
|
||
|
pad_token = self.config.pad_token_id
|
||
|
input_ids = torch.tensor([[0, 6, 10, 4, 2], [0, 8, 12, 2, pad_token]], device=self.device)
|
||
|
dummy_inputs = {
|
||
|
"attention_mask": input_ids.ne(pad_token),
|
||
|
"input_ids": input_ids,
|
||
|
}
|
||
|
return dummy_inputs
|
||
|
|
||
|
|
||
|
def _make_linear_from_emb(emb):
|
||
|
vocab_size, emb_size = emb.weight.shape
|
||
|
lin_layer = nn.Linear(vocab_size, emb_size, bias=False)
|
||
|
lin_layer.weight.data = emb.weight.data
|
||
|
return lin_layer
|
||
|
|
||
|
|
||
|
# Helper Functions, mostly for making masks
|
||
|
def _check_shapes(shape_1, shape2):
|
||
|
if shape_1 != shape2:
|
||
|
raise AssertionError(f"shape mismatch: {shape_1} != {shape2}")
|
||
|
|
||
|
|
||
|
def shift_tokens_right(input_ids, pad_token_id):
|
||
|
"""Shift input ids one token to the right, and wrap the last non pad token (usually <eos>)."""
|
||
|
|
||
|
# replace possible -100 values in labels by `pad_token_id`
|
||
|
input_ids.masked_fill_(input_ids == -100, pad_token_id)
|
||
|
|
||
|
prev_output_tokens = input_ids.clone()
|
||
|
index_of_eos = (input_ids.ne(pad_token_id).sum(dim=1) - 1).unsqueeze(-1)
|
||
|
prev_output_tokens[:, 0] = input_ids.gather(1, index_of_eos).squeeze()
|
||
|
prev_output_tokens[:, 1:] = input_ids[:, :-1]
|
||
|
return prev_output_tokens
|
||
|
|
||
|
|
||
|
def make_padding_mask(input_ids, padding_idx=1):
|
||
|
"""True for pad tokens"""
|
||
|
padding_mask = input_ids.eq(padding_idx)
|
||
|
if not padding_mask.any():
|
||
|
padding_mask = None
|
||
|
return padding_mask
|
||
|
|
||
|
|
||
|
# Helper Modules
|
||
|
|
||
|
|
||
|
class EncoderLayer(nn.Module):
|
||
|
def __init__(self, config: FSMTConfig):
|
||
|
super().__init__()
|
||
|
self.embed_dim = config.d_model
|
||
|
self.self_attn = Attention(self.embed_dim, config.encoder_attention_heads, dropout=config.attention_dropout)
|
||
|
self.self_attn_layer_norm = LayerNorm(self.embed_dim)
|
||
|
self.dropout = config.dropout
|
||
|
self.activation_fn = ACT2FN[config.activation_function]
|
||
|
self.activation_dropout = config.activation_dropout
|
||
|
self.fc1 = nn.Linear(self.embed_dim, config.encoder_ffn_dim)
|
||
|
self.fc2 = nn.Linear(config.encoder_ffn_dim, self.embed_dim)
|
||
|
self.final_layer_norm = LayerNorm(self.embed_dim)
|
||
|
|
||
|
def forward(self, x, encoder_padding_mask, layer_head_mask, output_attentions=False):
|
||
|
"""
|
||
|
Args:
|
||
|
x (`torch.Tensor`): input to the layer of shape *(seq_len, batch, embed_dim)*
|
||
|
encoder_padding_mask (`torch.ByteTensor`): binary ByteTensor of shape
|
||
|
*(batch, src_len)* where padding elements are indicated by `1`.
|
||
|
for t_tgt, t_src is excluded (or masked out), =0 means it is
|
||
|
included in attention
|
||
|
layer_head_mask (`torch.FloatTensor`): mask for attention heads in a given layer of size
|
||
|
*(config.encoder_attention_heads,)*.
|
||
|
|
||
|
Returns:
|
||
|
encoded output of shape *(seq_len, batch, embed_dim)*
|
||
|
"""
|
||
|
residual = x
|
||
|
x, attn_weights = self.self_attn(
|
||
|
query=x,
|
||
|
key=x,
|
||
|
key_padding_mask=encoder_padding_mask,
|
||
|
layer_head_mask=layer_head_mask,
|
||
|
output_attentions=output_attentions,
|
||
|
)
|
||
|
x = nn.functional.dropout(x, p=self.dropout, training=self.training)
|
||
|
x = residual + x
|
||
|
x = self.self_attn_layer_norm(x)
|
||
|
|
||
|
residual = x
|
||
|
x = self.activation_fn(self.fc1(x))
|
||
|
x = nn.functional.dropout(x, p=self.activation_dropout, training=self.training)
|
||
|
x = self.fc2(x)
|
||
|
x = nn.functional.dropout(x, p=self.dropout, training=self.training)
|
||
|
x = residual + x
|
||
|
x = self.final_layer_norm(x)
|
||
|
return x, attn_weights
|
||
|
|
||
|
|
||
|
class FSMTEncoder(nn.Module):
|
||
|
"""
|
||
|
Transformer encoder consisting of *config.encoder_layers* self attention layers. Each layer is a [`EncoderLayer`].
|
||
|
|
||
|
Args:
|
||
|
config: FSMTConfig
|
||
|
"""
|
||
|
|
||
|
def __init__(self, config: FSMTConfig, embed_tokens):
|
||
|
super().__init__()
|
||
|
self.dropout = config.dropout
|
||
|
self.layerdrop = config.encoder_layerdrop
|
||
|
self.padding_idx = embed_tokens.padding_idx
|
||
|
self.embed_tokens = embed_tokens
|
||
|
embed_dim = embed_tokens.embedding_dim
|
||
|
self.embed_scale = math.sqrt(embed_dim) if config.scale_embedding else 1.0
|
||
|
self.embed_positions = SinusoidalPositionalEmbedding(
|
||
|
config.max_position_embeddings + self.padding_idx + 1, embed_dim, self.padding_idx
|
||
|
)
|
||
|
self.layers = nn.ModuleList([EncoderLayer(config) for _ in range(config.encoder_layers)]) # type: List[EncoderLayer]
|
||
|
|
||
|
def forward(
|
||
|
self,
|
||
|
input_ids: torch.Tensor,
|
||
|
attention_mask: Optional[torch.Tensor] = None,
|
||
|
inputs_embeds: torch.Tensor = None,
|
||
|
head_mask: Optional[torch.Tensor] = None,
|
||
|
output_attentions: bool = False,
|
||
|
output_hidden_states: bool = False,
|
||
|
return_dict: bool = True,
|
||
|
):
|
||
|
"""
|
||
|
Args:
|
||
|
input_ids (`torch.LongTensor`): tokens in the source language of shape
|
||
|
*(batch, src_len)*
|
||
|
attention_mask (`torch.LongTensor`): indicating which indices are padding tokens
|
||
|
inputs_embeds (`torch.FloatTensor`):
|
||
|
embedding vectors of shape *(batch, src_len, embed_dim)*
|
||
|
head_mask (`torch.Tensor` of shape `(num_layers, num_heads)`, *optional*):
|
||
|
Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`:
|
||
|
|
||
|
- 1 indicates the head is **not masked**,
|
||
|
- 0 indicates the head is **masked**.
|
||
|
|
||
|
Returns:
|
||
|
BaseModelOutput or Tuple comprised of:
|
||
|
|
||
|
- **x** (`torch.Tensor`): the last encoder layer's output of shape *(src_len, batch, embed_dim)*
|
||
|
- **encoder_states** (`Tuple(torch.FloatTensor`)): all intermediate hidden states of shape *(src_len,
|
||
|
batch, embed_dim)*. Only populated if *output_hidden_states:* is True.
|
||
|
- **all_attentions** (`Tuple(torch.FloatTensor`)): Attention weights for each layer.
|
||
|
During training might not be of length n_layers because of layer dropout.
|
||
|
"""
|
||
|
# check attention mask and invert
|
||
|
if attention_mask is not None:
|
||
|
attention_mask = invert_mask(attention_mask)
|
||
|
|
||
|
if input_ids is not None and inputs_embeds is not None:
|
||
|
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
|
||
|
elif input_ids is not None:
|
||
|
inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
|
||
|
embed_pos = self.embed_positions(input_ids)
|
||
|
elif inputs_embeds is not None:
|
||
|
inputs_embeds = inputs_embeds * self.embed_scale
|
||
|
|
||
|
# We assume zeros hidden states correspond to padding tokens
|
||
|
# and create `position_ids` where inputs_embeds[:, :, 0] == 0
|
||
|
position_ids = inputs_embeds[:, :, 0].masked_fill(
|
||
|
inputs_embeds[:, :, 0].eq(0), self.embed_positions.padding_idx
|
||
|
)
|
||
|
|
||
|
embed_pos = self.embed_positions(position_ids)
|
||
|
else:
|
||
|
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
||
|
|
||
|
x = inputs_embeds + embed_pos
|
||
|
x = nn.functional.dropout(x, p=self.dropout, training=self.training)
|
||
|
|
||
|
# B x T x C -> T x B x C
|
||
|
x = x.transpose(0, 1)
|
||
|
|
||
|
encoder_states = () if output_hidden_states else None
|
||
|
all_attentions = () if output_attentions else None
|
||
|
# check if head_mask has a correct number of layers specified if desired
|
||
|
if head_mask is not None:
|
||
|
assert head_mask.size()[0] == (
|
||
|
len(self.layers)
|
||
|
), f"The head_mask should be specified for {len(self.layers)} layers, but it is for {head_mask.size()[0]}."
|
||
|
for idx, encoder_layer in enumerate(self.layers):
|
||
|
if output_hidden_states:
|
||
|
x = x.transpose(0, 1) # T x B x C -> B x T x C
|
||
|
encoder_states += (x,)
|
||
|
x = x.transpose(0, 1) # B x T x C -> T x B x C
|
||
|
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
|
||
|
dropout_probability = torch.rand([])
|
||
|
if self.training and (dropout_probability < self.layerdrop): # skip the layer
|
||
|
attn = None
|
||
|
else:
|
||
|
x, attn = encoder_layer(
|
||
|
x,
|
||
|
attention_mask,
|
||
|
layer_head_mask=(head_mask[idx] if head_mask is not None else None),
|
||
|
output_attentions=output_attentions,
|
||
|
)
|
||
|
|
||
|
if output_attentions:
|
||
|
all_attentions = all_attentions + (attn,)
|
||
|
|
||
|
# T x B x C -> B x T x C
|
||
|
x = x.transpose(0, 1)
|
||
|
|
||
|
if output_hidden_states:
|
||
|
encoder_states += (x,)
|
||
|
|
||
|
if not return_dict:
|
||
|
return tuple(v for v in [x, encoder_states, all_attentions] if v is not None)
|
||
|
return BaseModelOutput(last_hidden_state=x, hidden_states=encoder_states, attentions=all_attentions)
|
||
|
|
||
|
|
||
|
class DecoderLayer(nn.Module):
|
||
|
def __init__(self, config: FSMTConfig):
|
||
|
super().__init__()
|
||
|
self.embed_dim = config.d_model
|
||
|
|
||
|
self.self_attn = Attention(
|
||
|
embed_dim=self.embed_dim,
|
||
|
num_heads=config.decoder_attention_heads,
|
||
|
dropout=config.attention_dropout,
|
||
|
)
|
||
|
self.dropout = config.dropout
|
||
|
self.activation_fn = ACT2FN[config.activation_function]
|
||
|
self.activation_dropout = config.activation_dropout
|
||
|
|
||
|
self.self_attn_layer_norm = LayerNorm(self.embed_dim)
|
||
|
self.encoder_attn = Attention(
|
||
|
self.embed_dim,
|
||
|
config.decoder_attention_heads,
|
||
|
dropout=config.attention_dropout,
|
||
|
encoder_decoder_attention=True,
|
||
|
)
|
||
|
self.encoder_attn_layer_norm = LayerNorm(self.embed_dim)
|
||
|
self.fc1 = nn.Linear(self.embed_dim, config.decoder_ffn_dim)
|
||
|
self.fc2 = nn.Linear(config.decoder_ffn_dim, self.embed_dim)
|
||
|
self.final_layer_norm = LayerNorm(self.embed_dim)
|
||
|
|
||
|
def forward(
|
||
|
self,
|
||
|
x,
|
||
|
encoder_hidden_states,
|
||
|
encoder_attn_mask=None,
|
||
|
layer_state=None,
|
||
|
causal_mask=None,
|
||
|
layer_head_mask=None,
|
||
|
cross_attn_layer_head_mask=None,
|
||
|
decoder_padding_mask=None,
|
||
|
output_attentions=False,
|
||
|
):
|
||
|
residual = x
|
||
|
|
||
|
if layer_state is None:
|
||
|
layer_state = {}
|
||
|
|
||
|
# Self Attention
|
||
|
x, self_attn_weights = self.self_attn(
|
||
|
query=x,
|
||
|
key=x,
|
||
|
layer_state=layer_state, # adds keys to layer state
|
||
|
key_padding_mask=decoder_padding_mask,
|
||
|
attn_mask=causal_mask,
|
||
|
layer_head_mask=layer_head_mask,
|
||
|
output_attentions=output_attentions,
|
||
|
)
|
||
|
x = nn.functional.dropout(x, p=self.dropout, training=self.training)
|
||
|
x = residual + x
|
||
|
x = self.self_attn_layer_norm(x)
|
||
|
|
||
|
# Cross attention
|
||
|
residual = x
|
||
|
assert self.encoder_attn.cache_key != self.self_attn.cache_key
|
||
|
x, cross_attn_weights = self.encoder_attn(
|
||
|
query=x,
|
||
|
key=encoder_hidden_states,
|
||
|
key_padding_mask=encoder_attn_mask,
|
||
|
layer_state=layer_state, # mutates layer state
|
||
|
layer_head_mask=cross_attn_layer_head_mask,
|
||
|
output_attentions=output_attentions,
|
||
|
)
|
||
|
x = nn.functional.dropout(x, p=self.dropout, training=self.training)
|
||
|
x = residual + x
|
||
|
x = self.encoder_attn_layer_norm(x)
|
||
|
|
||
|
# Fully Connected
|
||
|
residual = x
|
||
|
x = self.activation_fn(self.fc1(x))
|
||
|
x = nn.functional.dropout(x, p=self.activation_dropout, training=self.training)
|
||
|
x = self.fc2(x)
|
||
|
x = nn.functional.dropout(x, p=self.dropout, training=self.training)
|
||
|
x = residual + x
|
||
|
x = self.final_layer_norm(x)
|
||
|
return (
|
||
|
x,
|
||
|
self_attn_weights,
|
||
|
layer_state,
|
||
|
cross_attn_weights,
|
||
|
) # layer_state = cache for decoding
|
||
|
|
||
|
|
||
|
class FSMTDecoder(nn.Module):
|
||
|
"""
|
||
|
Transformer decoder consisting of *config.decoder_layers* layers. Each layer is a [`DecoderLayer`]
|
||
|
|
||
|
Args:
|
||
|
config: FSMTConfig
|
||
|
embed_tokens (nn.Embedding): output embedding
|
||
|
"""
|
||
|
|
||
|
def __init__(self, config: FSMTConfig, embed_tokens: nn.Embedding):
|
||
|
super().__init__()
|
||
|
self.dropout = config.dropout
|
||
|
self.layerdrop = config.decoder_layerdrop
|
||
|
self.padding_idx = embed_tokens.padding_idx
|
||
|
self.embed_scale = math.sqrt(config.d_model) if config.scale_embedding else 1.0
|
||
|
self.embed_tokens = embed_tokens
|
||
|
embed_dim = embed_tokens.embedding_dim
|
||
|
self.embed_positions = SinusoidalPositionalEmbedding(
|
||
|
config.max_position_embeddings + self.padding_idx + 1, embed_dim, self.padding_idx
|
||
|
)
|
||
|
self.layers = nn.ModuleList([DecoderLayer(config) for _ in range(config.decoder_layers)]) # type: List[DecoderLayer]
|
||
|
|
||
|
if is_deepspeed_zero3_enabled():
|
||
|
import deepspeed
|
||
|
|
||
|
with deepspeed.zero.GatheredParameters(self.embed_tokens.weight, modifier_rank=None):
|
||
|
embed_tokens_weight_shape = self.embed_tokens.weight.shape
|
||
|
else:
|
||
|
embed_tokens_weight_shape = self.embed_tokens.weight.shape
|
||
|
self.output_projection = nn.Linear(embed_tokens_weight_shape[1], embed_tokens_weight_shape[0], bias=False)
|
||
|
self.output_projection.weight = self.embed_tokens.weight
|
||
|
|
||
|
def forward(
|
||
|
self,
|
||
|
input_ids: torch.Tensor,
|
||
|
encoder_hidden_states: torch.Tensor,
|
||
|
encoder_padding_mask: torch.Tensor,
|
||
|
decoder_padding_mask: torch.Tensor,
|
||
|
decoder_causal_mask: torch.Tensor,
|
||
|
head_mask: Optional[torch.Tensor] = None,
|
||
|
inputs_embeds: Optional[torch.Tensor] = None,
|
||
|
cross_attn_head_mask: Optional[torch.Tensor] = None,
|
||
|
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
||
|
use_cache: bool = False,
|
||
|
output_attentions: bool = False,
|
||
|
output_hidden_states: bool = False,
|
||
|
return_dict: bool = True,
|
||
|
):
|
||
|
"""
|
||
|
Includes several features from "Jointly Learning to Align and Translate with Transformer Models" (Garg et al.,
|
||
|
EMNLP 2019).
|
||
|
|
||
|
Args:
|
||
|
input_ids (`torch.LongTensor` of shape `(batch, tgt_len)`):
|
||
|
previous decoder outputs for teacher forcing
|
||
|
encoder_hidden_states: output from the encoder, used for
|
||
|
encoder-side attention
|
||
|
encoder_padding_mask: for ignoring pad tokens
|
||
|
past_key_values (dict or None): dictionary used for storing state during generation
|
||
|
head_mask (`torch.Tensor` of shape `(num_layers, num_heads)`, *optional*):
|
||
|
Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`:
|
||
|
|
||
|
- 1 indicates the head is **not masked**,
|
||
|
- 0 indicates the head is **masked**.
|
||
|
|
||
|
cross_attn_head_mask (`torch.Tensor` of shape `(num_layers, num_heads)`, *optional*):
|
||
|
Mask to nullify selected heads of the cross-attention modules. Mask values selected in `[0, 1]`:
|
||
|
|
||
|
- 1 indicates the head is **not masked**,
|
||
|
- 0 indicates the head is **masked**.
|
||
|
|
||
|
Returns:
|
||
|
BaseModelOutputWithPast or tuple:
|
||
|
|
||
|
- the decoder's features of shape *(batch, tgt_len, embed_dim)*
|
||
|
- the cache
|
||
|
- hidden states
|
||
|
- attentions
|
||
|
"""
|
||
|
# check attention mask and invert
|
||
|
if encoder_padding_mask is not None:
|
||
|
encoder_padding_mask = invert_mask(encoder_padding_mask)
|
||
|
|
||
|
if input_ids is not None and inputs_embeds is not None:
|
||
|
raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
|
||
|
elif input_ids is not None:
|
||
|
# embed positions
|
||
|
positions = self.embed_positions(input_ids)
|
||
|
if use_cache:
|
||
|
input_ids = input_ids[:, -1:]
|
||
|
positions = positions[:, -1:] # happens after we embed them
|
||
|
x = self.embed_tokens(input_ids) * self.embed_scale
|
||
|
elif inputs_embeds is not None:
|
||
|
# We assume zeros hidden states correspond to padding tokens
|
||
|
# and create `position_ids` where inputs_embeds[:, :, 0] == 0
|
||
|
position_ids = inputs_embeds[:, :, 0].masked_fill(
|
||
|
inputs_embeds[:, :, 0].eq(0), self.embed_positions.padding_idx
|
||
|
)
|
||
|
positions = self.embed_positions(position_ids)
|
||
|
x = inputs_embeds * self.embed_scale
|
||
|
else:
|
||
|
raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
|
||
|
|
||
|
x += positions
|
||
|
x = nn.functional.dropout(x, p=self.dropout, training=self.training)
|
||
|
|
||
|
# Convert to FSMT output format: (BS, seq_len, model_dim) -> (seq_len, BS, model_dim)
|
||
|
x = x.transpose(0, 1)
|
||
|
encoder_hidden_states = encoder_hidden_states.transpose(0, 1)
|
||
|
|
||
|
# decoder layers
|
||
|
all_hidden_states = () if output_hidden_states else None
|
||
|
all_self_attns = () if output_attentions else None
|
||
|
all_cross_attns = () if output_attentions else None
|
||
|
next_decoder_cache = []
|
||
|
|
||
|
# check if head_mask has a correct number of layers specified if desired
|
||
|
for attn_mask, mask_name in zip([head_mask, cross_attn_head_mask], ["head_mask", "cross_attn_head_mask"]):
|
||
|
if attn_mask is not None:
|
||
|
assert attn_mask.size()[0] == (len(self.layers)), (
|
||
|
f"The `{mask_name}` should be specified for {len(self.layers)} layers, but it is for"
|
||
|
f" {head_mask.size()[0]}."
|
||
|
)
|
||
|
for idx, decoder_layer in enumerate(self.layers):
|
||
|
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
|
||
|
if output_hidden_states:
|
||
|
x = x.transpose(0, 1)
|
||
|
all_hidden_states += (x,)
|
||
|
x = x.transpose(0, 1)
|
||
|
if self.training:
|
||
|
dropout_probability = torch.rand([])
|
||
|
if dropout_probability < self.layerdrop:
|
||
|
continue
|
||
|
|
||
|
layer_state = past_key_values[idx] if past_key_values is not None else None
|
||
|
|
||
|
x, layer_self_attn, layer_past, layer_cross_attn = decoder_layer(
|
||
|
x,
|
||
|
encoder_hidden_states,
|
||
|
encoder_attn_mask=encoder_padding_mask,
|
||
|
decoder_padding_mask=decoder_padding_mask,
|
||
|
layer_state=layer_state,
|
||
|
causal_mask=decoder_causal_mask,
|
||
|
layer_head_mask=(head_mask[idx] if head_mask is not None else None),
|
||
|
cross_attn_layer_head_mask=(cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None),
|
||
|
output_attentions=output_attentions,
|
||
|
)
|
||
|
|
||
|
if use_cache:
|
||
|
next_decoder_cache.append(layer_past.copy())
|
||
|
|
||
|
if output_attentions:
|
||
|
all_self_attns += (layer_self_attn,)
|
||
|
all_cross_attns += (layer_cross_attn,)
|
||
|
|
||
|
# add hidden states from the last decoder layer
|
||
|
if output_hidden_states:
|
||
|
x = x.transpose(0, 1)
|
||
|
all_hidden_states += (x,)
|
||
|
x = x.transpose(0, 1)
|
||
|
|
||
|
# Convert to standard output format: (seq_len, BS, model_dim) -> (BS, seq_len, model_dim)
|
||
|
x = x.transpose(0, 1)
|
||
|
encoder_hidden_states = encoder_hidden_states.transpose(0, 1)
|
||
|
|
||
|
x = self.output_projection(x)
|
||
|
|
||
|
next_cache = next_decoder_cache if use_cache else None
|
||
|
|
||
|
if not return_dict:
|
||
|
return tuple(
|
||
|
v for v in [x, next_cache, all_hidden_states, all_self_attns, all_cross_attns] if v is not None
|
||
|
)
|
||
|
return BaseModelOutputWithPastAndCrossAttentions(
|
||
|
last_hidden_state=x,
|
||
|
past_key_values=next_cache,
|
||
|
hidden_states=all_hidden_states,
|
||
|
attentions=all_self_attns,
|
||
|
cross_attentions=all_cross_attns,
|
||
|
)
|
||
|
|
||
|
|
||
|
def _reorder_buffer(attn_cache, new_order):
|
||
|
for k, input_buffer_k in attn_cache.items():
|
||
|
if input_buffer_k is not None:
|
||
|
attn_cache[k] = input_buffer_k.index_select(0, new_order)
|
||
|
return attn_cache
|
||
|
|
||
|
|
||
|
class Attention(nn.Module):
|
||
|
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
||
|
|
||
|
def __init__(
|
||
|
self,
|
||
|
embed_dim,
|
||
|
num_heads,
|
||
|
dropout=0.0,
|
||
|
bias=True,
|
||
|
encoder_decoder_attention=False, # otherwise self_attention
|
||
|
):
|
||
|
super().__init__()
|
||
|
self.embed_dim = embed_dim
|
||
|
self.num_heads = num_heads
|
||
|
self.dropout = dropout
|
||
|
self.head_dim = embed_dim // num_heads
|
||
|
assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads"
|
||
|
self.scaling = self.head_dim**-0.5
|
||
|
|
||
|
self.encoder_decoder_attention = encoder_decoder_attention
|
||
|
self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
|
||
|
self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
|
||
|
self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
|
||
|
self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
|
||
|
self.cache_key = "encoder_decoder" if self.encoder_decoder_attention else "self"
|
||
|
|
||
|
def _shape(self, tensor, seq_len, bsz):
|
||
|
return tensor.contiguous().view(seq_len, bsz * self.num_heads, self.head_dim).transpose(0, 1)
|
||
|
|
||
|
def forward(
|
||
|
self,
|
||
|
query,
|
||
|
key: Optional[Tensor],
|
||
|
key_padding_mask: Optional[Tensor] = None,
|
||
|
layer_state: Optional[Dict[str, Optional[Tensor]]] = None,
|
||
|
attn_mask: Optional[Tensor] = None,
|
||
|
layer_head_mask: Optional[Tensor] = None,
|
||
|
output_attentions=False,
|
||
|
) -> Tuple[Tensor, Optional[Tensor]]:
|
||
|
"""Input shape: Time(SeqLen) x Batch x Channel"""
|
||
|
static_kv: bool = self.encoder_decoder_attention
|
||
|
tgt_len, bsz, embed_dim = query.size()
|
||
|
assert embed_dim == self.embed_dim
|
||
|
assert list(query.size()) == [tgt_len, bsz, embed_dim]
|
||
|
# get here for encoder decoder cause of static_kv
|
||
|
if layer_state is not None: # reuse k,v and encoder_padding_mask
|
||
|
saved_state = layer_state.get(self.cache_key, {})
|
||
|
if "prev_key" in saved_state and static_kv:
|
||
|
# previous time steps are cached - no need to recompute key and value if they are static
|
||
|
key = None
|
||
|
else:
|
||
|
saved_state = None
|
||
|
layer_state = {}
|
||
|
|
||
|
q = self.q_proj(query) * self.scaling
|
||
|
if static_kv:
|
||
|
if key is None:
|
||
|
k = v = None
|
||
|
else:
|
||
|
k = self.k_proj(key)
|
||
|
v = self.v_proj(key)
|
||
|
else:
|
||
|
k = self.k_proj(query)
|
||
|
v = self.v_proj(query)
|
||
|
|
||
|
q = self._shape(q, tgt_len, bsz)
|
||
|
if k is not None:
|
||
|
k = self._shape(k, -1, bsz)
|
||
|
if v is not None:
|
||
|
v = self._shape(v, -1, bsz)
|
||
|
|
||
|
if saved_state is not None:
|
||
|
k, v, key_padding_mask = self._use_saved_state(k, v, saved_state, key_padding_mask, static_kv, bsz)
|
||
|
|
||
|
# Update cache
|
||
|
layer_state[self.cache_key] = {
|
||
|
"prev_key": k.view(bsz, self.num_heads, -1, self.head_dim),
|
||
|
"prev_value": v.view(bsz, self.num_heads, -1, self.head_dim),
|
||
|
"prev_key_padding_mask": key_padding_mask if not static_kv else None,
|
||
|
}
|
||
|
|
||
|
assert k is not None
|
||
|
src_len = k.size(1)
|
||
|
attn_weights = torch.bmm(q, k.transpose(1, 2))
|
||
|
assert attn_weights.size() == (bsz * self.num_heads, tgt_len, src_len)
|
||
|
|
||
|
if attn_mask is not None:
|
||
|
attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attn_mask
|
||
|
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
|
||
|
|
||
|
# This is part of a workaround to get around fork/join parallelism not supporting Optional types.
|
||
|
if key_padding_mask is not None and key_padding_mask.dim() == 0:
|
||
|
key_padding_mask = None
|
||
|
assert key_padding_mask is None or key_padding_mask.size()[:2] == (
|
||
|
bsz,
|
||
|
src_len,
|
||
|
)
|
||
|
|
||
|
if key_padding_mask is not None: # don't attend to padding symbols
|
||
|
attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
|
||
|
reshaped = key_padding_mask.unsqueeze(1).unsqueeze(2)
|
||
|
attn_weights = attn_weights.masked_fill(reshaped, torch.finfo(attn_weights.dtype).min)
|
||
|
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
|
||
|
|
||
|
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
|
||
|
|
||
|
if layer_head_mask is not None:
|
||
|
assert layer_head_mask.size() == (
|
||
|
self.num_heads,
|
||
|
), f"Head mask for a single layer should be of size {(self.num_heads,)}, but is {layer_head_mask.size()}"
|
||
|
attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
|
||
|
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
|
||
|
|
||
|
if output_attentions:
|
||
|
# make sure that attn_weights are included in graph
|
||
|
attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
|
||
|
attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len)
|
||
|
else:
|
||
|
attn_weights_reshaped = None
|
||
|
|
||
|
attn_probs = nn.functional.dropout(
|
||
|
attn_weights,
|
||
|
p=self.dropout,
|
||
|
training=self.training,
|
||
|
)
|
||
|
|
||
|
assert v is not None
|
||
|
attn_output = torch.bmm(attn_probs, v)
|
||
|
assert attn_output.size() == (bsz * self.num_heads, tgt_len, self.head_dim)
|
||
|
attn_output = attn_output.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim)
|
||
|
attn_output = self.out_proj(attn_output)
|
||
|
|
||
|
return attn_output, attn_weights_reshaped
|
||
|
|
||
|
def _use_saved_state(self, k, v, saved_state, key_padding_mask, static_kv, bsz):
|
||
|
# saved states are stored with shape (bsz, num_heads, seq_len, head_dim)
|
||
|
if "prev_key" in saved_state:
|
||
|
_prev_key = saved_state["prev_key"]
|
||
|
assert _prev_key is not None
|
||
|
prev_key = _prev_key.view(bsz * self.num_heads, -1, self.head_dim)
|
||
|
if static_kv:
|
||
|
k = prev_key
|
||
|
else:
|
||
|
assert k is not None
|
||
|
k = torch.cat([prev_key, k], dim=1)
|
||
|
if "prev_value" in saved_state:
|
||
|
_prev_value = saved_state["prev_value"]
|
||
|
assert _prev_value is not None
|
||
|
prev_value = _prev_value.view(bsz * self.num_heads, -1, self.head_dim)
|
||
|
if static_kv:
|
||
|
v = prev_value
|
||
|
else:
|
||
|
assert v is not None
|
||
|
v = torch.cat([prev_value, v], dim=1)
|
||
|
assert k is not None and v is not None
|
||
|
prev_key_padding_mask: Optional[Tensor] = saved_state.get("prev_key_padding_mask", None)
|
||
|
if prev_key_padding_mask is not None:
|
||
|
if static_kv:
|
||
|
new_key_padding_mask = prev_key_padding_mask
|
||
|
else:
|
||
|
new_key_padding_mask = torch.cat([prev_key_padding_mask, key_padding_mask], dim=1)
|
||
|
else:
|
||
|
new_key_padding_mask = key_padding_mask
|
||
|
return k, v, new_key_padding_mask
|
||
|
|
||
|
|
||
|
def fill_with_neg_inf(t):
|
||
|
"""FP16-compatible function that fills a input_ids with -inf."""
|
||
|
return t.float().fill_(torch.finfo(t.dtype).min).type_as(t)
|
||
|
|
||
|
|
||
|
# Public API
|
||
|
def _get_shape(t):
|
||
|
return getattr(t, "shape", None)
|
||
|
|
||
|
|
||
|
@add_start_docstrings(
|
||
|
"The bare FSMT Model outputting raw hidden-states without any specific head on top.",
|
||
|
FSMT_START_DOCSTRING,
|
||
|
)
|
||
|
class FSMTModel(PretrainedFSMTModel):
|
||
|
_tied_weights_keys = ["decoder.embed_tokens.weight", "decoder.output_projection.weight"]
|
||
|
|
||
|
def __init__(self, config: FSMTConfig):
|
||
|
super().__init__(config)
|
||
|
|
||
|
padding_idx = config.pad_token_id
|
||
|
encoder_embed_tokens = nn.Embedding(config.src_vocab_size, config.d_model, padding_idx)
|
||
|
decoder_embed_tokens = nn.Embedding(config.tgt_vocab_size, config.d_model, padding_idx)
|
||
|
|
||
|
self.encoder = FSMTEncoder(config, encoder_embed_tokens)
|
||
|
self.decoder = FSMTDecoder(config, decoder_embed_tokens)
|
||
|
|
||
|
# Initialize weights and apply final processing
|
||
|
self.post_init()
|
||
|
|
||
|
def get_encoder(self):
|
||
|
return self.encoder
|
||
|
|
||
|
def get_decoder(self):
|
||
|
return self.decoder
|
||
|
|
||
|
def _tie_weights(self):
|
||
|
if self.config.tie_word_embeddings:
|
||
|
self._tie_or_clone_weights(self.decoder.embed_tokens, self.get_input_embeddings())
|
||
|
self._tie_or_clone_weights(self.decoder.output_projection, self.get_input_embeddings())
|
||
|
|
||
|
@add_start_docstrings_to_model_forward(FSMT_INPUTS_DOCSTRING)
|
||
|
@add_code_sample_docstrings(
|
||
|
checkpoint=_CHECKPOINT_FOR_DOC,
|
||
|
output_type=Seq2SeqModelOutput,
|
||
|
config_class=_CONFIG_FOR_DOC,
|
||
|
)
|
||
|
def forward(
|
||
|
self,
|
||
|
input_ids: torch.LongTensor,
|
||
|
attention_mask: Optional[torch.Tensor] = None,
|
||
|
decoder_input_ids: Optional[torch.LongTensor] = None,
|
||
|
decoder_attention_mask: Optional[torch.BoolTensor] = None,
|
||
|
head_mask: Optional[torch.Tensor] = None,
|
||
|
decoder_head_mask: Optional[torch.Tensor] = None,
|
||
|
cross_attn_head_mask: Optional[torch.Tensor] = None,
|
||
|
encoder_outputs: Optional[Tuple[torch.FloatTensor]] = None,
|
||
|
past_key_values: Optional[Tuple[torch.FloatTensor]] = None,
|
||
|
use_cache: Optional[bool] = None,
|
||
|
output_attentions: Optional[bool] = None,
|
||
|
output_hidden_states: Optional[bool] = None,
|
||
|
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||
|
decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
|
||
|
return_dict: Optional[bool] = None,
|
||
|
) -> Union[Tuple[torch.Tensor], Seq2SeqModelOutput]:
|
||
|
if decoder_input_ids is None:
|
||
|
use_cache = False
|
||
|
|
||
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||
|
output_hidden_states = (
|
||
|
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||
|
)
|
||
|
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
||
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||
|
|
||
|
# make masks if user doesn't supply
|
||
|
if not use_cache and input_ids is not None:
|
||
|
decoder_input_ids, decoder_padding_mask, causal_mask = _prepare_fsmt_decoder_inputs(
|
||
|
self.config,
|
||
|
input_ids,
|
||
|
decoder_input_ids=decoder_input_ids,
|
||
|
decoder_padding_mask=decoder_attention_mask,
|
||
|
causal_mask_dtype=self.decoder.embed_tokens.weight.dtype,
|
||
|
)
|
||
|
else:
|
||
|
decoder_padding_mask, causal_mask = None, None
|
||
|
|
||
|
if decoder_input_ids is None and decoder_inputs_embeds is None:
|
||
|
raise ValueError("Make sure that `decoder_input_ids` or `decoder_inputs_embeds` are passed.")
|
||
|
|
||
|
if encoder_outputs is None:
|
||
|
encoder_outputs = self.encoder(
|
||
|
input_ids=input_ids,
|
||
|
attention_mask=attention_mask,
|
||
|
inputs_embeds=inputs_embeds,
|
||
|
head_mask=head_mask,
|
||
|
output_attentions=output_attentions,
|
||
|
output_hidden_states=output_hidden_states,
|
||
|
return_dict=return_dict,
|
||
|
)
|
||
|
# If the user passed a tuple for encoder_outputs, we wrap it in a BaseModelOutput when return_dict=False
|
||
|
elif return_dict and not isinstance(encoder_outputs, BaseModelOutput):
|
||
|
encoder_outputs = BaseModelOutput(
|
||
|
last_hidden_state=encoder_outputs[0],
|
||
|
hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None,
|
||
|
attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None,
|
||
|
)
|
||
|
|
||
|
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
||
|
decoder_outputs = self.decoder(
|
||
|
decoder_input_ids,
|
||
|
encoder_outputs[0],
|
||
|
attention_mask,
|
||
|
decoder_padding_mask,
|
||
|
decoder_causal_mask=causal_mask,
|
||
|
inputs_embeds=decoder_inputs_embeds,
|
||
|
head_mask=decoder_head_mask,
|
||
|
cross_attn_head_mask=cross_attn_head_mask,
|
||
|
past_key_values=past_key_values,
|
||
|
use_cache=use_cache,
|
||
|
output_attentions=output_attentions,
|
||
|
output_hidden_states=output_hidden_states,
|
||
|
return_dict=return_dict,
|
||
|
)
|
||
|
|
||
|
if not return_dict:
|
||
|
return decoder_outputs + encoder_outputs
|
||
|
|
||
|
return Seq2SeqModelOutput(
|
||
|
last_hidden_state=decoder_outputs.last_hidden_state,
|
||
|
past_key_values=decoder_outputs.past_key_values,
|
||
|
decoder_hidden_states=decoder_outputs.hidden_states,
|
||
|
decoder_attentions=decoder_outputs.attentions,
|
||
|
cross_attentions=decoder_outputs.cross_attentions,
|
||
|
encoder_last_hidden_state=encoder_outputs.last_hidden_state,
|
||
|
encoder_hidden_states=encoder_outputs.hidden_states,
|
||
|
encoder_attentions=encoder_outputs.attentions,
|
||
|
)
|
||
|
|
||
|
def get_input_embeddings(self):
|
||
|
return self.encoder.embed_tokens
|
||
|
|
||
|
def set_input_embeddings(self, value):
|
||
|
self.encoder.embed_tokens = value
|
||
|
|
||
|
def get_output_embeddings(self):
|
||
|
return self.decoder.embed_tokens
|
||
|
|
||
|
def set_output_embeddings(self, value):
|
||
|
self.decoder.embed_tokens = value
|
||
|
|
||
|
|
||
|
@add_start_docstrings(
|
||
|
"The FSMT Model with a language modeling head. Can be used for summarization.", FSMT_START_DOCSTRING
|
||
|
)
|
||
|
class FSMTForConditionalGeneration(PretrainedFSMTModel):
|
||
|
base_model_prefix = "model"
|
||
|
_tied_weights_keys = ["decoder.embed_tokens.weight", "decoder.output_projection.weight"]
|
||
|
|
||
|
def __init__(self, config: FSMTConfig):
|
||
|
super().__init__(config)
|
||
|
base_model = FSMTModel(config)
|
||
|
self.model = base_model
|
||
|
|
||
|
# Initialize weights and apply final processing
|
||
|
self.post_init()
|
||
|
|
||
|
@add_start_docstrings_to_model_forward(FSMT_INPUTS_DOCSTRING)
|
||
|
@replace_return_docstrings(output_type=Seq2SeqLMOutput, config_class=_CONFIG_FOR_DOC)
|
||
|
@add_end_docstrings(FSMT_GENERATION_EXAMPLE)
|
||
|
def forward(
|
||
|
self,
|
||
|
input_ids: torch.LongTensor,
|
||
|
attention_mask: Optional[torch.Tensor] = None,
|
||
|
decoder_input_ids: Optional[torch.LongTensor] = None,
|
||
|
decoder_attention_mask: Optional[torch.BoolTensor] = None,
|
||
|
head_mask: Optional[torch.Tensor] = None,
|
||
|
decoder_head_mask: Optional[torch.Tensor] = None,
|
||
|
cross_attn_head_mask: Optional[torch.Tensor] = None,
|
||
|
encoder_outputs: Optional[Tuple[torch.FloatTensor]] = None,
|
||
|
past_key_values: Optional[Tuple[torch.FloatTensor]] = None,
|
||
|
inputs_embeds: Optional[torch.Tensor] = None,
|
||
|
decoder_inputs_embeds: Optional[torch.Tensor] = None,
|
||
|
labels: Optional[torch.LongTensor] = None,
|
||
|
use_cache: Optional[bool] = None,
|
||
|
output_attentions: Optional[bool] = None,
|
||
|
output_hidden_states: Optional[bool] = None,
|
||
|
return_dict: Optional[bool] = None,
|
||
|
) -> Union[Tuple[torch.Tensor], Seq2SeqLMOutput]:
|
||
|
r"""
|
||
|
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||
|
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
||
|
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
||
|
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
||
|
|
||
|
Returns:
|
||
|
|
||
|
"""
|
||
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||
|
|
||
|
if labels is not None:
|
||
|
use_cache = False
|
||
|
|
||
|
outputs = self.model(
|
||
|
input_ids,
|
||
|
inputs_embeds=inputs_embeds,
|
||
|
attention_mask=attention_mask,
|
||
|
decoder_input_ids=decoder_input_ids,
|
||
|
decoder_inputs_embeds=decoder_inputs_embeds,
|
||
|
encoder_outputs=encoder_outputs,
|
||
|
decoder_attention_mask=decoder_attention_mask,
|
||
|
head_mask=head_mask,
|
||
|
decoder_head_mask=decoder_head_mask,
|
||
|
cross_attn_head_mask=cross_attn_head_mask,
|
||
|
past_key_values=past_key_values,
|
||
|
use_cache=use_cache,
|
||
|
output_attentions=output_attentions,
|
||
|
output_hidden_states=output_hidden_states,
|
||
|
return_dict=return_dict,
|
||
|
)
|
||
|
lm_logits = outputs[0]
|
||
|
|
||
|
masked_lm_loss = None
|
||
|
if labels is not None:
|
||
|
loss_fct = CrossEntropyLoss()
|
||
|
# TODO(SS): do we need to ignore pad tokens in labels?
|
||
|
masked_lm_loss = loss_fct(lm_logits.view(-1, self.config.tgt_vocab_size), labels.view(-1))
|
||
|
|
||
|
if not return_dict:
|
||
|
output = (lm_logits,) + outputs[1:]
|
||
|
return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output
|
||
|
|
||
|
return Seq2SeqLMOutput(
|
||
|
loss=masked_lm_loss,
|
||
|
logits=lm_logits,
|
||
|
past_key_values=outputs.past_key_values,
|
||
|
decoder_hidden_states=outputs.decoder_hidden_states,
|
||
|
decoder_attentions=outputs.decoder_attentions,
|
||
|
cross_attentions=outputs.cross_attentions,
|
||
|
encoder_last_hidden_state=outputs.encoder_last_hidden_state,
|
||
|
encoder_hidden_states=outputs.encoder_hidden_states,
|
||
|
encoder_attentions=outputs.encoder_attentions,
|
||
|
)
|
||
|
|
||
|
def prepare_inputs_for_generation(
|
||
|
self,
|
||
|
decoder_input_ids,
|
||
|
past_key_values=None,
|
||
|
attention_mask=None,
|
||
|
head_mask=None,
|
||
|
decoder_head_mask=None,
|
||
|
cross_attn_head_mask=None,
|
||
|
use_cache=None,
|
||
|
encoder_outputs=None,
|
||
|
**kwargs,
|
||
|
):
|
||
|
return {
|
||
|
"input_ids": None, # encoder_outputs is defined. input_ids not needed
|
||
|
"encoder_outputs": encoder_outputs,
|
||
|
"past_key_values": past_key_values,
|
||
|
"decoder_input_ids": decoder_input_ids,
|
||
|
"attention_mask": attention_mask,
|
||
|
"head_mask": head_mask,
|
||
|
"decoder_head_mask": decoder_head_mask,
|
||
|
"cross_attn_head_mask": cross_attn_head_mask,
|
||
|
"use_cache": use_cache, # change this to avoid caching (presumably for debugging)
|
||
|
}
|
||
|
|
||
|
def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor):
|
||
|
return shift_tokens_right(labels, self.config.pad_token_id)
|
||
|
|
||
|
@staticmethod
|
||
|
def _reorder_cache(past_key_values, beam_idx):
|
||
|
reordered_past = []
|
||
|
for layer_past in past_key_values:
|
||
|
# get the correct batch idx from decoder layer's batch dim for cross and self-attn
|
||
|
layer_past_new = {
|
||
|
attn_key: _reorder_buffer(attn_cache, beam_idx) for attn_key, attn_cache in layer_past.items()
|
||
|
}
|
||
|
reordered_past.append(layer_past_new)
|
||
|
return reordered_past
|
||
|
|
||
|
def get_encoder(self):
|
||
|
return self.model.encoder
|
||
|
|
||
|
def get_decoder(self):
|
||
|
return self.model.decoder
|
||
|
|
||
|
def get_output_embeddings(self):
|
||
|
return self.model.decoder.embed_tokens
|
||
|
|
||
|
def set_output_embeddings(self, value):
|
||
|
self.model.decoder.embed_tokens = value
|
||
|
|
||
|
|
||
|
class SinusoidalPositionalEmbedding(nn.Embedding):
|
||
|
"""
|
||
|
This module produces sinusoidal positional embeddings of any length.
|
||
|
|
||
|
We don't want to save the weight of this embedding since it's not trained (deterministic) and it can be huge.
|
||
|
|
||
|
Padding symbols are ignored.
|
||
|
|
||
|
These embeddings get automatically extended in forward if more positions is needed.
|
||
|
"""
|
||
|
|
||
|
def __init__(self, num_positions, embedding_dim, padding_idx):
|
||
|
self.make_weight(num_positions, embedding_dim, padding_idx)
|
||
|
|
||
|
def make_weight(self, num_positions, embedding_dim, padding_idx):
|
||
|
weight = self.get_embedding(num_positions, embedding_dim, padding_idx)
|
||
|
if not hasattr(self, "weight"):
|
||
|
# in ___init__
|
||
|
super().__init__(num_positions, embedding_dim, padding_idx, _weight=weight)
|
||
|
else:
|
||
|
# in forward put the weights on the correct dtype and device of the param
|
||
|
weight = weight.to(dtype=self.weight.dtype, device=self.weight.device)
|
||
|
self.weight = nn.Parameter(weight)
|
||
|
self.weight.detach_()
|
||
|
self.weight.requires_grad = False
|
||
|
|
||
|
@staticmethod
|
||
|
def get_embedding(num_embeddings, embedding_dim, padding_idx):
|
||
|
"""
|
||
|
Build sinusoidal embeddings.
|
||
|
|
||
|
This matches the implementation in tensor2tensor, but differs slightly from the description in Section 3.5 of
|
||
|
"Attention Is All You Need".
|
||
|
"""
|
||
|
half_dim = embedding_dim // 2
|
||
|
emb = math.log(10000) / (half_dim - 1)
|
||
|
emb = torch.exp(torch.arange(half_dim, dtype=torch.int64).float() * -emb)
|
||
|
emb = torch.arange(num_embeddings, dtype=torch.int64).float().unsqueeze(1) * emb.unsqueeze(0)
|
||
|
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1).view(num_embeddings, -1)
|
||
|
if embedding_dim % 2 == 1:
|
||
|
# zero pad
|
||
|
emb = torch.cat([emb, torch.zeros(num_embeddings, 1)], dim=1)
|
||
|
if padding_idx is not None:
|
||
|
emb[padding_idx, :] = 0
|
||
|
return emb
|
||
|
|
||
|
@staticmethod
|
||
|
def make_positions(tensor, padding_idx: int):
|
||
|
"""
|
||
|
Replace non-padding symbols with their position numbers.
|
||
|
|
||
|
Position numbers begin at padding_idx+1. Padding symbols are ignored.
|
||
|
"""
|
||
|
# The series of casts and type-conversions here are carefully
|
||
|
# balanced to both work with ONNX export and XLA. In particular XLA
|
||
|
# prefers ints, cumsum defaults to output longs, and ONNX doesn't know
|
||
|
# how to handle the dtype kwarg in cumsum.
|
||
|
mask = tensor.ne(padding_idx).int()
|
||
|
return (torch.cumsum(mask, dim=1).type_as(mask) * mask).long() + padding_idx
|
||
|
|
||
|
def forward(
|
||
|
self,
|
||
|
input,
|
||
|
incremental_state: Optional[Any] = None,
|
||
|
timestep: Optional[Tensor] = None,
|
||
|
):
|
||
|
"""Input is expected to be of size [bsz x seqlen]."""
|
||
|
bsz, seq_len = input.shape[:2]
|
||
|
max_pos = self.padding_idx + 1 + seq_len
|
||
|
if max_pos > self.weight.size(0):
|
||
|
# expand embeddings if needed
|
||
|
self.make_weight(max_pos, self.embedding_dim, self.padding_idx)
|
||
|
positions = self.make_positions(input, self.padding_idx)
|
||
|
return super().forward(positions)
|