From 3b56a8975ecf475efdd4558a31953319fb74f4d1 Mon Sep 17 00:00:00 2001 From: Anthony Drendel Date: Mon, 20 Apr 2026 01:28:51 +0200 Subject: [PATCH 1/3] Fix race condition in WebSocket.messages --- Sources/WebSocket/WebSocket.swift | 30 +++++++++--------------------- 1 file changed, 9 insertions(+), 21 deletions(-) diff --git a/Sources/WebSocket/WebSocket.swift b/Sources/WebSocket/WebSocket.swift index bed23d2..855f760 100644 --- a/Sources/WebSocket/WebSocket.swift +++ b/Sources/WebSocket/WebSocket.swift @@ -73,27 +73,15 @@ public extension WebSocket { /// The WebSocket's received messages as an asynchronous stream. var messages: AsyncStream { - let cancellable = Locked(nil) - - return AsyncStream { cont in - func finish() { - cancellable.access { cancellable in - if cancellable != nil { - cont.finish() - cancellable = nil - } - } - } - - let _cancellable = self.messagesPublisher() - .handleEvents(receiveCancel: { finish() }) - .sink( - receiveCompletion: { _ in finish() }, - receiveValue: { cont.yield($0) } - ) - - cancellable.access { $0 = _cancellable } - } + let (stream, continuation) = AsyncStream.makeStream() + let cancellable = messagesPublisher() + .handleEvents(receiveCancel: { continuation.finish() }) + .sink( + receiveCompletion: { _ in continuation.finish() }, + receiveValue: { continuation.yield($0) } + ) + continuation.onTermination = { @Sendable _ in cancellable.cancel() } + return stream } } From aa17b19e6ad5c2302f42b4d7f6a7a4d675c0b018 Mon Sep 17 00:00:00 2001 From: Anthony Drendel Date: Mon, 20 Apr 2026 02:57:51 +0200 Subject: [PATCH 2/3] Fix open and close bugs --- Sources/WebSocket/SystemURLSession.swift | 65 +++++--- Sources/WebSocket/SystemWebSocket.swift | 8 +- .../Server/WebSocketServer.swift | 13 +- .../WebSocketTests/SystemWebSocketTests.swift | 144 +++++++++++++++++- 4 files changed, 203 insertions(+), 27 deletions(-) diff --git a/Sources/WebSocket/SystemURLSession.swift b/Sources/WebSocket/SystemURLSession.swift index 00b841c..b16558e 100644 --- a/Sources/WebSocket/SystemURLSession.swift +++ b/Sources/WebSocket/SystemURLSession.swift @@ -53,21 +53,30 @@ private func configuration(with options: WebSocketOptions) -> URLSessionConfigur return config } -private final class Delegate: NSObject, URLSessionWebSocketDelegate, Sendable { +final class Delegate: NSObject, URLSessionWebSocketDelegate, Sendable { private struct Callbacks: Sendable { let onOpen: @Sendable () async -> Void let onClose: @Sendable (WebSocketCloseCode, Data?) async -> Void } - // `Dictionary` - private let state: Locked<[ObjectIdentifier: Callbacks]> = .init([:]) + private struct State: Sendable { + var callbacks: [ObjectIdentifier: Callbacks] = [:] + var callbackTasks: [ObjectIdentifier: Task] = [:] + } + + private let state = Locked(State()) func set( onOpen: @escaping @Sendable () async -> Void, onClose: @escaping @Sendable (WebSocketCloseCode, Data?) async -> Void, for taskID: ObjectIdentifier ) { - state.access { $0[taskID] = .init(onOpen: onOpen, onClose: onClose) } + state.access { state in + state.callbacks[taskID] = .init( + onOpen: onOpen, + onClose: onClose + ) + } } func urlSession( @@ -76,9 +85,8 @@ private final class Delegate: NSObject, URLSessionWebSocketDelegate, Sendable { didOpenWithProtocol _: String? ) { let taskID = ObjectIdentifier(webSocketTask) - - if let onOpen = state.access({ $0[taskID]?.onOpen }) { - Task { await onOpen() } + enqueue(for: taskID) { callbacks in + await callbacks.onOpen() } } @@ -89,9 +97,8 @@ private final class Delegate: NSObject, URLSessionWebSocketDelegate, Sendable { reason: Data? ) { let taskID = ObjectIdentifier(webSocketTask) - - if let onClose = state.access({ $0[taskID]?.onClose }) { - Task { await onClose(WebSocketCloseCode(closeCode), reason) } + enqueue(for: taskID) { callbacks in + await callbacks.onClose(WebSocketCloseCode(closeCode), reason) } } @@ -101,20 +108,36 @@ private final class Delegate: NSObject, URLSessionWebSocketDelegate, Sendable { didCompleteWithError error: Error? ) { let taskID = ObjectIdentifier(task) + let closeCode: WebSocketCloseCode = error == nil ? .normalClosure : .abnormalClosure + let reason = error.map { Data($0.localizedDescription.utf8) } + enqueue(for: taskID, removeAfterwards: true) { callbacks in + await callbacks.onClose(closeCode, reason) + } + } - if let onClose = state.access({ $0[taskID]?.onClose }) { - Task { [weak self] in - if let error { - await onClose( - .abnormalClosure, - Data(error.localizedDescription.utf8) - ) - } else { - await onClose(.normalClosure, nil) - } + private func enqueue( + for taskID: ObjectIdentifier, + removeAfterwards: Bool = false, + _ operation: @escaping @Sendable (Callbacks) async -> Void + ) { + state.access { state in + guard let callbacks = state.callbacks[taskID] else { + return + } - self?.state.access { _ = $0.removeValue(forKey: taskID) } + let previousTask = state.callbackTasks[taskID] + let task = Task { [weak self] in + _ = await previousTask?.result + await operation(callbacks) + + guard removeAfterwards else { return } + self?.state.access { state in + _ = state.callbacks.removeValue(forKey: taskID) + _ = state.callbackTasks.removeValue(forKey: taskID) + } } + + state.callbackTasks[taskID] = task } } } diff --git a/Sources/WebSocket/SystemWebSocket.swift b/Sources/WebSocket/SystemWebSocket.swift index f6afa00..0dcb525 100644 --- a/Sources/WebSocket/SystemWebSocket.swift +++ b/Sources/WebSocket/SystemWebSocket.swift @@ -86,7 +86,7 @@ final actor SystemWebSocket: Publisher { do { try await didOpen.value } catch is CancellationError { - doClose(closeCode: .cancelled, reason: Data("cancelled".utf8)) + throw CancellationError() } catch is TimeoutError { doClose(closeCode: .timeout, reason: Data("timeout".utf8)) throw TimeoutError() @@ -250,9 +250,12 @@ private extension SystemWebSocket { } func doClose(closeCode: WebSocketCloseCode, reason: Data?) { + let close = WebSocketClose(closeCode, reason) + didOpen.fail(WebSocketError(closeCode, reason)) + switch state { case .unopened: - state = .closed(.init(closeCode, reason)) + state = .closed(close) case let .connecting(ws), let .open(ws): os_log( @@ -271,7 +274,6 @@ private extension SystemWebSocket { } } - let close = WebSocketClose(closeCode, nil) state = .closed(close) onClose(close) didClose?.resolve((code: closeCode, reason: reason)) diff --git a/Tests/WebSocketTests/Server/WebSocketServer.swift b/Tests/WebSocketTests/Server/WebSocketServer.swift index 757f2d2..03dbf16 100644 --- a/Tests/WebSocketTests/Server/WebSocketServer.swift +++ b/Tests/WebSocketTests/Server/WebSocketServer.swift @@ -1,12 +1,14 @@ import Combine import Foundation import NIO +import NIOWebSocket import WebSocket import WebSocketKit -enum WebSocketServerOutput: Hashable { +enum WebSocketServerOutput { case message(WebSocketMessage) case remoteClose + case remoteCloseWithReason(WebSocketErrorCode, Data) } final class WebSocketServer { @@ -71,6 +73,15 @@ final class WebSocketServer { do { try ws.close(code: .goingAway).wait() } catch {} + case let .remoteCloseWithReason(code, reason): + var buffer = ByteBufferAllocator().buffer(capacity: 2 + reason.count) + buffer.write(webSocketErrorCode: code) + buffer.writeBytes(reason) + ws.send( + raw: buffer.readableBytesView, + opcode: .connectionClose + ) + case let .message(message): switch message { case let .data(data): diff --git a/Tests/WebSocketTests/SystemWebSocketTests.swift b/Tests/WebSocketTests/SystemWebSocketTests.swift index b73936e..c02d61f 100644 --- a/Tests/WebSocketTests/SystemWebSocketTests.swift +++ b/Tests/WebSocketTests/SystemWebSocketTests.swift @@ -1,4 +1,7 @@ +import AsyncExtensions import Combine +import NIO +import NIOWebSocket import Synchronized @testable import WebSocket import XCTest @@ -40,7 +43,7 @@ class SystemWebSocketTests: XCTestCase { onOpen: { XCTFail("Should not have opened") }, onClose: { close in XCTAssertEqual(.abnormalClosure, close.code) - XCTAssertNil(close.reason) + XCTAssertNotNil(close.reason) ex.fulfill() } ) @@ -52,6 +55,52 @@ class SystemWebSocketTests: XCTestCase { XCTAssertTrue(isClosed) } + func testOpenCancellationThrowsCancellationError() async throws { + let server = try HangingServer() + defer { server.shutDown() } + + let client = try await SystemWebSocket( + request: request(server.port), + options: .init(timeoutIntervalForRequest: 5) + ) + + let openTask = Task { + try await client.open() + } + + try await Task.sleep(nanoseconds: 50 * NSEC_PER_MSEC) + openTask.cancel() + + switch await openTask.result { + case .success: + XCTFail("Expected `open()` to throw `CancellationError`") + + case let .failure(error): + XCTAssertTrue( + error is CancellationError, + "Received wrong error: \(String(reflecting: error))" + ) + } + } + + func testOpenThrowsConnectionErrorWhenServerIsUnreachable() async throws { + let (server, client) = try await makeOfflineServerAndClient( + timeoutIntervalForRequest: 0.2 + ) + defer { server.shutDown() } + + do { + try await client.open() + XCTFail("Should not have opened") + } catch is TimeoutError { + XCTFail("Should surface the connection failure instead of timing out") + } catch let error as WebSocketError { + XCTAssertEqual(.abnormalClosure, error.closeCode) + } catch { + XCTFail("Received wrong error: \(error)") + } + } + func _testErrorWhenRemoteCloses() async throws { let errorEx = expectation(description: "Should have closed") let (server, client) = try await makeServerAndClient( @@ -114,6 +163,53 @@ class SystemWebSocketTests: XCTestCase { await fulfillment(of: [secondCloseEx], timeout: 0.1) } + func testDelegateDoesNotReorderOpenAndCloseCallbacks() async throws { + let delegate = Delegate() + let session = URLSession(configuration: .ephemeral) + defer { session.invalidateAndCancel() } + + let task = session.webSocketTask(with: URL(string: "ws://127.0.0.1/socket")!) + let openStarted = AsyncThrowingFuture(timeout: 2) + let allowOpenToFinish = AsyncThrowingFuture(timeout: 2) + let records = Locked([String]()) + + delegate.set( + onOpen: { + records.access { $0.append("open-started") } + openStarted.resolve() + do { try await allowOpenToFinish.value } + catch { XCTFail() } + records.access { $0.append("open-finished") } + }, + onClose: { _, _ in + records.access { $0.append("close") } + }, + for: ObjectIdentifier(task) + ) + + delegate.urlSession(session, webSocketTask: task, didOpenWithProtocol: nil) + try await openStarted.value + + delegate.urlSession( + session, + webSocketTask: task, + didCloseWith: .goingAway, + reason: nil + ) + + try await Task.sleep(nanoseconds: 10 * NSEC_PER_MSEC) + let eventsBeforeOpenFinishes = records.access { $0 } + XCTAssertEqual(["open-started"], eventsBeforeOpenFinishes) + + allowOpenToFinish.resolve() + try await Task.sleep(nanoseconds: 10 * NSEC_PER_MSEC) + let eventsAfterOpenFinishes = records.access { $0 } + XCTAssertEqual( + ["open-started", "open-finished", "close"], + eventsAfterOpenFinishes + ) + } + func testPushAndReceiveText() async throws { let (server, client) = try await makeServerAndClient() defer { server.shutDown() } @@ -338,9 +434,27 @@ class SystemWebSocketTests: XCTestCase { } } + await fulfillment(of: [closeEx], timeout: 2) + XCTAssertEqual(3, messagesReceivedByClient) XCTAssertEqual(3, messagesReceivedByServer) + } + func testRemoteCloseReasonIsPassedToOnClose() async throws { + let closeEx = expectation(description: "Should expose the close reason") + let reason = Data("server said goodbye".utf8) + + let (server, client) = try await makeServerAndClient( + onClose: { close in + XCTAssertEqual(.goingAway, close.code) + XCTAssertEqual(reason, close.reason) + closeEx.fulfill() + } + ) + defer { server.shutDown() } + + try await client.open() + subject.send(.remoteCloseWithReason(.goingAway, reason)) await fulfillment(of: [closeEx], timeout: 2) } } @@ -373,13 +487,14 @@ private extension SystemWebSocketTests { } func makeOfflineServerAndClient( + timeoutIntervalForRequest: TimeInterval = 2, onOpen: @escaping @Sendable () -> Void = {}, onClose: @escaping @Sendable (WebSocketClose) -> Void = { _ in } ) async throws -> (WebSocketServer, SystemWebSocket) { let server = try WebSocketServer(outputPublisher: empty) let client = try! await SystemWebSocket( request: request(19), - options: .init(timeoutIntervalForRequest: 2), + options: .init(timeoutIntervalForRequest: timeoutIntervalForRequest), onOpen: onOpen, onClose: onClose ) @@ -400,3 +515,28 @@ private extension SystemWebSocketTests { return (server, try! await .system(client)) } } + +private final class HangingServer { + var port: Int { channel!.localAddress!.port! } + + private let eventLoopGroup: EventLoopGroup + private var channel: Channel? + + init() throws { + eventLoopGroup = MultiThreadedEventLoopGroup(numberOfThreads: 1) + channel = try ServerBootstrap(group: eventLoopGroup) + .serverChannelOption(ChannelOptions.backlog, value: 256) + .serverChannelOption(ChannelOptions.socketOption(.so_reuseaddr), value: 1) + .childChannelInitializer { channel in + channel.eventLoop.makeSucceededFuture(()) + } + .childChannelOption(ChannelOptions.socketOption(.so_reuseaddr), value: 1) + .bind(host: "127.0.0.1", port: 0) + .wait() + } + + func shutDown() { + try? channel?.close(mode: .all).wait() + try? eventLoopGroup.syncShutdownGracefully() + } +} From 32568bd2fb2c7c03e9049dd6673c8d7a9d9db5c2 Mon Sep 17 00:00:00 2001 From: Anthony Drendel Date: Mon, 20 Apr 2026 03:24:30 +0200 Subject: [PATCH 3/3] Fix warning --- Sources/WebSocket/WebSocket.swift | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Sources/WebSocket/WebSocket.swift b/Sources/WebSocket/WebSocket.swift index 855f760..772a6e6 100644 --- a/Sources/WebSocket/WebSocket.swift +++ b/Sources/WebSocket/WebSocket.swift @@ -80,7 +80,7 @@ public extension WebSocket { receiveCompletion: { _ in continuation.finish() }, receiveValue: { continuation.yield($0) } ) - continuation.onTermination = { @Sendable _ in cancellable.cancel() } + continuation.onTermination = { _ in cancellable.cancel() } return stream } }