ai-content-maker/.venv/Lib/site-packages/huggingface_hub/inference/_templating.py

106 lines
4.0 KiB
Python

from functools import lru_cache
from typing import Callable, Dict, List, Optional, Union
from ..utils import HfHubHTTPError, RepositoryNotFoundError, is_minijinja_available
class TemplateError(Exception):
"""Any error raised while trying to fetch or render a chat template."""
def _import_minijinja():
if not is_minijinja_available():
raise ImportError("Cannot render template. Please install minijinja using `pip install minijinja`.")
import minijinja # noqa: F401
return minijinja
def render_chat_prompt(
*,
model_id: str,
messages: List[Dict[str, str]],
token: Union[str, bool, None] = None,
add_generation_prompt: bool = True,
**kwargs,
) -> str:
"""Render a chat prompt using a model's chat template.
Args:
model_id (`str`):
The model id.
messages (`List[Dict[str, str]]`):
The list of messages to render.
token (`str` or `bool`, *optional*):
Hugging Face token. Will default to the locally saved token if not provided.
Returns:
`str`: The rendered chat prompt.
Raises:
`TemplateError`: If there's any issue while fetching, compiling or rendering the chat template.
"""
minijinja = _import_minijinja()
template = _fetch_and_compile_template(model_id=model_id, token=token)
try:
return template(messages=messages, add_generation_prompt=add_generation_prompt, **kwargs)
except minijinja.TemplateError as e:
raise TemplateError(f"Error while trying to render chat prompt for model '{model_id}': {e}") from e
@lru_cache # TODO: lru_cache for raised exceptions
def _fetch_and_compile_template(*, model_id: str, token: Union[str, None]) -> Callable:
"""Fetch and compile a model's chat template.
Method is cached to avoid fetching the same model's config multiple times.
Args:
model_id (`str`):
The model id.
token (`str` or `bool`, *optional*):
Hugging Face token. Will default to the locally saved token if not provided.
Returns:
`Callable`: A callable that takes a list of messages and returns the rendered chat prompt.
"""
from huggingface_hub.hf_api import HfApi
minijinja = _import_minijinja()
# 1. fetch config from API
try:
config = HfApi(token=token).model_info(model_id).config
except RepositoryNotFoundError as e:
raise TemplateError(f"Cannot render chat template: model '{model_id}' not found.") from e
except HfHubHTTPError as e:
raise TemplateError(f"Error while trying to fetch chat template for model '{model_id}': {e}") from e
# 2. check config validity
if config is None:
raise TemplateError(f"Config not found for model '{model_id}'.")
tokenizer_config = config.get("tokenizer_config")
if tokenizer_config is None:
raise TemplateError(f"Tokenizer config not found for model '{model_id}'.")
if tokenizer_config.get("chat_template") is None:
raise TemplateError(f"Chat template not found in tokenizer_config for model '{model_id}'.")
chat_template = tokenizer_config["chat_template"]
if not isinstance(chat_template, str):
raise TemplateError(f"Chat template must be a string, not '{type(chat_template)}' (model: {model_id}).")
special_tokens: Dict[str, Optional[str]] = {}
for key, value in tokenizer_config.items():
if "token" in key:
if isinstance(value, str):
special_tokens[key] = value
elif isinstance(value, dict) and value.get("__type") == "AddedToken":
special_tokens[key] = value.get("content")
# 3. compile template and return
env = minijinja.Environment()
try:
env.add_template("chat_template", chat_template)
except minijinja.TemplateError as e:
raise TemplateError(f"Error while trying to compile chat template for model '{model_id}': {e}") from e
return lambda **kwargs: env.render_template("chat_template", **kwargs, **special_tokens)