ai-content-maker/.venv/Lib/site-packages/transformers/generation/candidate_generator.py

426 lines
20 KiB
Python

# coding=utf-8
# Copyright 2023 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import copy
from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple
import torch
from ..cache_utils import DynamicCache
if TYPE_CHECKING:
from ..modeling_utils import PreTrainedModel
from .configuration_utils import GenerationConfig
from .logits_process import LogitsProcessorList
class CandidateGenerator:
"""Abstract base class for all candidate generators that can be applied during assisted generation."""
def get_candidates(self, input_ids: torch.LongTensor) -> Tuple[torch.LongTensor, Optional[torch.FloatTensor]]:
"""
Fetches the candidates to be tried for the current input.
Args:
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
Indices of input sequence tokens in the vocabulary. [What are input IDs?](../glossary#input-ids)
Return:
`torch.LongTensor` of shape `(batch_size, candidate_length)` containing the candidate sequences to be
assessed by the model and, optionally, a `torch.FloatTensor` of shape `(batch_size, candidate_length,
vocabulary_size)` containing the logits associated to each candidate.
"""
raise NotImplementedError(
f"{self.__class__} is an abstract class. Only classes inheriting this class can call `get_candidates`."
)
def update_candidate_strategy(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, num_matches: int):
"""
Updates the candidate generation strategy based on the outcomes.
Args:
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
Indices of input sequence tokens in the vocabulary. [What are input IDs?](../glossary#input-ids)
scores (`torch.FloatTensor` of shape `(batch_size, candidate_length, config.vocab_size)`):
Prediction scores of a language modeling head. These can be logits for each vocabulary when not using
beam search or log softmax for each vocabulary token when using beam search
num_matches (`int`):
The number of matches between the candidate sequences and the model predictions.
"""
raise NotImplementedError(
f"{self.__class__} is an abstract class. Only classes inheriting this class can call "
"`update_candidate_strategy`."
)
class AssistedCandidateGenerator(CandidateGenerator):
"""
`CandidateGenerator` class to be used for assisted generation and speculative decoding. This class generates
candidates through the use of a smaller model. Read the following blog post for more information:
https://huggingface.co/blog/assisted-generation
Args:
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
Indices of input sequence tokens in the vocabulary. [What are input IDs?](../glossary#input-ids)
assistant_model (`PreTrainedModel`):
The model to be used for generating candidates. This model should be smaller than the main model.
generation_config (`~generation.GenerationConfig`, *optional*):
The generation configuration to be used as base parametrization for the generation call.
logits_processor (`LogitsProcessorList`):
An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsProcessor`]
used to modify the prediction scores of the language modeling head applied at each generation step.
model_kwargs (`Dict`):
The keyword arguments that will be passed to the main model, and are used as base inputs for the assistant
model as well.
inputs_tensor (`torch.Tensor`, *optional*):
The model input tensor. In encoder-decoder models, this is the encoder input.
"""
def __init__(
self,
input_ids: torch.LongTensor,
assistant_model: "PreTrainedModel",
generation_config: "GenerationConfig",
logits_processor: "LogitsProcessorList",
model_kwargs: Dict,
inputs_tensor: Optional[torch.Tensor] = None,
):
# Make sure all data at the same device as assistant model
device = assistant_model.device
input_ids = input_ids.to(device)
if inputs_tensor is not None:
inputs_tensor = inputs_tensor.to(device)
# Prepare the assistant and the starting number of candidate tokens
self.assistant_model = assistant_model
self.num_assistant_tokens = assistant_model.generation_config.num_assistant_tokens
# Prepare the kwargs for the assistant model
assistant_kwargs = {}
for key, value in model_kwargs.items(): # deepcopy crashes if we attempt to copy encoder outputs with grads
if key not in ("encoder_outputs", "assistant_encoder_outputs"):
assistant_kwargs[key] = (
value.detach().to(device) if isinstance(value, torch.Tensor) else copy.deepcopy(value)
)
if "assistant_encoder_outputs" in model_kwargs:
assistant_kwargs["encoder_outputs"] = model_kwargs["assistant_encoder_outputs"]
elif assistant_model.config.is_encoder_decoder:
inputs_tensor, model_input_name, assistant_kwargs = assistant_model._prepare_model_inputs(
inputs_tensor, assistant_model.generation_config.bos_token_id, assistant_kwargs
)
assistant_kwargs = assistant_model._prepare_encoder_decoder_kwargs_for_generation(
inputs_tensor, assistant_kwargs, model_input_name
)
elif "encoder_outputs" in model_kwargs:
assistant_kwargs["encoder_outputs"] = model_kwargs["encoder_outputs"]
self.assistant_kwargs = assistant_kwargs
# Prepare assistant model's keys of inputs
if assistant_model.config.is_encoder_decoder:
# both are encoder-decoder
self.input_ids_key = "decoder_input_ids"
elif "encoder_outputs" in assistant_kwargs:
# special case for encoder-decoder with decoder-only assistant (like DistilWhisper)
self.input_ids_key = "input_ids"
self.assistant_kwargs["attention_mask"] = self.assistant_kwargs.get(
"decoder_attention_mask",
torch.ones((input_ids.shape[0], 1), device=input_ids.device, dtype=torch.long),
)
else:
# both are decoder-only
self.input_ids_key = "input_ids"
# Prepare generation-related options.
self.logits_processor = logits_processor
self.generation_config = copy.deepcopy(generation_config)
self.generation_config.return_dict_in_generate = True
self.generation_config.output_scores = True
# avoid unnecessary warnings that min_length is larger than max_new_tokens
self.main_model_min_length = self.generation_config.min_length
self.generation_config.min_length = 0
self.generation_config.min_new_tokens = None
def get_candidates(self, input_ids: torch.LongTensor) -> Tuple[torch.LongTensor, Optional[torch.FloatTensor]]:
"""
Fetches the candidates to be tried for the current input.
Args:
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
Indices of input sequence tokens in the vocabulary. [What are input IDs?](../glossary#input-ids)
Return:
`torch.LongTensor` of shape `(batch_size, candidate_length)` containing the candidate sequences to be
assessed by the model and a `torch.FloatTensor` of shape `(batch_size, candidate_length,
vocabulary_size)` containing the logits associated to each candidate.
"""
input_ids = input_ids.to(self.assistant_model.device)
# Don't generate more than `max_length - 1` candidates since the target model generates one extra token.
new_cur_len = input_ids.shape[-1]
max_new_tokens = min(int(self.num_assistant_tokens), self.generation_config.max_length - new_cur_len - 1)
min_new_tokens = max(min(max_new_tokens, self.main_model_min_length - new_cur_len), 0)
if max_new_tokens == 0:
return input_ids, None
# 1. If it is not the first round of candidate generation, prepare the inputs based on the input_ids length
# (which implicitly contains the number of accepted candidates from the previous round)
has_past_key_values = self.assistant_kwargs.get("past_key_values", None) is not None
if has_past_key_values:
new_cache_size = new_cur_len - 1
self.assistant_kwargs["past_key_values"] = _crop_past_key_values(
self.assistant_model, self.assistant_kwargs["past_key_values"], new_cache_size - 1
) # the assistant does not have the token after the last match, hence the -1
self.assistant_kwargs = _prepare_attention_mask(
self.assistant_kwargs, new_cur_len, self.assistant_model.config.is_encoder_decoder
)
self.assistant_kwargs = _prepare_token_type_ids(self.assistant_kwargs, new_cur_len)
# 2. Forecast next N tokens using the assistant model.
assistant_generation_kwargs = {
self.input_ids_key: input_ids,
"min_new_tokens": min_new_tokens,
"max_new_tokens": max_new_tokens,
"generation_config": self.generation_config,
"logits_processor": self.logits_processor,
}
assistant_output = self.assistant_model.generate(**assistant_generation_kwargs, **self.assistant_kwargs)
# 3. Update variables for the next round of candidate generation
self.assistant_kwargs["past_key_values"] = assistant_output.past_key_values
# 4. Prepare variables for output
candidate_logits = torch.stack(assistant_output.scores, dim=1)
candidate_ids = assistant_output.sequences
return candidate_ids, candidate_logits
def update_candidate_strategy(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, num_matches: int):
"""
Updates the candidate generation strategy based on the outcomes.
Args:
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
Indices of input sequence tokens in the vocabulary. [What are input IDs?](../glossary#input-ids)
scores (`torch.FloatTensor` of shape `(batch_size, candidate_length, config.vocab_size)`):
Prediction scores of a language modeling head. These can be logits for each vocabulary when not using
beam search or log softmax for each vocabulary token when using beam search
num_matches (`int`):
The number of matches between the candidate sequences and the model predictions.
"""
# Adjust the max number of assistant tokens to use in the next iteration. This is a simple heuristic,
# probably can be improved -- we want to balance the benefits of getting assistant tokens correct with the
# cost of forecasting incorrect assistant tokens.
if self.assistant_model.generation_config.num_assistant_tokens_schedule in {
"heuristic",
"heuristic_transient",
}:
if num_matches == int(self.num_assistant_tokens):
self.num_assistant_tokens += 2.0
else:
self.num_assistant_tokens = max(1.0, self.num_assistant_tokens - 1.0)
class PromptLookupCandidateGenerator(CandidateGenerator):
"""
`CandidateGenerator` class to be used for prompt lookup generation. This class generates candidates by looking up
likely continuations in the provided prompt (input_ids) itself.
Read the following blog post for more information: https://github.com/apoorvumang/prompt-lookup-decoding
Args:
max_matching_ngram_size (`int`):
The maximum ngram size to be considered for matching in the prompt
num_output_tokens (`int`):
The number of tokens to be output as candidate tokens.
max_length (`int`):
The number of total maximum tokens that can be generated. For decoder-only models that includes the prompt length.
Defaults to 20, which is the max length used as default in generation config.
"""
def __init__(
self,
num_output_tokens: int = 10,
max_matching_ngram_size: int = None,
max_length: int = 20,
):
self.num_output_tokens = num_output_tokens
self.max_matching_ngram_size = max_matching_ngram_size if max_matching_ngram_size else 2
self.max_length = max_length
if self.max_matching_ngram_size <= 0 or self.num_output_tokens <= 0:
raise ValueError("Invalid max_matching_ngram_size or num_output_tokens")
def get_candidates(self, input_ids: torch.LongTensor) -> Tuple[torch.LongTensor, Optional[torch.FloatTensor]]:
"""
Fetches the candidates to be tried for the current input.
Args:
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
Indices of input sequence tokens in the vocabulary. [What are input IDs?](../glossary#input-ids)
Return:
`torch.LongTensor` of shape `(num_candidates, candidate_length)`: The candidate sequences to be tried.
"""
input_length = input_ids.size(1)
# Don't generate more than `max_length - 1` candidates since the target model generates one extra token.
if self.max_length == input_length + 1:
return input_ids, None
chosen_ids = None
match_found = False
for ngram_size in range(min(self.max_matching_ngram_size, input_length - 1), 0, -1):
# Create sliding windows of size ngram_size
windows = input_ids.unfold(dimension=1, size=ngram_size, step=1)
# Convert ngram to a tensor for comparison
ngram_tensor = input_ids[0, -ngram_size:]
# Find where the windows match the ngram
matches = (windows == ngram_tensor).all(dim=2)
# Get the indices of matches
match_indices = matches.nonzero(as_tuple=True)[1]
# Iterate through match indices to find a valid continuation
for idx in match_indices:
start_idx = idx + ngram_size
end_idx = start_idx + self.num_output_tokens
end_idx = min(end_idx, input_length, self.max_length)
if start_idx < end_idx:
chosen_ids = input_ids[0, start_idx:end_idx]
match_found = True
break
if match_found:
break
if chosen_ids is None or len(chosen_ids) == 0:
# In case we didn't find a match return the input sequence unchanged, reverts back to autoregressive decoding
return input_ids, None
# Now need extend input_ids with chosen_ids
chosen_ids = chosen_ids.unsqueeze(0)
candidate_input_ids = torch.cat((input_ids, chosen_ids), dim=1)
# assisted_generation expects logits as well, but we don't have those here, so returning None
return candidate_input_ids, None
def update_candidate_strategy(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, num_matches: int):
"""
Updates the candidate generation strategy based on the outcomes.
Args:
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
Indices of input sequence tokens in the vocabulary. [What are input IDs?](../glossary#input-ids)
scores (`torch.FloatTensor` of shape `(batch_size, candidate_length, config.vocab_size)`):
Prediction scores of a language modeling head. These can be logits for each vocabulary when not using
beam search or log softmax for each vocabulary token when using beam search
num_matches (`int`):
The number of matches between the candidate sequences and the model predictions.
"""
# Currently does nothing
return
def _crop_past_key_values(model, past_key_values, maximum_length):
"""Crops the past key values up to a certain maximum length."""
new_past = []
if model.config.is_encoder_decoder:
for idx in range(len(past_key_values)):
new_past.append(
(
past_key_values[idx][0][:, :, :maximum_length, :],
past_key_values[idx][1][:, :, :maximum_length, :],
past_key_values[idx][2],
past_key_values[idx][3],
)
)
past_key_values = tuple(new_past)
# bloom is special
elif "bloom" in model.__class__.__name__.lower() or (
model.config.architectures is not None and "bloom" in model.config.architectures[0].lower()
):
for idx in range(len(past_key_values)):
new_past.append(
(
past_key_values[idx][0][:, :, :maximum_length],
past_key_values[idx][1][:, :maximum_length, :],
)
)
past_key_values = tuple(new_past)
# gptbigcode is too
elif "gptbigcode" in model.__class__.__name__.lower() or (
model.config.architectures is not None and "gptbigcode" in model.config.architectures[0].lower()
):
if model.config.multi_query:
for idx in range(len(past_key_values)):
past_key_values[idx] = past_key_values[idx][:, :maximum_length, :]
else:
for idx in range(len(past_key_values)):
past_key_values[idx] = past_key_values[idx][:, :, :maximum_length, :]
elif isinstance(past_key_values, DynamicCache):
for idx in range(len(past_key_values.key_cache)):
if past_key_values.value_cache[idx].shape[-1] != 0:
past_key_values.key_cache[idx] = past_key_values.key_cache[idx][:, :, :maximum_length, :]
past_key_values.value_cache[idx] = past_key_values.value_cache[idx][:, :, :maximum_length, :]
elif past_key_values is not None:
for idx in range(len(past_key_values)):
new_past.append(
(
past_key_values[idx][0][:, :, :maximum_length, :],
past_key_values[idx][1][:, :, :maximum_length, :],
)
)
past_key_values = tuple(new_past)
return past_key_values
def _prepare_attention_mask(model_kwargs: Dict[str, Any], new_length: int, is_encoder_decoder: bool) -> Dict[str, Any]:
"""Expands or crops the model's mask for decoding purposes, to the defined length"""
mask_key = "decoder_attention_mask" if is_encoder_decoder else "attention_mask"
if mask_key not in model_kwargs:
return model_kwargs
mask = model_kwargs[mask_key]
mask_length_diff = new_length - mask.shape[1]
if mask_length_diff < 0:
model_kwargs[mask_key] = mask[:, :mask_length_diff]
elif mask_length_diff > 0:
model_kwargs[mask_key] = torch.cat([mask, mask.new_ones((mask.shape[0], mask_length_diff))], dim=-1)
return model_kwargs
def _prepare_token_type_ids(model_kwargs: Dict[str, Any], new_length: int) -> Dict[str, Any]:
"""Expands or crops the model's token_type_ids for decoding purposes, to the defined length"""
if "token_type_ids" not in model_kwargs or model_kwargs["token_type_ids"] is None:
return model_kwargs
token_type_ids = model_kwargs["token_type_ids"]
final_token_type = token_type_ids[:, -1].unsqueeze(-1)
type_length_diff = new_length - token_type_ids.shape[1]
if type_length_diff < 0:
token_type_ids = token_type_ids[:, :type_length_diff]
elif type_length_diff > 0:
token_type_copies = final_token_type.repeat(1, type_length_diff)
model_kwargs["token_type_ids"] = torch.cat([model_kwargs["token_type_ids"], token_type_copies], dim=-1)
return model_kwargs