ai-content-maker/.venv/Lib/site-packages/httpx/_models.py

1210 lines
41 KiB
Python

from __future__ import annotations
import datetime
import email.message
import json as jsonlib
import typing
import urllib.request
from collections.abc import Mapping
from http.cookiejar import Cookie, CookieJar
from ._content import ByteStream, UnattachedStream, encode_request, encode_response
from ._decoders import (
SUPPORTED_DECODERS,
ByteChunker,
ContentDecoder,
IdentityDecoder,
LineDecoder,
MultiDecoder,
TextChunker,
TextDecoder,
)
from ._exceptions import (
CookieConflict,
HTTPStatusError,
RequestNotRead,
ResponseNotRead,
StreamClosed,
StreamConsumed,
request_context,
)
from ._multipart import get_multipart_boundary_from_content_type
from ._status_codes import codes
from ._types import (
AsyncByteStream,
CookieTypes,
HeaderTypes,
QueryParamTypes,
RequestContent,
RequestData,
RequestExtensions,
RequestFiles,
ResponseContent,
ResponseExtensions,
SyncByteStream,
)
from ._urls import URL
from ._utils import (
is_known_encoding,
normalize_header_key,
normalize_header_value,
obfuscate_sensitive_headers,
parse_content_type_charset,
parse_header_links,
)
class Headers(typing.MutableMapping[str, str]):
"""
HTTP headers, as a case-insensitive multi-dict.
"""
def __init__(
self,
headers: HeaderTypes | None = None,
encoding: str | None = None,
) -> None:
if headers is None:
self._list = [] # type: typing.List[typing.Tuple[bytes, bytes, bytes]]
elif isinstance(headers, Headers):
self._list = list(headers._list)
elif isinstance(headers, Mapping):
self._list = [
(
normalize_header_key(k, lower=False, encoding=encoding),
normalize_header_key(k, lower=True, encoding=encoding),
normalize_header_value(v, encoding),
)
for k, v in headers.items()
]
else:
self._list = [
(
normalize_header_key(k, lower=False, encoding=encoding),
normalize_header_key(k, lower=True, encoding=encoding),
normalize_header_value(v, encoding),
)
for k, v in headers
]
self._encoding = encoding
@property
def encoding(self) -> str:
"""
Header encoding is mandated as ascii, but we allow fallbacks to utf-8
or iso-8859-1.
"""
if self._encoding is None:
for encoding in ["ascii", "utf-8"]:
for key, value in self.raw:
try:
key.decode(encoding)
value.decode(encoding)
except UnicodeDecodeError:
break
else:
# The else block runs if 'break' did not occur, meaning
# all values fitted the encoding.
self._encoding = encoding
break
else:
# The ISO-8859-1 encoding covers all 256 code points in a byte,
# so will never raise decode errors.
self._encoding = "iso-8859-1"
return self._encoding
@encoding.setter
def encoding(self, value: str) -> None:
self._encoding = value
@property
def raw(self) -> list[tuple[bytes, bytes]]:
"""
Returns a list of the raw header items, as byte pairs.
"""
return [(raw_key, value) for raw_key, _, value in self._list]
def keys(self) -> typing.KeysView[str]:
return {key.decode(self.encoding): None for _, key, value in self._list}.keys()
def values(self) -> typing.ValuesView[str]:
values_dict: dict[str, str] = {}
for _, key, value in self._list:
str_key = key.decode(self.encoding)
str_value = value.decode(self.encoding)
if str_key in values_dict:
values_dict[str_key] += f", {str_value}"
else:
values_dict[str_key] = str_value
return values_dict.values()
def items(self) -> typing.ItemsView[str, str]:
"""
Return `(key, value)` items of headers. Concatenate headers
into a single comma separated value when a key occurs multiple times.
"""
values_dict: dict[str, str] = {}
for _, key, value in self._list:
str_key = key.decode(self.encoding)
str_value = value.decode(self.encoding)
if str_key in values_dict:
values_dict[str_key] += f", {str_value}"
else:
values_dict[str_key] = str_value
return values_dict.items()
def multi_items(self) -> list[tuple[str, str]]:
"""
Return a list of `(key, value)` pairs of headers. Allow multiple
occurrences of the same key without concatenating into a single
comma separated value.
"""
return [
(key.decode(self.encoding), value.decode(self.encoding))
for _, key, value in self._list
]
def get(self, key: str, default: typing.Any = None) -> typing.Any:
"""
Return a header value. If multiple occurrences of the header occur
then concatenate them together with commas.
"""
try:
return self[key]
except KeyError:
return default
def get_list(self, key: str, split_commas: bool = False) -> list[str]:
"""
Return a list of all header values for a given key.
If `split_commas=True` is passed, then any comma separated header
values are split into multiple return strings.
"""
get_header_key = key.lower().encode(self.encoding)
values = [
item_value.decode(self.encoding)
for _, item_key, item_value in self._list
if item_key.lower() == get_header_key
]
if not split_commas:
return values
split_values = []
for value in values:
split_values.extend([item.strip() for item in value.split(",")])
return split_values
def update(self, headers: HeaderTypes | None = None) -> None: # type: ignore
headers = Headers(headers)
for key in headers.keys():
if key in self:
self.pop(key)
self._list.extend(headers._list)
def copy(self) -> Headers:
return Headers(self, encoding=self.encoding)
def __getitem__(self, key: str) -> str:
"""
Return a single header value.
If there are multiple headers with the same key, then we concatenate
them with commas. See: https://tools.ietf.org/html/rfc7230#section-3.2.2
"""
normalized_key = key.lower().encode(self.encoding)
items = [
header_value.decode(self.encoding)
for _, header_key, header_value in self._list
if header_key == normalized_key
]
if items:
return ", ".join(items)
raise KeyError(key)
def __setitem__(self, key: str, value: str) -> None:
"""
Set the header `key` to `value`, removing any duplicate entries.
Retains insertion order.
"""
set_key = key.encode(self._encoding or "utf-8")
set_value = value.encode(self._encoding or "utf-8")
lookup_key = set_key.lower()
found_indexes = [
idx
for idx, (_, item_key, _) in enumerate(self._list)
if item_key == lookup_key
]
for idx in reversed(found_indexes[1:]):
del self._list[idx]
if found_indexes:
idx = found_indexes[0]
self._list[idx] = (set_key, lookup_key, set_value)
else:
self._list.append((set_key, lookup_key, set_value))
def __delitem__(self, key: str) -> None:
"""
Remove the header `key`.
"""
del_key = key.lower().encode(self.encoding)
pop_indexes = [
idx
for idx, (_, item_key, _) in enumerate(self._list)
if item_key.lower() == del_key
]
if not pop_indexes:
raise KeyError(key)
for idx in reversed(pop_indexes):
del self._list[idx]
def __contains__(self, key: typing.Any) -> bool:
header_key = key.lower().encode(self.encoding)
return header_key in [key for _, key, _ in self._list]
def __iter__(self) -> typing.Iterator[typing.Any]:
return iter(self.keys())
def __len__(self) -> int:
return len(self._list)
def __eq__(self, other: typing.Any) -> bool:
try:
other_headers = Headers(other)
except ValueError:
return False
self_list = [(key, value) for _, key, value in self._list]
other_list = [(key, value) for _, key, value in other_headers._list]
return sorted(self_list) == sorted(other_list)
def __repr__(self) -> str:
class_name = self.__class__.__name__
encoding_str = ""
if self.encoding != "ascii":
encoding_str = f", encoding={self.encoding!r}"
as_list = list(obfuscate_sensitive_headers(self.multi_items()))
as_dict = dict(as_list)
no_duplicate_keys = len(as_dict) == len(as_list)
if no_duplicate_keys:
return f"{class_name}({as_dict!r}{encoding_str})"
return f"{class_name}({as_list!r}{encoding_str})"
class Request:
def __init__(
self,
method: str | bytes,
url: URL | str,
*,
params: QueryParamTypes | None = None,
headers: HeaderTypes | None = None,
cookies: CookieTypes | None = None,
content: RequestContent | None = None,
data: RequestData | None = None,
files: RequestFiles | None = None,
json: typing.Any | None = None,
stream: SyncByteStream | AsyncByteStream | None = None,
extensions: RequestExtensions | None = None,
) -> None:
self.method = (
method.decode("ascii").upper()
if isinstance(method, bytes)
else method.upper()
)
self.url = URL(url)
if params is not None:
self.url = self.url.copy_merge_params(params=params)
self.headers = Headers(headers)
self.extensions = {} if extensions is None else extensions
if cookies:
Cookies(cookies).set_cookie_header(self)
if stream is None:
content_type: str | None = self.headers.get("content-type")
headers, stream = encode_request(
content=content,
data=data,
files=files,
json=json,
boundary=get_multipart_boundary_from_content_type(
content_type=content_type.encode(self.headers.encoding)
if content_type
else None
),
)
self._prepare(headers)
self.stream = stream
# Load the request body, except for streaming content.
if isinstance(stream, ByteStream):
self.read()
else:
# There's an important distinction between `Request(content=...)`,
# and `Request(stream=...)`.
#
# Using `content=...` implies automatically populated `Host` and content
# headers, of either `Content-Length: ...` or `Transfer-Encoding: chunked`.
#
# Using `stream=...` will not automatically include *any*
# auto-populated headers.
#
# As an end-user you don't really need `stream=...`. It's only
# useful when:
#
# * Preserving the request stream when copying requests, eg for redirects.
# * Creating request instances on the *server-side* of the transport API.
self.stream = stream
def _prepare(self, default_headers: dict[str, str]) -> None:
for key, value in default_headers.items():
# Ignore Transfer-Encoding if the Content-Length has been set explicitly.
if key.lower() == "transfer-encoding" and "Content-Length" in self.headers:
continue
self.headers.setdefault(key, value)
auto_headers: list[tuple[bytes, bytes]] = []
has_host = "Host" in self.headers
has_content_length = (
"Content-Length" in self.headers or "Transfer-Encoding" in self.headers
)
if not has_host and self.url.host:
auto_headers.append((b"Host", self.url.netloc))
if not has_content_length and self.method in ("POST", "PUT", "PATCH"):
auto_headers.append((b"Content-Length", b"0"))
self.headers = Headers(auto_headers + self.headers.raw)
@property
def content(self) -> bytes:
if not hasattr(self, "_content"):
raise RequestNotRead()
return self._content
def read(self) -> bytes:
"""
Read and return the request content.
"""
if not hasattr(self, "_content"):
assert isinstance(self.stream, typing.Iterable)
self._content = b"".join(self.stream)
if not isinstance(self.stream, ByteStream):
# If a streaming request has been read entirely into memory, then
# we can replace the stream with a raw bytes implementation,
# to ensure that any non-replayable streams can still be used.
self.stream = ByteStream(self._content)
return self._content
async def aread(self) -> bytes:
"""
Read and return the request content.
"""
if not hasattr(self, "_content"):
assert isinstance(self.stream, typing.AsyncIterable)
self._content = b"".join([part async for part in self.stream])
if not isinstance(self.stream, ByteStream):
# If a streaming request has been read entirely into memory, then
# we can replace the stream with a raw bytes implementation,
# to ensure that any non-replayable streams can still be used.
self.stream = ByteStream(self._content)
return self._content
def __repr__(self) -> str:
class_name = self.__class__.__name__
url = str(self.url)
return f"<{class_name}({self.method!r}, {url!r})>"
def __getstate__(self) -> dict[str, typing.Any]:
return {
name: value
for name, value in self.__dict__.items()
if name not in ["extensions", "stream"]
}
def __setstate__(self, state: dict[str, typing.Any]) -> None:
for name, value in state.items():
setattr(self, name, value)
self.extensions = {}
self.stream = UnattachedStream()
class Response:
def __init__(
self,
status_code: int,
*,
headers: HeaderTypes | None = None,
content: ResponseContent | None = None,
text: str | None = None,
html: str | None = None,
json: typing.Any = None,
stream: SyncByteStream | AsyncByteStream | None = None,
request: Request | None = None,
extensions: ResponseExtensions | None = None,
history: list[Response] | None = None,
default_encoding: str | typing.Callable[[bytes], str] = "utf-8",
) -> None:
self.status_code = status_code
self.headers = Headers(headers)
self._request: Request | None = request
# When follow_redirects=False and a redirect is received,
# the client will set `response.next_request`.
self.next_request: Request | None = None
self.extensions: ResponseExtensions = {} if extensions is None else extensions
self.history = [] if history is None else list(history)
self.is_closed = False
self.is_stream_consumed = False
self.default_encoding = default_encoding
if stream is None:
headers, stream = encode_response(content, text, html, json)
self._prepare(headers)
self.stream = stream
if isinstance(stream, ByteStream):
# Load the response body, except for streaming content.
self.read()
else:
# There's an important distinction between `Response(content=...)`,
# and `Response(stream=...)`.
#
# Using `content=...` implies automatically populated content headers,
# of either `Content-Length: ...` or `Transfer-Encoding: chunked`.
#
# Using `stream=...` will not automatically include any content headers.
#
# As an end-user you don't really need `stream=...`. It's only
# useful when creating response instances having received a stream
# from the transport API.
self.stream = stream
self._num_bytes_downloaded = 0
def _prepare(self, default_headers: dict[str, str]) -> None:
for key, value in default_headers.items():
# Ignore Transfer-Encoding if the Content-Length has been set explicitly.
if key.lower() == "transfer-encoding" and "content-length" in self.headers:
continue
self.headers.setdefault(key, value)
@property
def elapsed(self) -> datetime.timedelta:
"""
Returns the time taken for the complete request/response
cycle to complete.
"""
if not hasattr(self, "_elapsed"):
raise RuntimeError(
"'.elapsed' may only be accessed after the response "
"has been read or closed."
)
return self._elapsed
@elapsed.setter
def elapsed(self, elapsed: datetime.timedelta) -> None:
self._elapsed = elapsed
@property
def request(self) -> Request:
"""
Returns the request instance associated to the current response.
"""
if self._request is None:
raise RuntimeError(
"The request instance has not been set on this response."
)
return self._request
@request.setter
def request(self, value: Request) -> None:
self._request = value
@property
def http_version(self) -> str:
try:
http_version: bytes = self.extensions["http_version"]
except KeyError:
return "HTTP/1.1"
else:
return http_version.decode("ascii", errors="ignore")
@property
def reason_phrase(self) -> str:
try:
reason_phrase: bytes = self.extensions["reason_phrase"]
except KeyError:
return codes.get_reason_phrase(self.status_code)
else:
return reason_phrase.decode("ascii", errors="ignore")
@property
def url(self) -> URL:
"""
Returns the URL for which the request was made.
"""
return self.request.url
@property
def content(self) -> bytes:
if not hasattr(self, "_content"):
raise ResponseNotRead()
return self._content
@property
def text(self) -> str:
if not hasattr(self, "_text"):
content = self.content
if not content:
self._text = ""
else:
decoder = TextDecoder(encoding=self.encoding or "utf-8")
self._text = "".join([decoder.decode(self.content), decoder.flush()])
return self._text
@property
def encoding(self) -> str | None:
"""
Return an encoding to use for decoding the byte content into text.
The priority for determining this is given by...
* `.encoding = <>` has been set explicitly.
* The encoding as specified by the charset parameter in the Content-Type header.
* The encoding as determined by `default_encoding`, which may either be
a string like "utf-8" indicating the encoding to use, or may be a callable
which enables charset autodetection.
"""
if not hasattr(self, "_encoding"):
encoding = self.charset_encoding
if encoding is None or not is_known_encoding(encoding):
if isinstance(self.default_encoding, str):
encoding = self.default_encoding
elif hasattr(self, "_content"):
encoding = self.default_encoding(self._content)
self._encoding = encoding or "utf-8"
return self._encoding
@encoding.setter
def encoding(self, value: str) -> None:
"""
Set the encoding to use for decoding the byte content into text.
If the `text` attribute has been accessed, attempting to set the
encoding will throw a ValueError.
"""
if hasattr(self, "_text"):
raise ValueError(
"Setting encoding after `text` has been accessed is not allowed."
)
self._encoding = value
@property
def charset_encoding(self) -> str | None:
"""
Return the encoding, as specified by the Content-Type header.
"""
content_type = self.headers.get("Content-Type")
if content_type is None:
return None
return parse_content_type_charset(content_type)
def _get_content_decoder(self) -> ContentDecoder:
"""
Returns a decoder instance which can be used to decode the raw byte
content, depending on the Content-Encoding used in the response.
"""
if not hasattr(self, "_decoder"):
decoders: list[ContentDecoder] = []
values = self.headers.get_list("content-encoding", split_commas=True)
for value in values:
value = value.strip().lower()
try:
decoder_cls = SUPPORTED_DECODERS[value]
decoders.append(decoder_cls())
except KeyError:
continue
if len(decoders) == 1:
self._decoder = decoders[0]
elif len(decoders) > 1:
self._decoder = MultiDecoder(children=decoders)
else:
self._decoder = IdentityDecoder()
return self._decoder
@property
def is_informational(self) -> bool:
"""
A property which is `True` for 1xx status codes, `False` otherwise.
"""
return codes.is_informational(self.status_code)
@property
def is_success(self) -> bool:
"""
A property which is `True` for 2xx status codes, `False` otherwise.
"""
return codes.is_success(self.status_code)
@property
def is_redirect(self) -> bool:
"""
A property which is `True` for 3xx status codes, `False` otherwise.
Note that not all responses with a 3xx status code indicate a URL redirect.
Use `response.has_redirect_location` to determine responses with a properly
formed URL redirection.
"""
return codes.is_redirect(self.status_code)
@property
def is_client_error(self) -> bool:
"""
A property which is `True` for 4xx status codes, `False` otherwise.
"""
return codes.is_client_error(self.status_code)
@property
def is_server_error(self) -> bool:
"""
A property which is `True` for 5xx status codes, `False` otherwise.
"""
return codes.is_server_error(self.status_code)
@property
def is_error(self) -> bool:
"""
A property which is `True` for 4xx and 5xx status codes, `False` otherwise.
"""
return codes.is_error(self.status_code)
@property
def has_redirect_location(self) -> bool:
"""
Returns True for 3xx responses with a properly formed URL redirection,
`False` otherwise.
"""
return (
self.status_code
in (
# 301 (Cacheable redirect. Method may change to GET.)
codes.MOVED_PERMANENTLY,
# 302 (Uncacheable redirect. Method may change to GET.)
codes.FOUND,
# 303 (Client should make a GET or HEAD request.)
codes.SEE_OTHER,
# 307 (Equiv. 302, but retain method)
codes.TEMPORARY_REDIRECT,
# 308 (Equiv. 301, but retain method)
codes.PERMANENT_REDIRECT,
)
and "Location" in self.headers
)
def raise_for_status(self) -> Response:
"""
Raise the `HTTPStatusError` if one occurred.
"""
request = self._request
if request is None:
raise RuntimeError(
"Cannot call `raise_for_status` as the request "
"instance has not been set on this response."
)
if self.is_success:
return self
if self.has_redirect_location:
message = (
"{error_type} '{0.status_code} {0.reason_phrase}' for url '{0.url}'\n"
"Redirect location: '{0.headers[location]}'\n"
"For more information check: https://developer.mozilla.org/en-US/docs/Web/HTTP/Status/{0.status_code}"
)
else:
message = (
"{error_type} '{0.status_code} {0.reason_phrase}' for url '{0.url}'\n"
"For more information check: https://developer.mozilla.org/en-US/docs/Web/HTTP/Status/{0.status_code}"
)
status_class = self.status_code // 100
error_types = {
1: "Informational response",
3: "Redirect response",
4: "Client error",
5: "Server error",
}
error_type = error_types.get(status_class, "Invalid status code")
message = message.format(self, error_type=error_type)
raise HTTPStatusError(message, request=request, response=self)
def json(self, **kwargs: typing.Any) -> typing.Any:
return jsonlib.loads(self.content, **kwargs)
@property
def cookies(self) -> Cookies:
if not hasattr(self, "_cookies"):
self._cookies = Cookies()
self._cookies.extract_cookies(self)
return self._cookies
@property
def links(self) -> dict[str | None, dict[str, str]]:
"""
Returns the parsed header links of the response, if any
"""
header = self.headers.get("link")
if header is None:
return {}
return {
(link.get("rel") or link.get("url")): link
for link in parse_header_links(header)
}
@property
def num_bytes_downloaded(self) -> int:
return self._num_bytes_downloaded
def __repr__(self) -> str:
return f"<Response [{self.status_code} {self.reason_phrase}]>"
def __getstate__(self) -> dict[str, typing.Any]:
return {
name: value
for name, value in self.__dict__.items()
if name not in ["extensions", "stream", "is_closed", "_decoder"]
}
def __setstate__(self, state: dict[str, typing.Any]) -> None:
for name, value in state.items():
setattr(self, name, value)
self.is_closed = True
self.extensions = {}
self.stream = UnattachedStream()
def read(self) -> bytes:
"""
Read and return the response content.
"""
if not hasattr(self, "_content"):
self._content = b"".join(self.iter_bytes())
return self._content
def iter_bytes(self, chunk_size: int | None = None) -> typing.Iterator[bytes]:
"""
A byte-iterator over the decoded response content.
This allows us to handle gzip, deflate, and brotli encoded responses.
"""
if hasattr(self, "_content"):
chunk_size = len(self._content) if chunk_size is None else chunk_size
for i in range(0, len(self._content), max(chunk_size, 1)):
yield self._content[i : i + chunk_size]
else:
decoder = self._get_content_decoder()
chunker = ByteChunker(chunk_size=chunk_size)
with request_context(request=self._request):
for raw_bytes in self.iter_raw():
decoded = decoder.decode(raw_bytes)
for chunk in chunker.decode(decoded):
yield chunk
decoded = decoder.flush()
for chunk in chunker.decode(decoded):
yield chunk # pragma: no cover
for chunk in chunker.flush():
yield chunk
def iter_text(self, chunk_size: int | None = None) -> typing.Iterator[str]:
"""
A str-iterator over the decoded response content
that handles both gzip, deflate, etc but also detects the content's
string encoding.
"""
decoder = TextDecoder(encoding=self.encoding or "utf-8")
chunker = TextChunker(chunk_size=chunk_size)
with request_context(request=self._request):
for byte_content in self.iter_bytes():
text_content = decoder.decode(byte_content)
for chunk in chunker.decode(text_content):
yield chunk
text_content = decoder.flush()
for chunk in chunker.decode(text_content):
yield chunk # pragma: no cover
for chunk in chunker.flush():
yield chunk
def iter_lines(self) -> typing.Iterator[str]:
decoder = LineDecoder()
with request_context(request=self._request):
for text in self.iter_text():
for line in decoder.decode(text):
yield line
for line in decoder.flush():
yield line
def iter_raw(self, chunk_size: int | None = None) -> typing.Iterator[bytes]:
"""
A byte-iterator over the raw response content.
"""
if self.is_stream_consumed:
raise StreamConsumed()
if self.is_closed:
raise StreamClosed()
if not isinstance(self.stream, SyncByteStream):
raise RuntimeError("Attempted to call a sync iterator on an async stream.")
self.is_stream_consumed = True
self._num_bytes_downloaded = 0
chunker = ByteChunker(chunk_size=chunk_size)
with request_context(request=self._request):
for raw_stream_bytes in self.stream:
self._num_bytes_downloaded += len(raw_stream_bytes)
for chunk in chunker.decode(raw_stream_bytes):
yield chunk
for chunk in chunker.flush():
yield chunk
self.close()
def close(self) -> None:
"""
Close the response and release the connection.
Automatically called if the response body is read to completion.
"""
if not isinstance(self.stream, SyncByteStream):
raise RuntimeError("Attempted to call an sync close on an async stream.")
if not self.is_closed:
self.is_closed = True
with request_context(request=self._request):
self.stream.close()
async def aread(self) -> bytes:
"""
Read and return the response content.
"""
if not hasattr(self, "_content"):
self._content = b"".join([part async for part in self.aiter_bytes()])
return self._content
async def aiter_bytes(
self, chunk_size: int | None = None
) -> typing.AsyncIterator[bytes]:
"""
A byte-iterator over the decoded response content.
This allows us to handle gzip, deflate, and brotli encoded responses.
"""
if hasattr(self, "_content"):
chunk_size = len(self._content) if chunk_size is None else chunk_size
for i in range(0, len(self._content), max(chunk_size, 1)):
yield self._content[i : i + chunk_size]
else:
decoder = self._get_content_decoder()
chunker = ByteChunker(chunk_size=chunk_size)
with request_context(request=self._request):
async for raw_bytes in self.aiter_raw():
decoded = decoder.decode(raw_bytes)
for chunk in chunker.decode(decoded):
yield chunk
decoded = decoder.flush()
for chunk in chunker.decode(decoded):
yield chunk # pragma: no cover
for chunk in chunker.flush():
yield chunk
async def aiter_text(
self, chunk_size: int | None = None
) -> typing.AsyncIterator[str]:
"""
A str-iterator over the decoded response content
that handles both gzip, deflate, etc but also detects the content's
string encoding.
"""
decoder = TextDecoder(encoding=self.encoding or "utf-8")
chunker = TextChunker(chunk_size=chunk_size)
with request_context(request=self._request):
async for byte_content in self.aiter_bytes():
text_content = decoder.decode(byte_content)
for chunk in chunker.decode(text_content):
yield chunk
text_content = decoder.flush()
for chunk in chunker.decode(text_content):
yield chunk # pragma: no cover
for chunk in chunker.flush():
yield chunk
async def aiter_lines(self) -> typing.AsyncIterator[str]:
decoder = LineDecoder()
with request_context(request=self._request):
async for text in self.aiter_text():
for line in decoder.decode(text):
yield line
for line in decoder.flush():
yield line
async def aiter_raw(
self, chunk_size: int | None = None
) -> typing.AsyncIterator[bytes]:
"""
A byte-iterator over the raw response content.
"""
if self.is_stream_consumed:
raise StreamConsumed()
if self.is_closed:
raise StreamClosed()
if not isinstance(self.stream, AsyncByteStream):
raise RuntimeError("Attempted to call an async iterator on an sync stream.")
self.is_stream_consumed = True
self._num_bytes_downloaded = 0
chunker = ByteChunker(chunk_size=chunk_size)
with request_context(request=self._request):
async for raw_stream_bytes in self.stream:
self._num_bytes_downloaded += len(raw_stream_bytes)
for chunk in chunker.decode(raw_stream_bytes):
yield chunk
for chunk in chunker.flush():
yield chunk
await self.aclose()
async def aclose(self) -> None:
"""
Close the response and release the connection.
Automatically called if the response body is read to completion.
"""
if not isinstance(self.stream, AsyncByteStream):
raise RuntimeError("Attempted to call an async close on an sync stream.")
if not self.is_closed:
self.is_closed = True
with request_context(request=self._request):
await self.stream.aclose()
class Cookies(typing.MutableMapping[str, str]):
"""
HTTP Cookies, as a mutable mapping.
"""
def __init__(self, cookies: CookieTypes | None = None) -> None:
if cookies is None or isinstance(cookies, dict):
self.jar = CookieJar()
if isinstance(cookies, dict):
for key, value in cookies.items():
self.set(key, value)
elif isinstance(cookies, list):
self.jar = CookieJar()
for key, value in cookies:
self.set(key, value)
elif isinstance(cookies, Cookies):
self.jar = CookieJar()
for cookie in cookies.jar:
self.jar.set_cookie(cookie)
else:
self.jar = cookies
def extract_cookies(self, response: Response) -> None:
"""
Loads any cookies based on the response `Set-Cookie` headers.
"""
urllib_response = self._CookieCompatResponse(response)
urllib_request = self._CookieCompatRequest(response.request)
self.jar.extract_cookies(urllib_response, urllib_request) # type: ignore
def set_cookie_header(self, request: Request) -> None:
"""
Sets an appropriate 'Cookie:' HTTP header on the `Request`.
"""
urllib_request = self._CookieCompatRequest(request)
self.jar.add_cookie_header(urllib_request)
def set(self, name: str, value: str, domain: str = "", path: str = "/") -> None:
"""
Set a cookie value by name. May optionally include domain and path.
"""
kwargs = {
"version": 0,
"name": name,
"value": value,
"port": None,
"port_specified": False,
"domain": domain,
"domain_specified": bool(domain),
"domain_initial_dot": domain.startswith("."),
"path": path,
"path_specified": bool(path),
"secure": False,
"expires": None,
"discard": True,
"comment": None,
"comment_url": None,
"rest": {"HttpOnly": None},
"rfc2109": False,
}
cookie = Cookie(**kwargs) # type: ignore
self.jar.set_cookie(cookie)
def get( # type: ignore
self,
name: str,
default: str | None = None,
domain: str | None = None,
path: str | None = None,
) -> str | None:
"""
Get a cookie by name. May optionally include domain and path
in order to specify exactly which cookie to retrieve.
"""
value = None
for cookie in self.jar:
if cookie.name == name:
if domain is None or cookie.domain == domain:
if path is None or cookie.path == path:
if value is not None:
message = f"Multiple cookies exist with name={name}"
raise CookieConflict(message)
value = cookie.value
if value is None:
return default
return value
def delete(
self,
name: str,
domain: str | None = None,
path: str | None = None,
) -> None:
"""
Delete a cookie by name. May optionally include domain and path
in order to specify exactly which cookie to delete.
"""
if domain is not None and path is not None:
return self.jar.clear(domain, path, name)
remove = [
cookie
for cookie in self.jar
if cookie.name == name
and (domain is None or cookie.domain == domain)
and (path is None or cookie.path == path)
]
for cookie in remove:
self.jar.clear(cookie.domain, cookie.path, cookie.name)
def clear(self, domain: str | None = None, path: str | None = None) -> None:
"""
Delete all cookies. Optionally include a domain and path in
order to only delete a subset of all the cookies.
"""
args = []
if domain is not None:
args.append(domain)
if path is not None:
assert domain is not None
args.append(path)
self.jar.clear(*args)
def update(self, cookies: CookieTypes | None = None) -> None: # type: ignore
cookies = Cookies(cookies)
for cookie in cookies.jar:
self.jar.set_cookie(cookie)
def __setitem__(self, name: str, value: str) -> None:
return self.set(name, value)
def __getitem__(self, name: str) -> str:
value = self.get(name)
if value is None:
raise KeyError(name)
return value
def __delitem__(self, name: str) -> None:
return self.delete(name)
def __len__(self) -> int:
return len(self.jar)
def __iter__(self) -> typing.Iterator[str]:
return (cookie.name for cookie in self.jar)
def __bool__(self) -> bool:
for _ in self.jar:
return True
return False
def __repr__(self) -> str:
cookies_repr = ", ".join(
[
f"<Cookie {cookie.name}={cookie.value} for {cookie.domain} />"
for cookie in self.jar
]
)
return f"<Cookies[{cookies_repr}]>"
class _CookieCompatRequest(urllib.request.Request):
"""
Wraps a `Request` instance up in a compatibility interface suitable
for use with `CookieJar` operations.
"""
def __init__(self, request: Request) -> None:
super().__init__(
url=str(request.url),
headers=dict(request.headers),
method=request.method,
)
self.request = request
def add_unredirected_header(self, key: str, value: str) -> None:
super().add_unredirected_header(key, value)
self.request.headers[key] = value
class _CookieCompatResponse:
"""
Wraps a `Request` instance up in a compatibility interface suitable
for use with `CookieJar` operations.
"""
def __init__(self, response: Response) -> None:
self.response = response
def info(self) -> email.message.Message:
info = email.message.Message()
for key, value in self.response.headers.multi_items():
# Note that setting `info[key]` here is an "append" operation,
# not a "replace" operation.
# https://docs.python.org/3/library/email.compat32-message.html#email.message.Message.__setitem__
info[key] = value
return info