# coding=utf-8 # Copyright 2023 The HuggingFace Inc. team. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import copy from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple import torch from ..cache_utils import DynamicCache if TYPE_CHECKING: from ..modeling_utils import PreTrainedModel from .configuration_utils import GenerationConfig from .logits_process import LogitsProcessorList class CandidateGenerator: """Abstract base class for all candidate generators that can be applied during assisted generation.""" def get_candidates(self, input_ids: torch.LongTensor) -> Tuple[torch.LongTensor, Optional[torch.FloatTensor]]: """ Fetches the candidates to be tried for the current input. Args: input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): Indices of input sequence tokens in the vocabulary. [What are input IDs?](../glossary#input-ids) Return: `torch.LongTensor` of shape `(batch_size, candidate_length)` containing the candidate sequences to be assessed by the model and, optionally, a `torch.FloatTensor` of shape `(batch_size, candidate_length, vocabulary_size)` containing the logits associated to each candidate. """ raise NotImplementedError( f"{self.__class__} is an abstract class. Only classes inheriting this class can call `get_candidates`." ) def update_candidate_strategy(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, num_matches: int): """ Updates the candidate generation strategy based on the outcomes. Args: input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): Indices of input sequence tokens in the vocabulary. [What are input IDs?](../glossary#input-ids) scores (`torch.FloatTensor` of shape `(batch_size, candidate_length, config.vocab_size)`): Prediction scores of a language modeling head. These can be logits for each vocabulary when not using beam search or log softmax for each vocabulary token when using beam search num_matches (`int`): The number of matches between the candidate sequences and the model predictions. """ raise NotImplementedError( f"{self.__class__} is an abstract class. Only classes inheriting this class can call " "`update_candidate_strategy`." ) class AssistedCandidateGenerator(CandidateGenerator): """ `CandidateGenerator` class to be used for assisted generation and speculative decoding. This class generates candidates through the use of a smaller model. Read the following blog post for more information: https://huggingface.co/blog/assisted-generation Args: input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): Indices of input sequence tokens in the vocabulary. [What are input IDs?](../glossary#input-ids) assistant_model (`PreTrainedModel`): The model to be used for generating candidates. This model should be smaller than the main model. generation_config (`~generation.GenerationConfig`, *optional*): The generation configuration to be used as base parametrization for the generation call. logits_processor (`LogitsProcessorList`): An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsProcessor`] used to modify the prediction scores of the language modeling head applied at each generation step. model_kwargs (`Dict`): The keyword arguments that will be passed to the main model, and are used as base inputs for the assistant model as well. inputs_tensor (`torch.Tensor`, *optional*): The model input tensor. In encoder-decoder models, this is the encoder input. """ def __init__( self, input_ids: torch.LongTensor, assistant_model: "PreTrainedModel", generation_config: "GenerationConfig", logits_processor: "LogitsProcessorList", model_kwargs: Dict, inputs_tensor: Optional[torch.Tensor] = None, ): # Make sure all data at the same device as assistant model device = assistant_model.device input_ids = input_ids.to(device) if inputs_tensor is not None: inputs_tensor = inputs_tensor.to(device) # Prepare the assistant and the starting number of candidate tokens self.assistant_model = assistant_model self.num_assistant_tokens = assistant_model.generation_config.num_assistant_tokens # Prepare the kwargs for the assistant model assistant_kwargs = {} for key, value in model_kwargs.items(): # deepcopy crashes if we attempt to copy encoder outputs with grads if key not in ("encoder_outputs", "assistant_encoder_outputs"): assistant_kwargs[key] = ( value.detach().to(device) if isinstance(value, torch.Tensor) else copy.deepcopy(value) ) if "assistant_encoder_outputs" in model_kwargs: assistant_kwargs["encoder_outputs"] = model_kwargs["assistant_encoder_outputs"] elif assistant_model.config.is_encoder_decoder: inputs_tensor, model_input_name, assistant_kwargs = assistant_model._prepare_model_inputs( inputs_tensor, assistant_model.generation_config.bos_token_id, assistant_kwargs ) assistant_kwargs = assistant_model._prepare_encoder_decoder_kwargs_for_generation( inputs_tensor, assistant_kwargs, model_input_name ) elif "encoder_outputs" in model_kwargs: assistant_kwargs["encoder_outputs"] = model_kwargs["encoder_outputs"] self.assistant_kwargs = assistant_kwargs # Prepare assistant model's keys of inputs if assistant_model.config.is_encoder_decoder: # both are encoder-decoder self.input_ids_key = "decoder_input_ids" elif "encoder_outputs" in assistant_kwargs: # special case for encoder-decoder with decoder-only assistant (like DistilWhisper) self.input_ids_key = "input_ids" self.assistant_kwargs["attention_mask"] = self.assistant_kwargs.get( "decoder_attention_mask", torch.ones((input_ids.shape[0], 1), device=input_ids.device, dtype=torch.long), ) else: # both are decoder-only self.input_ids_key = "input_ids" # Prepare generation-related options. self.logits_processor = logits_processor self.generation_config = copy.deepcopy(generation_config) self.generation_config.return_dict_in_generate = True self.generation_config.output_scores = True # avoid unnecessary warnings that min_length is larger than max_new_tokens self.main_model_min_length = self.generation_config.min_length self.generation_config.min_length = 0 self.generation_config.min_new_tokens = None def get_candidates(self, input_ids: torch.LongTensor) -> Tuple[torch.LongTensor, Optional[torch.FloatTensor]]: """ Fetches the candidates to be tried for the current input. Args: input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): Indices of input sequence tokens in the vocabulary. [What are input IDs?](../glossary#input-ids) Return: `torch.LongTensor` of shape `(batch_size, candidate_length)` containing the candidate sequences to be assessed by the model and a `torch.FloatTensor` of shape `(batch_size, candidate_length, vocabulary_size)` containing the logits associated to each candidate. """ input_ids = input_ids.to(self.assistant_model.device) # Don't generate more than `max_length - 1` candidates since the target model generates one extra token. new_cur_len = input_ids.shape[-1] max_new_tokens = min(int(self.num_assistant_tokens), self.generation_config.max_length - new_cur_len - 1) min_new_tokens = max(min(max_new_tokens, self.main_model_min_length - new_cur_len), 0) if max_new_tokens == 0: return input_ids, None # 1. If it is not the first round of candidate generation, prepare the inputs based on the input_ids length # (which implicitly contains the number of accepted candidates from the previous round) has_past_key_values = self.assistant_kwargs.get("past_key_values", None) is not None if has_past_key_values: new_cache_size = new_cur_len - 1 self.assistant_kwargs["past_key_values"] = _crop_past_key_values( self.assistant_model, self.assistant_kwargs["past_key_values"], new_cache_size - 1 ) # the assistant does not have the token after the last match, hence the -1 self.assistant_kwargs = _prepare_attention_mask( self.assistant_kwargs, new_cur_len, self.assistant_model.config.is_encoder_decoder ) self.assistant_kwargs = _prepare_token_type_ids(self.assistant_kwargs, new_cur_len) # 2. Forecast next N tokens using the assistant model. assistant_generation_kwargs = { self.input_ids_key: input_ids, "min_new_tokens": min_new_tokens, "max_new_tokens": max_new_tokens, "generation_config": self.generation_config, "logits_processor": self.logits_processor, } assistant_output = self.assistant_model.generate(**assistant_generation_kwargs, **self.assistant_kwargs) # 3. Update variables for the next round of candidate generation self.assistant_kwargs["past_key_values"] = assistant_output.past_key_values # 4. Prepare variables for output candidate_logits = torch.stack(assistant_output.scores, dim=1) candidate_ids = assistant_output.sequences return candidate_ids, candidate_logits def update_candidate_strategy(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, num_matches: int): """ Updates the candidate generation strategy based on the outcomes. Args: input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): Indices of input sequence tokens in the vocabulary. [What are input IDs?](../glossary#input-ids) scores (`torch.FloatTensor` of shape `(batch_size, candidate_length, config.vocab_size)`): Prediction scores of a language modeling head. These can be logits for each vocabulary when not using beam search or log softmax for each vocabulary token when using beam search num_matches (`int`): The number of matches between the candidate sequences and the model predictions. """ # Adjust the max number of assistant tokens to use in the next iteration. This is a simple heuristic, # probably can be improved -- we want to balance the benefits of getting assistant tokens correct with the # cost of forecasting incorrect assistant tokens. if self.assistant_model.generation_config.num_assistant_tokens_schedule in { "heuristic", "heuristic_transient", }: if num_matches == int(self.num_assistant_tokens): self.num_assistant_tokens += 2.0 else: self.num_assistant_tokens = max(1.0, self.num_assistant_tokens - 1.0) class PromptLookupCandidateGenerator(CandidateGenerator): """ `CandidateGenerator` class to be used for prompt lookup generation. This class generates candidates by looking up likely continuations in the provided prompt (input_ids) itself. Read the following blog post for more information: https://github.com/apoorvumang/prompt-lookup-decoding Args: max_matching_ngram_size (`int`): The maximum ngram size to be considered for matching in the prompt num_output_tokens (`int`): The number of tokens to be output as candidate tokens. max_length (`int`): The number of total maximum tokens that can be generated. For decoder-only models that includes the prompt length. Defaults to 20, which is the max length used as default in generation config. """ def __init__( self, num_output_tokens: int = 10, max_matching_ngram_size: int = None, max_length: int = 20, ): self.num_output_tokens = num_output_tokens self.max_matching_ngram_size = max_matching_ngram_size if max_matching_ngram_size else 2 self.max_length = max_length if self.max_matching_ngram_size <= 0 or self.num_output_tokens <= 0: raise ValueError("Invalid max_matching_ngram_size or num_output_tokens") def get_candidates(self, input_ids: torch.LongTensor) -> Tuple[torch.LongTensor, Optional[torch.FloatTensor]]: """ Fetches the candidates to be tried for the current input. Args: input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): Indices of input sequence tokens in the vocabulary. [What are input IDs?](../glossary#input-ids) Return: `torch.LongTensor` of shape `(num_candidates, candidate_length)`: The candidate sequences to be tried. """ input_length = input_ids.size(1) # Don't generate more than `max_length - 1` candidates since the target model generates one extra token. if self.max_length == input_length + 1: return input_ids, None chosen_ids = None match_found = False for ngram_size in range(min(self.max_matching_ngram_size, input_length - 1), 0, -1): # Create sliding windows of size ngram_size windows = input_ids.unfold(dimension=1, size=ngram_size, step=1) # Convert ngram to a tensor for comparison ngram_tensor = input_ids[0, -ngram_size:] # Find where the windows match the ngram matches = (windows == ngram_tensor).all(dim=2) # Get the indices of matches match_indices = matches.nonzero(as_tuple=True)[1] # Iterate through match indices to find a valid continuation for idx in match_indices: start_idx = idx + ngram_size end_idx = start_idx + self.num_output_tokens end_idx = min(end_idx, input_length, self.max_length) if start_idx < end_idx: chosen_ids = input_ids[0, start_idx:end_idx] match_found = True break if match_found: break if chosen_ids is None or len(chosen_ids) == 0: # In case we didn't find a match return the input sequence unchanged, reverts back to autoregressive decoding return input_ids, None # Now need extend input_ids with chosen_ids chosen_ids = chosen_ids.unsqueeze(0) candidate_input_ids = torch.cat((input_ids, chosen_ids), dim=1) # assisted_generation expects logits as well, but we don't have those here, so returning None return candidate_input_ids, None def update_candidate_strategy(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, num_matches: int): """ Updates the candidate generation strategy based on the outcomes. Args: input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): Indices of input sequence tokens in the vocabulary. [What are input IDs?](../glossary#input-ids) scores (`torch.FloatTensor` of shape `(batch_size, candidate_length, config.vocab_size)`): Prediction scores of a language modeling head. These can be logits for each vocabulary when not using beam search or log softmax for each vocabulary token when using beam search num_matches (`int`): The number of matches between the candidate sequences and the model predictions. """ # Currently does nothing return def _crop_past_key_values(model, past_key_values, maximum_length): """Crops the past key values up to a certain maximum length.""" new_past = [] if model.config.is_encoder_decoder: for idx in range(len(past_key_values)): new_past.append( ( past_key_values[idx][0][:, :, :maximum_length, :], past_key_values[idx][1][:, :, :maximum_length, :], past_key_values[idx][2], past_key_values[idx][3], ) ) past_key_values = tuple(new_past) # bloom is special elif "bloom" in model.__class__.__name__.lower() or ( model.config.architectures is not None and "bloom" in model.config.architectures[0].lower() ): for idx in range(len(past_key_values)): new_past.append( ( past_key_values[idx][0][:, :, :maximum_length], past_key_values[idx][1][:, :maximum_length, :], ) ) past_key_values = tuple(new_past) # gptbigcode is too elif "gptbigcode" in model.__class__.__name__.lower() or ( model.config.architectures is not None and "gptbigcode" in model.config.architectures[0].lower() ): if model.config.multi_query: for idx in range(len(past_key_values)): past_key_values[idx] = past_key_values[idx][:, :maximum_length, :] else: for idx in range(len(past_key_values)): past_key_values[idx] = past_key_values[idx][:, :, :maximum_length, :] elif isinstance(past_key_values, DynamicCache): for idx in range(len(past_key_values.key_cache)): if past_key_values.value_cache[idx].shape[-1] != 0: past_key_values.key_cache[idx] = past_key_values.key_cache[idx][:, :, :maximum_length, :] past_key_values.value_cache[idx] = past_key_values.value_cache[idx][:, :, :maximum_length, :] elif past_key_values is not None: for idx in range(len(past_key_values)): new_past.append( ( past_key_values[idx][0][:, :, :maximum_length, :], past_key_values[idx][1][:, :, :maximum_length, :], ) ) past_key_values = tuple(new_past) return past_key_values def _prepare_attention_mask(model_kwargs: Dict[str, Any], new_length: int, is_encoder_decoder: bool) -> Dict[str, Any]: """Expands or crops the model's mask for decoding purposes, to the defined length""" mask_key = "decoder_attention_mask" if is_encoder_decoder else "attention_mask" if mask_key not in model_kwargs: return model_kwargs mask = model_kwargs[mask_key] mask_length_diff = new_length - mask.shape[1] if mask_length_diff < 0: model_kwargs[mask_key] = mask[:, :mask_length_diff] elif mask_length_diff > 0: model_kwargs[mask_key] = torch.cat([mask, mask.new_ones((mask.shape[0], mask_length_diff))], dim=-1) return model_kwargs def _prepare_token_type_ids(model_kwargs: Dict[str, Any], new_length: int) -> Dict[str, Any]: """Expands or crops the model's token_type_ids for decoding purposes, to the defined length""" if "token_type_ids" not in model_kwargs or model_kwargs["token_type_ids"] is None: return model_kwargs token_type_ids = model_kwargs["token_type_ids"] final_token_type = token_type_ids[:, -1].unsqueeze(-1) type_length_diff = new_length - token_type_ids.shape[1] if type_length_diff < 0: token_type_ids = token_type_ids[:, :type_length_diff] elif type_length_diff > 0: token_type_copies = final_token_type.repeat(1, type_length_diff) model_kwargs["token_type_ids"] = torch.cat([model_kwargs["token_type_ids"], token_type_copies], dim=-1) return model_kwargs