1105 lines
52 KiB
Python
1105 lines
52 KiB
Python
|
# coding=utf-8
|
||
|
# Copyright 2021 Tel AViv University, AllenAI and The HuggingFace Inc. team. All rights reserved.
|
||
|
#
|
||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||
|
# you may not use this file except in compliance with the License.
|
||
|
# You may obtain a copy of the License at
|
||
|
#
|
||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||
|
#
|
||
|
# Unless required by applicable law or agreed to in writing, software
|
||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
|
# See the License for the specific language governing permissions and
|
||
|
# limitations under the License.
|
||
|
""" PyTorch Splinter model."""
|
||
|
|
||
|
|
||
|
import math
|
||
|
from dataclasses import dataclass
|
||
|
from typing import List, Optional, Tuple, Union
|
||
|
|
||
|
import torch
|
||
|
import torch.utils.checkpoint
|
||
|
from torch import nn
|
||
|
from torch.nn import CrossEntropyLoss
|
||
|
|
||
|
from ...activations import ACT2FN
|
||
|
from ...modeling_outputs import BaseModelOutputWithPastAndCrossAttentions, ModelOutput, QuestionAnsweringModelOutput
|
||
|
from ...modeling_utils import PreTrainedModel
|
||
|
from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer
|
||
|
from ...utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, logging
|
||
|
from .configuration_splinter import SplinterConfig
|
||
|
|
||
|
|
||
|
logger = logging.get_logger(__name__)
|
||
|
|
||
|
_CHECKPOINT_FOR_DOC = "tau/splinter-base"
|
||
|
_CONFIG_FOR_DOC = "SplinterConfig"
|
||
|
|
||
|
|
||
|
from ..deprecated._archive_maps import SPLINTER_PRETRAINED_MODEL_ARCHIVE_LIST # noqa: F401, E402
|
||
|
|
||
|
|
||
|
class SplinterEmbeddings(nn.Module):
|
||
|
"""Construct the embeddings from word, position and token_type embeddings."""
|
||
|
|
||
|
def __init__(self, config):
|
||
|
super().__init__()
|
||
|
self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
|
||
|
self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
|
||
|
self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size)
|
||
|
|
||
|
# self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
|
||
|
# any TensorFlow checkpoint file
|
||
|
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
||
|
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
||
|
|
||
|
# position_ids (1, len position emb) is contiguous in memory and exported when serialized
|
||
|
self.register_buffer(
|
||
|
"position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False
|
||
|
)
|
||
|
self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
|
||
|
|
||
|
def forward(
|
||
|
self,
|
||
|
input_ids: Optional[torch.LongTensor] = None,
|
||
|
token_type_ids: Optional[torch.LongTensor] = None,
|
||
|
position_ids: Optional[torch.LongTensor] = None,
|
||
|
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||
|
past_key_values_length: Optional[int] = 0,
|
||
|
) -> Tuple:
|
||
|
if input_ids is not None:
|
||
|
input_shape = input_ids.size()
|
||
|
else:
|
||
|
input_shape = inputs_embeds.size()[:-1]
|
||
|
|
||
|
seq_length = input_shape[1]
|
||
|
|
||
|
if position_ids is None:
|
||
|
position_ids = self.position_ids[:, past_key_values_length : seq_length + past_key_values_length]
|
||
|
|
||
|
if token_type_ids is None:
|
||
|
token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device)
|
||
|
|
||
|
if inputs_embeds is None:
|
||
|
inputs_embeds = self.word_embeddings(input_ids)
|
||
|
token_type_embeddings = self.token_type_embeddings(token_type_ids)
|
||
|
|
||
|
embeddings = inputs_embeds + token_type_embeddings
|
||
|
if self.position_embedding_type == "absolute":
|
||
|
position_embeddings = self.position_embeddings(position_ids)
|
||
|
embeddings += position_embeddings
|
||
|
embeddings = self.LayerNorm(embeddings)
|
||
|
embeddings = self.dropout(embeddings)
|
||
|
return embeddings
|
||
|
|
||
|
|
||
|
# Copied from transformers.models.bert.modeling_bert.BertSelfAttention with Bert->Splinter
|
||
|
class SplinterSelfAttention(nn.Module):
|
||
|
def __init__(self, config, position_embedding_type=None):
|
||
|
super().__init__()
|
||
|
if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
|
||
|
raise ValueError(
|
||
|
f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention "
|
||
|
f"heads ({config.num_attention_heads})"
|
||
|
)
|
||
|
|
||
|
self.num_attention_heads = config.num_attention_heads
|
||
|
self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
|
||
|
self.all_head_size = self.num_attention_heads * self.attention_head_size
|
||
|
|
||
|
self.query = nn.Linear(config.hidden_size, self.all_head_size)
|
||
|
self.key = nn.Linear(config.hidden_size, self.all_head_size)
|
||
|
self.value = nn.Linear(config.hidden_size, self.all_head_size)
|
||
|
|
||
|
self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
|
||
|
self.position_embedding_type = position_embedding_type or getattr(
|
||
|
config, "position_embedding_type", "absolute"
|
||
|
)
|
||
|
if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
|
||
|
self.max_position_embeddings = config.max_position_embeddings
|
||
|
self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size)
|
||
|
|
||
|
self.is_decoder = config.is_decoder
|
||
|
|
||
|
def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:
|
||
|
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
|
||
|
x = x.view(new_x_shape)
|
||
|
return x.permute(0, 2, 1, 3)
|
||
|
|
||
|
def forward(
|
||
|
self,
|
||
|
hidden_states: torch.Tensor,
|
||
|
attention_mask: Optional[torch.FloatTensor] = None,
|
||
|
head_mask: Optional[torch.FloatTensor] = None,
|
||
|
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
||
|
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
||
|
past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
|
||
|
output_attentions: Optional[bool] = False,
|
||
|
) -> Tuple[torch.Tensor]:
|
||
|
mixed_query_layer = self.query(hidden_states)
|
||
|
|
||
|
# If this is instantiated as a cross-attention module, the keys
|
||
|
# and values come from an encoder; the attention mask needs to be
|
||
|
# such that the encoder's padding tokens are not attended to.
|
||
|
is_cross_attention = encoder_hidden_states is not None
|
||
|
|
||
|
if is_cross_attention and past_key_value is not None:
|
||
|
# reuse k,v, cross_attentions
|
||
|
key_layer = past_key_value[0]
|
||
|
value_layer = past_key_value[1]
|
||
|
attention_mask = encoder_attention_mask
|
||
|
elif is_cross_attention:
|
||
|
key_layer = self.transpose_for_scores(self.key(encoder_hidden_states))
|
||
|
value_layer = self.transpose_for_scores(self.value(encoder_hidden_states))
|
||
|
attention_mask = encoder_attention_mask
|
||
|
elif past_key_value is not None:
|
||
|
key_layer = self.transpose_for_scores(self.key(hidden_states))
|
||
|
value_layer = self.transpose_for_scores(self.value(hidden_states))
|
||
|
key_layer = torch.cat([past_key_value[0], key_layer], dim=2)
|
||
|
value_layer = torch.cat([past_key_value[1], value_layer], dim=2)
|
||
|
else:
|
||
|
key_layer = self.transpose_for_scores(self.key(hidden_states))
|
||
|
value_layer = self.transpose_for_scores(self.value(hidden_states))
|
||
|
|
||
|
query_layer = self.transpose_for_scores(mixed_query_layer)
|
||
|
|
||
|
use_cache = past_key_value is not None
|
||
|
if self.is_decoder:
|
||
|
# if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
|
||
|
# Further calls to cross_attention layer can then reuse all cross-attention
|
||
|
# key/value_states (first "if" case)
|
||
|
# if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
|
||
|
# all previous decoder key/value_states. Further calls to uni-directional self-attention
|
||
|
# can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
|
||
|
# if encoder bi-directional self-attention `past_key_value` is always `None`
|
||
|
past_key_value = (key_layer, value_layer)
|
||
|
|
||
|
# Take the dot product between "query" and "key" to get the raw attention scores.
|
||
|
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
|
||
|
|
||
|
if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
|
||
|
query_length, key_length = query_layer.shape[2], key_layer.shape[2]
|
||
|
if use_cache:
|
||
|
position_ids_l = torch.tensor(key_length - 1, dtype=torch.long, device=hidden_states.device).view(
|
||
|
-1, 1
|
||
|
)
|
||
|
else:
|
||
|
position_ids_l = torch.arange(query_length, dtype=torch.long, device=hidden_states.device).view(-1, 1)
|
||
|
position_ids_r = torch.arange(key_length, dtype=torch.long, device=hidden_states.device).view(1, -1)
|
||
|
distance = position_ids_l - position_ids_r
|
||
|
|
||
|
positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1)
|
||
|
positional_embedding = positional_embedding.to(dtype=query_layer.dtype) # fp16 compatibility
|
||
|
|
||
|
if self.position_embedding_type == "relative_key":
|
||
|
relative_position_scores = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
|
||
|
attention_scores = attention_scores + relative_position_scores
|
||
|
elif self.position_embedding_type == "relative_key_query":
|
||
|
relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
|
||
|
relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", key_layer, positional_embedding)
|
||
|
attention_scores = attention_scores + relative_position_scores_query + relative_position_scores_key
|
||
|
|
||
|
attention_scores = attention_scores / math.sqrt(self.attention_head_size)
|
||
|
if attention_mask is not None:
|
||
|
# Apply the attention mask is (precomputed for all layers in SplinterModel forward() function)
|
||
|
attention_scores = attention_scores + attention_mask
|
||
|
|
||
|
# Normalize the attention scores to probabilities.
|
||
|
attention_probs = nn.functional.softmax(attention_scores, dim=-1)
|
||
|
|
||
|
# This is actually dropping out entire tokens to attend to, which might
|
||
|
# seem a bit unusual, but is taken from the original Transformer paper.
|
||
|
attention_probs = self.dropout(attention_probs)
|
||
|
|
||
|
# Mask heads if we want to
|
||
|
if head_mask is not None:
|
||
|
attention_probs = attention_probs * head_mask
|
||
|
|
||
|
context_layer = torch.matmul(attention_probs, value_layer)
|
||
|
|
||
|
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
|
||
|
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
|
||
|
context_layer = context_layer.view(new_context_layer_shape)
|
||
|
|
||
|
outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
|
||
|
|
||
|
if self.is_decoder:
|
||
|
outputs = outputs + (past_key_value,)
|
||
|
return outputs
|
||
|
|
||
|
|
||
|
# Copied from transformers.models.bert.modeling_bert.BertSelfOutput with Bert->Splinter
|
||
|
class SplinterSelfOutput(nn.Module):
|
||
|
def __init__(self, config):
|
||
|
super().__init__()
|
||
|
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
|
||
|
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
||
|
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
||
|
|
||
|
def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
|
||
|
hidden_states = self.dense(hidden_states)
|
||
|
hidden_states = self.dropout(hidden_states)
|
||
|
hidden_states = self.LayerNorm(hidden_states + input_tensor)
|
||
|
return hidden_states
|
||
|
|
||
|
|
||
|
# Copied from transformers.models.bert.modeling_bert.BertAttention with Bert->Splinter
|
||
|
class SplinterAttention(nn.Module):
|
||
|
def __init__(self, config, position_embedding_type=None):
|
||
|
super().__init__()
|
||
|
self.self = SplinterSelfAttention(config, position_embedding_type=position_embedding_type)
|
||
|
self.output = SplinterSelfOutput(config)
|
||
|
self.pruned_heads = set()
|
||
|
|
||
|
def prune_heads(self, heads):
|
||
|
if len(heads) == 0:
|
||
|
return
|
||
|
heads, index = find_pruneable_heads_and_indices(
|
||
|
heads, self.self.num_attention_heads, self.self.attention_head_size, self.pruned_heads
|
||
|
)
|
||
|
|
||
|
# Prune linear layers
|
||
|
self.self.query = prune_linear_layer(self.self.query, index)
|
||
|
self.self.key = prune_linear_layer(self.self.key, index)
|
||
|
self.self.value = prune_linear_layer(self.self.value, index)
|
||
|
self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
|
||
|
|
||
|
# Update hyper params and store pruned heads
|
||
|
self.self.num_attention_heads = self.self.num_attention_heads - len(heads)
|
||
|
self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads
|
||
|
self.pruned_heads = self.pruned_heads.union(heads)
|
||
|
|
||
|
def forward(
|
||
|
self,
|
||
|
hidden_states: torch.Tensor,
|
||
|
attention_mask: Optional[torch.FloatTensor] = None,
|
||
|
head_mask: Optional[torch.FloatTensor] = None,
|
||
|
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
||
|
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
||
|
past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
|
||
|
output_attentions: Optional[bool] = False,
|
||
|
) -> Tuple[torch.Tensor]:
|
||
|
self_outputs = self.self(
|
||
|
hidden_states,
|
||
|
attention_mask,
|
||
|
head_mask,
|
||
|
encoder_hidden_states,
|
||
|
encoder_attention_mask,
|
||
|
past_key_value,
|
||
|
output_attentions,
|
||
|
)
|
||
|
attention_output = self.output(self_outputs[0], hidden_states)
|
||
|
outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
|
||
|
return outputs
|
||
|
|
||
|
|
||
|
# Copied from transformers.models.bert.modeling_bert.BertIntermediate with Bert->Splinter
|
||
|
class SplinterIntermediate(nn.Module):
|
||
|
def __init__(self, config):
|
||
|
super().__init__()
|
||
|
self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
|
||
|
if isinstance(config.hidden_act, str):
|
||
|
self.intermediate_act_fn = ACT2FN[config.hidden_act]
|
||
|
else:
|
||
|
self.intermediate_act_fn = config.hidden_act
|
||
|
|
||
|
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||
|
hidden_states = self.dense(hidden_states)
|
||
|
hidden_states = self.intermediate_act_fn(hidden_states)
|
||
|
return hidden_states
|
||
|
|
||
|
|
||
|
# Copied from transformers.models.bert.modeling_bert.BertOutput with Bert->Splinter
|
||
|
class SplinterOutput(nn.Module):
|
||
|
def __init__(self, config):
|
||
|
super().__init__()
|
||
|
self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
|
||
|
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
||
|
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
||
|
|
||
|
def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
|
||
|
hidden_states = self.dense(hidden_states)
|
||
|
hidden_states = self.dropout(hidden_states)
|
||
|
hidden_states = self.LayerNorm(hidden_states + input_tensor)
|
||
|
return hidden_states
|
||
|
|
||
|
|
||
|
# Copied from transformers.models.bert.modeling_bert.BertLayer with Bert->Splinter
|
||
|
class SplinterLayer(nn.Module):
|
||
|
def __init__(self, config):
|
||
|
super().__init__()
|
||
|
self.chunk_size_feed_forward = config.chunk_size_feed_forward
|
||
|
self.seq_len_dim = 1
|
||
|
self.attention = SplinterAttention(config)
|
||
|
self.is_decoder = config.is_decoder
|
||
|
self.add_cross_attention = config.add_cross_attention
|
||
|
if self.add_cross_attention:
|
||
|
if not self.is_decoder:
|
||
|
raise ValueError(f"{self} should be used as a decoder model if cross attention is added")
|
||
|
self.crossattention = SplinterAttention(config, position_embedding_type="absolute")
|
||
|
self.intermediate = SplinterIntermediate(config)
|
||
|
self.output = SplinterOutput(config)
|
||
|
|
||
|
def forward(
|
||
|
self,
|
||
|
hidden_states: torch.Tensor,
|
||
|
attention_mask: Optional[torch.FloatTensor] = None,
|
||
|
head_mask: Optional[torch.FloatTensor] = None,
|
||
|
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
||
|
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
||
|
past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
|
||
|
output_attentions: Optional[bool] = False,
|
||
|
) -> Tuple[torch.Tensor]:
|
||
|
# decoder uni-directional self-attention cached key/values tuple is at positions 1,2
|
||
|
self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
|
||
|
self_attention_outputs = self.attention(
|
||
|
hidden_states,
|
||
|
attention_mask,
|
||
|
head_mask,
|
||
|
output_attentions=output_attentions,
|
||
|
past_key_value=self_attn_past_key_value,
|
||
|
)
|
||
|
attention_output = self_attention_outputs[0]
|
||
|
|
||
|
# if decoder, the last output is tuple of self-attn cache
|
||
|
if self.is_decoder:
|
||
|
outputs = self_attention_outputs[1:-1]
|
||
|
present_key_value = self_attention_outputs[-1]
|
||
|
else:
|
||
|
outputs = self_attention_outputs[1:] # add self attentions if we output attention weights
|
||
|
|
||
|
cross_attn_present_key_value = None
|
||
|
if self.is_decoder and encoder_hidden_states is not None:
|
||
|
if not hasattr(self, "crossattention"):
|
||
|
raise ValueError(
|
||
|
f"If `encoder_hidden_states` are passed, {self} has to be instantiated with cross-attention layers"
|
||
|
" by setting `config.add_cross_attention=True`"
|
||
|
)
|
||
|
|
||
|
# cross_attn cached key/values tuple is at positions 3,4 of past_key_value tuple
|
||
|
cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None
|
||
|
cross_attention_outputs = self.crossattention(
|
||
|
attention_output,
|
||
|
attention_mask,
|
||
|
head_mask,
|
||
|
encoder_hidden_states,
|
||
|
encoder_attention_mask,
|
||
|
cross_attn_past_key_value,
|
||
|
output_attentions,
|
||
|
)
|
||
|
attention_output = cross_attention_outputs[0]
|
||
|
outputs = outputs + cross_attention_outputs[1:-1] # add cross attentions if we output attention weights
|
||
|
|
||
|
# add cross-attn cache to positions 3,4 of present_key_value tuple
|
||
|
cross_attn_present_key_value = cross_attention_outputs[-1]
|
||
|
present_key_value = present_key_value + cross_attn_present_key_value
|
||
|
|
||
|
layer_output = apply_chunking_to_forward(
|
||
|
self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output
|
||
|
)
|
||
|
outputs = (layer_output,) + outputs
|
||
|
|
||
|
# if decoder, return the attn key/values as the last output
|
||
|
if self.is_decoder:
|
||
|
outputs = outputs + (present_key_value,)
|
||
|
|
||
|
return outputs
|
||
|
|
||
|
def feed_forward_chunk(self, attention_output):
|
||
|
intermediate_output = self.intermediate(attention_output)
|
||
|
layer_output = self.output(intermediate_output, attention_output)
|
||
|
return layer_output
|
||
|
|
||
|
|
||
|
# Copied from transformers.models.bert.modeling_bert.BertEncoder with Bert->Splinter
|
||
|
class SplinterEncoder(nn.Module):
|
||
|
def __init__(self, config):
|
||
|
super().__init__()
|
||
|
self.config = config
|
||
|
self.layer = nn.ModuleList([SplinterLayer(config) for _ in range(config.num_hidden_layers)])
|
||
|
self.gradient_checkpointing = False
|
||
|
|
||
|
def forward(
|
||
|
self,
|
||
|
hidden_states: torch.Tensor,
|
||
|
attention_mask: Optional[torch.FloatTensor] = None,
|
||
|
head_mask: Optional[torch.FloatTensor] = None,
|
||
|
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
||
|
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
||
|
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
|
||
|
use_cache: Optional[bool] = None,
|
||
|
output_attentions: Optional[bool] = False,
|
||
|
output_hidden_states: Optional[bool] = False,
|
||
|
return_dict: Optional[bool] = True,
|
||
|
) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPastAndCrossAttentions]:
|
||
|
all_hidden_states = () if output_hidden_states else None
|
||
|
all_self_attentions = () if output_attentions else None
|
||
|
all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
|
||
|
|
||
|
if self.gradient_checkpointing and self.training:
|
||
|
if use_cache:
|
||
|
logger.warning_once(
|
||
|
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
|
||
|
)
|
||
|
use_cache = False
|
||
|
|
||
|
next_decoder_cache = () if use_cache else None
|
||
|
for i, layer_module in enumerate(self.layer):
|
||
|
if output_hidden_states:
|
||
|
all_hidden_states = all_hidden_states + (hidden_states,)
|
||
|
|
||
|
layer_head_mask = head_mask[i] if head_mask is not None else None
|
||
|
past_key_value = past_key_values[i] if past_key_values is not None else None
|
||
|
|
||
|
if self.gradient_checkpointing and self.training:
|
||
|
layer_outputs = self._gradient_checkpointing_func(
|
||
|
layer_module.__call__,
|
||
|
hidden_states,
|
||
|
attention_mask,
|
||
|
layer_head_mask,
|
||
|
encoder_hidden_states,
|
||
|
encoder_attention_mask,
|
||
|
past_key_value,
|
||
|
output_attentions,
|
||
|
)
|
||
|
else:
|
||
|
layer_outputs = layer_module(
|
||
|
hidden_states,
|
||
|
attention_mask,
|
||
|
layer_head_mask,
|
||
|
encoder_hidden_states,
|
||
|
encoder_attention_mask,
|
||
|
past_key_value,
|
||
|
output_attentions,
|
||
|
)
|
||
|
|
||
|
hidden_states = layer_outputs[0]
|
||
|
if use_cache:
|
||
|
next_decoder_cache += (layer_outputs[-1],)
|
||
|
if output_attentions:
|
||
|
all_self_attentions = all_self_attentions + (layer_outputs[1],)
|
||
|
if self.config.add_cross_attention:
|
||
|
all_cross_attentions = all_cross_attentions + (layer_outputs[2],)
|
||
|
|
||
|
if output_hidden_states:
|
||
|
all_hidden_states = all_hidden_states + (hidden_states,)
|
||
|
|
||
|
if not return_dict:
|
||
|
return tuple(
|
||
|
v
|
||
|
for v in [
|
||
|
hidden_states,
|
||
|
next_decoder_cache,
|
||
|
all_hidden_states,
|
||
|
all_self_attentions,
|
||
|
all_cross_attentions,
|
||
|
]
|
||
|
if v is not None
|
||
|
)
|
||
|
return BaseModelOutputWithPastAndCrossAttentions(
|
||
|
last_hidden_state=hidden_states,
|
||
|
past_key_values=next_decoder_cache,
|
||
|
hidden_states=all_hidden_states,
|
||
|
attentions=all_self_attentions,
|
||
|
cross_attentions=all_cross_attentions,
|
||
|
)
|
||
|
|
||
|
|
||
|
class SplinterPreTrainedModel(PreTrainedModel):
|
||
|
"""
|
||
|
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
|
||
|
models.
|
||
|
"""
|
||
|
|
||
|
config_class = SplinterConfig
|
||
|
base_model_prefix = "splinter"
|
||
|
supports_gradient_checkpointing = True
|
||
|
|
||
|
# Copied from transformers.models.bert.modeling_bert.BertPreTrainedModel._init_weights
|
||
|
def _init_weights(self, module):
|
||
|
"""Initialize the weights"""
|
||
|
if isinstance(module, nn.Linear):
|
||
|
# Slightly different from the TF version which uses truncated_normal for initialization
|
||
|
# cf https://github.com/pytorch/pytorch/pull/5617
|
||
|
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
||
|
if module.bias is not None:
|
||
|
module.bias.data.zero_()
|
||
|
elif isinstance(module, nn.Embedding):
|
||
|
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
||
|
if module.padding_idx is not None:
|
||
|
module.weight.data[module.padding_idx].zero_()
|
||
|
elif isinstance(module, nn.LayerNorm):
|
||
|
module.bias.data.zero_()
|
||
|
module.weight.data.fill_(1.0)
|
||
|
|
||
|
|
||
|
SPLINTER_START_DOCSTRING = r"""
|
||
|
This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) sub-class. Use
|
||
|
it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and
|
||
|
behavior.
|
||
|
|
||
|
Parameters:
|
||
|
config ([`SplinterConfig`]): Model configuration class with all the parameters of the model.
|
||
|
Initializing with a config file does not load the weights associated with the model, only the
|
||
|
configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
|
||
|
"""
|
||
|
|
||
|
SPLINTER_INPUTS_DOCSTRING = r"""
|
||
|
Args:
|
||
|
input_ids (`torch.LongTensor` of shape `({0})`):
|
||
|
Indices of input sequence tokens in the vocabulary.
|
||
|
|
||
|
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
|
||
|
[`PreTrainedTokenizer.__call__`] for details.
|
||
|
|
||
|
[What are input IDs?](../glossary#input-ids)
|
||
|
attention_mask (`torch.FloatTensor` of shape `{0}`, *optional*):
|
||
|
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
|
||
|
|
||
|
- 1 for tokens that are **not masked**,
|
||
|
- 0 for tokens that are **masked**.
|
||
|
|
||
|
[What are attention masks?](../glossary#attention-mask)
|
||
|
token_type_ids (`torch.LongTensor` of shape `{0}`, *optional*):
|
||
|
Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0,
|
||
|
1]`:
|
||
|
|
||
|
- 0 corresponds to a *sentence A* token,
|
||
|
- 1 corresponds to a *sentence B* token.
|
||
|
|
||
|
[What are token type IDs?](../glossary#token-type-ids)
|
||
|
position_ids (`torch.LongTensor` of shape `{0}`, *optional*):
|
||
|
Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
|
||
|
config.max_position_embeddings - 1]`.
|
||
|
|
||
|
[What are position IDs?](../glossary#position-ids)
|
||
|
head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
|
||
|
Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`:
|
||
|
|
||
|
- 1 indicates the head is **not masked**,
|
||
|
- 0 indicates the head is **masked**.
|
||
|
|
||
|
inputs_embeds (`torch.FloatTensor` of shape `({0}, hidden_size)`, *optional*):
|
||
|
Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
|
||
|
is useful if you want more control over how to convert *input_ids* indices into associated vectors than the
|
||
|
model's internal embedding lookup matrix.
|
||
|
output_attentions (`bool`, *optional*):
|
||
|
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
|
||
|
tensors for more detail.
|
||
|
output_hidden_states (`bool`, *optional*):
|
||
|
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
|
||
|
more detail.
|
||
|
return_dict (`bool`, *optional*):
|
||
|
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
||
|
"""
|
||
|
|
||
|
|
||
|
@add_start_docstrings(
|
||
|
"The bare Splinter Model transformer outputting raw hidden-states without any specific head on top.",
|
||
|
SPLINTER_START_DOCSTRING,
|
||
|
)
|
||
|
class SplinterModel(SplinterPreTrainedModel):
|
||
|
"""
|
||
|
The model is an encoder (with only self-attention) following the architecture described in [Attention is all you
|
||
|
need](https://arxiv.org/abs/1706.03762) by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones,
|
||
|
Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin.
|
||
|
"""
|
||
|
|
||
|
def __init__(self, config):
|
||
|
super().__init__(config)
|
||
|
self.config = config
|
||
|
|
||
|
self.embeddings = SplinterEmbeddings(config)
|
||
|
self.encoder = SplinterEncoder(config)
|
||
|
|
||
|
# Initialize weights and apply final processing
|
||
|
self.post_init()
|
||
|
|
||
|
def get_input_embeddings(self):
|
||
|
return self.embeddings.word_embeddings
|
||
|
|
||
|
def set_input_embeddings(self, value):
|
||
|
self.embeddings.word_embeddings = value
|
||
|
|
||
|
def _prune_heads(self, heads_to_prune):
|
||
|
"""
|
||
|
Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
|
||
|
class PreTrainedModel
|
||
|
"""
|
||
|
for layer, heads in heads_to_prune.items():
|
||
|
self.encoder.layer[layer].attention.prune_heads(heads)
|
||
|
|
||
|
@add_start_docstrings_to_model_forward(SPLINTER_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
|
||
|
@add_code_sample_docstrings(
|
||
|
checkpoint=_CHECKPOINT_FOR_DOC,
|
||
|
output_type=BaseModelOutputWithPastAndCrossAttentions,
|
||
|
config_class=_CONFIG_FOR_DOC,
|
||
|
)
|
||
|
def forward(
|
||
|
self,
|
||
|
input_ids: Optional[torch.Tensor] = None,
|
||
|
attention_mask: Optional[torch.Tensor] = None,
|
||
|
token_type_ids: Optional[torch.Tensor] = None,
|
||
|
position_ids: Optional[torch.Tensor] = None,
|
||
|
head_mask: Optional[torch.Tensor] = None,
|
||
|
inputs_embeds: Optional[torch.Tensor] = None,
|
||
|
encoder_hidden_states: Optional[torch.Tensor] = None,
|
||
|
encoder_attention_mask: Optional[torch.Tensor] = None,
|
||
|
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
||
|
use_cache: Optional[bool] = None,
|
||
|
output_attentions: Optional[bool] = None,
|
||
|
output_hidden_states: Optional[bool] = None,
|
||
|
return_dict: Optional[bool] = None,
|
||
|
) -> Union[Tuple, BaseModelOutputWithPastAndCrossAttentions]:
|
||
|
r"""
|
||
|
encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
|
||
|
Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
|
||
|
the model is configured as a decoder.
|
||
|
encoder_attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||
|
Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
|
||
|
the cross-attention if the model is configured as a decoder. Mask values selected in `[0, 1]`:
|
||
|
|
||
|
- 1 for tokens that are **not masked**,
|
||
|
- 0 for tokens that are **masked**.
|
||
|
past_key_values (`tuple(tuple(torch.FloatTensor))` of length `config.n_layers` with each tuple having 4 tensors of shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
|
||
|
Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
|
||
|
If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
|
||
|
don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
|
||
|
`decoder_input_ids` of shape `(batch_size, sequence_length)`.
|
||
|
use_cache (`bool`, *optional*):
|
||
|
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
|
||
|
`past_key_values`).
|
||
|
"""
|
||
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||
|
output_hidden_states = (
|
||
|
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||
|
)
|
||
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||
|
|
||
|
if self.config.is_decoder:
|
||
|
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
||
|
else:
|
||
|
use_cache = False
|
||
|
|
||
|
if input_ids is not None and inputs_embeds is not None:
|
||
|
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
|
||
|
elif input_ids is not None:
|
||
|
self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask)
|
||
|
input_shape = input_ids.size()
|
||
|
elif inputs_embeds is not None:
|
||
|
input_shape = inputs_embeds.size()[:-1]
|
||
|
else:
|
||
|
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
||
|
|
||
|
batch_size, seq_length = input_shape
|
||
|
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
||
|
|
||
|
# past_key_values_length
|
||
|
past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0
|
||
|
|
||
|
if attention_mask is None:
|
||
|
attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length)), device=device)
|
||
|
if token_type_ids is None:
|
||
|
token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
|
||
|
|
||
|
# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
|
||
|
# ourselves in which case we just need to make it broadcastable to all heads.
|
||
|
extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape)
|
||
|
|
||
|
# If a 2D or 3D attention mask is provided for the cross-attention
|
||
|
# we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
|
||
|
if self.config.is_decoder and encoder_hidden_states is not None:
|
||
|
encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
|
||
|
encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
|
||
|
if encoder_attention_mask is None:
|
||
|
encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
|
||
|
encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
|
||
|
else:
|
||
|
encoder_extended_attention_mask = None
|
||
|
|
||
|
# Prepare head mask if needed
|
||
|
# 1.0 in head_mask indicate we keep the head
|
||
|
# attention_probs has shape bsz x n_heads x N x N
|
||
|
# input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
|
||
|
# and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
|
||
|
head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
|
||
|
|
||
|
embedding_output = self.embeddings(
|
||
|
input_ids=input_ids,
|
||
|
position_ids=position_ids,
|
||
|
token_type_ids=token_type_ids,
|
||
|
inputs_embeds=inputs_embeds,
|
||
|
past_key_values_length=past_key_values_length,
|
||
|
)
|
||
|
encoder_outputs = self.encoder(
|
||
|
embedding_output,
|
||
|
attention_mask=extended_attention_mask,
|
||
|
head_mask=head_mask,
|
||
|
encoder_hidden_states=encoder_hidden_states,
|
||
|
encoder_attention_mask=encoder_extended_attention_mask,
|
||
|
past_key_values=past_key_values,
|
||
|
use_cache=use_cache,
|
||
|
output_attentions=output_attentions,
|
||
|
output_hidden_states=output_hidden_states,
|
||
|
return_dict=return_dict,
|
||
|
)
|
||
|
sequence_output = encoder_outputs[0]
|
||
|
|
||
|
if not return_dict:
|
||
|
return (sequence_output,) + encoder_outputs[1:]
|
||
|
|
||
|
return BaseModelOutputWithPastAndCrossAttentions(
|
||
|
last_hidden_state=sequence_output,
|
||
|
past_key_values=encoder_outputs.past_key_values,
|
||
|
hidden_states=encoder_outputs.hidden_states,
|
||
|
attentions=encoder_outputs.attentions,
|
||
|
cross_attentions=encoder_outputs.cross_attentions,
|
||
|
)
|
||
|
|
||
|
|
||
|
class SplinterFullyConnectedLayer(nn.Module):
|
||
|
def __init__(self, input_dim, output_dim, hidden_act="gelu"):
|
||
|
super().__init__()
|
||
|
|
||
|
self.input_dim = input_dim
|
||
|
self.output_dim = output_dim
|
||
|
|
||
|
self.dense = nn.Linear(self.input_dim, self.output_dim)
|
||
|
self.act_fn = ACT2FN[hidden_act]
|
||
|
self.LayerNorm = nn.LayerNorm(self.output_dim)
|
||
|
|
||
|
def forward(self, inputs: torch.Tensor) -> torch.Tensor:
|
||
|
hidden_states = self.dense(inputs)
|
||
|
hidden_states = self.act_fn(hidden_states)
|
||
|
hidden_states = self.LayerNorm(hidden_states)
|
||
|
return hidden_states
|
||
|
|
||
|
|
||
|
class QuestionAwareSpanSelectionHead(nn.Module):
|
||
|
"""
|
||
|
Implementation of Question-Aware Span Selection (QASS) head, described in Splinter's paper:
|
||
|
|
||
|
"""
|
||
|
|
||
|
def __init__(self, config):
|
||
|
super().__init__()
|
||
|
|
||
|
self.query_start_transform = SplinterFullyConnectedLayer(config.hidden_size, config.hidden_size)
|
||
|
self.query_end_transform = SplinterFullyConnectedLayer(config.hidden_size, config.hidden_size)
|
||
|
self.start_transform = SplinterFullyConnectedLayer(config.hidden_size, config.hidden_size)
|
||
|
self.end_transform = SplinterFullyConnectedLayer(config.hidden_size, config.hidden_size)
|
||
|
|
||
|
self.start_classifier = nn.Linear(config.hidden_size, config.hidden_size, bias=False)
|
||
|
self.end_classifier = nn.Linear(config.hidden_size, config.hidden_size, bias=False)
|
||
|
|
||
|
def forward(self, inputs, positions):
|
||
|
_, _, dim = inputs.size()
|
||
|
index = positions.unsqueeze(-1).repeat(1, 1, dim) # [batch_size, num_positions, dim]
|
||
|
gathered_reps = torch.gather(inputs, dim=1, index=index) # [batch_size, num_positions, dim]
|
||
|
|
||
|
query_start_reps = self.query_start_transform(gathered_reps) # [batch_size, num_positions, dim]
|
||
|
query_end_reps = self.query_end_transform(gathered_reps) # [batch_size, num_positions, dim]
|
||
|
start_reps = self.start_transform(inputs) # [batch_size, seq_length, dim]
|
||
|
end_reps = self.end_transform(inputs) # [batch_size, seq_length, dim]
|
||
|
|
||
|
hidden_states = self.start_classifier(query_start_reps) # [batch_size, num_positions, dim]
|
||
|
start_reps = start_reps.permute(0, 2, 1) # [batch_size, dim, seq_length]
|
||
|
start_logits = torch.matmul(hidden_states, start_reps)
|
||
|
|
||
|
hidden_states = self.end_classifier(query_end_reps)
|
||
|
end_reps = end_reps.permute(0, 2, 1)
|
||
|
end_logits = torch.matmul(hidden_states, end_reps)
|
||
|
|
||
|
return start_logits, end_logits
|
||
|
|
||
|
|
||
|
@add_start_docstrings(
|
||
|
"""
|
||
|
Splinter Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear
|
||
|
layers on top of the hidden-states output to compute `span start logits` and `span end logits`).
|
||
|
""",
|
||
|
SPLINTER_START_DOCSTRING,
|
||
|
)
|
||
|
class SplinterForQuestionAnswering(SplinterPreTrainedModel):
|
||
|
def __init__(self, config):
|
||
|
super().__init__(config)
|
||
|
|
||
|
self.splinter = SplinterModel(config)
|
||
|
self.splinter_qass = QuestionAwareSpanSelectionHead(config)
|
||
|
self.question_token_id = config.question_token_id
|
||
|
|
||
|
# Initialize weights and apply final processing
|
||
|
self.post_init()
|
||
|
|
||
|
@add_start_docstrings_to_model_forward(SPLINTER_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
|
||
|
@add_code_sample_docstrings(
|
||
|
checkpoint=_CHECKPOINT_FOR_DOC,
|
||
|
output_type=QuestionAnsweringModelOutput,
|
||
|
config_class=_CONFIG_FOR_DOC,
|
||
|
)
|
||
|
def forward(
|
||
|
self,
|
||
|
input_ids: Optional[torch.Tensor] = None,
|
||
|
attention_mask: Optional[torch.Tensor] = None,
|
||
|
token_type_ids: Optional[torch.Tensor] = None,
|
||
|
position_ids: Optional[torch.Tensor] = None,
|
||
|
head_mask: Optional[torch.Tensor] = None,
|
||
|
inputs_embeds: Optional[torch.Tensor] = None,
|
||
|
start_positions: Optional[torch.LongTensor] = None,
|
||
|
end_positions: Optional[torch.LongTensor] = None,
|
||
|
output_attentions: Optional[bool] = None,
|
||
|
output_hidden_states: Optional[bool] = None,
|
||
|
return_dict: Optional[bool] = None,
|
||
|
question_positions: Optional[torch.LongTensor] = None,
|
||
|
) -> Union[Tuple, QuestionAnsweringModelOutput]:
|
||
|
r"""
|
||
|
start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
||
|
Labels for position (index) of the start of the labelled span for computing the token classification loss.
|
||
|
Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
|
||
|
are not taken into account for computing the loss.
|
||
|
end_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
||
|
Labels for position (index) of the end of the labelled span for computing the token classification loss.
|
||
|
Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
|
||
|
are not taken into account for computing the loss.
|
||
|
question_positions (`torch.LongTensor` of shape `(batch_size, num_questions)`, *optional*):
|
||
|
The positions of all question tokens. If given, start_logits and end_logits will be of shape `(batch_size,
|
||
|
num_questions, sequence_length)`. If None, the first question token in each sequence in the batch will be
|
||
|
the only one for which start_logits and end_logits are calculated and they will be of shape `(batch_size,
|
||
|
sequence_length)`.
|
||
|
"""
|
||
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||
|
|
||
|
question_positions_were_none = False
|
||
|
if question_positions is None:
|
||
|
if input_ids is not None:
|
||
|
question_position_for_each_example = torch.argmax(
|
||
|
(torch.eq(input_ids, self.question_token_id)).int(), dim=-1
|
||
|
)
|
||
|
else:
|
||
|
question_position_for_each_example = torch.zeros(
|
||
|
inputs_embeds.size(0), dtype=torch.long, layout=inputs_embeds.layout, device=inputs_embeds.device
|
||
|
)
|
||
|
question_positions = question_position_for_each_example.unsqueeze(-1)
|
||
|
question_positions_were_none = True
|
||
|
|
||
|
outputs = self.splinter(
|
||
|
input_ids,
|
||
|
attention_mask=attention_mask,
|
||
|
token_type_ids=token_type_ids,
|
||
|
position_ids=position_ids,
|
||
|
head_mask=head_mask,
|
||
|
inputs_embeds=inputs_embeds,
|
||
|
output_attentions=output_attentions,
|
||
|
output_hidden_states=output_hidden_states,
|
||
|
return_dict=return_dict,
|
||
|
)
|
||
|
|
||
|
sequence_output = outputs[0]
|
||
|
start_logits, end_logits = self.splinter_qass(sequence_output, question_positions)
|
||
|
|
||
|
if question_positions_were_none:
|
||
|
start_logits, end_logits = start_logits.squeeze(1), end_logits.squeeze(1)
|
||
|
|
||
|
if attention_mask is not None:
|
||
|
start_logits = start_logits + (1 - attention_mask) * torch.finfo(start_logits.dtype).min
|
||
|
end_logits = end_logits + (1 - attention_mask) * torch.finfo(end_logits.dtype).min
|
||
|
|
||
|
total_loss = None
|
||
|
if start_positions is not None and end_positions is not None:
|
||
|
# If we are on multi-GPU, split add a dimension
|
||
|
if len(start_positions.size()) > 1:
|
||
|
start_positions = start_positions.squeeze(-1)
|
||
|
if len(end_positions.size()) > 1:
|
||
|
end_positions = end_positions.squeeze(-1)
|
||
|
# sometimes the start/end positions are outside our model inputs, we ignore these terms
|
||
|
ignored_index = start_logits.size(1)
|
||
|
start_positions.clamp_(0, ignored_index)
|
||
|
end_positions.clamp_(0, ignored_index)
|
||
|
|
||
|
loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
|
||
|
start_loss = loss_fct(start_logits, start_positions)
|
||
|
end_loss = loss_fct(end_logits, end_positions)
|
||
|
total_loss = (start_loss + end_loss) / 2
|
||
|
|
||
|
if not return_dict:
|
||
|
output = (start_logits, end_logits) + outputs[1:]
|
||
|
return ((total_loss,) + output) if total_loss is not None else output
|
||
|
|
||
|
return QuestionAnsweringModelOutput(
|
||
|
loss=total_loss,
|
||
|
start_logits=start_logits,
|
||
|
end_logits=end_logits,
|
||
|
hidden_states=outputs.hidden_states,
|
||
|
attentions=outputs.attentions,
|
||
|
)
|
||
|
|
||
|
|
||
|
@dataclass
|
||
|
class SplinterForPreTrainingOutput(ModelOutput):
|
||
|
"""
|
||
|
Class for outputs of Splinter as a span selection model.
|
||
|
|
||
|
Args:
|
||
|
loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when start and end positions are provided):
|
||
|
Total span extraction loss is the sum of a Cross-Entropy for the start and end positions.
|
||
|
start_logits (`torch.FloatTensor` of shape `(batch_size, num_questions, sequence_length)`):
|
||
|
Span-start scores (before SoftMax).
|
||
|
end_logits (`torch.FloatTensor` of shape `(batch_size, num_questions, sequence_length)`):
|
||
|
Span-end scores (before SoftMax).
|
||
|
hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
|
||
|
Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
|
||
|
one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
|
||
|
|
||
|
Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
|
||
|
attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
|
||
|
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
|
||
|
sequence_length)`.
|
||
|
|
||
|
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
|
||
|
heads.
|
||
|
"""
|
||
|
|
||
|
loss: Optional[torch.FloatTensor] = None
|
||
|
start_logits: torch.FloatTensor = None
|
||
|
end_logits: torch.FloatTensor = None
|
||
|
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
|
||
|
attentions: Optional[Tuple[torch.FloatTensor]] = None
|
||
|
|
||
|
|
||
|
@add_start_docstrings(
|
||
|
"""
|
||
|
Splinter Model for the recurring span selection task as done during the pretraining. The difference to the QA task
|
||
|
is that we do not have a question, but multiple question tokens that replace the occurrences of recurring spans
|
||
|
instead.
|
||
|
""",
|
||
|
SPLINTER_START_DOCSTRING,
|
||
|
)
|
||
|
class SplinterForPreTraining(SplinterPreTrainedModel):
|
||
|
def __init__(self, config):
|
||
|
super().__init__(config)
|
||
|
|
||
|
self.splinter = SplinterModel(config)
|
||
|
self.splinter_qass = QuestionAwareSpanSelectionHead(config)
|
||
|
self.question_token_id = config.question_token_id
|
||
|
|
||
|
# Initialize weights and apply final processing
|
||
|
self.post_init()
|
||
|
|
||
|
@add_start_docstrings_to_model_forward(
|
||
|
SPLINTER_INPUTS_DOCSTRING.format("batch_size, num_questions, sequence_length")
|
||
|
)
|
||
|
def forward(
|
||
|
self,
|
||
|
input_ids: Optional[torch.Tensor] = None,
|
||
|
attention_mask: Optional[torch.Tensor] = None,
|
||
|
token_type_ids: Optional[torch.Tensor] = None,
|
||
|
position_ids: Optional[torch.Tensor] = None,
|
||
|
head_mask: Optional[torch.Tensor] = None,
|
||
|
inputs_embeds: Optional[torch.Tensor] = None,
|
||
|
start_positions: Optional[torch.LongTensor] = None,
|
||
|
end_positions: Optional[torch.LongTensor] = None,
|
||
|
output_attentions: Optional[bool] = None,
|
||
|
output_hidden_states: Optional[bool] = None,
|
||
|
return_dict: Optional[bool] = None,
|
||
|
question_positions: Optional[torch.LongTensor] = None,
|
||
|
) -> Union[Tuple, SplinterForPreTrainingOutput]:
|
||
|
r"""
|
||
|
start_positions (`torch.LongTensor` of shape `(batch_size, num_questions)`, *optional*):
|
||
|
Labels for position (index) of the start of the labelled span for computing the token classification loss.
|
||
|
Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
|
||
|
are not taken into account for computing the loss.
|
||
|
end_positions (`torch.LongTensor` of shape `(batch_size, num_questions)`, *optional*):
|
||
|
Labels for position (index) of the end of the labelled span for computing the token classification loss.
|
||
|
Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
|
||
|
are not taken into account for computing the loss.
|
||
|
question_positions (`torch.LongTensor` of shape `(batch_size, num_questions)`, *optional*):
|
||
|
The positions of all question tokens. If given, start_logits and end_logits will be of shape `(batch_size,
|
||
|
num_questions, sequence_length)`. If None, the first question token in each sequence in the batch will be
|
||
|
the only one for which start_logits and end_logits are calculated and they will be of shape `(batch_size,
|
||
|
sequence_length)`.
|
||
|
"""
|
||
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||
|
|
||
|
if question_positions is None and start_positions is not None and end_positions is not None:
|
||
|
raise TypeError("question_positions must be specified in order to calculate the loss")
|
||
|
|
||
|
elif question_positions is None and input_ids is None:
|
||
|
raise TypeError("question_positions must be specified when input_embeds is used")
|
||
|
|
||
|
elif question_positions is None:
|
||
|
question_positions = self._prepare_question_positions(input_ids)
|
||
|
|
||
|
outputs = self.splinter(
|
||
|
input_ids,
|
||
|
attention_mask=attention_mask,
|
||
|
token_type_ids=token_type_ids,
|
||
|
position_ids=position_ids,
|
||
|
head_mask=head_mask,
|
||
|
inputs_embeds=inputs_embeds,
|
||
|
output_attentions=output_attentions,
|
||
|
output_hidden_states=output_hidden_states,
|
||
|
return_dict=return_dict,
|
||
|
)
|
||
|
|
||
|
sequence_output = outputs[0]
|
||
|
batch_size, sequence_length, dim = sequence_output.size()
|
||
|
# [batch_size, num_questions, sequence_length]
|
||
|
start_logits, end_logits = self.splinter_qass(sequence_output, question_positions)
|
||
|
|
||
|
num_questions = question_positions.size(1)
|
||
|
if attention_mask is not None:
|
||
|
attention_mask_for_each_question = attention_mask.unsqueeze(1).expand(
|
||
|
batch_size, num_questions, sequence_length
|
||
|
)
|
||
|
start_logits = start_logits + (1 - attention_mask_for_each_question) * torch.finfo(start_logits.dtype).min
|
||
|
end_logits = end_logits + (1 - attention_mask_for_each_question) * torch.finfo(end_logits.dtype).min
|
||
|
|
||
|
total_loss = None
|
||
|
# [batch_size, num_questions, sequence_length]
|
||
|
if start_positions is not None and end_positions is not None:
|
||
|
# sometimes the start/end positions are outside our model inputs, we ignore these terms
|
||
|
start_positions.clamp_(0, max(0, sequence_length - 1))
|
||
|
end_positions.clamp_(0, max(0, sequence_length - 1))
|
||
|
|
||
|
# Ignore zero positions in the loss. Splinter never predicts zero
|
||
|
# during pretraining and zero is used for padding question
|
||
|
# tokens as well as for start and end positions of padded
|
||
|
# question tokens.
|
||
|
loss_fct = CrossEntropyLoss(ignore_index=self.config.pad_token_id)
|
||
|
start_loss = loss_fct(
|
||
|
start_logits.view(batch_size * num_questions, sequence_length),
|
||
|
start_positions.view(batch_size * num_questions),
|
||
|
)
|
||
|
end_loss = loss_fct(
|
||
|
end_logits.view(batch_size * num_questions, sequence_length),
|
||
|
end_positions.view(batch_size * num_questions),
|
||
|
)
|
||
|
total_loss = (start_loss + end_loss) / 2
|
||
|
|
||
|
if not return_dict:
|
||
|
output = (start_logits, end_logits) + outputs[1:]
|
||
|
return ((total_loss,) + output) if total_loss is not None else output
|
||
|
|
||
|
return SplinterForPreTrainingOutput(
|
||
|
loss=total_loss,
|
||
|
start_logits=start_logits,
|
||
|
end_logits=end_logits,
|
||
|
hidden_states=outputs.hidden_states,
|
||
|
attentions=outputs.attentions,
|
||
|
)
|
||
|
|
||
|
def _prepare_question_positions(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||
|
rows, flat_positions = torch.where(input_ids == self.config.question_token_id)
|
||
|
num_questions = torch.bincount(rows)
|
||
|
positions = torch.full(
|
||
|
(input_ids.size(0), num_questions.max()),
|
||
|
self.config.pad_token_id,
|
||
|
dtype=torch.long,
|
||
|
device=input_ids.device,
|
||
|
)
|
||
|
cols = torch.cat([torch.arange(n) for n in num_questions])
|
||
|
positions[rows, cols] = flat_positions
|
||
|
return positions
|