1720 lines
80 KiB
Python
1720 lines
80 KiB
Python
|
# coding=utf-8
|
||
|
# Copyright 2021 The Marian Team Authors and The HuggingFace Inc. team. All rights reserved.
|
||
|
#
|
||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||
|
# you may not use this file except in compliance with the License.
|
||
|
# You may obtain a copy of the License at
|
||
|
#
|
||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||
|
#
|
||
|
# Unless required by applicable law or agreed to in writing, software
|
||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
|
# See the License for the specific language governing permissions and
|
||
|
# limitations under the License.
|
||
|
"""PyTorch MarianMTModel model, ported from the Marian C++ repo."""
|
||
|
|
||
|
|
||
|
import copy
|
||
|
import math
|
||
|
from typing import Dict, List, Optional, Tuple, Union
|
||
|
|
||
|
import numpy as np
|
||
|
import torch
|
||
|
import torch.utils.checkpoint
|
||
|
from torch import nn
|
||
|
from torch.nn import CrossEntropyLoss
|
||
|
|
||
|
from ...activations import ACT2FN
|
||
|
from ...modeling_attn_mask_utils import _prepare_4d_attention_mask, _prepare_4d_causal_attention_mask
|
||
|
from ...modeling_outputs import (
|
||
|
BaseModelOutput,
|
||
|
BaseModelOutputWithPastAndCrossAttentions,
|
||
|
CausalLMOutputWithCrossAttentions,
|
||
|
Seq2SeqLMOutput,
|
||
|
Seq2SeqModelOutput,
|
||
|
)
|
||
|
from ...modeling_utils import PreTrainedModel
|
||
|
from ...utils import (
|
||
|
add_end_docstrings,
|
||
|
add_start_docstrings,
|
||
|
add_start_docstrings_to_model_forward,
|
||
|
logging,
|
||
|
replace_return_docstrings,
|
||
|
)
|
||
|
from .configuration_marian import MarianConfig
|
||
|
|
||
|
|
||
|
logger = logging.get_logger(__name__)
|
||
|
|
||
|
_CONFIG_FOR_DOC = "MarianConfig"
|
||
|
_CHECKPOINT_FOR_DOC = "Helsinki-NLP/opus-mt-en-de"
|
||
|
|
||
|
|
||
|
# Copied from transformers.models.bart.modeling_bart.shift_tokens_right
|
||
|
def shift_tokens_right(input_ids: torch.Tensor, pad_token_id: int, decoder_start_token_id: int):
|
||
|
"""
|
||
|
Shift input ids one token to the right.
|
||
|
"""
|
||
|
shifted_input_ids = input_ids.new_zeros(input_ids.shape)
|
||
|
shifted_input_ids[:, 1:] = input_ids[:, :-1].clone()
|
||
|
shifted_input_ids[:, 0] = decoder_start_token_id
|
||
|
|
||
|
if pad_token_id is None:
|
||
|
raise ValueError("self.model.config.pad_token_id has to be defined.")
|
||
|
# replace possible -100 values in labels by `pad_token_id`
|
||
|
shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id)
|
||
|
|
||
|
return shifted_input_ids
|
||
|
|
||
|
|
||
|
class MarianSinusoidalPositionalEmbedding(nn.Embedding):
|
||
|
"""This module produces sinusoidal positional embeddings of any length."""
|
||
|
|
||
|
def __init__(self, num_positions: int, embedding_dim: int, padding_idx: Optional[int] = None) -> None:
|
||
|
super().__init__(num_positions, embedding_dim)
|
||
|
self.weight = self._init_weight(self.weight)
|
||
|
|
||
|
@staticmethod
|
||
|
def _init_weight(out: nn.Parameter) -> nn.Parameter:
|
||
|
"""
|
||
|
Identical to the XLM create_sinusoidal_embeddings except features are not interleaved. The cos features are in
|
||
|
the 2nd half of the vector. [dim // 2:]
|
||
|
"""
|
||
|
n_pos, dim = out.shape
|
||
|
position_enc = np.array(
|
||
|
[[pos / np.power(10000, 2 * (j // 2) / dim) for j in range(dim)] for pos in range(n_pos)]
|
||
|
)
|
||
|
out.requires_grad = False # set early to avoid an error in pytorch-1.8+
|
||
|
sentinel = dim // 2 if dim % 2 == 0 else (dim // 2) + 1
|
||
|
out[:, 0:sentinel] = torch.FloatTensor(np.sin(position_enc[:, 0::2]))
|
||
|
out[:, sentinel:] = torch.FloatTensor(np.cos(position_enc[:, 1::2]))
|
||
|
out.detach_()
|
||
|
return out
|
||
|
|
||
|
@torch.no_grad()
|
||
|
def forward(self, input_ids_shape: torch.Size, past_key_values_length: int = 0) -> torch.Tensor:
|
||
|
"""`input_ids_shape` is expected to be [bsz x seqlen]."""
|
||
|
bsz, seq_len = input_ids_shape[:2]
|
||
|
positions = torch.arange(
|
||
|
past_key_values_length, past_key_values_length + seq_len, dtype=torch.long, device=self.weight.device
|
||
|
)
|
||
|
return super().forward(positions)
|
||
|
|
||
|
|
||
|
# Copied from transformers.models.bart.modeling_bart.BartAttention with Bart->Marian
|
||
|
class MarianAttention(nn.Module):
|
||
|
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
||
|
|
||
|
def __init__(
|
||
|
self,
|
||
|
embed_dim: int,
|
||
|
num_heads: int,
|
||
|
dropout: float = 0.0,
|
||
|
is_decoder: bool = False,
|
||
|
bias: bool = True,
|
||
|
is_causal: bool = False,
|
||
|
config: Optional[MarianConfig] = None,
|
||
|
):
|
||
|
super().__init__()
|
||
|
self.embed_dim = embed_dim
|
||
|
self.num_heads = num_heads
|
||
|
self.dropout = dropout
|
||
|
self.head_dim = embed_dim // num_heads
|
||
|
self.config = config
|
||
|
|
||
|
if (self.head_dim * num_heads) != self.embed_dim:
|
||
|
raise ValueError(
|
||
|
f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}"
|
||
|
f" and `num_heads`: {num_heads})."
|
||
|
)
|
||
|
self.scaling = self.head_dim**-0.5
|
||
|
self.is_decoder = is_decoder
|
||
|
self.is_causal = is_causal
|
||
|
|
||
|
self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
|
||
|
self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
|
||
|
self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
|
||
|
self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
|
||
|
|
||
|
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
|
||
|
return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
|
||
|
|
||
|
def forward(
|
||
|
self,
|
||
|
hidden_states: torch.Tensor,
|
||
|
key_value_states: Optional[torch.Tensor] = None,
|
||
|
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
||
|
attention_mask: Optional[torch.Tensor] = None,
|
||
|
layer_head_mask: Optional[torch.Tensor] = None,
|
||
|
output_attentions: bool = False,
|
||
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||
|
"""Input shape: Batch x Time x Channel"""
|
||
|
|
||
|
# if key_value_states are provided this layer is used as a cross-attention layer
|
||
|
# for the decoder
|
||
|
is_cross_attention = key_value_states is not None
|
||
|
|
||
|
bsz, tgt_len, _ = hidden_states.size()
|
||
|
|
||
|
# get query proj
|
||
|
query_states = self.q_proj(hidden_states) * self.scaling
|
||
|
# get key, value proj
|
||
|
# `past_key_value[0].shape[2] == key_value_states.shape[1]`
|
||
|
# is checking that the `sequence_length` of the `past_key_value` is the same as
|
||
|
# the provided `key_value_states` to support prefix tuning
|
||
|
if (
|
||
|
is_cross_attention
|
||
|
and past_key_value is not None
|
||
|
and past_key_value[0].shape[2] == key_value_states.shape[1]
|
||
|
):
|
||
|
# reuse k,v, cross_attentions
|
||
|
key_states = past_key_value[0]
|
||
|
value_states = past_key_value[1]
|
||
|
elif is_cross_attention:
|
||
|
# cross_attentions
|
||
|
key_states = self._shape(self.k_proj(key_value_states), -1, bsz)
|
||
|
value_states = self._shape(self.v_proj(key_value_states), -1, bsz)
|
||
|
elif past_key_value is not None:
|
||
|
# reuse k, v, self_attention
|
||
|
key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
|
||
|
value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
|
||
|
key_states = torch.cat([past_key_value[0], key_states], dim=2)
|
||
|
value_states = torch.cat([past_key_value[1], value_states], dim=2)
|
||
|
else:
|
||
|
# self_attention
|
||
|
key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
|
||
|
value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
|
||
|
|
||
|
if self.is_decoder:
|
||
|
# if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
|
||
|
# Further calls to cross_attention layer can then reuse all cross-attention
|
||
|
# key/value_states (first "if" case)
|
||
|
# if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
|
||
|
# all previous decoder key/value_states. Further calls to uni-directional self-attention
|
||
|
# can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
|
||
|
# if encoder bi-directional self-attention `past_key_value` is always `None`
|
||
|
past_key_value = (key_states, value_states)
|
||
|
|
||
|
proj_shape = (bsz * self.num_heads, -1, self.head_dim)
|
||
|
query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape)
|
||
|
key_states = key_states.reshape(*proj_shape)
|
||
|
value_states = value_states.reshape(*proj_shape)
|
||
|
|
||
|
src_len = key_states.size(1)
|
||
|
attn_weights = torch.bmm(query_states, key_states.transpose(1, 2))
|
||
|
|
||
|
if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len):
|
||
|
raise ValueError(
|
||
|
f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is"
|
||
|
f" {attn_weights.size()}"
|
||
|
)
|
||
|
|
||
|
if attention_mask is not None:
|
||
|
if attention_mask.size() != (bsz, 1, tgt_len, src_len):
|
||
|
raise ValueError(
|
||
|
f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}"
|
||
|
)
|
||
|
attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask
|
||
|
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
|
||
|
|
||
|
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
|
||
|
|
||
|
if layer_head_mask is not None:
|
||
|
if layer_head_mask.size() != (self.num_heads,):
|
||
|
raise ValueError(
|
||
|
f"Head mask for a single layer should be of size {(self.num_heads,)}, but is"
|
||
|
f" {layer_head_mask.size()}"
|
||
|
)
|
||
|
attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
|
||
|
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
|
||
|
|
||
|
if output_attentions:
|
||
|
# this operation is a bit awkward, but it's required to
|
||
|
# make sure that attn_weights keeps its gradient.
|
||
|
# In order to do so, attn_weights have to be reshaped
|
||
|
# twice and have to be reused in the following
|
||
|
attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
|
||
|
attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len)
|
||
|
else:
|
||
|
attn_weights_reshaped = None
|
||
|
|
||
|
attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
|
||
|
|
||
|
attn_output = torch.bmm(attn_probs, value_states)
|
||
|
|
||
|
if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim):
|
||
|
raise ValueError(
|
||
|
f"`attn_output` should be of size {(bsz * self.num_heads, tgt_len, self.head_dim)}, but is"
|
||
|
f" {attn_output.size()}"
|
||
|
)
|
||
|
|
||
|
attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)
|
||
|
attn_output = attn_output.transpose(1, 2)
|
||
|
|
||
|
# Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be
|
||
|
# partitioned across GPUs when using tensor-parallelism.
|
||
|
attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim)
|
||
|
|
||
|
attn_output = self.out_proj(attn_output)
|
||
|
|
||
|
return attn_output, attn_weights_reshaped, past_key_value
|
||
|
|
||
|
|
||
|
# Copied from transformers.models.bart.modeling_bart.BartEncoderLayer with Bart->Marian, BART->MARIAN
|
||
|
class MarianEncoderLayer(nn.Module):
|
||
|
def __init__(self, config: MarianConfig):
|
||
|
super().__init__()
|
||
|
self.embed_dim = config.d_model
|
||
|
|
||
|
self.self_attn = MARIAN_ATTENTION_CLASSES[config._attn_implementation](
|
||
|
embed_dim=self.embed_dim,
|
||
|
num_heads=config.encoder_attention_heads,
|
||
|
dropout=config.attention_dropout,
|
||
|
config=config,
|
||
|
)
|
||
|
self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
|
||
|
self.dropout = config.dropout
|
||
|
self.activation_fn = ACT2FN[config.activation_function]
|
||
|
self.activation_dropout = config.activation_dropout
|
||
|
self.fc1 = nn.Linear(self.embed_dim, config.encoder_ffn_dim)
|
||
|
self.fc2 = nn.Linear(config.encoder_ffn_dim, self.embed_dim)
|
||
|
self.final_layer_norm = nn.LayerNorm(self.embed_dim)
|
||
|
|
||
|
def forward(
|
||
|
self,
|
||
|
hidden_states: torch.FloatTensor,
|
||
|
attention_mask: torch.FloatTensor,
|
||
|
layer_head_mask: torch.FloatTensor,
|
||
|
output_attentions: Optional[bool] = False,
|
||
|
) -> Tuple[torch.FloatTensor, Optional[torch.FloatTensor]]:
|
||
|
"""
|
||
|
Args:
|
||
|
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
|
||
|
attention_mask (`torch.FloatTensor`): attention mask of size
|
||
|
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
|
||
|
layer_head_mask (`torch.FloatTensor`): mask for attention heads in a given layer of size
|
||
|
`(encoder_attention_heads,)`.
|
||
|
output_attentions (`bool`, *optional*):
|
||
|
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
|
||
|
returned tensors for more detail.
|
||
|
"""
|
||
|
residual = hidden_states
|
||
|
hidden_states, attn_weights, _ = self.self_attn(
|
||
|
hidden_states=hidden_states,
|
||
|
attention_mask=attention_mask,
|
||
|
layer_head_mask=layer_head_mask,
|
||
|
output_attentions=output_attentions,
|
||
|
)
|
||
|
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
|
||
|
hidden_states = residual + hidden_states
|
||
|
hidden_states = self.self_attn_layer_norm(hidden_states)
|
||
|
|
||
|
residual = hidden_states
|
||
|
hidden_states = self.activation_fn(self.fc1(hidden_states))
|
||
|
hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training)
|
||
|
hidden_states = self.fc2(hidden_states)
|
||
|
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
|
||
|
hidden_states = residual + hidden_states
|
||
|
hidden_states = self.final_layer_norm(hidden_states)
|
||
|
|
||
|
if hidden_states.dtype == torch.float16 and (
|
||
|
torch.isinf(hidden_states).any() or torch.isnan(hidden_states).any()
|
||
|
):
|
||
|
clamp_value = torch.finfo(hidden_states.dtype).max - 1000
|
||
|
hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value)
|
||
|
|
||
|
outputs = (hidden_states,)
|
||
|
|
||
|
if output_attentions:
|
||
|
outputs += (attn_weights,)
|
||
|
|
||
|
return outputs
|
||
|
|
||
|
|
||
|
MARIAN_ATTENTION_CLASSES = {"eager": MarianAttention}
|
||
|
|
||
|
|
||
|
# Copied from transformers.models.bart.modeling_bart.BartDecoderLayer with Bart->Marian, BART->MARIAN
|
||
|
class MarianDecoderLayer(nn.Module):
|
||
|
def __init__(self, config: MarianConfig):
|
||
|
super().__init__()
|
||
|
self.embed_dim = config.d_model
|
||
|
|
||
|
self.self_attn = MARIAN_ATTENTION_CLASSES[config._attn_implementation](
|
||
|
embed_dim=self.embed_dim,
|
||
|
num_heads=config.decoder_attention_heads,
|
||
|
dropout=config.attention_dropout,
|
||
|
is_decoder=True,
|
||
|
is_causal=True,
|
||
|
config=config,
|
||
|
)
|
||
|
self.dropout = config.dropout
|
||
|
self.activation_fn = ACT2FN[config.activation_function]
|
||
|
self.activation_dropout = config.activation_dropout
|
||
|
|
||
|
self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
|
||
|
self.encoder_attn = MARIAN_ATTENTION_CLASSES[config._attn_implementation](
|
||
|
self.embed_dim,
|
||
|
config.decoder_attention_heads,
|
||
|
dropout=config.attention_dropout,
|
||
|
is_decoder=True,
|
||
|
config=config,
|
||
|
)
|
||
|
self.encoder_attn_layer_norm = nn.LayerNorm(self.embed_dim)
|
||
|
self.fc1 = nn.Linear(self.embed_dim, config.decoder_ffn_dim)
|
||
|
self.fc2 = nn.Linear(config.decoder_ffn_dim, self.embed_dim)
|
||
|
self.final_layer_norm = nn.LayerNorm(self.embed_dim)
|
||
|
|
||
|
def forward(
|
||
|
self,
|
||
|
hidden_states: torch.Tensor,
|
||
|
attention_mask: Optional[torch.Tensor] = None,
|
||
|
encoder_hidden_states: Optional[torch.Tensor] = None,
|
||
|
encoder_attention_mask: Optional[torch.Tensor] = None,
|
||
|
layer_head_mask: Optional[torch.Tensor] = None,
|
||
|
cross_attn_layer_head_mask: Optional[torch.Tensor] = None,
|
||
|
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
||
|
output_attentions: Optional[bool] = False,
|
||
|
use_cache: Optional[bool] = True,
|
||
|
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
||
|
"""
|
||
|
Args:
|
||
|
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
|
||
|
attention_mask (`torch.FloatTensor`): attention mask of size
|
||
|
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
|
||
|
encoder_hidden_states (`torch.FloatTensor`):
|
||
|
cross attention input to the layer of shape `(batch, seq_len, embed_dim)`
|
||
|
encoder_attention_mask (`torch.FloatTensor`): encoder attention mask of size
|
||
|
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
|
||
|
layer_head_mask (`torch.FloatTensor`): mask for attention heads in a given layer of size
|
||
|
`(encoder_attention_heads,)`.
|
||
|
cross_attn_layer_head_mask (`torch.FloatTensor`): mask for cross-attention heads in a given layer of
|
||
|
size `(decoder_attention_heads,)`.
|
||
|
past_key_value (`Tuple(torch.FloatTensor)`): cached past key and value projection states
|
||
|
output_attentions (`bool`, *optional*):
|
||
|
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
|
||
|
returned tensors for more detail.
|
||
|
"""
|
||
|
residual = hidden_states
|
||
|
|
||
|
# Self Attention
|
||
|
# decoder uni-directional self-attention cached key/values tuple is at positions 1,2
|
||
|
self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
|
||
|
# add present self-attn cache to positions 1,2 of present_key_value tuple
|
||
|
hidden_states, self_attn_weights, present_key_value = self.self_attn(
|
||
|
hidden_states=hidden_states,
|
||
|
past_key_value=self_attn_past_key_value,
|
||
|
attention_mask=attention_mask,
|
||
|
layer_head_mask=layer_head_mask,
|
||
|
output_attentions=output_attentions,
|
||
|
)
|
||
|
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
|
||
|
hidden_states = residual + hidden_states
|
||
|
hidden_states = self.self_attn_layer_norm(hidden_states)
|
||
|
|
||
|
# Cross-Attention Block
|
||
|
cross_attn_present_key_value = None
|
||
|
cross_attn_weights = None
|
||
|
if encoder_hidden_states is not None:
|
||
|
residual = hidden_states
|
||
|
|
||
|
# cross_attn cached key/values tuple is at positions 3,4 of present_key_value tuple
|
||
|
cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None
|
||
|
hidden_states, cross_attn_weights, cross_attn_present_key_value = self.encoder_attn(
|
||
|
hidden_states=hidden_states,
|
||
|
key_value_states=encoder_hidden_states,
|
||
|
attention_mask=encoder_attention_mask,
|
||
|
layer_head_mask=cross_attn_layer_head_mask,
|
||
|
past_key_value=cross_attn_past_key_value,
|
||
|
output_attentions=output_attentions,
|
||
|
)
|
||
|
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
|
||
|
hidden_states = residual + hidden_states
|
||
|
hidden_states = self.encoder_attn_layer_norm(hidden_states)
|
||
|
|
||
|
# add cross-attn to positions 3,4 of present_key_value tuple
|
||
|
present_key_value = present_key_value + cross_attn_present_key_value
|
||
|
|
||
|
# Fully Connected
|
||
|
residual = hidden_states
|
||
|
hidden_states = self.activation_fn(self.fc1(hidden_states))
|
||
|
hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training)
|
||
|
hidden_states = self.fc2(hidden_states)
|
||
|
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
|
||
|
hidden_states = residual + hidden_states
|
||
|
hidden_states = self.final_layer_norm(hidden_states)
|
||
|
|
||
|
outputs = (hidden_states,)
|
||
|
|
||
|
if output_attentions:
|
||
|
outputs += (self_attn_weights, cross_attn_weights)
|
||
|
|
||
|
if use_cache:
|
||
|
outputs += (present_key_value,)
|
||
|
|
||
|
return outputs
|
||
|
|
||
|
|
||
|
class MarianPreTrainedModel(PreTrainedModel):
|
||
|
config_class = MarianConfig
|
||
|
base_model_prefix = "model"
|
||
|
supports_gradient_checkpointing = True
|
||
|
|
||
|
def _init_weights(self, module: Union[nn.Linear, nn.Embedding, MarianSinusoidalPositionalEmbedding]):
|
||
|
std = self.config.init_std
|
||
|
if isinstance(module, nn.Linear):
|
||
|
module.weight.data.normal_(mean=0.0, std=std)
|
||
|
if module.bias is not None:
|
||
|
module.bias.data.zero_()
|
||
|
elif isinstance(module, MarianSinusoidalPositionalEmbedding):
|
||
|
pass
|
||
|
elif isinstance(module, nn.Embedding):
|
||
|
module.weight.data.normal_(mean=0.0, std=std)
|
||
|
if module.padding_idx is not None:
|
||
|
module.weight.data[module.padding_idx].zero_()
|
||
|
|
||
|
@property
|
||
|
def dummy_inputs(self):
|
||
|
pad_token = self.config.pad_token_id
|
||
|
input_ids = torch.tensor([[0, 6, 10, 4, 2], [0, 8, 12, 2, pad_token]], device=self.device)
|
||
|
dummy_inputs = {
|
||
|
"attention_mask": input_ids.ne(pad_token),
|
||
|
"input_ids": input_ids,
|
||
|
"decoder_input_ids": input_ids,
|
||
|
}
|
||
|
return dummy_inputs
|
||
|
|
||
|
|
||
|
MARIAN_START_DOCSTRING = r"""
|
||
|
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
|
||
|
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
|
||
|
etc.)
|
||
|
|
||
|
This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
|
||
|
Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
|
||
|
and behavior.
|
||
|
|
||
|
Parameters:
|
||
|
config ([`MarianConfig`]):
|
||
|
Model configuration class with all the parameters of the model. Initializing with a config file does not
|
||
|
load the weights associated with the model, only the configuration. Check out the
|
||
|
[`~PreTrainedModel.from_pretrained`] method to load the model weights.
|
||
|
"""
|
||
|
|
||
|
MARIAN_GENERATION_EXAMPLE = r"""
|
||
|
Pytorch version of marian-nmt's transformer.h (c++). Designed for the OPUS-NMT translation checkpoints. Available
|
||
|
models are listed [here](https://huggingface.co/models?search=Helsinki-NLP).
|
||
|
|
||
|
Examples:
|
||
|
|
||
|
```python
|
||
|
>>> from transformers import AutoTokenizer, MarianMTModel
|
||
|
|
||
|
>>> src = "fr" # source language
|
||
|
>>> trg = "en" # target language
|
||
|
|
||
|
>>> model_name = f"Helsinki-NLP/opus-mt-{src}-{trg}"
|
||
|
>>> model = MarianMTModel.from_pretrained(model_name)
|
||
|
>>> tokenizer = AutoTokenizer.from_pretrained(model_name)
|
||
|
|
||
|
>>> sample_text = "où est l'arrêt de bus ?"
|
||
|
>>> batch = tokenizer([sample_text], return_tensors="pt")
|
||
|
|
||
|
>>> generated_ids = model.generate(**batch)
|
||
|
>>> tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
|
||
|
"Where's the bus stop?"
|
||
|
```
|
||
|
"""
|
||
|
|
||
|
MARIAN_INPUTS_DOCSTRING = r"""
|
||
|
Args:
|
||
|
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
|
||
|
Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
|
||
|
it.
|
||
|
|
||
|
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
|
||
|
[`PreTrainedTokenizer.__call__`] for details.
|
||
|
|
||
|
[What are input IDs?](../glossary#input-ids)
|
||
|
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||
|
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
|
||
|
|
||
|
- 1 for tokens that are **not masked**,
|
||
|
- 0 for tokens that are **masked**.
|
||
|
|
||
|
[What are attention masks?](../glossary#attention-mask)
|
||
|
decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
|
||
|
Indices of decoder input sequence tokens in the vocabulary.
|
||
|
|
||
|
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
|
||
|
[`PreTrainedTokenizer.__call__`] for details.
|
||
|
|
||
|
[What are decoder input IDs?](../glossary#decoder-input-ids)
|
||
|
|
||
|
Marian uses the `pad_token_id` as the starting token for `decoder_input_ids` generation. If
|
||
|
`past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see
|
||
|
`past_key_values`).
|
||
|
decoder_attention_mask (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
|
||
|
Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also
|
||
|
be used by default.
|
||
|
head_mask (`torch.Tensor` of shape `(encoder_layers, encoder_attention_heads)`, *optional*):
|
||
|
Mask to nullify selected heads of the attention modules in the encoder. Mask values selected in `[0, 1]`:
|
||
|
|
||
|
- 1 indicates the head is **not masked**,
|
||
|
- 0 indicates the head is **masked**.
|
||
|
|
||
|
decoder_head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*):
|
||
|
Mask to nullify selected heads of the attention modules in the decoder. Mask values selected in `[0, 1]`:
|
||
|
|
||
|
- 1 indicates the head is **not masked**,
|
||
|
- 0 indicates the head is **masked**.
|
||
|
|
||
|
cross_attn_head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*):
|
||
|
Mask to nullify selected heads of the cross-attention modules in the decoder. Mask values selected in `[0,
|
||
|
1]`:
|
||
|
|
||
|
- 1 indicates the head is **not masked**,
|
||
|
- 0 indicates the head is **masked**.
|
||
|
|
||
|
encoder_outputs (`tuple(tuple(torch.FloatTensor)`, *optional*):
|
||
|
Tuple consists of (`last_hidden_state`, *optional*: `hidden_states`, *optional*: `attentions`)
|
||
|
`last_hidden_state` of shape `(batch_size, sequence_length, hidden_size)`, *optional*) is a sequence of
|
||
|
hidden-states at the output of the last layer of the encoder. Used in the cross-attention of the decoder.
|
||
|
past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
|
||
|
Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
|
||
|
`(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape
|
||
|
`(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`.
|
||
|
|
||
|
Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
|
||
|
blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
|
||
|
|
||
|
If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
|
||
|
don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
|
||
|
`decoder_input_ids` of shape `(batch_size, sequence_length)`.
|
||
|
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
|
||
|
Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
|
||
|
is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
|
||
|
model's internal embedding lookup matrix.
|
||
|
decoder_inputs_embeds (`torch.FloatTensor` of shape `(batch_size, target_sequence_length, hidden_size)`, *optional*):
|
||
|
Optionally, instead of passing `decoder_input_ids` you can choose to directly pass an embedded
|
||
|
representation. If `past_key_values` is used, optionally only the last `decoder_inputs_embeds` have to be
|
||
|
input (see `past_key_values`). This is useful if you want more control over how to convert
|
||
|
`decoder_input_ids` indices into associated vectors than the model's internal embedding lookup matrix.
|
||
|
|
||
|
If `decoder_input_ids` and `decoder_inputs_embeds` are both unset, `decoder_inputs_embeds` takes the value
|
||
|
of `inputs_embeds`.
|
||
|
use_cache (`bool`, *optional*):
|
||
|
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
|
||
|
`past_key_values`).
|
||
|
output_attentions (`bool`, *optional*):
|
||
|
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
|
||
|
tensors for more detail.
|
||
|
output_hidden_states (`bool`, *optional*):
|
||
|
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
|
||
|
more detail.
|
||
|
return_dict (`bool`, *optional*):
|
||
|
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
||
|
"""
|
||
|
|
||
|
|
||
|
class MarianEncoder(MarianPreTrainedModel):
|
||
|
"""
|
||
|
Transformer encoder consisting of *config.encoder_layers* self attention layers. Each layer is a
|
||
|
[`MarianEncoderLayer`].
|
||
|
|
||
|
Args:
|
||
|
config: MarianConfig
|
||
|
embed_tokens (nn.Embedding): output embedding
|
||
|
"""
|
||
|
|
||
|
def __init__(self, config: MarianConfig, embed_tokens: Optional[nn.Embedding] = None):
|
||
|
super().__init__(config)
|
||
|
|
||
|
self.dropout = config.dropout
|
||
|
self.layerdrop = config.encoder_layerdrop
|
||
|
|
||
|
embed_dim = config.d_model
|
||
|
self.padding_idx = config.pad_token_id
|
||
|
self.max_source_positions = config.max_position_embeddings
|
||
|
self.embed_scale = math.sqrt(embed_dim) if config.scale_embedding else 1.0
|
||
|
|
||
|
if embed_tokens is not None:
|
||
|
self.embed_tokens = embed_tokens
|
||
|
else:
|
||
|
self.embed_tokens = nn.Embedding(config.vocab_size, embed_dim, self.padding_idx)
|
||
|
|
||
|
self.embed_positions = MarianSinusoidalPositionalEmbedding(
|
||
|
config.max_position_embeddings, embed_dim, self.padding_idx
|
||
|
)
|
||
|
self.layers = nn.ModuleList([MarianEncoderLayer(config) for _ in range(config.encoder_layers)])
|
||
|
|
||
|
self.gradient_checkpointing = False
|
||
|
# Initialize weights and apply final processing
|
||
|
self.post_init()
|
||
|
|
||
|
def get_input_embeddings(self):
|
||
|
return self.embed_tokens
|
||
|
|
||
|
def set_input_embeddings(self, value):
|
||
|
self.embed_tokens = value
|
||
|
|
||
|
def forward(
|
||
|
self,
|
||
|
input_ids: torch.LongTensor = None,
|
||
|
attention_mask: Optional[torch.LongTensor] = None,
|
||
|
head_mask: Optional[torch.Tensor] = None,
|
||
|
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||
|
output_attentions: Optional[bool] = None,
|
||
|
output_hidden_states: Optional[bool] = None,
|
||
|
return_dict: Optional[bool] = None,
|
||
|
) -> Union[Tuple[torch.Tensor], BaseModelOutput]:
|
||
|
r"""
|
||
|
Args:
|
||
|
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
|
||
|
Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you
|
||
|
provide it.
|
||
|
|
||
|
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
|
||
|
[`PreTrainedTokenizer.__call__`] for details.
|
||
|
|
||
|
[What are input IDs?](../glossary#input-ids)
|
||
|
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||
|
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
|
||
|
|
||
|
- 1 for tokens that are **not masked**,
|
||
|
- 0 for tokens that are **masked**.
|
||
|
|
||
|
[What are attention masks?](../glossary#attention-mask)
|
||
|
head_mask (`torch.Tensor` of shape `(encoder_layers, encoder_attention_heads)`, *optional*):
|
||
|
Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`:
|
||
|
|
||
|
- 1 indicates the head is **not masked**,
|
||
|
- 0 indicates the head is **masked**.
|
||
|
|
||
|
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
|
||
|
Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
|
||
|
This is useful if you want more control over how to convert `input_ids` indices into associated vectors
|
||
|
than the model's internal embedding lookup matrix.
|
||
|
output_attentions (`bool`, *optional*):
|
||
|
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
|
||
|
returned tensors for more detail.
|
||
|
output_hidden_states (`bool`, *optional*):
|
||
|
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
|
||
|
for more detail.
|
||
|
return_dict (`bool`, *optional*):
|
||
|
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
||
|
"""
|
||
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||
|
output_hidden_states = (
|
||
|
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||
|
)
|
||
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||
|
|
||
|
# retrieve input_ids and inputs_embeds
|
||
|
if input_ids is not None and inputs_embeds is not None:
|
||
|
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
|
||
|
elif input_ids is not None:
|
||
|
self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask)
|
||
|
input_shape = input_ids.size()
|
||
|
input_ids = input_ids.view(-1, input_shape[-1])
|
||
|
elif inputs_embeds is not None:
|
||
|
input_shape = inputs_embeds.size()[:-1]
|
||
|
else:
|
||
|
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
||
|
|
||
|
if inputs_embeds is None:
|
||
|
inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
|
||
|
|
||
|
embed_pos = self.embed_positions(input_shape)
|
||
|
|
||
|
hidden_states = inputs_embeds + embed_pos
|
||
|
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
|
||
|
|
||
|
# expand attention_mask
|
||
|
if attention_mask is not None:
|
||
|
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
|
||
|
attention_mask = _prepare_4d_attention_mask(attention_mask, inputs_embeds.dtype)
|
||
|
|
||
|
encoder_states = () if output_hidden_states else None
|
||
|
all_attentions = () if output_attentions else None
|
||
|
|
||
|
# check if head_mask has a correct number of layers specified if desired
|
||
|
if head_mask is not None:
|
||
|
assert head_mask.size()[0] == (
|
||
|
len(self.layers)
|
||
|
), f"The head_mask should be specified for {len(self.layers)} layers, but it is for {head_mask.size()[0]}."
|
||
|
for idx, encoder_layer in enumerate(self.layers):
|
||
|
if output_hidden_states:
|
||
|
encoder_states = encoder_states + (hidden_states,)
|
||
|
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
|
||
|
to_drop = False
|
||
|
if self.training:
|
||
|
dropout_probability = torch.rand([])
|
||
|
if dropout_probability < self.layerdrop: # skip the layer
|
||
|
to_drop = True
|
||
|
|
||
|
if to_drop:
|
||
|
layer_outputs = (None, None)
|
||
|
else:
|
||
|
if self.gradient_checkpointing and self.training:
|
||
|
layer_outputs = self._gradient_checkpointing_func(
|
||
|
encoder_layer.__call__,
|
||
|
hidden_states,
|
||
|
attention_mask,
|
||
|
(head_mask[idx] if head_mask is not None else None),
|
||
|
output_attentions,
|
||
|
)
|
||
|
else:
|
||
|
layer_outputs = encoder_layer(
|
||
|
hidden_states,
|
||
|
attention_mask,
|
||
|
layer_head_mask=(head_mask[idx] if head_mask is not None else None),
|
||
|
output_attentions=output_attentions,
|
||
|
)
|
||
|
|
||
|
hidden_states = layer_outputs[0]
|
||
|
|
||
|
if output_attentions:
|
||
|
all_attentions = all_attentions + (layer_outputs[1],)
|
||
|
|
||
|
if output_hidden_states:
|
||
|
encoder_states = encoder_states + (hidden_states,)
|
||
|
|
||
|
if not return_dict:
|
||
|
return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None)
|
||
|
return BaseModelOutput(
|
||
|
last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions
|
||
|
)
|
||
|
|
||
|
|
||
|
class MarianDecoder(MarianPreTrainedModel):
|
||
|
"""
|
||
|
Transformer decoder consisting of *config.decoder_layers* layers. Each layer is a [`MarianDecoderLayer`]
|
||
|
|
||
|
Args:
|
||
|
config: MarianConfig
|
||
|
embed_tokens (nn.Embedding): output embedding
|
||
|
"""
|
||
|
|
||
|
def __init__(self, config: MarianConfig, embed_tokens: Optional[nn.Embedding] = None):
|
||
|
super().__init__(config)
|
||
|
self.dropout = config.dropout
|
||
|
self.layerdrop = config.decoder_layerdrop
|
||
|
self.padding_idx = config.pad_token_id
|
||
|
self.max_target_positions = config.max_position_embeddings
|
||
|
self.embed_scale = math.sqrt(config.d_model) if config.scale_embedding else 1.0
|
||
|
|
||
|
if embed_tokens is not None:
|
||
|
self.embed_tokens = embed_tokens
|
||
|
else:
|
||
|
self.embed_tokens = nn.Embedding(config.decoder_vocab_size, config.d_model, self.padding_idx)
|
||
|
|
||
|
self.embed_positions = MarianSinusoidalPositionalEmbedding(
|
||
|
config.max_position_embeddings, config.d_model, self.padding_idx
|
||
|
)
|
||
|
self.layers = nn.ModuleList([MarianDecoderLayer(config) for _ in range(config.decoder_layers)])
|
||
|
|
||
|
self.gradient_checkpointing = False
|
||
|
# Initialize weights and apply final processing
|
||
|
self.post_init()
|
||
|
|
||
|
def get_input_embeddings(self):
|
||
|
return self.embed_tokens
|
||
|
|
||
|
def set_input_embeddings(self, value):
|
||
|
self.embed_tokens = value
|
||
|
|
||
|
def forward(
|
||
|
self,
|
||
|
input_ids: torch.LongTensor = None,
|
||
|
attention_mask: Optional[torch.Tensor] = None,
|
||
|
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
||
|
encoder_attention_mask: Optional[torch.LongTensor] = None,
|
||
|
head_mask: Optional[torch.Tensor] = None,
|
||
|
cross_attn_head_mask: Optional[torch.Tensor] = None,
|
||
|
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
|
||
|
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||
|
use_cache: Optional[bool] = None,
|
||
|
output_attentions: Optional[bool] = None,
|
||
|
output_hidden_states: Optional[bool] = None,
|
||
|
return_dict: Optional[bool] = None,
|
||
|
) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPastAndCrossAttentions]:
|
||
|
r"""
|
||
|
Args:
|
||
|
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
|
||
|
Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you
|
||
|
provide it.
|
||
|
|
||
|
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
|
||
|
[`PreTrainedTokenizer.__call__`] for details.
|
||
|
|
||
|
[What are input IDs?](../glossary#input-ids)
|
||
|
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||
|
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
|
||
|
|
||
|
- 1 for tokens that are **not masked**,
|
||
|
- 0 for tokens that are **masked**.
|
||
|
|
||
|
[What are attention masks?](../glossary#attention-mask)
|
||
|
encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, encoder_sequence_length, hidden_size)`, *optional*):
|
||
|
Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention
|
||
|
of the decoder.
|
||
|
encoder_attention_mask (`torch.LongTensor` of shape `(batch_size, encoder_sequence_length)`, *optional*):
|
||
|
Mask to avoid performing cross-attention on padding tokens indices of encoder input_ids. Mask values
|
||
|
selected in `[0, 1]`:
|
||
|
|
||
|
- 1 for tokens that are **not masked**,
|
||
|
- 0 for tokens that are **masked**.
|
||
|
|
||
|
[What are attention masks?](../glossary#attention-mask)
|
||
|
head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*):
|
||
|
Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`:
|
||
|
|
||
|
- 1 indicates the head is **not masked**,
|
||
|
- 0 indicates the head is **masked**.
|
||
|
|
||
|
cross_attn_head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*):
|
||
|
Mask to nullify selected heads of the cross-attention modules in the decoder to avoid performing
|
||
|
cross-attention on hidden heads. Mask values selected in `[0, 1]`:
|
||
|
|
||
|
- 1 indicates the head is **not masked**,
|
||
|
- 0 indicates the head is **masked**.
|
||
|
|
||
|
past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
|
||
|
Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of
|
||
|
shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of
|
||
|
shape `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`.
|
||
|
|
||
|
Contains pre-computed hidden-states (key and values in the self-attention blocks and in the
|
||
|
cross-attention blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
|
||
|
|
||
|
If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those
|
||
|
that don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of
|
||
|
all `decoder_input_ids` of shape `(batch_size, sequence_length)`.
|
||
|
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
|
||
|
Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
|
||
|
This is useful if you want more control over how to convert `input_ids` indices into associated vectors
|
||
|
than the model's internal embedding lookup matrix.
|
||
|
output_attentions (`bool`, *optional*):
|
||
|
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
|
||
|
returned tensors for more detail.
|
||
|
output_hidden_states (`bool`, *optional*):
|
||
|
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
|
||
|
for more detail.
|
||
|
return_dict (`bool`, *optional*):
|
||
|
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
||
|
"""
|
||
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||
|
output_hidden_states = (
|
||
|
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||
|
)
|
||
|
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
||
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||
|
|
||
|
# retrieve input_ids and inputs_embeds
|
||
|
if input_ids is not None and inputs_embeds is not None:
|
||
|
raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
|
||
|
elif input_ids is not None:
|
||
|
input_shape = input_ids.size()
|
||
|
input_ids = input_ids.view(-1, input_shape[-1])
|
||
|
elif inputs_embeds is not None:
|
||
|
input_shape = inputs_embeds.size()[:-1]
|
||
|
else:
|
||
|
raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
|
||
|
|
||
|
# past_key_values_length
|
||
|
past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0
|
||
|
|
||
|
if inputs_embeds is None:
|
||
|
inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
|
||
|
|
||
|
attention_mask = _prepare_4d_causal_attention_mask(
|
||
|
attention_mask, input_shape, inputs_embeds, past_key_values_length
|
||
|
)
|
||
|
|
||
|
# expand encoder attention mask
|
||
|
if encoder_hidden_states is not None and encoder_attention_mask is not None:
|
||
|
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
|
||
|
encoder_attention_mask = _prepare_4d_attention_mask(
|
||
|
encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]
|
||
|
)
|
||
|
|
||
|
# embed positions
|
||
|
positions = self.embed_positions(input_shape, past_key_values_length)
|
||
|
|
||
|
hidden_states = inputs_embeds + positions
|
||
|
|
||
|
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
|
||
|
|
||
|
if self.gradient_checkpointing and self.training:
|
||
|
if use_cache:
|
||
|
logger.warning_once(
|
||
|
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
|
||
|
)
|
||
|
use_cache = False
|
||
|
|
||
|
# decoder layers
|
||
|
all_hidden_states = () if output_hidden_states else None
|
||
|
all_self_attns = () if output_attentions else None
|
||
|
all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None
|
||
|
next_decoder_cache = () if use_cache else None
|
||
|
|
||
|
# check if head_mask/cross_attn_head_mask has a correct number of layers specified if desired
|
||
|
for attn_mask, mask_name in zip([head_mask, cross_attn_head_mask], ["head_mask", "cross_attn_head_mask"]):
|
||
|
if attn_mask is not None:
|
||
|
assert attn_mask.size()[0] == (len(self.layers)), (
|
||
|
f"The `{mask_name}` should be specified for {len(self.layers)} layers, but it is for"
|
||
|
f" {head_mask.size()[0]}."
|
||
|
)
|
||
|
for idx, decoder_layer in enumerate(self.layers):
|
||
|
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
|
||
|
if output_hidden_states:
|
||
|
all_hidden_states += (hidden_states,)
|
||
|
if self.training:
|
||
|
dropout_probability = torch.rand([])
|
||
|
if dropout_probability < self.layerdrop:
|
||
|
continue
|
||
|
|
||
|
past_key_value = past_key_values[idx] if past_key_values is not None else None
|
||
|
|
||
|
if self.gradient_checkpointing and self.training:
|
||
|
layer_outputs = self._gradient_checkpointing_func(
|
||
|
decoder_layer.__call__,
|
||
|
hidden_states,
|
||
|
attention_mask,
|
||
|
encoder_hidden_states,
|
||
|
encoder_attention_mask,
|
||
|
head_mask[idx] if head_mask is not None else None,
|
||
|
cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None,
|
||
|
None,
|
||
|
output_attentions,
|
||
|
use_cache,
|
||
|
)
|
||
|
else:
|
||
|
layer_outputs = decoder_layer(
|
||
|
hidden_states,
|
||
|
attention_mask=attention_mask,
|
||
|
encoder_hidden_states=encoder_hidden_states,
|
||
|
encoder_attention_mask=encoder_attention_mask,
|
||
|
layer_head_mask=(head_mask[idx] if head_mask is not None else None),
|
||
|
cross_attn_layer_head_mask=(
|
||
|
cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None
|
||
|
),
|
||
|
past_key_value=past_key_value,
|
||
|
output_attentions=output_attentions,
|
||
|
use_cache=use_cache,
|
||
|
)
|
||
|
hidden_states = layer_outputs[0]
|
||
|
|
||
|
if use_cache:
|
||
|
next_decoder_cache += (layer_outputs[3 if output_attentions else 1],)
|
||
|
|
||
|
if output_attentions:
|
||
|
all_self_attns += (layer_outputs[1],)
|
||
|
|
||
|
if encoder_hidden_states is not None:
|
||
|
all_cross_attentions += (layer_outputs[2],)
|
||
|
|
||
|
# add hidden states from the last decoder layer
|
||
|
if output_hidden_states:
|
||
|
all_hidden_states += (hidden_states,)
|
||
|
|
||
|
next_cache = next_decoder_cache if use_cache else None
|
||
|
if not return_dict:
|
||
|
return tuple(
|
||
|
v
|
||
|
for v in [hidden_states, next_cache, all_hidden_states, all_self_attns, all_cross_attentions]
|
||
|
if v is not None
|
||
|
)
|
||
|
return BaseModelOutputWithPastAndCrossAttentions(
|
||
|
last_hidden_state=hidden_states,
|
||
|
past_key_values=next_cache,
|
||
|
hidden_states=all_hidden_states,
|
||
|
attentions=all_self_attns,
|
||
|
cross_attentions=all_cross_attentions,
|
||
|
)
|
||
|
|
||
|
|
||
|
@add_start_docstrings(
|
||
|
"The bare Marian Model outputting raw hidden-states without any specific head on top.", MARIAN_START_DOCSTRING
|
||
|
)
|
||
|
class MarianModel(MarianPreTrainedModel):
|
||
|
_tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"]
|
||
|
|
||
|
def __init__(self, config: MarianConfig):
|
||
|
super().__init__(config)
|
||
|
|
||
|
padding_idx, vocab_size = config.pad_token_id, config.vocab_size
|
||
|
|
||
|
# We always use self.shared for token embeddings to ensure compatibility with all marian models
|
||
|
self.shared = nn.Embedding(vocab_size, config.d_model, padding_idx)
|
||
|
if self.config.share_encoder_decoder_embeddings:
|
||
|
encoder_embed_tokens = decoder_embed_tokens = self.shared
|
||
|
else:
|
||
|
# Since the embeddings are not shared, deepcopy the embeddings here for encoder
|
||
|
# and decoder to make sure they are not tied.
|
||
|
encoder_embed_tokens = copy.deepcopy(self.shared)
|
||
|
decoder_embed_tokens = copy.deepcopy(self.shared)
|
||
|
self.shared = None
|
||
|
|
||
|
self.encoder = MarianEncoder(config, encoder_embed_tokens)
|
||
|
self.decoder = MarianDecoder(config, decoder_embed_tokens)
|
||
|
|
||
|
# Initialize weights and apply final processing
|
||
|
self.post_init()
|
||
|
|
||
|
def get_input_embeddings(self):
|
||
|
# This will return shared embeddings if they are shared else specific to encoder.
|
||
|
return self.get_encoder().get_input_embeddings()
|
||
|
|
||
|
def set_input_embeddings(self, value):
|
||
|
if self.config.share_encoder_decoder_embeddings:
|
||
|
self.shared = value
|
||
|
self.encoder.embed_tokens = self.shared
|
||
|
self.decoder.embed_tokens = self.shared
|
||
|
else: # if not shared only set encoder embeedings
|
||
|
self.encoder.embed_tokens = value
|
||
|
|
||
|
def get_decoder_input_embeddings(self):
|
||
|
if self.config.share_encoder_decoder_embeddings:
|
||
|
raise ValueError(
|
||
|
"`get_decoder_input_embeddings` should not be called if `config.share_encoder_decoder_embeddings` "
|
||
|
"is `True`. Please use `get_input_embeddings` instead."
|
||
|
)
|
||
|
return self.get_decoder().get_input_embeddings()
|
||
|
|
||
|
def set_decoder_input_embeddings(self, value):
|
||
|
if self.config.share_encoder_decoder_embeddings:
|
||
|
raise ValueError(
|
||
|
"`config.share_encoder_decoder_embeddings` is set to `True` meaning the decoder input embeddings "
|
||
|
"are shared with the encoder. In order to set the decoder input embeddings, you should simply set "
|
||
|
"the encoder input embeddings by calling `set_input_embeddings` with the appropriate embeddings."
|
||
|
)
|
||
|
self.decoder.embed_tokens = value
|
||
|
|
||
|
def get_encoder(self):
|
||
|
return self.encoder
|
||
|
|
||
|
def get_decoder(self):
|
||
|
return self.decoder
|
||
|
|
||
|
def resize_decoder_token_embeddings(self, new_num_tokens: int) -> nn.Embedding:
|
||
|
if self.config.share_encoder_decoder_embeddings:
|
||
|
raise ValueError(
|
||
|
"`resize_decoder_token_embeddings` should not be called if `config.share_encoder_decoder_embeddings` "
|
||
|
"is `True`. Please use `resize_token_embeddings` instead."
|
||
|
)
|
||
|
|
||
|
old_embeddings = self.get_decoder_input_embeddings()
|
||
|
new_embeddings = self._get_resized_embeddings(old_embeddings, new_num_tokens)
|
||
|
self.set_decoder_input_embeddings(new_embeddings)
|
||
|
|
||
|
model_embeds = self.get_decoder_input_embeddings()
|
||
|
|
||
|
if new_num_tokens is None:
|
||
|
return model_embeds
|
||
|
|
||
|
# Update base model and current model config
|
||
|
self.config.decoder_vocab_size = new_num_tokens
|
||
|
|
||
|
# Tie weights again if needed
|
||
|
self.tie_weights()
|
||
|
|
||
|
return model_embeds
|
||
|
|
||
|
@add_start_docstrings_to_model_forward(MARIAN_INPUTS_DOCSTRING)
|
||
|
@replace_return_docstrings(output_type=Seq2SeqModelOutput, config_class=_CONFIG_FOR_DOC)
|
||
|
def forward(
|
||
|
self,
|
||
|
input_ids: torch.LongTensor = None,
|
||
|
attention_mask: Optional[torch.Tensor] = None,
|
||
|
decoder_input_ids: Optional[torch.LongTensor] = None,
|
||
|
decoder_attention_mask: Optional[torch.Tensor] = None,
|
||
|
head_mask: Optional[torch.Tensor] = None,
|
||
|
decoder_head_mask: Optional[torch.Tensor] = None,
|
||
|
cross_attn_head_mask: Optional[torch.Tensor] = None,
|
||
|
encoder_outputs: Optional[Union[Tuple[torch.Tensor], BaseModelOutput]] = None,
|
||
|
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
|
||
|
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||
|
decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
|
||
|
use_cache: Optional[bool] = None,
|
||
|
output_attentions: Optional[bool] = None,
|
||
|
output_hidden_states: Optional[bool] = None,
|
||
|
return_dict: Optional[bool] = None,
|
||
|
) -> Seq2SeqModelOutput:
|
||
|
r"""
|
||
|
Returns:
|
||
|
|
||
|
Example:
|
||
|
|
||
|
```python
|
||
|
>>> from transformers import AutoTokenizer, MarianModel
|
||
|
|
||
|
>>> tokenizer = AutoTokenizer.from_pretrained("Helsinki-NLP/opus-mt-en-de")
|
||
|
>>> model = MarianModel.from_pretrained("Helsinki-NLP/opus-mt-en-de")
|
||
|
|
||
|
>>> inputs = tokenizer("Studies have been shown that owning a dog is good for you", return_tensors="pt")
|
||
|
>>> decoder_inputs = tokenizer(
|
||
|
... "<pad> Studien haben gezeigt dass es hilfreich ist einen Hund zu besitzen",
|
||
|
... return_tensors="pt",
|
||
|
... add_special_tokens=False,
|
||
|
... )
|
||
|
>>> outputs = model(input_ids=inputs.input_ids, decoder_input_ids=decoder_inputs.input_ids)
|
||
|
|
||
|
>>> last_hidden_states = outputs.last_hidden_state
|
||
|
>>> list(last_hidden_states.shape)
|
||
|
[1, 26, 512]
|
||
|
```"""
|
||
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||
|
output_hidden_states = (
|
||
|
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||
|
)
|
||
|
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
||
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||
|
|
||
|
if encoder_outputs is None:
|
||
|
encoder_outputs = self.encoder(
|
||
|
input_ids=input_ids,
|
||
|
attention_mask=attention_mask,
|
||
|
head_mask=head_mask,
|
||
|
inputs_embeds=inputs_embeds,
|
||
|
output_attentions=output_attentions,
|
||
|
output_hidden_states=output_hidden_states,
|
||
|
return_dict=return_dict,
|
||
|
)
|
||
|
# If the user passed a tuple for encoder_outputs, we wrap it in a BaseModelOutput when return_dict=True
|
||
|
elif return_dict and not isinstance(encoder_outputs, BaseModelOutput):
|
||
|
encoder_outputs = BaseModelOutput(
|
||
|
last_hidden_state=encoder_outputs[0],
|
||
|
hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None,
|
||
|
attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None,
|
||
|
)
|
||
|
|
||
|
# decoder outputs consists of (dec_features, past_key_value, dec_hidden, dec_attn)
|
||
|
decoder_outputs = self.decoder(
|
||
|
input_ids=decoder_input_ids,
|
||
|
attention_mask=decoder_attention_mask,
|
||
|
encoder_hidden_states=encoder_outputs[0],
|
||
|
encoder_attention_mask=attention_mask,
|
||
|
head_mask=decoder_head_mask,
|
||
|
cross_attn_head_mask=cross_attn_head_mask,
|
||
|
past_key_values=past_key_values,
|
||
|
inputs_embeds=decoder_inputs_embeds,
|
||
|
use_cache=use_cache,
|
||
|
output_attentions=output_attentions,
|
||
|
output_hidden_states=output_hidden_states,
|
||
|
return_dict=return_dict,
|
||
|
)
|
||
|
|
||
|
if not return_dict:
|
||
|
return decoder_outputs + encoder_outputs
|
||
|
|
||
|
return Seq2SeqModelOutput(
|
||
|
last_hidden_state=decoder_outputs.last_hidden_state,
|
||
|
past_key_values=decoder_outputs.past_key_values,
|
||
|
decoder_hidden_states=decoder_outputs.hidden_states,
|
||
|
decoder_attentions=decoder_outputs.attentions,
|
||
|
cross_attentions=decoder_outputs.cross_attentions,
|
||
|
encoder_last_hidden_state=encoder_outputs.last_hidden_state,
|
||
|
encoder_hidden_states=encoder_outputs.hidden_states,
|
||
|
encoder_attentions=encoder_outputs.attentions,
|
||
|
)
|
||
|
|
||
|
|
||
|
@add_start_docstrings(
|
||
|
"The Marian Model with a language modeling head. Can be used for summarization.", MARIAN_START_DOCSTRING
|
||
|
)
|
||
|
class MarianMTModel(MarianPreTrainedModel):
|
||
|
base_model_prefix = "model"
|
||
|
_keys_to_ignore_on_load_missing = [
|
||
|
"final_logits_bias",
|
||
|
"encoder.embed_positions.weight",
|
||
|
"decoder.embed_positions.weight",
|
||
|
]
|
||
|
_keys_to_ignore_on_save = ["model.encoder.embed_positions.weight", "model.decoder.embed_positions.weight"]
|
||
|
_tied_weights_keys = ["model.encoder.embed_tokens.weight", "model.decoder.embed_tokens.weight", "lm_head.weight"]
|
||
|
|
||
|
def __init__(self, config: MarianConfig):
|
||
|
super().__init__(config)
|
||
|
self.model = MarianModel(config)
|
||
|
|
||
|
target_vocab_size = config.vocab_size if config.share_encoder_decoder_embeddings else config.decoder_vocab_size
|
||
|
self.register_buffer("final_logits_bias", torch.zeros((1, target_vocab_size)))
|
||
|
self.lm_head = nn.Linear(config.d_model, target_vocab_size, bias=False)
|
||
|
|
||
|
# Initialize weights and apply final processing
|
||
|
self.post_init()
|
||
|
|
||
|
def get_encoder(self):
|
||
|
return self.model.get_encoder()
|
||
|
|
||
|
def get_decoder(self):
|
||
|
return self.model.get_decoder()
|
||
|
|
||
|
def resize_token_embeddings(self, new_num_tokens: int, pad_to_multiple_of: Optional[int] = None) -> nn.Embedding:
|
||
|
new_embeddings = super().resize_token_embeddings(new_num_tokens, pad_to_multiple_of)
|
||
|
if self.config.share_encoder_decoder_embeddings:
|
||
|
self._resize_final_logits_bias(new_num_tokens)
|
||
|
return new_embeddings
|
||
|
|
||
|
def _resize_token_embeddings(self, new_num_tokens: int, pad_to_multiple_of=None) -> nn.Embedding:
|
||
|
old_embeddings = self.get_input_embeddings()
|
||
|
new_embeddings = self._get_resized_embeddings(old_embeddings, new_num_tokens, pad_to_multiple_of)
|
||
|
self.set_input_embeddings(new_embeddings)
|
||
|
|
||
|
new_num_tokens = new_embeddings.weight.shape[0]
|
||
|
# update config.decoder_vocab_size if embeddings are tied
|
||
|
if self.config.share_encoder_decoder_embeddings:
|
||
|
self.config.decoder_vocab_size = new_num_tokens
|
||
|
|
||
|
# if word embeddings are not tied, make sure that lm head is resized as well
|
||
|
if (
|
||
|
self.config.share_encoder_decoder_embeddings
|
||
|
and self.get_output_embeddings() is not None
|
||
|
and not self.config.tie_word_embeddings
|
||
|
):
|
||
|
old_lm_head = self.get_output_embeddings()
|
||
|
new_lm_head = self._get_resized_lm_head(old_lm_head, new_num_tokens)
|
||
|
self.set_output_embeddings(new_lm_head)
|
||
|
|
||
|
return self.get_input_embeddings()
|
||
|
|
||
|
def resize_decoder_token_embeddings(self, new_num_tokens):
|
||
|
if self.config.share_encoder_decoder_embeddings:
|
||
|
raise ValueError(
|
||
|
"`resize_decoder_token_embeddings` should not be called if `config.share_encoder_decoder_embeddings` "
|
||
|
"is `True`. Please use `resize_token_embeddings` instead."
|
||
|
)
|
||
|
|
||
|
old_embeddings = self.model.get_decoder_input_embeddings()
|
||
|
new_embeddings = self._get_resized_embeddings(old_embeddings, new_num_tokens)
|
||
|
self.model.set_decoder_input_embeddings(new_embeddings)
|
||
|
|
||
|
# if word embeddings are not tied, make sure that lm head is resized as well
|
||
|
if self.get_output_embeddings() is not None and not self.config.tie_word_embeddings:
|
||
|
old_lm_head = self.get_output_embeddings()
|
||
|
new_lm_head = self._get_resized_lm_head(old_lm_head, new_num_tokens)
|
||
|
self.set_output_embeddings(new_lm_head)
|
||
|
|
||
|
model_embeds = self.model.get_decoder_input_embeddings()
|
||
|
|
||
|
if new_num_tokens is None:
|
||
|
return model_embeds
|
||
|
|
||
|
# Update base model and current model config
|
||
|
self.config.decoder_vocab_size = new_num_tokens
|
||
|
|
||
|
# Tie weights again if needed
|
||
|
self.tie_weights()
|
||
|
|
||
|
self._resize_final_logits_bias(new_num_tokens)
|
||
|
|
||
|
return model_embeds
|
||
|
|
||
|
def _resize_final_logits_bias(self, new_num_tokens: int) -> None:
|
||
|
old_num_tokens = self.final_logits_bias.shape[-1]
|
||
|
if new_num_tokens <= old_num_tokens:
|
||
|
new_bias = self.final_logits_bias[:, :new_num_tokens]
|
||
|
else:
|
||
|
extra_bias = torch.zeros((1, new_num_tokens - old_num_tokens), device=self.final_logits_bias.device)
|
||
|
new_bias = torch.cat([self.final_logits_bias, extra_bias], dim=1)
|
||
|
self.register_buffer("final_logits_bias", new_bias)
|
||
|
|
||
|
def get_output_embeddings(self):
|
||
|
return self.lm_head
|
||
|
|
||
|
def set_output_embeddings(self, new_embeddings: nn.Embedding):
|
||
|
self.lm_head = new_embeddings
|
||
|
|
||
|
def tie_weights(self):
|
||
|
"""
|
||
|
Tie the weights between the input embeddings and the output embeddings.
|
||
|
|
||
|
If the `torchscript` flag is set in the configuration, can't handle parameter sharing so we are cloning the
|
||
|
weights instead.
|
||
|
"""
|
||
|
output_embeddings = self.get_output_embeddings()
|
||
|
if output_embeddings is not None and getattr(self.config, "tie_word_embeddings", True):
|
||
|
# if embeddings are shared this will return shared embeddings otherwise decoder embed_tokens
|
||
|
word_embeddings = self.get_decoder().get_input_embeddings()
|
||
|
self._tie_or_clone_weights(output_embeddings, word_embeddings)
|
||
|
|
||
|
if getattr(self.config, "is_encoder_decoder", False) and getattr(self.config, "tie_encoder_decoder", False):
|
||
|
if hasattr(self, self.base_model_prefix):
|
||
|
self = getattr(self, self.base_model_prefix)
|
||
|
tied_weights = self._tie_encoder_decoder_weights(
|
||
|
self.encoder, self.decoder, self.base_model_prefix, "encoder"
|
||
|
)
|
||
|
# Setting a dynamic variable instead of `_tied_weights_keys` because it's a class
|
||
|
# attributed not an instance member, therefore modifying it will modify the entire class
|
||
|
# Leading to issues on subsequent calls by different tests or subsequent calls.
|
||
|
self._dynamic_tied_weights_keys = tied_weights
|
||
|
|
||
|
for module in self.modules():
|
||
|
if hasattr(module, "_tie_weights"):
|
||
|
module._tie_weights()
|
||
|
|
||
|
@add_start_docstrings_to_model_forward(MARIAN_INPUTS_DOCSTRING)
|
||
|
@replace_return_docstrings(output_type=Seq2SeqLMOutput, config_class=_CONFIG_FOR_DOC)
|
||
|
@add_end_docstrings(MARIAN_GENERATION_EXAMPLE)
|
||
|
def forward(
|
||
|
self,
|
||
|
input_ids: torch.LongTensor = None,
|
||
|
attention_mask: Optional[torch.Tensor] = None,
|
||
|
decoder_input_ids: Optional[torch.LongTensor] = None,
|
||
|
decoder_attention_mask: Optional[torch.Tensor] = None,
|
||
|
head_mask: Optional[torch.Tensor] = None,
|
||
|
decoder_head_mask: Optional[torch.Tensor] = None,
|
||
|
cross_attn_head_mask: Optional[torch.Tensor] = None,
|
||
|
encoder_outputs: Optional[Union[Tuple[torch.Tensor], BaseModelOutput]] = None,
|
||
|
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
|
||
|
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||
|
decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
|
||
|
labels: Optional[torch.LongTensor] = None,
|
||
|
use_cache: Optional[bool] = None,
|
||
|
output_attentions: Optional[bool] = None,
|
||
|
output_hidden_states: Optional[bool] = None,
|
||
|
return_dict: Optional[bool] = None,
|
||
|
) -> Seq2SeqLMOutput:
|
||
|
r"""
|
||
|
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||
|
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
||
|
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
||
|
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
||
|
|
||
|
Returns:
|
||
|
|
||
|
"""
|
||
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||
|
|
||
|
if labels is not None:
|
||
|
if use_cache:
|
||
|
logger.warning("The `use_cache` argument is changed to `False` since `labels` is provided.")
|
||
|
use_cache = False
|
||
|
if decoder_input_ids is None and decoder_inputs_embeds is None:
|
||
|
decoder_input_ids = shift_tokens_right(
|
||
|
labels, self.config.pad_token_id, self.config.decoder_start_token_id
|
||
|
)
|
||
|
|
||
|
outputs = self.model(
|
||
|
input_ids,
|
||
|
attention_mask=attention_mask,
|
||
|
decoder_input_ids=decoder_input_ids,
|
||
|
encoder_outputs=encoder_outputs,
|
||
|
decoder_attention_mask=decoder_attention_mask,
|
||
|
head_mask=head_mask,
|
||
|
decoder_head_mask=decoder_head_mask,
|
||
|
cross_attn_head_mask=cross_attn_head_mask,
|
||
|
past_key_values=past_key_values,
|
||
|
inputs_embeds=inputs_embeds,
|
||
|
decoder_inputs_embeds=decoder_inputs_embeds,
|
||
|
use_cache=use_cache,
|
||
|
output_attentions=output_attentions,
|
||
|
output_hidden_states=output_hidden_states,
|
||
|
return_dict=return_dict,
|
||
|
)
|
||
|
lm_logits = self.lm_head(outputs[0]) + self.final_logits_bias
|
||
|
|
||
|
masked_lm_loss = None
|
||
|
if labels is not None:
|
||
|
loss_fct = CrossEntropyLoss()
|
||
|
masked_lm_loss = loss_fct(lm_logits.view(-1, self.config.decoder_vocab_size), labels.view(-1))
|
||
|
|
||
|
if not return_dict:
|
||
|
output = (lm_logits,) + outputs[1:]
|
||
|
return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output
|
||
|
|
||
|
return Seq2SeqLMOutput(
|
||
|
loss=masked_lm_loss,
|
||
|
logits=lm_logits,
|
||
|
past_key_values=outputs.past_key_values,
|
||
|
decoder_hidden_states=outputs.decoder_hidden_states,
|
||
|
decoder_attentions=outputs.decoder_attentions,
|
||
|
cross_attentions=outputs.cross_attentions,
|
||
|
encoder_last_hidden_state=outputs.encoder_last_hidden_state,
|
||
|
encoder_hidden_states=outputs.encoder_hidden_states,
|
||
|
encoder_attentions=outputs.encoder_attentions,
|
||
|
)
|
||
|
|
||
|
def prepare_inputs_for_generation(
|
||
|
self,
|
||
|
decoder_input_ids: torch.LongTensor,
|
||
|
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
|
||
|
attention_mask: Optional[torch.Tensor] = None,
|
||
|
head_mask: Optional[torch.Tensor] = None,
|
||
|
decoder_head_mask: Optional[torch.Tensor] = None,
|
||
|
cross_attn_head_mask: Optional[torch.Tensor] = None,
|
||
|
use_cache: Optional[bool] = None,
|
||
|
encoder_outputs: Optional[Union[Tuple[torch.Tensor], BaseModelOutput]] = None,
|
||
|
**kwargs,
|
||
|
) -> Dict:
|
||
|
# cut decoder_input_ids if past is used
|
||
|
if past_key_values is not None:
|
||
|
past_length = past_key_values[0][0].shape[2]
|
||
|
|
||
|
# Some generation methods already pass only the last input ID
|
||
|
if decoder_input_ids.shape[1] > past_length:
|
||
|
remove_prefix_length = past_length
|
||
|
else:
|
||
|
# Default to old behavior: keep only final ID
|
||
|
remove_prefix_length = decoder_input_ids.shape[1] - 1
|
||
|
|
||
|
decoder_input_ids = decoder_input_ids[:, remove_prefix_length:]
|
||
|
|
||
|
return {
|
||
|
"input_ids": None, # encoder_outputs is defined. input_ids not needed
|
||
|
"encoder_outputs": encoder_outputs,
|
||
|
"past_key_values": past_key_values,
|
||
|
"decoder_input_ids": decoder_input_ids,
|
||
|
"attention_mask": attention_mask,
|
||
|
"head_mask": head_mask,
|
||
|
"decoder_head_mask": decoder_head_mask,
|
||
|
"cross_attn_head_mask": cross_attn_head_mask,
|
||
|
"use_cache": use_cache, # change this to avoid caching (presumably for debugging)
|
||
|
}
|
||
|
|
||
|
def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor):
|
||
|
return shift_tokens_right(labels, self.config.pad_token_id, self.config.decoder_start_token_id)
|
||
|
|
||
|
@staticmethod
|
||
|
def _reorder_cache(past_key_values, beam_idx):
|
||
|
reordered_past = ()
|
||
|
for layer_past in past_key_values:
|
||
|
# cached cross_attention states don't have to be reordered -> they are always the same
|
||
|
reordered_past += (
|
||
|
tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past[:2])
|
||
|
+ layer_past[2:],
|
||
|
)
|
||
|
return reordered_past
|
||
|
|
||
|
|
||
|
# Copied from transformers.models.bart.modeling_bart.BartDecoderWrapper with Bart->Marian
|
||
|
class MarianDecoderWrapper(MarianPreTrainedModel):
|
||
|
"""
|
||
|
This wrapper class is a helper class to correctly load pretrained checkpoints when the causal language model is
|
||
|
used in combination with the [`EncoderDecoderModel`] framework.
|
||
|
"""
|
||
|
|
||
|
def __init__(self, config):
|
||
|
super().__init__(config)
|
||
|
self.decoder = MarianDecoder(config)
|
||
|
|
||
|
def forward(self, *args, **kwargs):
|
||
|
return self.decoder(*args, **kwargs)
|
||
|
|
||
|
|
||
|
# Copied from transformers.models.bart.modeling_bart.BartForCausalLM with Bart->Marian, facebook/bart-base->Helsinki-NLP/opus-mt-fr-en
|
||
|
class MarianForCausalLM(MarianPreTrainedModel):
|
||
|
_tied_weights_keys = ["lm_head.weight"]
|
||
|
|
||
|
def __init__(self, config):
|
||
|
config = copy.deepcopy(config)
|
||
|
config.is_decoder = True
|
||
|
config.is_encoder_decoder = False
|
||
|
super().__init__(config)
|
||
|
self.model = MarianDecoderWrapper(config)
|
||
|
|
||
|
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
||
|
|
||
|
# Initialize weights and apply final processing
|
||
|
self.post_init()
|
||
|
|
||
|
def get_input_embeddings(self):
|
||
|
return self.model.decoder.embed_tokens
|
||
|
|
||
|
def set_input_embeddings(self, value):
|
||
|
self.model.decoder.embed_tokens = value
|
||
|
|
||
|
def get_output_embeddings(self):
|
||
|
return self.lm_head
|
||
|
|
||
|
def set_output_embeddings(self, new_embeddings):
|
||
|
self.lm_head = new_embeddings
|
||
|
|
||
|
def set_decoder(self, decoder):
|
||
|
self.model.decoder = decoder
|
||
|
|
||
|
def get_decoder(self):
|
||
|
return self.model.decoder
|
||
|
|
||
|
@replace_return_docstrings(output_type=CausalLMOutputWithCrossAttentions, config_class=_CONFIG_FOR_DOC)
|
||
|
def forward(
|
||
|
self,
|
||
|
input_ids: torch.LongTensor = None,
|
||
|
attention_mask: Optional[torch.Tensor] = None,
|
||
|
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
||
|
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
||
|
head_mask: Optional[torch.Tensor] = None,
|
||
|
cross_attn_head_mask: Optional[torch.Tensor] = None,
|
||
|
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
||
|
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||
|
labels: Optional[torch.LongTensor] = None,
|
||
|
use_cache: Optional[bool] = None,
|
||
|
output_attentions: Optional[bool] = None,
|
||
|
output_hidden_states: Optional[bool] = None,
|
||
|
return_dict: Optional[bool] = None,
|
||
|
) -> Union[Tuple, CausalLMOutputWithCrossAttentions]:
|
||
|
r"""
|
||
|
Args:
|
||
|
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
|
||
|
Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you
|
||
|
provide it.
|
||
|
|
||
|
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
|
||
|
[`PreTrainedTokenizer.__call__`] for details.
|
||
|
|
||
|
[What are input IDs?](../glossary#input-ids)
|
||
|
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||
|
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
|
||
|
|
||
|
- 1 for tokens that are **not masked**,
|
||
|
- 0 for tokens that are **masked**.
|
||
|
|
||
|
[What are attention masks?](../glossary#attention-mask)
|
||
|
encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
|
||
|
Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention
|
||
|
if the model is configured as a decoder.
|
||
|
encoder_attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||
|
Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used
|
||
|
in the cross-attention if the model is configured as a decoder. Mask values selected in `[0, 1]`:
|
||
|
head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*):
|
||
|
Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`:
|
||
|
|
||
|
- 1 indicates the head is **not masked**,
|
||
|
- 0 indicates the head is **masked**.
|
||
|
|
||
|
cross_attn_head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*):
|
||
|
Mask to nullify selected heads of the cross-attention modules. Mask values selected in `[0, 1]`:
|
||
|
|
||
|
- 1 indicates the head is **not masked**,
|
||
|
- 0 indicates the head is **masked**.
|
||
|
|
||
|
past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
|
||
|
Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of
|
||
|
shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of
|
||
|
shape `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. The two additional
|
||
|
tensors are only required when the model is used as a decoder in a Sequence to Sequence model.
|
||
|
|
||
|
Contains pre-computed hidden-states (key and values in the self-attention blocks and in the
|
||
|
cross-attention blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
|
||
|
|
||
|
If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those
|
||
|
that don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of
|
||
|
all `decoder_input_ids` of shape `(batch_size, sequence_length)`.
|
||
|
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||
|
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
||
|
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
||
|
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
||
|
use_cache (`bool`, *optional*):
|
||
|
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
|
||
|
(see `past_key_values`).
|
||
|
|
||
|
- 1 for tokens that are **not masked**,
|
||
|
- 0 for tokens that are **masked**.
|
||
|
output_attentions (`bool`, *optional*):
|
||
|
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
|
||
|
returned tensors for more detail.
|
||
|
output_hidden_states (`bool`, *optional*):
|
||
|
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
|
||
|
for more detail.
|
||
|
return_dict (`bool`, *optional*):
|
||
|
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
||
|
|
||
|
Returns:
|
||
|
|
||
|
Example:
|
||
|
|
||
|
```python
|
||
|
>>> from transformers import AutoTokenizer, MarianForCausalLM
|
||
|
|
||
|
>>> tokenizer = AutoTokenizer.from_pretrained("Helsinki-NLP/opus-mt-fr-en")
|
||
|
>>> model = MarianForCausalLM.from_pretrained("Helsinki-NLP/opus-mt-fr-en", add_cross_attention=False)
|
||
|
>>> assert model.config.is_decoder, f"{model.__class__} has to be configured as a decoder."
|
||
|
>>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
|
||
|
>>> outputs = model(**inputs)
|
||
|
|
||
|
>>> logits = outputs.logits
|
||
|
>>> expected_shape = [1, inputs.input_ids.shape[-1], model.config.vocab_size]
|
||
|
>>> list(logits.shape) == expected_shape
|
||
|
True
|
||
|
```"""
|
||
|
|
||
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||
|
output_hidden_states = (
|
||
|
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||
|
)
|
||
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||
|
|
||
|
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
||
|
outputs = self.model.decoder(
|
||
|
input_ids=input_ids,
|
||
|
attention_mask=attention_mask,
|
||
|
encoder_hidden_states=encoder_hidden_states,
|
||
|
encoder_attention_mask=encoder_attention_mask,
|
||
|
head_mask=head_mask,
|
||
|
cross_attn_head_mask=cross_attn_head_mask,
|
||
|
past_key_values=past_key_values,
|
||
|
inputs_embeds=inputs_embeds,
|
||
|
use_cache=use_cache,
|
||
|
output_attentions=output_attentions,
|
||
|
output_hidden_states=output_hidden_states,
|
||
|
return_dict=return_dict,
|
||
|
)
|
||
|
|
||
|
logits = self.lm_head(outputs[0])
|
||
|
|
||
|
loss = None
|
||
|
if labels is not None:
|
||
|
labels = labels.to(logits.device)
|
||
|
loss_fct = CrossEntropyLoss()
|
||
|
loss = loss_fct(logits.view(-1, self.config.vocab_size), labels.view(-1))
|
||
|
|
||
|
if not return_dict:
|
||
|
output = (logits,) + outputs[1:]
|
||
|
return (loss,) + output if loss is not None else output
|
||
|
|
||
|
return CausalLMOutputWithCrossAttentions(
|
||
|
loss=loss,
|
||
|
logits=logits,
|
||
|
past_key_values=outputs.past_key_values,
|
||
|
hidden_states=outputs.hidden_states,
|
||
|
attentions=outputs.attentions,
|
||
|
cross_attentions=outputs.cross_attentions,
|
||
|
)
|
||
|
|
||
|
def prepare_inputs_for_generation(
|
||
|
self, input_ids, past_key_values=None, attention_mask=None, use_cache=None, **kwargs
|
||
|
):
|
||
|
# if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly
|
||
|
if attention_mask is None:
|
||
|
attention_mask = input_ids.new_ones(input_ids.shape)
|
||
|
|
||
|
if past_key_values:
|
||
|
past_length = past_key_values[0][0].shape[2]
|
||
|
|
||
|
# Some generation methods already pass only the last input ID
|
||
|
if input_ids.shape[1] > past_length:
|
||
|
remove_prefix_length = past_length
|
||
|
else:
|
||
|
# Default to old behavior: keep only final ID
|
||
|
remove_prefix_length = input_ids.shape[1] - 1
|
||
|
|
||
|
input_ids = input_ids[:, remove_prefix_length:]
|
||
|
# first step, decoder_cached_states are empty
|
||
|
return {
|
||
|
"input_ids": input_ids, # encoder_outputs is defined. input_ids not needed
|
||
|
"attention_mask": attention_mask,
|
||
|
"past_key_values": past_key_values,
|
||
|
"use_cache": use_cache,
|
||
|
}
|
||
|
|
||
|
@staticmethod
|
||
|
def _reorder_cache(past_key_values, beam_idx):
|
||
|
reordered_past = ()
|
||
|
for layer_past in past_key_values:
|
||
|
reordered_past += (
|
||
|
tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),
|
||
|
)
|
||
|
return reordered_past
|