140 lines
4.9 KiB
Python
140 lines
4.9 KiB
Python
|
from __future__ import annotations
|
||
|
|
||
|
from typing import TYPE_CHECKING, Any, cast
|
||
|
from argparse import ArgumentParser
|
||
|
|
||
|
from .._utils import get_client, print_model
|
||
|
from ..._types import NOT_GIVEN, NotGiven, NotGivenOr
|
||
|
from .._models import BaseModel
|
||
|
from .._progress import BufferReader
|
||
|
|
||
|
if TYPE_CHECKING:
|
||
|
from argparse import _SubParsersAction
|
||
|
|
||
|
|
||
|
def register(subparser: _SubParsersAction[ArgumentParser]) -> None:
|
||
|
sub = subparser.add_parser("images.generate")
|
||
|
sub.add_argument("-m", "--model", type=str)
|
||
|
sub.add_argument("-p", "--prompt", type=str, required=True)
|
||
|
sub.add_argument("-n", "--num-images", type=int, default=1)
|
||
|
sub.add_argument("-s", "--size", type=str, default="1024x1024", help="Size of the output image")
|
||
|
sub.add_argument("--response-format", type=str, default="url")
|
||
|
sub.set_defaults(func=CLIImage.create, args_model=CLIImageCreateArgs)
|
||
|
|
||
|
sub = subparser.add_parser("images.edit")
|
||
|
sub.add_argument("-m", "--model", type=str)
|
||
|
sub.add_argument("-p", "--prompt", type=str, required=True)
|
||
|
sub.add_argument("-n", "--num-images", type=int, default=1)
|
||
|
sub.add_argument(
|
||
|
"-I",
|
||
|
"--image",
|
||
|
type=str,
|
||
|
required=True,
|
||
|
help="Image to modify. Should be a local path and a PNG encoded image.",
|
||
|
)
|
||
|
sub.add_argument("-s", "--size", type=str, default="1024x1024", help="Size of the output image")
|
||
|
sub.add_argument("--response-format", type=str, default="url")
|
||
|
sub.add_argument(
|
||
|
"-M",
|
||
|
"--mask",
|
||
|
type=str,
|
||
|
required=False,
|
||
|
help="Path to a mask image. It should be the same size as the image you're editing and a RGBA PNG image. The Alpha channel acts as the mask.",
|
||
|
)
|
||
|
sub.set_defaults(func=CLIImage.edit, args_model=CLIImageEditArgs)
|
||
|
|
||
|
sub = subparser.add_parser("images.create_variation")
|
||
|
sub.add_argument("-m", "--model", type=str)
|
||
|
sub.add_argument("-n", "--num-images", type=int, default=1)
|
||
|
sub.add_argument(
|
||
|
"-I",
|
||
|
"--image",
|
||
|
type=str,
|
||
|
required=True,
|
||
|
help="Image to modify. Should be a local path and a PNG encoded image.",
|
||
|
)
|
||
|
sub.add_argument("-s", "--size", type=str, default="1024x1024", help="Size of the output image")
|
||
|
sub.add_argument("--response-format", type=str, default="url")
|
||
|
sub.set_defaults(func=CLIImage.create_variation, args_model=CLIImageCreateVariationArgs)
|
||
|
|
||
|
|
||
|
class CLIImageCreateArgs(BaseModel):
|
||
|
prompt: str
|
||
|
num_images: int
|
||
|
size: str
|
||
|
response_format: str
|
||
|
model: NotGivenOr[str] = NOT_GIVEN
|
||
|
|
||
|
|
||
|
class CLIImageCreateVariationArgs(BaseModel):
|
||
|
image: str
|
||
|
num_images: int
|
||
|
size: str
|
||
|
response_format: str
|
||
|
model: NotGivenOr[str] = NOT_GIVEN
|
||
|
|
||
|
|
||
|
class CLIImageEditArgs(BaseModel):
|
||
|
image: str
|
||
|
num_images: int
|
||
|
size: str
|
||
|
response_format: str
|
||
|
prompt: str
|
||
|
mask: NotGivenOr[str] = NOT_GIVEN
|
||
|
model: NotGivenOr[str] = NOT_GIVEN
|
||
|
|
||
|
|
||
|
class CLIImage:
|
||
|
@staticmethod
|
||
|
def create(args: CLIImageCreateArgs) -> None:
|
||
|
image = get_client().images.generate(
|
||
|
model=args.model,
|
||
|
prompt=args.prompt,
|
||
|
n=args.num_images,
|
||
|
# casts required because the API is typed for enums
|
||
|
# but we don't want to validate that here for forwards-compat
|
||
|
size=cast(Any, args.size),
|
||
|
response_format=cast(Any, args.response_format),
|
||
|
)
|
||
|
print_model(image)
|
||
|
|
||
|
@staticmethod
|
||
|
def create_variation(args: CLIImageCreateVariationArgs) -> None:
|
||
|
with open(args.image, "rb") as file_reader:
|
||
|
buffer_reader = BufferReader(file_reader.read(), desc="Upload progress")
|
||
|
|
||
|
image = get_client().images.create_variation(
|
||
|
model=args.model,
|
||
|
image=("image", buffer_reader),
|
||
|
n=args.num_images,
|
||
|
# casts required because the API is typed for enums
|
||
|
# but we don't want to validate that here for forwards-compat
|
||
|
size=cast(Any, args.size),
|
||
|
response_format=cast(Any, args.response_format),
|
||
|
)
|
||
|
print_model(image)
|
||
|
|
||
|
@staticmethod
|
||
|
def edit(args: CLIImageEditArgs) -> None:
|
||
|
with open(args.image, "rb") as file_reader:
|
||
|
buffer_reader = BufferReader(file_reader.read(), desc="Image upload progress")
|
||
|
|
||
|
if isinstance(args.mask, NotGiven):
|
||
|
mask: NotGivenOr[BufferReader] = NOT_GIVEN
|
||
|
else:
|
||
|
with open(args.mask, "rb") as file_reader:
|
||
|
mask = BufferReader(file_reader.read(), desc="Mask progress")
|
||
|
|
||
|
image = get_client().images.edit(
|
||
|
model=args.model,
|
||
|
prompt=args.prompt,
|
||
|
image=("image", buffer_reader),
|
||
|
n=args.num_images,
|
||
|
mask=("mask", mask) if not isinstance(mask, NotGiven) else mask,
|
||
|
# casts required because the API is typed for enums
|
||
|
# but we don't want to validate that here for forwards-compat
|
||
|
size=cast(Any, args.size),
|
||
|
response_format=cast(Any, args.response_format),
|
||
|
)
|
||
|
print_model(image)
|