diff --git a/Sources/AnyLanguageModel/Models/MLXLanguageModel.swift b/Sources/AnyLanguageModel/Models/MLXLanguageModel.swift index 0ef37ef..551c9ef 100644 --- a/Sources/AnyLanguageModel/Models/MLXLanguageModel.swift +++ b/Sources/AnyLanguageModel/Models/MLXLanguageModel.swift @@ -287,6 +287,31 @@ import Foundation additionalContext?.mapValues { $0.toSendable() } } + /// Top-p (nucleus) sampling threshold. + /// + /// Set this to `nil` to use the backend default (`1.0`, i.e. disabled). + public var topP: Float? + + /// Top-k sampling: restricts sampling to the `k` most likely tokens. + /// + /// Set this to `nil` or `0` to disable top-k sampling. + public var topK: Int? + + /// Min-p sampling threshold, relative to the most likely token's probability. + /// + /// Set this to `nil` or `0` to disable min-p sampling. + public var minP: Float? + + /// Penalty factor applied to recently generated tokens to reduce repetition. + /// + /// Set this to `nil` to disable the repetition penalty. + public var repetitionPenalty: Float? + + /// Number of recent tokens considered by the repetition penalty. + /// + /// Set this to `nil` to use the backend default. + public var repetitionContextSize: Int? + /// Creates MLX-specific generation options. /// /// - Parameters: @@ -295,14 +320,30 @@ import Foundation /// template rendering context. /// - userInputProcessing: Processing to apply to user media before input preparation. /// Defaults to `nil`, which lets MLX use its default media handling. + /// - topP: Top-p (nucleus) sampling threshold. Defaults to `nil` (backend default). + /// - topK: Top-k sampling count. Defaults to `nil` (disabled). + /// - minP: Min-p sampling threshold. Defaults to `nil` (disabled). + /// - repetitionPenalty: Repetition penalty factor. Defaults to `nil` (disabled). + /// - repetitionContextSize: Repetition-penalty token window. Defaults to `nil` + /// (backend default). public init( kvCache: KVCache, userInputProcessing: UserInputProcessing?, - additionalContext: [String: AnyLanguageModel.JSONValue]? + additionalContext: [String: AnyLanguageModel.JSONValue]?, + topP: Float? = nil, + topK: Int? = nil, + minP: Float? = nil, + repetitionPenalty: Float? = nil, + repetitionContextSize: Int? = nil ) { self.kvCache = kvCache self.additionalContext = additionalContext self.userInputProcessing = userInputProcessing + self.topP = topP + self.topK = topK + self.minP = minP + self.repetitionPenalty = repetitionPenalty + self.repetitionContextSize = repetitionContextSize } /// Default MLX generation options used when none are provided at runtime. @@ -1214,34 +1255,63 @@ import Foundation // MARK: - Options Mapping - private func toGenerateParameters(_ options: GenerationOptions) -> MLXLMCommon.GenerateParameters { + /// Derives MLX sampler parameters from the core ``GenerationOptions/sampling`` (`SamplingMode`), + /// so `.sampling` acts as a unified sampling surface across backends (the same one Apple + /// FoundationModels consumes). Returns `nil` for any field the sampling mode doesn't express. + /// + /// Precedence at the call sites is custom-block → this (sampling) → existing default, so an + /// explicit `CustomGenerationOptions` value always wins. The `SamplingMode` seed is not + /// forwarded: `MLXLMCommon.GenerateParameters` has no per-call seed field. + func samplingDerivedParameters( + from options: GenerationOptions + ) -> (topP: Float?, topK: Int?, greedyTemperature: Float?) { + switch options.sampling?.mode { + case .greedy: + // Greedy = argmax; MLX realizes this with temperature 0. + return (topP: nil, topK: nil, greedyTemperature: 0) + case .topK(let k, _): + return (topP: nil, topK: k, greedyTemperature: nil) + case .nucleus(let threshold, _): + return (topP: Float(threshold), topK: nil, greedyTemperature: nil) + case nil: + return (topP: nil, topK: nil, greedyTemperature: nil) + } + } + + func toGenerateParameters(_ options: GenerationOptions) -> MLXLMCommon.GenerateParameters { let custom = options[custom: MLXLanguageModel.self] + let derived = samplingDerivedParameters(from: options) return MLXLMCommon.GenerateParameters( maxTokens: options.maximumResponseTokens, maxKVSize: custom?.kvCache.maxSize, kvBits: custom?.kvCache.bits, kvGroupSize: custom?.kvCache.groupSize ?? 64, quantizedKVStart: custom?.kvCache.quantizedStart ?? 0, - temperature: Float(options.temperature ?? 0.6), - topP: 1.0, - repetitionPenalty: nil, - repetitionContextSize: 20 + temperature: Float(options.temperature ?? derived.greedyTemperature.map(Double.init) ?? 0.6), + topP: custom?.topP ?? derived.topP ?? 1.0, + topK: custom?.topK ?? derived.topK ?? 0, + minP: custom?.minP ?? 0.0, + repetitionPenalty: custom?.repetitionPenalty, + repetitionContextSize: custom?.repetitionContextSize ?? 20 ) } /// Builds MLX parameters tuned for structured generation. - private func toStructuredGenerateParameters(_ options: GenerationOptions) -> MLXLMCommon.GenerateParameters { + func toStructuredGenerateParameters(_ options: GenerationOptions) -> MLXLMCommon.GenerateParameters { let custom = options[custom: MLXLanguageModel.self] + let derived = samplingDerivedParameters(from: options) return MLXLMCommon.GenerateParameters( maxTokens: options.maximumResponseTokens, maxKVSize: custom?.kvCache.maxSize, kvBits: custom?.kvCache.bits, kvGroupSize: custom?.kvCache.groupSize ?? 64, quantizedKVStart: custom?.kvCache.quantizedStart ?? 0, - temperature: Float(options.temperature ?? 0.2), - topP: 0.95, - repetitionPenalty: 1.1, - repetitionContextSize: 64 + temperature: Float(options.temperature ?? derived.greedyTemperature.map(Double.init) ?? 0.2), + topP: custom?.topP ?? derived.topP ?? 0.95, + topK: custom?.topK ?? derived.topK ?? 0, + minP: custom?.minP ?? 0.0, + repetitionPenalty: custom?.repetitionPenalty ?? 1.1, + repetitionContextSize: custom?.repetitionContextSize ?? 64 ) } diff --git a/Sources/AnyLanguageModel/Models/SystemLanguageModel.swift b/Sources/AnyLanguageModel/Models/SystemLanguageModel.swift index ff2e8f2..c1aa2af 100644 --- a/Sources/AnyLanguageModel/Models/SystemLanguageModel.swift +++ b/Sources/AnyLanguageModel/Models/SystemLanguageModel.swift @@ -384,6 +384,22 @@ @available(macOS 26.0, iOS 26.0, watchOS 26.0, tvOS 26.0, visionOS 26.0, *) extension GenerationOptions { fileprivate func toFoundationModels() -> FoundationModels.GenerationOptions { + // OS 27 support: `#if compiler(>=6.4)` compiles the 27 path only on the Xcode 27 + // (beta) toolchain — Xcode 26 skips it entirely and never sees the 27-only symbol. + // `#available` then gates it to OS 27 at runtime; deployment floor stays 26 via the + // fall-through below. Self-contained (no FM-parity mapping required on this branch). + // KNOBS to confirm on the first Xcode 27 build: the compiler gate version, and the + // initializer's parameter labels if Apple changed them. + #if compiler(>=6.4) + if #available(macOS 27.0, iOS 27.0, visionOS 27.0, *) { + return FoundationModels.GenerationOptions( + temperature: temperature, + maximumResponseTokens: maximumResponseTokens, + toolCallingMode: nil + ) + } + #endif + var options = FoundationModels.GenerationOptions() if let temperature = self.temperature { diff --git a/Tests/AnyLanguageModelTests/CustomGenerationOptionsTests.swift b/Tests/AnyLanguageModelTests/CustomGenerationOptionsTests.swift index 3a4a812..fbfed73 100644 --- a/Tests/AnyLanguageModelTests/CustomGenerationOptionsTests.swift +++ b/Tests/AnyLanguageModelTests/CustomGenerationOptionsTests.swift @@ -901,6 +901,68 @@ struct GeminiCustomOptionsTests { #expect(retrieved?.kvCache.quantizedStart == 256) } + // MARK: - SamplingMode → MLX derivation + + @Test func samplingDerivationGreedy() { + let derived = samplingDerivedParameters(from: GenerationOptions(sampling: .greedy)) + #expect(derived.topP == nil) + #expect(derived.topK == nil) + #expect(derived.greedyTemperature == 0) + } + + @Test func samplingDerivationTopK() { + let derived = samplingDerivedParameters(from: GenerationOptions(sampling: .random(top: 40, seed: 7))) + #expect(derived.topK == 40) + #expect(derived.topP == nil) + #expect(derived.greedyTemperature == nil) + } + + @Test func samplingDerivationNucleus() { + let derived = samplingDerivedParameters(from: GenerationOptions(sampling: .random(probabilityThreshold: 0.9))) + #expect(derived.topP == 0.9) + #expect(derived.topK == nil) + #expect(derived.greedyTemperature == nil) + } + + @Test func samplingDerivationNil() { + let derived = samplingDerivedParameters(from: GenerationOptions()) + #expect(derived.topP == nil) + #expect(derived.topK == nil) + #expect(derived.greedyTemperature == nil) + } + + // MARK: - Mapping precedence (custom-wins → sampling-fills → default) + + @Test func samplingFillsWhenNoCustomBlock() { + let params = toGenerateParameters(GenerationOptions(sampling: .random(top: 12))) + #expect(params.topK == 12) // top-k now reaches MLX via sampling + #expect(params.topP == 1.0) // untouched default + } + + @Test func customBlockWinsOverSampling() { + var options = GenerationOptions(sampling: .random(probabilityThreshold: 0.9)) + options[custom: MLXLanguageModel.self] = .init( + kvCache: .default, + userInputProcessing: nil, + additionalContext: nil, + topP: 0.3, + topK: 5 + ) + let params = toGenerateParameters(options) + #expect(params.topP == 0.3) // custom wins over sampling's 0.9 + #expect(params.topK == 5) // custom wins (sampling expressed no top-k) + } + + @Test func greedyMapsToZeroTemperature() { + let params = toGenerateParameters(GenerationOptions(sampling: .greedy)) + #expect(params.temperature == 0) + } + + @Test func explicitTemperatureWinsOverGreedy() { + let params = toGenerateParameters(GenerationOptions(sampling: .greedy, temperature: 0.7)) + #expect(params.temperature == Float(0.7)) + } + @Test func codable() throws { let options = MLXLanguageModel.CustomGenerationOptions( kvCache: .init(