426 lines
20 KiB
Python
426 lines
20 KiB
Python
|
# coding=utf-8
|
||
|
# Copyright 2023 The HuggingFace Inc. team.
|
||
|
#
|
||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||
|
# you may not use this file except in compliance with the License.
|
||
|
# You may obtain a copy of the License at
|
||
|
#
|
||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||
|
#
|
||
|
# Unless required by applicable law or agreed to in writing, software
|
||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
|
# See the License for the specific language governing permissions and
|
||
|
# limitations under the License.
|
||
|
|
||
|
import copy
|
||
|
from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple
|
||
|
|
||
|
import torch
|
||
|
|
||
|
from ..cache_utils import DynamicCache
|
||
|
|
||
|
|
||
|
if TYPE_CHECKING:
|
||
|
from ..modeling_utils import PreTrainedModel
|
||
|
from .configuration_utils import GenerationConfig
|
||
|
from .logits_process import LogitsProcessorList
|
||
|
|
||
|
|
||
|
class CandidateGenerator:
|
||
|
"""Abstract base class for all candidate generators that can be applied during assisted generation."""
|
||
|
|
||
|
def get_candidates(self, input_ids: torch.LongTensor) -> Tuple[torch.LongTensor, Optional[torch.FloatTensor]]:
|
||
|
"""
|
||
|
Fetches the candidates to be tried for the current input.
|
||
|
|
||
|
Args:
|
||
|
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
|
||
|
Indices of input sequence tokens in the vocabulary. [What are input IDs?](../glossary#input-ids)
|
||
|
|
||
|
Return:
|
||
|
`torch.LongTensor` of shape `(batch_size, candidate_length)` containing the candidate sequences to be
|
||
|
assessed by the model and, optionally, a `torch.FloatTensor` of shape `(batch_size, candidate_length,
|
||
|
vocabulary_size)` containing the logits associated to each candidate.
|
||
|
"""
|
||
|
raise NotImplementedError(
|
||
|
f"{self.__class__} is an abstract class. Only classes inheriting this class can call `get_candidates`."
|
||
|
)
|
||
|
|
||
|
def update_candidate_strategy(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, num_matches: int):
|
||
|
"""
|
||
|
Updates the candidate generation strategy based on the outcomes.
|
||
|
|
||
|
Args:
|
||
|
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
|
||
|
Indices of input sequence tokens in the vocabulary. [What are input IDs?](../glossary#input-ids)
|
||
|
scores (`torch.FloatTensor` of shape `(batch_size, candidate_length, config.vocab_size)`):
|
||
|
Prediction scores of a language modeling head. These can be logits for each vocabulary when not using
|
||
|
beam search or log softmax for each vocabulary token when using beam search
|
||
|
num_matches (`int`):
|
||
|
The number of matches between the candidate sequences and the model predictions.
|
||
|
"""
|
||
|
raise NotImplementedError(
|
||
|
f"{self.__class__} is an abstract class. Only classes inheriting this class can call "
|
||
|
"`update_candidate_strategy`."
|
||
|
)
|
||
|
|
||
|
|
||
|
class AssistedCandidateGenerator(CandidateGenerator):
|
||
|
"""
|
||
|
`CandidateGenerator` class to be used for assisted generation and speculative decoding. This class generates
|
||
|
candidates through the use of a smaller model. Read the following blog post for more information:
|
||
|
https://huggingface.co/blog/assisted-generation
|
||
|
|
||
|
Args:
|
||
|
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
|
||
|
Indices of input sequence tokens in the vocabulary. [What are input IDs?](../glossary#input-ids)
|
||
|
assistant_model (`PreTrainedModel`):
|
||
|
The model to be used for generating candidates. This model should be smaller than the main model.
|
||
|
generation_config (`~generation.GenerationConfig`, *optional*):
|
||
|
The generation configuration to be used as base parametrization for the generation call.
|
||
|
logits_processor (`LogitsProcessorList`):
|
||
|
An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsProcessor`]
|
||
|
used to modify the prediction scores of the language modeling head applied at each generation step.
|
||
|
model_kwargs (`Dict`):
|
||
|
The keyword arguments that will be passed to the main model, and are used as base inputs for the assistant
|
||
|
model as well.
|
||
|
inputs_tensor (`torch.Tensor`, *optional*):
|
||
|
The model input tensor. In encoder-decoder models, this is the encoder input.
|
||
|
"""
|
||
|
|
||
|
def __init__(
|
||
|
self,
|
||
|
input_ids: torch.LongTensor,
|
||
|
assistant_model: "PreTrainedModel",
|
||
|
generation_config: "GenerationConfig",
|
||
|
logits_processor: "LogitsProcessorList",
|
||
|
model_kwargs: Dict,
|
||
|
inputs_tensor: Optional[torch.Tensor] = None,
|
||
|
):
|
||
|
# Make sure all data at the same device as assistant model
|
||
|
device = assistant_model.device
|
||
|
input_ids = input_ids.to(device)
|
||
|
if inputs_tensor is not None:
|
||
|
inputs_tensor = inputs_tensor.to(device)
|
||
|
|
||
|
# Prepare the assistant and the starting number of candidate tokens
|
||
|
self.assistant_model = assistant_model
|
||
|
self.num_assistant_tokens = assistant_model.generation_config.num_assistant_tokens
|
||
|
|
||
|
# Prepare the kwargs for the assistant model
|
||
|
assistant_kwargs = {}
|
||
|
for key, value in model_kwargs.items(): # deepcopy crashes if we attempt to copy encoder outputs with grads
|
||
|
if key not in ("encoder_outputs", "assistant_encoder_outputs"):
|
||
|
assistant_kwargs[key] = (
|
||
|
value.detach().to(device) if isinstance(value, torch.Tensor) else copy.deepcopy(value)
|
||
|
)
|
||
|
|
||
|
if "assistant_encoder_outputs" in model_kwargs:
|
||
|
assistant_kwargs["encoder_outputs"] = model_kwargs["assistant_encoder_outputs"]
|
||
|
elif assistant_model.config.is_encoder_decoder:
|
||
|
inputs_tensor, model_input_name, assistant_kwargs = assistant_model._prepare_model_inputs(
|
||
|
inputs_tensor, assistant_model.generation_config.bos_token_id, assistant_kwargs
|
||
|
)
|
||
|
assistant_kwargs = assistant_model._prepare_encoder_decoder_kwargs_for_generation(
|
||
|
inputs_tensor, assistant_kwargs, model_input_name
|
||
|
)
|
||
|
elif "encoder_outputs" in model_kwargs:
|
||
|
assistant_kwargs["encoder_outputs"] = model_kwargs["encoder_outputs"]
|
||
|
self.assistant_kwargs = assistant_kwargs
|
||
|
|
||
|
# Prepare assistant model's keys of inputs
|
||
|
if assistant_model.config.is_encoder_decoder:
|
||
|
# both are encoder-decoder
|
||
|
self.input_ids_key = "decoder_input_ids"
|
||
|
elif "encoder_outputs" in assistant_kwargs:
|
||
|
# special case for encoder-decoder with decoder-only assistant (like DistilWhisper)
|
||
|
self.input_ids_key = "input_ids"
|
||
|
self.assistant_kwargs["attention_mask"] = self.assistant_kwargs.get(
|
||
|
"decoder_attention_mask",
|
||
|
torch.ones((input_ids.shape[0], 1), device=input_ids.device, dtype=torch.long),
|
||
|
)
|
||
|
else:
|
||
|
# both are decoder-only
|
||
|
self.input_ids_key = "input_ids"
|
||
|
|
||
|
# Prepare generation-related options.
|
||
|
self.logits_processor = logits_processor
|
||
|
self.generation_config = copy.deepcopy(generation_config)
|
||
|
self.generation_config.return_dict_in_generate = True
|
||
|
self.generation_config.output_scores = True
|
||
|
|
||
|
# avoid unnecessary warnings that min_length is larger than max_new_tokens
|
||
|
self.main_model_min_length = self.generation_config.min_length
|
||
|
self.generation_config.min_length = 0
|
||
|
self.generation_config.min_new_tokens = None
|
||
|
|
||
|
def get_candidates(self, input_ids: torch.LongTensor) -> Tuple[torch.LongTensor, Optional[torch.FloatTensor]]:
|
||
|
"""
|
||
|
Fetches the candidates to be tried for the current input.
|
||
|
|
||
|
Args:
|
||
|
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
|
||
|
Indices of input sequence tokens in the vocabulary. [What are input IDs?](../glossary#input-ids)
|
||
|
|
||
|
Return:
|
||
|
`torch.LongTensor` of shape `(batch_size, candidate_length)` containing the candidate sequences to be
|
||
|
assessed by the model and a `torch.FloatTensor` of shape `(batch_size, candidate_length,
|
||
|
vocabulary_size)` containing the logits associated to each candidate.
|
||
|
"""
|
||
|
input_ids = input_ids.to(self.assistant_model.device)
|
||
|
|
||
|
# Don't generate more than `max_length - 1` candidates since the target model generates one extra token.
|
||
|
new_cur_len = input_ids.shape[-1]
|
||
|
max_new_tokens = min(int(self.num_assistant_tokens), self.generation_config.max_length - new_cur_len - 1)
|
||
|
min_new_tokens = max(min(max_new_tokens, self.main_model_min_length - new_cur_len), 0)
|
||
|
if max_new_tokens == 0:
|
||
|
return input_ids, None
|
||
|
|
||
|
# 1. If it is not the first round of candidate generation, prepare the inputs based on the input_ids length
|
||
|
# (which implicitly contains the number of accepted candidates from the previous round)
|
||
|
has_past_key_values = self.assistant_kwargs.get("past_key_values", None) is not None
|
||
|
if has_past_key_values:
|
||
|
new_cache_size = new_cur_len - 1
|
||
|
self.assistant_kwargs["past_key_values"] = _crop_past_key_values(
|
||
|
self.assistant_model, self.assistant_kwargs["past_key_values"], new_cache_size - 1
|
||
|
) # the assistant does not have the token after the last match, hence the -1
|
||
|
|
||
|
self.assistant_kwargs = _prepare_attention_mask(
|
||
|
self.assistant_kwargs, new_cur_len, self.assistant_model.config.is_encoder_decoder
|
||
|
)
|
||
|
self.assistant_kwargs = _prepare_token_type_ids(self.assistant_kwargs, new_cur_len)
|
||
|
|
||
|
# 2. Forecast next N tokens using the assistant model.
|
||
|
assistant_generation_kwargs = {
|
||
|
self.input_ids_key: input_ids,
|
||
|
"min_new_tokens": min_new_tokens,
|
||
|
"max_new_tokens": max_new_tokens,
|
||
|
"generation_config": self.generation_config,
|
||
|
"logits_processor": self.logits_processor,
|
||
|
}
|
||
|
|
||
|
assistant_output = self.assistant_model.generate(**assistant_generation_kwargs, **self.assistant_kwargs)
|
||
|
|
||
|
# 3. Update variables for the next round of candidate generation
|
||
|
self.assistant_kwargs["past_key_values"] = assistant_output.past_key_values
|
||
|
|
||
|
# 4. Prepare variables for output
|
||
|
candidate_logits = torch.stack(assistant_output.scores, dim=1)
|
||
|
candidate_ids = assistant_output.sequences
|
||
|
return candidate_ids, candidate_logits
|
||
|
|
||
|
def update_candidate_strategy(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, num_matches: int):
|
||
|
"""
|
||
|
Updates the candidate generation strategy based on the outcomes.
|
||
|
|
||
|
Args:
|
||
|
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
|
||
|
Indices of input sequence tokens in the vocabulary. [What are input IDs?](../glossary#input-ids)
|
||
|
scores (`torch.FloatTensor` of shape `(batch_size, candidate_length, config.vocab_size)`):
|
||
|
Prediction scores of a language modeling head. These can be logits for each vocabulary when not using
|
||
|
beam search or log softmax for each vocabulary token when using beam search
|
||
|
num_matches (`int`):
|
||
|
The number of matches between the candidate sequences and the model predictions.
|
||
|
"""
|
||
|
# Adjust the max number of assistant tokens to use in the next iteration. This is a simple heuristic,
|
||
|
# probably can be improved -- we want to balance the benefits of getting assistant tokens correct with the
|
||
|
# cost of forecasting incorrect assistant tokens.
|
||
|
if self.assistant_model.generation_config.num_assistant_tokens_schedule in {
|
||
|
"heuristic",
|
||
|
"heuristic_transient",
|
||
|
}:
|
||
|
if num_matches == int(self.num_assistant_tokens):
|
||
|
self.num_assistant_tokens += 2.0
|
||
|
else:
|
||
|
self.num_assistant_tokens = max(1.0, self.num_assistant_tokens - 1.0)
|
||
|
|
||
|
|
||
|
class PromptLookupCandidateGenerator(CandidateGenerator):
|
||
|
"""
|
||
|
`CandidateGenerator` class to be used for prompt lookup generation. This class generates candidates by looking up
|
||
|
likely continuations in the provided prompt (input_ids) itself.
|
||
|
Read the following blog post for more information: https://github.com/apoorvumang/prompt-lookup-decoding
|
||
|
|
||
|
Args:
|
||
|
max_matching_ngram_size (`int`):
|
||
|
The maximum ngram size to be considered for matching in the prompt
|
||
|
num_output_tokens (`int`):
|
||
|
The number of tokens to be output as candidate tokens.
|
||
|
max_length (`int`):
|
||
|
The number of total maximum tokens that can be generated. For decoder-only models that includes the prompt length.
|
||
|
Defaults to 20, which is the max length used as default in generation config.
|
||
|
"""
|
||
|
|
||
|
def __init__(
|
||
|
self,
|
||
|
num_output_tokens: int = 10,
|
||
|
max_matching_ngram_size: int = None,
|
||
|
max_length: int = 20,
|
||
|
):
|
||
|
self.num_output_tokens = num_output_tokens
|
||
|
self.max_matching_ngram_size = max_matching_ngram_size if max_matching_ngram_size else 2
|
||
|
self.max_length = max_length
|
||
|
|
||
|
if self.max_matching_ngram_size <= 0 or self.num_output_tokens <= 0:
|
||
|
raise ValueError("Invalid max_matching_ngram_size or num_output_tokens")
|
||
|
|
||
|
def get_candidates(self, input_ids: torch.LongTensor) -> Tuple[torch.LongTensor, Optional[torch.FloatTensor]]:
|
||
|
"""
|
||
|
Fetches the candidates to be tried for the current input.
|
||
|
|
||
|
Args:
|
||
|
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
|
||
|
Indices of input sequence tokens in the vocabulary. [What are input IDs?](../glossary#input-ids)
|
||
|
|
||
|
Return:
|
||
|
`torch.LongTensor` of shape `(num_candidates, candidate_length)`: The candidate sequences to be tried.
|
||
|
"""
|
||
|
input_length = input_ids.size(1)
|
||
|
|
||
|
# Don't generate more than `max_length - 1` candidates since the target model generates one extra token.
|
||
|
if self.max_length == input_length + 1:
|
||
|
return input_ids, None
|
||
|
|
||
|
chosen_ids = None
|
||
|
match_found = False
|
||
|
for ngram_size in range(min(self.max_matching_ngram_size, input_length - 1), 0, -1):
|
||
|
# Create sliding windows of size ngram_size
|
||
|
windows = input_ids.unfold(dimension=1, size=ngram_size, step=1)
|
||
|
|
||
|
# Convert ngram to a tensor for comparison
|
||
|
ngram_tensor = input_ids[0, -ngram_size:]
|
||
|
|
||
|
# Find where the windows match the ngram
|
||
|
matches = (windows == ngram_tensor).all(dim=2)
|
||
|
|
||
|
# Get the indices of matches
|
||
|
match_indices = matches.nonzero(as_tuple=True)[1]
|
||
|
|
||
|
# Iterate through match indices to find a valid continuation
|
||
|
for idx in match_indices:
|
||
|
start_idx = idx + ngram_size
|
||
|
end_idx = start_idx + self.num_output_tokens
|
||
|
end_idx = min(end_idx, input_length, self.max_length)
|
||
|
|
||
|
if start_idx < end_idx:
|
||
|
chosen_ids = input_ids[0, start_idx:end_idx]
|
||
|
match_found = True
|
||
|
break
|
||
|
if match_found:
|
||
|
break
|
||
|
|
||
|
if chosen_ids is None or len(chosen_ids) == 0:
|
||
|
# In case we didn't find a match return the input sequence unchanged, reverts back to autoregressive decoding
|
||
|
return input_ids, None
|
||
|
|
||
|
# Now need extend input_ids with chosen_ids
|
||
|
chosen_ids = chosen_ids.unsqueeze(0)
|
||
|
candidate_input_ids = torch.cat((input_ids, chosen_ids), dim=1)
|
||
|
# assisted_generation expects logits as well, but we don't have those here, so returning None
|
||
|
return candidate_input_ids, None
|
||
|
|
||
|
def update_candidate_strategy(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, num_matches: int):
|
||
|
"""
|
||
|
Updates the candidate generation strategy based on the outcomes.
|
||
|
|
||
|
Args:
|
||
|
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
|
||
|
Indices of input sequence tokens in the vocabulary. [What are input IDs?](../glossary#input-ids)
|
||
|
scores (`torch.FloatTensor` of shape `(batch_size, candidate_length, config.vocab_size)`):
|
||
|
Prediction scores of a language modeling head. These can be logits for each vocabulary when not using
|
||
|
beam search or log softmax for each vocabulary token when using beam search
|
||
|
num_matches (`int`):
|
||
|
The number of matches between the candidate sequences and the model predictions.
|
||
|
"""
|
||
|
# Currently does nothing
|
||
|
return
|
||
|
|
||
|
|
||
|
def _crop_past_key_values(model, past_key_values, maximum_length):
|
||
|
"""Crops the past key values up to a certain maximum length."""
|
||
|
new_past = []
|
||
|
if model.config.is_encoder_decoder:
|
||
|
for idx in range(len(past_key_values)):
|
||
|
new_past.append(
|
||
|
(
|
||
|
past_key_values[idx][0][:, :, :maximum_length, :],
|
||
|
past_key_values[idx][1][:, :, :maximum_length, :],
|
||
|
past_key_values[idx][2],
|
||
|
past_key_values[idx][3],
|
||
|
)
|
||
|
)
|
||
|
past_key_values = tuple(new_past)
|
||
|
# bloom is special
|
||
|
elif "bloom" in model.__class__.__name__.lower() or (
|
||
|
model.config.architectures is not None and "bloom" in model.config.architectures[0].lower()
|
||
|
):
|
||
|
for idx in range(len(past_key_values)):
|
||
|
new_past.append(
|
||
|
(
|
||
|
past_key_values[idx][0][:, :, :maximum_length],
|
||
|
past_key_values[idx][1][:, :maximum_length, :],
|
||
|
)
|
||
|
)
|
||
|
past_key_values = tuple(new_past)
|
||
|
# gptbigcode is too
|
||
|
elif "gptbigcode" in model.__class__.__name__.lower() or (
|
||
|
model.config.architectures is not None and "gptbigcode" in model.config.architectures[0].lower()
|
||
|
):
|
||
|
if model.config.multi_query:
|
||
|
for idx in range(len(past_key_values)):
|
||
|
past_key_values[idx] = past_key_values[idx][:, :maximum_length, :]
|
||
|
else:
|
||
|
for idx in range(len(past_key_values)):
|
||
|
past_key_values[idx] = past_key_values[idx][:, :, :maximum_length, :]
|
||
|
elif isinstance(past_key_values, DynamicCache):
|
||
|
for idx in range(len(past_key_values.key_cache)):
|
||
|
if past_key_values.value_cache[idx].shape[-1] != 0:
|
||
|
past_key_values.key_cache[idx] = past_key_values.key_cache[idx][:, :, :maximum_length, :]
|
||
|
past_key_values.value_cache[idx] = past_key_values.value_cache[idx][:, :, :maximum_length, :]
|
||
|
|
||
|
elif past_key_values is not None:
|
||
|
for idx in range(len(past_key_values)):
|
||
|
new_past.append(
|
||
|
(
|
||
|
past_key_values[idx][0][:, :, :maximum_length, :],
|
||
|
past_key_values[idx][1][:, :, :maximum_length, :],
|
||
|
)
|
||
|
)
|
||
|
past_key_values = tuple(new_past)
|
||
|
return past_key_values
|
||
|
|
||
|
|
||
|
def _prepare_attention_mask(model_kwargs: Dict[str, Any], new_length: int, is_encoder_decoder: bool) -> Dict[str, Any]:
|
||
|
"""Expands or crops the model's mask for decoding purposes, to the defined length"""
|
||
|
|
||
|
mask_key = "decoder_attention_mask" if is_encoder_decoder else "attention_mask"
|
||
|
if mask_key not in model_kwargs:
|
||
|
return model_kwargs
|
||
|
|
||
|
mask = model_kwargs[mask_key]
|
||
|
mask_length_diff = new_length - mask.shape[1]
|
||
|
|
||
|
if mask_length_diff < 0:
|
||
|
model_kwargs[mask_key] = mask[:, :mask_length_diff]
|
||
|
elif mask_length_diff > 0:
|
||
|
model_kwargs[mask_key] = torch.cat([mask, mask.new_ones((mask.shape[0], mask_length_diff))], dim=-1)
|
||
|
return model_kwargs
|
||
|
|
||
|
|
||
|
def _prepare_token_type_ids(model_kwargs: Dict[str, Any], new_length: int) -> Dict[str, Any]:
|
||
|
"""Expands or crops the model's token_type_ids for decoding purposes, to the defined length"""
|
||
|
if "token_type_ids" not in model_kwargs or model_kwargs["token_type_ids"] is None:
|
||
|
return model_kwargs
|
||
|
|
||
|
token_type_ids = model_kwargs["token_type_ids"]
|
||
|
final_token_type = token_type_ids[:, -1].unsqueeze(-1)
|
||
|
type_length_diff = new_length - token_type_ids.shape[1]
|
||
|
|
||
|
if type_length_diff < 0:
|
||
|
token_type_ids = token_type_ids[:, :type_length_diff]
|
||
|
elif type_length_diff > 0:
|
||
|
token_type_copies = final_token_type.repeat(1, type_length_diff)
|
||
|
model_kwargs["token_type_ids"] = torch.cat([model_kwargs["token_type_ids"], token_type_copies], dim=-1)
|
||
|
return model_kwargs
|