Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 18 additions & 3 deletions Sources/AnyLanguageModel/LanguageModelSession.swift
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,7 @@ public final class LanguageModelSession: @unchecked Sendable {
public struct Response<Content>: Sendable where Content: Generable, Content: Sendable {
public let content: Content
public let rawContent: GeneratedContent
public let usage: LanguageModelUsage?
public let transcriptEntries: ArraySlice<Transcript.Entry>

/// Creates a response value from generated content and transcript entries.
Expand All @@ -184,10 +185,12 @@ public final class LanguageModelSession: @unchecked Sendable {
public init(
content: Content,
rawContent: GeneratedContent,
usage: LanguageModelUsage? = nil,
transcriptEntries: ArraySlice<Transcript.Entry>
) {
self.content = content
self.rawContent = rawContent
self.usage = usage
self.transcriptEntries = transcriptEntries
}
}
Expand Down Expand Up @@ -801,8 +804,12 @@ extension LanguageModelSession {
/// - Parameters:
/// - content: The complete response content.
/// - rawContent: The raw content produced by the model.
public init(content: Content, rawContent: GeneratedContent) {
self.fallbackSnapshot = Snapshot(content: content.asPartiallyGenerated(), rawContent: rawContent)
public init(content: Content, rawContent: GeneratedContent, usage: LanguageModelUsage? = nil) {
self.fallbackSnapshot = Snapshot(
content: content.asPartiallyGenerated(),
rawContent: rawContent,
usage: usage
)
self.streaming = nil
}

Expand All @@ -817,14 +824,20 @@ extension LanguageModelSession {
public struct Snapshot: Sendable where Content.PartiallyGenerated: Sendable {
public var content: Content.PartiallyGenerated
public var rawContent: GeneratedContent
public var usage: LanguageModelUsage?

/// Creates a snapshot from partially generated content and raw content.
/// - Parameters:
/// - content: The partially generated content.
/// - rawContent: The raw content produced by the model.
public init(content: Content.PartiallyGenerated, rawContent: GeneratedContent) {
public init(
content: Content.PartiallyGenerated,
rawContent: GeneratedContent,
usage: LanguageModelUsage? = nil
) {
self.content = content
self.rawContent = rawContent
self.usage = usage
}
}
}
Expand Down Expand Up @@ -887,6 +900,7 @@ extension LanguageModelSession.ResponseStream: AsyncSequence {
return LanguageModelSession.Response(
content: finalContent,
rawContent: last.rawContent,
usage: last.usage,
transcriptEntries: []
)
}
Expand All @@ -902,6 +916,7 @@ extension LanguageModelSession.ResponseStream: AsyncSequence {
return LanguageModelSession.Response(
content: finalContent,
rawContent: fallbackSnapshot.rawContent,
usage: fallbackSnapshot.usage,
transcriptEntries: []
)
}
Expand Down
89 changes: 89 additions & 0 deletions Sources/AnyLanguageModel/LanguageModelUsage.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
/// Provider-reported token usage for a language model response.
///
/// Values are optional because not every provider or endpoint reports every field.
/// When a single `respond` call internally triggers multiple provider requests
/// (for example while resolving tool calls), the returned usage is aggregated
/// across those underlying requests.
public struct LanguageModelUsage: Hashable, Codable, Sendable {
/// Tokens consumed by the request input or prompt.
public var inputTokens: Int?

/// Tokens generated in the response.
public var outputTokens: Int?

/// Total tokens reported for the request.
public var totalTokens: Int?

/// Tokens spent on reasoning or thinking, when reported separately.
public var reasoningTokens: Int?

/// Input tokens served from prompt cache, when reported separately.
public var cachedInputTokens: Int?

/// Input tokens written into a prompt cache, when reported separately.
public var cacheCreationInputTokens: Int?

public init(
inputTokens: Int? = nil,
outputTokens: Int? = nil,
totalTokens: Int? = nil,
reasoningTokens: Int? = nil,
cachedInputTokens: Int? = nil,
cacheCreationInputTokens: Int? = nil
) {
self.inputTokens = inputTokens
self.outputTokens = outputTokens
self.totalTokens = totalTokens
self.reasoningTokens = reasoningTokens
self.cachedInputTokens = cachedInputTokens
self.cacheCreationInputTokens = cacheCreationInputTokens
}
}

extension LanguageModelUsage {
var isEmpty: Bool {
inputTokens == nil
&& outputTokens == nil
&& totalTokens == nil
&& reasoningTokens == nil
&& cachedInputTokens == nil
&& cacheCreationInputTokens == nil
}

var normalized: Self? {
isEmpty ? nil : self
}

mutating func add(_ other: Self?) {
guard let other else { return }
inputTokens = Self.sum(inputTokens, other.inputTokens)
outputTokens = Self.sum(outputTokens, other.outputTokens)
totalTokens = Self.sum(totalTokens, other.totalTokens)
reasoningTokens = Self.sum(reasoningTokens, other.reasoningTokens)
cachedInputTokens = Self.sum(cachedInputTokens, other.cachedInputTokens)
cacheCreationInputTokens = Self.sum(cacheCreationInputTokens, other.cacheCreationInputTokens)
}

mutating func merge(_ other: Self?) {
guard let other else { return }
inputTokens = other.inputTokens ?? inputTokens
outputTokens = other.outputTokens ?? outputTokens
totalTokens = other.totalTokens ?? totalTokens
reasoningTokens = other.reasoningTokens ?? reasoningTokens
cachedInputTokens = other.cachedInputTokens ?? cachedInputTokens
cacheCreationInputTokens = other.cacheCreationInputTokens ?? cacheCreationInputTokens
}

private static func sum(_ lhs: Int?, _ rhs: Int?) -> Int? {
switch (lhs, rhs) {
case (.some(let lhs), .some(let rhs)):
lhs + rhs
case (.some(let lhs), .none):
lhs
case (.none, .some(let rhs)):
rhs
case (.none, .none):
nil
}
}
}
57 changes: 53 additions & 4 deletions Sources/AnyLanguageModel/Models/AnthropicLanguageModel.swift
Original file line number Diff line number Diff line change
Expand Up @@ -345,6 +345,7 @@ public struct AnthropicLanguageModel: LanguageModel {
)

var entries: [Transcript.Entry] = []
let usage = message.usage?.languageModelUsage

// Handle tool calls, if present
let toolUses: [AnthropicToolUse] = message.content.compactMap { block in
Expand All @@ -363,6 +364,7 @@ public struct AnthropicLanguageModel: LanguageModel {
return LanguageModelSession.Response(
content: empty.content,
rawContent: empty.rawContent,
usage: usage,
transcriptEntries: ArraySlice(entries)
)
case .invocations(let invocations):
Expand All @@ -386,6 +388,7 @@ public struct AnthropicLanguageModel: LanguageModel {
return LanguageModelSession.Response(
content: text as! Content,
rawContent: GeneratedContent(text),
usage: usage,
transcriptEntries: ArraySlice(entries)
)
}
Expand All @@ -395,6 +398,7 @@ public struct AnthropicLanguageModel: LanguageModel {
return LanguageModelSession.Response(
content: content,
rawContent: rawContent,
usage: usage,
transcriptEntries: ArraySlice(entries)
)
}
Expand Down Expand Up @@ -445,30 +449,48 @@ public struct AnthropicLanguageModel: LanguageModel {

var accumulatedText = ""
let expectsStructuredResponse = type != String.self
var latestUsage = LanguageModelUsage()
var lastSnapshot: LanguageModelSession.ResponseStream<Content>.Snapshot?

for try await event in events {
switch event {
case .messageStart(let start):
latestUsage.merge(start.message.usage?.languageModelUsage)
case .contentBlockDelta(let delta):
if case .textDelta(let textDelta) = delta.delta {
accumulatedText += textDelta.text

if expectsStructuredResponse {
if let snapshot: LanguageModelSession.ResponseStream<Content>.Snapshot =
if var snapshot: LanguageModelSession.ResponseStream<Content>.Snapshot =
try? partialSnapshot(from: accumulatedText)
{
snapshot.usage = latestUsage.normalized
lastSnapshot = snapshot
continuation.yield(snapshot)
}
} else {
let raw = GeneratedContent(accumulatedText)
let content: Content.PartiallyGenerated = (accumulatedText as! Content)
.asPartiallyGenerated()
continuation.yield(.init(content: content, rawContent: raw))
let snapshot = LanguageModelSession.ResponseStream<Content>.Snapshot(
content: content,
rawContent: raw,
usage: latestUsage.normalized
)
lastSnapshot = snapshot
continuation.yield(snapshot)
}
}
case .messageDelta(let delta):
latestUsage.merge(delta.usage?.languageModelUsage)
case .messageStop:
if var lastSnapshot, lastSnapshot.usage != latestUsage.normalized {
lastSnapshot.usage = latestUsage.normalized
continuation.yield(lastSnapshot)
}
continuation.finish()
return
case .messageStart, .contentBlockStart, .contentBlockStop, .messageDelta, .ping, .ignored:
case .contentBlockStart, .contentBlockStop, .ping, .ignored:
break
}
}
Expand Down Expand Up @@ -995,9 +1017,10 @@ private struct AnthropicMessageResponse: Codable, Sendable {
let content: [AnthropicContent]
let model: String
let stopReason: StopReason?
let usage: AnthropicUsage?

enum CodingKeys: String, CodingKey {
case id, type, role, content, model
case id, type, role, content, model, usage
case stopReason = "stop_reason"
}

Expand All @@ -1012,6 +1035,20 @@ private struct AnthropicMessageResponse: Codable, Sendable {
}
}

private struct AnthropicUsage: Codable, Sendable {
let inputTokens: Int?
let outputTokens: Int?
let cacheCreationInputTokens: Int?
let cacheReadInputTokens: Int?

enum CodingKeys: String, CodingKey {
case inputTokens = "input_tokens"
case outputTokens = "output_tokens"
case cacheCreationInputTokens = "cache_creation_input_tokens"
case cacheReadInputTokens = "cache_read_input_tokens"
}
}

private struct AnthropicErrorResponse: Codable { let error: AnthropicErrorDetail }
private struct AnthropicErrorDetail: Codable {
let type: String
Expand Down Expand Up @@ -1157,6 +1194,7 @@ private enum AnthropicStreamEvent: Codable, Sendable {
struct MessageDeltaEvent: Codable, Sendable {
let type: String
let delta: Delta
let usage: AnthropicUsage?

struct Delta: Codable, Sendable {
let stopReason: String?
Expand All @@ -1169,3 +1207,14 @@ private enum AnthropicStreamEvent: Codable, Sendable {
}
}
}

private extension AnthropicUsage {
var languageModelUsage: LanguageModelUsage? {
LanguageModelUsage(
inputTokens: inputTokens,
outputTokens: outputTokens,
cachedInputTokens: cacheReadInputTokens,
cacheCreationInputTokens: cacheCreationInputTokens
).normalized
}
}
Loading