diff --git a/README.md b/README.md index 2abda96..c54846f 100644 --- a/README.md +++ b/README.md @@ -488,6 +488,32 @@ let response = try await session.respond { } ``` +You can tune MLX KV-cache behavior per request with model-specific options: + +```swift +var options = GenerationOptions(temperature: 0.7) +options[custom: MLXLanguageModel.self] = .init( + maxKVSize: 4096, + kvBits: 4, + kvGroupSize: 64, + quantizedKVStart: 128 +) + +let response = try await session.respond( + to: "Summarize this transcript", + options: options +) +``` + +GPU cache behavior can be configured when creating the model: + +```swift +let model = MLXLanguageModel( + modelId: "mlx-community/Qwen3-0.6B-4bit", + gpuMemory: .automatic +) +``` + Vision support depends on the specific MLX model you load. Use a vision‑capable model for multimodal prompts (for example, a VLM variant). diff --git a/Sources/AnyLanguageModel/Models/MLXLanguageModel.swift b/Sources/AnyLanguageModel/Models/MLXLanguageModel.swift index f4be593..8f23b2f 100644 --- a/Sources/AnyLanguageModel/Models/MLXLanguageModel.swift +++ b/Sources/AnyLanguageModel/Models/MLXLanguageModel.swift @@ -191,6 +191,341 @@ import Foundation case failedToLoad(String) } + /// Configures MLX-specific generation behavior. + /// + /// Set these values through ``GenerationOptions`` using + /// `GenerationOptions[custom: MLXLanguageModel.self]`. + public struct CustomGenerationOptions: AnyLanguageModel.CustomGenerationOptions, Codable { + /// Limits how many tokens the KV cache retains. + /// + /// Set this to `nil` to use the backend default. + public var maxKVSize: Int? + /// Sets the KV-cache quantization bit width. + /// + /// Set this to `nil` to disable KV quantization. + public var kvBits: Int? + /// Sets the token group size used for KV quantization. + public var kvGroupSize: Int + /// Sets the token offset where quantized KV storage starts. + public var quantizedKVStart: Int + + /// Creates MLX-specific generation options. + /// + /// - Parameters: + /// - maxKVSize: The maximum number of tokens to retain in KV cache storage. + /// Pass `nil` to use the backend default. + /// - kvBits: The KV-cache quantization bit width. + /// Pass `nil` to disable KV quantization. + /// - kvGroupSize: The token group size used for KV quantization. + /// - quantizedKVStart: The token index where quantized KV storage begins. + public init( + maxKVSize: Int? = nil, + kvBits: Int? = nil, + kvGroupSize: Int = 64, + quantizedKVStart: Int = 0 + ) { + self.maxKVSize = maxKVSize + self.kvBits = kvBits + self.kvGroupSize = kvGroupSize + self.quantizedKVStart = quantizedKVStart + } + } + + /// Controls GPU buffer-pool limits during active and idle phases. + public struct GPUMemoryConfiguration: Sendable, Hashable { + /// The cache limit applied while at least one generation is active. + public var activeCacheLimit: Int + /// The cache limit applied when no generations are active. + public var idleCacheLimit: Int + /// Indicates whether MLX clears cached GPU buffers on safe eviction. + public var clearCacheOnEviction: Bool + + /// Creates a GPU-memory configuration for MLX generations. + /// + /// - Parameters: + /// - activeCacheLimit: The GPU cache-limit value used during active generation. + /// - idleCacheLimit: The GPU cache-limit value used while idle. + /// - clearCacheOnEviction: A Boolean value that indicates whether to clear + /// cached GPU buffers when eviction is safe. + public init( + activeCacheLimit: Int, + idleCacheLimit: Int, + clearCacheOnEviction: Bool = true + ) { + self.activeCacheLimit = activeCacheLimit + self.idleCacheLimit = idleCacheLimit + self.clearCacheOnEviction = clearCacheOnEviction + } + + /// Returns a memory configuration using physical-memory heuristics. + /// + /// The active limit scales with device RAM, + /// and the idle limit stays conservative to reduce background memory pressure. + public static var automatic: GPUMemoryConfiguration { + let ramBytes = ProcessInfo.processInfo.physicalMemory + let ramGB = ramBytes / (1024 * 1024 * 1024) + let active: Int + switch ramGB { + case ..<4: + active = 128_000_000 + case ..<6: + active = 256_000_000 + case ..<8: + active = 512_000_000 + default: + active = 768_000_000 + } + + return .init( + activeCacheLimit: active, + idleCacheLimit: 50_000_000, + clearCacheOnEviction: true + ) + } + + /// Returns a memory configuration that leaves GPU cache effectively unconstrained. + /// + /// Use this when your application prefers maximum reuse over memory reclamation. + public static var unconstrained: GPUMemoryConfiguration { + .init( + activeCacheLimit: Int.max, + idleCacheLimit: Int.max, + clearCacheOnEviction: false + ) + } + } + + private struct CacheConfigSignature: Equatable { + let maxKVSize: Int? + let kvBits: Int? + let kvGroupSize: Int + let quantizedKVStart: Int + } + + private final class SessionCacheEntry: @unchecked Sendable { + let kvCache: [MLXLMCommon.KVCache] + let prefillTokenCount: Int + let prefixTokens: [Int32] + let cacheConfigSignature: CacheConfigSignature + + init( + kvCache: [MLXLMCommon.KVCache], + prefillTokenCount: Int, + prefixTokens: [Int32], + cacheConfigSignature: CacheConfigSignature + ) { + self.kvCache = kvCache + self.prefillTokenCount = prefillTokenCount + self.prefixTokens = prefixTokens + self.cacheConfigSignature = cacheConfigSignature + } + } + + private final class SessionKVStore: @unchecked Sendable { + private final class WeakSessionReference: @unchecked Sendable { + weak var session: LanguageModelSession? + + init(_ session: LanguageModelSession) { + self.session = session + } + } + + private struct SessionBucket { + let sessionReference: WeakSessionReference + var modelEntries: [String: SessionCacheEntry] + } + + private let lock = NSLock() + private var buckets: [ObjectIdentifier: SessionBucket] = [:] + + func entry( + for session: LanguageModelSession, + modelKey: String + ) -> SessionCacheEntry? { + lock.withLock { + reapDeadSessionsLocked() + return buckets[ObjectIdentifier(session)]?.modelEntries[modelKey] + } + } + + func set( + _ entry: SessionCacheEntry, + for session: LanguageModelSession, + modelKey: String + ) { + lock.withLock { + reapDeadSessionsLocked() + let id = ObjectIdentifier(session) + var bucket = + buckets[id] + ?? SessionBucket( + sessionReference: WeakSessionReference(session), + modelEntries: [:] + ) + bucket.modelEntries[modelKey] = entry + buckets[id] = bucket + } + } + + func removeEntry( + for session: LanguageModelSession, + modelKey: String + ) { + lock.withLock { + reapDeadSessionsLocked() + let id = ObjectIdentifier(session) + guard var bucket = buckets[id] else { + return + } + bucket.modelEntries[modelKey] = nil + if bucket.modelEntries.isEmpty { + buckets[id] = nil + } else { + buckets[id] = bucket + } + } + } + + func removeEntries(forModelKey modelKey: String) { + lock.withLock { + reapDeadSessionsLocked() + for id in Array(buckets.keys) { + guard var bucket = buckets[id] else { + continue + } + bucket.modelEntries[modelKey] = nil + if bucket.modelEntries.isEmpty { + buckets[id] = nil + } else { + buckets[id] = bucket + } + } + } + } + + func removeAll() { + lock.withLock { + buckets.removeAll() + } + } + + private func reapDeadSessionsLocked() { + let deadSessionIDs = buckets.compactMap { id, bucket in + bucket.sessionReference.session == nil ? id : nil + } + for id in deadSessionIDs { + buckets[id] = nil + } + } + } + + private final class SessionGenerationGate: @unchecked Sendable { + private let lock = NSLock() + private var activeSessions: Set = [] + + func acquire(session: LanguageModelSession) -> Bool { + lock.withLock { + let id = ObjectIdentifier(session) + guard !activeSessions.contains(id) else { + return false + } + activeSessions.insert(id) + return true + } + } + + func release(session: LanguageModelSession) { + _ = lock.withLock { + activeSessions.remove(ObjectIdentifier(session)) + } + } + } + + private final class GPUMemoryManager: @unchecked Sendable { + static let shared = GPUMemoryManager() + + private let lock = NSLock() + private var knownConfigs: Set = [] + private var activeScopes: [UUID: GPUMemoryConfiguration] = [:] + + private init() { + GPU.set(cacheLimit: GPUMemoryConfiguration.automatic.idleCacheLimit) + } + + func register(_ configuration: GPUMemoryConfiguration) { + var cacheLimitToSet: Int? + lock.withLock { + knownConfigs.insert(configuration) + if activeScopes.isEmpty { + cacheLimitToSet = effectiveIdleLimit() + } + } + if let cacheLimitToSet { + GPU.set(cacheLimit: cacheLimitToSet) + } + } + + func markActive(_ configuration: GPUMemoryConfiguration) -> UUID { + let id = UUID() + let cacheLimitToSet = lock.withLock { + knownConfigs.insert(configuration) + activeScopes[id] = configuration + return effectiveActiveLimit() + } + GPU.set(cacheLimit: cacheLimitToSet) + return id + } + + func markIdle(scope id: UUID) { + let cacheLimitToSet = lock.withLock { + activeScopes.removeValue(forKey: id) + if activeScopes.isEmpty { + return effectiveIdleLimit() + } + return effectiveActiveLimit() + } + GPU.set(cacheLimit: cacheLimitToSet) + } + + func evictIfSafe() { + var shouldUpdateCacheLimit = false + var cacheLimitToSet = 0 + var shouldClearCache = false + lock.withLock { + guard activeScopes.isEmpty else { return } + shouldUpdateCacheLimit = true + cacheLimitToSet = effectiveIdleLimit() + shouldClearCache = shouldClearOnEviction() + } + guard shouldUpdateCacheLimit else { return } + GPU.set(cacheLimit: cacheLimitToSet) + if shouldClearCache { + GPU.clearCache() + } + } + + private func effectiveActiveLimit() -> Int { + let limits = activeScopes.values.map(\.activeCacheLimit) + return limits.max() ?? effectiveIdleLimit() + } + + private func idlePolicyConfiguration() -> GPUMemoryConfiguration { + knownConfigs.max(by: { $0.idleCacheLimit < $1.idleCacheLimit }) + ?? GPUMemoryConfiguration.automatic + } + + private func effectiveIdleLimit() -> Int { + idlePolicyConfiguration().idleCacheLimit + } + + private func shouldClearOnEviction() -> Bool { + idlePolicyConfiguration().clearCacheOnEviction + } + } + + private static let sessionKVCache = SessionKVStore() + private static let sessionGenerationGate = SessionGenerationGate() + /// The model identifier. public let modelId: String @@ -200,16 +535,27 @@ import Foundation /// The local directory containing the model files. public let directory: URL? + /// GPU memory behavior used for this model's generation scopes. + public let gpuMemory: GPUMemoryConfiguration + /// Creates an MLX language model. /// /// - Parameters: /// - modelId: The model identifier (for example, "mlx-community/Llama-3.2-3B-Instruct-4bit"). /// - hub: An optional Hub API instance for downloading models. If not provided, the default Hub API is used. /// - directory: An optional local directory URL containing the model files. If provided, the model is loaded from this directory instead of downloading. - public init(modelId: String, hub: HubApi? = nil, directory: URL? = nil) { + /// - gpuMemory: The GPU-memory behavior used for this model's active and idle phases. + public init( + modelId: String, + hub: HubApi? = nil, + directory: URL? = nil, + gpuMemory: GPUMemoryConfiguration = .automatic + ) { self.modelId = modelId self.hub = hub self.directory = directory + self.gpuMemory = gpuMemory + GPUMemoryManager.shared.register(gpuMemory) } /// The current availability of this model in memory. @@ -233,11 +579,15 @@ import Foundation public func removeFromCache() async { let key = directory?.absoluteString ?? modelId await modelCache.removeAndCancel(for: key) + Self.removeSessionCaches(forModelKey: modelSessionCacheKey()) + GPUMemoryManager.shared.evictIfSafe() } /// Removes all MLX models from the shared cache and cancels in-flight loads. public static func removeAllFromCache() async { await modelCache.removeAllAndCancel() + sessionKVCache.removeAll() + GPUMemoryManager.shared.evictIfSafe() } /// Get or load model context with caching @@ -253,6 +603,176 @@ import Foundation } } + private static func sessionKey(model: MLXLanguageModel) -> String { + let directoryKey = model.directory?.absoluteString ?? "" + return "\(model.modelId)|\(directoryKey)" + } + + private func modelSessionCacheKey() -> String { + Self.sessionKey(model: self) + } + + private func getSessionCache(for session: LanguageModelSession) -> SessionCacheEntry? { + Self.sessionKVCache.entry(for: session, modelKey: modelSessionCacheKey()) + } + + private func setSessionCache(_ entry: SessionCacheEntry, for session: LanguageModelSession) { + Self.sessionKVCache.set(entry, for: session, modelKey: modelSessionCacheKey()) + } + + private func removeSessionCache(for session: LanguageModelSession) { + Self.sessionKVCache.removeEntry(for: session, modelKey: modelSessionCacheKey()) + } + + private static func removeSessionCaches(forModelKey modelKey: String) { + sessionKVCache.removeEntries(forModelKey: modelKey) + } + + private static func concurrentSessionError() -> LanguageModelSession.GenerationError { + .concurrentRequests( + .init( + debugDescription: + "Concurrent requests on the same LanguageModelSession are not supported for MLX due to cache and memory management constraints." + ) + ) + } + + private static func maxToolIterationsExceededError(limit: Int) -> LanguageModelSession.GenerationError { + .decodingFailure( + .init( + debugDescription: + "Exceeded maximum tool iterations (\(limit)) while processing MLX tool calls." + ) + ) + } + + private static func repeatedToolCallLoopError() -> LanguageModelSession.GenerationError { + .decodingFailure( + .init( + debugDescription: + "Detected repeated MLX tool-call signature and aborted to avoid an infinite tool loop." + ) + ) + } + + private static func acquireGenerationSlot(for session: LanguageModelSession) -> Bool { + sessionGenerationGate.acquire(session: session) + } + + private static func releaseGenerationSlot(for session: LanguageModelSession) { + sessionGenerationGate.release(session: session) + } + + private func cacheSignature(from parameters: MLXLMCommon.GenerateParameters) -> CacheConfigSignature { + CacheConfigSignature( + maxKVSize: parameters.maxKVSize, + kvBits: parameters.kvBits, + kvGroupSize: parameters.kvGroupSize, + quantizedKVStart: parameters.quantizedKVStart + ) + } + + private func tokens(from input: MLXLMCommon.LMInput) -> [Int32] { + input.text.tokens.asArray(Int32.self) + } + + private func isCacheHit( + entry: SessionCacheEntry, + currentTokens: [Int32], + signature: CacheConfigSignature, + lmInput: MLXLMCommon.LMInput + ) -> Bool { + guard lmInput.image == nil, lmInput.video == nil else { + return false + } + guard entry.cacheConfigSignature == signature else { + return false + } + guard entry.prefillTokenCount > 0, currentTokens.count > entry.prefillTokenCount else { + return false + } + guard entry.prefixTokens.count == entry.prefillTokenCount else { + return false + } + return currentTokens.starts(with: entry.prefixTokens) + } + + private func resolveCache( + session: LanguageModelSession, + lmInput: MLXLMCommon.LMInput, + generateParameters: MLXLMCommon.GenerateParameters, + context: ModelContext + ) -> (cache: [MLXLMCommon.KVCache], input: MLXLMCommon.LMInput, fullTokens: [Int32]) { + let signature = cacheSignature(from: generateParameters) + let fullTokens = tokens(from: lmInput) + let existingEntry = getSessionCache(for: session) + + if let existingEntry, + isCacheHit(entry: existingEntry, currentTokens: fullTokens, signature: signature, lmInput: lmInput) + { + let cachedCount = existingEntry.prefillTokenCount + let newTokens = lmInput.text.tokens[cachedCount...] + let newMask = lmInput.text.mask?[cachedCount...] + let partialText = MLXLMCommon.LMInput.Text(tokens: newTokens, mask: newMask) + return (existingEntry.kvCache, MLXLMCommon.LMInput(text: partialText), fullTokens) + } + + if existingEntry != nil { + removeSessionCache(for: session) + } + + let newCache = context.model.newCache(parameters: generateParameters) + return (newCache, lmInput, fullTokens) + } + + private func storeSessionCache( + cache: [MLXLMCommon.KVCache], + fullTokens: [Int32], + generateParameters: MLXLMCommon.GenerateParameters, + session: LanguageModelSession + ) { + let offset = cache.first?.offset ?? 0 + let prefillCount = max(0, min(offset, fullTokens.count)) + guard prefillCount > 0 else { + removeSessionCache(for: session) + return + } + + let prefixTokens = Array(fullTokens.prefix(prefillCount)) + let entry = SessionCacheEntry( + kvCache: cache, + prefillTokenCount: prefillCount, + prefixTokens: prefixTokens, + cacheConfigSignature: cacheSignature(from: generateParameters) + ) + setSessionCache(entry, for: session) + } + + private func beginGenerationScope() -> UUID { + GPUMemoryManager.shared.markActive(gpuMemory) + } + + private func endGenerationScope(_ id: UUID) { + GPUMemoryManager.shared.markIdle(scope: id) + } + + private func mlxToolSpecs(for session: LanguageModelSession) -> [ToolSpec]? { + session.tools.isEmpty ? nil : session.tools.map { convertToolToMLXSpec($0) } + } + + private func makeUserInput( + session: LanguageModelSession, + fallbackPrompt: String, + tools: [ToolSpec]? + ) -> MLXLMCommon.UserInput { + let chat = convertTranscriptToMLXChat(session: session, fallbackPrompt: fallbackPrompt) + return MLXLMCommon.UserInput( + chat: chat, + processing: .init(resize: .init(width: 512, height: 512)), + tools: tools + ) + } + public func respond( within session: LanguageModelSession, to prompt: Prompt, @@ -260,8 +780,15 @@ import Foundation includeSchemaInPrompt: Bool, options: GenerationOptions ) async throws -> LanguageModelSession.Response where Content: Generable { + guard Self.acquireGenerationSlot(for: session) else { + throw Self.concurrentSessionError() + } + defer { Self.releaseGenerationSlot(for: session) } + // Get cached or load fresh ModelContext let context = try await loadContext(modelId: modelId, hub: hub, directory: directory) + let generationScope = beginGenerationScope() + defer { endGenerationScope(generationScope) } if type != String.self { let jsonString = try await generateStructuredJSON( @@ -281,13 +808,7 @@ import Foundation ) } - // Convert session tools to MLX ToolSpec format - let toolSpecs: [ToolSpec]? = - session.tools.isEmpty - ? nil - : session.tools.map { tool in - convertToolToMLXSpec(tool) - } + let toolSpecs = mlxToolSpecs(for: session) // Map AnyLanguageModel GenerationOptions to MLX GenerateParameters let generateParameters = toGenerateParameters(options) @@ -297,6 +818,9 @@ import Foundation var allTextChunks: [String] = [] var allEntries: [Transcript.Entry] = [] + let maxToolIterations = 8 + var toolIteration = 0 + var previousToolCallSignature: String? // Loop until no more tool calls while true { @@ -304,13 +828,20 @@ import Foundation let userInput = MLXLMCommon.UserInput( chat: chat, processing: .init(resize: .init(width: 512, height: 512)), - tools: toolSpecs, + tools: toolSpecs ) let lmInput = try await context.processor.prepare(input: userInput) + let resolved = resolveCache( + session: session, + lmInput: lmInput, + generateParameters: generateParameters, + context: context + ) // Generate let stream = try MLXLMCommon.generate( - input: lmInput, + input: resolved.input, + cache: resolved.cache, parameters: generateParameters, context: context ) @@ -328,6 +859,12 @@ import Foundation collectedToolCalls.append(call) } } + storeSessionCache( + cache: resolved.cache, + fullTokens: resolved.fullTokens, + generateParameters: generateParameters, + session: session + ) let assistantText = chunks.joined() allTextChunks.append(assistantText) @@ -339,6 +876,24 @@ import Foundation // If there are tool calls, execute them and continue if !collectedToolCalls.isEmpty { + toolIteration += 1 + if toolIteration > maxToolIterations { + let unresolvedCalls = try makeTranscriptToolCalls(from: collectedToolCalls) + allEntries.append(Transcript.Entry.toolCalls(Transcript.ToolCalls(unresolvedCalls))) + throw Self.maxToolIterationsExceededError(limit: maxToolIterations) + } + + let signature = + collectedToolCalls + .map { "\($0.function.name):\($0.function.arguments)" } + .joined(separator: "|") + if signature == previousToolCallSignature { + let unresolvedCalls = try makeTranscriptToolCalls(from: collectedToolCalls) + allEntries.append(Transcript.Entry.toolCalls(Transcript.ToolCalls(unresolvedCalls))) + throw Self.repeatedToolCallLoopError() + } + previousToolCallSignature = signature + let resolution = try await resolveToolCalls(collectedToolCalls, session: session) switch resolution { case .stop(let calls): @@ -391,31 +946,67 @@ import Foundation guard type == String.self else { fatalError("MLXLanguageModel streaming only supports String content") } + guard Self.acquireGenerationSlot(for: session) else { + let error = Self.concurrentSessionError() + let stream: AsyncThrowingStream.Snapshot, any Error> = + .init { continuation in + continuation.finish(throwing: error) + } + return LanguageModelSession.ResponseStream(stream: stream) + } let modelId = self.modelId let hub = self.hub let directory = self.directory + let gpuMemory = self.gpuMemory let stream: AsyncThrowingStream.Snapshot, any Error> = .init { continuation in + let didEndScope = Locked(false) + let didReleaseGenerationSlot = Locked(false) + let generationScope = GPUMemoryManager.shared.markActive(gpuMemory) + let task = Task { @Sendable in + func finishScope() { + didEndScope.withLock { done in + if !done { + GPUMemoryManager.shared.markIdle(scope: generationScope) + done = true + } + } + } + + func finishGenerationSlot() { + didReleaseGenerationSlot.withLock { done in + if !done { + Self.releaseGenerationSlot(for: session) + done = true + } + } + } + do { // Get cached or load fresh ModelContext let context = try await loadContext(modelId: modelId, hub: hub, directory: directory) // Build chat inside task to avoid Sendable issues let generateParameters = toGenerateParameters(options) - let chat = convertTranscriptToMLXChat(session: session, fallbackPrompt: prompt.description) - - let userInput = MLXLMCommon.UserInput( - chat: chat, - processing: .init(resize: .init(width: 512, height: 512)), + let userInput = makeUserInput( + session: session, + fallbackPrompt: prompt.description, tools: nil ) let lmInput = try await context.processor.prepare(input: userInput) + let resolved = resolveCache( + session: session, + lmInput: lmInput, + generateParameters: generateParameters, + context: context + ) let mlxStream = try MLXLMCommon.generate( - input: lmInput, + input: resolved.input, + cache: resolved.cache, parameters: generateParameters, context: context ) @@ -436,12 +1027,36 @@ import Foundation } } + storeSessionCache( + cache: resolved.cache, + fullTokens: resolved.fullTokens, + generateParameters: generateParameters, + session: session + ) + finishScope() + finishGenerationSlot() continuation.finish() } catch { + finishScope() + finishGenerationSlot() continuation.finish(throwing: error) } } - continuation.onTermination = { _ in task.cancel() } + continuation.onTermination = { _ in + didEndScope.withLock { done in + if !done { + GPUMemoryManager.shared.markIdle(scope: generationScope) + done = true + } + } + didReleaseGenerationSlot.withLock { done in + if !done { + Self.releaseGenerationSlot(for: session) + done = true + } + } + task.cancel() + } } return LanguageModelSession.ResponseStream(stream: stream) @@ -457,8 +1072,37 @@ import Foundation let directory = self.directory Task { + guard Self.acquireGenerationSlot(for: session) else { + return + } + defer { Self.releaseGenerationSlot(for: session) } + + let generationScope = beginGenerationScope() + defer { endGenerationScope(generationScope) } + do { - _ = try await loadContext(modelId: modelId, hub: hub, directory: directory) + let context = try await loadContext(modelId: modelId, hub: hub, directory: directory) + guard let instructions = session.instructions?.description, !instructions.isEmpty else { + return + } + + let toolSpecs = mlxToolSpecs(for: session) + + let params = toGenerateParameters(.init()) + let newCache = context.model.newCache(parameters: params) + let userInput = MLXLMCommon.UserInput( + chat: [.init(role: .system, content: instructions)], + processing: .init(resize: .init(width: 512, height: 512)), + tools: toolSpecs + ) + let lmInput = try await context.processor.prepare(input: userInput) + _ = try context.model.prepare(lmInput, cache: newCache, windowSize: params.prefillStepSize) + storeSessionCache( + cache: newCache, + fullTokens: tokens(from: lmInput), + generateParameters: params, + session: session + ) } catch { // Ignore errors during prewarm } @@ -469,12 +1113,13 @@ import Foundation // MARK: - Options Mapping private func toGenerateParameters(_ options: GenerationOptions) -> MLXLMCommon.GenerateParameters { - MLXLMCommon.GenerateParameters( + let custom = options[custom: MLXLanguageModel.self] + return MLXLMCommon.GenerateParameters( maxTokens: options.maximumResponseTokens, - maxKVSize: nil, - kvBits: nil, - kvGroupSize: 64, - quantizedKVStart: 0, + maxKVSize: custom?.maxKVSize, + kvBits: custom?.kvBits, + kvGroupSize: custom?.kvGroupSize ?? 64, + quantizedKVStart: custom?.quantizedKVStart ?? 0, temperature: Float(options.temperature ?? 0.6), topP: 1.0, repetitionPenalty: nil, @@ -484,12 +1129,13 @@ import Foundation /// Builds MLX parameters tuned for structured generation. private func toStructuredGenerateParameters(_ options: GenerationOptions) -> MLXLMCommon.GenerateParameters { - MLXLMCommon.GenerateParameters( + let custom = options[custom: MLXLanguageModel.self] + return MLXLMCommon.GenerateParameters( maxTokens: options.maximumResponseTokens, - maxKVSize: nil, - kvBits: nil, - kvGroupSize: 64, - quantizedKVStart: 0, + maxKVSize: custom?.maxKVSize, + kvBits: custom?.kvBits, + kvGroupSize: custom?.kvGroupSize ?? 64, + quantizedKVStart: custom?.quantizedKVStart ?? 0, temperature: Float(options.temperature ?? 0.2), topP: 0.95, repetitionPenalty: 1.1, @@ -684,19 +1330,9 @@ import Foundation case invocations([ToolInvocationResult]) } - private func resolveToolCalls( - _ toolCalls: [MLXLMCommon.ToolCall], - session: LanguageModelSession - ) async throws -> ToolResolutionOutcome { - if toolCalls.isEmpty { return .invocations([]) } - - var toolsByName: [String: any Tool] = [:] - for tool in session.tools { - if toolsByName[tool.name] == nil { - toolsByName[tool.name] = tool - } - } - + private func makeTranscriptToolCalls( + from toolCalls: [MLXLMCommon.ToolCall] + ) throws -> [Transcript.ToolCall] { var transcriptCalls: [Transcript.ToolCall] = [] transcriptCalls.reserveCapacity(toolCalls.count) for call in toolCalls { @@ -710,6 +1346,23 @@ import Foundation ) ) } + return transcriptCalls + } + + private func resolveToolCalls( + _ toolCalls: [MLXLMCommon.ToolCall], + session: LanguageModelSession + ) async throws -> ToolResolutionOutcome { + if toolCalls.isEmpty { return .invocations([]) } + + var toolsByName: [String: any Tool] = [:] + for tool in session.tools { + if toolsByName[tool.name] == nil { + toolsByName[tool.name] = tool + } + } + + let transcriptCalls = try makeTranscriptToolCalls(from: toolCalls) if let delegate = session.toolExecutionDelegate { await delegate.didGenerateToolCalls(transcriptCalls, in: session) diff --git a/Tests/AnyLanguageModelTests/CustomGenerationOptionsTests.swift b/Tests/AnyLanguageModelTests/CustomGenerationOptionsTests.swift index fa81049..89539a1 100644 --- a/Tests/AnyLanguageModelTests/CustomGenerationOptionsTests.swift +++ b/Tests/AnyLanguageModelTests/CustomGenerationOptionsTests.swift @@ -862,6 +862,51 @@ struct GeminiCustomOptionsTests { } } +#if MLX + @Suite("MLX CustomGenerationOptions") + struct MLXCustomOptionsTests { + @Test func initialization() { + let options = MLXLanguageModel.CustomGenerationOptions( + maxKVSize: 4096, + kvBits: 4, + kvGroupSize: 64, + quantizedKVStart: 128 + ) + #expect(options.maxKVSize == 4096) + #expect(options.kvBits == 4) + #expect(options.kvGroupSize == 64) + #expect(options.quantizedKVStart == 128) + } + + @Test func integrationWithGenerationOptions() { + var options = GenerationOptions(temperature: 0.7) + options[custom: MLXLanguageModel.self] = .init( + maxKVSize: 2048, + kvBits: 8, + kvGroupSize: 32, + quantizedKVStart: 256 + ) + let retrieved = options[custom: MLXLanguageModel.self] + #expect(retrieved?.maxKVSize == 2048) + #expect(retrieved?.kvBits == 8) + #expect(retrieved?.kvGroupSize == 32) + #expect(retrieved?.quantizedKVStart == 256) + } + + @Test func codable() throws { + let options = MLXLanguageModel.CustomGenerationOptions( + maxKVSize: 8192, + kvBits: 4, + kvGroupSize: 64, + quantizedKVStart: 0 + ) + let data = try JSONEncoder().encode(options) + let decoded = try JSONDecoder().decode(MLXLanguageModel.CustomGenerationOptions.self, from: data) + #expect(decoded == options) + } + } +#endif + #if Llama @Suite("Llama CustomGenerationOptions") struct LlamaCustomOptionsTests { diff --git a/Tests/AnyLanguageModelTests/MLXLanguageModelTests.swift b/Tests/AnyLanguageModelTests/MLXLanguageModelTests.swift index 82c5f85..937b32e 100644 --- a/Tests/AnyLanguageModelTests/MLXLanguageModelTests.swift +++ b/Tests/AnyLanguageModelTests/MLXLanguageModelTests.swift @@ -69,6 +69,41 @@ import Testing #expect(!chunks.isEmpty) } + @Test func multiTurnSameSession() async throws { + let session = LanguageModelSession(model: model) + let first = try await session.respond(to: "Say hello in one sentence.") + #expect(!first.content.isEmpty) + + let second = try await session.respond(to: "Now answer with one more short sentence.") + #expect(!second.content.isEmpty) + } + + @Test func rejectsConcurrentRequestsForSameSession() async throws { + let session = LanguageModelSession(model: model) + let stream = session.streamResponse( + to: "Count from 1 to 400 with one number per line.", + options: .init(maximumResponseTokens: 256) + ) + + do { + _ = try await session.respond(to: "This concurrent request should fail.") + Issue.record("Expected concurrent request to throw.") + } catch let error as LanguageModelSession.GenerationError { + switch error { + case .concurrentRequests: + break + default: + Issue.record("Expected .concurrentRequests, got \(error)") + } + } catch { + Issue.record("Expected GenerationError.concurrentRequests, got \(error)") + } + + for try await _ in stream { + break + } + } + @Test func withGenerationOptions() async throws { let session = LanguageModelSession(model: model) @@ -239,5 +274,12 @@ import Testing } #expect(model.isAvailable == false) } + + @Test func removeAllFromCacheThenRespond() async throws { + await MLXLanguageModel.removeAllFromCache() + let session = LanguageModelSession(model: model) + let response = try await session.respond(to: "Say hello after cache clear") + #expect(!response.content.isEmpty) + } } #endif // MLX