219 lines
8.7 KiB
Python
219 lines
8.7 KiB
Python
# coding=utf-8
|
|
# Copyright 2021 The EleutherAI and HuggingFace Teams. All rights reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
""" GPT-J model configuration"""
|
|
from collections import OrderedDict
|
|
from typing import Any, List, Mapping, Optional
|
|
|
|
from ... import PreTrainedTokenizer, TensorType, is_torch_available
|
|
from ...configuration_utils import PretrainedConfig
|
|
from ...onnx import OnnxConfigWithPast, PatchingSpec
|
|
from ...utils import logging
|
|
|
|
|
|
logger = logging.get_logger(__name__)
|
|
|
|
|
|
from ..deprecated._archive_maps import GPTJ_PRETRAINED_CONFIG_ARCHIVE_MAP # noqa: F401, E402
|
|
|
|
|
|
class GPTJConfig(PretrainedConfig):
|
|
r"""
|
|
This is the configuration class to store the configuration of a [`GPTJModel`]. It is used to instantiate a GPT-J
|
|
model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
|
|
defaults will yield a similar configuration to that of the GPT-J
|
|
[EleutherAI/gpt-j-6B](https://huggingface.co/EleutherAI/gpt-j-6B) architecture. Configuration objects inherit from
|
|
[`PretrainedConfig`] and can be used to control the model outputs. Read the documentation from [`PretrainedConfig`]
|
|
for more information.
|
|
|
|
Args:
|
|
vocab_size (`int`, *optional*, defaults to 50400):
|
|
Vocabulary size of the GPT-J model. Defines the number of different tokens that can be represented by the
|
|
`inputs_ids` passed when calling [`GPTJModel`].
|
|
n_positions (`int`, *optional*, defaults to 2048):
|
|
The maximum sequence length that this model might ever be used with. Typically set this to something large
|
|
just in case (e.g., 512 or 1024 or 2048).
|
|
n_embd (`int`, *optional*, defaults to 4096):
|
|
Dimensionality of the embeddings and hidden states.
|
|
n_layer (`int`, *optional*, defaults to 28):
|
|
Number of hidden layers in the Transformer encoder.
|
|
n_head (`int`, *optional*, defaults to 16):
|
|
Number of attention heads for each attention layer in the Transformer encoder.
|
|
rotary_dim (`int`, *optional*, defaults to 64):
|
|
Number of dimensions in the embedding that Rotary Position Embedding is applied to.
|
|
n_inner (`int`, *optional*, defaults to None):
|
|
Dimensionality of the inner feed-forward layers. `None` will set it to 4 times n_embd
|
|
activation_function (`str`, *optional*, defaults to `"gelu_new"`):
|
|
Activation function, to be selected in the list `["relu", "silu", "gelu", "tanh", "gelu_new"]`.
|
|
resid_pdrop (`float`, *optional*, defaults to 0.1):
|
|
The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
|
|
embd_pdrop (`int`, *optional*, defaults to 0.1):
|
|
The dropout ratio for the embeddings.
|
|
attn_pdrop (`float`, *optional*, defaults to 0.1):
|
|
The dropout ratio for the attention.
|
|
layer_norm_epsilon (`float`, *optional*, defaults to 1e-5):
|
|
The epsilon to use in the layer normalization layers.
|
|
initializer_range (`float`, *optional*, defaults to 0.02):
|
|
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
|
use_cache (`bool`, *optional*, defaults to `True`):
|
|
Whether or not the model should return the last key/values attentions (not used by all models).
|
|
|
|
Example:
|
|
|
|
```python
|
|
>>> from transformers import GPTJModel, GPTJConfig
|
|
|
|
>>> # Initializing a GPT-J 6B configuration
|
|
>>> configuration = GPTJConfig()
|
|
|
|
>>> # Initializing a model from the configuration
|
|
>>> model = GPTJModel(configuration)
|
|
|
|
>>> # Accessing the model configuration
|
|
>>> configuration = model.config
|
|
```"""
|
|
|
|
model_type = "gptj"
|
|
attribute_map = {
|
|
"max_position_embeddings": "n_positions",
|
|
"hidden_size": "n_embd",
|
|
"num_attention_heads": "n_head",
|
|
"num_hidden_layers": "n_layer",
|
|
}
|
|
|
|
def __init__(
|
|
self,
|
|
vocab_size=50400,
|
|
n_positions=2048,
|
|
n_embd=4096,
|
|
n_layer=28,
|
|
n_head=16,
|
|
rotary_dim=64,
|
|
n_inner=None,
|
|
activation_function="gelu_new",
|
|
resid_pdrop=0.0,
|
|
embd_pdrop=0.0,
|
|
attn_pdrop=0.0,
|
|
layer_norm_epsilon=1e-5,
|
|
initializer_range=0.02,
|
|
use_cache=True,
|
|
bos_token_id=50256,
|
|
eos_token_id=50256,
|
|
tie_word_embeddings=False,
|
|
**kwargs,
|
|
):
|
|
self.vocab_size = vocab_size
|
|
self.n_positions = n_positions
|
|
self.n_embd = n_embd
|
|
self.n_layer = n_layer
|
|
self.n_head = n_head
|
|
self.n_inner = n_inner
|
|
self.rotary_dim = rotary_dim
|
|
self.activation_function = activation_function
|
|
self.resid_pdrop = resid_pdrop
|
|
self.embd_pdrop = embd_pdrop
|
|
self.attn_pdrop = attn_pdrop
|
|
self.layer_norm_epsilon = layer_norm_epsilon
|
|
self.initializer_range = initializer_range
|
|
self.use_cache = use_cache
|
|
|
|
self.bos_token_id = bos_token_id
|
|
self.eos_token_id = eos_token_id
|
|
|
|
super().__init__(
|
|
bos_token_id=bos_token_id, eos_token_id=eos_token_id, tie_word_embeddings=tie_word_embeddings, **kwargs
|
|
)
|
|
|
|
|
|
# Copied from transformers.models.gpt2.configuration_gpt2.GPT2OnnxConfig
|
|
class GPTJOnnxConfig(OnnxConfigWithPast):
|
|
def __init__(
|
|
self,
|
|
config: PretrainedConfig,
|
|
task: str = "default",
|
|
patching_specs: List[PatchingSpec] = None,
|
|
use_past: bool = False,
|
|
):
|
|
super().__init__(config, task=task, patching_specs=patching_specs, use_past=use_past)
|
|
if not getattr(self._config, "pad_token_id", None):
|
|
# TODO: how to do that better?
|
|
self._config.pad_token_id = 0
|
|
|
|
@property
|
|
def inputs(self) -> Mapping[str, Mapping[int, str]]:
|
|
common_inputs = OrderedDict({"input_ids": {0: "batch", 1: "sequence"}})
|
|
if self.use_past:
|
|
self.fill_with_past_key_values_(common_inputs, direction="inputs")
|
|
common_inputs["attention_mask"] = {0: "batch", 1: "past_sequence + sequence"}
|
|
else:
|
|
common_inputs["attention_mask"] = {0: "batch", 1: "sequence"}
|
|
|
|
return common_inputs
|
|
|
|
@property
|
|
def num_layers(self) -> int:
|
|
return self._config.n_layer
|
|
|
|
@property
|
|
def num_attention_heads(self) -> int:
|
|
return self._config.n_head
|
|
|
|
def generate_dummy_inputs(
|
|
self,
|
|
tokenizer: PreTrainedTokenizer,
|
|
batch_size: int = -1,
|
|
seq_length: int = -1,
|
|
is_pair: bool = False,
|
|
framework: Optional[TensorType] = None,
|
|
) -> Mapping[str, Any]:
|
|
common_inputs = super(OnnxConfigWithPast, self).generate_dummy_inputs(
|
|
tokenizer, batch_size=batch_size, seq_length=seq_length, is_pair=is_pair, framework=framework
|
|
)
|
|
|
|
# We need to order the input in the way they appears in the forward()
|
|
ordered_inputs = OrderedDict({"input_ids": common_inputs["input_ids"]})
|
|
|
|
# Need to add the past_keys
|
|
if self.use_past:
|
|
if not is_torch_available():
|
|
raise ValueError("Cannot generate dummy past_keys inputs without PyTorch installed.")
|
|
else:
|
|
import torch
|
|
|
|
batch, seqlen = common_inputs["input_ids"].shape
|
|
# Not using the same length for past_key_values
|
|
past_key_values_length = seqlen + 2
|
|
past_shape = (
|
|
batch,
|
|
self.num_attention_heads,
|
|
past_key_values_length,
|
|
self._config.hidden_size // self.num_attention_heads,
|
|
)
|
|
ordered_inputs["past_key_values"] = [
|
|
(torch.zeros(past_shape), torch.zeros(past_shape)) for _ in range(self.num_layers)
|
|
]
|
|
|
|
ordered_inputs["attention_mask"] = common_inputs["attention_mask"]
|
|
if self.use_past:
|
|
mask_dtype = ordered_inputs["attention_mask"].dtype
|
|
ordered_inputs["attention_mask"] = torch.cat(
|
|
[ordered_inputs["attention_mask"], torch.ones(batch, past_key_values_length, dtype=mask_dtype)], dim=1
|
|
)
|
|
|
|
return ordered_inputs
|
|
|
|
@property
|
|
def default_onnx_opset(self) -> int:
|
|
return 13
|