diff --git a/Sources/OpenAI/OpenAI+OpenAIAsync.swift b/Sources/OpenAI/OpenAI+OpenAIAsync.swift index f4d47b94..7003e0d1 100644 --- a/Sources/OpenAI/OpenAI+OpenAIAsync.swift +++ b/Sources/OpenAI/OpenAI+OpenAIAsync.swift @@ -43,7 +43,16 @@ extension OpenAI: OpenAIAsync { chatsStream(query: query, onResult: onResult, completion: completion) } } - + + public func chatsStream( + query: ChatQuery, + onWebSearchEvent: @escaping @Sendable (WebSearchEvent) -> Void + ) -> AsyncThrowingStream { + makeAsyncStream { onResult, completion in + chatsStream(query: query, onResult: onResult, onWebSearchEvent: onWebSearchEvent, completion: completion) + } + } + public func model(query: ModelQuery) async throws -> ModelResult { try await performRequestAsync( request: makeModelRequest(query: query) diff --git a/Sources/OpenAI/OpenAI.swift b/Sources/OpenAI/OpenAI.swift index 4712fdda..ccfb7371 100644 --- a/Sources/OpenAI/OpenAI.swift +++ b/Sources/OpenAI/OpenAI.swift @@ -106,6 +106,62 @@ final public class OpenAI: OpenAIProtocol, @unchecked Sendable { ) } + /// Creates an OpenAI client with a custom URLSession protocol implementation. + /// + /// - Important: This initializer only uses the custom session for non-streaming requests. + /// For streaming requests, use the initializer that accepts a `URLSessionFactory`. + /// + /// - Parameters: + /// - configuration: The client configuration + /// - customSession: Custom URLSession protocol implementation + /// - middlewares: Optional middlewares for request/response interception + public convenience init( + configuration: Configuration, + customSession: any URLSessionProtocol, + middlewares: [OpenAIMiddleware] = [] + ) { + let streamingSessionFactory = ImplicitURLSessionStreamingSessionFactory( + middlewares: middlewares, + parsingOptions: configuration.parsingOptions, + sslDelegate: nil + ) + + self.init( + configuration: configuration, + session: customSession, + streamingSessionFactory: streamingSessionFactory, + middlewares: middlewares + ) + } + + /// Creates an OpenAI client with custom session handling for both regular and streaming requests. + /// + /// - Parameters: + /// - configuration: The client configuration + /// - customSession: Custom URLSession protocol implementation for non-streaming requests + /// - streamingURLSessionFactory: Factory for creating sessions for streaming requests + /// - middlewares: Optional middlewares for request/response interception + public convenience init( + configuration: Configuration, + customSession: any URLSessionProtocol, + streamingURLSessionFactory: URLSessionFactory, + middlewares: [OpenAIMiddleware] = [] + ) { + let streamingSessionFactory = ImplicitURLSessionStreamingSessionFactory( + urlSessionFactory: streamingURLSessionFactory, + middlewares: middlewares, + parsingOptions: configuration.parsingOptions, + sslDelegate: nil + ) + + self.init( + configuration: configuration, + session: customSession, + streamingSessionFactory: streamingSessionFactory, + middlewares: middlewares + ) + } + init( configuration: Configuration, session: URLSessionProtocol, @@ -284,9 +340,19 @@ final public class OpenAI: OpenAIProtocol, @unchecked Sendable { } public func chatsStream(query: ChatQuery, onResult: @escaping @Sendable (Result) -> Void, completion: (@Sendable (Error?) -> Void)?) -> CancellableRequest { + chatsStream(query: query, onResult: onResult, onWebSearchEvent: nil, completion: completion) + } + + public func chatsStream( + query: ChatQuery, + onResult: @escaping @Sendable (Result) -> Void, + onWebSearchEvent: (@Sendable (WebSearchEvent) -> Void)?, + completion: (@Sendable (Error?) -> Void)? + ) -> CancellableRequest { performStreamingRequest( request: JSONRequest(body: query.makeStreamable(), url: buildURL(path: .chats)), onResult: onResult, + onWebSearchEvent: onWebSearchEvent, completion: completion ) } @@ -355,9 +421,10 @@ extension OpenAI { func performStreamingRequest( request: any URLRequestBuildable, onResult: @escaping @Sendable (Result) -> Void, + onWebSearchEvent: (@Sendable (WebSearchEvent) -> Void)? = nil, completion: (@Sendable (Error?) -> Void)? ) -> CancellableRequest { - streamingClient.performStreamingRequest(request: request, onResult: onResult, completion: completion) + streamingClient.performStreamingRequest(request: request, onResult: onResult, onWebSearchEvent: onWebSearchEvent, completion: completion) } func performSpeechRequest(request: any URLRequestBuildable, completion: @escaping @Sendable (Result) -> Void) -> CancellableRequest { diff --git a/Sources/OpenAI/Private/Client/StreamingClient.swift b/Sources/OpenAI/Private/Client/StreamingClient.swift index e545a586..b127a0b8 100644 --- a/Sources/OpenAI/Private/Client/StreamingClient.swift +++ b/Sources/OpenAI/Private/Client/StreamingClient.swift @@ -32,6 +32,7 @@ final class StreamingClient: @unchecked Sendable { func performStreamingRequest( request: any URLRequestBuildable, onResult: @escaping @Sendable (Result) -> Void, + onWebSearchEvent: (@Sendable (WebSearchEvent) -> Void)? = nil, completion: (@Sendable (Error?) -> Void)? ) -> CancellableRequest { do { @@ -41,16 +42,20 @@ final class StreamingClient: @unchecked Sendable { } let session = streamingSessionFactory.makeServerSentEventsStreamingSession( - urlRequest: interceptedRequest - ) { _, object in - onResult(.success(object)) - } onProcessingError: { _, error in - onResult(.failure(error)) - } onComplete: { [weak self] session, error in - completion?(error) - self?.invalidateSession(session) - } - + urlRequest: interceptedRequest, + onReceiveContent: { _, object in + onResult(.success(object)) + }, + onWebSearchEvent: onWebSearchEvent, + onProcessingError: { _, error in + onResult(.failure(error)) + }, + onComplete: { [weak self] session, error in + completion?(error) + self?.invalidateSession(session) + } + ) + return runSession(session) } catch { completion?(error) diff --git a/Sources/OpenAI/Private/Streaming/InvalidatableSession.swift b/Sources/OpenAI/Private/Streaming/InvalidatableSession.swift index f262c69c..94e9ebf3 100644 --- a/Sources/OpenAI/Private/Streaming/InvalidatableSession.swift +++ b/Sources/OpenAI/Private/Streaming/InvalidatableSession.swift @@ -7,7 +7,7 @@ import Foundation -protocol InvalidatableSession: Sendable { +public protocol InvalidatableSession: Sendable { func invalidateAndCancel() func finishTasksAndInvalidate() } diff --git a/Sources/OpenAI/Private/Streaming/ServerSentEventsStreamInterpreter.swift b/Sources/OpenAI/Private/Streaming/ServerSentEventsStreamInterpreter.swift index 72b2e4ef..ef354bc2 100644 --- a/Sources/OpenAI/Private/Streaming/ServerSentEventsStreamInterpreter.swift +++ b/Sources/OpenAI/Private/Streaming/ServerSentEventsStreamInterpreter.swift @@ -17,6 +17,7 @@ final class ServerSentEventsStreamInterpreter : private var previousChunkBuffer = "" private var onEventDispatched: ((ResultType) -> Void)? + private var onWebSearchEvent: ((WebSearchEvent) -> Void)? private var onError: ((Error) -> Void)? private let parsingOptions: ParsingOptions @@ -39,8 +40,26 @@ final class ServerSentEventsStreamInterpreter : /// - Parameters: /// - onEventDispatched: Can be called multiple times per `processData` /// - onError: Will only be called once per `processData` - func setCallbackClosures(onEventDispatched: @escaping @Sendable (ResultType) -> Void, onError: @escaping @Sendable (Error) -> Void) { + func setCallbackClosures( + onEventDispatched: @escaping @Sendable (ResultType) -> Void, + onError: @escaping @Sendable (Error) -> Void + ) { + setCallbackClosures(onEventDispatched: onEventDispatched, onWebSearchEvent: nil, onError: onError) + } + + /// Sets closures an instance of type. Not thread safe. + /// + /// - Parameters: + /// - onEventDispatched: Can be called multiple times per `processData` + /// - onWebSearchEvent: Called when a web search event is received (optional) + /// - onError: Will only be called once per `processData` + func setCallbackClosures( + onEventDispatched: @escaping @Sendable (ResultType) -> Void, + onWebSearchEvent: (@Sendable (WebSearchEvent) -> Void)?, + onError: @escaping @Sendable (Error) -> Void + ) { self.onEventDispatched = onEventDispatched + self.onWebSearchEvent = onWebSearchEvent self.onError = onError } @@ -66,7 +85,21 @@ final class ServerSentEventsStreamInterpreter : onError?(StreamingError.unknownContent) return } - + + // Handle web search events (they have "type" field instead of "object") + // Event types include: "web_search_call", or prefixed like "response.web_search_call.*" + if let json = try? JSONSerialization.jsonObject(with: jsonData) as? [String: Any], + let eventType = json["type"] as? String, + eventType.contains("web_search") { + do { + let webSearchEvent = try JSONDecoder().decode(WebSearchEvent.self, from: jsonData) + onWebSearchEvent?(webSearchEvent) + } catch { + onError?(error) + } + return + } + let decoder = JSONResponseDecoder(parsingOptions: parsingOptions) do { let object: ResultType = try decoder.decodeResponseData(jsonData) diff --git a/Sources/OpenAI/Private/Streaming/ServerSentEventsStreamingSessionFactory.swift b/Sources/OpenAI/Private/Streaming/ServerSentEventsStreamingSessionFactory.swift index 21fb2b14..25616010 100644 --- a/Sources/OpenAI/Private/Streaming/ServerSentEventsStreamingSessionFactory.swift +++ b/Sources/OpenAI/Private/Streaming/ServerSentEventsStreamingSessionFactory.swift @@ -15,6 +15,7 @@ protocol StreamingSessionFactory: Sendable { func makeServerSentEventsStreamingSession( urlRequest: URLRequest, onReceiveContent: @Sendable @escaping (StreamingSession>, ResultType) -> Void, + onWebSearchEvent: (@Sendable (WebSearchEvent) -> Void)?, onProcessingError: @Sendable @escaping (StreamingSession>, Error) -> Void, onComplete: @Sendable @escaping (StreamingSession>, Error?) -> Void ) -> StreamingSession> @@ -35,27 +36,43 @@ protocol StreamingSessionFactory: Sendable { } struct ImplicitURLSessionStreamingSessionFactory: StreamingSessionFactory { + let urlSessionFactory: URLSessionFactory let middlewares: [OpenAIMiddleware] let parsingOptions: ParsingOptions let sslDelegate: SSLDelegateProtocol? - + + init( + urlSessionFactory: URLSessionFactory = FoundationURLSessionFactory(), + middlewares: [OpenAIMiddleware], + parsingOptions: ParsingOptions, + sslDelegate: SSLDelegateProtocol? + ) { + self.urlSessionFactory = urlSessionFactory + self.middlewares = middlewares + self.parsingOptions = parsingOptions + self.sslDelegate = sslDelegate + } + func makeServerSentEventsStreamingSession( urlRequest: URLRequest, onReceiveContent: @Sendable @escaping (StreamingSession>, ResultType) -> Void, + onWebSearchEvent: (@Sendable (WebSearchEvent) -> Void)?, onProcessingError: @Sendable @escaping (StreamingSession>, any Error) -> Void, onComplete: @Sendable @escaping (StreamingSession>, (any Error)?) -> Void ) -> StreamingSession> where ResultType : Decodable, ResultType : Encodable, ResultType : Sendable { .init( + urlSessionFactory: urlSessionFactory, urlRequest: urlRequest, interpreter: .init(parsingOptions: parsingOptions), sslDelegate: sslDelegate, middlewares: middlewares, onReceiveContent: onReceiveContent, + onWebSearchEvent: onWebSearchEvent, onProcessingError: onProcessingError, onComplete: onComplete ) } - + func makeAudioSpeechStreamingSession( urlRequest: URLRequest, onReceiveContent: @Sendable @escaping (StreamingSession, AudioSpeechResult) -> Void, @@ -63,6 +80,7 @@ struct ImplicitURLSessionStreamingSessionFactory: StreamingSessionFactory { onComplete: @Sendable @escaping (StreamingSession, (any Error)?) -> Void ) -> StreamingSession { .init( + urlSessionFactory: urlSessionFactory, urlRequest: urlRequest, interpreter: .init(), sslDelegate: sslDelegate, @@ -72,7 +90,7 @@ struct ImplicitURLSessionStreamingSessionFactory: StreamingSessionFactory { onComplete: onComplete ) } - + func makeModelResponseStreamingSession( urlRequest: URLRequest, onReceiveContent: @Sendable @escaping (StreamingSession, ResponseStreamEvent) -> Void, @@ -80,6 +98,7 @@ struct ImplicitURLSessionStreamingSessionFactory: StreamingSessionFactory { onComplete: @Sendable @escaping (StreamingSession, (any Error)?) -> Void ) -> StreamingSession { .init( + urlSessionFactory: urlSessionFactory, urlRequest: urlRequest, interpreter: .init(), sslDelegate: sslDelegate, diff --git a/Sources/OpenAI/Private/Streaming/StreamingSession.swift b/Sources/OpenAI/Private/Streaming/StreamingSession.swift index b718db6b..fe162c75 100644 --- a/Sources/OpenAI/Private/Streaming/StreamingSession.swift +++ b/Sources/OpenAI/Private/Streaming/StreamingSession.swift @@ -21,6 +21,7 @@ final class StreamingSession: NSObject, Identifi private let middlewares: [OpenAIMiddleware] private let executionSerializer: ExecutionSerializer private let onReceiveContent: (@Sendable (StreamingSession, ResultType) -> Void)? + private let onWebSearchEvent: (@Sendable (WebSearchEvent) -> Void)? private let onProcessingError: (@Sendable (StreamingSession, Error) -> Void)? private let onComplete: (@Sendable (StreamingSession, Error?) -> Void)? @@ -32,6 +33,7 @@ final class StreamingSession: NSObject, Identifi middlewares: [OpenAIMiddleware], executionSerializer: ExecutionSerializer = GCDQueueAsyncExecutionSerializer(queue: .userInitiated), onReceiveContent: @escaping @Sendable (StreamingSession, ResultType) -> Void, + onWebSearchEvent: (@Sendable (WebSearchEvent) -> Void)? = nil, onProcessingError: @escaping @Sendable (StreamingSession, Error) -> Void, onComplete: @escaping @Sendable (StreamingSession, Error?) -> Void ) { @@ -42,6 +44,7 @@ final class StreamingSession: NSObject, Identifi self.middlewares = middlewares self.executionSerializer = executionSerializer self.onReceiveContent = onReceiveContent + self.onWebSearchEvent = onWebSearchEvent self.onProcessingError = onProcessingError self.onComplete = onComplete super.init() @@ -96,12 +99,25 @@ final class StreamingSession: NSObject, Identifi } private func subscribeToParser() { - interpreter.setCallbackClosures { [weak self] content in - guard let self else { return } - self.onReceiveContent?(self, content) - } onError: { [weak self] error in - guard let self else { return } - self.onProcessingError?(self, error) + // Check if interpreter supports web search events (ServerSentEventsStreamInterpreter) + if let sseInterpreter = interpreter as? ServerSentEventsStreamInterpreter { + sseInterpreter.setCallbackClosures { [weak self] content in + guard let self else { return } + self.onReceiveContent?(self, content) + } onWebSearchEvent: { [weak self] event in + self?.onWebSearchEvent?(event) + } onError: { [weak self] error in + guard let self else { return } + self.onProcessingError?(self, error) + } + } else { + interpreter.setCallbackClosures { [weak self] content in + guard let self else { return } + self.onReceiveContent?(self, content) + } onError: { [weak self] error in + guard let self else { return } + self.onProcessingError?(self, error) + } } } } diff --git a/Sources/OpenAI/Private/URLSessionCombine.swift b/Sources/OpenAI/Private/URLSessionCombine.swift index e4c521a7..e4776bc5 100644 --- a/Sources/OpenAI/Private/URLSessionCombine.swift +++ b/Sources/OpenAI/Private/URLSessionCombine.swift @@ -14,19 +14,19 @@ import FoundationNetworking #if canImport(Combine) import Combine -protocol URLSessionCombine { +public protocol URLSessionCombine { func dataTaskPublisher(for request: URLRequest) -> AnyPublisher<(data: Data, response: URLResponse), URLError> } extension URLSession: URLSessionCombine { - func dataTaskPublisher(for request: URLRequest) -> AnyPublisher<(data: Data, response: URLResponse), URLError> { + public func dataTaskPublisher(for request: URLRequest) -> AnyPublisher<(data: Data, response: URLResponse), URLError> { let typedPublisher: URLSession.DataTaskPublisher = dataTaskPublisher(for: request) return typedPublisher.eraseToAnyPublisher() } } #else -protocol URLSessionCombine { +public protocol URLSessionCombine { } extension URLSession: URLSessionCombine {} diff --git a/Sources/OpenAI/Private/URLSessionDataTaskProtocol.swift b/Sources/OpenAI/Private/URLSessionDataTaskProtocol.swift index ab044231..a8416b8d 100644 --- a/Sources/OpenAI/Private/URLSessionDataTaskProtocol.swift +++ b/Sources/OpenAI/Private/URLSessionDataTaskProtocol.swift @@ -10,14 +10,14 @@ import Foundation import FoundationNetworking #endif -protocol URLSessionTaskProtocol: Sendable { +public protocol URLSessionTaskProtocol: Sendable { var originalRequest: URLRequest? { get } func cancel() } extension URLSessionTask: URLSessionTaskProtocol {} -protocol URLSessionDataTaskProtocol: URLSessionTaskProtocol { +public protocol URLSessionDataTaskProtocol: URLSessionTaskProtocol { func resume() } diff --git a/Sources/OpenAI/Private/URLSessionDelegateProtocol.swift b/Sources/OpenAI/Private/URLSessionDelegateProtocol.swift index f85fc619..4f743d07 100644 --- a/Sources/OpenAI/Private/URLSessionDelegateProtocol.swift +++ b/Sources/OpenAI/Private/URLSessionDelegateProtocol.swift @@ -11,9 +11,12 @@ import Foundation import FoundationNetworking #endif -protocol URLSessionDelegateProtocol: Sendable { // Sendable to make a better match with URLSessionDelegate, it's sendable too +/// Protocol for handling URLSession delegate callbacks. +/// Sendable to match URLSessionDelegate behavior. +/// AnyObject constraint allows weak references to delegate implementations. +public protocol URLSessionDelegateProtocol: AnyObject, Sendable { func urlSession(_ session: URLSessionProtocol, task: URLSessionTaskProtocol, didCompleteWithError error: Error?) - + func urlSession( _ session: URLSession, didReceive challenge: URLAuthenticationChallenge, @@ -21,9 +24,11 @@ protocol URLSessionDelegateProtocol: Sendable { // Sendable to make a better mat ) } -protocol URLSessionDataDelegateProtocol: URLSessionDelegateProtocol { +/// Protocol for handling URLSession data delegate callbacks. +/// Used for streaming data reception. +public protocol URLSessionDataDelegateProtocol: URLSessionDelegateProtocol { func urlSession(_ session: URLSessionProtocol, dataTask: URLSessionDataTaskProtocol, didReceive data: Data) - + func urlSession( _ session: URLSessionProtocol, dataTask: URLSessionDataTaskProtocol, diff --git a/Sources/OpenAI/Private/URLSessionFactory.swift b/Sources/OpenAI/Private/URLSessionFactory.swift index 7b99769b..3baba037 100644 --- a/Sources/OpenAI/Private/URLSessionFactory.swift +++ b/Sources/OpenAI/Private/URLSessionFactory.swift @@ -10,12 +10,20 @@ import Foundation import FoundationNetworking #endif -protocol URLSessionFactory: Sendable { +/// Factory protocol for creating URLSession instances. +/// Implement this protocol to provide custom session creation for streaming requests. +public protocol URLSessionFactory: Sendable { + /// Creates a URLSession for streaming requests. + /// - Parameter delegate: The delegate to receive streaming data callbacks + /// - Returns: A URLSession protocol implementation func makeUrlSession(delegate: URLSessionDataDelegateProtocol) -> URLSessionProtocol } -struct FoundationURLSessionFactory: URLSessionFactory { - func makeUrlSession(delegate: URLSessionDataDelegateProtocol) -> any URLSessionProtocol { +/// Default factory that creates standard Foundation URLSession instances. +public struct FoundationURLSessionFactory: URLSessionFactory { + public init() {} + + public func makeUrlSession(delegate: URLSessionDataDelegateProtocol) -> any URLSessionProtocol { let forwarder = URLSessionDataDelegateForwarder(target: delegate) return URLSession(configuration: .default, delegate: forwarder, delegateQueue: nil) } diff --git a/Sources/OpenAI/Private/URLSessionProtocol.swift b/Sources/OpenAI/Private/URLSessionProtocol.swift index cd4bccd7..a59506db 100644 --- a/Sources/OpenAI/Private/URLSessionProtocol.swift +++ b/Sources/OpenAI/Private/URLSessionProtocol.swift @@ -10,20 +10,20 @@ import Foundation import FoundationNetworking #endif -protocol URLSessionProtocol: InvalidatableSession, URLSessionCombine { +public protocol URLSessionProtocol: InvalidatableSession, URLSessionCombine { func dataTask(with request: URLRequest, completionHandler: @escaping @Sendable (Data?, URLResponse?, Error?) -> Void) -> URLSessionDataTaskProtocol func dataTask(with request: URLRequest) -> URLSessionDataTaskProtocol - + @available(iOS 15.0, macOS 12.0, tvOS 15.0, watchOS 8.0, *) func data(for request: URLRequest, delegate: (any URLSessionTaskDelegate)?) async throws -> (Data, URLResponse) } extension URLSession: URLSessionProtocol { - func dataTask(with request: URLRequest) -> URLSessionDataTaskProtocol { + public func dataTask(with request: URLRequest) -> URLSessionDataTaskProtocol { dataTask(with: request) as URLSessionDataTask } - - func dataTask(with request: URLRequest, completionHandler: @escaping @Sendable (Data?, URLResponse?, Error?) -> Void) -> URLSessionDataTaskProtocol { + + public func dataTask(with request: URLRequest, completionHandler: @escaping @Sendable (Data?, URLResponse?, Error?) -> Void) -> URLSessionDataTaskProtocol { dataTask(with: request, completionHandler: completionHandler) as URLSessionDataTask } } diff --git a/Sources/OpenAI/Public/Models/ChatStreamResult.swift b/Sources/OpenAI/Public/Models/ChatStreamResult.swift index c1f30216..8ba132c3 100644 --- a/Sources/OpenAI/Public/Models/ChatStreamResult.swift +++ b/Sources/OpenAI/Public/Models/ChatStreamResult.swift @@ -26,6 +26,9 @@ public struct ChatStreamResult: Codable, Equatable, Sendable { public let role: Self.Role? public let toolCalls: [Self.ChoiceDeltaToolCall]? + /// URL citation annotations from web search results + public let annotations: [Self.Annotation]? + /// Value for `reasoning` field in response. /// /// Provided by: @@ -107,11 +110,44 @@ public struct ChatStreamResult: Codable, Equatable, Sendable { } } + /// An annotation containing citation information from web search + public struct Annotation: Codable, Equatable, Sendable { + /// The type of annotation (e.g., "url_citation") + public let type: String + /// URL citation details + public let urlCitation: URLCitation? + + /// URL citation information from web search results + public struct URLCitation: Codable, Equatable, Sendable { + /// The URL of the cited source + public let url: String + /// The title of the cited source + public let title: String? + /// Start index in the content where this citation applies + public let startIndex: Int? + /// End index in the content where this citation applies + public let endIndex: Int? + + public enum CodingKeys: String, CodingKey { + case url + case title + case startIndex = "start_index" + case endIndex = "end_index" + } + } + + public enum CodingKeys: String, CodingKey { + case type + case urlCitation = "url_citation" + } + } + public enum CodingKeys: String, CodingKey { case content case audio case role case toolCalls = "tool_calls" + case annotations case _reasoning = "reasoning" case _reasoningContent = "reasoning_content" } diff --git a/Sources/OpenAI/Public/Models/Types/WebSearchEvent.swift b/Sources/OpenAI/Public/Models/Types/WebSearchEvent.swift new file mode 100644 index 00000000..debb6f38 --- /dev/null +++ b/Sources/OpenAI/Public/Models/Types/WebSearchEvent.swift @@ -0,0 +1,58 @@ +// +// WebSearchEvent.swift +// OpenAI +// +// Created on 01/02/2026. +// + +import Foundation + +/// Represents a web search event during streaming. +/// These events indicate the status of web search operations triggered by the model. +public struct WebSearchEvent: Codable, Equatable, Sendable { + /// The type of event, typically "web_search_call" + public let type: String + + /// Unique ID for the output item associated with the web search call + public let itemId: String? + + /// The index of the output item that the web search call is associated with + public let outputIndex: Int? + + /// The status of the web search operation + public let status: Status + + /// The action being performed (contains the search query) + public let action: Action? + + /// Reason for blocked or failed searches + public let reason: String? + + /// Possible statuses for a web search event + public enum Status: String, Codable, Sendable { + case inProgress = "in_progress" + case searching + case completed + case failed + case blocked + } + + /// The action associated with the web search + public struct Action: Codable, Equatable, Sendable { + /// The type of action: "search", "open_page", or "find_in_page" + public let type: String? + /// The search query being executed (for "search" actions) + public let query: String? + /// The URL being fetched (for "open_page" actions) + public let url: String? + } + + private enum CodingKeys: String, CodingKey { + case type + case itemId = "item_id" + case outputIndex = "output_index" + case status + case action + case reason + } +} diff --git a/Sources/OpenAI/Public/Protocols/OpenAIAsync.swift b/Sources/OpenAI/Public/Protocols/OpenAIAsync.swift index 8a6a2431..fa5246af 100644 --- a/Sources/OpenAI/Public/Protocols/OpenAIAsync.swift +++ b/Sources/OpenAI/Public/Protocols/OpenAIAsync.swift @@ -14,6 +14,7 @@ public protocol OpenAIAsync: Sendable { func embeddings(query: EmbeddingsQuery) async throws -> EmbeddingsResult func chats(query: ChatQuery) async throws -> ChatResult func chatsStream(query: ChatQuery) -> AsyncThrowingStream + func chatsStream(query: ChatQuery, onWebSearchEvent: @escaping @Sendable (WebSearchEvent) -> Void) -> AsyncThrowingStream func model(query: ModelQuery) async throws -> ModelResult func models() async throws -> ModelsResult func moderations(query: ModerationsQuery) async throws -> ModerationsResult diff --git a/Sources/OpenAI/Public/Protocols/OpenAIProtocol.swift b/Sources/OpenAI/Public/Protocols/OpenAIProtocol.swift index dba875b5..f4c14093 100644 --- a/Sources/OpenAI/Public/Protocols/OpenAIProtocol.swift +++ b/Sources/OpenAI/Public/Protocols/OpenAIProtocol.swift @@ -124,7 +124,9 @@ public protocol OpenAIProtocol: OpenAIModern { - Note: This method creates and configures separate session object specifically for streaming. In order for it to work properly and don't leak memory you should hold a reference to the returned value, and when you're done - call cancel() on it. */ @discardableResult func chatsStream(query: ChatQuery, onResult: @escaping @Sendable (Result) -> Void, completion: (@Sendable (Error?) -> Void)?) -> CancellableRequest - + + @discardableResult func chatsStream(query: ChatQuery, onResult: @escaping @Sendable (Result) -> Void, onWebSearchEvent: (@Sendable (WebSearchEvent) -> Void)?, completion: (@Sendable (Error?) -> Void)?) -> CancellableRequest + /** This function sends a model query to the OpenAI API and retrieves a model instance, providing owner information. The Models API in this usage enables you to gather detailed information on the model in question, like GPT-3. diff --git a/Tests/OpenAITests/Mocks/MockStreamingSessionFactory.swift b/Tests/OpenAITests/Mocks/MockStreamingSessionFactory.swift index b7861397..3023abd9 100644 --- a/Tests/OpenAITests/Mocks/MockStreamingSessionFactory.swift +++ b/Tests/OpenAITests/Mocks/MockStreamingSessionFactory.swift @@ -19,6 +19,7 @@ class MockStreamingSessionFactory: StreamingSessionFactory, @unchecked Sendable func makeServerSentEventsStreamingSession( urlRequest: URLRequest, onReceiveContent: @Sendable @escaping (StreamingSession>, ResultType) -> Void, + onWebSearchEvent: (@Sendable (WebSearchEvent) -> Void)?, onProcessingError: @Sendable @escaping (StreamingSession>, any Error) -> Void, onComplete: @Sendable @escaping (StreamingSession>, (any Error)?) -> Void ) -> StreamingSession> where ResultType : Decodable, ResultType : Encodable, ResultType : Sendable { @@ -30,6 +31,7 @@ class MockStreamingSessionFactory: StreamingSessionFactory, @unchecked Sendable middlewares: [], executionSerializer: executionSerializer, onReceiveContent: onReceiveContent, + onWebSearchEvent: onWebSearchEvent, onProcessingError: onProcessingError, onComplete: onComplete ) diff --git a/Tests/OpenAITests/ServerSentEventsStreamInterpreterTests.swift b/Tests/OpenAITests/ServerSentEventsStreamInterpreterTests.swift index 464ff0fa..dcd514c5 100644 --- a/Tests/OpenAITests/ServerSentEventsStreamInterpreterTests.swift +++ b/Tests/OpenAITests/ServerSentEventsStreamInterpreterTests.swift @@ -66,7 +66,7 @@ struct ServerSentEventsStreamInterpreterTests { @Test func parseApiError() async throws { var error: Error! - + await withCheckedContinuation { continuation in interpreter.setCallbackClosures { result in } onError: { apiError in @@ -77,13 +77,141 @@ struct ServerSentEventsStreamInterpreterTests { } } } - + interpreter.processData(chatCompletionError()) } - + #expect(error is APIErrorResponse) } - + + @Test func parseWebSearchEvent() async throws { + var webSearchEvents: [WebSearchEvent] = [] + var chatResults: [ChatStreamResult] = [] + var errors: [Error] = [] + + await withCheckedContinuation { continuation in + interpreter.setCallbackClosures { result in + Task { @MainActor in + chatResults.append(result) + continuation.resume() + } + } onWebSearchEvent: { event in + Task { @MainActor in + webSearchEvents.append(event) + continuation.resume() + } + } onError: { error in + Task { @MainActor in + errors.append(error) + continuation.resume() + } + } + + interpreter.processData(webSearchEventData()) + } + + #expect(chatResults.isEmpty, "Expected no chat results") + #expect(errors.isEmpty, "Expected no errors") + #expect(webSearchEvents.count == 1) + #expect(webSearchEvents.first?.type == "web_search_call") + #expect(webSearchEvents.first?.status == .inProgress) + #expect(webSearchEvents.first?.action?.query == "latest news") + } + + @Test func parseWebSearchEventCompleted() async throws { + var webSearchEvents: [WebSearchEvent] = [] + var chatResults: [ChatStreamResult] = [] + var errors: [Error] = [] + + await withCheckedContinuation { continuation in + interpreter.setCallbackClosures { result in + Task { @MainActor in + chatResults.append(result) + continuation.resume() + } + } onWebSearchEvent: { event in + Task { @MainActor in + webSearchEvents.append(event) + continuation.resume() + } + } onError: { error in + Task { @MainActor in + errors.append(error) + continuation.resume() + } + } + + interpreter.processData(webSearchEventCompletedData()) + } + + #expect(chatResults.isEmpty, "Expected no chat results") + #expect(errors.isEmpty, "Expected no errors") + #expect(webSearchEvents.count == 1) + #expect(webSearchEvents.first?.status == .completed) + } + + @Test func webSearchEventDoesNotTriggerOnResult() async throws { + var chatStreamResults: [ChatStreamResult] = [] + var webSearchEvents: [WebSearchEvent] = [] + var errors: [Error] = [] + + await withCheckedContinuation { continuation in + interpreter.setCallbackClosures { result in + Task { @MainActor in + chatStreamResults.append(result) + continuation.resume() + } + } onWebSearchEvent: { event in + Task { @MainActor in + webSearchEvents.append(event) + continuation.resume() + } + } onError: { error in + Task { @MainActor in + errors.append(error) + continuation.resume() + } + } + + interpreter.processData(webSearchEventData()) + } + + #expect(errors.isEmpty, "Expected no errors") + #expect(chatStreamResults.isEmpty) + #expect(webSearchEvents.count == 1) + } + + @Test func invalidWebSearchEventReportsError() async throws { + var receivedError: Error? + var webSearchEvents: [WebSearchEvent] = [] + var chatResults: [ChatStreamResult] = [] + + await withCheckedContinuation { continuation in + interpreter.setCallbackClosures { result in + Task { @MainActor in + chatResults.append(result) + continuation.resume() + } + } onWebSearchEvent: { event in + Task { @MainActor in + webSearchEvents.append(event) + continuation.resume() + } + } onError: { error in + Task { @MainActor in + receivedError = error + continuation.resume() + } + } + + interpreter.processData(invalidWebSearchEventData()) + } + + #expect(chatResults.isEmpty, "Expected no chat results") + #expect(webSearchEvents.isEmpty, "Expected no web search events") + #expect(receivedError != nil) + } + private func chatCompletionChunk() -> Data { MockServerSentEvent.chatCompletionChunk() } @@ -100,6 +228,18 @@ struct ServerSentEventsStreamInterpreterTests { private func chatCompletionError() -> Data { MockServerSentEvent.chatCompletionError() } + + private func webSearchEventData() -> Data { + "data: {\"type\":\"web_search_call\",\"item_id\":\"ws_123\",\"output_index\":0,\"status\":\"in_progress\",\"action\":{\"type\":\"search\",\"query\":\"latest news\"}}\n\n".data(using: .utf8)! + } + + private func webSearchEventCompletedData() -> Data { + "data: {\"type\":\"web_search_call\",\"item_id\":\"ws_123\",\"output_index\":0,\"status\":\"completed\"}\n\n".data(using: .utf8)! + } + + private func invalidWebSearchEventData() -> Data { + "data: {\"type\":\"web_search_call\",\"status\":\"unknown_status\"}\n\n".data(using: .utf8)! + } } private actor ChatStreamResultsActor {