1386 lines
64 KiB
Python
1386 lines
64 KiB
Python
# coding=utf-8
|
|
# Copyright 2024 EleutherAI and the HuggingFace Inc. team. All rights reserved.
|
|
#
|
|
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
|
|
# and OPT implementations in this library. It has been modified from its
|
|
# original forms to accommodate minor architectural differences compared
|
|
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
""" PyTorch StableLM model."""
|
|
import math
|
|
from typing import List, Optional, Tuple, Union
|
|
|
|
import torch
|
|
import torch.nn.functional as F
|
|
import torch.utils.checkpoint
|
|
from torch import nn
|
|
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
|
|
|
from ...activations import ACT2FN
|
|
from ...cache_utils import Cache, DynamicCache
|
|
from ...modeling_attn_mask_utils import _prepare_4d_causal_attention_mask, _prepare_4d_causal_attention_mask_for_sdpa
|
|
from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast
|
|
from ...modeling_utils import PreTrainedModel
|
|
from ...utils import (
|
|
add_start_docstrings,
|
|
add_start_docstrings_to_model_forward,
|
|
is_flash_attn_2_available,
|
|
is_flash_attn_greater_or_equal_2_10,
|
|
logging,
|
|
replace_return_docstrings,
|
|
)
|
|
from .configuration_stablelm import StableLmConfig
|
|
|
|
|
|
if is_flash_attn_2_available():
|
|
from flash_attn import flash_attn_func, flash_attn_varlen_func
|
|
from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
|
|
|
|
|
|
logger = logging.get_logger(__name__)
|
|
|
|
_CONFIG_FOR_DOC = "StableLmConfig"
|
|
|
|
|
|
# Copied from transformers.models.llama.modeling_llama._get_unpad_data
|
|
def _get_unpad_data(attention_mask):
|
|
seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
|
|
indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
|
|
max_seqlen_in_batch = seqlens_in_batch.max().item()
|
|
cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0))
|
|
return (
|
|
indices,
|
|
cu_seqlens,
|
|
max_seqlen_in_batch,
|
|
)
|
|
|
|
|
|
# Copied from transformers.models.mistral.modeling_mistral.MistralRotaryEmbedding with Mistral->StableLm
|
|
class StableLmRotaryEmbedding(nn.Module):
|
|
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
|
|
super().__init__()
|
|
|
|
self.dim = dim
|
|
self.max_position_embeddings = max_position_embeddings
|
|
self.base = base
|
|
inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim))
|
|
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
|
|
|
# Build here to make `torch.jit.trace` work.
|
|
self._set_cos_sin_cache(
|
|
seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype()
|
|
)
|
|
|
|
def _set_cos_sin_cache(self, seq_len, device, dtype):
|
|
self.max_seq_len_cached = seq_len
|
|
t = torch.arange(self.max_seq_len_cached, device=device, dtype=torch.int64).type_as(self.inv_freq)
|
|
|
|
freqs = torch.outer(t, self.inv_freq)
|
|
# Different from paper, but it uses a different permutation in order to obtain the same calculation
|
|
emb = torch.cat((freqs, freqs), dim=-1)
|
|
self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
|
|
self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
|
|
|
|
def forward(self, x, seq_len=None):
|
|
# x: [bs, num_attention_heads, seq_len, head_size]
|
|
if seq_len > self.max_seq_len_cached:
|
|
self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)
|
|
|
|
return (
|
|
self.cos_cached[:seq_len].to(dtype=x.dtype),
|
|
self.sin_cached[:seq_len].to(dtype=x.dtype),
|
|
)
|
|
|
|
|
|
# Copied from transformers.models.falcon.modeling_falcon.FalconLinearScalingRotaryEmbedding with Falcon->StableLm
|
|
class StableLmLinearScalingRotaryEmbedding(StableLmRotaryEmbedding):
|
|
"""StableLmRotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev"""
|
|
|
|
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0):
|
|
self.scaling_factor = scaling_factor
|
|
super().__init__(dim, max_position_embeddings, base, device)
|
|
|
|
def _set_cos_sin_cache(self, seq_len, device, dtype):
|
|
self.max_seq_len_cached = seq_len
|
|
t = torch.arange(self.max_seq_len_cached, device=device, dtype=torch.int64).type_as(self.inv_freq)
|
|
t = t / self.scaling_factor
|
|
|
|
freqs = torch.outer(t, self.inv_freq)
|
|
# Different from paper, but it uses a different permutation in order to obtain the same calculation
|
|
emb = torch.cat((freqs, freqs), dim=-1)
|
|
self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
|
|
self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
|
|
|
|
|
|
# Copied from transformers.models.falcon.modeling_falcon.FalconDynamicNTKScalingRotaryEmbedding with Falcon->StableLm
|
|
class StableLmDynamicNTKScalingRotaryEmbedding(StableLmRotaryEmbedding):
|
|
"""StableLmRotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla"""
|
|
|
|
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0):
|
|
self.scaling_factor = scaling_factor
|
|
super().__init__(dim, max_position_embeddings, base, device)
|
|
|
|
def _set_cos_sin_cache(self, seq_len, device, dtype):
|
|
self.max_seq_len_cached = seq_len
|
|
|
|
if seq_len > self.max_position_embeddings:
|
|
base = self.base * (
|
|
(self.scaling_factor * seq_len / self.max_position_embeddings) - (self.scaling_factor - 1)
|
|
) ** (self.dim / (self.dim - 2))
|
|
inv_freq = 1.0 / (base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim))
|
|
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
|
|
|
t = torch.arange(self.max_seq_len_cached, device=device, dtype=torch.int64).type_as(self.inv_freq)
|
|
|
|
freqs = torch.outer(t, self.inv_freq)
|
|
# Different from paper, but it uses a different permutation in order to obtain the same calculation
|
|
emb = torch.cat((freqs, freqs), dim=-1)
|
|
self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
|
|
self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
|
|
|
|
|
|
# Copied from transformers.models.llama.modeling_llama.rotate_half
|
|
def rotate_half(x):
|
|
"""Rotates half the hidden dims of the input."""
|
|
x1 = x[..., : x.shape[-1] // 2]
|
|
x2 = x[..., x.shape[-1] // 2 :]
|
|
return torch.cat((-x2, x1), dim=-1)
|
|
|
|
|
|
# Copied from transformers.models.mistral.modeling_mistral.apply_rotary_pos_emb
|
|
def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
|
|
"""Applies Rotary Position Embedding to the query and key tensors.
|
|
|
|
Args:
|
|
q (`torch.Tensor`): The query tensor.
|
|
k (`torch.Tensor`): The key tensor.
|
|
cos (`torch.Tensor`): The cosine part of the rotary embedding.
|
|
sin (`torch.Tensor`): The sine part of the rotary embedding.
|
|
position_ids (`torch.Tensor`):
|
|
The position indices of the tokens corresponding to the query and key tensors. For example, this can be
|
|
used to pass offsetted position ids when working with a KV-cache.
|
|
unsqueeze_dim (`int`, *optional*, defaults to 1):
|
|
The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
|
|
sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
|
|
that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
|
|
k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
|
|
cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
|
|
the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
|
|
Returns:
|
|
`tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
|
|
"""
|
|
cos = cos[position_ids].unsqueeze(unsqueeze_dim)
|
|
sin = sin[position_ids].unsqueeze(unsqueeze_dim)
|
|
q_embed = (q * cos) + (rotate_half(q) * sin)
|
|
k_embed = (k * cos) + (rotate_half(k) * sin)
|
|
return q_embed, k_embed
|
|
|
|
|
|
# Copied from transformers.models.mistral.modeling_mistral.MistralMLP with Mistral->StableLm
|
|
class StableLmMLP(nn.Module):
|
|
def __init__(self, config):
|
|
super().__init__()
|
|
self.config = config
|
|
self.hidden_size = config.hidden_size
|
|
self.intermediate_size = config.intermediate_size
|
|
self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
|
|
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
|
|
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
|
|
self.act_fn = ACT2FN[config.hidden_act]
|
|
|
|
def forward(self, x):
|
|
return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
|
|
|
|
|
|
class StableLmLayerNormPerHead(nn.Module):
|
|
def __init__(self, dim, num_heads, eps=1e-5, bias=False):
|
|
super().__init__()
|
|
self.dim = dim
|
|
self.num_heads = num_heads
|
|
self.norms = nn.ModuleList([nn.LayerNorm(dim, eps=eps, bias=bias) for _ in range(self.num_heads)])
|
|
|
|
def forward(self, hidden_states: torch.Tensor):
|
|
# Split along the num_heads axis to get per-head inputs
|
|
# [batch_size, num_heads, seq_len, head_dim] -> [batch_size, 1, seq_len, head_dim] * num_heads
|
|
states_per_heads = torch.split(hidden_states, 1, dim=1)
|
|
# Normalize and merge the heads back together
|
|
return torch.cat([norm(hidden_states) for norm, hidden_states in zip(self.norms, states_per_heads)], dim=1)
|
|
|
|
|
|
# Copied from transformers.models.llama.modeling_llama.repeat_kv
|
|
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
|
|
"""
|
|
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
|
|
num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
|
|
"""
|
|
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
|
|
if n_rep == 1:
|
|
return hidden_states
|
|
hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
|
|
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
|
|
|
|
|
|
class StableLmAttention(nn.Module):
|
|
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
|
|
|
def __init__(self, config: StableLmConfig, layer_idx: Optional[int] = None):
|
|
super().__init__()
|
|
self.config = config
|
|
self.layer_idx = layer_idx
|
|
if layer_idx is None:
|
|
logger.warning_once(
|
|
f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will "
|
|
"lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` "
|
|
"when creating this class."
|
|
)
|
|
|
|
self.hidden_size = config.hidden_size
|
|
self.num_heads = config.num_attention_heads
|
|
self.head_dim = self.hidden_size // self.num_heads
|
|
self.num_key_value_heads = config.num_key_value_heads
|
|
self.num_key_value_groups = self.num_heads // self.num_key_value_heads
|
|
self.max_position_embeddings = config.max_position_embeddings
|
|
self.rope_theta = config.rope_theta
|
|
self.partial_rotary_factor = config.partial_rotary_factor
|
|
self.is_causal = True
|
|
|
|
if (self.head_dim * self.num_heads) != self.hidden_size:
|
|
raise ValueError(
|
|
f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
|
|
f" and `num_heads`: {self.num_heads})."
|
|
)
|
|
self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.use_qkv_bias)
|
|
self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.use_qkv_bias)
|
|
self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.use_qkv_bias)
|
|
self.o_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=False)
|
|
|
|
self.qk_layernorm = config.qk_layernorm
|
|
if self.qk_layernorm:
|
|
self.q_layernorm = StableLmLayerNormPerHead(self.head_dim, self.num_heads, eps=config.layer_norm_eps)
|
|
self.k_layernorm = StableLmLayerNormPerHead(
|
|
self.head_dim, self.num_key_value_heads, eps=config.layer_norm_eps
|
|
)
|
|
|
|
self.attention_dropout = nn.Dropout(config.attention_dropout)
|
|
self._init_rope()
|
|
|
|
# Copied from transformers.models.persimmon.modeling_persimmon.PersimmonAttention._init_rope with Persimmon->StableLm
|
|
def _init_rope(self):
|
|
if self.config.rope_scaling is None:
|
|
self.rotary_emb = StableLmRotaryEmbedding(
|
|
int(self.partial_rotary_factor * self.head_dim),
|
|
max_position_embeddings=self.max_position_embeddings,
|
|
base=self.rope_theta,
|
|
)
|
|
else:
|
|
scaling_type = self.config.rope_scaling["type"]
|
|
scaling_factor = self.config.rope_scaling["factor"]
|
|
if scaling_type == "linear":
|
|
self.rotary_emb = StableLmLinearScalingRotaryEmbedding(
|
|
int(self.partial_rotary_factor * self.head_dim),
|
|
max_position_embeddings=self.max_position_embeddings,
|
|
scaling_factor=scaling_factor,
|
|
base=self.rope_theta,
|
|
)
|
|
elif scaling_type == "dynamic":
|
|
self.rotary_emb = StableLmDynamicNTKScalingRotaryEmbedding(
|
|
int(self.partial_rotary_factor * self.head_dim),
|
|
max_position_embeddings=self.max_position_embeddings,
|
|
scaling_factor=scaling_factor,
|
|
base=self.rope_theta,
|
|
)
|
|
else:
|
|
raise ValueError(f"Unknown RoPE scaling type {scaling_type}")
|
|
|
|
def forward(
|
|
self,
|
|
hidden_states: torch.Tensor,
|
|
attention_mask: Optional[torch.Tensor] = None,
|
|
position_ids: Optional[torch.LongTensor] = None,
|
|
past_key_value: Optional[Cache] = None,
|
|
output_attentions: bool = False,
|
|
use_cache: bool = False,
|
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
|
bsz, q_len, _ = hidden_states.size()
|
|
|
|
query_states = self.q_proj(hidden_states)
|
|
key_states = self.k_proj(hidden_states)
|
|
value_states = self.v_proj(hidden_states)
|
|
|
|
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
|
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
|
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
|
|
|
if self.qk_layernorm:
|
|
query_states = self.q_layernorm(query_states)
|
|
key_states = self.k_layernorm(key_states)
|
|
|
|
kv_seq_len = key_states.shape[-2]
|
|
if past_key_value is not None:
|
|
if self.layer_idx is None:
|
|
raise ValueError(
|
|
f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} "
|
|
"for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
|
|
"with a layer index."
|
|
)
|
|
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
|
|
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
|
|
|
# Partial rotary embedding
|
|
query_rot, query_pass = (
|
|
query_states[..., : self.rotary_emb.dim],
|
|
query_states[..., self.rotary_emb.dim :],
|
|
)
|
|
key_rot, key_pass = (
|
|
key_states[..., : self.rotary_emb.dim],
|
|
key_states[..., self.rotary_emb.dim :],
|
|
)
|
|
# [batch_size, seq_length, num_heads, head_dim // config.partial_rotary_factor]
|
|
query_rot, key_rot = apply_rotary_pos_emb(query_rot, key_rot, cos, sin, position_ids)
|
|
|
|
# [batch_size, seq_length, num_heads, head_dim]
|
|
query_states = torch.cat((query_rot, query_pass), dim=-1)
|
|
key_states = torch.cat((key_rot, key_pass), dim=-1)
|
|
|
|
if past_key_value is not None:
|
|
# Specific to RoPE models with partial rotation
|
|
cache_kwargs = {"sin": sin, "cos": cos, "partial_rotation_size": self.rotary_emb.dim}
|
|
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
|
|
|
|
# Repeat k/v heads if n_kv_heads < n_heads
|
|
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
|
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
|
|
|
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
|
|
|
|
if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
|
|
raise ValueError(
|
|
f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is"
|
|
f" {attn_weights.size()}"
|
|
)
|
|
|
|
if attention_mask is not None:
|
|
if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
|
|
raise ValueError(
|
|
f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
|
|
)
|
|
attn_weights = attn_weights + attention_mask
|
|
|
|
# upcast attention to fp32
|
|
attn_weights = nn.functional.softmax(attn_weights, dtype=torch.float32, dim=-1).to(query_states.dtype)
|
|
attn_weights = self.attention_dropout(attn_weights)
|
|
|
|
attn_output = torch.matmul(attn_weights, value_states)
|
|
|
|
if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
|
|
raise ValueError(
|
|
f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
|
|
f" {attn_output.size()}"
|
|
)
|
|
|
|
attn_output = attn_output.transpose(1, 2).contiguous()
|
|
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
|
|
|
|
attn_output = self.o_proj(attn_output)
|
|
|
|
if not output_attentions:
|
|
attn_weights = None
|
|
|
|
return attn_output, attn_weights, past_key_value
|
|
|
|
|
|
class StableLmSdpaAttention(StableLmAttention):
|
|
def forward(
|
|
self,
|
|
hidden_states: torch.Tensor,
|
|
attention_mask: Optional[torch.Tensor] = None,
|
|
position_ids: Optional[torch.LongTensor] = None,
|
|
past_key_value: Optional[Cache] = None,
|
|
output_attentions: bool = False,
|
|
use_cache: bool = False,
|
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
|
if output_attentions:
|
|
# TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented.
|
|
logger.warning_once(
|
|
"StableLmModel is using StableLmSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, "
|
|
'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
|
|
)
|
|
return super().forward(
|
|
hidden_states=hidden_states,
|
|
attention_mask=attention_mask,
|
|
position_ids=position_ids,
|
|
past_key_value=past_key_value,
|
|
output_attentions=output_attentions,
|
|
use_cache=use_cache,
|
|
)
|
|
|
|
bsz, q_len, _ = hidden_states.size()
|
|
|
|
query_states = self.q_proj(hidden_states)
|
|
key_states = self.k_proj(hidden_states)
|
|
value_states = self.v_proj(hidden_states)
|
|
|
|
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
|
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
|
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
|
|
|
if self.qk_layernorm:
|
|
query_states = self.q_layernorm(query_states)
|
|
key_states = self.k_layernorm(key_states)
|
|
|
|
kv_seq_len = key_states.shape[-2]
|
|
if past_key_value is not None:
|
|
if self.layer_idx is None:
|
|
raise ValueError(
|
|
f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} "
|
|
"for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
|
|
"with a layer index."
|
|
)
|
|
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
|
|
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
|
|
|
# Partial rotary embedding
|
|
query_rot, query_pass = (
|
|
query_states[..., : self.rotary_emb.dim],
|
|
query_states[..., self.rotary_emb.dim :],
|
|
)
|
|
key_rot, key_pass = (
|
|
key_states[..., : self.rotary_emb.dim],
|
|
key_states[..., self.rotary_emb.dim :],
|
|
)
|
|
# [batch_size, seq_length, num_heads, head_dim // config.partial_rotary_factor]
|
|
query_rot, key_rot = apply_rotary_pos_emb(query_rot, key_rot, cos, sin, position_ids)
|
|
|
|
# [batch_size, seq_length, num_heads, head_dim]
|
|
query_states = torch.cat((query_rot, query_pass), dim=-1)
|
|
key_states = torch.cat((key_rot, key_pass), dim=-1)
|
|
|
|
if past_key_value is not None:
|
|
# Specific to RoPE models with partial rotation
|
|
cache_kwargs = {"sin": sin, "cos": cos, "partial_rotation_size": self.rotary_emb.dim}
|
|
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
|
|
|
|
# Repeat k/v heads if n_kv_heads < n_heads
|
|
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
|
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
|
|
|
# SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
|
|
# Reference: https://github.com/pytorch/pytorch/issues/112577.
|
|
if query_states.device.type == "cuda" and attention_mask is not None:
|
|
query_states = query_states.contiguous()
|
|
key_states = key_states.contiguous()
|
|
value_states = value_states.contiguous()
|
|
|
|
attn_output = torch.nn.functional.scaled_dot_product_attention(
|
|
query_states,
|
|
key_states,
|
|
value_states,
|
|
attn_mask=attention_mask,
|
|
dropout_p=self.attention_dropout.p if self.training else 0.0,
|
|
# The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case q_len == 1.
|
|
is_causal=self.is_causal and attention_mask is None and q_len > 1,
|
|
)
|
|
|
|
attn_output = attn_output.transpose(1, 2).contiguous()
|
|
attn_output = attn_output.view(bsz, q_len, self.hidden_size)
|
|
|
|
attn_output = self.o_proj(attn_output)
|
|
|
|
return attn_output, None, past_key_value
|
|
|
|
|
|
class StableLmFlashAttention2(StableLmAttention):
|
|
"""
|
|
StableLM flash attention module. This module inherits from `StableLmAttention` as the weights of the module stays
|
|
untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
|
|
flash attention and deal with padding tokens in case the input contains any of them.
|
|
"""
|
|
|
|
# Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__
|
|
def __init__(self, *args, **kwargs):
|
|
super().__init__(*args, **kwargs)
|
|
|
|
# TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
|
|
# flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
|
|
# Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
|
|
self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
|
|
|
|
def forward(
|
|
self,
|
|
hidden_states: torch.Tensor,
|
|
attention_mask: Optional[torch.LongTensor] = None,
|
|
position_ids: Optional[torch.LongTensor] = None,
|
|
past_key_value: Optional[Cache] = None,
|
|
output_attentions: bool = False,
|
|
use_cache: bool = False,
|
|
**kwargs,
|
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
|
# StableLmFlashAttention2 attention does not support output_attentions
|
|
|
|
output_attentions = False
|
|
|
|
bsz, q_len, _ = hidden_states.size()
|
|
|
|
query_states = self.q_proj(hidden_states)
|
|
key_states = self.k_proj(hidden_states)
|
|
value_states = self.v_proj(hidden_states)
|
|
|
|
# Flash attention requires the input to have the shape
|
|
# batch_size x seq_length x head_dim x hidden_dim
|
|
# therefore we just need to keep the original shape
|
|
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
|
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
|
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
|
|
|
if self.qk_layernorm:
|
|
query_states = self.q_layernorm(query_states)
|
|
key_states = self.k_layernorm(key_states)
|
|
|
|
kv_seq_len = key_states.shape[-2]
|
|
if past_key_value is not None:
|
|
if self.layer_idx is None:
|
|
raise ValueError(
|
|
f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} "
|
|
"for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
|
|
"with a layer index."
|
|
)
|
|
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
|
|
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
|
|
|
# Partial rotary embedding
|
|
query_rot, query_pass = (
|
|
query_states[..., : self.rotary_emb.dim],
|
|
query_states[..., self.rotary_emb.dim :],
|
|
)
|
|
key_rot, key_pass = (
|
|
key_states[..., : self.rotary_emb.dim],
|
|
key_states[..., self.rotary_emb.dim :],
|
|
)
|
|
query_rot, key_rot = apply_rotary_pos_emb(query_rot, key_rot, cos, sin, position_ids)
|
|
|
|
# [batch_size, seq_length, num_heads, head_dim]
|
|
query_states = torch.cat((query_rot, query_pass), dim=-1)
|
|
key_states = torch.cat((key_rot, key_pass), dim=-1)
|
|
|
|
if past_key_value is not None:
|
|
cache_kwargs = {"sin": sin, "cos": cos, "partial_rotation_size": self.rotary_emb.dim}
|
|
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
|
|
|
|
# TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache
|
|
# to be able to avoid many of these transpose/reshape/view.
|
|
query_states = query_states.transpose(1, 2)
|
|
key_states = key_states.transpose(1, 2)
|
|
value_states = value_states.transpose(1, 2)
|
|
|
|
dropout_rate = self.attention_dropout.p if self.training else 0.0
|
|
|
|
attn_output = self._flash_attention_forward(
|
|
query_states,
|
|
key_states,
|
|
value_states,
|
|
attention_mask,
|
|
q_len,
|
|
dropout=dropout_rate,
|
|
)
|
|
|
|
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous()
|
|
attn_output = self.o_proj(attn_output)
|
|
|
|
if not output_attentions:
|
|
attn_weights = None
|
|
|
|
return attn_output, attn_weights, past_key_value
|
|
|
|
# Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2._flash_attention_forward
|
|
def _flash_attention_forward(
|
|
self, query_states, key_states, value_states, attention_mask, query_length, dropout=0.0, softmax_scale=None
|
|
):
|
|
"""
|
|
Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token
|
|
first unpad the input, then computes the attention scores and pad the final attention scores.
|
|
|
|
Args:
|
|
query_states (`torch.Tensor`):
|
|
Input query states to be passed to Flash Attention API
|
|
key_states (`torch.Tensor`):
|
|
Input key states to be passed to Flash Attention API
|
|
value_states (`torch.Tensor`):
|
|
Input value states to be passed to Flash Attention API
|
|
attention_mask (`torch.Tensor`):
|
|
The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the
|
|
position of padding tokens and 1 for the position of non-padding tokens.
|
|
dropout (`float`):
|
|
Attention dropout
|
|
softmax_scale (`float`, *optional*):
|
|
The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim)
|
|
"""
|
|
if not self._flash_attn_uses_top_left_mask:
|
|
causal = self.is_causal
|
|
else:
|
|
# TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in LlamaFlashAttention2 __init__.
|
|
causal = self.is_causal and query_length != 1
|
|
|
|
# Contains at least one padding token in the sequence
|
|
if attention_mask is not None:
|
|
batch_size = query_states.shape[0]
|
|
query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input(
|
|
query_states, key_states, value_states, attention_mask, query_length
|
|
)
|
|
|
|
cu_seqlens_q, cu_seqlens_k = cu_seq_lens
|
|
max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
|
|
|
|
attn_output_unpad = flash_attn_varlen_func(
|
|
query_states,
|
|
key_states,
|
|
value_states,
|
|
cu_seqlens_q=cu_seqlens_q,
|
|
cu_seqlens_k=cu_seqlens_k,
|
|
max_seqlen_q=max_seqlen_in_batch_q,
|
|
max_seqlen_k=max_seqlen_in_batch_k,
|
|
dropout_p=dropout,
|
|
softmax_scale=softmax_scale,
|
|
causal=causal,
|
|
)
|
|
|
|
attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length)
|
|
else:
|
|
attn_output = flash_attn_func(
|
|
query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=causal
|
|
)
|
|
|
|
return attn_output
|
|
|
|
# Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2._upad_input
|
|
def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length):
|
|
indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask)
|
|
batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape
|
|
|
|
key_layer = index_first_axis(
|
|
key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k
|
|
)
|
|
value_layer = index_first_axis(
|
|
value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k
|
|
)
|
|
if query_length == kv_seq_len:
|
|
query_layer = index_first_axis(
|
|
query_layer.reshape(batch_size * kv_seq_len, self.num_heads, head_dim), indices_k
|
|
)
|
|
cu_seqlens_q = cu_seqlens_k
|
|
max_seqlen_in_batch_q = max_seqlen_in_batch_k
|
|
indices_q = indices_k
|
|
elif query_length == 1:
|
|
max_seqlen_in_batch_q = 1
|
|
cu_seqlens_q = torch.arange(
|
|
batch_size + 1, dtype=torch.int32, device=query_layer.device
|
|
) # There is a memcpy here, that is very bad.
|
|
indices_q = cu_seqlens_q[:-1]
|
|
query_layer = query_layer.squeeze(1)
|
|
else:
|
|
# The -q_len: slice assumes left padding.
|
|
attention_mask = attention_mask[:, -query_length:]
|
|
query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask)
|
|
|
|
return (
|
|
query_layer,
|
|
key_layer,
|
|
value_layer,
|
|
indices_q,
|
|
(cu_seqlens_q, cu_seqlens_k),
|
|
(max_seqlen_in_batch_q, max_seqlen_in_batch_k),
|
|
)
|
|
|
|
|
|
ATTENTION_CLASSES = {
|
|
"eager": StableLmAttention,
|
|
"sdpa": StableLmSdpaAttention,
|
|
"flash_attention_2": StableLmFlashAttention2,
|
|
}
|
|
|
|
|
|
class StableLmDecoderLayer(nn.Module):
|
|
def __init__(self, config: StableLmConfig, layer_idx: int):
|
|
super().__init__()
|
|
self.use_parallel_residual = config.use_parallel_residual
|
|
self.hidden_size = config.hidden_size
|
|
self.self_attn = ATTENTION_CLASSES[config._attn_implementation](config, layer_idx=layer_idx)
|
|
self.mlp = StableLmMLP(config)
|
|
self.input_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
|
self.post_attention_layernorm = None
|
|
if not self.use_parallel_residual:
|
|
self.post_attention_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
|
self.dropout = nn.Dropout(config.hidden_dropout)
|
|
|
|
def forward(
|
|
self,
|
|
hidden_states: torch.Tensor,
|
|
attention_mask: Optional[torch.Tensor] = None,
|
|
position_ids: Optional[torch.LongTensor] = None,
|
|
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
|
output_attentions: Optional[bool] = False,
|
|
use_cache: Optional[bool] = False,
|
|
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
|
"""
|
|
Args:
|
|
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
|
|
attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
|
|
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
|
|
position_ids (`torch.LongTensor` of shape `({0})`, *optional*):
|
|
Indices of positions of each input sequence tokens in the position embeddings. Selected in the range
|
|
`[0, config.n_positions - 1]`.
|
|
|
|
[What are position IDs?](../glossary#position-ids)
|
|
past_key_value (`Tuple(torch.FloatTensor)`, *optional*):
|
|
cached past key and value projection states
|
|
output_attentions (`bool`, *optional*):
|
|
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
|
|
returned tensors for more detail.
|
|
use_cache (`bool`, *optional*):
|
|
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
|
|
(see `past_key_values`).
|
|
"""
|
|
|
|
residual = hidden_states
|
|
|
|
hidden_states = self.input_layernorm(hidden_states)
|
|
|
|
# Self Attention
|
|
self_attn_output, self_attn_weights, present_key_value = self.self_attn(
|
|
hidden_states=hidden_states,
|
|
attention_mask=attention_mask,
|
|
position_ids=position_ids,
|
|
past_key_value=past_key_value,
|
|
output_attentions=output_attentions,
|
|
use_cache=use_cache,
|
|
)
|
|
|
|
# copied from transformers.models.gpt_neox.modeling_gpt_neox.GPTNeoXLayer.forward
|
|
if self.use_parallel_residual:
|
|
# x = x + attn(ln1(x)) + mlp(ln1(x))
|
|
# Fully Connected
|
|
mlp_output = self.mlp(hidden_states)
|
|
mlp_output = self.dropout(mlp_output)
|
|
hidden_states = residual + self_attn_output + mlp_output
|
|
else:
|
|
# x = x + attn(ln1(x))
|
|
# x = x + mlp(ln2(x))
|
|
residual = residual + self_attn_output
|
|
# Fully Connected
|
|
mlp_output = self.mlp(self.post_attention_layernorm(residual))
|
|
mlp_output = self.dropout(mlp_output)
|
|
hidden_states = residual + mlp_output
|
|
|
|
outputs = (hidden_states,)
|
|
|
|
if output_attentions:
|
|
outputs += (self_attn_weights,)
|
|
|
|
if use_cache:
|
|
outputs += (present_key_value,)
|
|
|
|
return outputs
|
|
|
|
|
|
STABLELM_START_DOCSTRING = r"""
|
|
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
|
|
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
|
|
etc.)
|
|
|
|
This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
|
|
Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
|
|
and behavior.
|
|
|
|
Parameters:
|
|
config ([`StableLmConfig`]):
|
|
Model configuration class with all the parameters of the model. Initializing with a config file does not
|
|
load the weights associated with the model, only the configuration. Check out the
|
|
[`~PreTrainedModel.from_pretrained`] method to load the model weights.
|
|
"""
|
|
|
|
|
|
@add_start_docstrings(
|
|
"The bare StableLm Model outputting raw hidden-states without any specific head on top.",
|
|
STABLELM_START_DOCSTRING,
|
|
)
|
|
class StableLmPreTrainedModel(PreTrainedModel):
|
|
config_class = StableLmConfig
|
|
base_model_prefix = "model"
|
|
supports_gradient_checkpointing = True
|
|
_no_split_modules = ["StableLmDecoderLayer"]
|
|
_skip_keys_device_placement = "past_key_values"
|
|
_supports_flash_attn_2 = True
|
|
_supports_cache_class = True
|
|
_supports_sdpa = True
|
|
|
|
def _init_weights(self, module):
|
|
std = self.config.initializer_range
|
|
if isinstance(module, nn.Linear):
|
|
module.weight.data.normal_(mean=0.0, std=std)
|
|
if module.bias is not None:
|
|
module.bias.data.zero_()
|
|
elif isinstance(module, nn.Embedding):
|
|
module.weight.data.normal_(mean=0.0, std=std)
|
|
if module.padding_idx is not None:
|
|
module.weight.data[module.padding_idx].zero_()
|
|
|
|
|
|
STABLELM_INPUTS_DOCSTRING = r"""
|
|
Args:
|
|
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
|
|
Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
|
|
it.
|
|
|
|
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
|
|
[`PreTrainedTokenizer.__call__`] for details.
|
|
|
|
[What are input IDs?](../glossary#input-ids)
|
|
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
|
|
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
|
|
|
|
- 1 for tokens that are **not masked**,
|
|
- 0 for tokens that are **masked**.
|
|
|
|
[What are attention masks?](../glossary#attention-mask)
|
|
|
|
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
|
|
[`PreTrainedTokenizer.__call__`] for details.
|
|
|
|
If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see
|
|
`past_key_values`).
|
|
|
|
If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
|
|
and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
|
|
information on the default strategy.
|
|
|
|
- 1 indicates the head is **not masked**,
|
|
- 0 indicates the head is **masked**.
|
|
position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
|
Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
|
|
config.n_positions - 1]`.
|
|
|
|
[What are position IDs?](../glossary#position-ids)
|
|
past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*):
|
|
Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
|
|
blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values`
|
|
returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`.
|
|
|
|
Two formats are allowed:
|
|
- a [`~cache_utils.Cache`] instance;
|
|
- Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of
|
|
shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy
|
|
cache format.
|
|
|
|
The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the
|
|
legacy cache format will be returned.
|
|
|
|
If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't
|
|
have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids`
|
|
of shape `(batch_size, sequence_length)`.
|
|
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
|
|
Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
|
|
is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
|
|
model's internal embedding lookup matrix.
|
|
use_cache (`bool`, *optional*):
|
|
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
|
|
`past_key_values`).
|
|
output_attentions (`bool`, *optional*):
|
|
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
|
|
tensors for more detail.
|
|
output_hidden_states (`bool`, *optional*):
|
|
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
|
|
more detail.
|
|
return_dict (`bool`, *optional*):
|
|
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
|
"""
|
|
|
|
|
|
@add_start_docstrings(
|
|
"The bare StableLm Model outputting raw hidden-states without any specific head on top.",
|
|
STABLELM_START_DOCSTRING,
|
|
)
|
|
class StableLmModel(StableLmPreTrainedModel):
|
|
"""
|
|
Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`StableLmDecoderLayer`]
|
|
|
|
Args:
|
|
config: StableLmConfig
|
|
"""
|
|
|
|
def __init__(self, config: StableLmConfig):
|
|
super().__init__(config)
|
|
self.padding_idx = config.pad_token_id
|
|
self.vocab_size = config.vocab_size
|
|
|
|
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
|
|
self.layers = nn.ModuleList(
|
|
[StableLmDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
|
|
)
|
|
self.norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
|
|
|
self._attn_implementation = config._attn_implementation
|
|
self.gradient_checkpointing = False
|
|
# Initialize weights and apply final processing
|
|
self.post_init()
|
|
|
|
def get_input_embeddings(self):
|
|
return self.embed_tokens
|
|
|
|
def set_input_embeddings(self, value):
|
|
self.embed_tokens = value
|
|
|
|
@add_start_docstrings_to_model_forward(STABLELM_INPUTS_DOCSTRING)
|
|
def forward(
|
|
self,
|
|
input_ids: torch.LongTensor = None,
|
|
attention_mask: Optional[torch.Tensor] = None,
|
|
position_ids: Optional[torch.LongTensor] = None,
|
|
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
|
inputs_embeds: Optional[torch.FloatTensor] = None,
|
|
use_cache: Optional[bool] = None,
|
|
output_attentions: Optional[bool] = None,
|
|
output_hidden_states: Optional[bool] = None,
|
|
return_dict: Optional[bool] = None,
|
|
) -> Union[Tuple, BaseModelOutputWithPast]:
|
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
|
output_hidden_states = (
|
|
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
|
)
|
|
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
|
|
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
|
|
|
# retrieve input_ids and inputs_embeds
|
|
if input_ids is not None and inputs_embeds is not None:
|
|
raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
|
|
elif input_ids is not None:
|
|
batch_size, seq_length = input_ids.shape
|
|
elif inputs_embeds is not None:
|
|
batch_size, seq_length, _ = inputs_embeds.shape
|
|
else:
|
|
raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
|
|
|
|
seq_length_with_past = seq_length
|
|
past_key_values_length = 0
|
|
|
|
if self.gradient_checkpointing and self.training:
|
|
if use_cache:
|
|
logger.warning_once(
|
|
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
|
|
)
|
|
use_cache = False
|
|
|
|
if use_cache:
|
|
use_legacy_cache = not isinstance(past_key_values, Cache)
|
|
if use_legacy_cache:
|
|
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
|
|
past_key_values_length = past_key_values.get_usable_length(seq_length)
|
|
seq_length_with_past = seq_length_with_past + past_key_values_length
|
|
|
|
if position_ids is None:
|
|
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
|
position_ids = torch.arange(
|
|
past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
|
|
)
|
|
position_ids = position_ids.unsqueeze(0)
|
|
|
|
if inputs_embeds is None:
|
|
inputs_embeds = self.embed_tokens(input_ids)
|
|
# embed positions
|
|
if self._attn_implementation == "flash_attention_2":
|
|
# 2d mask is passed through the layers
|
|
attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
|
|
# for output_attentions case used fallback to eager attention realization
|
|
elif self._attn_implementation == "sdpa" and not output_attentions:
|
|
attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(
|
|
attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length
|
|
)
|
|
else:
|
|
# 4d mask is passed through the layers
|
|
attention_mask = _prepare_4d_causal_attention_mask(
|
|
attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length
|
|
)
|
|
|
|
hidden_states = inputs_embeds
|
|
|
|
# decoder layers
|
|
all_hidden_states = () if output_hidden_states else None
|
|
all_self_attns = () if output_attentions else None
|
|
next_decoder_cache = None
|
|
|
|
for decoder_layer in self.layers:
|
|
if output_hidden_states:
|
|
all_hidden_states += (hidden_states,)
|
|
|
|
if self.gradient_checkpointing and self.training:
|
|
layer_outputs = self._gradient_checkpointing_func(
|
|
decoder_layer.__call__,
|
|
hidden_states,
|
|
attention_mask,
|
|
position_ids,
|
|
past_key_values,
|
|
output_attentions,
|
|
)
|
|
else:
|
|
layer_outputs = decoder_layer(
|
|
hidden_states,
|
|
attention_mask=attention_mask,
|
|
position_ids=position_ids,
|
|
past_key_value=past_key_values,
|
|
output_attentions=output_attentions,
|
|
use_cache=use_cache,
|
|
)
|
|
|
|
hidden_states = layer_outputs[0]
|
|
|
|
if use_cache:
|
|
next_decoder_cache = layer_outputs[2 if output_attentions else 1]
|
|
|
|
if output_attentions:
|
|
all_self_attns += (layer_outputs[1],)
|
|
|
|
hidden_states = self.norm(hidden_states)
|
|
|
|
# add hidden states from the last decoder layer
|
|
if output_hidden_states:
|
|
all_hidden_states += (hidden_states,)
|
|
|
|
next_cache = None
|
|
if use_cache:
|
|
next_cache = next_decoder_cache.to_legacy_cache() if use_legacy_cache else next_decoder_cache
|
|
|
|
if not return_dict:
|
|
return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
|
|
return BaseModelOutputWithPast(
|
|
last_hidden_state=hidden_states,
|
|
past_key_values=next_cache,
|
|
hidden_states=all_hidden_states,
|
|
attentions=all_self_attns,
|
|
)
|
|
|
|
|
|
# Copied from transformers.models.persimmon.modeling_persimmon.PersimmonForCausalLM with PERSIMMON->STABLELM,Persimmon->StableLm
|
|
class StableLmForCausalLM(StableLmPreTrainedModel):
|
|
_tied_weights_keys = ["lm_head.weight"]
|
|
|
|
# Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.__init__ with LLAMA->STABLELM,Llama->StableLm
|
|
def __init__(self, config):
|
|
super().__init__(config)
|
|
self.model = StableLmModel(config)
|
|
self.vocab_size = config.vocab_size
|
|
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
|
|
|
# Initialize weights and apply final processing
|
|
self.post_init()
|
|
|
|
# Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.get_input_embeddings
|
|
def get_input_embeddings(self):
|
|
return self.model.embed_tokens
|
|
|
|
# Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.set_input_embeddings
|
|
def set_input_embeddings(self, value):
|
|
self.model.embed_tokens = value
|
|
|
|
# Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.get_output_embeddings
|
|
def get_output_embeddings(self):
|
|
return self.lm_head
|
|
|
|
# Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.set_output_embeddings
|
|
def set_output_embeddings(self, new_embeddings):
|
|
self.lm_head = new_embeddings
|
|
|
|
# Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.set_decoder
|
|
def set_decoder(self, decoder):
|
|
self.model = decoder
|
|
|
|
# Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.get_decoder
|
|
def get_decoder(self):
|
|
return self.model
|
|
|
|
@add_start_docstrings_to_model_forward(STABLELM_INPUTS_DOCSTRING)
|
|
@replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
|
|
# Ignore copy
|
|
def forward(
|
|
self,
|
|
input_ids: torch.LongTensor = None,
|
|
attention_mask: Optional[torch.Tensor] = None,
|
|
position_ids: Optional[torch.LongTensor] = None,
|
|
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
|
inputs_embeds: Optional[torch.FloatTensor] = None,
|
|
labels: Optional[torch.LongTensor] = None,
|
|
use_cache: Optional[bool] = None,
|
|
output_attentions: Optional[bool] = None,
|
|
output_hidden_states: Optional[bool] = None,
|
|
return_dict: Optional[bool] = None,
|
|
) -> Union[Tuple, CausalLMOutputWithPast]:
|
|
r"""
|
|
Args:
|
|
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
|
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
|
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
|
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
|
|
|
Returns:
|
|
|
|
Example:
|
|
|
|
```python
|
|
>>> from transformers import AutoTokenizer, StableLmForCausalLM
|
|
|
|
>>> model = StableLmForCausalLM.from_pretrained("stabilityai/stablelm-3b-4e1t")
|
|
>>> tokenizer = AutoTokenizer.from_pretrained("stabilityai/stablelm-3b-4e1t")
|
|
|
|
>>> prompt = "The weather is always wonderful in"
|
|
>>> inputs = tokenizer(prompt, return_tensors="pt")
|
|
|
|
>>> # Generate
|
|
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
|
|
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
|
'The weather is always wonderful in the summer in the city of San Diego. The city is located on the coast of the Pacific Ocean and is surrounded by'
|
|
```"""
|
|
|
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
|
output_hidden_states = (
|
|
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
|
)
|
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
|
|
|
outputs = self.model(
|
|
input_ids=input_ids,
|
|
attention_mask=attention_mask,
|
|
position_ids=position_ids,
|
|
past_key_values=past_key_values,
|
|
inputs_embeds=inputs_embeds,
|
|
use_cache=use_cache,
|
|
output_attentions=output_attentions,
|
|
output_hidden_states=output_hidden_states,
|
|
return_dict=return_dict,
|
|
)
|
|
|
|
hidden_states = outputs[0]
|
|
logits = self.lm_head(hidden_states)
|
|
|
|
loss = None
|
|
if labels is not None:
|
|
# Shift so that tokens < n predict n
|
|
shift_logits = logits[..., :-1, :].contiguous()
|
|
shift_labels = labels[..., 1:].contiguous()
|
|
# Flatten the tokens
|
|
loss_fct = CrossEntropyLoss()
|
|
shift_logits = shift_logits.view(-1, self.config.vocab_size)
|
|
shift_labels = shift_labels.view(-1)
|
|
# Enable model parallelism
|
|
shift_labels = shift_labels.to(shift_logits.device)
|
|
loss = loss_fct(shift_logits, shift_labels)
|
|
|
|
if not return_dict:
|
|
output = (logits,) + outputs[1:]
|
|
return (loss,) + output if loss is not None else output
|
|
|
|
return CausalLMOutputWithPast(
|
|
loss=loss,
|
|
logits=logits,
|
|
past_key_values=outputs.past_key_values,
|
|
hidden_states=outputs.hidden_states,
|
|
attentions=outputs.attentions,
|
|
)
|
|
|
|
def prepare_inputs_for_generation(
|
|
self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
|
|
):
|
|
if past_key_values is not None:
|
|
if isinstance(past_key_values, Cache):
|
|
cache_length = past_key_values.get_seq_length()
|
|
past_length = past_key_values.seen_tokens
|
|
max_cache_length = past_key_values.get_max_length()
|
|
else:
|
|
cache_length = past_length = past_key_values[0][0].shape[2]
|
|
max_cache_length = None
|
|
|
|
# Keep only the unprocessed tokens:
|
|
# 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
|
|
# some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as
|
|
# input)
|
|
if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]:
|
|
input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :]
|
|
# 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard
|
|
# input_ids based on the past_length.
|
|
elif past_length < input_ids.shape[1]:
|
|
input_ids = input_ids[:, past_length:]
|
|
# 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens.
|
|
|
|
# If we are about to go beyond the maximum cache length, we need to crop the input attention mask.
|
|
if (
|
|
max_cache_length is not None
|
|
and attention_mask is not None
|
|
and cache_length + input_ids.shape[1] > max_cache_length
|
|
):
|
|
attention_mask = attention_mask[:, -max_cache_length:]
|
|
|
|
position_ids = kwargs.get("position_ids", None)
|
|
if attention_mask is not None and position_ids is None:
|
|
# create position_ids on the fly for batch generation
|
|
position_ids = attention_mask.long().cumsum(-1) - 1
|
|
position_ids.masked_fill_(attention_mask == 0, 1)
|
|
if past_key_values:
|
|
position_ids = position_ids[:, -input_ids.shape[1] :]
|
|
|
|
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
|
|
if inputs_embeds is not None and past_key_values is None:
|
|
model_inputs = {"inputs_embeds": inputs_embeds}
|
|
else:
|
|
model_inputs = {"input_ids": input_ids}
|
|
|
|
model_inputs.update(
|
|
{
|
|
"position_ids": position_ids,
|
|
"past_key_values": past_key_values,
|
|
"use_cache": kwargs.get("use_cache"),
|
|
"attention_mask": attention_mask,
|
|
}
|
|
)
|
|
return model_inputs
|
|
|
|
@staticmethod
|
|
def _reorder_cache(past_key_values, beam_idx):
|
|
reordered_past = ()
|
|
for layer_past in past_key_values:
|
|
reordered_past += (
|
|
tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),
|
|
)
|
|
return reordered_past
|
|
|
|
|
|
@add_start_docstrings(
|
|
"""
|
|
The StableLm transformer with a sequence classification head on top (linear layer).
|
|
|
|
[`StableLmForSequenceClassification`] uses the last token in order to do the classification, as other causal
|
|
models (e.g. GPT-2) do.
|
|
|
|
Since it does classification on the last token, it requires to know the position of the last token. If a
|
|
`pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If
|
|
no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the
|
|
padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in
|
|
each row of the batch).
|
|
""",
|
|
STABLELM_START_DOCSTRING,
|
|
)
|
|
# Copied from transformers.models.llama.modeling_llama.LlamaForSequenceClassification with LLAMA->STABLELM,Llama->StableLm
|
|
class StableLmForSequenceClassification(StableLmPreTrainedModel):
|
|
def __init__(self, config):
|
|
super().__init__(config)
|
|
self.num_labels = config.num_labels
|
|
self.model = StableLmModel(config)
|
|
self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False)
|
|
|
|
# Initialize weights and apply final processing
|
|
self.post_init()
|
|
|
|
def get_input_embeddings(self):
|
|
return self.model.embed_tokens
|
|
|
|
def set_input_embeddings(self, value):
|
|
self.model.embed_tokens = value
|
|
|
|
@add_start_docstrings_to_model_forward(STABLELM_INPUTS_DOCSTRING)
|
|
def forward(
|
|
self,
|
|
input_ids: torch.LongTensor = None,
|
|
attention_mask: Optional[torch.Tensor] = None,
|
|
position_ids: Optional[torch.LongTensor] = None,
|
|
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
|
inputs_embeds: Optional[torch.FloatTensor] = None,
|
|
labels: Optional[torch.LongTensor] = None,
|
|
use_cache: Optional[bool] = None,
|
|
output_attentions: Optional[bool] = None,
|
|
output_hidden_states: Optional[bool] = None,
|
|
return_dict: Optional[bool] = None,
|
|
) -> Union[Tuple, SequenceClassifierOutputWithPast]:
|
|
r"""
|
|
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
|
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
|
|
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
|
|
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
|
|
"""
|
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
|
|
|
transformer_outputs = self.model(
|
|
input_ids,
|
|
attention_mask=attention_mask,
|
|
position_ids=position_ids,
|
|
past_key_values=past_key_values,
|
|
inputs_embeds=inputs_embeds,
|
|
use_cache=use_cache,
|
|
output_attentions=output_attentions,
|
|
output_hidden_states=output_hidden_states,
|
|
return_dict=return_dict,
|
|
)
|
|
hidden_states = transformer_outputs[0]
|
|
logits = self.score(hidden_states)
|
|
|
|
if input_ids is not None:
|
|
batch_size = input_ids.shape[0]
|
|
else:
|
|
batch_size = inputs_embeds.shape[0]
|
|
|
|
if self.config.pad_token_id is None and batch_size != 1:
|
|
raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
|
|
if self.config.pad_token_id is None:
|
|
sequence_lengths = -1
|
|
else:
|
|
if input_ids is not None:
|
|
# if no pad token found, use modulo instead of reverse indexing for ONNX compatibility
|
|
sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1
|
|
sequence_lengths = sequence_lengths % input_ids.shape[-1]
|
|
sequence_lengths = sequence_lengths.to(logits.device)
|
|
else:
|
|
sequence_lengths = -1
|
|
|
|
pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths]
|
|
|
|
loss = None
|
|
if labels is not None:
|
|
labels = labels.to(logits.device)
|
|
if self.config.problem_type is None:
|
|
if self.num_labels == 1:
|
|
self.config.problem_type = "regression"
|
|
elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
|
|
self.config.problem_type = "single_label_classification"
|
|
else:
|
|
self.config.problem_type = "multi_label_classification"
|
|
|
|
if self.config.problem_type == "regression":
|
|
loss_fct = MSELoss()
|
|
if self.num_labels == 1:
|
|
loss = loss_fct(pooled_logits.squeeze(), labels.squeeze())
|
|
else:
|
|
loss = loss_fct(pooled_logits, labels)
|
|
elif self.config.problem_type == "single_label_classification":
|
|
loss_fct = CrossEntropyLoss()
|
|
loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1))
|
|
elif self.config.problem_type == "multi_label_classification":
|
|
loss_fct = BCEWithLogitsLoss()
|
|
loss = loss_fct(pooled_logits, labels)
|
|
if not return_dict:
|
|
output = (pooled_logits,) + transformer_outputs[1:]
|
|
return ((loss,) + output) if loss is not None else output
|
|
|
|
return SequenceClassifierOutputWithPast(
|
|
loss=loss,
|
|
logits=pooled_logits,
|
|
past_key_values=transformer_outputs.past_key_values,
|
|
hidden_states=transformer_outputs.hidden_states,
|
|
attentions=transformer_outputs.attentions,
|
|
)
|