-
Notifications
You must be signed in to change notification settings - Fork 63
Add additionalContext support to MLXLanguageModel
#145
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
|
|
@@ -208,6 +208,12 @@ import Foundation | |||||
| public var kvGroupSize: Int | ||||||
| /// Sets the token offset where quantized KV storage starts. | ||||||
| public var quantizedKVStart: Int | ||||||
| /// Additional key-value pairs injected into the chat template rendering context. | ||||||
| public var additionalContext: [String: MLXLMCommon.JSONValue]? | ||||||
|
|
||||||
| var additionalContextForUserInput: [String: any Sendable]? { | ||||||
| additionalContext?.mapValues { $0.toSendable() } | ||||||
| } | ||||||
|
|
||||||
| /// Creates MLX-specific generation options. | ||||||
| /// | ||||||
|
|
@@ -218,16 +224,20 @@ import Foundation | |||||
| /// Pass `nil` to disable KV quantization. | ||||||
| /// - kvGroupSize: The token group size used for KV quantization. | ||||||
| /// - quantizedKVStart: The token index where quantized KV storage begins. | ||||||
| /// - additionalContext: Additional key-value pairs injected into the chat | ||||||
| /// template rendering context. | ||||||
| public init( | ||||||
| maxKVSize: Int? = nil, | ||||||
| kvBits: Int? = nil, | ||||||
| kvGroupSize: Int = 64, | ||||||
| quantizedKVStart: Int = 0 | ||||||
| quantizedKVStart: Int = 0, | ||||||
| additionalContext: [String: MLXLMCommon.JSONValue]? = nil | ||||||
| ) { | ||||||
| self.maxKVSize = maxKVSize | ||||||
| self.kvBits = kvBits | ||||||
| self.kvGroupSize = kvGroupSize | ||||||
| self.quantizedKVStart = quantizedKVStart | ||||||
| self.additionalContext = additionalContext | ||||||
| } | ||||||
| } | ||||||
|
|
||||||
|
|
@@ -813,6 +823,9 @@ import Foundation | |||||
| // Map AnyLanguageModel GenerationOptions to MLX GenerateParameters | ||||||
| let generateParameters = toGenerateParameters(options) | ||||||
|
|
||||||
| // Extract additional context from custom options | ||||||
| let additionalContext = options[custom: MLXLanguageModel.self]?.additionalContextForUserInput | ||||||
|
|
||||||
| // Build chat history from full transcript | ||||||
| var chat = convertTranscriptToMLXChat(session: session, fallbackPrompt: prompt.description) | ||||||
|
|
||||||
|
|
@@ -828,7 +841,8 @@ import Foundation | |||||
| let userInput = MLXLMCommon.UserInput( | ||||||
| chat: chat, | ||||||
| processing: .init(resize: .init(width: 512, height: 512)), | ||||||
| tools: toolSpecs | ||||||
| tools: toolSpecs, | ||||||
| additionalContext: additionalContext, | ||||||
| ) | ||||||
|
Comment on lines
841
to
846
|
||||||
| let lmInput = try await context.processor.prepare(input: userInput) | ||||||
| let resolved = resolveCache( | ||||||
|
|
@@ -991,10 +1005,15 @@ import Foundation | |||||
|
|
||||||
| // Build chat inside task to avoid Sendable issues | ||||||
| let generateParameters = toGenerateParameters(options) | ||||||
| let userInput = makeUserInput( | ||||||
| session: session, | ||||||
| fallbackPrompt: prompt.description, | ||||||
| tools: nil | ||||||
| let chat = convertTranscriptToMLXChat(session: session, fallbackPrompt: prompt.description) | ||||||
|
|
||||||
| let additionalContext = options[custom: MLXLanguageModel.self]?.additionalContextForUserInput | ||||||
|
|
||||||
| let userInput = MLXLMCommon.UserInput( | ||||||
| chat: chat, | ||||||
| processing: .init(resize: .init(width: 512, height: 512)), | ||||||
| tools: nil, | ||||||
| additionalContext: additionalContext | ||||||
| ) | ||||||
| let lmInput = try await context.processor.prepare(input: userInput) | ||||||
| let resolved = resolveCache( | ||||||
|
|
@@ -1529,10 +1548,14 @@ import Foundation | |||||
| let baseChat = convertTranscriptToMLXChat(session: session, fallbackPrompt: prompt.description) | ||||||
| let schemaPrompt = includeSchemaInPrompt ? schemaPrompt(for: schema) : nil | ||||||
| let chat = normalizeChatForStructuredGeneration(baseChat, schemaPrompt: schemaPrompt) | ||||||
|
|
||||||
| let additionalContext = options[custom: MLXLanguageModel.self]?.additionalContextForUserInput | ||||||
|
|
||||||
| let userInput = MLXLMCommon.UserInput( | ||||||
| chat: chat, | ||||||
| processing: .init(resize: .init(width: 512, height: 512)), | ||||||
| tools: nil | ||||||
| tools: nil, | ||||||
| additionalContext: additionalContext, | ||||||
| ) | ||||||
| let lmInput = try await context.processor.prepare(input: userInput) | ||||||
|
|
||||||
|
|
@@ -1773,4 +1796,18 @@ import Foundation | |||||
| return sampledToken.item(Int.self) | ||||||
| } | ||||||
| } | ||||||
| extension MLXLMCommon.JSONValue { | ||||||
| /// Recursively converts a `JSONValue` to its primitive Swift equivalent. | ||||||
| func toSendable() -> any Sendable { | ||||||
| switch self { | ||||||
| case .string(let s): return s | ||||||
| case .int(let i): return i | ||||||
| case .double(let d): return d | ||||||
| case .bool(let b): return b | ||||||
| case .null: return NSNull() | ||||||
|
||||||
| case .null: return NSNull() | |
| case .null: return MLXLMCommon.JSONValue.null |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
additionalContextis exposed as[String: MLXLMCommon.JSONValue]?, which leaks the MLXLMCommon dependency into AnyLanguageModel’s public API. Other models’ custom options use the package-wideJSONValue(e.g., Anthropic’sextraBody: [String: JSONValue]?), so consider changing this to[String: JSONValue]?(or a typealias) and converting internally before buildingMLXLMCommon.UserInput.