229 lines
7.8 KiB
Python
229 lines
7.8 KiB
Python
# Copyright 2020 The HuggingFace Team. All rights reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
from argparse import ArgumentParser, Namespace
|
|
from typing import Any, List, Optional
|
|
|
|
from ..pipelines import Pipeline, get_supported_tasks, pipeline
|
|
from ..utils import logging
|
|
from . import BaseTransformersCLICommand
|
|
|
|
|
|
try:
|
|
from fastapi import Body, FastAPI, HTTPException
|
|
from fastapi.routing import APIRoute
|
|
from pydantic import BaseModel
|
|
from starlette.responses import JSONResponse
|
|
from uvicorn import run
|
|
|
|
_serve_dependencies_installed = True
|
|
except (ImportError, AttributeError):
|
|
BaseModel = object
|
|
|
|
def Body(*x, **y):
|
|
pass
|
|
|
|
_serve_dependencies_installed = False
|
|
|
|
|
|
logger = logging.get_logger("transformers-cli/serving")
|
|
|
|
|
|
def serve_command_factory(args: Namespace):
|
|
"""
|
|
Factory function used to instantiate serving server from provided command line arguments.
|
|
|
|
Returns: ServeCommand
|
|
"""
|
|
nlp = pipeline(
|
|
task=args.task,
|
|
model=args.model if args.model else None,
|
|
config=args.config,
|
|
tokenizer=args.tokenizer,
|
|
device=args.device,
|
|
)
|
|
return ServeCommand(nlp, args.host, args.port, args.workers)
|
|
|
|
|
|
class ServeModelInfoResult(BaseModel):
|
|
"""
|
|
Expose model information
|
|
"""
|
|
|
|
infos: dict
|
|
|
|
|
|
class ServeTokenizeResult(BaseModel):
|
|
"""
|
|
Tokenize result model
|
|
"""
|
|
|
|
tokens: List[str]
|
|
tokens_ids: Optional[List[int]]
|
|
|
|
|
|
class ServeDeTokenizeResult(BaseModel):
|
|
"""
|
|
DeTokenize result model
|
|
"""
|
|
|
|
text: str
|
|
|
|
|
|
class ServeForwardResult(BaseModel):
|
|
"""
|
|
Forward result model
|
|
"""
|
|
|
|
output: Any
|
|
|
|
|
|
class ServeCommand(BaseTransformersCLICommand):
|
|
@staticmethod
|
|
def register_subcommand(parser: ArgumentParser):
|
|
"""
|
|
Register this command to argparse so it's available for the transformer-cli
|
|
|
|
Args:
|
|
parser: Root parser to register command-specific arguments
|
|
"""
|
|
serve_parser = parser.add_parser(
|
|
"serve", help="CLI tool to run inference requests through REST and GraphQL endpoints."
|
|
)
|
|
serve_parser.add_argument(
|
|
"--task",
|
|
type=str,
|
|
choices=get_supported_tasks(),
|
|
help="The task to run the pipeline on",
|
|
)
|
|
serve_parser.add_argument("--host", type=str, default="localhost", help="Interface the server will listen on.")
|
|
serve_parser.add_argument("--port", type=int, default=8888, help="Port the serving will listen to.")
|
|
serve_parser.add_argument("--workers", type=int, default=1, help="Number of http workers")
|
|
serve_parser.add_argument("--model", type=str, help="Model's name or path to stored model.")
|
|
serve_parser.add_argument("--config", type=str, help="Model's config name or path to stored model.")
|
|
serve_parser.add_argument("--tokenizer", type=str, help="Tokenizer name to use.")
|
|
serve_parser.add_argument(
|
|
"--device",
|
|
type=int,
|
|
default=-1,
|
|
help="Indicate the device to run onto, -1 indicates CPU, >= 0 indicates GPU (default: -1)",
|
|
)
|
|
serve_parser.set_defaults(func=serve_command_factory)
|
|
|
|
def __init__(self, pipeline: Pipeline, host: str, port: int, workers: int):
|
|
self._pipeline = pipeline
|
|
|
|
self.host = host
|
|
self.port = port
|
|
self.workers = workers
|
|
|
|
if not _serve_dependencies_installed:
|
|
raise RuntimeError(
|
|
"Using serve command requires FastAPI and uvicorn. "
|
|
'Please install transformers with [serving]: pip install "transformers[serving]". '
|
|
"Or install FastAPI and uvicorn separately."
|
|
)
|
|
else:
|
|
logger.info(f"Serving model over {host}:{port}")
|
|
self._app = FastAPI(
|
|
routes=[
|
|
APIRoute(
|
|
"/",
|
|
self.model_info,
|
|
response_model=ServeModelInfoResult,
|
|
response_class=JSONResponse,
|
|
methods=["GET"],
|
|
),
|
|
APIRoute(
|
|
"/tokenize",
|
|
self.tokenize,
|
|
response_model=ServeTokenizeResult,
|
|
response_class=JSONResponse,
|
|
methods=["POST"],
|
|
),
|
|
APIRoute(
|
|
"/detokenize",
|
|
self.detokenize,
|
|
response_model=ServeDeTokenizeResult,
|
|
response_class=JSONResponse,
|
|
methods=["POST"],
|
|
),
|
|
APIRoute(
|
|
"/forward",
|
|
self.forward,
|
|
response_model=ServeForwardResult,
|
|
response_class=JSONResponse,
|
|
methods=["POST"],
|
|
),
|
|
],
|
|
timeout=600,
|
|
)
|
|
|
|
def run(self):
|
|
run(self._app, host=self.host, port=self.port, workers=self.workers)
|
|
|
|
def model_info(self):
|
|
return ServeModelInfoResult(infos=vars(self._pipeline.model.config))
|
|
|
|
def tokenize(self, text_input: str = Body(None, embed=True), return_ids: bool = Body(False, embed=True)):
|
|
"""
|
|
Tokenize the provided input and eventually returns corresponding tokens id: - **text_input**: String to
|
|
tokenize - **return_ids**: Boolean flags indicating if the tokens have to be converted to their integer
|
|
mapping.
|
|
"""
|
|
try:
|
|
tokens_txt = self._pipeline.tokenizer.tokenize(text_input)
|
|
|
|
if return_ids:
|
|
tokens_ids = self._pipeline.tokenizer.convert_tokens_to_ids(tokens_txt)
|
|
return ServeTokenizeResult(tokens=tokens_txt, tokens_ids=tokens_ids)
|
|
else:
|
|
return ServeTokenizeResult(tokens=tokens_txt)
|
|
|
|
except Exception as e:
|
|
raise HTTPException(status_code=500, detail={"model": "", "error": str(e)})
|
|
|
|
def detokenize(
|
|
self,
|
|
tokens_ids: List[int] = Body(None, embed=True),
|
|
skip_special_tokens: bool = Body(False, embed=True),
|
|
cleanup_tokenization_spaces: bool = Body(True, embed=True),
|
|
):
|
|
"""
|
|
Detokenize the provided tokens ids to readable text: - **tokens_ids**: List of tokens ids -
|
|
**skip_special_tokens**: Flag indicating to not try to decode special tokens - **cleanup_tokenization_spaces**:
|
|
Flag indicating to remove all leading/trailing spaces and intermediate ones.
|
|
"""
|
|
try:
|
|
decoded_str = self._pipeline.tokenizer.decode(tokens_ids, skip_special_tokens, cleanup_tokenization_spaces)
|
|
return ServeDeTokenizeResult(model="", text=decoded_str)
|
|
except Exception as e:
|
|
raise HTTPException(status_code=500, detail={"model": "", "error": str(e)})
|
|
|
|
async def forward(self, inputs=Body(None, embed=True)):
|
|
"""
|
|
**inputs**: **attention_mask**: **tokens_type_ids**:
|
|
"""
|
|
|
|
# Check we don't have empty string
|
|
if len(inputs) == 0:
|
|
return ServeForwardResult(output=[], attention=[])
|
|
|
|
try:
|
|
# Forward through the model
|
|
output = self._pipeline(inputs)
|
|
return ServeForwardResult(output=output)
|
|
except Exception as e:
|
|
raise HTTPException(500, {"error": str(e)})
|