diff --git a/Sources/AnyLanguageModel/LanguageModelSession.swift b/Sources/AnyLanguageModel/LanguageModelSession.swift index ba38550..1f9cf5d 100644 --- a/Sources/AnyLanguageModel/LanguageModelSession.swift +++ b/Sources/AnyLanguageModel/LanguageModelSession.swift @@ -174,6 +174,7 @@ public final class LanguageModelSession: @unchecked Sendable { public struct Response: Sendable where Content: Generable, Content: Sendable { public let content: Content public let rawContent: GeneratedContent + public let usage: LanguageModelUsage? public let transcriptEntries: ArraySlice /// Creates a response value from generated content and transcript entries. @@ -184,10 +185,12 @@ public final class LanguageModelSession: @unchecked Sendable { public init( content: Content, rawContent: GeneratedContent, + usage: LanguageModelUsage? = nil, transcriptEntries: ArraySlice ) { self.content = content self.rawContent = rawContent + self.usage = usage self.transcriptEntries = transcriptEntries } } @@ -801,8 +804,12 @@ extension LanguageModelSession { /// - Parameters: /// - content: The complete response content. /// - rawContent: The raw content produced by the model. - public init(content: Content, rawContent: GeneratedContent) { - self.fallbackSnapshot = Snapshot(content: content.asPartiallyGenerated(), rawContent: rawContent) + public init(content: Content, rawContent: GeneratedContent, usage: LanguageModelUsage? = nil) { + self.fallbackSnapshot = Snapshot( + content: content.asPartiallyGenerated(), + rawContent: rawContent, + usage: usage + ) self.streaming = nil } @@ -817,14 +824,20 @@ extension LanguageModelSession { public struct Snapshot: Sendable where Content.PartiallyGenerated: Sendable { public var content: Content.PartiallyGenerated public var rawContent: GeneratedContent + public var usage: LanguageModelUsage? /// Creates a snapshot from partially generated content and raw content. /// - Parameters: /// - content: The partially generated content. /// - rawContent: The raw content produced by the model. - public init(content: Content.PartiallyGenerated, rawContent: GeneratedContent) { + public init( + content: Content.PartiallyGenerated, + rawContent: GeneratedContent, + usage: LanguageModelUsage? = nil + ) { self.content = content self.rawContent = rawContent + self.usage = usage } } } @@ -887,6 +900,7 @@ extension LanguageModelSession.ResponseStream: AsyncSequence { return LanguageModelSession.Response( content: finalContent, rawContent: last.rawContent, + usage: last.usage, transcriptEntries: [] ) } @@ -902,6 +916,7 @@ extension LanguageModelSession.ResponseStream: AsyncSequence { return LanguageModelSession.Response( content: finalContent, rawContent: fallbackSnapshot.rawContent, + usage: fallbackSnapshot.usage, transcriptEntries: [] ) } diff --git a/Sources/AnyLanguageModel/LanguageModelUsage.swift b/Sources/AnyLanguageModel/LanguageModelUsage.swift new file mode 100644 index 0000000..0e5359d --- /dev/null +++ b/Sources/AnyLanguageModel/LanguageModelUsage.swift @@ -0,0 +1,89 @@ +/// Provider-reported token usage for a language model response. +/// +/// Values are optional because not every provider or endpoint reports every field. +/// When a single `respond` call internally triggers multiple provider requests +/// (for example while resolving tool calls), the returned usage is aggregated +/// across those underlying requests. +public struct LanguageModelUsage: Hashable, Codable, Sendable { + /// Tokens consumed by the request input or prompt. + public var inputTokens: Int? + + /// Tokens generated in the response. + public var outputTokens: Int? + + /// Total tokens reported for the request. + public var totalTokens: Int? + + /// Tokens spent on reasoning or thinking, when reported separately. + public var reasoningTokens: Int? + + /// Input tokens served from prompt cache, when reported separately. + public var cachedInputTokens: Int? + + /// Input tokens written into a prompt cache, when reported separately. + public var cacheCreationInputTokens: Int? + + public init( + inputTokens: Int? = nil, + outputTokens: Int? = nil, + totalTokens: Int? = nil, + reasoningTokens: Int? = nil, + cachedInputTokens: Int? = nil, + cacheCreationInputTokens: Int? = nil + ) { + self.inputTokens = inputTokens + self.outputTokens = outputTokens + self.totalTokens = totalTokens + self.reasoningTokens = reasoningTokens + self.cachedInputTokens = cachedInputTokens + self.cacheCreationInputTokens = cacheCreationInputTokens + } +} + +extension LanguageModelUsage { + var isEmpty: Bool { + inputTokens == nil + && outputTokens == nil + && totalTokens == nil + && reasoningTokens == nil + && cachedInputTokens == nil + && cacheCreationInputTokens == nil + } + + var normalized: Self? { + isEmpty ? nil : self + } + + mutating func add(_ other: Self?) { + guard let other else { return } + inputTokens = Self.sum(inputTokens, other.inputTokens) + outputTokens = Self.sum(outputTokens, other.outputTokens) + totalTokens = Self.sum(totalTokens, other.totalTokens) + reasoningTokens = Self.sum(reasoningTokens, other.reasoningTokens) + cachedInputTokens = Self.sum(cachedInputTokens, other.cachedInputTokens) + cacheCreationInputTokens = Self.sum(cacheCreationInputTokens, other.cacheCreationInputTokens) + } + + mutating func merge(_ other: Self?) { + guard let other else { return } + inputTokens = other.inputTokens ?? inputTokens + outputTokens = other.outputTokens ?? outputTokens + totalTokens = other.totalTokens ?? totalTokens + reasoningTokens = other.reasoningTokens ?? reasoningTokens + cachedInputTokens = other.cachedInputTokens ?? cachedInputTokens + cacheCreationInputTokens = other.cacheCreationInputTokens ?? cacheCreationInputTokens + } + + private static func sum(_ lhs: Int?, _ rhs: Int?) -> Int? { + switch (lhs, rhs) { + case (.some(let lhs), .some(let rhs)): + lhs + rhs + case (.some(let lhs), .none): + lhs + case (.none, .some(let rhs)): + rhs + case (.none, .none): + nil + } + } +} diff --git a/Sources/AnyLanguageModel/Models/AnthropicLanguageModel.swift b/Sources/AnyLanguageModel/Models/AnthropicLanguageModel.swift index f78cd07..6ac2ba2 100644 --- a/Sources/AnyLanguageModel/Models/AnthropicLanguageModel.swift +++ b/Sources/AnyLanguageModel/Models/AnthropicLanguageModel.swift @@ -345,6 +345,7 @@ public struct AnthropicLanguageModel: LanguageModel { ) var entries: [Transcript.Entry] = [] + let usage = message.usage?.languageModelUsage // Handle tool calls, if present let toolUses: [AnthropicToolUse] = message.content.compactMap { block in @@ -363,6 +364,7 @@ public struct AnthropicLanguageModel: LanguageModel { return LanguageModelSession.Response( content: empty.content, rawContent: empty.rawContent, + usage: usage, transcriptEntries: ArraySlice(entries) ) case .invocations(let invocations): @@ -386,6 +388,7 @@ public struct AnthropicLanguageModel: LanguageModel { return LanguageModelSession.Response( content: text as! Content, rawContent: GeneratedContent(text), + usage: usage, transcriptEntries: ArraySlice(entries) ) } @@ -395,6 +398,7 @@ public struct AnthropicLanguageModel: LanguageModel { return LanguageModelSession.Response( content: content, rawContent: rawContent, + usage: usage, transcriptEntries: ArraySlice(entries) ) } @@ -445,30 +449,48 @@ public struct AnthropicLanguageModel: LanguageModel { var accumulatedText = "" let expectsStructuredResponse = type != String.self + var latestUsage = LanguageModelUsage() + var lastSnapshot: LanguageModelSession.ResponseStream.Snapshot? for try await event in events { switch event { + case .messageStart(let start): + latestUsage.merge(start.message.usage?.languageModelUsage) case .contentBlockDelta(let delta): if case .textDelta(let textDelta) = delta.delta { accumulatedText += textDelta.text if expectsStructuredResponse { - if let snapshot: LanguageModelSession.ResponseStream.Snapshot = + if var snapshot: LanguageModelSession.ResponseStream.Snapshot = try? partialSnapshot(from: accumulatedText) { + snapshot.usage = latestUsage.normalized + lastSnapshot = snapshot continuation.yield(snapshot) } } else { let raw = GeneratedContent(accumulatedText) let content: Content.PartiallyGenerated = (accumulatedText as! Content) .asPartiallyGenerated() - continuation.yield(.init(content: content, rawContent: raw)) + let snapshot = LanguageModelSession.ResponseStream.Snapshot( + content: content, + rawContent: raw, + usage: latestUsage.normalized + ) + lastSnapshot = snapshot + continuation.yield(snapshot) } } + case .messageDelta(let delta): + latestUsage.merge(delta.usage?.languageModelUsage) case .messageStop: + if var lastSnapshot, lastSnapshot.usage != latestUsage.normalized { + lastSnapshot.usage = latestUsage.normalized + continuation.yield(lastSnapshot) + } continuation.finish() return - case .messageStart, .contentBlockStart, .contentBlockStop, .messageDelta, .ping, .ignored: + case .contentBlockStart, .contentBlockStop, .ping, .ignored: break } } @@ -995,9 +1017,10 @@ private struct AnthropicMessageResponse: Codable, Sendable { let content: [AnthropicContent] let model: String let stopReason: StopReason? + let usage: AnthropicUsage? enum CodingKeys: String, CodingKey { - case id, type, role, content, model + case id, type, role, content, model, usage case stopReason = "stop_reason" } @@ -1012,6 +1035,20 @@ private struct AnthropicMessageResponse: Codable, Sendable { } } +private struct AnthropicUsage: Codable, Sendable { + let inputTokens: Int? + let outputTokens: Int? + let cacheCreationInputTokens: Int? + let cacheReadInputTokens: Int? + + enum CodingKeys: String, CodingKey { + case inputTokens = "input_tokens" + case outputTokens = "output_tokens" + case cacheCreationInputTokens = "cache_creation_input_tokens" + case cacheReadInputTokens = "cache_read_input_tokens" + } +} + private struct AnthropicErrorResponse: Codable { let error: AnthropicErrorDetail } private struct AnthropicErrorDetail: Codable { let type: String @@ -1157,6 +1194,7 @@ private enum AnthropicStreamEvent: Codable, Sendable { struct MessageDeltaEvent: Codable, Sendable { let type: String let delta: Delta + let usage: AnthropicUsage? struct Delta: Codable, Sendable { let stopReason: String? @@ -1169,3 +1207,14 @@ private enum AnthropicStreamEvent: Codable, Sendable { } } } + +private extension AnthropicUsage { + var languageModelUsage: LanguageModelUsage? { + LanguageModelUsage( + inputTokens: inputTokens, + outputTokens: outputTokens, + cachedInputTokens: cacheReadInputTokens, + cacheCreationInputTokens: cacheCreationInputTokens + ).normalized + } +} diff --git a/Sources/AnyLanguageModel/Models/GeminiLanguageModel.swift b/Sources/AnyLanguageModel/Models/GeminiLanguageModel.swift index b074ddc..17d3d41 100644 --- a/Sources/AnyLanguageModel/Models/GeminiLanguageModel.swift +++ b/Sources/AnyLanguageModel/Models/GeminiLanguageModel.swift @@ -281,6 +281,7 @@ public struct GeminiLanguageModel: LanguageModel { let geminiTools = try buildTools(from: session.tools, serverTools: effectiveServerTools) var transcript = session.transcript + var usage = LanguageModelUsage() // Multi-turn conversation loop for tool calling while true { @@ -301,6 +302,7 @@ public struct GeminiLanguageModel: LanguageModel { headers: headers, body: body ) + usage.add(response.usageMetadata?.languageModelUsage) guard let firstCandidate = response.candidates.first else { throw GeminiError.noCandidate @@ -324,6 +326,7 @@ public struct GeminiLanguageModel: LanguageModel { return LanguageModelSession.Response( content: empty.content, rawContent: empty.rawContent, + usage: usage.normalized, transcriptEntries: ArraySlice(transcript) ) case .invocations(let invocations): @@ -352,6 +355,7 @@ public struct GeminiLanguageModel: LanguageModel { return LanguageModelSession.Response( content: text as! Content, rawContent: GeneratedContent(text), + usage: usage.normalized, transcriptEntries: ArraySlice(transcript) ) } @@ -361,6 +365,7 @@ public struct GeminiLanguageModel: LanguageModel { return LanguageModelSession.Response( content: content, rawContent: generatedContent, + usage: usage.normalized, transcriptEntries: ArraySlice(transcript) ) } @@ -416,11 +421,15 @@ public struct GeminiLanguageModel: LanguageModel { ) var accumulatedText = "" + var latestUsage = LanguageModelUsage() + var lastSnapshot: LanguageModelSession.ResponseStream.Snapshot? for try await chunk in stream { - guard let candidate = chunk.candidates.first else { continue } + let previousUsage = latestUsage.normalized + latestUsage.merge(chunk.usageMetadata?.languageModelUsage) + var yieldedSnapshot = false - if let parts = candidate.content.parts { + if let candidate = chunk.candidates.first, let parts = candidate.content.parts { for part in parts { if case .text(let textPart) = part { accumulatedText += textPart.text @@ -444,11 +453,23 @@ public struct GeminiLanguageModel: LanguageModel { } if let content { - continuation.yield(.init(content: content, rawContent: raw)) + let snapshot = LanguageModelSession.ResponseStream.Snapshot( + content: content, + rawContent: raw, + usage: latestUsage.normalized + ) + lastSnapshot = snapshot + continuation.yield(snapshot) + yieldedSnapshot = true } } } } + + if !yieldedSnapshot, previousUsage != latestUsage.normalized, var lastSnapshot { + lastSnapshot.usage = latestUsage.normalized + continuation.yield(lastSnapshot) + } } continuation.finish() @@ -1022,6 +1043,17 @@ private struct GeminiUsageMetadata: Codable, Sendable { } } +private extension GeminiUsageMetadata { + var languageModelUsage: LanguageModelUsage? { + LanguageModelUsage( + inputTokens: promptTokenCount, + outputTokens: candidatesTokenCount, + totalTokens: totalTokenCount, + reasoningTokens: thoughtsTokenCount + ).normalized + } +} + enum GeminiError: Error, CustomStringConvertible { case noCandidate diff --git a/Sources/AnyLanguageModel/Models/OllamaLanguageModel.swift b/Sources/AnyLanguageModel/Models/OllamaLanguageModel.swift index 82be5cc..9efb107 100644 --- a/Sources/AnyLanguageModel/Models/OllamaLanguageModel.swift +++ b/Sources/AnyLanguageModel/Models/OllamaLanguageModel.swift @@ -124,6 +124,7 @@ public struct OllamaLanguageModel: LanguageModel { return LanguageModelSession.Response( content: "" as! Content, rawContent: GeneratedContent(""), + usage: chatResponse.usage, transcriptEntries: ArraySlice(entries) ) case .invocations(let invocations): @@ -141,6 +142,7 @@ public struct OllamaLanguageModel: LanguageModel { return LanguageModelSession.Response( content: text as! Content, rawContent: GeneratedContent(text), + usage: chatResponse.usage, transcriptEntries: ArraySlice(entries) ) } @@ -150,6 +152,7 @@ public struct OllamaLanguageModel: LanguageModel { return LanguageModelSession.Response( content: content, rawContent: generatedContent, + usage: chatResponse.usage, transcriptEntries: ArraySlice(entries) ) } @@ -207,23 +210,31 @@ public struct OllamaLanguageModel: LanguageModel { ) as AsyncThrowingStream var partialText = "" + var latestUsage = LanguageModelUsage() + var lastSnapshot: LanguageModelSession.ResponseStream.Snapshot? for try await chunk in chunks { + let previousUsage = latestUsage.normalized + latestUsage.merge(chunk.usage) if let piece = chunk.message.content { partialText += piece if type == String.self { let snapshot = LanguageModelSession.ResponseStream.Snapshot( content: (partialText as! Content).asPartiallyGenerated(), - rawContent: GeneratedContent(partialText) + rawContent: GeneratedContent(partialText), + usage: latestUsage.normalized ) + lastSnapshot = snapshot continuation.yield(snapshot) } else if let raw = try? GeneratedContent(json: partialText), let parsed = try? type.init(raw) { let snapshot = LanguageModelSession.ResponseStream.Snapshot( content: parsed.asPartiallyGenerated(), - rawContent: raw + rawContent: raw, + usage: latestUsage.normalized ) + lastSnapshot = snapshot continuation.yield(snapshot) } else { // Structured responses can stream as incomplete JSON fragments. @@ -231,6 +242,14 @@ public struct OllamaLanguageModel: LanguageModel { } } + if chunk.message.content == nil, + previousUsage != latestUsage.normalized, + var lastSnapshot + { + lastSnapshot.usage = latestUsage.normalized + continuation.yield(lastSnapshot) + } + if chunk.done { break } @@ -532,12 +551,37 @@ private struct ChatResponse: Decodable, Sendable { let createdAt: Date let message: ChatMessageResponse let done: Bool + let promptEvalCount: Int? + let evalCount: Int? private enum CodingKeys: String, CodingKey { case model case createdAt = "created_at" case message case done + case promptEvalCount = "prompt_eval_count" + case evalCount = "eval_count" + } +} + +private extension ChatResponse { + var usage: LanguageModelUsage? { + LanguageModelUsage( + inputTokens: promptEvalCount, + outputTokens: evalCount, + totalTokens: { + switch (promptEvalCount, evalCount) { + case (.some(let promptEvalCount), .some(let evalCount)): + promptEvalCount + evalCount + case (.some(let promptEvalCount), .none): + promptEvalCount + case (.none, .some(let evalCount)): + evalCount + case (.none, .none): + nil + } + }() + ).normalized } } diff --git a/Sources/AnyLanguageModel/Models/OpenAILanguageModel.swift b/Sources/AnyLanguageModel/Models/OpenAILanguageModel.swift index 20bca5e..46dabec 100644 --- a/Sources/AnyLanguageModel/Models/OpenAILanguageModel.swift +++ b/Sources/AnyLanguageModel/Models/OpenAILanguageModel.swift @@ -471,6 +471,7 @@ public struct OpenAILanguageModel: LanguageModel { var entries: [Transcript.Entry] = [] var text = "" var messages = messages + var usage = LanguageModelUsage() // Loop until no more tool calls while true { @@ -493,6 +494,7 @@ public struct OpenAILanguageModel: LanguageModel { ], body: body ) + usage.add(resp.usage?.languageModelUsage) guard let choice = resp.choices.first else { throw OpenAILanguageModelError.noResponseGenerated @@ -525,6 +527,7 @@ public struct OpenAILanguageModel: LanguageModel { return LanguageModelSession.Response( content: empty.content, rawContent: empty.rawContent, + usage: usage.normalized, transcriptEntries: ArraySlice(entries) ) case .invocations(let invocations): @@ -553,6 +556,7 @@ public struct OpenAILanguageModel: LanguageModel { return LanguageModelSession.Response( content: text as! Content, rawContent: GeneratedContent(text), + usage: usage.normalized, transcriptEntries: ArraySlice(entries) ) } @@ -562,6 +566,7 @@ public struct OpenAILanguageModel: LanguageModel { return LanguageModelSession.Response( content: content, rawContent: generatedContent, + usage: usage.normalized, transcriptEntries: ArraySlice(entries) ) } @@ -577,6 +582,7 @@ public struct OpenAILanguageModel: LanguageModel { var text = "" var lastOutput: [JSONValue]? var messages = messages + var usage = LanguageModelUsage() let url = baseURL.appendingPathComponent("responses") @@ -601,6 +607,7 @@ public struct OpenAILanguageModel: LanguageModel { ], body: body ) + usage.add(resp.usage?.languageModelUsage) let toolCalls = extractToolCallsFromOutput(resp.output) lastOutput = resp.output @@ -620,6 +627,7 @@ public struct OpenAILanguageModel: LanguageModel { return LanguageModelSession.Response( content: empty.content, rawContent: empty.rawContent, + usage: usage.normalized, transcriptEntries: ArraySlice(entries) ) case .invocations(let invocations): @@ -650,6 +658,7 @@ public struct OpenAILanguageModel: LanguageModel { return LanguageModelSession.Response( content: text as! Content, rawContent: GeneratedContent(text), + usage: usage.normalized, transcriptEntries: ArraySlice(entries) ) } @@ -660,6 +669,7 @@ public struct OpenAILanguageModel: LanguageModel { return LanguageModelSession.Response( content: content, rawContent: generatedContent, + usage: usage.normalized, transcriptEntries: ArraySlice(entries) ) } @@ -714,6 +724,8 @@ public struct OpenAILanguageModel: LanguageModel { ) var accumulatedText = "" + var latestUsage: LanguageModelUsage? + var lastSnapshot: LanguageModelSession.ResponseStream.Snapshot? for try await event in events { switch event { @@ -739,7 +751,13 @@ public struct OpenAILanguageModel: LanguageModel { } if let content { - continuation.yield(.init(content: content, rawContent: raw)) + let snapshot = LanguageModelSession.ResponseStream.Snapshot( + content: content, + rawContent: raw, + usage: latestUsage + ) + lastSnapshot = snapshot + continuation.yield(snapshot) } case .toolCallCreated(_): @@ -748,7 +766,12 @@ public struct OpenAILanguageModel: LanguageModel { case .toolCallDelta(_): // Minimal streaming implementation ignores tool call deltas break - case .completed(_): + case .completed(let usage): + latestUsage = usage ?? latestUsage + if var lastSnapshot, lastSnapshot.usage != latestUsage { + lastSnapshot.usage = latestUsage + continuation.yield(lastSnapshot) + } continuation.finish() case .ignored: break @@ -798,8 +821,18 @@ public struct OpenAILanguageModel: LanguageModel { ) var accumulatedText = "" + var latestUsage: LanguageModelUsage? + var lastSnapshot: LanguageModelSession.ResponseStream.Snapshot? for try await chunk in events { + if let usage = chunk.usage?.languageModelUsage { + latestUsage = usage + if chunk.choices.isEmpty, var lastSnapshot, lastSnapshot.usage != latestUsage { + lastSnapshot.usage = latestUsage + continuation.yield(lastSnapshot) + } + } + if let choice = chunk.choices.first { if let piece = choice.delta.content, !piece.isEmpty { accumulatedText += piece @@ -823,13 +856,16 @@ public struct OpenAILanguageModel: LanguageModel { } if let content { - continuation.yield(.init(content: content, rawContent: raw)) + let snapshot = LanguageModelSession.ResponseStream.Snapshot( + content: content, + rawContent: raw, + usage: latestUsage + ) + lastSnapshot = snapshot + continuation.yield(snapshot) } } - if choice.finishReason != nil { - continuation.finish() - } } } @@ -866,6 +902,10 @@ private enum ChatCompletions { "stream": .bool(stream), ] + if stream { + body["stream_options"] = .object(["include_usage": .bool(true)]) + } + if let tools { body["tools"] = .array(tools.map { $0.jsonValue(for: .chatCompletions) }) } @@ -973,6 +1013,39 @@ private enum ChatCompletions { struct Response: Decodable, Sendable { let id: String let choices: [Choice] + let usage: Usage? + + struct Usage: Codable, Sendable { + let promptTokens: Int? + let completionTokens: Int? + let totalTokens: Int? + let promptTokensDetails: PromptTokensDetails? + let completionTokensDetails: CompletionTokensDetails? + + enum CodingKeys: String, CodingKey { + case promptTokens = "prompt_tokens" + case completionTokens = "completion_tokens" + case totalTokens = "total_tokens" + case promptTokensDetails = "prompt_tokens_details" + case completionTokensDetails = "completion_tokens_details" + } + + struct PromptTokensDetails: Codable, Sendable { + let cachedTokens: Int? + + enum CodingKeys: String, CodingKey { + case cachedTokens = "cached_tokens" + } + } + + struct CompletionTokensDetails: Codable, Sendable { + let reasoningTokens: Int? + + enum CodingKeys: String, CodingKey { + case reasoningTokens = "reasoning_tokens" + } + } + } struct Choice: Codable, Sendable { let message: Message @@ -1225,6 +1298,39 @@ private enum Responses { let error: [JSONValue]? let outputText: String? let finishReason: String? + let usage: Usage? + + struct Usage: Decodable, Sendable { + let inputTokens: Int? + let outputTokens: Int? + let totalTokens: Int? + let inputTokensDetails: InputTokensDetails? + let outputTokensDetails: OutputTokensDetails? + + enum CodingKeys: String, CodingKey { + case inputTokens = "input_tokens" + case outputTokens = "output_tokens" + case totalTokens = "total_tokens" + case inputTokensDetails = "input_tokens_details" + case outputTokensDetails = "output_tokens_details" + } + + struct InputTokensDetails: Decodable, Sendable { + let cachedTokens: Int? + + enum CodingKeys: String, CodingKey { + case cachedTokens = "cached_tokens" + } + } + + struct OutputTokensDetails: Decodable, Sendable { + let reasoningTokens: Int? + + enum CodingKeys: String, CodingKey { + case reasoningTokens = "reasoning_tokens" + } + } + } private enum CodingKeys: String, CodingKey { case id @@ -1232,6 +1338,7 @@ private enum Responses { case outputText = "output_text" case finishReason = "finish_reason" case error = "error" + case usage } } } @@ -1554,7 +1661,7 @@ private enum OpenAIResponsesServerEvent: Decodable, Sendable { case outputTextDelta(String) case toolCallCreated(OpenAIToolCall) case toolCallDelta(OpenAIToolCall) - case completed(String) + case completed(LanguageModelUsage?) case ignored init(from decoder: any Decoder) throws { @@ -1568,7 +1675,10 @@ private enum OpenAIResponsesServerEvent: Decodable, Sendable { case "response.tool_call.delta": self = .toolCallDelta(try container.decode(OpenAIToolCall.self, forKey: .toolCall)) case "response.completed": - self = .completed((try? container.decode(String.self, forKey: .finishReason)) ?? "stop") + let usage = + (try? container.decode(Responses.Response.self, forKey: .response))?.usage?.languageModelUsage + ?? (try? container.decode(Responses.Response.Usage.self, forKey: .usage))?.languageModelUsage + self = .completed(usage) default: self = .ignored } @@ -1579,6 +1689,8 @@ private enum OpenAIResponsesServerEvent: Decodable, Sendable { case delta case toolCall = "tool_call" case finishReason = "finish_reason" + case response + case usage } } @@ -1599,6 +1711,31 @@ private struct OpenAIChatCompletionsChunk: Decodable, Sendable { let id: String let choices: [Choice] + let usage: ChatCompletions.Response.Usage? +} + +private extension ChatCompletions.Response.Usage { + var languageModelUsage: LanguageModelUsage? { + LanguageModelUsage( + inputTokens: promptTokens, + outputTokens: completionTokens, + totalTokens: totalTokens, + reasoningTokens: completionTokensDetails?.reasoningTokens, + cachedInputTokens: promptTokensDetails?.cachedTokens + ).normalized + } +} + +private extension Responses.Response.Usage { + var languageModelUsage: LanguageModelUsage? { + LanguageModelUsage( + inputTokens: inputTokens, + outputTokens: outputTokens, + totalTokens: totalTokens, + reasoningTokens: outputTokensDetails?.reasoningTokens, + cachedInputTokens: inputTokensDetails?.cachedTokens + ).normalized + } } private struct OpenAIToolInvocationResult { diff --git a/Sources/AnyLanguageModel/Models/OpenResponsesLanguageModel.swift b/Sources/AnyLanguageModel/Models/OpenResponsesLanguageModel.swift index 8370e14..0bb3dbb 100644 --- a/Sources/AnyLanguageModel/Models/OpenResponsesLanguageModel.swift +++ b/Sources/AnyLanguageModel/Models/OpenResponsesLanguageModel.swift @@ -440,6 +440,8 @@ public struct OpenResponsesLanguageModel: LanguageModel { body: body ) var accumulatedText = "" + var latestUsage: LanguageModelUsage? + var lastSnapshot: LanguageModelSession.ResponseStream.Snapshot? for try await event in events { switch event { case .outputTextDelta(let delta): @@ -456,9 +458,20 @@ public struct OpenResponsesLanguageModel: LanguageModel { content = (try? type.init(raw))?.asPartiallyGenerated() } if let content { - continuation.yield(.init(content: content, rawContent: raw)) + let snapshot = LanguageModelSession.ResponseStream.Snapshot( + content: content, + rawContent: raw, + usage: latestUsage + ) + lastSnapshot = snapshot + continuation.yield(snapshot) + } + case .completed(let usage): + latestUsage = usage ?? latestUsage + if var lastSnapshot, lastSnapshot.usage != latestUsage { + lastSnapshot.usage = latestUsage + continuation.yield(lastSnapshot) } - case .completed: continuation.finish() return case .failed: @@ -493,6 +506,7 @@ public struct OpenResponsesLanguageModel: LanguageModel { var text = "" var lastOutput: [JSONValue]? var messages = messages + var usage = LanguageModelUsage() let url = baseURL.appendingPathComponent("responses") while true { @@ -511,6 +525,7 @@ public struct OpenResponsesLanguageModel: LanguageModel { headers: ["Authorization": "Bearer \(tokenProvider())"], body: body ) + usage.add(resp.usage?.languageModelUsage) let toolCalls = extractToolCallsFromOutput(resp.output) lastOutput = resp.output @@ -530,6 +545,7 @@ public struct OpenResponsesLanguageModel: LanguageModel { return LanguageModelSession.Response( content: empty.content, rawContent: empty.rawContent, + usage: usage.normalized, transcriptEntries: ArraySlice(entries) ) case .invocations(let invocations): @@ -557,6 +573,7 @@ public struct OpenResponsesLanguageModel: LanguageModel { return LanguageModelSession.Response( content: text as! Content, rawContent: GeneratedContent(text), + usage: usage.normalized, transcriptEntries: ArraySlice(entries) ) } @@ -566,6 +583,7 @@ public struct OpenResponsesLanguageModel: LanguageModel { return LanguageModelSession.Response( content: content, rawContent: generatedContent, + usage: usage.normalized, transcriptEntries: ArraySlice(entries) ) } @@ -731,12 +749,46 @@ private enum OpenResponsesAPI { let output: [JSONValue]? let outputText: String? let error: OpenResponsesError? + let usage: Usage? + + struct Usage: Decodable, Sendable { + let inputTokens: Int? + let outputTokens: Int? + let totalTokens: Int? + let inputTokensDetails: InputTokensDetails? + let outputTokensDetails: OutputTokensDetails? + + enum CodingKeys: String, CodingKey { + case inputTokens = "input_tokens" + case outputTokens = "output_tokens" + case totalTokens = "total_tokens" + case inputTokensDetails = "input_tokens_details" + case outputTokensDetails = "output_tokens_details" + } + + struct InputTokensDetails: Decodable, Sendable { + let cachedTokens: Int? + + enum CodingKeys: String, CodingKey { + case cachedTokens = "cached_tokens" + } + } + + struct OutputTokensDetails: Decodable, Sendable { + let reasoningTokens: Int? + + enum CodingKeys: String, CodingKey { + case reasoningTokens = "reasoning_tokens" + } + } + } private enum CodingKeys: String, CodingKey { case id case output case outputText = "output_text" case error + case usage } } @@ -1108,7 +1160,7 @@ private func resolveToolCalls( private enum OpenResponsesStreamEvent: Decodable, Sendable { case outputTextDelta(String) - case completed + case completed(LanguageModelUsage?) case failed case ignored @@ -1119,14 +1171,29 @@ private enum OpenResponsesStreamEvent: Decodable, Sendable { case "response.output_text.delta": self = .outputTextDelta(try c.decode(String.self, forKey: .delta)) case "response.completed": - self = .completed + let usage = + (try? c.decode(OpenResponsesAPI.Response.self, forKey: .response))?.usage?.languageModelUsage + ?? (try? c.decode(OpenResponsesAPI.Response.Usage.self, forKey: .usage))?.languageModelUsage + self = .completed(usage) case "response.failed": self = .failed default: self = .ignored } } - private enum CodingKeys: String, CodingKey { case type, delta } + private enum CodingKeys: String, CodingKey { case type, delta, response, usage } +} + +private extension OpenResponsesAPI.Response.Usage { + var languageModelUsage: LanguageModelUsage? { + LanguageModelUsage( + inputTokens: inputTokens, + outputTokens: outputTokens, + totalTokens: totalTokens, + reasoningTokens: outputTokensDetails?.reasoningTokens, + cachedInputTokens: inputTokensDetails?.cachedTokens + ).normalized + } } // MARK: - Errors diff --git a/Tests/AnyLanguageModelTests/GeminiLanguageModelTests.swift b/Tests/AnyLanguageModelTests/GeminiLanguageModelTests.swift index 9375851..7934ee6 100644 --- a/Tests/AnyLanguageModelTests/GeminiLanguageModelTests.swift +++ b/Tests/AnyLanguageModelTests/GeminiLanguageModelTests.swift @@ -25,6 +25,22 @@ struct GeminiLanguageModelTests { #expect(!response.content.isEmpty) } + @Test func reportsUsage() async throws { + let session = LanguageModelSession(model: model) + + var options = GenerationOptions(maximumResponseTokens: 64) + options[custom: GeminiLanguageModel.self] = .init(thinking: .disabled) + + let response = try await session.respond( + to: "Reply with exactly: OK", + options: options + ) + + #expect(!response.content.isEmpty) + #expect((response.usage?.inputTokens ?? 0) > 0) + #expect((response.usage?.totalTokens ?? 0) > 0) + } + @Test func withInstructions() async throws { let session = LanguageModelSession( model: model, diff --git a/Tests/AnyLanguageModelTests/MockLanguageModelTests.swift b/Tests/AnyLanguageModelTests/MockLanguageModelTests.swift index ad8e1c0..8358936 100644 --- a/Tests/AnyLanguageModelTests/MockLanguageModelTests.swift +++ b/Tests/AnyLanguageModelTests/MockLanguageModelTests.swift @@ -33,6 +33,51 @@ private func waitUntil( return true } +private struct UsageReportingModel: LanguageModel { + typealias UnavailableReason = Never + + let usage: LanguageModelUsage + + func respond( + within session: LanguageModelSession, + to prompt: Prompt, + generating type: Content.Type, + includeSchemaInPrompt: Bool, + options: GenerationOptions + ) async throws -> LanguageModelSession.Response where Content: Generable { + #expect(type == String.self) + return LanguageModelSession.Response( + content: "usage" as! Content, + rawContent: GeneratedContent("usage"), + usage: usage, + transcriptEntries: [] + ) + } + + func streamResponse( + within session: LanguageModelSession, + to prompt: Prompt, + generating type: Content.Type, + includeSchemaInPrompt: Bool, + options: GenerationOptions + ) -> sending LanguageModelSession.ResponseStream where Content: Generable { + #expect(type == String.self) + let stream = AsyncThrowingStream.Snapshot, any Error> { + continuation in + continuation.yield( + .init( + content: ("usage" as! Content).asPartiallyGenerated(), + rawContent: GeneratedContent("usage"), + usage: usage + ) + ) + continuation.finish() + } + + return LanguageModelSession.ResponseStream(stream: stream) + } +} + @Suite("MockLanguageModel") struct MockLanguageModelTests { @Test func fixedResponse() async throws { @@ -375,4 +420,45 @@ struct MockLanguageModelTests { Issue.record("First entry should be prompt with images") } } + + @Test func responseUsageIsExposed() async throws { + let usage = LanguageModelUsage( + inputTokens: 12, + outputTokens: 7, + totalTokens: 19, + reasoningTokens: 3, + cachedInputTokens: 4 + ) + let session = LanguageModelSession(model: UsageReportingModel(usage: usage)) + + let response = try await session.respond(to: "Track usage") + + #expect(response.usage == usage) + } + + @Test func streamingUsageIsExposedAndCollected() async throws { + let usage = LanguageModelUsage( + inputTokens: 5, + outputTokens: 2, + totalTokens: 7 + ) + let session = LanguageModelSession(model: UsageReportingModel(usage: usage)) + + let stream = session.streamResponse(to: "Track streaming usage") + + var snapshots: [LanguageModelSession.ResponseStream.Snapshot] = [] + for try await snapshot in stream { + snapshots.append(snapshot) + } + + #expect(snapshots.count == 1) + #expect(snapshots.first?.usage == usage) + + let collected = try await LanguageModelSession.ResponseStream( + content: "usage", + rawContent: GeneratedContent("usage"), + usage: usage + ).collect() + #expect(collected.usage == usage) + } } diff --git a/Tests/AnyLanguageModelTests/OpenAILanguageModelTests.swift b/Tests/AnyLanguageModelTests/OpenAILanguageModelTests.swift index 4b4d2a5..bee0dbf 100644 --- a/Tests/AnyLanguageModelTests/OpenAILanguageModelTests.swift +++ b/Tests/AnyLanguageModelTests/OpenAILanguageModelTests.swift @@ -74,6 +74,21 @@ struct OpenAILanguageModelTests { #expect(!snapshots.last!.rawContent.jsonString.isEmpty) } + @Test func streamingReportsUsage() async throws { + let session = LanguageModelSession(model: model) + + let stream = session.streamResponse(to: "Reply with exactly: OK") + + var snapshots: [LanguageModelSession.ResponseStream.Snapshot] = [] + for try await snapshot in stream { + snapshots.append(snapshot) + } + + #expect(!snapshots.isEmpty) + #expect((snapshots.last?.usage?.inputTokens ?? 0) > 0) + #expect((snapshots.last?.usage?.totalTokens ?? 0) > 0) + } + @Test func withGenerationOptions() async throws { let session = LanguageModelSession(model: model) @@ -313,6 +328,16 @@ struct OpenAILanguageModelTests { #expect(!response.content.isEmpty) } + @Test func reportsUsage() async throws { + let session = LanguageModelSession(model: model) + + let response = try await session.respond(to: "Reply with exactly: OK") + + #expect(!response.content.isEmpty) + #expect((response.usage?.inputTokens ?? 0) > 0) + #expect((response.usage?.totalTokens ?? 0) > 0) + } + @Test func withInstructions() async throws { let session = LanguageModelSession( model: model,