ai-content-maker/.venv/Lib/site-packages/openai/lib/streaming/_assistants.py

1036 lines
40 KiB
Python
Raw Permalink Normal View History

2024-05-11 23:00:43 +03:00
from __future__ import annotations
import asyncio
from types import TracebackType
from typing import TYPE_CHECKING, Any, Generic, TypeVar, Callable, Iterable, Iterator, cast
from typing_extensions import Awaitable, AsyncIterable, AsyncIterator, assert_never
import httpx
from ..._utils import is_dict, is_list, consume_sync_iterator, consume_async_iterator
from ..._models import construct_type
from ..._streaming import Stream, AsyncStream
from ...types.beta import AssistantStreamEvent
from ...types.beta.threads import (
Run,
Text,
Message,
ImageFile,
TextDelta,
MessageDelta,
MessageContent,
MessageContentDelta,
)
from ...types.beta.threads.runs import RunStep, ToolCall, RunStepDelta, ToolCallDelta
class AssistantEventHandler:
text_deltas: Iterable[str]
"""Iterator over just the text deltas in the stream.
This corresponds to the `thread.message.delta` event
in the API.
```py
for text in stream.text_deltas:
print(text, end="", flush=True)
print()
```
"""
def __init__(self) -> None:
self._current_event: AssistantStreamEvent | None = None
self._current_message_content_index: int | None = None
self._current_message_content: MessageContent | None = None
self._current_tool_call_index: int | None = None
self._current_tool_call: ToolCall | None = None
self.__current_run_step_id: str | None = None
self.__current_run: Run | None = None
self.__run_step_snapshots: dict[str, RunStep] = {}
self.__message_snapshots: dict[str, Message] = {}
self.__current_message_snapshot: Message | None = None
self.text_deltas = self.__text_deltas__()
self._iterator = self.__stream__()
self.__stream: Stream[AssistantStreamEvent] | None = None
def _init(self, stream: Stream[AssistantStreamEvent]) -> None:
if self.__stream:
raise RuntimeError(
"A single event handler cannot be shared between multiple streams; You will need to construct a new event handler instance"
)
self.__stream = stream
def __next__(self) -> AssistantStreamEvent:
return self._iterator.__next__()
def __iter__(self) -> Iterator[AssistantStreamEvent]:
for item in self._iterator:
yield item
@property
def current_event(self) -> AssistantStreamEvent | None:
return self._current_event
@property
def current_run(self) -> Run | None:
return self.__current_run
@property
def current_run_step_snapshot(self) -> RunStep | None:
if not self.__current_run_step_id:
return None
return self.__run_step_snapshots[self.__current_run_step_id]
@property
def current_message_snapshot(self) -> Message | None:
return self.__current_message_snapshot
def close(self) -> None:
"""
Close the response and release the connection.
Automatically called when the context manager exits.
"""
if self.__stream:
self.__stream.close()
def until_done(self) -> None:
"""Waits until the stream has been consumed"""
consume_sync_iterator(self)
def get_final_run(self) -> Run:
"""Wait for the stream to finish and returns the completed Run object"""
self.until_done()
if not self.__current_run:
raise RuntimeError("No final run object found")
return self.__current_run
def get_final_run_steps(self) -> list[RunStep]:
"""Wait for the stream to finish and returns the steps taken in this run"""
self.until_done()
if not self.__run_step_snapshots:
raise RuntimeError("No run steps found")
return [step for step in self.__run_step_snapshots.values()]
def get_final_messages(self) -> list[Message]:
"""Wait for the stream to finish and returns the messages emitted in this run"""
self.until_done()
if not self.__message_snapshots:
raise RuntimeError("No messages found")
return [message for message in self.__message_snapshots.values()]
def __text_deltas__(self) -> Iterator[str]:
for event in self:
if event.event != "thread.message.delta":
continue
for content_delta in event.data.delta.content or []:
if content_delta.type == "text" and content_delta.text and content_delta.text.value:
yield content_delta.text.value
# event handlers
def on_end(self) -> None:
"""Fires when the stream has finished.
This happens if the stream is read to completion
or if an exception occurs during iteration.
"""
def on_event(self, event: AssistantStreamEvent) -> None:
"""Callback that is fired for every Server-Sent-Event"""
def on_run_step_created(self, run_step: RunStep) -> None:
"""Callback that is fired when a run step is created"""
def on_run_step_delta(self, delta: RunStepDelta, snapshot: RunStep) -> None:
"""Callback that is fired whenever a run step delta is returned from the API
The first argument is just the delta as sent by the API and the second argument
is the accumulated snapshot of the run step. For example, a tool calls event may
look like this:
# delta
tool_calls=[
RunStepDeltaToolCallsCodeInterpreter(
index=0,
type='code_interpreter',
id=None,
code_interpreter=CodeInterpreter(input=' sympy', outputs=None)
)
]
# snapshot
tool_calls=[
CodeToolCall(
id='call_wKayJlcYV12NiadiZuJXxcfx',
code_interpreter=CodeInterpreter(input='from sympy', outputs=[]),
type='code_interpreter',
index=0
)
],
"""
def on_run_step_done(self, run_step: RunStep) -> None:
"""Callback that is fired when a run step is completed"""
def on_tool_call_created(self, tool_call: ToolCall) -> None:
"""Callback that is fired when a tool call is created"""
def on_tool_call_delta(self, delta: ToolCallDelta, snapshot: ToolCall) -> None:
"""Callback that is fired when a tool call delta is encountered"""
def on_tool_call_done(self, tool_call: ToolCall) -> None:
"""Callback that is fired when a tool call delta is encountered"""
def on_exception(self, exception: Exception) -> None:
"""Fired whenever an exception happens during streaming"""
def on_timeout(self) -> None:
"""Fires if the request times out"""
def on_message_created(self, message: Message) -> None:
"""Callback that is fired when a message is created"""
def on_message_delta(self, delta: MessageDelta, snapshot: Message) -> None:
"""Callback that is fired whenever a message delta is returned from the API
The first argument is just the delta as sent by the API and the second argument
is the accumulated snapshot of the message. For example, a text content event may
look like this:
# delta
MessageDeltaText(
index=0,
type='text',
text=Text(
value=' Jane'
),
)
# snapshot
MessageContentText(
index=0,
type='text',
text=Text(
value='Certainly, Jane'
),
)
"""
def on_message_done(self, message: Message) -> None:
"""Callback that is fired when a message is completed"""
def on_text_created(self, text: Text) -> None:
"""Callback that is fired when a text content block is created"""
def on_text_delta(self, delta: TextDelta, snapshot: Text) -> None:
"""Callback that is fired whenever a text content delta is returned
by the API.
The first argument is just the delta as sent by the API and the second argument
is the accumulated snapshot of the text. For example:
on_text_delta(TextDelta(value="The"), Text(value="The")),
on_text_delta(TextDelta(value=" solution"), Text(value="The solution")),
on_text_delta(TextDelta(value=" to"), Text(value="The solution to")),
on_text_delta(TextDelta(value=" the"), Text(value="The solution to the")),
on_text_delta(TextDelta(value=" equation"), Text(value="The solution to the equivalent")),
"""
def on_text_done(self, text: Text) -> None:
"""Callback that is fired when a text content block is finished"""
def on_image_file_done(self, image_file: ImageFile) -> None:
"""Callback that is fired when an image file block is finished"""
def _emit_sse_event(self, event: AssistantStreamEvent) -> None:
self._current_event = event
self.on_event(event)
self.__current_message_snapshot, new_content = accumulate_event(
event=event,
current_message_snapshot=self.__current_message_snapshot,
)
if self.__current_message_snapshot is not None:
self.__message_snapshots[self.__current_message_snapshot.id] = self.__current_message_snapshot
accumulate_run_step(
event=event,
run_step_snapshots=self.__run_step_snapshots,
)
for content_delta in new_content:
assert self.__current_message_snapshot is not None
block = self.__current_message_snapshot.content[content_delta.index]
if block.type == "text":
self.on_text_created(block.text)
if (
event.event == "thread.run.completed"
or event.event == "thread.run.cancelled"
or event.event == "thread.run.expired"
or event.event == "thread.run.failed"
or event.event == "thread.run.requires_action"
):
self.__current_run = event.data
if self._current_tool_call:
self.on_tool_call_done(self._current_tool_call)
elif (
event.event == "thread.run.created"
or event.event == "thread.run.in_progress"
or event.event == "thread.run.cancelling"
or event.event == "thread.run.queued"
):
self.__current_run = event.data
elif event.event == "thread.message.created":
self.on_message_created(event.data)
elif event.event == "thread.message.delta":
snapshot = self.__current_message_snapshot
assert snapshot is not None
message_delta = event.data.delta
if message_delta.content is not None:
for content_delta in message_delta.content:
if content_delta.type == "text" and content_delta.text:
snapshot_content = snapshot.content[content_delta.index]
assert snapshot_content.type == "text"
self.on_text_delta(content_delta.text, snapshot_content.text)
# If the delta is for a new message content:
# - emit on_text_done/on_image_file_done for the previous message content
# - emit on_text_created/on_image_created for the new message content
if content_delta.index != self._current_message_content_index:
if self._current_message_content is not None:
if self._current_message_content.type == "text":
self.on_text_done(self._current_message_content.text)
elif self._current_message_content.type == "image_file":
self.on_image_file_done(self._current_message_content.image_file)
self._current_message_content_index = content_delta.index
self._current_message_content = snapshot.content[content_delta.index]
# Update the current_message_content (delta event is correctly emitted already)
self._current_message_content = snapshot.content[content_delta.index]
self.on_message_delta(event.data.delta, snapshot)
elif event.event == "thread.message.completed" or event.event == "thread.message.incomplete":
self.__current_message_snapshot = event.data
self.__message_snapshots[event.data.id] = event.data
if self._current_message_content_index is not None:
content = event.data.content[self._current_message_content_index]
if content.type == "text":
self.on_text_done(content.text)
elif content.type == "image_file":
self.on_image_file_done(content.image_file)
self.on_message_done(event.data)
elif event.event == "thread.run.step.created":
self.__current_run_step_id = event.data.id
self.on_run_step_created(event.data)
elif event.event == "thread.run.step.in_progress":
self.__current_run_step_id = event.data.id
elif event.event == "thread.run.step.delta":
step_snapshot = self.__run_step_snapshots[event.data.id]
run_step_delta = event.data.delta
if (
run_step_delta.step_details
and run_step_delta.step_details.type == "tool_calls"
and run_step_delta.step_details.tool_calls is not None
):
assert step_snapshot.step_details.type == "tool_calls"
for tool_call_delta in run_step_delta.step_details.tool_calls:
if tool_call_delta.index == self._current_tool_call_index:
self.on_tool_call_delta(
tool_call_delta,
step_snapshot.step_details.tool_calls[tool_call_delta.index],
)
# If the delta is for a new tool call:
# - emit on_tool_call_done for the previous tool_call
# - emit on_tool_call_created for the new tool_call
if tool_call_delta.index != self._current_tool_call_index:
if self._current_tool_call is not None:
self.on_tool_call_done(self._current_tool_call)
self._current_tool_call_index = tool_call_delta.index
self._current_tool_call = step_snapshot.step_details.tool_calls[tool_call_delta.index]
self.on_tool_call_created(self._current_tool_call)
# Update the current_tool_call (delta event is correctly emitted already)
self._current_tool_call = step_snapshot.step_details.tool_calls[tool_call_delta.index]
self.on_run_step_delta(
event.data.delta,
step_snapshot,
)
elif (
event.event == "thread.run.step.completed"
or event.event == "thread.run.step.cancelled"
or event.event == "thread.run.step.expired"
or event.event == "thread.run.step.failed"
):
if self._current_tool_call:
self.on_tool_call_done(self._current_tool_call)
self.on_run_step_done(event.data)
self.__current_run_step_id = None
elif event.event == "thread.created" or event.event == "thread.message.in_progress" or event.event == "error":
# currently no special handling
...
else:
# we only want to error at build-time
if TYPE_CHECKING: # type: ignore[unreachable]
assert_never(event)
self._current_event = None
def __stream__(self) -> Iterator[AssistantStreamEvent]:
stream = self.__stream
if not stream:
raise RuntimeError("Stream has not been started yet")
try:
for event in stream:
self._emit_sse_event(event)
yield event
except (httpx.TimeoutException, asyncio.TimeoutError) as exc:
self.on_timeout()
self.on_exception(exc)
raise
except Exception as exc:
self.on_exception(exc)
raise
finally:
self.on_end()
AssistantEventHandlerT = TypeVar("AssistantEventHandlerT", bound=AssistantEventHandler)
class AssistantStreamManager(Generic[AssistantEventHandlerT]):
"""Wrapper over AssistantStreamEventHandler that is returned by `.stream()`
so that a context manager can be used.
```py
with client.threads.create_and_run_stream(...) as stream:
for event in stream:
...
```
"""
def __init__(
self,
api_request: Callable[[], Stream[AssistantStreamEvent]],
*,
event_handler: AssistantEventHandlerT,
) -> None:
self.__stream: Stream[AssistantStreamEvent] | None = None
self.__event_handler = event_handler
self.__api_request = api_request
def __enter__(self) -> AssistantEventHandlerT:
self.__stream = self.__api_request()
self.__event_handler._init(self.__stream)
return self.__event_handler
def __exit__(
self,
exc_type: type[BaseException] | None,
exc: BaseException | None,
exc_tb: TracebackType | None,
) -> None:
if self.__stream is not None:
self.__stream.close()
class AsyncAssistantEventHandler:
text_deltas: AsyncIterable[str]
"""Iterator over just the text deltas in the stream.
This corresponds to the `thread.message.delta` event
in the API.
```py
async for text in stream.text_deltas:
print(text, end="", flush=True)
print()
```
"""
def __init__(self) -> None:
self._current_event: AssistantStreamEvent | None = None
self._current_message_content_index: int | None = None
self._current_message_content: MessageContent | None = None
self._current_tool_call_index: int | None = None
self._current_tool_call: ToolCall | None = None
self.__current_run_step_id: str | None = None
self.__current_run: Run | None = None
self.__run_step_snapshots: dict[str, RunStep] = {}
self.__message_snapshots: dict[str, Message] = {}
self.__current_message_snapshot: Message | None = None
self.text_deltas = self.__text_deltas__()
self._iterator = self.__stream__()
self.__stream: AsyncStream[AssistantStreamEvent] | None = None
def _init(self, stream: AsyncStream[AssistantStreamEvent]) -> None:
if self.__stream:
raise RuntimeError(
"A single event handler cannot be shared between multiple streams; You will need to construct a new event handler instance"
)
self.__stream = stream
async def __anext__(self) -> AssistantStreamEvent:
return await self._iterator.__anext__()
async def __aiter__(self) -> AsyncIterator[AssistantStreamEvent]:
async for item in self._iterator:
yield item
async def close(self) -> None:
"""
Close the response and release the connection.
Automatically called when the context manager exits.
"""
if self.__stream:
await self.__stream.close()
@property
def current_event(self) -> AssistantStreamEvent | None:
return self._current_event
@property
def current_run(self) -> Run | None:
return self.__current_run
@property
def current_run_step_snapshot(self) -> RunStep | None:
if not self.__current_run_step_id:
return None
return self.__run_step_snapshots[self.__current_run_step_id]
@property
def current_message_snapshot(self) -> Message | None:
return self.__current_message_snapshot
async def until_done(self) -> None:
"""Waits until the stream has been consumed"""
await consume_async_iterator(self)
async def get_final_run(self) -> Run:
"""Wait for the stream to finish and returns the completed Run object"""
await self.until_done()
if not self.__current_run:
raise RuntimeError("No final run object found")
return self.__current_run
async def get_final_run_steps(self) -> list[RunStep]:
"""Wait for the stream to finish and returns the steps taken in this run"""
await self.until_done()
if not self.__run_step_snapshots:
raise RuntimeError("No run steps found")
return [step for step in self.__run_step_snapshots.values()]
async def get_final_messages(self) -> list[Message]:
"""Wait for the stream to finish and returns the messages emitted in this run"""
await self.until_done()
if not self.__message_snapshots:
raise RuntimeError("No messages found")
return [message for message in self.__message_snapshots.values()]
async def __text_deltas__(self) -> AsyncIterator[str]:
async for event in self:
if event.event != "thread.message.delta":
continue
for content_delta in event.data.delta.content or []:
if content_delta.type == "text" and content_delta.text and content_delta.text.value:
yield content_delta.text.value
# event handlers
async def on_end(self) -> None:
"""Fires when the stream has finished.
This happens if the stream is read to completion
or if an exception occurs during iteration.
"""
async def on_event(self, event: AssistantStreamEvent) -> None:
"""Callback that is fired for every Server-Sent-Event"""
async def on_run_step_created(self, run_step: RunStep) -> None:
"""Callback that is fired when a run step is created"""
async def on_run_step_delta(self, delta: RunStepDelta, snapshot: RunStep) -> None:
"""Callback that is fired whenever a run step delta is returned from the API
The first argument is just the delta as sent by the API and the second argument
is the accumulated snapshot of the run step. For example, a tool calls event may
look like this:
# delta
tool_calls=[
RunStepDeltaToolCallsCodeInterpreter(
index=0,
type='code_interpreter',
id=None,
code_interpreter=CodeInterpreter(input=' sympy', outputs=None)
)
]
# snapshot
tool_calls=[
CodeToolCall(
id='call_wKayJlcYV12NiadiZuJXxcfx',
code_interpreter=CodeInterpreter(input='from sympy', outputs=[]),
type='code_interpreter',
index=0
)
],
"""
async def on_run_step_done(self, run_step: RunStep) -> None:
"""Callback that is fired when a run step is completed"""
async def on_tool_call_created(self, tool_call: ToolCall) -> None:
"""Callback that is fired when a tool call is created"""
async def on_tool_call_delta(self, delta: ToolCallDelta, snapshot: ToolCall) -> None:
"""Callback that is fired when a tool call delta is encountered"""
async def on_tool_call_done(self, tool_call: ToolCall) -> None:
"""Callback that is fired when a tool call delta is encountered"""
async def on_exception(self, exception: Exception) -> None:
"""Fired whenever an exception happens during streaming"""
async def on_timeout(self) -> None:
"""Fires if the request times out"""
async def on_message_created(self, message: Message) -> None:
"""Callback that is fired when a message is created"""
async def on_message_delta(self, delta: MessageDelta, snapshot: Message) -> None:
"""Callback that is fired whenever a message delta is returned from the API
The first argument is just the delta as sent by the API and the second argument
is the accumulated snapshot of the message. For example, a text content event may
look like this:
# delta
MessageDeltaText(
index=0,
type='text',
text=Text(
value=' Jane'
),
)
# snapshot
MessageContentText(
index=0,
type='text',
text=Text(
value='Certainly, Jane'
),
)
"""
async def on_message_done(self, message: Message) -> None:
"""Callback that is fired when a message is completed"""
async def on_text_created(self, text: Text) -> None:
"""Callback that is fired when a text content block is created"""
async def on_text_delta(self, delta: TextDelta, snapshot: Text) -> None:
"""Callback that is fired whenever a text content delta is returned
by the API.
The first argument is just the delta as sent by the API and the second argument
is the accumulated snapshot of the text. For example:
on_text_delta(TextDelta(value="The"), Text(value="The")),
on_text_delta(TextDelta(value=" solution"), Text(value="The solution")),
on_text_delta(TextDelta(value=" to"), Text(value="The solution to")),
on_text_delta(TextDelta(value=" the"), Text(value="The solution to the")),
on_text_delta(TextDelta(value=" equation"), Text(value="The solution to the equivalent")),
"""
async def on_text_done(self, text: Text) -> None:
"""Callback that is fired when a text content block is finished"""
async def on_image_file_done(self, image_file: ImageFile) -> None:
"""Callback that is fired when an image file block is finished"""
async def _emit_sse_event(self, event: AssistantStreamEvent) -> None:
self._current_event = event
await self.on_event(event)
self.__current_message_snapshot, new_content = accumulate_event(
event=event,
current_message_snapshot=self.__current_message_snapshot,
)
if self.__current_message_snapshot is not None:
self.__message_snapshots[self.__current_message_snapshot.id] = self.__current_message_snapshot
accumulate_run_step(
event=event,
run_step_snapshots=self.__run_step_snapshots,
)
for content_delta in new_content:
assert self.__current_message_snapshot is not None
block = self.__current_message_snapshot.content[content_delta.index]
if block.type == "text":
await self.on_text_created(block.text)
if (
event.event == "thread.run.completed"
or event.event == "thread.run.cancelled"
or event.event == "thread.run.expired"
or event.event == "thread.run.failed"
or event.event == "thread.run.requires_action"
):
self.__current_run = event.data
if self._current_tool_call:
await self.on_tool_call_done(self._current_tool_call)
elif (
event.event == "thread.run.created"
or event.event == "thread.run.in_progress"
or event.event == "thread.run.cancelling"
or event.event == "thread.run.queued"
):
self.__current_run = event.data
elif event.event == "thread.message.created":
await self.on_message_created(event.data)
elif event.event == "thread.message.delta":
snapshot = self.__current_message_snapshot
assert snapshot is not None
message_delta = event.data.delta
if message_delta.content is not None:
for content_delta in message_delta.content:
if content_delta.type == "text" and content_delta.text:
snapshot_content = snapshot.content[content_delta.index]
assert snapshot_content.type == "text"
await self.on_text_delta(content_delta.text, snapshot_content.text)
# If the delta is for a new message content:
# - emit on_text_done/on_image_file_done for the previous message content
# - emit on_text_created/on_image_created for the new message content
if content_delta.index != self._current_message_content_index:
if self._current_message_content is not None:
if self._current_message_content.type == "text":
await self.on_text_done(self._current_message_content.text)
elif self._current_message_content.type == "image_file":
await self.on_image_file_done(self._current_message_content.image_file)
self._current_message_content_index = content_delta.index
self._current_message_content = snapshot.content[content_delta.index]
# Update the current_message_content (delta event is correctly emitted already)
self._current_message_content = snapshot.content[content_delta.index]
await self.on_message_delta(event.data.delta, snapshot)
elif event.event == "thread.message.completed" or event.event == "thread.message.incomplete":
self.__current_message_snapshot = event.data
self.__message_snapshots[event.data.id] = event.data
if self._current_message_content_index is not None:
content = event.data.content[self._current_message_content_index]
if content.type == "text":
await self.on_text_done(content.text)
elif content.type == "image_file":
await self.on_image_file_done(content.image_file)
await self.on_message_done(event.data)
elif event.event == "thread.run.step.created":
self.__current_run_step_id = event.data.id
await self.on_run_step_created(event.data)
elif event.event == "thread.run.step.in_progress":
self.__current_run_step_id = event.data.id
elif event.event == "thread.run.step.delta":
step_snapshot = self.__run_step_snapshots[event.data.id]
run_step_delta = event.data.delta
if (
run_step_delta.step_details
and run_step_delta.step_details.type == "tool_calls"
and run_step_delta.step_details.tool_calls is not None
):
assert step_snapshot.step_details.type == "tool_calls"
for tool_call_delta in run_step_delta.step_details.tool_calls:
if tool_call_delta.index == self._current_tool_call_index:
await self.on_tool_call_delta(
tool_call_delta,
step_snapshot.step_details.tool_calls[tool_call_delta.index],
)
# If the delta is for a new tool call:
# - emit on_tool_call_done for the previous tool_call
# - emit on_tool_call_created for the new tool_call
if tool_call_delta.index != self._current_tool_call_index:
if self._current_tool_call is not None:
await self.on_tool_call_done(self._current_tool_call)
self._current_tool_call_index = tool_call_delta.index
self._current_tool_call = step_snapshot.step_details.tool_calls[tool_call_delta.index]
await self.on_tool_call_created(self._current_tool_call)
# Update the current_tool_call (delta event is correctly emitted already)
self._current_tool_call = step_snapshot.step_details.tool_calls[tool_call_delta.index]
await self.on_run_step_delta(
event.data.delta,
step_snapshot,
)
elif (
event.event == "thread.run.step.completed"
or event.event == "thread.run.step.cancelled"
or event.event == "thread.run.step.expired"
or event.event == "thread.run.step.failed"
):
if self._current_tool_call:
await self.on_tool_call_done(self._current_tool_call)
await self.on_run_step_done(event.data)
self.__current_run_step_id = None
elif event.event == "thread.created" or event.event == "thread.message.in_progress" or event.event == "error":
# currently no special handling
...
else:
# we only want to error at build-time
if TYPE_CHECKING: # type: ignore[unreachable]
assert_never(event)
self._current_event = None
async def __stream__(self) -> AsyncIterator[AssistantStreamEvent]:
stream = self.__stream
if not stream:
raise RuntimeError("Stream has not been started yet")
try:
async for event in stream:
await self._emit_sse_event(event)
yield event
except (httpx.TimeoutException, asyncio.TimeoutError) as exc:
await self.on_timeout()
await self.on_exception(exc)
raise
except Exception as exc:
await self.on_exception(exc)
raise
finally:
await self.on_end()
AsyncAssistantEventHandlerT = TypeVar("AsyncAssistantEventHandlerT", bound=AsyncAssistantEventHandler)
class AsyncAssistantStreamManager(Generic[AsyncAssistantEventHandlerT]):
"""Wrapper over AsyncAssistantStreamEventHandler that is returned by `.stream()`
so that an async context manager can be used without `await`ing the
original client call.
```py
async with client.threads.create_and_run_stream(...) as stream:
async for event in stream:
...
```
"""
def __init__(
self,
api_request: Awaitable[AsyncStream[AssistantStreamEvent]],
*,
event_handler: AsyncAssistantEventHandlerT,
) -> None:
self.__stream: AsyncStream[AssistantStreamEvent] | None = None
self.__event_handler = event_handler
self.__api_request = api_request
async def __aenter__(self) -> AsyncAssistantEventHandlerT:
self.__stream = await self.__api_request
self.__event_handler._init(self.__stream)
return self.__event_handler
async def __aexit__(
self,
exc_type: type[BaseException] | None,
exc: BaseException | None,
exc_tb: TracebackType | None,
) -> None:
if self.__stream is not None:
await self.__stream.close()
def accumulate_run_step(
*,
event: AssistantStreamEvent,
run_step_snapshots: dict[str, RunStep],
) -> None:
if event.event == "thread.run.step.created":
run_step_snapshots[event.data.id] = event.data
return
if event.event == "thread.run.step.delta":
data = event.data
snapshot = run_step_snapshots[data.id]
if data.delta:
merged = accumulate_delta(
cast(
"dict[object, object]",
snapshot.model_dump(exclude_unset=True),
),
cast(
"dict[object, object]",
data.delta.model_dump(exclude_unset=True),
),
)
run_step_snapshots[snapshot.id] = cast(RunStep, construct_type(type_=RunStep, value=merged))
return None
def accumulate_event(
*,
event: AssistantStreamEvent,
current_message_snapshot: Message | None,
) -> tuple[Message | None, list[MessageContentDelta]]:
"""Returns a tuple of message snapshot and newly created text message deltas"""
if event.event == "thread.message.created":
return event.data, []
new_content: list[MessageContentDelta] = []
if event.event != "thread.message.delta":
return current_message_snapshot, []
if not current_message_snapshot:
raise RuntimeError("Encountered a message delta with no previous snapshot")
data = event.data
if data.delta.content:
for content_delta in data.delta.content:
try:
block = current_message_snapshot.content[content_delta.index]
except IndexError:
current_message_snapshot.content.insert(
content_delta.index,
cast(
MessageContent,
construct_type(
# mypy doesn't allow Content for some reason
type_=cast(Any, MessageContent),
value=content_delta.model_dump(exclude_unset=True),
),
),
)
new_content.append(content_delta)
else:
merged = accumulate_delta(
cast(
"dict[object, object]",
block.model_dump(exclude_unset=True),
),
cast(
"dict[object, object]",
content_delta.model_dump(exclude_unset=True),
),
)
current_message_snapshot.content[content_delta.index] = cast(
MessageContent,
construct_type(
# mypy doesn't allow Content for some reason
type_=cast(Any, MessageContent),
value=merged,
),
)
return current_message_snapshot, new_content
def accumulate_delta(acc: dict[object, object], delta: dict[object, object]) -> dict[object, object]:
for key, delta_value in delta.items():
if key not in acc:
acc[key] = delta_value
continue
acc_value = acc[key]
if acc_value is None:
acc[key] = delta_value
continue
# the `index` property is used in arrays of objects so it should
# not be accumulated like other values e.g.
# [{'foo': 'bar', 'index': 0}]
#
# the same applies to `type` properties as they're used for
# discriminated unions
if key == "index" or key == "type":
acc[key] = delta_value
continue
if isinstance(acc_value, str) and isinstance(delta_value, str):
acc_value += delta_value
elif isinstance(acc_value, (int, float)) and isinstance(delta_value, (int, float)):
acc_value += delta_value
elif is_dict(acc_value) and is_dict(delta_value):
acc_value = accumulate_delta(acc_value, delta_value)
elif is_list(acc_value) and is_list(delta_value):
# for lists of non-dictionary items we'll only ever get new entries
# in the array, existing entries will never be changed
if all(isinstance(x, (str, int, float)) for x in acc_value):
acc_value.extend(delta_value)
continue
for delta_entry in delta_value:
if not is_dict(delta_entry):
raise TypeError(f"Unexpected list delta entry is not a dictionary: {delta_entry}")
try:
index = delta_entry["index"]
except KeyError as exc:
raise RuntimeError(f"Expected list delta entry to have an `index` key; {delta_entry}") from exc
if not isinstance(index, int):
raise TypeError(f"Unexpected, list delta entry `index` value is not an integer; {index}")
try:
acc_entry = acc_value[index]
except IndexError:
acc_value.insert(index, delta_entry)
else:
if not is_dict(acc_entry):
raise TypeError("not handled yet")
acc_value[index] = accumulate_delta(acc_entry, delta_entry)
acc[key] = acc_value
return acc