Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
92 changes: 81 additions & 11 deletions Sources/AnyLanguageModel/Models/MLXLanguageModel.swift
Original file line number Diff line number Diff line change
Expand Up @@ -287,6 +287,31 @@ import Foundation
additionalContext?.mapValues { $0.toSendable() }
}

/// Top-p (nucleus) sampling threshold.
///
/// Set this to `nil` to use the backend default (`1.0`, i.e. disabled).
public var topP: Float?

/// Top-k sampling: restricts sampling to the `k` most likely tokens.
///
/// Set this to `nil` or `0` to disable top-k sampling.
public var topK: Int?

/// Min-p sampling threshold, relative to the most likely token's probability.
///
/// Set this to `nil` or `0` to disable min-p sampling.
public var minP: Float?

/// Penalty factor applied to recently generated tokens to reduce repetition.
///
/// Set this to `nil` to disable the repetition penalty.
public var repetitionPenalty: Float?

/// Number of recent tokens considered by the repetition penalty.
///
/// Set this to `nil` to use the backend default.
public var repetitionContextSize: Int?

/// Creates MLX-specific generation options.
///
/// - Parameters:
Expand All @@ -295,14 +320,30 @@ import Foundation
/// template rendering context.
/// - userInputProcessing: Processing to apply to user media before input preparation.
/// Defaults to `nil`, which lets MLX use its default media handling.
/// - topP: Top-p (nucleus) sampling threshold. Defaults to `nil` (backend default).
/// - topK: Top-k sampling count. Defaults to `nil` (disabled).
/// - minP: Min-p sampling threshold. Defaults to `nil` (disabled).
/// - repetitionPenalty: Repetition penalty factor. Defaults to `nil` (disabled).
/// - repetitionContextSize: Repetition-penalty token window. Defaults to `nil`
/// (backend default).
public init(
kvCache: KVCache,
userInputProcessing: UserInputProcessing?,
additionalContext: [String: AnyLanguageModel.JSONValue]?
additionalContext: [String: AnyLanguageModel.JSONValue]?,
topP: Float? = nil,
topK: Int? = nil,
minP: Float? = nil,
repetitionPenalty: Float? = nil,
repetitionContextSize: Int? = nil
) {
self.kvCache = kvCache
self.additionalContext = additionalContext
self.userInputProcessing = userInputProcessing
self.topP = topP
self.topK = topK
self.minP = minP
self.repetitionPenalty = repetitionPenalty
self.repetitionContextSize = repetitionContextSize
}

/// Default MLX generation options used when none are provided at runtime.
Expand Down Expand Up @@ -1214,34 +1255,63 @@ import Foundation

// MARK: - Options Mapping

