228 lines
9.0 KiB
Python
228 lines
9.0 KiB
Python
|
# coding=utf-8
|
||
|
# Copyright 2023 The HuggingFace Inc. team.
|
||
|
#
|
||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||
|
# you may not use this file except in compliance with the License.
|
||
|
# You may obtain a copy of the License at
|
||
|
#
|
||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||
|
#
|
||
|
# Unless required by applicable law or agreed to in writing, software
|
||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
|
# See the License for the specific language governing permissions and
|
||
|
# limitations under the License.
|
||
|
|
||
|
from queue import Queue
|
||
|
from typing import TYPE_CHECKING, Optional
|
||
|
|
||
|
|
||
|
if TYPE_CHECKING:
|
||
|
from ..models.auto import AutoTokenizer
|
||
|
|
||
|
|
||
|
class BaseStreamer:
|
||
|
"""
|
||
|
Base class from which `.generate()` streamers should inherit.
|
||
|
"""
|
||
|
|
||
|
def put(self, value):
|
||
|
"""Function that is called by `.generate()` to push new tokens"""
|
||
|
raise NotImplementedError()
|
||
|
|
||
|
def end(self):
|
||
|
"""Function that is called by `.generate()` to signal the end of generation"""
|
||
|
raise NotImplementedError()
|
||
|
|
||
|
|
||
|
class TextStreamer(BaseStreamer):
|
||
|
"""
|
||
|
Simple text streamer that prints the token(s) to stdout as soon as entire words are formed.
|
||
|
|
||
|
<Tip warning={true}>
|
||
|
|
||
|
The API for the streamer classes is still under development and may change in the future.
|
||
|
|
||
|
</Tip>
|
||
|
|
||
|
Parameters:
|
||
|
tokenizer (`AutoTokenizer`):
|
||
|
The tokenized used to decode the tokens.
|
||
|
skip_prompt (`bool`, *optional*, defaults to `False`):
|
||
|
Whether to skip the prompt to `.generate()` or not. Useful e.g. for chatbots.
|
||
|
decode_kwargs (`dict`, *optional*):
|
||
|
Additional keyword arguments to pass to the tokenizer's `decode` method.
|
||
|
|
||
|
Examples:
|
||
|
|
||
|
```python
|
||
|
>>> from transformers import AutoModelForCausalLM, AutoTokenizer, TextStreamer
|
||
|
|
||
|
>>> tok = AutoTokenizer.from_pretrained("openai-community/gpt2")
|
||
|
>>> model = AutoModelForCausalLM.from_pretrained("openai-community/gpt2")
|
||
|
>>> inputs = tok(["An increasing sequence: one,"], return_tensors="pt")
|
||
|
>>> streamer = TextStreamer(tok)
|
||
|
|
||
|
>>> # Despite returning the usual output, the streamer will also print the generated text to stdout.
|
||
|
>>> _ = model.generate(**inputs, streamer=streamer, max_new_tokens=20)
|
||
|
An increasing sequence: one, two, three, four, five, six, seven, eight, nine, ten, eleven,
|
||
|
```
|
||
|
"""
|
||
|
|
||
|
def __init__(self, tokenizer: "AutoTokenizer", skip_prompt: bool = False, **decode_kwargs):
|
||
|
self.tokenizer = tokenizer
|
||
|
self.skip_prompt = skip_prompt
|
||
|
self.decode_kwargs = decode_kwargs
|
||
|
|
||
|
# variables used in the streaming process
|
||
|
self.token_cache = []
|
||
|
self.print_len = 0
|
||
|
self.next_tokens_are_prompt = True
|
||
|
|
||
|
def put(self, value):
|
||
|
"""
|
||
|
Receives tokens, decodes them, and prints them to stdout as soon as they form entire words.
|
||
|
"""
|
||
|
if len(value.shape) > 1 and value.shape[0] > 1:
|
||
|
raise ValueError("TextStreamer only supports batch size 1")
|
||
|
elif len(value.shape) > 1:
|
||
|
value = value[0]
|
||
|
|
||
|
if self.skip_prompt and self.next_tokens_are_prompt:
|
||
|
self.next_tokens_are_prompt = False
|
||
|
return
|
||
|
|
||
|
# Add the new token to the cache and decodes the entire thing.
|
||
|
self.token_cache.extend(value.tolist())
|
||
|
text = self.tokenizer.decode(self.token_cache, **self.decode_kwargs)
|
||
|
|
||
|
# After the symbol for a new line, we flush the cache.
|
||
|
if text.endswith("\n"):
|
||
|
printable_text = text[self.print_len :]
|
||
|
self.token_cache = []
|
||
|
self.print_len = 0
|
||
|
# If the last token is a CJK character, we print the characters.
|
||
|
elif len(text) > 0 and self._is_chinese_char(ord(text[-1])):
|
||
|
printable_text = text[self.print_len :]
|
||
|
self.print_len += len(printable_text)
|
||
|
# Otherwise, prints until the last space char (simple heuristic to avoid printing incomplete words,
|
||
|
# which may change with the subsequent token -- there are probably smarter ways to do this!)
|
||
|
else:
|
||
|
printable_text = text[self.print_len : text.rfind(" ") + 1]
|
||
|
self.print_len += len(printable_text)
|
||
|
|
||
|
self.on_finalized_text(printable_text)
|
||
|
|
||
|
def end(self):
|
||
|
"""Flushes any remaining cache and prints a newline to stdout."""
|
||
|
# Flush the cache, if it exists
|
||
|
if len(self.token_cache) > 0:
|
||
|
text = self.tokenizer.decode(self.token_cache, **self.decode_kwargs)
|
||
|
printable_text = text[self.print_len :]
|
||
|
self.token_cache = []
|
||
|
self.print_len = 0
|
||
|
else:
|
||
|
printable_text = ""
|
||
|
|
||
|
self.next_tokens_are_prompt = True
|
||
|
self.on_finalized_text(printable_text, stream_end=True)
|
||
|
|
||
|
def on_finalized_text(self, text: str, stream_end: bool = False):
|
||
|
"""Prints the new text to stdout. If the stream is ending, also prints a newline."""
|
||
|
print(text, flush=True, end="" if not stream_end else None)
|
||
|
|
||
|
def _is_chinese_char(self, cp):
|
||
|
"""Checks whether CP is the codepoint of a CJK character."""
|
||
|
# This defines a "chinese character" as anything in the CJK Unicode block:
|
||
|
# https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block)
|
||
|
#
|
||
|
# Note that the CJK Unicode block is NOT all Japanese and Korean characters,
|
||
|
# despite its name. The modern Korean Hangul alphabet is a different block,
|
||
|
# as is Japanese Hiragana and Katakana. Those alphabets are used to write
|
||
|
# space-separated words, so they are not treated specially and handled
|
||
|
# like the all of the other languages.
|
||
|
if (
|
||
|
(cp >= 0x4E00 and cp <= 0x9FFF)
|
||
|
or (cp >= 0x3400 and cp <= 0x4DBF) #
|
||
|
or (cp >= 0x20000 and cp <= 0x2A6DF) #
|
||
|
or (cp >= 0x2A700 and cp <= 0x2B73F) #
|
||
|
or (cp >= 0x2B740 and cp <= 0x2B81F) #
|
||
|
or (cp >= 0x2B820 and cp <= 0x2CEAF) #
|
||
|
or (cp >= 0xF900 and cp <= 0xFAFF)
|
||
|
or (cp >= 0x2F800 and cp <= 0x2FA1F) #
|
||
|
): #
|
||
|
return True
|
||
|
|
||
|
return False
|
||
|
|
||
|
|
||
|
class TextIteratorStreamer(TextStreamer):
|
||
|
"""
|
||
|
Streamer that stores print-ready text in a queue, to be used by a downstream application as an iterator. This is
|
||
|
useful for applications that benefit from acessing the generated text in a non-blocking way (e.g. in an interactive
|
||
|
Gradio demo).
|
||
|
|
||
|
<Tip warning={true}>
|
||
|
|
||
|
The API for the streamer classes is still under development and may change in the future.
|
||
|
|
||
|
</Tip>
|
||
|
|
||
|
Parameters:
|
||
|
tokenizer (`AutoTokenizer`):
|
||
|
The tokenized used to decode the tokens.
|
||
|
skip_prompt (`bool`, *optional*, defaults to `False`):
|
||
|
Whether to skip the prompt to `.generate()` or not. Useful e.g. for chatbots.
|
||
|
timeout (`float`, *optional*):
|
||
|
The timeout for the text queue. If `None`, the queue will block indefinitely. Useful to handle exceptions
|
||
|
in `.generate()`, when it is called in a separate thread.
|
||
|
decode_kwargs (`dict`, *optional*):
|
||
|
Additional keyword arguments to pass to the tokenizer's `decode` method.
|
||
|
|
||
|
Examples:
|
||
|
|
||
|
```python
|
||
|
>>> from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
|
||
|
>>> from threading import Thread
|
||
|
|
||
|
>>> tok = AutoTokenizer.from_pretrained("openai-community/gpt2")
|
||
|
>>> model = AutoModelForCausalLM.from_pretrained("openai-community/gpt2")
|
||
|
>>> inputs = tok(["An increasing sequence: one,"], return_tensors="pt")
|
||
|
>>> streamer = TextIteratorStreamer(tok)
|
||
|
|
||
|
>>> # Run the generation in a separate thread, so that we can fetch the generated text in a non-blocking way.
|
||
|
>>> generation_kwargs = dict(inputs, streamer=streamer, max_new_tokens=20)
|
||
|
>>> thread = Thread(target=model.generate, kwargs=generation_kwargs)
|
||
|
>>> thread.start()
|
||
|
>>> generated_text = ""
|
||
|
>>> for new_text in streamer:
|
||
|
... generated_text += new_text
|
||
|
>>> generated_text
|
||
|
'An increasing sequence: one, two, three, four, five, six, seven, eight, nine, ten, eleven,'
|
||
|
```
|
||
|
"""
|
||
|
|
||
|
def __init__(
|
||
|
self, tokenizer: "AutoTokenizer", skip_prompt: bool = False, timeout: Optional[float] = None, **decode_kwargs
|
||
|
):
|
||
|
super().__init__(tokenizer, skip_prompt, **decode_kwargs)
|
||
|
self.text_queue = Queue()
|
||
|
self.stop_signal = None
|
||
|
self.timeout = timeout
|
||
|
|
||
|
def on_finalized_text(self, text: str, stream_end: bool = False):
|
||
|
"""Put the new text in the queue. If the stream is ending, also put a stop signal in the queue."""
|
||
|
self.text_queue.put(text, timeout=self.timeout)
|
||
|
if stream_end:
|
||
|
self.text_queue.put(self.stop_signal, timeout=self.timeout)
|
||
|
|
||
|
def __iter__(self):
|
||
|
return self
|
||
|
|
||
|
def __next__(self):
|
||
|
value = self.text_queue.get(timeout=self.timeout)
|
||
|
if value == self.stop_signal:
|
||
|
raise StopIteration()
|
||
|
else:
|
||
|
return value
|