diff --git a/src/openai/lib/streaming/responses/_responses.py b/src/openai/lib/streaming/responses/_responses.py index 6975a9260d..85ffede217 100644 --- a/src/openai/lib/streaming/responses/_responses.py +++ b/src/openai/lib/streaming/responses/_responses.py @@ -17,7 +17,7 @@ from ...._utils import is_given, consume_sync_iterator, consume_async_iterator from ...._models import build, construct_type_unchecked from ...._streaming import Stream, AsyncStream -from ....types.responses import ParsedResponse, ResponseStreamEvent as RawResponseStreamEvent +from ....types.responses import Response, ParsedResponse, ResponseStreamEvent as RawResponseStreamEvent from ..._parsing._responses import TextFormatT, parse_text, parse_response from ....types.responses.tool_param import ToolParam from ....types.responses.parsed_response import ( @@ -35,12 +35,14 @@ def __init__( text_format: type[TextFormatT] | Omit, input_tools: Iterable[ToolParam] | Omit, starting_after: int | None, + cancel_response: Callable[[str], Response] | None = None, ) -> None: self._raw_stream = raw_stream self._response = raw_stream.response self._iterator = self.__stream__() self._state = ResponseStreamState(text_format=text_format, input_tools=input_tools) self._starting_after = starting_after + self._cancel_response = cancel_response def __next__(self) -> ResponseStreamEvent[TextFormatT]: return self._iterator.__next__() @@ -91,6 +93,32 @@ def until_done(self) -> Self: consume_sync_iterator(self) return self + @property + def response_id(self) -> str | None: + """The ID of the response being streamed, available after the first event.""" + snapshot = self._state.current_snapshot + if snapshot is not None: + return snapshot.id + return None + + def cancel(self) -> Response: + """Cancel the response being streamed. + + Closes the stream first, then calls the API to cancel the response. + + Returns the cancelled Response object. + """ + response_id = self.response_id + if response_id is None: + raise ValueError("Cannot cancel: response ID not yet available. Wait for the first event.") + if self._cancel_response is None: + raise ValueError("Cancel not available for this stream.") + try: + result = self._cancel_response(response_id) + finally: + self.close() + + class ResponseStreamManager(Generic[TextFormatT]): def __init__( @@ -100,12 +128,14 @@ def __init__( text_format: type[TextFormatT] | Omit, input_tools: Iterable[ToolParam] | Omit, starting_after: int | None, + cancel_response: Callable[[str], Response] | None = None, ) -> None: self.__stream: ResponseStream[TextFormatT] | None = None self.__api_request = api_request self.__text_format = text_format self.__input_tools = input_tools self.__starting_after = starting_after + self.__cancel_response = cancel_response def __enter__(self) -> ResponseStream[TextFormatT]: raw_stream = self.__api_request() @@ -115,6 +145,7 @@ def __enter__(self) -> ResponseStream[TextFormatT]: text_format=self.__text_format, input_tools=self.__input_tools, starting_after=self.__starting_after, + cancel_response=self.__cancel_response, ) return self.__stream @@ -137,12 +168,14 @@ def __init__( text_format: type[TextFormatT] | Omit, input_tools: Iterable[ToolParam] | Omit, starting_after: int | None, + cancel_response: Callable[[str], Awaitable[Response]] | None = None, ) -> None: self._raw_stream = raw_stream self._response = raw_stream.response self._iterator = self.__stream__() self._state = ResponseStreamState(text_format=text_format, input_tools=input_tools) self._starting_after = starting_after + self._cancel_response = cancel_response async def __anext__(self) -> ResponseStreamEvent[TextFormatT]: return await self._iterator.__anext__() @@ -193,6 +226,31 @@ async def until_done(self) -> Self: await consume_async_iterator(self) return self + @property + def response_id(self) -> str | None: + """The ID of the response being streamed, available after the first event.""" + snapshot = self._state.current_snapshot + if snapshot is not None: + return snapshot.id + return None + + async def cancel(self) -> Response: + """Cancel the response being streamed. + + Calls the API to cancel the response and closes the stream. + + Returns the cancelled Response object. + """ + response_id = self.response_id + if response_id is None: + raise ValueError("Cannot cancel: response ID not yet available. Wait for the first event.") + if self._cancel_response is None: + raise ValueError("Cancel not available for this stream.") + try: + result = await self._cancel_response(response_id) + finally: + await self.close() + class AsyncResponseStreamManager(Generic[TextFormatT]): def __init__( @@ -202,12 +260,14 @@ def __init__( text_format: type[TextFormatT] | Omit, input_tools: Iterable[ToolParam] | Omit, starting_after: int | None, + cancel_response: Callable[[str], Awaitable[Response]] | None = None, ) -> None: self.__stream: AsyncResponseStream[TextFormatT] | None = None self.__api_request = api_request self.__text_format = text_format self.__input_tools = input_tools self.__starting_after = starting_after + self.__cancel_response = cancel_response async def __aenter__(self) -> AsyncResponseStream[TextFormatT]: raw_stream = await self.__api_request @@ -217,6 +277,7 @@ async def __aenter__(self) -> AsyncResponseStream[TextFormatT]: text_format=self.__text_format, input_tools=self.__input_tools, starting_after=self.__starting_after, + cancel_response=self.__cancel_response, ) return self.__stream @@ -244,6 +305,10 @@ def __init__( self._text_format = text_format self._rich_text_format: type | Omit = text_format if inspect.isclass(text_format) else omit + @property + def current_snapshot(self) -> ParsedResponseSnapshot | None: + return self.__current_snapshot + def handle_event(self, event: RawResponseStreamEvent) -> List[ResponseStreamEvent[TextFormatT]]: self.__current_snapshot = snapshot = self.accumulate_event(event) diff --git a/src/openai/resources/responses/responses.py b/src/openai/resources/responses/responses.py index c85a94495d..11d7029bed 100644 --- a/src/openai/resources/responses/responses.py +++ b/src/openai/resources/responses/responses.py @@ -1129,7 +1129,19 @@ def stream( timeout=timeout, ) - return ResponseStreamManager(api_request, text_format=text_format, input_tools=tools, starting_after=None) + return ResponseStreamManager( + api_request, + text_format=text_format, + input_tools=tools, + starting_after=None, + cancel_response=lambda response_id: self.cancel( + response_id, + extra_headers=extra_headers, + extra_query=extra_query, + extra_body=extra_body, + timeout=timeout, + ), + ) else: if not is_given(response_id): raise ValueError("id must be provided when streaming an existing response") @@ -1148,6 +1160,13 @@ def stream( text_format=text_format, input_tools=tools, starting_after=starting_after if is_given(starting_after) else None, + cancel_response=lambda response_id: self.cancel( + response_id, + extra_headers=extra_headers, + extra_query=extra_query, + extra_body=extra_body, + timeout=timeout, + ), ) def parse( @@ -2794,6 +2813,13 @@ def stream( text_format=text_format, input_tools=tools, starting_after=None, + cancel_response=lambda response_id: self.cancel( + response_id, + extra_headers=extra_headers, + extra_query=extra_query, + extra_body=extra_body, + timeout=timeout, + ), ) else: if isinstance(response_id, Omit): @@ -2813,6 +2839,13 @@ def stream( text_format=text_format, input_tools=tools, starting_after=starting_after if is_given(starting_after) else None, + cancel_response=lambda response_id: self.cancel( + response_id, + extra_headers=extra_headers, + extra_query=extra_query, + extra_body=extra_body, + timeout=timeout, + ), ) async def parse( diff --git a/tests/lib/responses/test_response_stream_cancel.py b/tests/lib/responses/test_response_stream_cancel.py new file mode 100644 index 0000000000..907c1d1105 --- /dev/null +++ b/tests/lib/responses/test_response_stream_cancel.py @@ -0,0 +1,223 @@ +from __future__ import annotations + +from unittest.mock import AsyncMock, MagicMock + +import pytest + +from openai._types import Omit +from openai._streaming import Stream, AsyncStream +from openai.types.responses import Response +from openai.lib.streaming.responses._responses import ( + ResponseStream, + AsyncResponseStream, + ResponseStreamState, +) + + +def _make_state_with_snapshot(response_id: str = "resp_123") -> ResponseStreamState[object]: + """Create a ResponseStreamState that has a snapshot with a given id.""" + from openai.types.responses.response import Response as RawResponse + from openai.types.responses.response_created_event import ResponseCreatedEvent + + state = ResponseStreamState(text_format=Omit(), input_tools=Omit()) + + raw_response = RawResponse.construct( + id=response_id, + object="response", + created_at=0, + status="in_progress", + output=[], + model="gpt-4o", + parallel_tool_calls=True, + tool_choice="auto", + tools=[], + temperature=1.0, + top_p=1.0, + max_output_tokens=None, + max_tool_calls=None, + previous_response_id=None, + reasoning=None, + truncation="disabled", + error=None, + incomplete_details=None, + instructions=None, + metadata={}, + text={"format": {"type": "text"}}, + usage=None, + user=None, + background=False, + store=True, + ) + + event = ResponseCreatedEvent.construct( + type="response.created", + response=raw_response, + sequence_number=0, + ) + + state.handle_event(event) + return state + + +class TestResponseStreamCancel: + def test_response_id_none_initially(self) -> None: + raw_stream = MagicMock(spec=Stream) + raw_stream.response = MagicMock() + + stream = ResponseStream( + raw_stream=raw_stream, + text_format=Omit(), + input_tools=Omit(), + starting_after=None, + ) + + assert stream.response_id is None + + def test_response_id_available_after_event(self) -> None: + raw_stream = MagicMock(spec=Stream) + raw_stream.response = MagicMock() + + stream = ResponseStream( + raw_stream=raw_stream, + text_format=Omit(), + input_tools=Omit(), + starting_after=None, + ) + + stream._state = _make_state_with_snapshot("resp_abc") + assert stream.response_id == "resp_abc" + + def test_cancel_raises_when_no_response_id(self) -> None: + raw_stream = MagicMock(spec=Stream) + raw_stream.response = MagicMock() + + stream = ResponseStream( + raw_stream=raw_stream, + text_format=Omit(), + input_tools=Omit(), + starting_after=None, + cancel_response=MagicMock(), + ) + + with pytest.raises(ValueError, match="response ID not yet available"): + stream.cancel() + + def test_cancel_raises_when_no_callback(self) -> None: + raw_stream = MagicMock(spec=Stream) + raw_stream.response = MagicMock() + + stream = ResponseStream( + raw_stream=raw_stream, + text_format=Omit(), + input_tools=Omit(), + starting_after=None, + ) + + stream._state = _make_state_with_snapshot("resp_abc") + + with pytest.raises(ValueError, match="Cancel not available"): + stream.cancel() + + def test_cancel_calls_callback_and_closes(self) -> None: + raw_stream = MagicMock(spec=Stream) + raw_stream.response = MagicMock() + + mock_response = MagicMock(spec=Response) + cancel_fn = MagicMock(return_value=mock_response) + + stream = ResponseStream( + raw_stream=raw_stream, + text_format=Omit(), + input_tools=Omit(), + starting_after=None, + cancel_response=cancel_fn, + ) + + stream._state = _make_state_with_snapshot("resp_xyz") + + result = stream.cancel() + + cancel_fn.assert_called_once_with("resp_xyz") + raw_stream.response.close.assert_called_once() + assert result is mock_response + + +class TestAsyncResponseStreamCancel: + def test_response_id_none_initially(self) -> None: + raw_stream = MagicMock(spec=AsyncStream) + raw_stream.response = MagicMock() + + stream = AsyncResponseStream( + raw_stream=raw_stream, + text_format=Omit(), + input_tools=Omit(), + starting_after=None, + ) + + assert stream.response_id is None + + def test_response_id_available_after_event(self) -> None: + raw_stream = MagicMock(spec=AsyncStream) + raw_stream.response = MagicMock() + + stream = AsyncResponseStream( + raw_stream=raw_stream, + text_format=Omit(), + input_tools=Omit(), + starting_after=None, + ) + + stream._state = _make_state_with_snapshot("resp_abc") + assert stream.response_id == "resp_abc" + + @pytest.mark.asyncio + async def test_cancel_raises_when_no_response_id(self) -> None: + raw_stream = MagicMock(spec=AsyncStream) + raw_stream.response = MagicMock() + + stream = AsyncResponseStream( + raw_stream=raw_stream, + text_format=Omit(), + input_tools=Omit(), + starting_after=None, + cancel_response=AsyncMock(), + ) + + with pytest.raises(ValueError, match="response ID not yet available"): + await stream.cancel() + + @pytest.mark.asyncio + async def test_cancel_calls_callback_and_closes(self) -> None: + raw_stream = MagicMock(spec=AsyncStream) + raw_stream.response = MagicMock() + raw_stream.response.aclose = AsyncMock() + + mock_response = MagicMock(spec=Response) + cancel_fn = AsyncMock(return_value=mock_response) + + stream = AsyncResponseStream( + raw_stream=raw_stream, + text_format=Omit(), + input_tools=Omit(), + starting_after=None, + cancel_response=cancel_fn, + ) + + stream._state = _make_state_with_snapshot("resp_xyz") + + result = await stream.cancel() + + cancel_fn.assert_called_once_with("resp_xyz") + raw_stream.response.aclose.assert_called_once() + assert result is mock_response + + +class TestResponseStreamStateSnapshot: + def test_current_snapshot_none_initially(self) -> None: + state = ResponseStreamState(text_format=Omit(), input_tools=Omit()) + assert state.current_snapshot is None + + def test_current_snapshot_available_after_event(self) -> None: + state = _make_state_with_snapshot("resp_test") + assert state.current_snapshot is not None + assert state.current_snapshot.id == "resp_test"