235 lines
6.6 KiB
Python
235 lines
6.6 KiB
Python
from __future__ import annotations
|
|
|
|
import sys
|
|
import logging
|
|
import argparse
|
|
from typing import Any, List, Type, Optional
|
|
from typing_extensions import ClassVar
|
|
|
|
import httpx
|
|
import pydantic
|
|
|
|
import openai
|
|
|
|
from . import _tools
|
|
from .. import _ApiType, __version__
|
|
from ._api import register_commands
|
|
from ._utils import can_use_http2
|
|
from .._types import ProxiesDict
|
|
from ._errors import CLIError, display_error
|
|
from .._compat import PYDANTIC_V2, ConfigDict, model_parse
|
|
from .._models import BaseModel
|
|
from .._exceptions import APIError
|
|
|
|
logger = logging.getLogger()
|
|
formatter = logging.Formatter("[%(asctime)s] %(message)s")
|
|
handler = logging.StreamHandler(sys.stderr)
|
|
handler.setFormatter(formatter)
|
|
logger.addHandler(handler)
|
|
|
|
|
|
class Arguments(BaseModel):
|
|
if PYDANTIC_V2:
|
|
model_config: ClassVar[ConfigDict] = ConfigDict(
|
|
extra="ignore",
|
|
)
|
|
else:
|
|
|
|
class Config(pydantic.BaseConfig): # type: ignore
|
|
extra: Any = pydantic.Extra.ignore # type: ignore
|
|
|
|
verbosity: int
|
|
version: Optional[str] = None
|
|
|
|
api_key: Optional[str]
|
|
api_base: Optional[str]
|
|
organization: Optional[str]
|
|
proxy: Optional[List[str]]
|
|
api_type: Optional[_ApiType] = None
|
|
api_version: Optional[str] = None
|
|
|
|
# azure
|
|
azure_endpoint: Optional[str] = None
|
|
azure_ad_token: Optional[str] = None
|
|
|
|
# internal, set by subparsers to parse their specific args
|
|
args_model: Optional[Type[BaseModel]] = None
|
|
|
|
# internal, used so that subparsers can forward unknown arguments
|
|
unknown_args: List[str] = []
|
|
allow_unknown_args: bool = False
|
|
|
|
|
|
def _build_parser() -> argparse.ArgumentParser:
|
|
parser = argparse.ArgumentParser(description=None, prog="openai")
|
|
parser.add_argument(
|
|
"-v",
|
|
"--verbose",
|
|
action="count",
|
|
dest="verbosity",
|
|
default=0,
|
|
help="Set verbosity.",
|
|
)
|
|
parser.add_argument("-b", "--api-base", help="What API base url to use.")
|
|
parser.add_argument("-k", "--api-key", help="What API key to use.")
|
|
parser.add_argument("-p", "--proxy", nargs="+", help="What proxy to use.")
|
|
parser.add_argument(
|
|
"-o",
|
|
"--organization",
|
|
help="Which organization to run as (will use your default organization if not specified)",
|
|
)
|
|
parser.add_argument(
|
|
"-t",
|
|
"--api-type",
|
|
type=str,
|
|
choices=("openai", "azure"),
|
|
help="The backend API to call, must be `openai` or `azure`",
|
|
)
|
|
parser.add_argument(
|
|
"--api-version",
|
|
help="The Azure API version, e.g. 'https://learn.microsoft.com/en-us/azure/ai-services/openai/reference#rest-api-versioning'",
|
|
)
|
|
|
|
# azure
|
|
parser.add_argument(
|
|
"--azure-endpoint",
|
|
help="The Azure endpoint, e.g. 'https://endpoint.openai.azure.com'",
|
|
)
|
|
parser.add_argument(
|
|
"--azure-ad-token",
|
|
help="A token from Azure Active Directory, https://www.microsoft.com/en-us/security/business/identity-access/microsoft-entra-id",
|
|
)
|
|
|
|
# prints the package version
|
|
parser.add_argument(
|
|
"-V",
|
|
"--version",
|
|
action="version",
|
|
version="%(prog)s " + __version__,
|
|
)
|
|
|
|
def help() -> None:
|
|
parser.print_help()
|
|
|
|
parser.set_defaults(func=help)
|
|
|
|
subparsers = parser.add_subparsers()
|
|
sub_api = subparsers.add_parser("api", help="Direct API calls")
|
|
|
|
register_commands(sub_api)
|
|
|
|
sub_tools = subparsers.add_parser("tools", help="Client side tools for convenience")
|
|
_tools.register_commands(sub_tools, subparsers)
|
|
|
|
return parser
|
|
|
|
|
|
def main() -> int:
|
|
try:
|
|
_main()
|
|
except (APIError, CLIError, pydantic.ValidationError) as err:
|
|
display_error(err)
|
|
return 1
|
|
except KeyboardInterrupt:
|
|
sys.stderr.write("\n")
|
|
return 1
|
|
return 0
|
|
|
|
|
|
def _parse_args(parser: argparse.ArgumentParser) -> tuple[argparse.Namespace, Arguments, list[str]]:
|
|
# argparse by default will strip out the `--` but we want to keep it for unknown arguments
|
|
if "--" in sys.argv:
|
|
idx = sys.argv.index("--")
|
|
known_args = sys.argv[1:idx]
|
|
unknown_args = sys.argv[idx:]
|
|
else:
|
|
known_args = sys.argv[1:]
|
|
unknown_args = []
|
|
|
|
parsed, remaining_unknown = parser.parse_known_args(known_args)
|
|
|
|
# append any remaining unknown arguments from the initial parsing
|
|
remaining_unknown.extend(unknown_args)
|
|
|
|
args = model_parse(Arguments, vars(parsed))
|
|
if not args.allow_unknown_args:
|
|
# we have to parse twice to ensure any unknown arguments
|
|
# result in an error if that behaviour is desired
|
|
parser.parse_args()
|
|
|
|
return parsed, args, remaining_unknown
|
|
|
|
|
|
def _main() -> None:
|
|
parser = _build_parser()
|
|
parsed, args, unknown = _parse_args(parser)
|
|
|
|
if args.verbosity != 0:
|
|
sys.stderr.write("Warning: --verbosity isn't supported yet\n")
|
|
|
|
proxies: ProxiesDict = {}
|
|
if args.proxy is not None:
|
|
for proxy in args.proxy:
|
|
key = "https://" if proxy.startswith("https") else "http://"
|
|
if key in proxies:
|
|
raise CLIError(f"Multiple {key} proxies given - only the last one would be used")
|
|
|
|
proxies[key] = proxy
|
|
|
|
http_client = httpx.Client(
|
|
proxies=proxies or None,
|
|
http2=can_use_http2(),
|
|
)
|
|
openai.http_client = http_client
|
|
|
|
if args.organization:
|
|
openai.organization = args.organization
|
|
|
|
if args.api_key:
|
|
openai.api_key = args.api_key
|
|
|
|
if args.api_base:
|
|
openai.base_url = args.api_base
|
|
|
|
# azure
|
|
if args.api_type is not None:
|
|
openai.api_type = args.api_type
|
|
|
|
if args.azure_endpoint is not None:
|
|
openai.azure_endpoint = args.azure_endpoint
|
|
|
|
if args.api_version is not None:
|
|
openai.api_version = args.api_version
|
|
|
|
if args.azure_ad_token is not None:
|
|
openai.azure_ad_token = args.azure_ad_token
|
|
|
|
try:
|
|
if args.args_model:
|
|
parsed.func(
|
|
model_parse(
|
|
args.args_model,
|
|
{
|
|
**{
|
|
# we omit None values so that they can be defaulted to `NotGiven`
|
|
# and we'll strip it from the API request
|
|
key: value
|
|
for key, value in vars(parsed).items()
|
|
if value is not None
|
|
},
|
|
"unknown_args": unknown,
|
|
},
|
|
)
|
|
)
|
|
else:
|
|
parsed.func()
|
|
finally:
|
|
try:
|
|
http_client.close()
|
|
except Exception:
|
|
pass
|
|
|
|
|
|
if __name__ == "__main__":
|
|
sys.exit(main())
|