182 lines
5.4 KiB
Python
182 lines
5.4 KiB
Python
|
from __future__ import annotations
|
||
|
|
||
|
import typing
|
||
|
|
||
|
import sniffio
|
||
|
|
||
|
from .._models import Request, Response
|
||
|
from .._types import AsyncByteStream
|
||
|
from .base import AsyncBaseTransport
|
||
|
|
||
|
if typing.TYPE_CHECKING: # pragma: no cover
|
||
|
import asyncio
|
||
|
|
||
|
import trio
|
||
|
|
||
|
Event = typing.Union[asyncio.Event, trio.Event]
|
||
|
|
||
|
|
||
|
_Message = typing.Dict[str, typing.Any]
|
||
|
_Receive = typing.Callable[[], typing.Awaitable[_Message]]
|
||
|
_Send = typing.Callable[
|
||
|
[typing.Dict[str, typing.Any]], typing.Coroutine[None, None, None]
|
||
|
]
|
||
|
_ASGIApp = typing.Callable[
|
||
|
[typing.Dict[str, typing.Any], _Receive, _Send], typing.Coroutine[None, None, None]
|
||
|
]
|
||
|
|
||
|
|
||
|
def create_event() -> Event:
|
||
|
if sniffio.current_async_library() == "trio":
|
||
|
import trio
|
||
|
|
||
|
return trio.Event()
|
||
|
else:
|
||
|
import asyncio
|
||
|
|
||
|
return asyncio.Event()
|
||
|
|
||
|
|
||
|
class ASGIResponseStream(AsyncByteStream):
|
||
|
def __init__(self, body: list[bytes]) -> None:
|
||
|
self._body = body
|
||
|
|
||
|
async def __aiter__(self) -> typing.AsyncIterator[bytes]:
|
||
|
yield b"".join(self._body)
|
||
|
|
||
|
|
||
|
class ASGITransport(AsyncBaseTransport):
|
||
|
"""
|
||
|
A custom AsyncTransport that handles sending requests directly to an ASGI app.
|
||
|
The simplest way to use this functionality is to use the `app` argument.
|
||
|
|
||
|
```
|
||
|
client = httpx.AsyncClient(app=app)
|
||
|
```
|
||
|
|
||
|
Alternatively, you can setup the transport instance explicitly.
|
||
|
This allows you to include any additional configuration arguments specific
|
||
|
to the ASGITransport class:
|
||
|
|
||
|
```
|
||
|
transport = httpx.ASGITransport(
|
||
|
app=app,
|
||
|
root_path="/submount",
|
||
|
client=("1.2.3.4", 123)
|
||
|
)
|
||
|
client = httpx.AsyncClient(transport=transport)
|
||
|
```
|
||
|
|
||
|
Arguments:
|
||
|
|
||
|
* `app` - The ASGI application.
|
||
|
* `raise_app_exceptions` - Boolean indicating if exceptions in the application
|
||
|
should be raised. Default to `True`. Can be set to `False` for use cases
|
||
|
such as testing the content of a client 500 response.
|
||
|
* `root_path` - The root path on which the ASGI application should be mounted.
|
||
|
* `client` - A two-tuple indicating the client IP and port of incoming requests.
|
||
|
```
|
||
|
"""
|
||
|
|
||
|
def __init__(
|
||
|
self,
|
||
|
app: _ASGIApp,
|
||
|
raise_app_exceptions: bool = True,
|
||
|
root_path: str = "",
|
||
|
client: tuple[str, int] = ("127.0.0.1", 123),
|
||
|
) -> None:
|
||
|
self.app = app
|
||
|
self.raise_app_exceptions = raise_app_exceptions
|
||
|
self.root_path = root_path
|
||
|
self.client = client
|
||
|
|
||
|
async def handle_async_request(
|
||
|
self,
|
||
|
request: Request,
|
||
|
) -> Response:
|
||
|
assert isinstance(request.stream, AsyncByteStream)
|
||
|
|
||
|
# ASGI scope.
|
||
|
scope = {
|
||
|
"type": "http",
|
||
|
"asgi": {"version": "3.0"},
|
||
|
"http_version": "1.1",
|
||
|
"method": request.method,
|
||
|
"headers": [(k.lower(), v) for (k, v) in request.headers.raw],
|
||
|
"scheme": request.url.scheme,
|
||
|
"path": request.url.path,
|
||
|
"raw_path": request.url.raw_path.split(b"?")[0],
|
||
|
"query_string": request.url.query,
|
||
|
"server": (request.url.host, request.url.port),
|
||
|
"client": self.client,
|
||
|
"root_path": self.root_path,
|
||
|
}
|
||
|
|
||
|
# Request.
|
||
|
request_body_chunks = request.stream.__aiter__()
|
||
|
request_complete = False
|
||
|
|
||
|
# Response.
|
||
|
status_code = None
|
||
|
response_headers = None
|
||
|
body_parts = []
|
||
|
response_started = False
|
||
|
response_complete = create_event()
|
||
|
|
||
|
# ASGI callables.
|
||
|
|
||
|
async def receive() -> dict[str, typing.Any]:
|
||
|
nonlocal request_complete
|
||
|
|
||
|
if request_complete:
|
||
|
await response_complete.wait()
|
||
|
return {"type": "http.disconnect"}
|
||
|
|
||
|
try:
|
||
|
body = await request_body_chunks.__anext__()
|
||
|
except StopAsyncIteration:
|
||
|
request_complete = True
|
||
|
return {"type": "http.request", "body": b"", "more_body": False}
|
||
|
return {"type": "http.request", "body": body, "more_body": True}
|
||
|
|
||
|
async def send(message: dict[str, typing.Any]) -> None:
|
||
|
nonlocal status_code, response_headers, response_started
|
||
|
|
||
|
if message["type"] == "http.response.start":
|
||
|
assert not response_started
|
||
|
|
||
|
status_code = message["status"]
|
||
|
response_headers = message.get("headers", [])
|
||
|
response_started = True
|
||
|
|
||
|
elif message["type"] == "http.response.body":
|
||
|
assert not response_complete.is_set()
|
||
|
body = message.get("body", b"")
|
||
|
more_body = message.get("more_body", False)
|
||
|
|
||
|
if body and request.method != "HEAD":
|
||
|
body_parts.append(body)
|
||
|
|
||
|
if not more_body:
|
||
|
response_complete.set()
|
||
|
|
||
|
try:
|
||
|
await self.app(scope, receive, send)
|
||
|
except Exception: # noqa: PIE-786
|
||
|
if self.raise_app_exceptions:
|
||
|
raise
|
||
|
|
||
|
response_complete.set()
|
||
|
if status_code is None:
|
||
|
status_code = 500
|
||
|
if response_headers is None:
|
||
|
response_headers = {}
|
||
|
|
||
|
assert response_complete.is_set()
|
||
|
assert status_code is not None
|
||
|
assert response_headers is not None
|
||
|
|
||
|
stream = ASGIResponseStream(body_parts)
|
||
|
|
||
|
return Response(status_code, headers=response_headers, stream=stream)
|