Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 44 additions & 7 deletions Sources/AnyLanguageModel/Models/MLXLanguageModel.swift
Original file line number Diff line number Diff line change
Expand Up @@ -208,6 +208,12 @@ import Foundation
public var kvGroupSize: Int
/// Sets the token offset where quantized KV storage starts.
public var quantizedKVStart: Int
/// Additional key-value pairs injected into the chat template rendering context.
public var additionalContext: [String: MLXLMCommon.JSONValue]?

var additionalContextForUserInput: [String: any Sendable]? {
additionalContext?.mapValues { $0.toSendable() }
}
Comment on lines +211 to +216
Copy link

Copilot AI Mar 23, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

additionalContext is exposed as [String: MLXLMCommon.JSONValue]?, which leaks the MLXLMCommon dependency into AnyLanguageModel’s public API. Other models’ custom options use the package-wide JSONValue (e.g., Anthropic’s extraBody: [String: JSONValue]?), so consider changing this to [String: JSONValue]? (or a typealias) and converting internally before building MLXLMCommon.UserInput.

Copilot uses AI. Check for mistakes.

/// Creates MLX-specific generation options.
///
Expand All @@ -218,16 +224,20 @@ import Foundation
/// Pass `nil` to disable KV quantization.
/// - kvGroupSize: The token group size used for KV quantization.
/// - quantizedKVStart: The token index where quantized KV storage begins.
/// - additionalContext: Additional key-value pairs injected into the chat
/// template rendering context.
public init(
maxKVSize: Int? = nil,
kvBits: Int? = nil,
kvGroupSize: Int = 64,
quantizedKVStart: Int = 0
quantizedKVStart: Int = 0,
additionalContext: [String: MLXLMCommon.JSONValue]? = nil
) {
self.maxKVSize = maxKVSize
self.kvBits = kvBits
self.kvGroupSize = kvGroupSize
self.quantizedKVStart = quantizedKVStart
self.additionalContext = additionalContext
}
}

Expand Down Expand Up @@ -813,6 +823,9 @@ import Foundation
// Map AnyLanguageModel GenerationOptions to MLX GenerateParameters
let generateParameters = toGenerateParameters(options)

// Extract additional context from custom options
let additionalContext = options[custom: MLXLanguageModel.self]?.additionalContextForUserInput

// Build chat history from full transcript
var chat = convertTranscriptToMLXChat(session: session, fallbackPrompt: prompt.description)

Expand All @@ -828,7 +841,8 @@ import Foundation
let userInput = MLXLMCommon.UserInput(
chat: chat,
processing: .init(resize: .init(width: 512, height: 512)),
tools: toolSpecs
tools: toolSpecs,
additionalContext: additionalContext,
)
Comment on lines 841 to 846
Copy link

Copilot AI Mar 23, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

makeUserInput(...) now appears unused, and MLXLMCommon.UserInput construction (including processing defaults) is duplicated in respond, streamResponse, and structured generation. To avoid drift and keep future changes (like additionalContext) consistent, consider updating makeUserInput to accept additionalContext and reusing it in these call sites, or removing the helper if you prefer inline construction.

Copilot uses AI. Check for mistakes.
let lmInput = try await context.processor.prepare(input: userInput)
let resolved = resolveCache(
Expand Down Expand Up @@ -991,10 +1005,15 @@ import Foundation

// Build chat inside task to avoid Sendable issues
let generateParameters = toGenerateParameters(options)
let userInput = makeUserInput(
session: session,
fallbackPrompt: prompt.description,
tools: nil
let chat = convertTranscriptToMLXChat(session: session, fallbackPrompt: prompt.description)

let additionalContext = options[custom: MLXLanguageModel.self]?.additionalContextForUserInput

let userInput = MLXLMCommon.UserInput(
chat: chat,
processing: .init(resize: .init(width: 512, height: 512)),
tools: nil,
additionalContext: additionalContext
)
let lmInput = try await context.processor.prepare(input: userInput)
let resolved = resolveCache(
Expand Down Expand Up @@ -1529,10 +1548,14 @@ import Foundation
let baseChat = convertTranscriptToMLXChat(session: session, fallbackPrompt: prompt.description)
let schemaPrompt = includeSchemaInPrompt ? schemaPrompt(for: schema) : nil
let chat = normalizeChatForStructuredGeneration(baseChat, schemaPrompt: schemaPrompt)

let additionalContext = options[custom: MLXLanguageModel.self]?.additionalContextForUserInput

let userInput = MLXLMCommon.UserInput(
chat: chat,
processing: .init(resize: .init(width: 512, height: 512)),
tools: nil
tools: nil,
additionalContext: additionalContext,
)
let lmInput = try await context.processor.prepare(input: userInput)

Expand Down Expand Up @@ -1773,4 +1796,18 @@ import Foundation
return sampledToken.item(Int.self)
}
}
extension MLXLMCommon.JSONValue {
/// Recursively converts a `JSONValue` to its primitive Swift equivalent.
func toSendable() -> any Sendable {
switch self {
case .string(let s): return s
case .int(let i): return i
case .double(let d): return d
case .bool(let b): return b
case .null: return NSNull()
Copy link

Copilot AI Mar 23, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

MLXLMCommon.JSONValue.toSendable() maps .null to NSNull(), while elsewhere in this file nulls are represented as MLXLMCommon.JSONValue.null when converting JSON for Sendable contexts. Mixing null representations can lead to inconsistent template rendering and may also interact poorly with strict Sendable checking; prefer a single null representation (ideally MLXLMCommon.JSONValue.null if additionalContext ultimately expects Sendable values).

Suggested change
case .null: return NSNull()
case .null: return MLXLMCommon.JSONValue.null

Copilot uses AI. Check for mistakes.
case .array(let arr): return arr.map { $0.toSendable() }
case .object(let obj): return obj.mapValues { $0.toSendable() }
}
}
}
#endif // MLX
22 changes: 22 additions & 0 deletions Tests/AnyLanguageModelTests/MLXLanguageModelTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -255,6 +255,28 @@ import Testing
#expect([Priority.low, Priority.medium, Priority.high].contains(response.content))
}

@Test func withAdditionalContext() async throws {
let session = LanguageModelSession(model: model)

var options = GenerationOptions(
temperature: 0.7,
maximumResponseTokens: 32
)
options[custom: MLXLanguageModel.self] = .init(
additionalContext: [
"user_name": .string("Alice"),
"turn_count": .int(3),
"verbose": .bool(true),
]
)

let response = try await session.respond(
to: "Say hello",
options: options
)
#expect(!response.content.isEmpty)
}

@Test func unavailableForNonexistentModel() async {
let model = MLXLanguageModel(modelId: "mlx-community/does-not-exist-anylanguagemodel-test")
await model.removeFromCache()
Expand Down
Loading