private func toGenerateParameters(_ options: GenerationOptions) -> MLXLMCommon.GenerateParameters {
/// Derives MLX sampler parameters from the core ``GenerationOptions/sampling`` (`SamplingMode`),
/// so `.sampling` acts as a unified sampling surface across backends (the same one Apple
/// FoundationModels consumes). Returns `nil` for any field the sampling mode doesn't express.
///
/// Precedence at the call sites is custom-block → this (sampling) → existing default, so an
/// explicit `CustomGenerationOptions` value always wins. The `SamplingMode` seed is not
/// forwarded: `MLXLMCommon.GenerateParameters` has no per-call seed field.
func samplingDerivedParameters(
from options: GenerationOptions
) -> (topP: Float?, topK: Int?, greedyTemperature: Float?) {
switch options.sampling?.mode {
case .greedy:
// Greedy = argmax; MLX realizes this with temperature 0.
return (topP: nil, topK: nil, greedyTemperature: 0)
case .topK(let k, _):
return (topP: nil, topK: k, greedyTemperature: nil)
case .nucleus(let threshold, _):
return (topP: Float(threshold), topK: nil, greedyTemperature: nil)
case nil:
return (topP: nil, topK: nil, greedyTemperature: nil)
}
}

func toGenerateParameters(_ options: GenerationOptions) -> MLXLMCommon.GenerateParameters {
let custom = options[custom: MLXLanguageModel.self]
let derived = samplingDerivedParameters(from: options)
return MLXLMCommon.GenerateParameters(
maxTokens: options.maximumResponseTokens,
maxKVSize: custom?.kvCache.maxSize,
kvBits: custom?.kvCache.bits,
kvGroupSize: custom?.kvCache.groupSize ?? 64,
quantizedKVStart: custom?.kvCache.quantizedStart ?? 0,
temperature: Float(options.temperature ?? 0.6),
topP: 1.0,
repetitionPenalty: nil,
repetitionContextSize: 20
temperature: Float(options.temperature ?? derived.greedyTemperature.map(Double.init) ?? 0.6),
topP: custom?.topP ?? derived.topP ?? 1.0,
topK: custom?.topK ?? derived.topK ?? 0,
minP: custom?.minP ?? 0.0,
repetitionPenalty: custom?.repetitionPenalty,
repetitionContextSize: custom?.repetitionContextSize ?? 20
)
}

/// Builds MLX parameters tuned for structured generation.
private func toStructuredGenerateParameters(_ options: GenerationOptions) -> MLXLMCommon.GenerateParameters {
func toStructuredGenerateParameters(_ options: GenerationOptions) -> MLXLMCommon.GenerateParameters {
let custom = options[custom: MLXLanguageModel.self]
let derived = samplingDerivedParameters(from: options)
return MLXLMCommon.GenerateParameters(
maxTokens: options.maximumResponseTokens,
maxKVSize: custom?.kvCache.maxSize,
kvBits: custom?.kvCache.bits,
kvGroupSize: custom?.kvCache.groupSize ?? 64,
quantizedKVStart: custom?.kvCache.quantizedStart ?? 0,
temperature: Float(options.temperature ?? 0.2),
topP: 0.95,
repetitionPenalty: 1.1,
repetitionContextSize: 64
temperature: Float(options.temperature ?? derived.greedyTemperature.map(Double.init) ?? 0.2),
topP: custom?.topP ?? derived.topP ?? 0.95,
topK: custom?.topK ?? derived.topK ?? 0,
minP: custom?.minP ?? 0.0,
repetitionPenalty: custom?.repetitionPenalty ?? 1.1,
repetitionContextSize: custom?.repetitionContextSize ?? 64
)
}

Expand Down
16 changes: 16 additions & 0 deletions Sources/AnyLanguageModel/Models/SystemLanguageModel.swift
Original file line number Diff line number Diff line change
Expand Up @@ -384,6 +384,22 @@
@available(macOS 26.0, iOS 26.0, watchOS 26.0, tvOS 26.0, visionOS 26.0, *)
extension GenerationOptions {
fileprivate func toFoundationModels() -> FoundationModels.GenerationOptions {
// OS 27 support: `#if compiler(>=6.4)` compiles the 27 path only on the Xcode 27
// (beta) toolchain — Xcode 26 skips it entirely and never sees the 27-only symbol.
// `#available` then gates it to OS 27 at runtime; deployment floor stays 26 via the
// fall-through below. Self-contained (no FM-parity mapping required on this branch).
// KNOBS to confirm on the first Xcode 27 build: the compiler gate version, and the
// initializer's parameter labels if Apple changed them.
#if compiler(>=6.4)
if #available(macOS 27.0, iOS 27.0, visionOS 27.0, *) {
return FoundationModels.GenerationOptions(
temperature: temperature,
maximumResponseTokens: maximumResponseTokens,
toolCallingMode: nil
)
}
#endif

var options = FoundationModels.GenerationOptions()

if let temperature = self.temperature {
Expand Down
62 changes: 62 additions & 0 deletions Tests/AnyLanguageModelTests/CustomGenerationOptionsTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -901,6 +901,68 @@ struct GeminiCustomOptionsTests {
#expect(retrieved?.kvCache.quantizedStart == 256)
}

// MARK: - SamplingMode → MLX derivation

@Test func samplingDerivationGreedy() {
let derived = samplingDerivedParameters(from: GenerationOptions(sampling: .greedy))
#expect(derived.topP == nil)
#expect(derived.topK == nil)
#expect(derived.greedyTemperature == 0)
}

@Test func samplingDerivationTopK() {
let derived = samplingDerivedParameters(from: GenerationOptions(sampling: .random(top: 40, seed: 7)))
#expect(derived.topK == 40)
#expect(derived.topP == nil)
#expect(derived.greedyTemperature == nil)
}

@Test func samplingDerivationNucleus() {
let derived = samplingDerivedParameters(from: GenerationOptions(sampling: .random(probabilityThreshold: 0.9)))
#expect(derived.topP == 0.9)
#expect(derived.topK == nil)
#expect(derived.greedyTemperature == nil)
}

@Test func samplingDerivationNil() {
let derived = samplingDerivedParameters(from: GenerationOptions())
#expect(derived.topP == nil)
#expect(derived.topK == nil)
#expect(derived.greedyTemperature == nil)
}

// MARK: - Mapping precedence (custom-wins → sampling-fills → default)

@Test func samplingFillsWhenNoCustomBlock() {
let params = toGenerateParameters(GenerationOptions(sampling: .random(top: 12)))
#expect(params.topK == 12) // top-k now reaches MLX via sampling
#expect(params.topP == 1.0) // untouched default
}

@Test func customBlockWinsOverSampling() {
var options = GenerationOptions(sampling: .random(probabilityThreshold: 0.9))
options[custom: MLXLanguageModel.self] = .init(
kvCache: .default,
userInputProcessing: nil,
additionalContext: nil,
topP: 0.3,
topK: 5
)
let params = toGenerateParameters(options)
#expect(params.topP == 0.3) // custom wins over sampling's 0.9
#expect(params.topK == 5) // custom wins (sampling expressed no top-k)
}

@Test func greedyMapsToZeroTemperature() {
let params = toGenerateParameters(GenerationOptions(sampling: .greedy))
#expect(params.temperature == 0)
}

@Test func explicitTemperatureWinsOverGreedy() {
let params = toGenerateParameters(GenerationOptions(sampling: .greedy, temperature: 0.7))
#expect(params.temperature == Float(0.7))
}

@Test func codable() throws {
let options = MLXLanguageModel.CustomGenerationOptions(
kvCache: .init(
Expand Down