174 lines
6.3 KiB
Python
174 lines
6.3 KiB
Python
|
from __future__ import annotations
|
||
|
|
||
|
import sys
|
||
|
from typing import TYPE_CHECKING, Optional, cast
|
||
|
from argparse import ArgumentParser
|
||
|
from functools import partial
|
||
|
|
||
|
from openai.types.completion import Completion
|
||
|
|
||
|
from .._utils import get_client
|
||
|
from ..._types import NOT_GIVEN, NotGivenOr
|
||
|
from ..._utils import is_given
|
||
|
from .._errors import CLIError
|
||
|
from .._models import BaseModel
|
||
|
from ..._streaming import Stream
|
||
|
|
||
|
if TYPE_CHECKING:
|
||
|
from argparse import _SubParsersAction
|
||
|
|
||
|
|
||
|
def register(subparser: _SubParsersAction[ArgumentParser]) -> None:
|
||
|
sub = subparser.add_parser("completions.create")
|
||
|
|
||
|
# Required
|
||
|
sub.add_argument(
|
||
|
"-m",
|
||
|
"--model",
|
||
|
help="The model to use",
|
||
|
required=True,
|
||
|
)
|
||
|
|
||
|
# Optional
|
||
|
sub.add_argument("-p", "--prompt", help="An optional prompt to complete from")
|
||
|
sub.add_argument("--stream", help="Stream tokens as they're ready.", action="store_true")
|
||
|
sub.add_argument("-M", "--max-tokens", help="The maximum number of tokens to generate", type=int)
|
||
|
sub.add_argument(
|
||
|
"-t",
|
||
|
"--temperature",
|
||
|
help="""What sampling temperature to use. Higher values means the model will take more risks. Try 0.9 for more creative applications, and 0 (argmax sampling) for ones with a well-defined answer.
|
||
|
|
||
|
Mutually exclusive with `top_p`.""",
|
||
|
type=float,
|
||
|
)
|
||
|
sub.add_argument(
|
||
|
"-P",
|
||
|
"--top_p",
|
||
|
help="""An alternative to sampling with temperature, called nucleus sampling, where the considers the results of the tokens with top_p probability mass. So 0.1 means only the tokens comprising the top 10%% probability mass are considered.
|
||
|
|
||
|
Mutually exclusive with `temperature`.""",
|
||
|
type=float,
|
||
|
)
|
||
|
sub.add_argument(
|
||
|
"-n",
|
||
|
"--n",
|
||
|
help="How many sub-completions to generate for each prompt.",
|
||
|
type=int,
|
||
|
)
|
||
|
sub.add_argument(
|
||
|
"--logprobs",
|
||
|
help="Include the log probabilities on the `logprobs` most likely tokens, as well the chosen tokens. So for example, if `logprobs` is 10, the API will return a list of the 10 most likely tokens. If `logprobs` is 0, only the chosen tokens will have logprobs returned.",
|
||
|
type=int,
|
||
|
)
|
||
|
sub.add_argument(
|
||
|
"--best_of",
|
||
|
help="Generates `best_of` completions server-side and returns the 'best' (the one with the highest log probability per token). Results cannot be streamed.",
|
||
|
type=int,
|
||
|
)
|
||
|
sub.add_argument(
|
||
|
"--echo",
|
||
|
help="Echo back the prompt in addition to the completion",
|
||
|
action="store_true",
|
||
|
)
|
||
|
sub.add_argument(
|
||
|
"--frequency_penalty",
|
||
|
help="Positive values penalize new tokens based on their existing frequency in the text so far, decreasing the model's likelihood to repeat the same line verbatim.",
|
||
|
type=float,
|
||
|
)
|
||
|
sub.add_argument(
|
||
|
"--presence_penalty",
|
||
|
help="Positive values penalize new tokens based on whether they appear in the text so far, increasing the model's likelihood to talk about new topics.",
|
||
|
type=float,
|
||
|
)
|
||
|
sub.add_argument("--suffix", help="The suffix that comes after a completion of inserted text.")
|
||
|
sub.add_argument("--stop", help="A stop sequence at which to stop generating tokens.")
|
||
|
sub.add_argument(
|
||
|
"--user",
|
||
|
help="A unique identifier representing your end-user, which can help OpenAI to monitor and detect abuse.",
|
||
|
)
|
||
|
# TODO: add support for logit_bias
|
||
|
sub.set_defaults(func=CLICompletions.create, args_model=CLICompletionCreateArgs)
|
||
|
|
||
|
|
||
|
class CLICompletionCreateArgs(BaseModel):
|
||
|
model: str
|
||
|
stream: bool = False
|
||
|
|
||
|
prompt: Optional[str] = None
|
||
|
n: NotGivenOr[int] = NOT_GIVEN
|
||
|
stop: NotGivenOr[str] = NOT_GIVEN
|
||
|
user: NotGivenOr[str] = NOT_GIVEN
|
||
|
echo: NotGivenOr[bool] = NOT_GIVEN
|
||
|
suffix: NotGivenOr[str] = NOT_GIVEN
|
||
|
best_of: NotGivenOr[int] = NOT_GIVEN
|
||
|
top_p: NotGivenOr[float] = NOT_GIVEN
|
||
|
logprobs: NotGivenOr[int] = NOT_GIVEN
|
||
|
max_tokens: NotGivenOr[int] = NOT_GIVEN
|
||
|
temperature: NotGivenOr[float] = NOT_GIVEN
|
||
|
presence_penalty: NotGivenOr[float] = NOT_GIVEN
|
||
|
frequency_penalty: NotGivenOr[float] = NOT_GIVEN
|
||
|
|
||
|
|
||
|
class CLICompletions:
|
||
|
@staticmethod
|
||
|
def create(args: CLICompletionCreateArgs) -> None:
|
||
|
if is_given(args.n) and args.n > 1 and args.stream:
|
||
|
raise CLIError("Can't stream completions with n>1 with the current CLI")
|
||
|
|
||
|
make_request = partial(
|
||
|
get_client().completions.create,
|
||
|
n=args.n,
|
||
|
echo=args.echo,
|
||
|
stop=args.stop,
|
||
|
user=args.user,
|
||
|
model=args.model,
|
||
|
top_p=args.top_p,
|
||
|
prompt=args.prompt,
|
||
|
suffix=args.suffix,
|
||
|
best_of=args.best_of,
|
||
|
logprobs=args.logprobs,
|
||
|
max_tokens=args.max_tokens,
|
||
|
temperature=args.temperature,
|
||
|
presence_penalty=args.presence_penalty,
|
||
|
frequency_penalty=args.frequency_penalty,
|
||
|
)
|
||
|
|
||
|
if args.stream:
|
||
|
return CLICompletions._stream_create(
|
||
|
# mypy doesn't understand the `partial` function but pyright does
|
||
|
cast(Stream[Completion], make_request(stream=True)) # pyright: ignore[reportUnnecessaryCast]
|
||
|
)
|
||
|
|
||
|
return CLICompletions._create(make_request())
|
||
|
|
||
|
@staticmethod
|
||
|
def _create(completion: Completion) -> None:
|
||
|
should_print_header = len(completion.choices) > 1
|
||
|
for choice in completion.choices:
|
||
|
if should_print_header:
|
||
|
sys.stdout.write("===== Completion {} =====\n".format(choice.index))
|
||
|
|
||
|
sys.stdout.write(choice.text)
|
||
|
|
||
|
if should_print_header or not choice.text.endswith("\n"):
|
||
|
sys.stdout.write("\n")
|
||
|
|
||
|
sys.stdout.flush()
|
||
|
|
||
|
@staticmethod
|
||
|
def _stream_create(stream: Stream[Completion]) -> None:
|
||
|
for completion in stream:
|
||
|
should_print_header = len(completion.choices) > 1
|
||
|
for choice in sorted(completion.choices, key=lambda c: c.index):
|
||
|
if should_print_header:
|
||
|
sys.stdout.write("===== Chat Completion {} =====\n".format(choice.index))
|
||
|
|
||
|
sys.stdout.write(choice.text)
|
||
|
|
||
|
if should_print_header:
|
||
|
sys.stdout.write("\n")
|
||
|
|
||
|
sys.stdout.flush()
|
||
|
|
||
|
sys.stdout.write("\n")
|