404 lines
11 KiB
Python
404 lines
11 KiB
Python
from __future__ import annotations
|
|
|
|
import os
|
|
import re
|
|
import inspect
|
|
import functools
|
|
from typing import (
|
|
Any,
|
|
Tuple,
|
|
Mapping,
|
|
TypeVar,
|
|
Callable,
|
|
Iterable,
|
|
Sequence,
|
|
cast,
|
|
overload,
|
|
)
|
|
from pathlib import Path
|
|
from typing_extensions import TypeGuard
|
|
|
|
import sniffio
|
|
|
|
from .._types import Headers, NotGiven, FileTypes, NotGivenOr, HeadersLike
|
|
from .._compat import parse_date as parse_date, parse_datetime as parse_datetime
|
|
|
|
_T = TypeVar("_T")
|
|
_TupleT = TypeVar("_TupleT", bound=Tuple[object, ...])
|
|
_MappingT = TypeVar("_MappingT", bound=Mapping[str, object])
|
|
_SequenceT = TypeVar("_SequenceT", bound=Sequence[object])
|
|
CallableT = TypeVar("CallableT", bound=Callable[..., Any])
|
|
|
|
|
|
def flatten(t: Iterable[Iterable[_T]]) -> list[_T]:
|
|
return [item for sublist in t for item in sublist]
|
|
|
|
|
|
def extract_files(
|
|
# TODO: this needs to take Dict but variance issues.....
|
|
# create protocol type ?
|
|
query: Mapping[str, object],
|
|
*,
|
|
paths: Sequence[Sequence[str]],
|
|
) -> list[tuple[str, FileTypes]]:
|
|
"""Recursively extract files from the given dictionary based on specified paths.
|
|
|
|
A path may look like this ['foo', 'files', '<array>', 'data'].
|
|
|
|
Note: this mutates the given dictionary.
|
|
"""
|
|
files: list[tuple[str, FileTypes]] = []
|
|
for path in paths:
|
|
files.extend(_extract_items(query, path, index=0, flattened_key=None))
|
|
return files
|
|
|
|
|
|
def _extract_items(
|
|
obj: object,
|
|
path: Sequence[str],
|
|
*,
|
|
index: int,
|
|
flattened_key: str | None,
|
|
) -> list[tuple[str, FileTypes]]:
|
|
try:
|
|
key = path[index]
|
|
except IndexError:
|
|
if isinstance(obj, NotGiven):
|
|
# no value was provided - we can safely ignore
|
|
return []
|
|
|
|
# cyclical import
|
|
from .._files import assert_is_file_content
|
|
|
|
# We have exhausted the path, return the entry we found.
|
|
assert_is_file_content(obj, key=flattened_key)
|
|
assert flattened_key is not None
|
|
return [(flattened_key, cast(FileTypes, obj))]
|
|
|
|
index += 1
|
|
if is_dict(obj):
|
|
try:
|
|
# We are at the last entry in the path so we must remove the field
|
|
if (len(path)) == index:
|
|
item = obj.pop(key)
|
|
else:
|
|
item = obj[key]
|
|
except KeyError:
|
|
# Key was not present in the dictionary, this is not indicative of an error
|
|
# as the given path may not point to a required field. We also do not want
|
|
# to enforce required fields as the API may differ from the spec in some cases.
|
|
return []
|
|
if flattened_key is None:
|
|
flattened_key = key
|
|
else:
|
|
flattened_key += f"[{key}]"
|
|
return _extract_items(
|
|
item,
|
|
path,
|
|
index=index,
|
|
flattened_key=flattened_key,
|
|
)
|
|
elif is_list(obj):
|
|
if key != "<array>":
|
|
return []
|
|
|
|
return flatten(
|
|
[
|
|
_extract_items(
|
|
item,
|
|
path,
|
|
index=index,
|
|
flattened_key=flattened_key + "[]" if flattened_key is not None else "[]",
|
|
)
|
|
for item in obj
|
|
]
|
|
)
|
|
|
|
# Something unexpected was passed, just ignore it.
|
|
return []
|
|
|
|
|
|
def is_given(obj: NotGivenOr[_T]) -> TypeGuard[_T]:
|
|
return not isinstance(obj, NotGiven)
|
|
|
|
|
|
# Type safe methods for narrowing types with TypeVars.
|
|
# The default narrowing for isinstance(obj, dict) is dict[unknown, unknown],
|
|
# however this cause Pyright to rightfully report errors. As we know we don't
|
|
# care about the contained types we can safely use `object` in it's place.
|
|
#
|
|
# There are two separate functions defined, `is_*` and `is_*_t` for different use cases.
|
|
# `is_*` is for when you're dealing with an unknown input
|
|
# `is_*_t` is for when you're narrowing a known union type to a specific subset
|
|
|
|
|
|
def is_tuple(obj: object) -> TypeGuard[tuple[object, ...]]:
|
|
return isinstance(obj, tuple)
|
|
|
|
|
|
def is_tuple_t(obj: _TupleT | object) -> TypeGuard[_TupleT]:
|
|
return isinstance(obj, tuple)
|
|
|
|
|
|
def is_sequence(obj: object) -> TypeGuard[Sequence[object]]:
|
|
return isinstance(obj, Sequence)
|
|
|
|
|
|
def is_sequence_t(obj: _SequenceT | object) -> TypeGuard[_SequenceT]:
|
|
return isinstance(obj, Sequence)
|
|
|
|
|
|
def is_mapping(obj: object) -> TypeGuard[Mapping[str, object]]:
|
|
return isinstance(obj, Mapping)
|
|
|
|
|
|
def is_mapping_t(obj: _MappingT | object) -> TypeGuard[_MappingT]:
|
|
return isinstance(obj, Mapping)
|
|
|
|
|
|
def is_dict(obj: object) -> TypeGuard[dict[object, object]]:
|
|
return isinstance(obj, dict)
|
|
|
|
|
|
def is_list(obj: object) -> TypeGuard[list[object]]:
|
|
return isinstance(obj, list)
|
|
|
|
|
|
def is_iterable(obj: object) -> TypeGuard[Iterable[object]]:
|
|
return isinstance(obj, Iterable)
|
|
|
|
|
|
def deepcopy_minimal(item: _T) -> _T:
|
|
"""Minimal reimplementation of copy.deepcopy() that will only copy certain object types:
|
|
|
|
- mappings, e.g. `dict`
|
|
- list
|
|
|
|
This is done for performance reasons.
|
|
"""
|
|
if is_mapping(item):
|
|
return cast(_T, {k: deepcopy_minimal(v) for k, v in item.items()})
|
|
if is_list(item):
|
|
return cast(_T, [deepcopy_minimal(entry) for entry in item])
|
|
return item
|
|
|
|
|
|
# copied from https://github.com/Rapptz/RoboDanny
|
|
def human_join(seq: Sequence[str], *, delim: str = ", ", final: str = "or") -> str:
|
|
size = len(seq)
|
|
if size == 0:
|
|
return ""
|
|
|
|
if size == 1:
|
|
return seq[0]
|
|
|
|
if size == 2:
|
|
return f"{seq[0]} {final} {seq[1]}"
|
|
|
|
return delim.join(seq[:-1]) + f" {final} {seq[-1]}"
|
|
|
|
|
|
def quote(string: str) -> str:
|
|
"""Add single quotation marks around the given string. Does *not* do any escaping."""
|
|
return f"'{string}'"
|
|
|
|
|
|
def required_args(*variants: Sequence[str]) -> Callable[[CallableT], CallableT]:
|
|
"""Decorator to enforce a given set of arguments or variants of arguments are passed to the decorated function.
|
|
|
|
Useful for enforcing runtime validation of overloaded functions.
|
|
|
|
Example usage:
|
|
```py
|
|
@overload
|
|
def foo(*, a: str) -> str:
|
|
...
|
|
|
|
|
|
@overload
|
|
def foo(*, b: bool) -> str:
|
|
...
|
|
|
|
|
|
# This enforces the same constraints that a static type checker would
|
|
# i.e. that either a or b must be passed to the function
|
|
@required_args(["a"], ["b"])
|
|
def foo(*, a: str | None = None, b: bool | None = None) -> str:
|
|
...
|
|
```
|
|
"""
|
|
|
|
def inner(func: CallableT) -> CallableT:
|
|
params = inspect.signature(func).parameters
|
|
positional = [
|
|
name
|
|
for name, param in params.items()
|
|
if param.kind
|
|
in {
|
|
param.POSITIONAL_ONLY,
|
|
param.POSITIONAL_OR_KEYWORD,
|
|
}
|
|
]
|
|
|
|
@functools.wraps(func)
|
|
def wrapper(*args: object, **kwargs: object) -> object:
|
|
given_params: set[str] = set()
|
|
for i, _ in enumerate(args):
|
|
try:
|
|
given_params.add(positional[i])
|
|
except IndexError:
|
|
raise TypeError(
|
|
f"{func.__name__}() takes {len(positional)} argument(s) but {len(args)} were given"
|
|
) from None
|
|
|
|
for key in kwargs.keys():
|
|
given_params.add(key)
|
|
|
|
for variant in variants:
|
|
matches = all((param in given_params for param in variant))
|
|
if matches:
|
|
break
|
|
else: # no break
|
|
if len(variants) > 1:
|
|
variations = human_join(
|
|
["(" + human_join([quote(arg) for arg in variant], final="and") + ")" for variant in variants]
|
|
)
|
|
msg = f"Missing required arguments; Expected either {variations} arguments to be given"
|
|
else:
|
|
assert len(variants) > 0
|
|
|
|
# TODO: this error message is not deterministic
|
|
missing = list(set(variants[0]) - given_params)
|
|
if len(missing) > 1:
|
|
msg = f"Missing required arguments: {human_join([quote(arg) for arg in missing])}"
|
|
else:
|
|
msg = f"Missing required argument: {quote(missing[0])}"
|
|
raise TypeError(msg)
|
|
return func(*args, **kwargs)
|
|
|
|
return wrapper # type: ignore
|
|
|
|
return inner
|
|
|
|
|
|
_K = TypeVar("_K")
|
|
_V = TypeVar("_V")
|
|
|
|
|
|
@overload
|
|
def strip_not_given(obj: None) -> None:
|
|
...
|
|
|
|
|
|
@overload
|
|
def strip_not_given(obj: Mapping[_K, _V | NotGiven]) -> dict[_K, _V]:
|
|
...
|
|
|
|
|
|
@overload
|
|
def strip_not_given(obj: object) -> object:
|
|
...
|
|
|
|
|
|
def strip_not_given(obj: object | None) -> object:
|
|
"""Remove all top-level keys where their values are instances of `NotGiven`"""
|
|
if obj is None:
|
|
return None
|
|
|
|
if not is_mapping(obj):
|
|
return obj
|
|
|
|
return {key: value for key, value in obj.items() if not isinstance(value, NotGiven)}
|
|
|
|
|
|
def coerce_integer(val: str) -> int:
|
|
return int(val, base=10)
|
|
|
|
|
|
def coerce_float(val: str) -> float:
|
|
return float(val)
|
|
|
|
|
|
def coerce_boolean(val: str) -> bool:
|
|
return val == "true" or val == "1" or val == "on"
|
|
|
|
|
|
def maybe_coerce_integer(val: str | None) -> int | None:
|
|
if val is None:
|
|
return None
|
|
return coerce_integer(val)
|
|
|
|
|
|
def maybe_coerce_float(val: str | None) -> float | None:
|
|
if val is None:
|
|
return None
|
|
return coerce_float(val)
|
|
|
|
|
|
def maybe_coerce_boolean(val: str | None) -> bool | None:
|
|
if val is None:
|
|
return None
|
|
return coerce_boolean(val)
|
|
|
|
|
|
def removeprefix(string: str, prefix: str) -> str:
|
|
"""Remove a prefix from a string.
|
|
|
|
Backport of `str.removeprefix` for Python < 3.9
|
|
"""
|
|
if string.startswith(prefix):
|
|
return string[len(prefix) :]
|
|
return string
|
|
|
|
|
|
def removesuffix(string: str, suffix: str) -> str:
|
|
"""Remove a suffix from a string.
|
|
|
|
Backport of `str.removesuffix` for Python < 3.9
|
|
"""
|
|
if string.endswith(suffix):
|
|
return string[: -len(suffix)]
|
|
return string
|
|
|
|
|
|
def file_from_path(path: str) -> FileTypes:
|
|
contents = Path(path).read_bytes()
|
|
file_name = os.path.basename(path)
|
|
return (file_name, contents)
|
|
|
|
|
|
def get_required_header(headers: HeadersLike, header: str) -> str:
|
|
lower_header = header.lower()
|
|
if isinstance(headers, Mapping):
|
|
headers = cast(Headers, headers)
|
|
for k, v in headers.items():
|
|
if k.lower() == lower_header and isinstance(v, str):
|
|
return v
|
|
|
|
""" to deal with the case where the header looks like Stainless-Event-Id """
|
|
intercaps_header = re.sub(r"([^\w])(\w)", lambda pat: pat.group(1) + pat.group(2).upper(), header.capitalize())
|
|
|
|
for normalized_header in [header, lower_header, header.upper(), intercaps_header]:
|
|
value = headers.get(normalized_header)
|
|
if value:
|
|
return value
|
|
|
|
raise ValueError(f"Could not find {header} header")
|
|
|
|
|
|
def get_async_library() -> str:
|
|
try:
|
|
return sniffio.current_async_library()
|
|
except Exception:
|
|
return "false"
|
|
|
|
|
|
def lru_cache(*, maxsize: int | None = 128) -> Callable[[CallableT], CallableT]:
|
|
"""A version of functools.lru_cache that retains the type signature
|
|
for the wrapped function arguments.
|
|
"""
|
|
wrapper = functools.lru_cache( # noqa: TID251
|
|
maxsize=maxsize,
|
|
)
|
|
return cast(Any, wrapper) # type: ignore[no-any-return]
|