Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 66 additions & 1 deletion src/openai/lib/streaming/responses/_responses.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from ...._utils import is_given, consume_sync_iterator, consume_async_iterator
from ...._models import build, construct_type_unchecked
from ...._streaming import Stream, AsyncStream
from ....types.responses import ParsedResponse, ResponseStreamEvent as RawResponseStreamEvent
from ....types.responses import Response, ParsedResponse, ResponseStreamEvent as RawResponseStreamEvent
from ..._parsing._responses import TextFormatT, parse_text, parse_response
from ....types.responses.tool_param import ToolParam
from ....types.responses.parsed_response import (
Expand All @@ -35,12 +35,14 @@ def __init__(
text_format: type[TextFormatT] | Omit,
input_tools: Iterable[ToolParam] | Omit,
starting_after: int | None,
cancel_response: Callable[[str], Response] | None = None,
) -> None:
self._raw_stream = raw_stream
self._response = raw_stream.response
self._iterator = self.__stream__()
self._state = ResponseStreamState(text_format=text_format, input_tools=input_tools)
self._starting_after = starting_after
self._cancel_response = cancel_response

def __next__(self) -> ResponseStreamEvent[TextFormatT]:
return self._iterator.__next__()
Expand Down Expand Up @@ -91,6 +93,32 @@ def until_done(self) -> Self:
consume_sync_iterator(self)
return self

@property
def response_id(self) -> str | None:
"""The ID of the response being streamed, available after the first event."""
snapshot = self._state.current_snapshot
if snapshot is not None:
return snapshot.id
return None

def cancel(self) -> Response:
"""Cancel the response being streamed.

Closes the stream first, then calls the API to cancel the response.

Returns the cancelled Response object.
"""
response_id = self.response_id
if response_id is None:
raise ValueError("Cannot cancel: response ID not yet available. Wait for the first event.")
if self._cancel_response is None:
raise ValueError("Cancel not available for this stream.")
try:
result = self._cancel_response(response_id)
finally:
self.close()



class ResponseStreamManager(Generic[TextFormatT]):
def __init__(
Expand All @@ -100,12 +128,14 @@ def __init__(
text_format: type[TextFormatT] | Omit,
input_tools: Iterable[ToolParam] | Omit,
starting_after: int | None,
cancel_response: Callable[[str], Response] | None = None,
) -> None:
self.__stream: ResponseStream[TextFormatT] | None = None
self.__api_request = api_request
self.__text_format = text_format
self.__input_tools = input_tools
self.__starting_after = starting_after
self.__cancel_response = cancel_response

def __enter__(self) -> ResponseStream[TextFormatT]:
raw_stream = self.__api_request()
Expand All @@ -115,6 +145,7 @@ def __enter__(self) -> ResponseStream[TextFormatT]:
text_format=self.__text_format,
input_tools=self.__input_tools,
starting_after=self.__starting_after,
cancel_response=self.__cancel_response,
)

return self.__stream
Expand All @@ -137,12 +168,14 @@ def __init__(
text_format: type[TextFormatT] | Omit,
input_tools: Iterable[ToolParam] | Omit,
starting_after: int | None,
cancel_response: Callable[[str], Awaitable[Response]] | None = None,
) -> None:
self._raw_stream = raw_stream
self._response = raw_stream.response
self._iterator = self.__stream__()
self._state = ResponseStreamState(text_format=text_format, input_tools=input_tools)
self._starting_after = starting_after
self._cancel_response = cancel_response

async def __anext__(self) -> ResponseStreamEvent[TextFormatT]:
return await self._iterator.__anext__()
Expand Down Expand Up @@ -193,6 +226,31 @@ async def until_done(self) -> Self:
await consume_async_iterator(self)
return self

@property
def response_id(self) -> str | None:
"""The ID of the response being streamed, available after the first event."""
snapshot = self._state.current_snapshot
if snapshot is not None:
return snapshot.id
return None

async def cancel(self) -> Response:
"""Cancel the response being streamed.

Calls the API to cancel the response and closes the stream.

Returns the cancelled Response object.
"""
response_id = self.response_id
if response_id is None:
raise ValueError("Cannot cancel: response ID not yet available. Wait for the first event.")
if self._cancel_response is None:
raise ValueError("Cancel not available for this stream.")
try:
result = await self._cancel_response(response_id)
finally:
await self.close()


class AsyncResponseStreamManager(Generic[TextFormatT]):
def __init__(
Expand All @@ -202,12 +260,14 @@ def __init__(
text_format: type[TextFormatT] | Omit,
input_tools: Iterable[ToolParam] | Omit,
starting_after: int | None,
cancel_response: Callable[[str], Awaitable[Response]] | None = None,
) -> None:
self.__stream: AsyncResponseStream[TextFormatT] | None = None
self.__api_request = api_request
self.__text_format = text_format
self.__input_tools = input_tools
self.__starting_after = starting_after
self.__cancel_response = cancel_response

async def __aenter__(self) -> AsyncResponseStream[TextFormatT]:
raw_stream = await self.__api_request
Expand All @@ -217,6 +277,7 @@ async def __aenter__(self) -> AsyncResponseStream[TextFormatT]:
text_format=self.__text_format,
input_tools=self.__input_tools,
starting_after=self.__starting_after,
cancel_response=self.__cancel_response,
)

return self.__stream
Expand Down Expand Up @@ -244,6 +305,10 @@ def __init__(
self._text_format = text_format
self._rich_text_format: type | Omit = text_format if inspect.isclass(text_format) else omit

@property
def current_snapshot(self) -> ParsedResponseSnapshot | None:
return self.__current_snapshot

def handle_event(self, event: RawResponseStreamEvent) -> List[ResponseStreamEvent[TextFormatT]]:
self.__current_snapshot = snapshot = self.accumulate_event(event)

Expand Down
35 changes: 34 additions & 1 deletion src/openai/resources/responses/responses.py
Original file line number Diff line number Diff line change
Expand Up @@ -1129,7 +1129,19 @@ def stream(
timeout=timeout,
)

return ResponseStreamManager(api_request, text_format=text_format, input_tools=tools, starting_after=None)
return ResponseStreamManager(
api_request,
text_format=text_format,
input_tools=tools,
starting_after=None,
cancel_response=lambda response_id: self.cancel(
response_id,
extra_headers=extra_headers,
extra_query=extra_query,
extra_body=extra_body,
timeout=timeout,
),
)
else:
if not is_given(response_id):
raise ValueError("id must be provided when streaming an existing response")
Expand All @@ -1148,6 +1160,13 @@ def stream(
text_format=text_format,
input_tools=tools,
starting_after=starting_after if is_given(starting_after) else None,
cancel_response=lambda response_id: self.cancel(
response_id,
extra_headers=extra_headers,
extra_query=extra_query,
extra_body=extra_body,
timeout=timeout,
),
)

def parse(
Expand Down Expand Up @@ -2794,6 +2813,13 @@ def stream(
text_format=text_format,
input_tools=tools,
starting_after=None,
cancel_response=lambda response_id: self.cancel(
response_id,
extra_headers=extra_headers,
extra_query=extra_query,
extra_body=extra_body,
timeout=timeout,
),
)
else:
if isinstance(response_id, Omit):
Expand All @@ -2813,6 +2839,13 @@ def stream(
text_format=text_format,
input_tools=tools,
starting_after=starting_after if is_given(starting_after) else None,
cancel_response=lambda response_id: self.cancel(
response_id,
extra_headers=extra_headers,
extra_query=extra_query,
extra_body=extra_body,
timeout=timeout,
),
)

async def parse(
Expand Down
Loading