375 lines
18 KiB
Python
375 lines
18 KiB
Python
import enum
|
|
import warnings
|
|
from typing import Dict
|
|
|
|
from ..utils import add_end_docstrings, is_tf_available, is_torch_available
|
|
from .base import Pipeline, build_pipeline_init_args
|
|
|
|
|
|
if is_torch_available():
|
|
from ..models.auto.modeling_auto import MODEL_FOR_CAUSAL_LM_MAPPING_NAMES
|
|
|
|
if is_tf_available():
|
|
import tensorflow as tf
|
|
|
|
from ..models.auto.modeling_tf_auto import TF_MODEL_FOR_CAUSAL_LM_MAPPING_NAMES
|
|
|
|
|
|
class ReturnType(enum.Enum):
|
|
TENSORS = 0
|
|
NEW_TEXT = 1
|
|
FULL_TEXT = 2
|
|
|
|
|
|
class Chat:
|
|
"""This class is intended to just be used internally in this pipeline and not exposed to users. We convert chats
|
|
to this format because the rest of the pipeline code tends to assume that lists of messages are
|
|
actually a batch of samples rather than messages in the same conversation."""
|
|
|
|
def __init__(self, messages: Dict):
|
|
for message in messages:
|
|
if not ("role" in message and "content" in message):
|
|
raise ValueError("When passing chat dicts as input, each dict must have a 'role' and 'content' key.")
|
|
self.messages = messages
|
|
|
|
|
|
@add_end_docstrings(build_pipeline_init_args(has_tokenizer=True))
|
|
class TextGenerationPipeline(Pipeline):
|
|
"""
|
|
Language generation pipeline using any `ModelWithLMHead`. This pipeline predicts the words that will follow a
|
|
specified text prompt. It can also accept one or more chats. Each chat takes the form of a list of dicts,
|
|
where each dict contains "role" and "content" keys.
|
|
|
|
Example:
|
|
|
|
```python
|
|
>>> from transformers import pipeline
|
|
|
|
>>> generator = pipeline(model="openai-community/gpt2")
|
|
>>> generator("I can't believe you did such a ", do_sample=False)
|
|
[{'generated_text': "I can't believe you did such a icky thing to me. I'm so sorry. I'm so sorry. I'm so sorry. I'm so sorry. I'm so sorry. I'm so sorry. I'm so sorry. I"}]
|
|
|
|
>>> # These parameters will return suggestions, and only the newly created text making it easier for prompting suggestions.
|
|
>>> outputs = generator("My tart needs some", num_return_sequences=4, return_full_text=False)
|
|
```
|
|
|
|
Learn more about the basics of using a pipeline in the [pipeline tutorial](../pipeline_tutorial). You can pass text
|
|
generation parameters to this pipeline to control stopping criteria, decoding strategy, and more. Learn more about
|
|
text generation parameters in [Text generation strategies](../generation_strategies) and [Text
|
|
generation](text_generation).
|
|
|
|
This language generation pipeline can currently be loaded from [`pipeline`] using the following task identifier:
|
|
`"text-generation"`.
|
|
|
|
The models that this pipeline can use are models that have been trained with an autoregressive language modeling
|
|
objective, which includes the uni-directional models in the library (e.g. openai-community/gpt2). See the list of available models
|
|
on [huggingface.co/models](https://huggingface.co/models?filter=text-generation).
|
|
"""
|
|
|
|
# Prefix text to help Transformer-XL and XLNet with short prompts as proposed by Aman Rusia
|
|
# in https://github.com/rusiaaman/XLNet-gen#methodology
|
|
# and https://medium.com/@amanrusia/xlnet-speaks-comparison-to-gpt-2-ea1a4e9ba39e
|
|
|
|
XL_PREFIX = """
|
|
In 1991, the remains of Russian Tsar Nicholas II and his family (except for Alexei and Maria) are discovered. The
|
|
voice of Nicholas's young son, Tsarevich Alexei Nikolaevich, narrates the remainder of the story. 1883 Western
|
|
Siberia, a young Grigori Rasputin is asked by his father and a group of men to perform magic. Rasputin has a vision
|
|
and denounces one of the men as a horse thief. Although his father initially slaps him for making such an
|
|
accusation, Rasputin watches as the man is chased outside and beaten. Twenty years later, Rasputin sees a vision of
|
|
the Virgin Mary, prompting him to become a priest. Rasputin quickly becomes famous, with people, even a bishop,
|
|
begging for his blessing. <eod> </s> <eos>
|
|
"""
|
|
|
|
def __init__(self, *args, **kwargs):
|
|
super().__init__(*args, **kwargs)
|
|
self.check_model_type(
|
|
TF_MODEL_FOR_CAUSAL_LM_MAPPING_NAMES if self.framework == "tf" else MODEL_FOR_CAUSAL_LM_MAPPING_NAMES
|
|
)
|
|
if "prefix" not in self._preprocess_params:
|
|
# This is very specific. The logic is quite complex and needs to be done
|
|
# as a "default".
|
|
# It also defines both some preprocess_kwargs and generate_kwargs
|
|
# which is why we cannot put them in their respective methods.
|
|
prefix = None
|
|
if self.model.config.prefix is not None:
|
|
prefix = self.model.config.prefix
|
|
if prefix is None and self.model.__class__.__name__ in [
|
|
"XLNetLMHeadModel",
|
|
"TransfoXLLMHeadModel",
|
|
"TFXLNetLMHeadModel",
|
|
"TFTransfoXLLMHeadModel",
|
|
]:
|
|
# For XLNet and TransformerXL we add an article to the prompt to give more state to the model.
|
|
prefix = self.XL_PREFIX
|
|
if prefix is not None:
|
|
# Recalculate some generate_kwargs linked to prefix.
|
|
preprocess_params, forward_params, _ = self._sanitize_parameters(prefix=prefix, **self._forward_params)
|
|
self._preprocess_params = {**self._preprocess_params, **preprocess_params}
|
|
self._forward_params = {**self._forward_params, **forward_params}
|
|
|
|
def _sanitize_parameters(
|
|
self,
|
|
return_full_text=None,
|
|
return_tensors=None,
|
|
return_text=None,
|
|
return_type=None,
|
|
clean_up_tokenization_spaces=None,
|
|
prefix=None,
|
|
handle_long_generation=None,
|
|
stop_sequence=None,
|
|
add_special_tokens=False,
|
|
truncation=None,
|
|
padding=False,
|
|
max_length=None,
|
|
**generate_kwargs,
|
|
):
|
|
preprocess_params = {
|
|
"add_special_tokens": add_special_tokens,
|
|
"truncation": truncation,
|
|
"padding": padding,
|
|
"max_length": max_length,
|
|
}
|
|
if max_length is not None:
|
|
generate_kwargs["max_length"] = max_length
|
|
|
|
if prefix is not None:
|
|
preprocess_params["prefix"] = prefix
|
|
if prefix:
|
|
prefix_inputs = self.tokenizer(
|
|
prefix, padding=False, add_special_tokens=add_special_tokens, return_tensors=self.framework
|
|
)
|
|
generate_kwargs["prefix_length"] = prefix_inputs["input_ids"].shape[-1]
|
|
|
|
if handle_long_generation is not None:
|
|
if handle_long_generation not in {"hole"}:
|
|
raise ValueError(
|
|
f"{handle_long_generation} is not a valid value for `handle_long_generation` parameter expected"
|
|
" [None, 'hole']"
|
|
)
|
|
preprocess_params["handle_long_generation"] = handle_long_generation
|
|
|
|
preprocess_params.update(generate_kwargs)
|
|
forward_params = generate_kwargs
|
|
|
|
postprocess_params = {}
|
|
if return_full_text is not None and return_type is None:
|
|
if return_text is not None:
|
|
raise ValueError("`return_text` is mutually exclusive with `return_full_text`")
|
|
if return_tensors is not None:
|
|
raise ValueError("`return_full_text` is mutually exclusive with `return_tensors`")
|
|
return_type = ReturnType.FULL_TEXT if return_full_text else ReturnType.NEW_TEXT
|
|
if return_tensors is not None and return_type is None:
|
|
if return_text is not None:
|
|
raise ValueError("`return_text` is mutually exclusive with `return_tensors`")
|
|
return_type = ReturnType.TENSORS
|
|
if return_type is not None:
|
|
postprocess_params["return_type"] = return_type
|
|
if clean_up_tokenization_spaces is not None:
|
|
postprocess_params["clean_up_tokenization_spaces"] = clean_up_tokenization_spaces
|
|
|
|
if stop_sequence is not None:
|
|
stop_sequence_ids = self.tokenizer.encode(stop_sequence, add_special_tokens=False)
|
|
if len(stop_sequence_ids) > 1:
|
|
warnings.warn(
|
|
"Stopping on a multiple token sequence is not yet supported on transformers. The first token of"
|
|
" the stop sequence will be used as the stop sequence string in the interim."
|
|
)
|
|
generate_kwargs["eos_token_id"] = stop_sequence_ids[0]
|
|
|
|
return preprocess_params, forward_params, postprocess_params
|
|
|
|
# overriding _parse_and_tokenize to allow for unusual language-modeling tokenizer arguments
|
|
def _parse_and_tokenize(self, *args, **kwargs):
|
|
"""
|
|
Parse arguments and tokenize
|
|
"""
|
|
# Parse arguments
|
|
if self.model.__class__.__name__ in ["TransfoXLLMHeadModel"]:
|
|
kwargs.update({"add_space_before_punct_symbol": True})
|
|
|
|
return super()._parse_and_tokenize(*args, **kwargs)
|
|
|
|
def __call__(self, text_inputs, **kwargs):
|
|
"""
|
|
Complete the prompt(s) given as inputs.
|
|
|
|
Args:
|
|
text_inputs (`str` or `List[str]`):
|
|
One or several prompts (or one list of prompts) to complete.
|
|
return_tensors (`bool`, *optional*, defaults to `False`):
|
|
Whether or not to return the tensors of predictions (as token indices) in the outputs. If set to
|
|
`True`, the decoded text is not returned.
|
|
return_text (`bool`, *optional*, defaults to `True`):
|
|
Whether or not to return the decoded texts in the outputs.
|
|
return_full_text (`bool`, *optional*, defaults to `True`):
|
|
If set to `False` only added text is returned, otherwise the full text is returned. Only meaningful if
|
|
*return_text* is set to True.
|
|
clean_up_tokenization_spaces (`bool`, *optional*, defaults to `True`):
|
|
Whether or not to clean up the potential extra spaces in the text output.
|
|
prefix (`str`, *optional*):
|
|
Prefix added to prompt.
|
|
handle_long_generation (`str`, *optional*):
|
|
By default, this pipelines does not handle long generation (ones that exceed in one form or the other
|
|
the model maximum length). There is no perfect way to adress this (more info
|
|
:https://github.com/huggingface/transformers/issues/14033#issuecomment-948385227). This provides common
|
|
strategies to work around that problem depending on your use case.
|
|
|
|
- `None` : default strategy where nothing in particular happens
|
|
- `"hole"`: Truncates left of input, and leaves a gap wide enough to let generation happen (might
|
|
truncate a lot of the prompt and not suitable when generation exceed the model capacity)
|
|
generate_kwargs (`dict`, *optional*):
|
|
Additional keyword arguments to pass along to the generate method of the model (see the generate method
|
|
corresponding to your framework [here](./model#generative-models)).
|
|
|
|
Return:
|
|
A list or a list of list of `dict`: Returns one of the following dictionaries (cannot return a combination
|
|
of both `generated_text` and `generated_token_ids`):
|
|
|
|
- **generated_text** (`str`, present when `return_text=True`) -- The generated text.
|
|
- **generated_token_ids** (`torch.Tensor` or `tf.Tensor`, present when `return_tensors=True`) -- The token
|
|
ids of the generated text.
|
|
"""
|
|
if isinstance(text_inputs, (list, tuple)) and isinstance(text_inputs[0], (list, tuple, dict)):
|
|
# We have one or more prompts in list-of-dicts format, so this is chat mode
|
|
if isinstance(text_inputs[0], dict):
|
|
return super().__call__(Chat(text_inputs), **kwargs)
|
|
else:
|
|
chats = [Chat(chat) for chat in text_inputs] # 🐈 🐈 🐈
|
|
return super().__call__(chats, **kwargs)
|
|
else:
|
|
return super().__call__(text_inputs, **kwargs)
|
|
|
|
def preprocess(
|
|
self,
|
|
prompt_text,
|
|
prefix="",
|
|
handle_long_generation=None,
|
|
add_special_tokens=False,
|
|
truncation=None,
|
|
padding=False,
|
|
max_length=None,
|
|
**generate_kwargs,
|
|
):
|
|
if isinstance(prompt_text, Chat):
|
|
inputs = self.tokenizer.apply_chat_template(
|
|
prompt_text.messages,
|
|
truncation=truncation,
|
|
padding=padding,
|
|
max_length=max_length,
|
|
add_generation_prompt=True,
|
|
return_dict=True,
|
|
return_tensors=self.framework,
|
|
)
|
|
else:
|
|
inputs = self.tokenizer(
|
|
prefix + prompt_text,
|
|
truncation=truncation,
|
|
padding=padding,
|
|
max_length=max_length,
|
|
add_special_tokens=add_special_tokens,
|
|
return_tensors=self.framework,
|
|
)
|
|
inputs["prompt_text"] = prompt_text
|
|
|
|
if handle_long_generation == "hole":
|
|
cur_len = inputs["input_ids"].shape[-1]
|
|
if "max_new_tokens" in generate_kwargs:
|
|
new_tokens = generate_kwargs["max_new_tokens"]
|
|
else:
|
|
new_tokens = generate_kwargs.get("max_length", self.model.config.max_length) - cur_len
|
|
if new_tokens < 0:
|
|
raise ValueError("We cannot infer how many new tokens are expected")
|
|
if cur_len + new_tokens > self.tokenizer.model_max_length:
|
|
keep_length = self.tokenizer.model_max_length - new_tokens
|
|
if keep_length <= 0:
|
|
raise ValueError(
|
|
"We cannot use `hole` to handle this generation the number of desired tokens exceeds the"
|
|
" models max length"
|
|
)
|
|
|
|
inputs["input_ids"] = inputs["input_ids"][:, -keep_length:]
|
|
if "attention_mask" in inputs:
|
|
inputs["attention_mask"] = inputs["attention_mask"][:, -keep_length:]
|
|
|
|
return inputs
|
|
|
|
def _forward(self, model_inputs, **generate_kwargs):
|
|
input_ids = model_inputs["input_ids"]
|
|
attention_mask = model_inputs.get("attention_mask", None)
|
|
# Allow empty prompts
|
|
if input_ids.shape[1] == 0:
|
|
input_ids = None
|
|
attention_mask = None
|
|
in_b = 1
|
|
else:
|
|
in_b = input_ids.shape[0]
|
|
prompt_text = model_inputs.pop("prompt_text")
|
|
|
|
# If there is a prefix, we may need to adjust the generation length. Do so without permanently modifying
|
|
# generate_kwargs, as some of the parameterization may come from the initialization of the pipeline.
|
|
prefix_length = generate_kwargs.pop("prefix_length", 0)
|
|
if prefix_length > 0:
|
|
has_max_new_tokens = "max_new_tokens" in generate_kwargs or (
|
|
"generation_config" in generate_kwargs
|
|
and generate_kwargs["generation_config"].max_new_tokens is not None
|
|
)
|
|
if not has_max_new_tokens:
|
|
generate_kwargs["max_length"] = generate_kwargs.get("max_length") or self.model.config.max_length
|
|
generate_kwargs["max_length"] += prefix_length
|
|
has_min_new_tokens = "min_new_tokens" in generate_kwargs or (
|
|
"generation_config" in generate_kwargs
|
|
and generate_kwargs["generation_config"].min_new_tokens is not None
|
|
)
|
|
if not has_min_new_tokens and "min_length" in generate_kwargs:
|
|
generate_kwargs["min_length"] += prefix_length
|
|
|
|
# BS x SL
|
|
generated_sequence = self.model.generate(input_ids=input_ids, attention_mask=attention_mask, **generate_kwargs)
|
|
out_b = generated_sequence.shape[0]
|
|
if self.framework == "pt":
|
|
generated_sequence = generated_sequence.reshape(in_b, out_b // in_b, *generated_sequence.shape[1:])
|
|
elif self.framework == "tf":
|
|
generated_sequence = tf.reshape(generated_sequence, (in_b, out_b // in_b, *generated_sequence.shape[1:]))
|
|
return {"generated_sequence": generated_sequence, "input_ids": input_ids, "prompt_text": prompt_text}
|
|
|
|
def postprocess(self, model_outputs, return_type=ReturnType.FULL_TEXT, clean_up_tokenization_spaces=True):
|
|
generated_sequence = model_outputs["generated_sequence"][0]
|
|
input_ids = model_outputs["input_ids"]
|
|
prompt_text = model_outputs["prompt_text"]
|
|
generated_sequence = generated_sequence.numpy().tolist()
|
|
records = []
|
|
for sequence in generated_sequence:
|
|
if return_type == ReturnType.TENSORS:
|
|
record = {"generated_token_ids": sequence}
|
|
elif return_type in {ReturnType.NEW_TEXT, ReturnType.FULL_TEXT}:
|
|
# Decode text
|
|
text = self.tokenizer.decode(
|
|
sequence,
|
|
skip_special_tokens=True,
|
|
clean_up_tokenization_spaces=clean_up_tokenization_spaces,
|
|
)
|
|
|
|
# Remove PADDING prompt of the sequence if XLNet or Transfo-XL model is used
|
|
if input_ids is None:
|
|
prompt_length = 0
|
|
else:
|
|
prompt_length = len(
|
|
self.tokenizer.decode(
|
|
input_ids[0],
|
|
skip_special_tokens=True,
|
|
clean_up_tokenization_spaces=clean_up_tokenization_spaces,
|
|
)
|
|
)
|
|
|
|
all_text = text[prompt_length:]
|
|
if return_type == ReturnType.FULL_TEXT:
|
|
if isinstance(prompt_text, str):
|
|
all_text = prompt_text + all_text
|
|
elif isinstance(prompt_text, Chat):
|
|
all_text = prompt_text.messages + [{"role": "assistant", "content": all_text}]
|
|
|
|
record = {"generated_text": all_text}
|
|
records.append(record)
|
|
|
|
return records
|