from __future__ import annotations import asyncio from types import TracebackType from typing import TYPE_CHECKING, Any, Generic, TypeVar, Callable, Iterable, Iterator, cast from typing_extensions import Awaitable, AsyncIterable, AsyncIterator, assert_never import httpx from ..._utils import is_dict, is_list, consume_sync_iterator, consume_async_iterator from ..._models import construct_type from ..._streaming import Stream, AsyncStream from ...types.beta import AssistantStreamEvent from ...types.beta.threads import ( Run, Text, Message, ImageFile, TextDelta, MessageDelta, MessageContent, MessageContentDelta, ) from ...types.beta.threads.runs import RunStep, ToolCall, RunStepDelta, ToolCallDelta class AssistantEventHandler: text_deltas: Iterable[str] """Iterator over just the text deltas in the stream. This corresponds to the `thread.message.delta` event in the API. ```py for text in stream.text_deltas: print(text, end="", flush=True) print() ``` """ def __init__(self) -> None: self._current_event: AssistantStreamEvent | None = None self._current_message_content_index: int | None = None self._current_message_content: MessageContent | None = None self._current_tool_call_index: int | None = None self._current_tool_call: ToolCall | None = None self.__current_run_step_id: str | None = None self.__current_run: Run | None = None self.__run_step_snapshots: dict[str, RunStep] = {} self.__message_snapshots: dict[str, Message] = {} self.__current_message_snapshot: Message | None = None self.text_deltas = self.__text_deltas__() self._iterator = self.__stream__() self.__stream: Stream[AssistantStreamEvent] | None = None def _init(self, stream: Stream[AssistantStreamEvent]) -> None: if self.__stream: raise RuntimeError( "A single event handler cannot be shared between multiple streams; You will need to construct a new event handler instance" ) self.__stream = stream def __next__(self) -> AssistantStreamEvent: return self._iterator.__next__() def __iter__(self) -> Iterator[AssistantStreamEvent]: for item in self._iterator: yield item @property def current_event(self) -> AssistantStreamEvent | None: return self._current_event @property def current_run(self) -> Run | None: return self.__current_run @property def current_run_step_snapshot(self) -> RunStep | None: if not self.__current_run_step_id: return None return self.__run_step_snapshots[self.__current_run_step_id] @property def current_message_snapshot(self) -> Message | None: return self.__current_message_snapshot def close(self) -> None: """ Close the response and release the connection. Automatically called when the context manager exits. """ if self.__stream: self.__stream.close() def until_done(self) -> None: """Waits until the stream has been consumed""" consume_sync_iterator(self) def get_final_run(self) -> Run: """Wait for the stream to finish and returns the completed Run object""" self.until_done() if not self.__current_run: raise RuntimeError("No final run object found") return self.__current_run def get_final_run_steps(self) -> list[RunStep]: """Wait for the stream to finish and returns the steps taken in this run""" self.until_done() if not self.__run_step_snapshots: raise RuntimeError("No run steps found") return [step for step in self.__run_step_snapshots.values()] def get_final_messages(self) -> list[Message]: """Wait for the stream to finish and returns the messages emitted in this run""" self.until_done() if not self.__message_snapshots: raise RuntimeError("No messages found") return [message for message in self.__message_snapshots.values()] def __text_deltas__(self) -> Iterator[str]: for event in self: if event.event != "thread.message.delta": continue for content_delta in event.data.delta.content or []: if content_delta.type == "text" and content_delta.text and content_delta.text.value: yield content_delta.text.value # event handlers def on_end(self) -> None: """Fires when the stream has finished. This happens if the stream is read to completion or if an exception occurs during iteration. """ def on_event(self, event: AssistantStreamEvent) -> None: """Callback that is fired for every Server-Sent-Event""" def on_run_step_created(self, run_step: RunStep) -> None: """Callback that is fired when a run step is created""" def on_run_step_delta(self, delta: RunStepDelta, snapshot: RunStep) -> None: """Callback that is fired whenever a run step delta is returned from the API The first argument is just the delta as sent by the API and the second argument is the accumulated snapshot of the run step. For example, a tool calls event may look like this: # delta tool_calls=[ RunStepDeltaToolCallsCodeInterpreter( index=0, type='code_interpreter', id=None, code_interpreter=CodeInterpreter(input=' sympy', outputs=None) ) ] # snapshot tool_calls=[ CodeToolCall( id='call_wKayJlcYV12NiadiZuJXxcfx', code_interpreter=CodeInterpreter(input='from sympy', outputs=[]), type='code_interpreter', index=0 ) ], """ def on_run_step_done(self, run_step: RunStep) -> None: """Callback that is fired when a run step is completed""" def on_tool_call_created(self, tool_call: ToolCall) -> None: """Callback that is fired when a tool call is created""" def on_tool_call_delta(self, delta: ToolCallDelta, snapshot: ToolCall) -> None: """Callback that is fired when a tool call delta is encountered""" def on_tool_call_done(self, tool_call: ToolCall) -> None: """Callback that is fired when a tool call delta is encountered""" def on_exception(self, exception: Exception) -> None: """Fired whenever an exception happens during streaming""" def on_timeout(self) -> None: """Fires if the request times out""" def on_message_created(self, message: Message) -> None: """Callback that is fired when a message is created""" def on_message_delta(self, delta: MessageDelta, snapshot: Message) -> None: """Callback that is fired whenever a message delta is returned from the API The first argument is just the delta as sent by the API and the second argument is the accumulated snapshot of the message. For example, a text content event may look like this: # delta MessageDeltaText( index=0, type='text', text=Text( value=' Jane' ), ) # snapshot MessageContentText( index=0, type='text', text=Text( value='Certainly, Jane' ), ) """ def on_message_done(self, message: Message) -> None: """Callback that is fired when a message is completed""" def on_text_created(self, text: Text) -> None: """Callback that is fired when a text content block is created""" def on_text_delta(self, delta: TextDelta, snapshot: Text) -> None: """Callback that is fired whenever a text content delta is returned by the API. The first argument is just the delta as sent by the API and the second argument is the accumulated snapshot of the text. For example: on_text_delta(TextDelta(value="The"), Text(value="The")), on_text_delta(TextDelta(value=" solution"), Text(value="The solution")), on_text_delta(TextDelta(value=" to"), Text(value="The solution to")), on_text_delta(TextDelta(value=" the"), Text(value="The solution to the")), on_text_delta(TextDelta(value=" equation"), Text(value="The solution to the equivalent")), """ def on_text_done(self, text: Text) -> None: """Callback that is fired when a text content block is finished""" def on_image_file_done(self, image_file: ImageFile) -> None: """Callback that is fired when an image file block is finished""" def _emit_sse_event(self, event: AssistantStreamEvent) -> None: self._current_event = event self.on_event(event) self.__current_message_snapshot, new_content = accumulate_event( event=event, current_message_snapshot=self.__current_message_snapshot, ) if self.__current_message_snapshot is not None: self.__message_snapshots[self.__current_message_snapshot.id] = self.__current_message_snapshot accumulate_run_step( event=event, run_step_snapshots=self.__run_step_snapshots, ) for content_delta in new_content: assert self.__current_message_snapshot is not None block = self.__current_message_snapshot.content[content_delta.index] if block.type == "text": self.on_text_created(block.text) if ( event.event == "thread.run.completed" or event.event == "thread.run.cancelled" or event.event == "thread.run.expired" or event.event == "thread.run.failed" or event.event == "thread.run.requires_action" ): self.__current_run = event.data if self._current_tool_call: self.on_tool_call_done(self._current_tool_call) elif ( event.event == "thread.run.created" or event.event == "thread.run.in_progress" or event.event == "thread.run.cancelling" or event.event == "thread.run.queued" ): self.__current_run = event.data elif event.event == "thread.message.created": self.on_message_created(event.data) elif event.event == "thread.message.delta": snapshot = self.__current_message_snapshot assert snapshot is not None message_delta = event.data.delta if message_delta.content is not None: for content_delta in message_delta.content: if content_delta.type == "text" and content_delta.text: snapshot_content = snapshot.content[content_delta.index] assert snapshot_content.type == "text" self.on_text_delta(content_delta.text, snapshot_content.text) # If the delta is for a new message content: # - emit on_text_done/on_image_file_done for the previous message content # - emit on_text_created/on_image_created for the new message content if content_delta.index != self._current_message_content_index: if self._current_message_content is not None: if self._current_message_content.type == "text": self.on_text_done(self._current_message_content.text) elif self._current_message_content.type == "image_file": self.on_image_file_done(self._current_message_content.image_file) self._current_message_content_index = content_delta.index self._current_message_content = snapshot.content[content_delta.index] # Update the current_message_content (delta event is correctly emitted already) self._current_message_content = snapshot.content[content_delta.index] self.on_message_delta(event.data.delta, snapshot) elif event.event == "thread.message.completed" or event.event == "thread.message.incomplete": self.__current_message_snapshot = event.data self.__message_snapshots[event.data.id] = event.data if self._current_message_content_index is not None: content = event.data.content[self._current_message_content_index] if content.type == "text": self.on_text_done(content.text) elif content.type == "image_file": self.on_image_file_done(content.image_file) self.on_message_done(event.data) elif event.event == "thread.run.step.created": self.__current_run_step_id = event.data.id self.on_run_step_created(event.data) elif event.event == "thread.run.step.in_progress": self.__current_run_step_id = event.data.id elif event.event == "thread.run.step.delta": step_snapshot = self.__run_step_snapshots[event.data.id] run_step_delta = event.data.delta if ( run_step_delta.step_details and run_step_delta.step_details.type == "tool_calls" and run_step_delta.step_details.tool_calls is not None ): assert step_snapshot.step_details.type == "tool_calls" for tool_call_delta in run_step_delta.step_details.tool_calls: if tool_call_delta.index == self._current_tool_call_index: self.on_tool_call_delta( tool_call_delta, step_snapshot.step_details.tool_calls[tool_call_delta.index], ) # If the delta is for a new tool call: # - emit on_tool_call_done for the previous tool_call # - emit on_tool_call_created for the new tool_call if tool_call_delta.index != self._current_tool_call_index: if self._current_tool_call is not None: self.on_tool_call_done(self._current_tool_call) self._current_tool_call_index = tool_call_delta.index self._current_tool_call = step_snapshot.step_details.tool_calls[tool_call_delta.index] self.on_tool_call_created(self._current_tool_call) # Update the current_tool_call (delta event is correctly emitted already) self._current_tool_call = step_snapshot.step_details.tool_calls[tool_call_delta.index] self.on_run_step_delta( event.data.delta, step_snapshot, ) elif ( event.event == "thread.run.step.completed" or event.event == "thread.run.step.cancelled" or event.event == "thread.run.step.expired" or event.event == "thread.run.step.failed" ): if self._current_tool_call: self.on_tool_call_done(self._current_tool_call) self.on_run_step_done(event.data) self.__current_run_step_id = None elif event.event == "thread.created" or event.event == "thread.message.in_progress" or event.event == "error": # currently no special handling ... else: # we only want to error at build-time if TYPE_CHECKING: # type: ignore[unreachable] assert_never(event) self._current_event = None def __stream__(self) -> Iterator[AssistantStreamEvent]: stream = self.__stream if not stream: raise RuntimeError("Stream has not been started yet") try: for event in stream: self._emit_sse_event(event) yield event except (httpx.TimeoutException, asyncio.TimeoutError) as exc: self.on_timeout() self.on_exception(exc) raise except Exception as exc: self.on_exception(exc) raise finally: self.on_end() AssistantEventHandlerT = TypeVar("AssistantEventHandlerT", bound=AssistantEventHandler) class AssistantStreamManager(Generic[AssistantEventHandlerT]): """Wrapper over AssistantStreamEventHandler that is returned by `.stream()` so that a context manager can be used. ```py with client.threads.create_and_run_stream(...) as stream: for event in stream: ... ``` """ def __init__( self, api_request: Callable[[], Stream[AssistantStreamEvent]], *, event_handler: AssistantEventHandlerT, ) -> None: self.__stream: Stream[AssistantStreamEvent] | None = None self.__event_handler = event_handler self.__api_request = api_request def __enter__(self) -> AssistantEventHandlerT: self.__stream = self.__api_request() self.__event_handler._init(self.__stream) return self.__event_handler def __exit__( self, exc_type: type[BaseException] | None, exc: BaseException | None, exc_tb: TracebackType | None, ) -> None: if self.__stream is not None: self.__stream.close() class AsyncAssistantEventHandler: text_deltas: AsyncIterable[str] """Iterator over just the text deltas in the stream. This corresponds to the `thread.message.delta` event in the API. ```py async for text in stream.text_deltas: print(text, end="", flush=True) print() ``` """ def __init__(self) -> None: self._current_event: AssistantStreamEvent | None = None self._current_message_content_index: int | None = None self._current_message_content: MessageContent | None = None self._current_tool_call_index: int | None = None self._current_tool_call: ToolCall | None = None self.__current_run_step_id: str | None = None self.__current_run: Run | None = None self.__run_step_snapshots: dict[str, RunStep] = {} self.__message_snapshots: dict[str, Message] = {} self.__current_message_snapshot: Message | None = None self.text_deltas = self.__text_deltas__() self._iterator = self.__stream__() self.__stream: AsyncStream[AssistantStreamEvent] | None = None def _init(self, stream: AsyncStream[AssistantStreamEvent]) -> None: if self.__stream: raise RuntimeError( "A single event handler cannot be shared between multiple streams; You will need to construct a new event handler instance" ) self.__stream = stream async def __anext__(self) -> AssistantStreamEvent: return await self._iterator.__anext__() async def __aiter__(self) -> AsyncIterator[AssistantStreamEvent]: async for item in self._iterator: yield item async def close(self) -> None: """ Close the response and release the connection. Automatically called when the context manager exits. """ if self.__stream: await self.__stream.close() @property def current_event(self) -> AssistantStreamEvent | None: return self._current_event @property def current_run(self) -> Run | None: return self.__current_run @property def current_run_step_snapshot(self) -> RunStep | None: if not self.__current_run_step_id: return None return self.__run_step_snapshots[self.__current_run_step_id] @property def current_message_snapshot(self) -> Message | None: return self.__current_message_snapshot async def until_done(self) -> None: """Waits until the stream has been consumed""" await consume_async_iterator(self) async def get_final_run(self) -> Run: """Wait for the stream to finish and returns the completed Run object""" await self.until_done() if not self.__current_run: raise RuntimeError("No final run object found") return self.__current_run async def get_final_run_steps(self) -> list[RunStep]: """Wait for the stream to finish and returns the steps taken in this run""" await self.until_done() if not self.__run_step_snapshots: raise RuntimeError("No run steps found") return [step for step in self.__run_step_snapshots.values()] async def get_final_messages(self) -> list[Message]: """Wait for the stream to finish and returns the messages emitted in this run""" await self.until_done() if not self.__message_snapshots: raise RuntimeError("No messages found") return [message for message in self.__message_snapshots.values()] async def __text_deltas__(self) -> AsyncIterator[str]: async for event in self: if event.event != "thread.message.delta": continue for content_delta in event.data.delta.content or []: if content_delta.type == "text" and content_delta.text and content_delta.text.value: yield content_delta.text.value # event handlers async def on_end(self) -> None: """Fires when the stream has finished. This happens if the stream is read to completion or if an exception occurs during iteration. """ async def on_event(self, event: AssistantStreamEvent) -> None: """Callback that is fired for every Server-Sent-Event""" async def on_run_step_created(self, run_step: RunStep) -> None: """Callback that is fired when a run step is created""" async def on_run_step_delta(self, delta: RunStepDelta, snapshot: RunStep) -> None: """Callback that is fired whenever a run step delta is returned from the API The first argument is just the delta as sent by the API and the second argument is the accumulated snapshot of the run step. For example, a tool calls event may look like this: # delta tool_calls=[ RunStepDeltaToolCallsCodeInterpreter( index=0, type='code_interpreter', id=None, code_interpreter=CodeInterpreter(input=' sympy', outputs=None) ) ] # snapshot tool_calls=[ CodeToolCall( id='call_wKayJlcYV12NiadiZuJXxcfx', code_interpreter=CodeInterpreter(input='from sympy', outputs=[]), type='code_interpreter', index=0 ) ], """ async def on_run_step_done(self, run_step: RunStep) -> None: """Callback that is fired when a run step is completed""" async def on_tool_call_created(self, tool_call: ToolCall) -> None: """Callback that is fired when a tool call is created""" async def on_tool_call_delta(self, delta: ToolCallDelta, snapshot: ToolCall) -> None: """Callback that is fired when a tool call delta is encountered""" async def on_tool_call_done(self, tool_call: ToolCall) -> None: """Callback that is fired when a tool call delta is encountered""" async def on_exception(self, exception: Exception) -> None: """Fired whenever an exception happens during streaming""" async def on_timeout(self) -> None: """Fires if the request times out""" async def on_message_created(self, message: Message) -> None: """Callback that is fired when a message is created""" async def on_message_delta(self, delta: MessageDelta, snapshot: Message) -> None: """Callback that is fired whenever a message delta is returned from the API The first argument is just the delta as sent by the API and the second argument is the accumulated snapshot of the message. For example, a text content event may look like this: # delta MessageDeltaText( index=0, type='text', text=Text( value=' Jane' ), ) # snapshot MessageContentText( index=0, type='text', text=Text( value='Certainly, Jane' ), ) """ async def on_message_done(self, message: Message) -> None: """Callback that is fired when a message is completed""" async def on_text_created(self, text: Text) -> None: """Callback that is fired when a text content block is created""" async def on_text_delta(self, delta: TextDelta, snapshot: Text) -> None: """Callback that is fired whenever a text content delta is returned by the API. The first argument is just the delta as sent by the API and the second argument is the accumulated snapshot of the text. For example: on_text_delta(TextDelta(value="The"), Text(value="The")), on_text_delta(TextDelta(value=" solution"), Text(value="The solution")), on_text_delta(TextDelta(value=" to"), Text(value="The solution to")), on_text_delta(TextDelta(value=" the"), Text(value="The solution to the")), on_text_delta(TextDelta(value=" equation"), Text(value="The solution to the equivalent")), """ async def on_text_done(self, text: Text) -> None: """Callback that is fired when a text content block is finished""" async def on_image_file_done(self, image_file: ImageFile) -> None: """Callback that is fired when an image file block is finished""" async def _emit_sse_event(self, event: AssistantStreamEvent) -> None: self._current_event = event await self.on_event(event) self.__current_message_snapshot, new_content = accumulate_event( event=event, current_message_snapshot=self.__current_message_snapshot, ) if self.__current_message_snapshot is not None: self.__message_snapshots[self.__current_message_snapshot.id] = self.__current_message_snapshot accumulate_run_step( event=event, run_step_snapshots=self.__run_step_snapshots, ) for content_delta in new_content: assert self.__current_message_snapshot is not None block = self.__current_message_snapshot.content[content_delta.index] if block.type == "text": await self.on_text_created(block.text) if ( event.event == "thread.run.completed" or event.event == "thread.run.cancelled" or event.event == "thread.run.expired" or event.event == "thread.run.failed" or event.event == "thread.run.requires_action" ): self.__current_run = event.data if self._current_tool_call: await self.on_tool_call_done(self._current_tool_call) elif ( event.event == "thread.run.created" or event.event == "thread.run.in_progress" or event.event == "thread.run.cancelling" or event.event == "thread.run.queued" ): self.__current_run = event.data elif event.event == "thread.message.created": await self.on_message_created(event.data) elif event.event == "thread.message.delta": snapshot = self.__current_message_snapshot assert snapshot is not None message_delta = event.data.delta if message_delta.content is not None: for content_delta in message_delta.content: if content_delta.type == "text" and content_delta.text: snapshot_content = snapshot.content[content_delta.index] assert snapshot_content.type == "text" await self.on_text_delta(content_delta.text, snapshot_content.text) # If the delta is for a new message content: # - emit on_text_done/on_image_file_done for the previous message content # - emit on_text_created/on_image_created for the new message content if content_delta.index != self._current_message_content_index: if self._current_message_content is not None: if self._current_message_content.type == "text": await self.on_text_done(self._current_message_content.text) elif self._current_message_content.type == "image_file": await self.on_image_file_done(self._current_message_content.image_file) self._current_message_content_index = content_delta.index self._current_message_content = snapshot.content[content_delta.index] # Update the current_message_content (delta event is correctly emitted already) self._current_message_content = snapshot.content[content_delta.index] await self.on_message_delta(event.data.delta, snapshot) elif event.event == "thread.message.completed" or event.event == "thread.message.incomplete": self.__current_message_snapshot = event.data self.__message_snapshots[event.data.id] = event.data if self._current_message_content_index is not None: content = event.data.content[self._current_message_content_index] if content.type == "text": await self.on_text_done(content.text) elif content.type == "image_file": await self.on_image_file_done(content.image_file) await self.on_message_done(event.data) elif event.event == "thread.run.step.created": self.__current_run_step_id = event.data.id await self.on_run_step_created(event.data) elif event.event == "thread.run.step.in_progress": self.__current_run_step_id = event.data.id elif event.event == "thread.run.step.delta": step_snapshot = self.__run_step_snapshots[event.data.id] run_step_delta = event.data.delta if ( run_step_delta.step_details and run_step_delta.step_details.type == "tool_calls" and run_step_delta.step_details.tool_calls is not None ): assert step_snapshot.step_details.type == "tool_calls" for tool_call_delta in run_step_delta.step_details.tool_calls: if tool_call_delta.index == self._current_tool_call_index: await self.on_tool_call_delta( tool_call_delta, step_snapshot.step_details.tool_calls[tool_call_delta.index], ) # If the delta is for a new tool call: # - emit on_tool_call_done for the previous tool_call # - emit on_tool_call_created for the new tool_call if tool_call_delta.index != self._current_tool_call_index: if self._current_tool_call is not None: await self.on_tool_call_done(self._current_tool_call) self._current_tool_call_index = tool_call_delta.index self._current_tool_call = step_snapshot.step_details.tool_calls[tool_call_delta.index] await self.on_tool_call_created(self._current_tool_call) # Update the current_tool_call (delta event is correctly emitted already) self._current_tool_call = step_snapshot.step_details.tool_calls[tool_call_delta.index] await self.on_run_step_delta( event.data.delta, step_snapshot, ) elif ( event.event == "thread.run.step.completed" or event.event == "thread.run.step.cancelled" or event.event == "thread.run.step.expired" or event.event == "thread.run.step.failed" ): if self._current_tool_call: await self.on_tool_call_done(self._current_tool_call) await self.on_run_step_done(event.data) self.__current_run_step_id = None elif event.event == "thread.created" or event.event == "thread.message.in_progress" or event.event == "error": # currently no special handling ... else: # we only want to error at build-time if TYPE_CHECKING: # type: ignore[unreachable] assert_never(event) self._current_event = None async def __stream__(self) -> AsyncIterator[AssistantStreamEvent]: stream = self.__stream if not stream: raise RuntimeError("Stream has not been started yet") try: async for event in stream: await self._emit_sse_event(event) yield event except (httpx.TimeoutException, asyncio.TimeoutError) as exc: await self.on_timeout() await self.on_exception(exc) raise except Exception as exc: await self.on_exception(exc) raise finally: await self.on_end() AsyncAssistantEventHandlerT = TypeVar("AsyncAssistantEventHandlerT", bound=AsyncAssistantEventHandler) class AsyncAssistantStreamManager(Generic[AsyncAssistantEventHandlerT]): """Wrapper over AsyncAssistantStreamEventHandler that is returned by `.stream()` so that an async context manager can be used without `await`ing the original client call. ```py async with client.threads.create_and_run_stream(...) as stream: async for event in stream: ... ``` """ def __init__( self, api_request: Awaitable[AsyncStream[AssistantStreamEvent]], *, event_handler: AsyncAssistantEventHandlerT, ) -> None: self.__stream: AsyncStream[AssistantStreamEvent] | None = None self.__event_handler = event_handler self.__api_request = api_request async def __aenter__(self) -> AsyncAssistantEventHandlerT: self.__stream = await self.__api_request self.__event_handler._init(self.__stream) return self.__event_handler async def __aexit__( self, exc_type: type[BaseException] | None, exc: BaseException | None, exc_tb: TracebackType | None, ) -> None: if self.__stream is not None: await self.__stream.close() def accumulate_run_step( *, event: AssistantStreamEvent, run_step_snapshots: dict[str, RunStep], ) -> None: if event.event == "thread.run.step.created": run_step_snapshots[event.data.id] = event.data return if event.event == "thread.run.step.delta": data = event.data snapshot = run_step_snapshots[data.id] if data.delta: merged = accumulate_delta( cast( "dict[object, object]", snapshot.model_dump(exclude_unset=True), ), cast( "dict[object, object]", data.delta.model_dump(exclude_unset=True), ), ) run_step_snapshots[snapshot.id] = cast(RunStep, construct_type(type_=RunStep, value=merged)) return None def accumulate_event( *, event: AssistantStreamEvent, current_message_snapshot: Message | None, ) -> tuple[Message | None, list[MessageContentDelta]]: """Returns a tuple of message snapshot and newly created text message deltas""" if event.event == "thread.message.created": return event.data, [] new_content: list[MessageContentDelta] = [] if event.event != "thread.message.delta": return current_message_snapshot, [] if not current_message_snapshot: raise RuntimeError("Encountered a message delta with no previous snapshot") data = event.data if data.delta.content: for content_delta in data.delta.content: try: block = current_message_snapshot.content[content_delta.index] except IndexError: current_message_snapshot.content.insert( content_delta.index, cast( MessageContent, construct_type( # mypy doesn't allow Content for some reason type_=cast(Any, MessageContent), value=content_delta.model_dump(exclude_unset=True), ), ), ) new_content.append(content_delta) else: merged = accumulate_delta( cast( "dict[object, object]", block.model_dump(exclude_unset=True), ), cast( "dict[object, object]", content_delta.model_dump(exclude_unset=True), ), ) current_message_snapshot.content[content_delta.index] = cast( MessageContent, construct_type( # mypy doesn't allow Content for some reason type_=cast(Any, MessageContent), value=merged, ), ) return current_message_snapshot, new_content def accumulate_delta(acc: dict[object, object], delta: dict[object, object]) -> dict[object, object]: for key, delta_value in delta.items(): if key not in acc: acc[key] = delta_value continue acc_value = acc[key] if acc_value is None: acc[key] = delta_value continue # the `index` property is used in arrays of objects so it should # not be accumulated like other values e.g. # [{'foo': 'bar', 'index': 0}] # # the same applies to `type` properties as they're used for # discriminated unions if key == "index" or key == "type": acc[key] = delta_value continue if isinstance(acc_value, str) and isinstance(delta_value, str): acc_value += delta_value elif isinstance(acc_value, (int, float)) and isinstance(delta_value, (int, float)): acc_value += delta_value elif is_dict(acc_value) and is_dict(delta_value): acc_value = accumulate_delta(acc_value, delta_value) elif is_list(acc_value) and is_list(delta_value): # for lists of non-dictionary items we'll only ever get new entries # in the array, existing entries will never be changed if all(isinstance(x, (str, int, float)) for x in acc_value): acc_value.extend(delta_value) continue for delta_entry in delta_value: if not is_dict(delta_entry): raise TypeError(f"Unexpected list delta entry is not a dictionary: {delta_entry}") try: index = delta_entry["index"] except KeyError as exc: raise RuntimeError(f"Expected list delta entry to have an `index` key; {delta_entry}") from exc if not isinstance(index, int): raise TypeError(f"Unexpected, list delta entry `index` value is not an integer; {index}") try: acc_entry = acc_value[index] except IndexError: acc_value.insert(index, delta_entry) else: if not is_dict(acc_entry): raise TypeError("not handled yet") acc_value[index] = accumulate_delta(acc_entry, delta_entry) acc[key] = acc_value return acc