Skip to content
159 changes: 131 additions & 28 deletions Sources/AnyLanguageModel/Extensions/URLSession+Extensions.swift
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,64 @@ enum HTTP {
}
}

#if canImport(FoundationNetworking)
/// Serializes Linux URLSession operations to mitigate a FoundationNetworking race.
///
/// AnyLanguageModel performs many concurrent HTTP requests across model implementations.
/// On Linux, `FoundationNetworking` routes `URLSession` through a shared
/// `_MultiHandle`, which has a known thread-safety bug that can crash under
/// concurrent access (`URLSession._MultiHandle.endOperation(for:)`).
///
/// This gate intentionally allows only one in-flight request path at a time on Linux.
Comment thread
mattt marked this conversation as resolved.
Outdated
/// This fully serializes HTTP request setup paths on Linux and reduces request-level
/// parallelism, which can lower throughput for heavily concurrent workloads.
Comment thread
mattt marked this conversation as resolved.
Outdated
/// Keep this scoped to Linux-only code paths until the upstream issue is resolved.
///
/// See: https://github.com/swiftlang/swift-corelibs-foundation/issues/4791
actor LinuxURLSessionRequestGate {
static let shared = LinuxURLSessionRequestGate()

private var isLocked = false
private var waiters: [CheckedContinuation<Void, Never>] = []

func acquire() async {
if !isLocked {
isLocked = true
return
}

await withCheckedContinuation { continuation in
waiters.append(continuation)
}
}
Comment thread
mattt marked this conversation as resolved.
Outdated

func release() {
if waiters.isEmpty {
isLocked = false
return
}

let continuation = waiters.removeFirst()
continuation.resume()
}

}

func withLinuxRequestLock(
_ operation: () async throws -> Void
) async throws {
let gate = LinuxURLSessionRequestGate.shared
await gate.acquire()
do {
try await operation()
await gate.release()
} catch {
await gate.release()
throw error
}
Comment thread
mattt marked this conversation as resolved.
}
#endif

