ai-content-maker/.venv/Lib/site-packages/openai/cli/_api/completions.py

174 lines
6.3 KiB
Python
Raw Normal View History

2024-05-11 23:00:43 +03:00
from __future__ import annotations
import sys
from typing import TYPE_CHECKING, Optional, cast
from argparse import ArgumentParser
from functools import partial
from openai.types.completion import Completion
from .._utils import get_client
from ..._types import NOT_GIVEN, NotGivenOr
from ..._utils import is_given
from .._errors import CLIError
from .._models import BaseModel
from ..._streaming import Stream
if TYPE_CHECKING:
from argparse import _SubParsersAction
def register(subparser: _SubParsersAction[ArgumentParser]) -> None:
sub = subparser.add_parser("completions.create")
# Required
sub.add_argument(
"-m",
"--model",
help="The model to use",
required=True,
)
# Optional
sub.add_argument("-p", "--prompt", help="An optional prompt to complete from")
sub.add_argument("--stream", help="Stream tokens as they're ready.", action="store_true")
sub.add_argument("-M", "--max-tokens", help="The maximum number of tokens to generate", type=int)
sub.add_argument(
"-t",
"--temperature",
help="""What sampling temperature to use. Higher values means the model will take more risks. Try 0.9 for more creative applications, and 0 (argmax sampling) for ones with a well-defined answer.
Mutually exclusive with `top_p`.""",
type=float,
)
sub.add_argument(
"-P",
"--top_p",
help="""An alternative to sampling with temperature, called nucleus sampling, where the considers the results of the tokens with top_p probability mass. So 0.1 means only the tokens comprising the top 10%% probability mass are considered.
Mutually exclusive with `temperature`.""",
type=float,
)
sub.add_argument(
"-n",
"--n",
help="How many sub-completions to generate for each prompt.",
type=int,
)
sub.add_argument(
"--logprobs",
help="Include the log probabilities on the `logprobs` most likely tokens, as well the chosen tokens. So for example, if `logprobs` is 10, the API will return a list of the 10 most likely tokens. If `logprobs` is 0, only the chosen tokens will have logprobs returned.",
type=int,
)
sub.add_argument(
"--best_of",
help="Generates `best_of` completions server-side and returns the 'best' (the one with the highest log probability per token). Results cannot be streamed.",
type=int,
)
sub.add_argument(
"--echo",
help="Echo back the prompt in addition to the completion",
action="store_true",
)
sub.add_argument(
"--frequency_penalty",
help="Positive values penalize new tokens based on their existing frequency in the text so far, decreasing the model's likelihood to repeat the same line verbatim.",
type=float,
)
sub.add_argument(
"--presence_penalty",
help="Positive values penalize new tokens based on whether they appear in the text so far, increasing the model's likelihood to talk about new topics.",
type=float,
)
sub.add_argument("--suffix", help="The suffix that comes after a completion of inserted text.")
sub.add_argument("--stop", help="A stop sequence at which to stop generating tokens.")
sub.add_argument(
"--user",
help="A unique identifier representing your end-user, which can help OpenAI to monitor and detect abuse.",
)
# TODO: add support for logit_bias
sub.set_defaults(func=CLICompletions.create, args_model=CLICompletionCreateArgs)
class CLICompletionCreateArgs(BaseModel):
model: str
stream: bool = False
prompt: Optional[str] = None
n: NotGivenOr[int] = NOT_GIVEN
stop: NotGivenOr[str] = NOT_GIVEN
user: NotGivenOr[str] = NOT_GIVEN
echo: NotGivenOr[bool] = NOT_GIVEN
suffix: NotGivenOr[str] = NOT_GIVEN
best_of: NotGivenOr[int] = NOT_GIVEN
top_p: NotGivenOr[float] = NOT_GIVEN
logprobs: NotGivenOr[int] = NOT_GIVEN
max_tokens: NotGivenOr[int] = NOT_GIVEN
temperature: NotGivenOr[float] = NOT_GIVEN
presence_penalty: NotGivenOr[float] = NOT_GIVEN
frequency_penalty: NotGivenOr[float] = NOT_GIVEN
class CLICompletions:
@staticmethod
def create(args: CLICompletionCreateArgs) -> None:
if is_given(args.n) and args.n > 1 and args.stream:
raise CLIError("Can't stream completions with n>1 with the current CLI")
make_request = partial(
get_client().completions.create,
n=args.n,
echo=args.echo,
stop=args.stop,
user=args.user,
model=args.model,
top_p=args.top_p,
prompt=args.prompt,
suffix=args.suffix,
best_of=args.best_of,
logprobs=args.logprobs,
max_tokens=args.max_tokens,
temperature=args.temperature,
presence_penalty=args.presence_penalty,
frequency_penalty=args.frequency_penalty,
)
if args.stream:
return CLICompletions._stream_create(
# mypy doesn't understand the `partial` function but pyright does
cast(Stream[Completion], make_request(stream=True)) # pyright: ignore[reportUnnecessaryCast]
)
return CLICompletions._create(make_request())
@staticmethod
def _create(completion: Completion) -> None:
should_print_header = len(completion.choices) > 1
for choice in completion.choices:
if should_print_header:
sys.stdout.write("===== Completion {} =====\n".format(choice.index))
sys.stdout.write(choice.text)
if should_print_header or not choice.text.endswith("\n"):
sys.stdout.write("\n")
sys.stdout.flush()
@staticmethod
def _stream_create(stream: Stream[Completion]) -> None:
for completion in stream:
should_print_header = len(completion.choices) > 1
for choice in sorted(completion.choices, key=lambda c: c.index):
if should_print_header:
sys.stdout.write("===== Chat Completion {} =====\n".format(choice.index))
sys.stdout.write(choice.text)
if should_print_header:
sys.stdout.write("\n")
sys.stdout.flush()
sys.stdout.write("\n")