873 lines
37 KiB
Python
873 lines
37 KiB
Python
|
# coding=utf-8
|
||
|
# Copyright 2022 The OpenBMB Team and The HuggingFace Inc. team. All rights reserved.
|
||
|
#
|
||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||
|
# you may not use this file except in compliance with the License.
|
||
|
# You may obtain a copy of the License at
|
||
|
#
|
||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||
|
#
|
||
|
# Unless required by applicable law or agreed to in writing, software
|
||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
|
# See the License for the specific language governing permissions and
|
||
|
# limitations under the License.
|
||
|
""" PyTorch CPMAnt"""
|
||
|
|
||
|
|
||
|
import math
|
||
|
from typing import List, Optional, Tuple, Union
|
||
|
|
||
|
import torch
|
||
|
import torch.nn.functional as F
|
||
|
import torch.utils.checkpoint
|
||
|
from torch import nn
|
||
|
from torch.nn import CrossEntropyLoss
|
||
|
|
||
|
from ...activations import ACT2FN
|
||
|
from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
|
||
|
from ...modeling_utils import PreTrainedModel
|
||
|
from ...utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, logging
|
||
|
from .configuration_cpmant import CpmAntConfig
|
||
|
|
||
|
|
||
|
logger = logging.get_logger(__name__)
|
||
|
|
||
|
_CHECKPOINT_FOR_DOC = "openbmb/cpm-ant-10b"
|
||
|
_CONFIG_FOR_DOC = "CpmAntConfig"
|
||
|
|
||
|
|
||
|
from ..deprecated._archive_maps import CPMANT_PRETRAINED_MODEL_ARCHIVE_LIST # noqa: F401, E402
|
||
|
|
||
|
|
||
|
class CpmAntLayerNorm(nn.Module):
|
||
|
"""
|
||
|
We use Root Mean Square (RMS) Layer Normalization, please see https://arxiv.org/abs/1910.07467 for details."
|
||
|
"""
|
||
|
|
||
|
def __init__(self, config: CpmAntConfig):
|
||
|
super().__init__()
|
||
|
|
||
|
self.eps = config.eps
|
||
|
self.dim_norm = config.hidden_size
|
||
|
self.weight = nn.Parameter(torch.empty(config.hidden_size))
|
||
|
|
||
|
def forward(self, hidden_states: torch.Tensor):
|
||
|
"""
|
||
|
Args:
|
||
|
hidden_states (`torch.Tensor` of shape `(batch, seq_len, dim_in)`)
|
||
|
"""
|
||
|
if hidden_states.size(-1) != self.dim_norm:
|
||
|
raise AssertionError("hidden_states.size(-1) != self.dim_norm")
|
||
|
old_dtype = hidden_states.dtype
|
||
|
variance = hidden_states.to(torch.float32).pow(2).mean(dim=-1, keepdim=True)
|
||
|
hidden_states = (hidden_states * torch.rsqrt(variance + self.eps)).to(old_dtype) * self.weight
|
||
|
return hidden_states
|
||
|
|
||
|
|
||
|
class CpmAntAttention(nn.Module):
|
||
|
def __init__(self, config: CpmAntConfig):
|
||
|
super().__init__()
|
||
|
self.dim_model = config.hidden_size
|
||
|
self.num_heads = config.num_attention_heads
|
||
|
self.dim_head = config.dim_head
|
||
|
|
||
|
self.project_q = nn.Linear(self.dim_model, self.num_heads * self.dim_head, bias=False)
|
||
|
self.project_k = nn.Linear(self.dim_model, self.num_heads * self.dim_head, bias=False)
|
||
|
self.project_v = nn.Linear(self.dim_model, self.num_heads * self.dim_head, bias=False)
|
||
|
|
||
|
self.attention_out = nn.Linear(self.num_heads * self.dim_head, self.dim_model, bias=False)
|
||
|
|
||
|
self.softmax = torch.nn.Softmax(dim=-1)
|
||
|
|
||
|
if config.dropout_p is not None:
|
||
|
self.dropout = torch.nn.Dropout(p=config.dropout_p)
|
||
|
else:
|
||
|
self.dropout = None
|
||
|
|
||
|
def forward(
|
||
|
self,
|
||
|
hidden_q: torch.Tensor,
|
||
|
hidden_kv: torch.Tensor,
|
||
|
attention_mask: torch.BoolTensor,
|
||
|
position_bias: torch.Tensor,
|
||
|
output_attentions: Optional[bool] = False,
|
||
|
past_key_values: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
||
|
use_cache: Optional[bool] = None,
|
||
|
):
|
||
|
"""
|
||
|
Args:
|
||
|
hidden_q (`torch.Tensor`):
|
||
|
Input of transformer block(self-attention block). It can be the raw embedding of a batch of sequences.
|
||
|
hidden_kv (`torch.Tensor` of shape `(batch, len_k, dim_model)`)):
|
||
|
Tensor *key_value* and *query* of shape `(batch, len_k, dim_model)`
|
||
|
attention_mask (`torch.Tensor` of shape `(batch, len_seq, len_seq)`):
|
||
|
Avoid invalid areas to participate in the calculation of self-attention.
|
||
|
position_bias (`torch.Tensor` of shape `(batch, len_seq, len_seq)`):
|
||
|
Provide positional information to self-attention block.
|
||
|
output_attentions (`bool`, *optional*):
|
||
|
Whether or not to return the attentions tensors of all attention layers.
|
||
|
past_key_values (`Tuple[torch.Tensor, torch.Tensor]`, *optional*):
|
||
|
Cached past key and value projection states.
|
||
|
use_cache (`bool`, *optional*):
|
||
|
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
|
||
|
(see `past_key_values`).
|
||
|
"""
|
||
|
batch_size = hidden_q.size(0)
|
||
|
len_q = hidden_q.size(1)
|
||
|
len_k = hidden_kv.size(1)
|
||
|
|
||
|
query = self.project_q(hidden_q)
|
||
|
key = self.project_k(hidden_kv)
|
||
|
value = self.project_v(hidden_kv)
|
||
|
|
||
|
query = query.view(batch_size, len_q, self.num_heads, self.dim_head).permute(0, 2, 1, 3)
|
||
|
key = key.view(batch_size, len_k, self.num_heads, self.dim_head).permute(0, 2, 1, 3)
|
||
|
value = value.view(batch_size, len_k, self.num_heads, self.dim_head).permute(0, 2, 1, 3)
|
||
|
|
||
|
if past_key_values is not None:
|
||
|
key = torch.cat([past_key_values[0], key], dim=-2)
|
||
|
value = torch.cat([past_key_values[1], value], dim=-2)
|
||
|
len_k = key.size(-2)
|
||
|
|
||
|
# (batch_size, num_heads, len_q, dim_head) @ (batch_size, num_heads, dim_head, len_k) -> (batch_size, num_heads, len_q, len_k)
|
||
|
score = torch.matmul(query, key.transpose(-1, -2)) / math.sqrt(self.dim_head)
|
||
|
score = score + position_bias
|
||
|
|
||
|
score = torch.masked_fill(
|
||
|
score,
|
||
|
attention_mask.view(batch_size, 1, len_q, len_k) == torch.tensor(False),
|
||
|
torch.scalar_tensor(float("-inf"), device=score.device, dtype=score.dtype),
|
||
|
)
|
||
|
score = self.softmax(score)
|
||
|
|
||
|
score = torch.masked_fill(
|
||
|
score,
|
||
|
attention_mask.view(batch_size, 1, len_q, len_k) == torch.tensor(False),
|
||
|
torch.scalar_tensor(0, device=score.device, dtype=score.dtype),
|
||
|
)
|
||
|
if output_attentions:
|
||
|
attn_weights = score
|
||
|
else:
|
||
|
attn_weights = None
|
||
|
|
||
|
if self.dropout is not None:
|
||
|
score = self.dropout(score)
|
||
|
|
||
|
# (batch_size, num_heads, len_q, len_k) @ (batch_size, num_heads, len_k, dim_head) -> (batch_size, num_heads, len_q, dim_head)
|
||
|
score = torch.matmul(score, value)
|
||
|
|
||
|
score = score.view(batch_size, self.num_heads, len_q, self.dim_head).permute(0, 2, 1, 3)
|
||
|
score = score.contiguous().view(batch_size, len_q, self.num_heads * self.dim_head)
|
||
|
|
||
|
score = self.attention_out(score)
|
||
|
|
||
|
past_key_values = None
|
||
|
if use_cache:
|
||
|
past_key_values = (key, value)
|
||
|
|
||
|
return score, attn_weights, past_key_values
|
||
|
|
||
|
|
||
|
class CpmAntSelfAttentionBlock(nn.Module):
|
||
|
def __init__(self, config: CpmAntConfig):
|
||
|
super().__init__()
|
||
|
self.layernorm_before_attention = CpmAntLayerNorm(config)
|
||
|
self.self_attention = CpmAntAttention(config)
|
||
|
if config.dropout_p:
|
||
|
self.dropout = torch.nn.Dropout(config.dropout_p)
|
||
|
else:
|
||
|
self.dropout = None
|
||
|
|
||
|
def forward(
|
||
|
self,
|
||
|
hidden_states: torch.Tensor,
|
||
|
attention_mask: torch.Tensor,
|
||
|
position_bias: Optional[torch.Tensor] = None,
|
||
|
output_attentions: Optional[bool] = False,
|
||
|
past_key_values: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
||
|
use_cache: Optional[bool] = None,
|
||
|
):
|
||
|
"""
|
||
|
Args:
|
||
|
hidden_states (`torch.Tensor` of shape `(batch, len_seq, dim_model)`):
|
||
|
Input of transformer block(self-attention block). It can be the raw embedding of a batch of sequences.
|
||
|
attention_mask (`torch.Tensor` of shape `(batch, len_seq, len_seq)`):
|
||
|
Avoid invalid areas to participate in the calculation of self-attention.
|
||
|
position_bias (`torch.Tensor` of shape `(batch, len_seq, len_seq)`):
|
||
|
Provide positional information to self-attention block.
|
||
|
output_attentions (`bool`, *optional*):
|
||
|
Whether or not to return the attentions tensors of all attention layers.
|
||
|
past_key_values (`Tuple(torch.FloatTensor)`, *optional*):
|
||
|
Cached past key and value projection states.
|
||
|
use_cache (`bool`, *optional*):
|
||
|
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
|
||
|
(see `past_key_values`).
|
||
|
"""
|
||
|
outputs = self.layernorm_before_attention(hidden_states)
|
||
|
outputs = self.self_attention(
|
||
|
outputs, outputs, attention_mask, position_bias, output_attentions, past_key_values, use_cache
|
||
|
)
|
||
|
|
||
|
outputs, attn_weights, current_key_value = outputs
|
||
|
|
||
|
if self.dropout is not None:
|
||
|
outputs = self.dropout(outputs)
|
||
|
hidden_states = hidden_states + outputs
|
||
|
|
||
|
return hidden_states, attn_weights, current_key_value
|
||
|
|
||
|
|
||
|
class CpmAntDenseGatedACT(nn.Module):
|
||
|
def __init__(self, config: CpmAntConfig):
|
||
|
super().__init__()
|
||
|
self.w_0 = nn.Linear(config.hidden_size, config.dim_ff, bias=False)
|
||
|
self.w_1 = nn.Linear(config.hidden_size, config.dim_ff, bias=False)
|
||
|
self.act = torch.nn.GELU()
|
||
|
|
||
|
def forward(self, hidden_states: torch.Tensor):
|
||
|
"""Transform an input tensor from one feature space to another via a nonlinear operation
|
||
|
|
||
|
Args:
|
||
|
hidden_states (`torch.Tensor` of shape `(batch, seq_len, dim_in)`)
|
||
|
"""
|
||
|
gate_score = self.act(self.w_0(hidden_states))
|
||
|
hidden_states = self.w_1(hidden_states)
|
||
|
|
||
|
hidden_states = gate_score * hidden_states
|
||
|
return hidden_states
|
||
|
|
||
|
|
||
|
class CpmAntFeedForward(nn.Module):
|
||
|
def __init__(self, config: CpmAntConfig):
|
||
|
super().__init__()
|
||
|
self.w_in = CpmAntDenseGatedACT(config)
|
||
|
if config.dropout_p is not None:
|
||
|
self.dropout = torch.nn.Dropout(config.dropout_p)
|
||
|
else:
|
||
|
self.dropout = None
|
||
|
|
||
|
self.w_out = nn.Linear(config.dim_ff, config.hidden_size, bias=False)
|
||
|
|
||
|
def forward(self, hidden_states: torch.Tensor):
|
||
|
"""
|
||
|
Args:
|
||
|
hidden_states (`torch.Tensor` of shape `(batch, seq_len, dim_in)`)
|
||
|
"""
|
||
|
hidden_states = self.w_in(hidden_states)
|
||
|
|
||
|
if self.dropout is not None:
|
||
|
hidden_states = self.dropout(hidden_states)
|
||
|
|
||
|
hidden_states = self.w_out(hidden_states)
|
||
|
|
||
|
return hidden_states
|
||
|
|
||
|
|
||
|
class CpmAntFFNBlock(nn.Module):
|
||
|
def __init__(self, config: CpmAntConfig):
|
||
|
super().__init__()
|
||
|
self.layernorm_before_ffn = CpmAntLayerNorm(config)
|
||
|
self.ffn = CpmAntFeedForward(config)
|
||
|
if config.dropout_p:
|
||
|
self.dropout = torch.nn.Dropout(config.dropout_p)
|
||
|
else:
|
||
|
self.dropout = None
|
||
|
|
||
|
def forward(
|
||
|
self,
|
||
|
hidden_states: torch.Tensor,
|
||
|
):
|
||
|
"""
|
||
|
Args:
|
||
|
hidden_states (`torch.Tensor` of shape `(batch, len_seq, dim_model)`):
|
||
|
Hidden states before feed forward layer.
|
||
|
"""
|
||
|
ln_outputs = self.layernorm_before_ffn(hidden_states)
|
||
|
outputs = self.ffn(ln_outputs)
|
||
|
if self.dropout is not None:
|
||
|
outputs = self.dropout(outputs)
|
||
|
hidden_states = hidden_states + outputs
|
||
|
return hidden_states
|
||
|
|
||
|
|
||
|
class CpmAntTransformerBlock(nn.Module):
|
||
|
def __init__(self, config: CpmAntConfig):
|
||
|
super().__init__()
|
||
|
self.self_att = CpmAntSelfAttentionBlock(config)
|
||
|
self.ffn = CpmAntFFNBlock(config)
|
||
|
|
||
|
def forward(
|
||
|
self,
|
||
|
hidden_states: torch.Tensor,
|
||
|
attention_mask: torch.Tensor,
|
||
|
position_bias: Optional[torch.Tensor] = None,
|
||
|
output_attentions: Optional[bool] = False,
|
||
|
past_key_values: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
||
|
use_cache: Optional[bool] = None,
|
||
|
):
|
||
|
"""
|
||
|
Args:
|
||
|
hidden_states (`torch.Tensor`):
|
||
|
Input to the layer of shape `(batch, seq_len, dim_model)`
|
||
|
attention_mask (`torch.Tensor`):
|
||
|
Avoid invalid areas to participate in the calculation of shape `(batch, seq_len, seq_len)`
|
||
|
position_bias (`torch.Tensor`):
|
||
|
Provides position information to attention mechanism of shape `(num_heads, seq_len, seq_len)`
|
||
|
output_attentions (`bool`, *optional*):
|
||
|
Whether or not to return the attentions tensors of all attention layers.
|
||
|
past_key_values (`Tuple[torch.Tensor, torch.Tensor])`, *optional*):
|
||
|
Cached past key and value projection states
|
||
|
use_cache (`bool`, *optional*):
|
||
|
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
|
||
|
(see `past_key_values`).
|
||
|
"""
|
||
|
hidden_states = self.self_att(
|
||
|
hidden_states,
|
||
|
attention_mask=attention_mask,
|
||
|
position_bias=position_bias,
|
||
|
output_attentions=output_attentions,
|
||
|
past_key_values=past_key_values,
|
||
|
use_cache=use_cache,
|
||
|
)
|
||
|
|
||
|
hidden_states, attn_weights, current_key_value = hidden_states
|
||
|
|
||
|
hidden_states = self.ffn(hidden_states)
|
||
|
|
||
|
return hidden_states, attn_weights, current_key_value
|
||
|
|
||
|
|
||
|
class CpmAntEncoder(nn.Module):
|
||
|
def __init__(self, config: CpmAntConfig):
|
||
|
super().__init__()
|
||
|
self.num_layers = config.num_hidden_layers
|
||
|
self.layers = nn.ModuleList([CpmAntTransformerBlock(config) for ith in range(self.num_layers)])
|
||
|
|
||
|
self.output_layernorm = CpmAntLayerNorm(config)
|
||
|
|
||
|
def forward(
|
||
|
self,
|
||
|
hidden_states: torch.Tensor,
|
||
|
attention_mask: torch.Tensor,
|
||
|
position_bias: torch.Tensor,
|
||
|
output_attentions: Optional[bool] = None,
|
||
|
output_hidden_states: Optional[bool] = None,
|
||
|
past_key_values: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
||
|
use_cache: Optional[bool] = None,
|
||
|
):
|
||
|
"""
|
||
|
Args:
|
||
|
hidden_states (`torch.Tensor`):
|
||
|
Input to the layer of shape `(batch, seq_len, dim_model)`
|
||
|
attention_mask (`torch.Tensor`):
|
||
|
Avoid invalid areas to participate in the calculation of shape `(batch, seq_len, seq_len)`
|
||
|
position_bias (`torch.Tensor`):
|
||
|
Provides position information to attention mechanism of shape `(num_heads, seq_len, seq_len)`
|
||
|
output_attentions (`bool`, *optional*):
|
||
|
Whether or not to return the attentions tensors of all attention layers.
|
||
|
output_hidden_states (`bool`, *optional*):
|
||
|
Whether or not to return the hidden states of all layers.
|
||
|
past_key_values (`Tuple[torch.Tensor, torch.Tensor])`, *optional*):
|
||
|
Cached past key and value projection states
|
||
|
use_cache (`bool`, *optional*):
|
||
|
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
|
||
|
(see `past_key_values`).
|
||
|
"""
|
||
|
all_hidden_states = () if output_hidden_states else None
|
||
|
all_self_attns = () if output_attentions else None
|
||
|
current_key_values = () if use_cache else None
|
||
|
|
||
|
for i, layer in enumerate(self.layers):
|
||
|
if output_hidden_states:
|
||
|
all_hidden_states += (hidden_states,)
|
||
|
layer_outputs = layer(
|
||
|
hidden_states,
|
||
|
attention_mask,
|
||
|
position_bias,
|
||
|
output_attentions=output_attentions,
|
||
|
past_key_values=past_key_values[i] if past_key_values else None,
|
||
|
use_cache=use_cache,
|
||
|
)
|
||
|
hidden_states, attn_weights, current_key_value = layer_outputs
|
||
|
if output_attentions:
|
||
|
all_self_attns += (attn_weights,)
|
||
|
if current_key_value is not None:
|
||
|
current_key_values = current_key_values + (current_key_value,)
|
||
|
|
||
|
hidden_states = self.output_layernorm(hidden_states)
|
||
|
|
||
|
if output_hidden_states:
|
||
|
all_hidden_states += (hidden_states,)
|
||
|
|
||
|
return hidden_states, current_key_values, all_hidden_states, all_self_attns
|
||
|
|
||
|
|
||
|
# Copied from transformers.models.bert.modeling_bert.BertIntermediate with Bert->CPMAnt
|
||
|
class CpmAntIntermediate(nn.Module):
|
||
|
def __init__(self, config):
|
||
|
super().__init__()
|
||
|
self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
|
||
|
if isinstance(config.hidden_act, str):
|
||
|
self.intermediate_act_fn = ACT2FN[config.hidden_act]
|
||
|
else:
|
||
|
self.intermediate_act_fn = config.hidden_act
|
||
|
|
||
|
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||
|
hidden_states = self.dense(hidden_states)
|
||
|
hidden_states = self.intermediate_act_fn(hidden_states)
|
||
|
return hidden_states
|
||
|
|
||
|
|
||
|
class CpmAntSegmentPositionEmbedding(nn.Module):
|
||
|
def __init__(self, config: CpmAntConfig):
|
||
|
super().__init__()
|
||
|
|
||
|
self.num_heads = config.num_attention_heads
|
||
|
self.num_buckets = config.position_bias_num_buckets
|
||
|
self.max_distance = config.position_bias_max_distance
|
||
|
self.num_segments = config.segment_types
|
||
|
|
||
|
self.relative_attention_bias = nn.Parameter(
|
||
|
torch.empty(
|
||
|
config.segment_types * config.segment_types + config.position_bias_num_buckets,
|
||
|
config.num_attention_heads,
|
||
|
)
|
||
|
)
|
||
|
|
||
|
def forward(
|
||
|
self,
|
||
|
key_pos: torch.Tensor,
|
||
|
query_pos: torch.Tensor,
|
||
|
key_segment: torch.Tensor,
|
||
|
query_segment: torch.Tensor,
|
||
|
):
|
||
|
with torch.no_grad():
|
||
|
batch = key_pos.size(0)
|
||
|
keylen = key_pos.size(1)
|
||
|
querylen = query_pos.size(1)
|
||
|
|
||
|
if key_pos.size(0) != query_pos.size(0):
|
||
|
raise AssertionError(
|
||
|
f"key_pos.size(0) should be equal to query_pos.size(0), but got {key_pos.size(0)} and {query_pos.size(0)}!"
|
||
|
)
|
||
|
if keylen != key_segment.size(1) or querylen != query_segment.size(1):
|
||
|
raise AssertionError(
|
||
|
f"keylen should be equal to key_segment.size(1), but got {keylen} and {key_segment.size(1)}!"
|
||
|
)
|
||
|
if querylen != query_segment.size(1):
|
||
|
raise AssertionError(
|
||
|
f"querylen should be equal to query_segment.size(1), but got {querylen} and {query_segment.szie(1)}!"
|
||
|
)
|
||
|
|
||
|
key_pos = key_pos.view(batch, -1, keylen)
|
||
|
query_pos = query_pos.view(batch, querylen, -1)
|
||
|
key_segment = key_segment.view(batch, -1, keylen)
|
||
|
query_segment = query_segment.view(batch, querylen, -1)
|
||
|
|
||
|
relative_position_bucket = self._segment_relative_position_bucket(query_segment, key_segment)
|
||
|
relative_position_bucket = relative_position_bucket + self.num_buckets
|
||
|
|
||
|
# (batch, len_q, len_k)
|
||
|
absolute_position_bucket = self._position_bucket(
|
||
|
torch.arange(keylen, dtype=torch.int32, device=relative_position_bucket.device)[None, :]
|
||
|
- torch.arange(querylen, dtype=torch.int32, device=relative_position_bucket.device)[:, None],
|
||
|
num_buckets=self.num_buckets,
|
||
|
max_distance=self.max_distance,
|
||
|
)
|
||
|
relative_position_bucket = torch.where(
|
||
|
(key_segment == query_segment),
|
||
|
absolute_position_bucket[None, :, :],
|
||
|
relative_position_bucket,
|
||
|
)
|
||
|
|
||
|
# (batch, len_q, len_k, num_heads)
|
||
|
embeds = F.embedding(relative_position_bucket, self.relative_attention_bias)
|
||
|
# (batch, num_heads, len_q, len_k)
|
||
|
embeds = embeds.permute(0, 3, 1, 2).contiguous()
|
||
|
return embeds
|
||
|
|
||
|
def _segment_relative_position_bucket(self, query_segment, key_segment):
|
||
|
return query_segment * self.num_segments + key_segment
|
||
|
|
||
|
def _position_bucket(self, relative_position, num_buckets=32, max_distance=128):
|
||
|
relative_buckets = 0
|
||
|
# always bidirectional in CPMAnt
|
||
|
num_buckets //= 2
|
||
|
relative_buckets = (relative_position > 0).to(torch.int32) * num_buckets
|
||
|
relative_position = torch.abs(relative_position)
|
||
|
max_exact = num_buckets // 2
|
||
|
is_small = relative_position < max_exact
|
||
|
relative_postion_if_large = max_exact + (
|
||
|
torch.log(relative_position.float() / max_exact)
|
||
|
/ math.log(max_distance / max_exact)
|
||
|
* (num_buckets - max_exact)
|
||
|
).to(torch.int32)
|
||
|
relative_postion_if_large = torch.min(
|
||
|
relative_postion_if_large,
|
||
|
torch.full_like(relative_postion_if_large, num_buckets - 1),
|
||
|
)
|
||
|
relative_buckets += torch.where(is_small, relative_position.to(torch.int32), relative_postion_if_large)
|
||
|
return relative_buckets
|
||
|
|
||
|
|
||
|
# Copied from transformers.models.bert.modeling_bert.BertOutput with Bert->CPMAnt
|
||
|
class CpmAntOutput(nn.Module):
|
||
|
def __init__(self, config):
|
||
|
super().__init__()
|
||
|
self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
|
||
|
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
||
|
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
||
|
|
||
|
def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
|
||
|
hidden_states = self.dense(hidden_states)
|
||
|
hidden_states = self.dropout(hidden_states)
|
||
|
hidden_states = self.LayerNorm(hidden_states + input_tensor)
|
||
|
return hidden_states
|
||
|
|
||
|
|
||
|
class CpmAntPreTrainedModel(PreTrainedModel):
|
||
|
"""
|
||
|
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
|
||
|
models.
|
||
|
"""
|
||
|
|
||
|
config_class = CpmAntConfig
|
||
|
base_model_prefix = "cpmant"
|
||
|
|
||
|
def _init_weights(self, module):
|
||
|
"""Initialize the weights"""
|
||
|
if isinstance(module, nn.Linear):
|
||
|
module.weight.data.normal_(mean=0.0, std=self.config.init_std)
|
||
|
if module.bias is not None:
|
||
|
module.bias.data.zero_()
|
||
|
elif isinstance(module, nn.Embedding):
|
||
|
module.weight.data.normal_(mean=0.0, std=self.config.init_std)
|
||
|
if module.padding_idx is not None:
|
||
|
module.weight.data[module.padding_idx].zero_()
|
||
|
elif isinstance(module, nn.LayerNorm):
|
||
|
module.bias.data.zero_()
|
||
|
module.weight.data.fill_(1.0)
|
||
|
elif isinstance(module, CpmAntLayerNorm):
|
||
|
module.weight.data.fill_(1.0)
|
||
|
elif isinstance(module, CpmAntSegmentPositionEmbedding):
|
||
|
module.relative_attention_bias.data.normal_(mean=0.0, std=self.config.init_std)
|
||
|
|
||
|
|
||
|
CPMANT_START_DOCSTRING = r"""
|
||
|
This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) sub-class. Use
|
||
|
it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and
|
||
|
behavior.
|
||
|
|
||
|
Parameters
|
||
|
config ([`~CpmAntConfig`]): Model configuration class with all the parameters of the
|
||
|
Initializing with a config file does not load the weights associated with the model, only the
|
||
|
configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
|
||
|
"""
|
||
|
|
||
|
CPMANT_INPUTS_DOCSTRING = r"""
|
||
|
Args:
|
||
|
input_ids (`torch.Tensor` of shape `(batch_size, seq_len)`):
|
||
|
Indices of input sequence tokens in the vocabulary.
|
||
|
|
||
|
Indices can be obtained using [`CPMAntTokenizer`]. See [`PreTrainedTokenizer.encode`] and
|
||
|
[`PreTrainedTokenizer.__call__`] for details.
|
||
|
|
||
|
[What are input IDs?](../glossary#input-ids)
|
||
|
past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
|
||
|
Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
|
||
|
blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
|
||
|
use_cache (`bool`, *optional*):
|
||
|
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
|
||
|
`past_key_values`).
|
||
|
output_attentions (`bool`, *optional*):
|
||
|
Whether or not to return the attentions tensors of all attention layers.
|
||
|
output_hidden_states (`bool`, *optional*):
|
||
|
Whether or not to return the hidden states of all layers.
|
||
|
return_dict (`bool`, *optional*):
|
||
|
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
||
|
"""
|
||
|
|
||
|
|
||
|
@add_start_docstrings(
|
||
|
"The bare CPMAnt Model outputting raw hidden-states without any specific head on top.",
|
||
|
CPMANT_START_DOCSTRING,
|
||
|
)
|
||
|
class CpmAntModel(CpmAntPreTrainedModel):
|
||
|
def __init__(self, config: CpmAntConfig):
|
||
|
super().__init__(config)
|
||
|
self.encoder = CpmAntEncoder(config)
|
||
|
self.segment_embedding = nn.Embedding(config.segment_types, config.hidden_size)
|
||
|
self.input_embedding = nn.Embedding(
|
||
|
config.vocab_size + config.prompt_types * config.prompt_length, config.hidden_size
|
||
|
)
|
||
|
self.position_bias = CpmAntSegmentPositionEmbedding(config)
|
||
|
self.prompt_length = config.prompt_length
|
||
|
self.vocab_size = config.vocab_size
|
||
|
|
||
|
self.post_init()
|
||
|
|
||
|
def get_input_embeddings(self):
|
||
|
return self.input_embedding
|
||
|
|
||
|
def set_input_embeddings(self, embeddings, **kwargs):
|
||
|
self.input_embedding = embeddings
|
||
|
|
||
|
def _prepare_attention_mask(self, input_ids, span, context, length):
|
||
|
batch = input_ids.size(0)
|
||
|
seqlen = input_ids.size(1)
|
||
|
device = input_ids.device
|
||
|
directional_mask_2d = torch.arange(seqlen, device=device) <= torch.arange(seqlen, device=device).view(-1, 1)
|
||
|
attention_mask = context[:, None, :] | (
|
||
|
context[:, :, None].logical_not() & directional_mask_2d.view(1, seqlen, seqlen)
|
||
|
)
|
||
|
attention_mask = attention_mask & (span[:, None, :] == span[:, :, None])
|
||
|
# mask for left padding
|
||
|
mask_1d = (
|
||
|
torch.tensor(list(range(seqlen - self.prompt_length))[::-1], device=device)[None, :].repeat(batch, 1)
|
||
|
< length[:, None]
|
||
|
)
|
||
|
mask_1d = torch.cat((torch.ones(batch, self.prompt_length, device=device).bool(), mask_1d), dim=1)
|
||
|
attention_mask = mask_1d.view(batch, seqlen, 1) & mask_1d.view(batch, 1, seqlen) & attention_mask
|
||
|
return attention_mask
|
||
|
|
||
|
@add_start_docstrings_to_model_forward(CPMANT_INPUTS_DOCSTRING)
|
||
|
@add_code_sample_docstrings(
|
||
|
checkpoint=_CHECKPOINT_FOR_DOC,
|
||
|
output_type=BaseModelOutputWithPast,
|
||
|
config_class=_CONFIG_FOR_DOC,
|
||
|
)
|
||
|
def forward(
|
||
|
self,
|
||
|
input_ids: Optional[torch.Tensor] = None,
|
||
|
output_attentions: Optional[bool] = None,
|
||
|
output_hidden_states: Optional[bool] = None,
|
||
|
past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
|
||
|
use_cache: Optional[bool] = None,
|
||
|
return_dict: Optional[bool] = None,
|
||
|
**kwargs,
|
||
|
) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPast]:
|
||
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||
|
output_hidden_states = (
|
||
|
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||
|
)
|
||
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||
|
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
||
|
|
||
|
# add prompts ahead
|
||
|
if input_ids.dtype != torch.int32:
|
||
|
input_ids = input_ids.to(torch.int32)
|
||
|
dtype, device = input_ids.dtype, input_ids.device
|
||
|
segment = torch.where(input_ids != 0, 2, 0).to(dtype=dtype, device=device)
|
||
|
length = (segment != 0).sum(-1).to(dtype=dtype, device=device)
|
||
|
input_ids = torch.cat(
|
||
|
(
|
||
|
torch.arange(
|
||
|
self.prompt_length * 2 + self.vocab_size,
|
||
|
self.prompt_length * 3 + self.vocab_size,
|
||
|
dtype=dtype,
|
||
|
device=device,
|
||
|
).repeat(input_ids.size(0), 1),
|
||
|
input_ids,
|
||
|
),
|
||
|
dim=1,
|
||
|
)
|
||
|
batch, seq_length = input_ids.size()
|
||
|
segment = torch.cat((torch.zeros(batch, self.prompt_length, dtype=dtype, device=device), segment), dim=1)
|
||
|
context = torch.full((batch, seq_length), 1, dtype=dtype, device=device)
|
||
|
position = torch.arange(seq_length, dtype=dtype, device=device).repeat(batch, 1)
|
||
|
span = torch.full((batch, seq_length), 0, dtype=dtype, device=device)
|
||
|
|
||
|
if past_key_values is None:
|
||
|
past_length = 0
|
||
|
past_key_values = tuple([None] * self.encoder.num_layers)
|
||
|
input_ids = input_ids.contiguous()
|
||
|
hidden_states = self.input_embedding(input_ids)
|
||
|
segment_states = self.segment_embedding(segment)
|
||
|
hidden_states = hidden_states + segment_states
|
||
|
else:
|
||
|
past_length = past_key_values[0][0].size(-2)
|
||
|
segment_states = self.segment_embedding(segment)
|
||
|
hidden_states = self.input_embedding(input_ids) + segment_states[:, -1:, :]
|
||
|
|
||
|
attention_mask = self._prepare_attention_mask(input_ids, span, context, length)
|
||
|
position_bias = self.position_bias(position, position, segment, segment)
|
||
|
|
||
|
attention_mask = attention_mask[:, past_length:, :]
|
||
|
position_bias = position_bias[:, :, past_length:, :]
|
||
|
hidden_states = hidden_states[:, past_length:, :]
|
||
|
|
||
|
hidden_states, present_key_values, all_hidden_states, all_attentions = self.encoder(
|
||
|
hidden_states,
|
||
|
attention_mask,
|
||
|
position_bias,
|
||
|
output_attentions,
|
||
|
output_hidden_states,
|
||
|
past_key_values,
|
||
|
use_cache,
|
||
|
)
|
||
|
|
||
|
if past_length == 0:
|
||
|
hidden_states = hidden_states[:, self.prompt_length :, :]
|
||
|
# drop the prompt
|
||
|
if all_attentions is not None:
|
||
|
new_attentions = ()
|
||
|
for attention in all_attentions:
|
||
|
new_attentions += (attention[:, :, self.prompt_length :, self.prompt_length :],)
|
||
|
all_attentions = new_attentions
|
||
|
if all_hidden_states is not None:
|
||
|
new_hidden_states = ()
|
||
|
for hidden_state in all_hidden_states:
|
||
|
new_hidden_states += (hidden_state[:, self.prompt_length :, :],)
|
||
|
all_hidden_states = new_hidden_states
|
||
|
|
||
|
if not return_dict:
|
||
|
return tuple(
|
||
|
v for v in [hidden_states, present_key_values, all_hidden_states, all_attentions] if v is not None
|
||
|
)
|
||
|
|
||
|
return BaseModelOutputWithPast(
|
||
|
last_hidden_state=hidden_states,
|
||
|
past_key_values=present_key_values,
|
||
|
hidden_states=all_hidden_states,
|
||
|
attentions=all_attentions,
|
||
|
)
|
||
|
|
||
|
|
||
|
@add_start_docstrings(
|
||
|
"""
|
||
|
The CPMAnt Model with a language modeling head on top (linear layer with weights tied to the input embeddings).
|
||
|
""",
|
||
|
CPMANT_START_DOCSTRING,
|
||
|
)
|
||
|
class CpmAntForCausalLM(CpmAntPreTrainedModel):
|
||
|
_tied_weights_keys = ["lm_head.weight"]
|
||
|
|
||
|
def __init__(self, config: CpmAntConfig):
|
||
|
super().__init__(config)
|
||
|
self.cpmant = CpmAntModel(config)
|
||
|
|
||
|
# lm_head.weight is tied to cpmant.input_embedding.weight
|
||
|
self.lm_head = nn.Linear(
|
||
|
config.hidden_size, config.vocab_size + config.prompt_types * config.prompt_length, bias=False
|
||
|
)
|
||
|
self.post_init()
|
||
|
|
||
|
@add_start_docstrings_to_model_forward(CPMANT_INPUTS_DOCSTRING)
|
||
|
@add_code_sample_docstrings(
|
||
|
checkpoint=_CHECKPOINT_FOR_DOC,
|
||
|
output_type=CausalLMOutputWithPast,
|
||
|
config_class=_CONFIG_FOR_DOC,
|
||
|
)
|
||
|
def forward(
|
||
|
self,
|
||
|
input_ids: Optional[torch.Tensor] = None,
|
||
|
past_key_values: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = None,
|
||
|
use_cache: Optional[bool] = None,
|
||
|
output_attentions: Optional[bool] = None,
|
||
|
output_hidden_states: Optional[bool] = None,
|
||
|
labels: Optional[torch.Tensor] = None,
|
||
|
return_dict: Optional[bool] = None,
|
||
|
attention_mask: Optional[torch.Tensor] = None, # dummy parameter for text-generation pipeline
|
||
|
**kwargs,
|
||
|
) -> Union[Tuple, CausalLMOutputWithPast]:
|
||
|
r"""
|
||
|
Args:
|
||
|
input_ids (`torch.Tensor` of shape `(batch_size, seq_len)`):
|
||
|
Indices of input sequence tokens in the vocabulary.
|
||
|
|
||
|
Indices can be obtained using [`CPMAntTokenizer`]. See [`PreTrainedTokenizer.encode`] and
|
||
|
[`PreTrainedTokenizer.__call__`] for details.
|
||
|
|
||
|
[What are input IDs?](../glossary#input-ids)
|
||
|
past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
|
||
|
Contains pre-computed hidden-states (key and values in the self-attention blocks and in the
|
||
|
cross-attention blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
|
||
|
use_cache (`bool`, *optional*):
|
||
|
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
|
||
|
(see `past_key_values`).
|
||
|
output_attentions (`bool`, *optional*):
|
||
|
Whether or not to return the attentions tensors of all attention layers.
|
||
|
output_hidden_states (`bool`, *optional*):
|
||
|
Whether or not to return the hidden states of all layers.
|
||
|
labels (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||
|
Labels for computing the masked language modeling loss.
|
||
|
return_dict (`bool`, *optional*):
|
||
|
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
||
|
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||
|
CPMAnt will process attention mask automatically, this parameter is a dummy parameter for
|
||
|
text-generation pipeline.
|
||
|
|
||
|
Example:
|
||
|
|
||
|
Text Generation with CpmAntForCausalLM.
|
||
|
```python
|
||
|
>>> from transformers import CPMAntTokenizer, CpmAntForCausalLM
|
||
|
|
||
|
>>> texts = "今天天气不错,"
|
||
|
>>> model = CpmAntForCausalLM.from_pretrained("openbmb/cpm-ant-10b")
|
||
|
>>> tokenizer = CPMAntTokenizer.from_pretrained("openbmb/cpm-ant-10b")
|
||
|
>>> input_ids = tokenizer(texts, return_tensors="pt")
|
||
|
>>> outputs = model.generate(**input_ids)
|
||
|
>>> output_texts = tokenizer.batch_decode(outputs)
|
||
|
>>> print(output_texts)
|
||
|
['今天天气不错,阳光明媚,我和妈妈一起去超市买东西。\n在超市里,我看到了一个很好玩的玩具,它的名字叫“机器人”。它有一个圆圆的脑袋,两只圆圆的眼睛,还有一个圆圆的']
|
||
|
```
|
||
|
"""
|
||
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||
|
|
||
|
model_output = self.cpmant(
|
||
|
input_ids, output_attentions, output_hidden_states, past_key_values, use_cache, return_dict
|
||
|
)
|
||
|
hidden_states = model_output.last_hidden_state if return_dict else model_output[0]
|
||
|
|
||
|
logits = self.lm_head(hidden_states)
|
||
|
|
||
|
loss = None
|
||
|
if labels is not None:
|
||
|
loss_func = CrossEntropyLoss()
|
||
|
loss = loss_func(logits.view(-1, logits.size(-1)), labels.view(-1))
|
||
|
|
||
|
if not return_dict:
|
||
|
output = (logits,) + model_output[1:]
|
||
|
return ((loss,) + output) if loss is not None else output
|
||
|
|
||
|
return CausalLMOutputWithPast(
|
||
|
loss=loss,
|
||
|
logits=logits,
|
||
|
past_key_values=model_output.past_key_values,
|
||
|
hidden_states=model_output.hidden_states,
|
||
|
attentions=model_output.attentions,
|
||
|
)
|
||
|
|
||
|
def get_input_embeddings(self):
|
||
|
return self.cpmant.input_embedding
|
||
|
|
||
|
def set_input_embeddings(self, embeddings):
|
||
|
self.cpmant.input_embedding = embeddings
|
||
|
|
||
|
def get_output_embeddings(self):
|
||
|
return self.lm_head
|
||
|
|
||
|
def set_output_embeddings(self, new_embeddings):
|
||
|
self.lm_head = new_embeddings
|
||
|
|
||
|
def prepare_inputs_for_generation(self, input_ids, **kwargs):
|
||
|
input_ids = input_ids.int()
|
||
|
# save the memory usage of dummy attention mask
|
||
|
if "attention_mask" in kwargs:
|
||
|
kwargs["attention_mask"] = torch.zeros(1, 1)
|
||
|
|
||
|
return {
|
||
|
"input_ids": input_ids,
|
||
|
"use_cache": kwargs["use_cache"],
|
||
|
"past_key_values": kwargs.get("past_key_values", None),
|
||
|
}
|
||
|
|
||
|
def _reorder_cache(self, past_key_values, beam_idx):
|
||
|
past_key_values = [list(each) if each is not None else each for each in past_key_values]
|
||
|
for key_value_layer in past_key_values:
|
||
|
key_value_layer[0] = key_value_layer[0][beam_idx]
|
||
|
key_value_layer[1] = key_value_layer[1][beam_idx]
|
||
|
return past_key_values
|