extension URLSession {
func fetch<T: Decodable>(
_ method: HTTP.Method,
Expand All @@ -34,7 +92,20 @@ extension URLSession {
request.addValue("application/json", forHTTPHeaderField: "Content-Type")
}

let (data, response) = try await data(for: request)
#if canImport(FoundationNetworking)
var lockedData: Data?
var lockedResponse: URLResponse?
try await withLinuxRequestLock {
let (data, response) = try await data(for: request)
lockedData = data
lockedResponse = response
}
guard let data = lockedData, let response = lockedResponse else {
throw URLSessionError.invalidResponse
}
Comment thread
mattt marked this conversation as resolved.
Comment thread
mattt marked this conversation as resolved.
#else
let (data, response) = try await data(for: request)
#endif

guard let httpResponse = response as? HTTPURLResponse else {
throw URLSessionError.invalidResponse
Expand Down Expand Up @@ -83,7 +154,20 @@ extension URLSession {
request.addValue("application/json", forHTTPHeaderField: "Content-Type")
}

let (data, response) = try await self.data(for: request)
#if canImport(FoundationNetworking)
var lockedData: Data?
var lockedResponse: URLResponse?
try await withLinuxRequestLock {
let (data, response) = try await self.data(for: request)
lockedData = data
lockedResponse = response
}
guard let data = lockedData, let response = lockedResponse else {
throw URLSessionError.invalidResponse
}
#else
let (data, response) = try await self.data(for: request)
#endif

guard let httpResponse = response as? HTTPURLResponse else {
throw URLSessionError.invalidResponse
Expand Down Expand Up @@ -143,35 +227,21 @@ extension URLSession {
}

#if canImport(FoundationNetworking)
let (asyncBytes, response) = try await self.linuxBytes(for: request)
var lockedAsyncBytes: AsyncThrowingStream<UInt8, Error>?
try await withLinuxRequestLock {
let (bytes, response) = try await self.linuxBytes(for: request)
try await self.validateEventStreamResponse(response, asyncBytes: bytes)
lockedAsyncBytes = bytes
}
guard let asyncBytes = lockedAsyncBytes else {
throw URLSessionError.invalidResponse
}
Comment thread
mattt marked this conversation as resolved.
try await decodeAndYieldEventStream(asyncBytes, to: continuation)
Comment thread
mattt marked this conversation as resolved.
#else
let (asyncBytes, response) = try await self.bytes(for: request)
try await validateEventStreamResponse(response, asyncBytes: asyncBytes)
try await decodeAndYieldEventStream(asyncBytes, to: continuation)
#endif

guard let httpResponse = response as? HTTPURLResponse else {
throw URLSessionError.invalidResponse
}

guard (200 ..< 300).contains(httpResponse.statusCode) else {
var errorData = Data()
for try await byte in asyncBytes {
errorData.append(byte)
}
if let errorString = String(data: errorData, encoding: .utf8) {
throw URLSessionError.httpError(statusCode: httpResponse.statusCode, detail: errorString)
}
throw URLSessionError.httpError(statusCode: httpResponse.statusCode, detail: "Invalid response")
}

let decoder = JSONDecoder()

for try await event in asyncBytes.events {
guard let data = event.data.data(using: .utf8) else { continue }
if let decoded = try? decoder.decode(T.self, from: data) {
continuation.yield(decoded)
}
}

continuation.finish()
} catch {
continuation.finish(throwing: error)
Expand All @@ -183,6 +253,39 @@ extension URLSession {
}
}
}

private func validateEventStreamResponse<Bytes>(
_ response: URLResponse,
asyncBytes: Bytes
) async throws where Bytes: AsyncSequence, Bytes.Element == UInt8 {
guard let httpResponse = response as? HTTPURLResponse else {
throw URLSessionError.invalidResponse
}

guard (200 ..< 300).contains(httpResponse.statusCode) else {
var errorData = Data()
for try await byte in asyncBytes {
errorData.append(byte)
}
if let errorString = String(data: errorData, encoding: .utf8) {
throw URLSessionError.httpError(statusCode: httpResponse.statusCode, detail: errorString)
}
throw URLSessionError.httpError(statusCode: httpResponse.statusCode, detail: "Invalid response")
}
}

private func decodeAndYieldEventStream<T: Decodable & Sendable, Bytes>(
_ asyncBytes: Bytes,
to continuation: AsyncThrowingStream<T, any Error>.Continuation
) async throws where Bytes: AsyncSequence, Bytes.Element == UInt8 {
let decoder = JSONDecoder()
for try await event in asyncBytes.events {
guard let data = event.data.data(using: .utf8) else { continue }
if let decoded = try? decoder.decode(T.self, from: data) {
continuation.yield(decoded)
}
}
}
}

#if canImport(FoundationNetworking)
Expand Down
82 changes: 82 additions & 0 deletions Tests/AnyLanguageModelTests/URLSessionExtensionsTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -19,3 +19,85 @@ struct URLSessionExtensionsTests {
#expect(error.description == "Decoding error: keyNotFound")
}
}

#if canImport(FoundationNetworking)
private actor GateCounter {
private(set) var current = 0
private(set) var maxConcurrent = 0

func enter() {
current += 1
maxConcurrent = max(maxConcurrent, current)
}

func leave() {
current -= 1
}
}

private enum GateTestError: Error {
case expected
}

extension URLSessionExtensionsTests {
@Test func linuxGateSerializesConcurrentOperations() async throws {
let counter = GateCounter()

try await withThrowingTaskGroup(of: Void.self) { group in
for _ in 0 ..< 8 {
group.addTask {
try await withLinuxRequestLock {
await counter.enter()
do {
try await Task.sleep(for: .milliseconds(20))
await counter.leave()
} catch {
await counter.leave()
throw error
}
}
}
}
try await group.waitForAll()
}

#expect(await counter.maxConcurrent == 1)
}

@Test func linuxGateReleasesAfterError() async throws {
do {
try await withLinuxRequestLock {
throw GateTestError.expected
}
Issue.record("Expected error was not thrown")
} catch GateTestError.expected {
// expected
}

var ranSecondOperation = false
try await withLinuxRequestLock {
ranSecondOperation = true
}
#expect(ranSecondOperation)
}

@Test func linuxGateReleasesAfterCancellation() async throws {
let longTask = Task {
try await withLinuxRequestLock {
try await Task.sleep(for: .seconds(10))
}
}

try await Task.sleep(for: .milliseconds(30))
longTask.cancel()
_ = await longTask.result

var acquiredAfterCancellation = false
try await withLinuxRequestLock {
acquiredAfterCancellation = true
}

#expect(acquiredAfterCancellation)
}
}
#endif