Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 44 additions & 21 deletions Sources/WebSocket/SystemURLSession.swift
Original file line number Diff line number Diff line change
Expand Up @@ -53,21 +53,30 @@ private func configuration(with options: WebSocketOptions) -> URLSessionConfigur
return config
}

private final class Delegate: NSObject, URLSessionWebSocketDelegate, Sendable {
final class Delegate: NSObject, URLSessionWebSocketDelegate, Sendable {
private struct Callbacks: Sendable {
let onOpen: @Sendable () async -> Void
let onClose: @Sendable (WebSocketCloseCode, Data?) async -> Void
}

// `Dictionary<ObjectIdentifier(URLWebSocketTask): Callbacks>`
private let state: Locked<[ObjectIdentifier: Callbacks]> = .init([:])
private struct State: Sendable {
var callbacks: [ObjectIdentifier: Callbacks] = [:]
var callbackTasks: [ObjectIdentifier: Task<Void, Never>] = [:]
}

private let state = Locked(State())

func set(
onOpen: @escaping @Sendable () async -> Void,
onClose: @escaping @Sendable (WebSocketCloseCode, Data?) async -> Void,
for taskID: ObjectIdentifier
) {
state.access { $0[taskID] = .init(onOpen: onOpen, onClose: onClose) }
state.access { state in
state.callbacks[taskID] = .init(
onOpen: onOpen,
onClose: onClose
)
}
}

func urlSession(
Expand All @@ -76,9 +85,8 @@ private final class Delegate: NSObject, URLSessionWebSocketDelegate, Sendable {
didOpenWithProtocol _: String?
) {
let taskID = ObjectIdentifier(webSocketTask)

if let onOpen = state.access({ $0[taskID]?.onOpen }) {
Task { await onOpen() }
enqueue(for: taskID) { callbacks in
await callbacks.onOpen()
}
}

Expand All @@ -89,9 +97,8 @@ private final class Delegate: NSObject, URLSessionWebSocketDelegate, Sendable {
reason: Data?
) {
let taskID = ObjectIdentifier(webSocketTask)

if let onClose = state.access({ $0[taskID]?.onClose }) {
Task { await onClose(WebSocketCloseCode(closeCode), reason) }
enqueue(for: taskID) { callbacks in
await callbacks.onClose(WebSocketCloseCode(closeCode), reason)
}
}

Expand All @@ -101,20 +108,36 @@ private final class Delegate: NSObject, URLSessionWebSocketDelegate, Sendable {
didCompleteWithError error: Error?
) {
let taskID = ObjectIdentifier(task)
let closeCode: WebSocketCloseCode = error == nil ? .normalClosure : .abnormalClosure
let reason = error.map { Data($0.localizedDescription.utf8) }
enqueue(for: taskID, removeAfterwards: true) { callbacks in
await callbacks.onClose(closeCode, reason)
}
}

if let onClose = state.access({ $0[taskID]?.onClose }) {
Task { [weak self] in
if let error {
await onClose(
.abnormalClosure,
Data(error.localizedDescription.utf8)
)
} else {
await onClose(.normalClosure, nil)
}
private func enqueue(
for taskID: ObjectIdentifier,
removeAfterwards: Bool = false,
_ operation: @escaping @Sendable (Callbacks) async -> Void
) {
state.access { state in
guard let callbacks = state.callbacks[taskID] else {
return
}

self?.state.access { _ = $0.removeValue(forKey: taskID) }
let previousTask = state.callbackTasks[taskID]
let task = Task { [weak self] in
_ = await previousTask?.result
await operation(callbacks)

guard removeAfterwards else { return }
self?.state.access { state in
_ = state.callbacks.removeValue(forKey: taskID)
_ = state.callbackTasks.removeValue(forKey: taskID)
}
}

state.callbackTasks[taskID] = task
}
}
}
8 changes: 5 additions & 3 deletions Sources/WebSocket/SystemWebSocket.swift
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ final actor SystemWebSocket: Publisher {
do {
try await didOpen.value
} catch is CancellationError {
doClose(closeCode: .cancelled, reason: Data("cancelled".utf8))
throw CancellationError()
} catch is TimeoutError {
doClose(closeCode: .timeout, reason: Data("timeout".utf8))
throw TimeoutError()
Expand Down Expand Up @@ -250,9 +250,12 @@ private extension SystemWebSocket {
}

func doClose(closeCode: WebSocketCloseCode, reason: Data?) {
let close = WebSocketClose(closeCode, reason)
didOpen.fail(WebSocketError(closeCode, reason))

switch state {
case .unopened:
state = .closed(.init(closeCode, reason))
state = .closed(close)

case let .connecting(ws), let .open(ws):
os_log(
Expand All @@ -271,7 +274,6 @@ private extension SystemWebSocket {
}
}

let close = WebSocketClose(closeCode, nil)
state = .closed(close)
onClose(close)
didClose?.resolve((code: closeCode, reason: reason))
Expand Down
30 changes: 9 additions & 21 deletions Sources/WebSocket/WebSocket.swift
Original file line number Diff line number Diff line change
Expand Up @@ -73,27 +73,15 @@ public extension WebSocket {

/// The WebSocket's received messages as an asynchronous stream.
var messages: AsyncStream<WebSocketMessage> {
let cancellable = Locked<AnyCancellable?>(nil)

return AsyncStream { cont in
func finish() {
cancellable.access { cancellable in
if cancellable != nil {
cont.finish()
cancellable = nil
}
}
}

let _cancellable = self.messagesPublisher()
.handleEvents(receiveCancel: { finish() })
.sink(
receiveCompletion: { _ in finish() },
receiveValue: { cont.yield($0) }
)

cancellable.access { $0 = _cancellable }
}
let (stream, continuation) = AsyncStream<WebSocketMessage>.makeStream()
let cancellable = messagesPublisher()
.handleEvents(receiveCancel: { continuation.finish() })
.sink(
receiveCompletion: { _ in continuation.finish() },
receiveValue: { continuation.yield($0) }
)
continuation.onTermination = { _ in cancellable.cancel() }
return stream
Comment thread
atdrendel marked this conversation as resolved.
}
}

Expand Down
13 changes: 12 additions & 1 deletion Tests/WebSocketTests/Server/WebSocketServer.swift
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
import Combine
import Foundation
import NIO
import NIOWebSocket
import WebSocket
import WebSocketKit

enum WebSocketServerOutput: Hashable {
enum WebSocketServerOutput {
case message(WebSocketMessage)
case remoteClose
case remoteCloseWithReason(WebSocketErrorCode, Data)
}

final class WebSocketServer {
Expand Down Expand Up @@ -71,6 +73,15 @@ final class WebSocketServer {
do { try ws.close(code: .goingAway).wait() }
catch {}

case let .remoteCloseWithReason(code, reason):
var buffer = ByteBufferAllocator().buffer(capacity: 2 + reason.count)
buffer.write(webSocketErrorCode: code)
buffer.writeBytes(reason)
ws.send(
raw: buffer.readableBytesView,
opcode: .connectionClose
)

case let .message(message):
switch message {
case let .data(data):
Expand Down
Loading
Loading