diff --git a/Documentation/Benchmarks.md b/Documentation/Benchmarks.md index 2a3d07ed2..f4d9ef757 100644 --- a/Documentation/Benchmarks.md +++ b/Documentation/Benchmarks.md @@ -745,6 +745,46 @@ AVERAGE 31.7 21.5 0.5 9.7 - 126.7 ====================================================================== ``` +### Offline throughput (M5 Pro) + +A single fused offline graph (`mel[1,128,3072] → speaker_preds`, 30.72 s window) exported via the +NeMo offline path — one CoreML call per window, no streaming state. ComputeUnit.ALL, median of 120 +runs after 12 warmup: + +| variant | model-exec (mel → preds) | RTFx | +|---|---:|---:| +| **fp16** | **10.65 ms** | **2884×** | +| 6-bit palettized | 10.93 ms | 2809× | + +End-to-end incl. mel (fp16): 12.49 ms · 2459×. One fused GPU graph — no per-call dispatch or +ANE→GPU handoff. Numerical parity vs the PyTorch reference: 100% speaker-argmax agreement (fp16). + +### Offline diarizer (whole-file, Swift) + +`OfflineSortformerDiarizer` runs the fused offline model end-to-end: mel extraction → fused graph +per 30.72 s window (no streaming state) → timeline. Run with `fluidaudio sortformer --offline` +(`--palettized` for the 6-bit set). Models: `FluidInference/diar-streaming-sortformer-coreml` +`v3/{fp16,palettized}/SortformerOffline_v2.1.mlmodelc`; NeMo-reference parity = 100% speaker-argmax +(fp16), 96.4% (6-bit palettized). + +**Throughput** is excellent — ~1000–1400× RTFx (one fused call per window, no per-chunk state). Voice +detection matches streaming exactly (identical Miss/FA on AMI). + +**Quality caveat — not for long multi-speaker audio.** Each 30.72 s window diarizes independently +with no speaker cache, so on long meetings with several speakers it produces large speaker confusion. +AMI-SDM (collar 0.25, full test set), same harness: + +| | DER | Miss | FA | Speaker confusion | RTFx | +|---|---|---|---|---|---| +| Streaming (`highContextV2_1`) | **26.4%** | 23.4 | 0.6 | **2.4** | 835× | +| Offline (whole-file) | 56.7% | 23.4 | 0.6 | **32.7** | 1418× | + +Detection is identical; the entire gap is speaker confusion the streaming `spkcache` avoids by +construction (accumulating speaker profiles across the whole history). Cross-window re-stitching does +not recover it — the confusion is generated *within* each window — so **use offline for short clips +(≤ ~30 s), few-speaker audio, or throughput-bound batch jobs, and use the streaming variants for +accurate long-form multi-speaker diarization.** + ## LS-EEND Streaming Diarization A research prototype from Westlake University for streaming speaker diarization. diff --git a/Documentation/Diarization/Sortformer.md b/Documentation/Diarization/Sortformer.md index 14733d3f1..5da4acca3 100644 --- a/Documentation/Diarization/Sortformer.md +++ b/Documentation/Diarization/Sortformer.md @@ -415,15 +415,92 @@ let config = DiarizerTimelineConfig( ## Model Variants -Three CoreML models are available on HuggingFace: +CoreML models live on HuggingFace under +[FluidInference/diar-streaming-sortformer-coreml](https://huggingface.co/FluidInference/diar-streaming-sortformer-coreml), +in `v3/fp16/` (default) and `v3/palettized/` (see [Precision](#precision-fp16-vs-palettized)). +The `v3/` set is the BNNS-fixed rebuild — the older root-level models hit a +"tensor as both input and output" graph-compile crash on newer BNNS ([#726](https://github.com/FluidInference/FluidAudio/issues/726)). -| Variant | File | Config | -|---------|------|--------| -| Default | `Sortformer.mlmodelc` | `SortformerConfig.default` | -| Balanced | `SortformerNvidiaLow.mlmodelc` | `SortformerConfig.balancedV2_1` | -| High Context | `SortformerNvidiaHigh.mlmodelc` | `SortformerConfig.highContextV2_1` | +| Variant | Config | File (under `v3//`) | Output latency | +|---------|--------|-------------------------------|----------------| +| Fast (v2.1) | `.fastV2_1` | `Sortformer_v2.1.mlmodelc` | ~1.04 s | +| Balanced (v2.1) | `.balancedV2_1` | `SortformerNvidiaLow_v2.1.mlmodelc` | ~1.5 s | +| High Context (v2.1) | `.highContextV2_1` | `SortformerNvidiaHigh_v2.1.mlmodelc` | ~3.5 s | +| Efficient (v2.1) | `.efficientV2_1` | `SortformerEfficient_v2.1.mlmodelc` | ~2.0 s (highest throughput) | + +(The `v2` weight variants — `.fastV2`, `.balancedV2`, `.highContextV2` — ship alongside each `v2.1`.) **Important:** Each model has baked-in static shapes. You must use the matching configuration. +The diarizer logs a loud config-mismatch error at `initialize()` if the `SortformerConfig` does +not match the streaming parameters embedded in the model (issue [#726](https://github.com/FluidInference/FluidAudio/issues/726)). + +### Precision: fp16 vs palettized + +Each variant is built at two weight precisions, selected via `SortformerConfig.precision`: + +| Precision | Head weights | `highContextV2_1` RAM | DER impact | When | +|-----------|--------------|----------------------|------------|------| +| `.fp16` (default) | full | ~2.4 GB | baseline | Best accuracy; Apple Silicon Macs, recent iPhones/iPads | +| `.palettized` | 6-bit k-means LUT | ~330 MB | +0.9 pp avg (streaming); larger on high-context | RAM-constrained / older devices | + +```swift +var config = SortformerConfig.highContextV2_1 +config.precision = .palettized // ~2.4 GB -> ~330 MB +``` + +Palettization is **opt-in**, not the default, because 6-bit perturbs the embeddings and the +streaming speaker-cache cascades that drift over time (worse on the high-context variant). For +offline/batch or RAM-limited devices it's a good trade; for best streaming DER keep `.fp16`. + +**Old-device compute units.** The ~2.4 GB fp16 high-context head triggers a multi-minute ANE +program-compile hang on RAM-constrained devices (A14, ~4 GB). `recommendedComputeUnits(for:)` +auto-falls-back those variants to `.cpuOnly` on <8 GB devices; everything else (including the +~330 MB palettized high-context head, which loads fine on ANE) keeps `.all`. Pass `computeUnits:` +explicitly to override. On A14 the recommended path is `precision = .palettized`. + +### Benchmarks + +Streaming DER/RTFx and offline-throughput numbers live in +[Documentation/Benchmarks.md](../Benchmarks.md#sortformer-streaming-diarization). + +## Offline (whole-file) mode + +When the entire audio is available up front, `OfflineSortformerDiarizer` runs the **fused offline +model** — a single graph `mel -> speaker_preds` over a fixed 30.72 s window (3072 mel → 384 output +frames) with **no streaming state** (no spkcache/FIFO threaded across calls). One CoreML call per +window makes it the fastest path for batch diarization (~2880× RTFx model-exec on M5 Pro; see +[Benchmarks](#benchmarks)). The model ships at both precisions: +`v3/fp16/SortformerOffline_v2.1.mlmodelc` and `v3/palettized/SortformerOffline_v2.1.mlmodelc`. + +This differs from `SortformerDiarizer.processComplete(...)`, which runs the *streaming* model over +all chunks (threading speaker-cache state). Use the offline diarizer when you have the whole file and +want maximum throughput on short or few-speaker audio. + +**Scope — short clips / few speakers / throughput.** Each 30.72 s window is diarized independently +with no speaker cache, so long multi-speaker audio accumulates large speaker confusion: on AMI-SDM +the offline path scores ~56% DER vs ~26% for the streaming `highContextV2_1` (voice detection is +identical — the gap is entirely speaker confusion the `spkcache` prevents). Cross-window re-stitching +can't recover it because the confusion is generated within each window. **For accurate long-form +multi-speaker diarization use the streaming variants; reach for offline for ≤ ~30 s clips, +few-speaker audio, or throughput-bound batch jobs.** Longer inputs are tiled into 30.72 s windows +(`overlapOutputFrames` controls the overlap) with activity-based stitching across boundaries. + +```swift +let diarizer = OfflineSortformerDiarizer(config: .offlineV2_1) +try await diarizer.initializeFromHuggingFace() // or initialize(modelPath:) + +let timeline = try diarizer.processComplete(audioSamples, sourceSampleRate: 16_000) +// Or load + resample a file directly: +let fileTimeline = try diarizer.processComplete(audioFileURL: audioURL) + +for (index, speaker) in timeline.speakers { + for segment in speaker.finalizedSegments { + print("Speaker \(index): \(segment.startTime)s - \(segment.endTime)s") + } +} +``` + +CLI: `fluidaudio sortformer audio.wav --offline` (add `--palettized` for the 6-bit set). ## Usage Examples diff --git a/Sources/FluidAudio/Diarizer/Offline/Clustering/KMeansClustering.swift b/Sources/FluidAudio/Diarizer/Offline/Clustering/KMeansClustering.swift index 296c03624..bbda2029a 100644 --- a/Sources/FluidAudio/Diarizer/Offline/Clustering/KMeansClustering.swift +++ b/Sources/FluidAudio/Diarizer/Offline/Clustering/KMeansClustering.swift @@ -123,9 +123,10 @@ struct KMeansClustering { best = result } } - return best ?? clusterWithCentroids( - embeddings: embeddings, numClusters: numClusters, - maxIterations: maxIterations, seed: baseSeed) + return best + ?? clusterWithCentroids( + embeddings: embeddings, numClusters: numClusters, + maxIterations: maxIterations, seed: baseSeed) } private static func normalizeEmbeddings(_ embeddings: [[Double]]) -> [[Double]] { diff --git a/Sources/FluidAudio/Diarizer/Sortformer/Offline/OfflineSortformerDiarizer.swift b/Sources/FluidAudio/Diarizer/Sortformer/Offline/OfflineSortformerDiarizer.swift new file mode 100644 index 000000000..16951f128 --- /dev/null +++ b/Sources/FluidAudio/Diarizer/Sortformer/Offline/OfflineSortformerDiarizer.swift @@ -0,0 +1,381 @@ +@preconcurrency import CoreML +import Foundation +import OSLog + +// MARK: - Configuration + +/// Configuration for the offline (whole-window) Sortformer diarizer. +/// +/// The offline model is a single fused graph `mel -> speaker_preds` over a fixed window +/// (3072 mel frames -> 384 output frames, 30.72s) with no streaming state. Long audio is tiled +/// into overlapping windows and the per-window speaker columns are stitched back together by +/// ``SortformerSpeakerStitcher``. +public struct OfflineSortformerConfig: Sendable { + /// Head-weight precision (selects the HuggingFace subdirectory). + public var precision: ModelNames.Sortformer.ModelPrecision + + /// Output (post-subsampling) frames per window. Fixed by the exported model. + public let windowOutputFrames: Int = 384 + + /// Encoder subsampling factor (mel frames per output frame). + public let subsamplingFactor: Int = 8 + + /// Number of speaker slots (fixed at 4 for current models). + public let numSpeakers: Int = 4 + + /// Mel filterbank feature count. + public let melFeatures: Int = 128 + + /// Model sample rate in Hz. + public let sampleRate: Int = 16000 + + /// Mel window / stride in samples (25ms / 10ms). + public let melWindow: Int = 400 + public let melStride: Int = 160 + + /// Duration of one output frame in seconds (subsampling * stride / sampleRate). + public var frameDurationSeconds: Float { + Float(subsamplingFactor) * Float(melStride) / Float(sampleRate) + } + + /// Output-frame overlap between consecutive windows used to stitch speaker identities. + /// ~8s of context — enough to correlate speaker activity, well under one window. + public var overlapOutputFrames: Int = 100 + + /// Mel frames per window (`windowOutputFrames * subsamplingFactor`). + public var windowMelFrames: Int { windowOutputFrames * subsamplingFactor } + + public init( + precision: ModelNames.Sortformer.ModelPrecision = .fp16, + overlapOutputFrames: Int = 100 + ) { + self.precision = precision + self.overlapOutputFrames = overlapOutputFrames + } + + /// Default offline configuration (fp16, ~8s overlap). + public static let offlineV2_1 = OfflineSortformerConfig() +} + +// MARK: - Model Container + +/// Loads and runs the fused offline Sortformer model (`mel -> speaker_preds`). +public struct OfflineSortformerModels { + public let mainModel: MLModel + public let compilationDuration: TimeInterval + + private let config: OfflineSortformerConfig + private let memoryOptimizer: ANEMemoryOptimizer + private let melArray: MLMultiArray + private let melLengthArray: MLMultiArray + + public init( + config: OfflineSortformerConfig, + main: MLModel, + compilationDuration: TimeInterval = 0 + ) throws { + self.config = config + self.mainModel = main + self.compilationDuration = compilationDuration + self.memoryOptimizer = .init() + // Model input is channels-first [1, melFeatures, windowMelFrames]. + self.melArray = try memoryOptimizer.createAlignedArray( + shape: [1, NSNumber(value: config.melFeatures), NSNumber(value: config.windowMelFrames)], + dataType: .float32) + self.melLengthArray = try memoryOptimizer.createAlignedArray(shape: [1], dataType: .int32) + } + + /// Run the offline model on one window. + /// + /// - Parameters: + /// - melTimeMajor: Mel features for the window, flat `[validMelFrames * melFeatures]` in + /// time-major order (frame `t`, feature `c` at `t * melFeatures + c`) as produced by + /// ``AudioMelSpectrogram/computeFlatTransposed(audio:)``. + /// - validMelFrames: Number of valid mel frames (≤ window). Shorter windows are zero-padded + /// and masked via `mel_length`. + /// - Returns: Per-frame speaker probabilities, flat `[windowOutputFrames * numSpeakers]` + /// (frame-major). + public func runOffline(melTimeMajor: [Float], validMelFrames: Int) throws -> [Float] { + let featureCount = config.melFeatures + let windowFrames = config.windowMelFrames + let frames = min(validMelFrames, windowFrames) + + let dst = melArray.dataPointer.assumingMemoryBound(to: Float.self) + // Transpose time-major [t, c] -> channels-first [c, t]; zero-pad the tail. + for t in 0.. OfflineSortformerModels { + let start = Date() + let compiledURL: URL + if modelPath.pathExtension == "mlmodelc" { + compiledURL = modelPath + } else { + compiledURL = try await MLModel.compileModel(at: modelPath) + } + let mlConfig = MLModelConfiguration() + mlConfig.computeUnits = .all + let model = try MLModel(contentsOf: compiledURL, configuration: mlConfig) + let duration = Date().timeIntervalSince(start) + logger.info("Loaded offline Sortformer model in \(String(format: "%.2f", duration))s") + return try OfflineSortformerModels(config: config, main: model, compilationDuration: duration) + } + + /// Download (if needed) and load the offline model from HuggingFace. + public static func loadFromHuggingFace( + config: OfflineSortformerConfig = .offlineV2_1, + cacheDirectory: URL? = nil, + computeUnits: MLComputeUnits = .all, + progressHandler: DownloadUtils.ProgressHandler? = nil + ) async throws -> OfflineSortformerModels { + let start = Date() + + let directory: URL + if let cache = cacheDirectory { + directory = cache + } else { + directory = FileManager.default.urls(for: .applicationSupportDirectory, in: .userDomainMask)[0] + .appendingPathComponent("FluidAudio/Models") + } + + let bundle = ModelNames.Sortformer.offlineBundle(precision: config.precision) + logger.info("Downloading offline Sortformer model: \(bundle)...") + + let models = try await DownloadUtils.loadModels( + .sortformer, + modelNames: [bundle], + directory: directory, + computeUnits: computeUnits, + variant: bundle, + progressHandler: progressHandler + ) + + guard let model = models[bundle] else { + throw SortformerError.modelLoadFailed("Failed to load offline Sortformer model from HuggingFace") + } + + let duration = Date().timeIntervalSince(start) + logger.info("Offline Sortformer model loaded from HuggingFace in \(String(format: "%.2f", duration))s") + return try OfflineSortformerModels(config: config, main: model, compilationDuration: duration) + } +} + +// MARK: - Diarizer + +/// Whole-file (offline) speaker diarizer built on the fused Sortformer graph. +/// +/// Faster than the streaming path when the entire audio is available up front: each 30.72s window +/// is one CoreML call (no per-chunk state threading). Windows overlap and are stitched into a +/// single, globally speaker-consistent ``DiarizerTimeline`` via ``SortformerSpeakerStitcher``. +public final class OfflineSortformerDiarizer { + + private let logger = AppLogger(category: "OfflineSortformerDiarizer") + public let config: OfflineSortformerConfig + private let timelineConfig: DiarizerTimelineConfig + private let melSpectrogram: AudioMelSpectrogram + private let lock = NSLock() + private var _models: OfflineSortformerModels? + + public var isAvailable: Bool { + withLock { _models != nil } + } + + /// Execute a closure while holding the lock (async-safe: the lock ops stay in a sync frame). + private func withLock(_ body: () throws -> T) rethrows -> T { + lock.lock() + defer { lock.unlock() } + return try body() + } + + public init( + config: OfflineSortformerConfig = .offlineV2_1, + timelineConfig: DiarizerTimelineConfig? = nil + ) { + self.config = config + self.timelineConfig = + timelineConfig + ?? DiarizerTimelineConfig.default( + numSpeakers: config.numSpeakers, + frameDurationSeconds: config.frameDurationSeconds) + self.melSpectrogram = AudioMelSpectrogram() + } + + /// Initialize from a local model path. + public func initialize(modelPath: URL) async throws { + let models = try await OfflineSortformerModels.load(config: config, modelPath: modelPath) + withLock { _models = models } + } + + /// Initialize from HuggingFace. + public func initializeFromHuggingFace( + computeUnits: MLComputeUnits = .all, + progressHandler: DownloadUtils.ProgressHandler? = nil + ) async throws { + let models = try await OfflineSortformerModels.loadFromHuggingFace( + config: config, computeUnits: computeUnits, progressHandler: progressHandler) + withLock { _models = models } + } + + /// Initialize with pre-loaded models. + public func initialize(models: OfflineSortformerModels) { + withLock { _models = models } + } + + /// Diarize a complete audio buffer. + /// + /// - Parameters: + /// - samples: Mono audio buffer. + /// - sourceSampleRate: Sample rate of `samples`, or `nil` if already at the model rate. + /// - Returns: A finalized timeline with globally consistent speaker IDs. + public func processComplete( + _ samples: [Float], + sourceSampleRate: Double? = nil + ) throws -> DiarizerTimeline { + lock.lock() + defer { lock.unlock() } + guard let models = _models else { throw SortformerError.notInitialized } + + let normalized = try normalizeSamples(samples, sourceSampleRate: sourceSampleRate) + guard !normalized.isEmpty else { + return DiarizerTimeline(config: timelineConfig) + } + + let (melFlat, _, numMelFrames) = melSpectrogram.computeFlatTransposed(audio: normalized) + let featureCount = config.melFeatures + let windowMel = config.windowMelFrames + let outPerWindow = config.windowOutputFrames + let speakers = config.numSpeakers + let sub = config.subsamplingFactor + + guard numMelFrames > 0 else { + return DiarizerTimeline(config: timelineConfig) + } + + let overlapOut = max(0, min(config.overlapOutputFrames, outPerWindow - 1)) + let hopOut = outPerWindow - overlapOut + let hopMel = hopOut * sub + + let totalOut = (numMelFrames + sub - 1) / sub // ceil to cover the tail + var global = [Float](repeating: 0, count: totalOut * speakers) + var filled = [Bool](repeating: false, count: totalOut) + + var melStart = 0 + var windowIndex = 0 + while melStart < numMelFrames { + let validMel = min(windowMel, numMelFrames - melStart) + let slice = Array(melFlat[(melStart * featureCount)..<((melStart + validMel) * featureCount)]) + let preds = try models.runOffline(melTimeMajor: slice, validMelFrames: validMel) + + let validOut = min(outPerWindow, (validMel + sub - 1) / sub) + let gStart = melStart / sub + + // Align this window's speaker columns to the global timeline over the overlap region. + var mapping = Array(0.. 0, overlapOut > 0 { + let ov = min(overlapOut, validOut, max(0, totalOut - gStart)) + if ov > 0 { + var gOverlap = [Float](repeating: 0, count: ov * speakers) + var wOverlap = [Float](repeating: 0, count: ov * speakers) + for j in 0.. DiarizerTimeline { + let converter = AudioConverter(sampleRate: Double(config.sampleRate)) + let audio = try converter.resampleAudioFile(audioFileURL) + return try processComplete(audio, sourceSampleRate: nil) + } + + private func normalizeSamples(_ samples: [Float], sourceSampleRate: Double?) throws -> [Float] { + guard let sourceSampleRate, sourceSampleRate != Double(config.sampleRate) else { + return samples + } + return try AudioConverter(sampleRate: Double(config.sampleRate)) + .resample(samples, from: sourceSampleRate) + } +} diff --git a/Sources/FluidAudio/Diarizer/Sortformer/Offline/SortformerSpeakerStitcher.swift b/Sources/FluidAudio/Diarizer/Sortformer/Offline/SortformerSpeakerStitcher.swift new file mode 100644 index 000000000..43a91c426 --- /dev/null +++ b/Sources/FluidAudio/Diarizer/Sortformer/Offline/SortformerSpeakerStitcher.swift @@ -0,0 +1,91 @@ +import Foundation + +/// Cross-window speaker-permutation alignment for offline Sortformer. +/// +/// The offline model processes independent 30.72s windows with no speaker cache, so window +/// `N`'s speaker columns are in an arbitrary order relative to the accumulated global timeline. +/// Each window overlaps the previous one by a fixed number of output frames; over that overlap +/// the same speakers are active in both, so we recover the permutation that best matches the +/// two windows' per-frame speaker activity and remap the new window into global speaker IDs. +/// +/// Speaker count is small and fixed (4), so we brute-force all `numSpeakers!` bijections — 24 +/// for 4 speakers — and keep the one maximizing summed activity agreement. No external solver. +enum SortformerSpeakerStitcher { + + /// Best bijection between a new window's speaker columns and the global speaker columns over + /// the overlap region. + /// + /// - Parameters: + /// - global: Overlap-region probabilities already committed to the global timeline, + /// flat `[frames * numSpeakers]` (frame-major), in global speaker order. + /// - window: The new window's probabilities for the same overlap frames, flat + /// `[frames * numSpeakers]`, in the window's own speaker order. + /// - frames: Number of overlap frames compared. + /// - numSpeakers: Speaker-column count (4 for current models). + /// - Returns: `mapping` of length `numSpeakers` where `mapping[windowSpeaker] == globalSpeaker`. + /// Identity (`[0, 1, ...]`) when there is nothing to align on. + static func alignment( + global: [Float], + window: [Float], + frames: Int, + numSpeakers: Int + ) -> [Int] { + let identity = Array(0.. 0, numSpeakers > 0, + global.count >= frames * numSpeakers, + window.count >= frames * numSpeakers + else { + return identity + } + + // correlation[g][w] = summed per-frame activity agreement between global speaker g and + // window speaker w over the overlap. Higher = more likely the same physical speaker. + var correlation = [[Float]]( + repeating: [Float](repeating: 0, count: numSpeakers), count: numSpeakers) + for f in 0.. bestScore { + bestScore = score + bestPerm = candidate + } + } + + // Invert perm[g] = w into mapping[w] = g (window column -> global column). + var mapping = identity + for g in 0.. Void) { + if k == array.count { + body(array) + return + } + for i in k..(_ body: () throws -> T) rethrows -> T { lock.lock() @@ -126,6 +160,8 @@ public final class SortformerDiarizer: Diarizer { /// Initialize with pre-loaded models. public func initialize(models: SortformerModels) { + validateConfigMatch(models) + lock.lock() defer { lock.unlock() } diff --git a/Sources/FluidAudio/Diarizer/Sortformer/SortformerModelInference.swift b/Sources/FluidAudio/Diarizer/Sortformer/SortformerModelInference.swift index 22415f978..0bd393c7d 100644 --- a/Sources/FluidAudio/Diarizer/Sortformer/SortformerModelInference.swift +++ b/Sources/FluidAudio/Diarizer/Sortformer/SortformerModelInference.swift @@ -58,12 +58,14 @@ extension SortformerModels { /// - Parameters: /// - preprocessorPath: Path to SortformerPreprocessor.mlpackage /// - mainModelPath: Path to Sortformer.mlpackage - /// - configuration: Optional MLModel configuration + /// - computeUnits: CoreML compute units. Pass `nil` (default) to auto-resolve via + /// `recommendedComputeUnits(for:)`, which avoids the multi-minute ANE compile hang on + /// RAM-constrained devices for the large fp16 high-context variants (issue #726). /// - Returns: Loaded SortformerModels public static func load( config: SortformerConfig, mainModelPath: URL, - configuration: MLModelConfiguration? = nil + computeUnits: MLComputeUnits? = nil ) async throws -> SortformerModels { logger.info("Loading Sortformer models from local paths (combined pipeline mode)") @@ -73,11 +75,10 @@ extension SortformerModels { logger.info("Compiling main model...") let compiledMainModelURL = try await MLModel.compileModel(at: mainModelPath) - // Load main model - .all lets CoreML pick optimal compute units let mainConfig = MLModelConfiguration() - mainConfig.computeUnits = .all + mainConfig.computeUnits = computeUnits ?? recommendedComputeUnits(for: config) let mainModel = try MLModel(contentsOf: compiledMainModelURL, configuration: mainConfig) - logger.info("Loaded main Sortformer model") + logger.info("Loaded main Sortformer model (computeUnits=\(mainConfig.computeUnits.rawValue))") let duration = Date().timeIntervalSince(startTime) logger.info("Models loaded in \(String(format: "%.2f", duration))s") @@ -95,18 +96,49 @@ extension SortformerModels { return MLModelConfigurationUtils.defaultConfiguration(computeUnits: isCI ? .cpuAndNeuralEngine : .all) } + /// Memory (GiB) below which the large fp16 high-context head is loaded on CPU only. + /// A14-class devices ship ~4GB; iPhone 14 has 6GB. 8GB cleanly separates those from + /// Apple Silicon Macs / recent iPads that compile the ANE program without hanging. + private static let highContextAneRamThresholdGiB: Double = 8 + + /// Pick CoreML compute units for `config`, defaulting to `.all` but avoiding a known + /// load-time pathology: the ~2.4GB fp16 high-context head triggers a multi-minute ANE + /// program-compile hang on RAM-constrained devices (A14, ~4GB), which `.cpuOnly` avoids + /// (issue #726). The palettized high-context head (~330MB) loads fine on ANE, so it keeps + /// `.all`; only the large fp16 high-context variants on low-RAM devices fall back. + public static func recommendedComputeUnits(for config: SortformerConfig) -> MLComputeUnits { + let isLargeHighContext = + (config.modelVariant == .highContextV2 || config.modelVariant == .highContextV2_1) + && config.precision == .fp16 + guard isLargeHighContext else { return .all } + + let physicalGiB = Double(ProcessInfo.processInfo.physicalMemory) / 1_073_741_824 + guard physicalGiB < highContextAneRamThresholdGiB else { return .all } + + logger.warning( + """ + Loading large fp16 \(config.modelVariant.map(String.init(describing:)) ?? "high-context") \ + head on a \(String(format: "%.1f", physicalGiB))GB device with .cpuOnly to avoid the \ + multi-minute ANE compile hang (issue #726). Use SortformerConfig.precision = .palettized \ + for an ANE-friendly ~330MB build, or pass computeUnits explicitly to override. + """ + ) + return .cpuOnly + } + /// Load Sortformer models from HuggingFace. /// /// Downloads models from FluidInference/diar-streaming-sortformer-coreml if not cached. /// /// - Parameters: /// - cacheDirectory: Directory to cache downloaded models (defaults to app support) - /// - computeUnits: CoreML compute units to use (default: cpuOnly for consistency) + /// - computeUnits: CoreML compute units. Pass `nil` (default) to auto-resolve via + /// `recommendedComputeUnits(for:)` (issue #726). /// - Returns: Loaded SortformerModels public static func loadFromHuggingFace( config: SortformerConfig, cacheDirectory: URL? = nil, - computeUnits: MLComputeUnits = .all, + computeUnits: MLComputeUnits? = nil, progressHandler: DownloadUtils.ProgressHandler? = nil ) async throws -> SortformerModels { logger.info("Loading Sortformer models from HuggingFace...") @@ -131,11 +163,13 @@ extension SortformerModels { // Download models if needed + let resolvedComputeUnits = computeUnits ?? recommendedComputeUnits(for: config) + let models = try await DownloadUtils.loadModels( .sortformer, modelNames: [bundle], directory: directory, - computeUnits: computeUnits, + computeUnits: resolvedComputeUnits, variant: bundle, progressHandler: progressHandler ) @@ -156,6 +190,52 @@ extension SortformerModels { } } +// MARK: - Embedded Configuration + +extension SortformerModels { + + /// The model-shape-defining streaming parameters the converter embeds in the CoreML model + /// metadata. These determine the input tensor shapes and must match the host config. + /// + /// Note: `spkcache_update_period` is intentionally excluded — `SortformerConfig.init` clamps + /// it host-side (`max(min(period, fifoLen+chunkLen), chunkLen)`), so the host value legitimately + /// differs from the raw value baked into the model and is not a compatibility signal. + public struct EmbeddedConfig: Equatable, Sendable { + public let chunkLen: Int + public let chunkLeftContext: Int + public let chunkRightContext: Int + public let fifoLen: Int + public let spkcacheLen: Int + } + + /// The variant-defining streaming parameters the converter writes into the CoreML model + /// metadata. Returns `nil` for older exports that don't carry them. Used to detect a + /// `SortformerConfig` that doesn't match the model — a mismatch yields incorrect and much + /// slower diarization (issue #726). + public var embeddedConfig: EmbeddedConfig? { + guard let meta = mainModel.modelDescription.metadata[.creatorDefinedKey] as? [String: String] else { + return nil + } + func value(_ key: String) -> Int? { meta[key].flatMap(Int.init) } + guard + let chunkLen = value("chunk_len"), + let chunkLeftContext = value("chunk_left_context"), + let chunkRightContext = value("chunk_right_context"), + let fifoLen = value("fifo_len"), + let spkcacheLen = value("spkcache_len") + else { + return nil + } + return EmbeddedConfig( + chunkLen: chunkLen, + chunkLeftContext: chunkLeftContext, + chunkRightContext: chunkRightContext, + fifoLen: fifoLen, + spkcacheLen: spkcacheLen + ) + } +} + // MARK: - Main Model Inference extension SortformerModels { diff --git a/Sources/FluidAudio/Diarizer/Sortformer/SortformerTypes.swift b/Sources/FluidAudio/Diarizer/Sortformer/SortformerTypes.swift index c34511353..405e6c365 100644 --- a/Sources/FluidAudio/Diarizer/Sortformer/SortformerTypes.swift +++ b/Sources/FluidAudio/Diarizer/Sortformer/SortformerTypes.swift @@ -13,6 +13,12 @@ public struct SortformerConfig: Sendable { public let modelVariant: ModelVariant? + /// Weight-precision build to download for `modelVariant`. `.fp16` (default) gives the best + /// DER; `.palettized` is the 6-bit, ~2.5x-smaller set for RAM-constrained devices (issue #726). + /// Mutable post-construction (e.g. `var c = .highContextV2_1; c.precision = .palettized`) since + /// it does not affect tensor shapes and has no bearing on `isCompatible`. + public var precision: ModelNames.Sortformer.ModelPrecision + /// Number of speaker slots (fixed at 4 for current model) public let numSpeakers: Int = 4 @@ -194,9 +200,25 @@ public struct SortformerConfig: Sendable { spkcacheUpdatePeriod: 300 ) + /// Higher-throughput streaming config with Sortformer v2.1 weights (~2s output latency). + /// Same context as `fastV2_1` but a larger 25-frame chunk: per-inference cost is dominated + /// by the static speaker-cache + FIFO context, so a bigger chunk advances ~4x more audio per + /// call at near-identical latency (~4x real-time factor vs `fastV2_1`). Use when ~2s latency + /// is acceptable and throughput matters. + public static let efficientV2_1 = SortformerConfig( + modelVariant: .efficientV2_1, + chunkLen: 25, + chunkLeftContext: 1, + chunkRightContext: 7, + fifoLen: 40, + spkcacheLen: 188, + spkcacheUpdatePeriod: 31 + ) + /// - Warning: If you don't use one of the default configurations, you must use a local model converted with that configuration. public init( modelVariant: ModelVariant? = .fastV2_1, + precision: ModelNames.Sortformer.ModelPrecision = .fp16, chunkLen: Int = 6, chunkLeftContext: Int = 1, chunkRightContext: Int = 7, @@ -213,6 +235,7 @@ public struct SortformerConfig: Sendable { debugMode: Bool = false ) { self.modelVariant = modelVariant + self.precision = precision self.chunkLen = max(1, chunkLen) self.chunkLeftContext = chunkLeftContext self.chunkRightContext = chunkRightContext diff --git a/Sources/FluidAudio/ModelNames.swift b/Sources/FluidAudio/ModelNames.swift index 63dc59204..642fb6f95 100644 --- a/Sources/FluidAudio/ModelNames.swift +++ b/Sources/FluidAudio/ModelNames.swift @@ -635,6 +635,26 @@ public enum ModelNames { /// Sortformer streaming diarization model names public enum Sortformer { + /// Selects which weight-precision build of the model set to download. + /// + /// Both sets are the BNNS-fixed v3 rebuild (issue #726); they differ only in head weights: + /// - `.fp16`: full-precision head. Default. Best DER. + /// - `.palettized`: 6-bit k-means LUT head, ~2.5x smaller on disk and ~330MB RAM + /// (vs ~2.4GB for `highContextV2_1` fp16), at ~+0.9pp DER. Use on RAM-constrained + /// devices where the fp16 set crashes (older iPhones). + public enum ModelPrecision: String, Sendable, CaseIterable { + case fp16 + case palettized + + /// Repo subdirectory holding this precision's model set. + public var subdirectory: String { + switch self { + case .fp16: return "v3/fp16" + case .palettized: return "v3/palettized" + } + } + } + public enum Variant: CaseIterable, Sendable { case fastV2 case fastV2_1 @@ -642,6 +662,9 @@ public enum ModelNames { case balancedV2_1 case highContextV2 case highContextV2_1 + /// Higher-throughput streaming: larger chunk (~2s output latency) for ~4x the + /// real-time factor of `fastV2_1` at near-identical per-inference cost. + case efficientV2_1 public var name: String { switch self { @@ -657,6 +680,8 @@ public enum ModelNames { return "SortformerNvidiaHigh_v2" case .highContextV2_1: return "SortformerNvidiaHigh_v2.1" + case .efficientV2_1: + return "SortformerEfficient_v2.1" } } @@ -674,11 +699,19 @@ public enum ModelNames { return .highContextV2 case .highContextV2_1: return .highContextV2_1 + case .efficientV2_1: + return .efficientV2_1 } } + /// Compiled-model path for this variant at the default (`.fp16`) precision. public var fileName: String { - return "\(name).mlmodelc" + return fileName(precision: .fp16) + } + + /// Compiled-model path for this variant at the given weight precision. + public func fileName(precision: ModelPrecision) -> String { + return "\(precision.subdirectory)/\(name).mlmodelc" } public func isCompatible(with config: SortformerConfig) -> Bool { @@ -686,21 +719,33 @@ public enum ModelNames { } } + /// Repo subdirectory holding the default (`.fp16`) model set. Both v3 sets are the + /// BNNS-fixed rebuild (the older root-level models hit a "tensor as both input and output" + /// graph-compile crash on newer BNNS — issue #726). Select the 6-bit, ~2.5x-smaller set + /// via `SortformerConfig.precision = .palettized` (fixes RAM-driven crashes on older + /// devices at ~+0.9pp DER). + public static let modelsSubdirectory = ModelPrecision.fp16.subdirectory + /// Lowest latency for streaming public static let defaultVariant: Variant = .fastV2_1 - /// Bundle name for a specific variant + /// Bundle name for a specific variant at the default (`.fp16`) precision public static func bundle(for variant: Variant) -> String { return variant.fileName } - /// Bundle name for a given configuration + /// Bundle name for a specific variant at the given precision + public static func bundle(for variant: Variant, precision: ModelPrecision) -> String { + return variant.fileName(precision: precision) + } + + /// Bundle name for a given configuration (honors `config.precision`) public static func bundle(for config: SortformerConfig) -> String? { guard let variant = config.modelVariant else { return nil } assert(variant.isCompatible(with: config), "ERROR: Model variant and configuration are not compatible.") - return variant.fileName + return variant.fileName(precision: config.precision) } /// Default bundle name @@ -712,6 +757,18 @@ public enum ModelNames { public static var requiredModels: Set { Set(Variant.allCases.map(\.fileName)) } + + // MARK: - Offline (whole-window) model + + /// Fused offline model name (v2.1 weights). Unlike the streaming variants this is a single + /// graph `mel -> speaker_preds` over a fixed 30.72s window (3072 mel -> 384 output frames), + /// with no spkcache/FIFO state — see ``OfflineSortformerDiarizer``. + public static let offlineModelName = "SortformerOffline_v2.1" + + /// Compiled-model path for the offline model at the given precision. + public static func offlineBundle(precision: ModelPrecision = .fp16) -> String { + return "\(precision.subdirectory)/\(offlineModelName).mlmodelc" + } } /// LS-EEND streaming diarization model names diff --git a/Sources/FluidAudioCLI/Commands/SortformerBenchmark.swift b/Sources/FluidAudioCLI/Commands/SortformerBenchmark.swift index 08d2f10ff..68bfd171c 100644 --- a/Sources/FluidAudioCLI/Commands/SortformerBenchmark.swift +++ b/Sources/FluidAudioCLI/Commands/SortformerBenchmark.swift @@ -68,6 +68,9 @@ enum SortformerBenchmark { var singleFile: String? var maxFiles: Int? var threshold: Float = 0.5 + var collarSeconds: Double = 0 + var onsetThreshold: Float? + var offsetThreshold: Float? var modelPath: String? var outputFile: String? var verbose = false @@ -77,6 +80,8 @@ enum SortformerBenchmark { var useNvidiaHighLatency = false var useHuggingFace = true var useLocalModels = false + var offline = false + var palettized = false var progressFile: String = ".sortformer_progress.json" var resumeFromProgress = false var dataset: Dataset = .ami @@ -116,6 +121,21 @@ enum SortformerBenchmark { threshold = Float(arguments[i + 1]) ?? 0.5 i += 1 } + case "--collar": + if i + 1 < arguments.count { + collarSeconds = Double(arguments[i + 1]) ?? 0 + i += 1 + } + case "--onset": + if i + 1 < arguments.count { + onsetThreshold = Float(arguments[i + 1]) + i += 1 + } + case "--offset": + if i + 1 < arguments.count { + offsetThreshold = Float(arguments[i + 1]) + i += 1 + } case "--model": if i + 1 < arguments.count { modelPath = arguments[i + 1] @@ -180,6 +200,10 @@ enum SortformerBenchmark { useHuggingFace = true case "--local": useLocalModels = true + case "--offline": + offline = true + case "--palettized": + palettized = true case "--help": printUsage() return @@ -274,6 +298,24 @@ enum SortformerBenchmark { print("") fflush(stdout) + // Offline (whole-file fused model) path — no streaming state, one fused call per window. + if offline { + await runOfflineBenchmark( + filesToProcess: filesToProcess, + completedResults: completedResults, + completedMeetings: completedMeetings, + dataset: dataset, + palettized: palettized, + modelPath: modelPath, + useHuggingFace: useHuggingFace, + threshold: threshold, + collarSeconds: collarSeconds, + verbose: verbose, + progressFile: progressFile, + outputFile: outputFile) + return + } + // Initialize Sortformer print("Loading Sortformer models...") fflush(stdout) @@ -294,7 +336,19 @@ enum SortformerBenchmark { if let v = weakBoostRate { config.weakBoostRate = v } if let v = minPosScoresRate { config.minPosScoresRate = v } if let v = spkcacheSilFramesPerSpk { config.spkcacheSilFramesPerSpk = v } - let diarizer = SortformerDiarizer(config: config) + // Allow overriding the timeline binarization thresholds (sortformerDefault = 0.5/0.5). + let diarizer: SortformerDiarizer + if onsetThreshold != nil || offsetThreshold != nil { + let timeline = DiarizerTimelineConfig( + numSpeakers: config.numSpeakers, + frameDurationSeconds: Float(config.frameDurationSeconds), + onsetThreshold: onsetThreshold ?? 0.5, + offsetThreshold: offsetThreshold ?? onsetThreshold ?? 0.5 + ) + diarizer = SortformerDiarizer(config: config, timelineConfig: timeline) + } else { + diarizer = SortformerDiarizer(config: config) + } do { if useHuggingFace { @@ -340,6 +394,7 @@ enum SortformerBenchmark { diarizer: diarizer, modelLoadTime: modelLoadTime, threshold: threshold, + collarSeconds: collarSeconds, verbose: verbose ) @@ -375,12 +430,95 @@ enum SortformerBenchmark { } } + /// Whole-file offline benchmark over the dataset using the fused offline Sortformer model. + private static func runOfflineBenchmark( + filesToProcess: [String], + completedResults: [BenchmarkResult], + completedMeetings: Set, + dataset: Dataset, + palettized: Bool, + modelPath: String?, + useHuggingFace: Bool, + threshold: Float, + collarSeconds: Double, + verbose: Bool, + progressFile: String, + outputFile: String? + ) async { + print("Loading offline Sortformer model...") + fflush(stdout) + let modelLoadStart = Date() + + var offlineConfig = OfflineSortformerConfig.offlineV2_1 + if palettized { offlineConfig.precision = .palettized } + let diarizer = OfflineSortformerDiarizer(config: offlineConfig) + + do { + if let modelPath = modelPath, !useHuggingFace { + try await diarizer.initialize(modelPath: URL(fileURLWithPath: modelPath)) + } else { + try await diarizer.initializeFromHuggingFace() + } + } catch { + print("Failed to initialize offline Sortformer: \(error)") + return + } + let modelLoadTime = Date().timeIntervalSince(modelLoadStart) + print( + "Model loaded in \(String(format: "%.2f", modelLoadTime))s (precision: \(offlineConfig.precision.rawValue))\n" + ) + fflush(stdout) + + var allResults: [BenchmarkResult] = completedResults + for (fileIndex, meetingName) in filesToProcess.enumerated() { + if completedMeetings.contains(meetingName) { + print("[\(fileIndex + 1)/\(filesToProcess.count)] Skipping (already done): \(meetingName)") + continue + } + print(String(repeating: "=", count: 60)) + print("[\(fileIndex + 1)/\(filesToProcess.count)] Processing (offline): \(meetingName)") + print(String(repeating: "=", count: 60)) + fflush(stdout) + + let result = await processMeeting( + meetingName: meetingName, + dataset: dataset, + diarizer: nil, + offlineDiarizer: diarizer, + modelLoadTime: modelLoadTime, + threshold: threshold, + collarSeconds: collarSeconds, + verbose: verbose) + + if let result = result { + allResults.append(result) + print("Results for \(meetingName):") + print(" DER: \(String(format: "%.1f", result.der))%") + print(" RTFx: \(String(format: "%.1f", result.rtfx))x") + print(" Speakers: \(result.detectedSpeakers) detected / \(result.groundTruthSpeakers) truth") + DiarizationBenchmarkUtils.saveProgress(results: allResults, to: progressFile) + } + fflush(stdout) + } + + DiarizationBenchmarkUtils.printFinalSummary( + results: allResults, + title: "SORTFORMER OFFLINE BENCHMARK SUMMARY", + derTargets: [15, 20]) + + if let outputPath = outputFile { + DiarizationBenchmarkUtils.saveJSONResults(results: allResults, to: outputPath) + } + } + private static func processMeeting( meetingName: String, dataset: Dataset, - diarizer: SortformerDiarizer, + diarizer: SortformerDiarizer?, + offlineDiarizer: OfflineSortformerDiarizer? = nil, modelLoadTime: Double, threshold: Float, + collarSeconds: Double, verbose: Bool ) async -> BenchmarkResult? { @@ -408,20 +546,29 @@ enum SortformerBenchmark { // Process with progress reporting let startTime = Date() var lastProgressPrint = Date() - let result = try diarizer.processComplete(audioSamples) { processed, total, chunks in - // Print progress every 2 seconds - let now = Date() - if now.timeIntervalSince(lastProgressPrint) >= 2.0 { - let percent = Float(processed) / Float(total) * 100 - let elapsed = now.timeIntervalSince(startTime) - let processedSeconds = Float(processed) / 16000.0 - let currentRtfx = processedSeconds / Float(elapsed) - print( - " Progress: \(String(format: "%.1f", percent))% | Chunks: \(chunks) | RTFx: \(String(format: "%.1f", currentRtfx))x" - ) - fflush(stdout) - lastProgressPrint = now + let result: DiarizerTimeline + if let offlineDiarizer = offlineDiarizer { + // Offline: one fused call per window, no streaming progress callback. + result = try offlineDiarizer.processComplete(audioSamples) + } else if let diarizer = diarizer { + result = try diarizer.processComplete(audioSamples) { processed, total, chunks in + // Print progress every 2 seconds + let now = Date() + if now.timeIntervalSince(lastProgressPrint) >= 2.0 { + let percent = Float(processed) / Float(total) * 100 + let elapsed = now.timeIntervalSince(startTime) + let processedSeconds = Float(processed) / 16000.0 + let currentRtfx = processedSeconds / Float(elapsed) + print( + " Progress: \(String(format: "%.1f", percent))% | Chunks: \(chunks) | RTFx: \(String(format: "%.1f", currentRtfx))x" + ) + fflush(stdout) + lastProgressPrint = now + } } + } else { + print("No diarizer provided") + return nil } let processingTime = Date().timeIntervalSince(startTime) @@ -484,7 +631,7 @@ enum SortformerBenchmark { ref: referenceSegments, hyp: hypothesisSegments, frameStep: derFrameStepSeconds, - collar: 0 + collar: collarSeconds ) let totalRefSpeech = max(derResult.totalRefSpeech, .leastNonzeroMagnitude) let derPercent = Float(derResult.der * 100) diff --git a/Sources/FluidAudioCLI/Commands/SortformerCommand.swift b/Sources/FluidAudioCLI/Commands/SortformerCommand.swift index 482520a93..615940b19 100644 --- a/Sources/FluidAudioCLI/Commands/SortformerCommand.swift +++ b/Sources/FluidAudioCLI/Commands/SortformerCommand.swift @@ -41,6 +41,9 @@ enum SortformerCommand { var weakBoostRate: Float? var minPosScoresRate: Float? var spkcacheSilFramesPerSpk: Int? + var configName = "default" + var palettized = false + var offline = false // Parse remaining arguments var i = 1 @@ -88,6 +91,15 @@ enum SortformerCommand { modelPath = arguments[i + 1] i += 1 } + case "--config": + if i + 1 < arguments.count { + configName = arguments[i + 1].lowercased() + i += 1 + } + case "--offline": + offline = true + case "--palettized": + palettized = true case "--threshold": if i + 1 < arguments.count, let v = Float(arguments[i + 1]) { predScoreThreshold = v @@ -129,13 +141,43 @@ enum SortformerCommand { i += 1 } + if offline { + var postConfig = DiarizerTimelineConfig.sortformerDefault + if let v = onset { postConfig.onsetThreshold = v } + if let v = offset { postConfig.offsetThreshold = v } + if let v = padOnset { postConfig.onsetPadSeconds = v } + if let v = padOffset { postConfig.offsetPadSeconds = v } + if let v = minDurationOn { postConfig.minDurationOn = v } + if let v = minDurationOff { postConfig.minDurationOff = v } + await runOffline( + audioFile: audioFile, + modelPath: modelPath, + palettized: palettized, + outputFile: outputFile, + postConfig: postConfig) + return + } + print("Sortformer Streaming Diarization") print(" Audio: \(audioFile)") - // Initialize Sortformer with default config (NVIDIA low latency: 1.04s) - var config = SortformerConfig.default + // Select config (default = NVIDIA low latency ~1.04s). `--config efficient` = chunk_len=25 (~2s, higher throughput). + var config: SortformerConfig + switch configName { + case "efficient": + config = .efficientV2_1 + case "fast", "fastv2_1": + config = .fastV2_1 + case "low", "balanced": + config = .balancedV2_1 + case "high", "highcontext": + config = .highContextV2_1 + default: + config = .default + } var postConfig = DiarizerTimelineConfig.sortformerDefault config.debugMode = debugMode + if palettized { config.precision = .palettized } if let v = predScoreThreshold { config.predScoreThreshold = v } if let v = silenceThreshold { config.silenceThreshold = v } if let v = scoresBoostLatest { config.scoresBoostLatest = v } @@ -289,6 +331,95 @@ enum SortformerCommand { } } + /// Whole-file diarization via the fused offline Sortformer model (no streaming state). + private static func runOffline( + audioFile: String, + modelPath: String?, + palettized: Bool, + outputFile: String?, + postConfig: DiarizerTimelineConfig + ) async { + print("Sortformer Offline Diarization") + print(" Audio: \(audioFile)") + + var offlineConfig = OfflineSortformerConfig.offlineV2_1 + if palettized { offlineConfig.precision = .palettized } + let diarizer = OfflineSortformerDiarizer(config: offlineConfig, timelineConfig: postConfig) + + do { + let loadStart = Date() + if let modelPath = modelPath { + print("Loading offline model from local path: \(modelPath)") + try await diarizer.initialize(modelPath: URL(fileURLWithPath: modelPath)) + } else { + print("Loading offline model from HuggingFace...") + try await diarizer.initializeFromHuggingFace(computeUnits: .all) + } + print("Model loaded in \(String(format: "%.2f", Date().timeIntervalSince(loadStart)))s") + } catch { + print("ERROR: Failed to initialize offline Sortformer: \(error)") + exit(1) + } + + do { + print("Loading audio...") + let audioSamples = try AudioConverter().resampleAudioFile(path: audioFile) + let duration = Float(audioSamples.count) / 16000.0 + print("Loaded \(audioSamples.count) samples (\(String(format: "%.1f", duration))s)") + + print("Processing...") + fflush(stdout) + let startTime = Date() + let result = try diarizer.processComplete(audioSamples) + let processingTime = Date().timeIntervalSince(startTime) + let rtfx = duration / Float(processingTime) + + print("Processing completed in \(String(format: "%.2f", processingTime))s") + print(" Real-time factor (RTFx): \(String(format: "%.1f", rtfx))x") + print(" Total frames: \(result.numFinalizedFrames)") + + let segments = result.speakers.values.flatMap { $0.finalizedSegments } + print(" Found \(segments.count) segments") + print("\n--- Speaker Segments ---") + for segment in segments.sorted() { + let start = String(format: "%.2f", segment.startTime) + let end = String(format: "%.2f", segment.endTime) + let dur = String(format: "%.2f", segment.duration) + print("\(segment.speakerLabel): \(start)s - \(end)s (\(dur)s)") + } + + if let outputFile = outputFile { + var segmentDicts: [[String: Any]] = [] + for segment in segments.sorted() { + segmentDicts.append([ + "speaker": segment.speakerLabel, + "speakerIndex": segment.speakerIndex, + "startTimeSeconds": segment.startTime, + "endTimeSeconds": segment.endTime, + "durationSeconds": segment.duration, + ]) + } + let output: [String: Any] = [ + "audioFile": audioFile, + "mode": "offline", + "durationSeconds": duration, + "processingTimeSeconds": processingTime, + "rtfx": rtfx, + "totalFrames": result.numFinalizedFrames, + "segmentCount": segments.count, + "segments": segmentDicts, + ] + let jsonData = try JSONSerialization.data( + withJSONObject: output, options: [.prettyPrinted, .sortedKeys]) + try jsonData.write(to: URL(fileURLWithPath: outputFile)) + print("Results saved to: \(outputFile)") + } + } catch { + print("ERROR: Failed to process audio: \(error)") + exit(1) + } + } + private static func printUsage() { let usage = """ @@ -297,6 +428,8 @@ enum SortformerCommand { Options: --model-path Path to local CoreML model (.mlpackage or .mlmodelc) + --offline Whole-file mode: fused offline model (no streaming state) + --palettized Use the 6-bit palettized model set (~2.5x smaller) --debug Enable debug mode --output Save results to JSON file --onset Onset threshold for speech detection (default: 0.5) @@ -320,6 +453,9 @@ enum SortformerCommand { # With local model path fluidaudio sortformer audio.wav --model-path ./coreml_models/SortformerPipeline.mlpackage + # Offline whole-file mode (fused model, fastest when the full audio is available) + fluidaudio sortformer audio.wav --offline + # Tune streaming parameters fluidaudio sortformer audio.wav --threshold 0.3 --silence-threshold 0.15 diff --git a/Tests/FluidAudioTests/ASR/Parakeet/ModelNamesTests.swift b/Tests/FluidAudioTests/ASR/Parakeet/ModelNamesTests.swift index fb392b8bf..3e94da56c 100644 --- a/Tests/FluidAudioTests/ASR/Parakeet/ModelNamesTests.swift +++ b/Tests/FluidAudioTests/ASR/Parakeet/ModelNamesTests.swift @@ -119,6 +119,38 @@ final class ModelNamesTests: XCTestCase { ) } + func testSortformerPrecisionSubdirectories() { + XCTAssertEqual(ModelNames.Sortformer.ModelPrecision.fp16.subdirectory, "v3/fp16") + XCTAssertEqual(ModelNames.Sortformer.ModelPrecision.palettized.subdirectory, "v3/palettized") + // Default subdirectory must track the fp16 precision. + XCTAssertEqual( + ModelNames.Sortformer.modelsSubdirectory, ModelNames.Sortformer.ModelPrecision.fp16.subdirectory) + } + + func testSortformerBundleHonorsPrecision() { + for variant in ModelNames.Sortformer.Variant.allCases { + let fp16 = ModelNames.Sortformer.bundle(for: variant, precision: .fp16) + let palettized = ModelNames.Sortformer.bundle(for: variant, precision: .palettized) + XCTAssertTrue(fp16.hasPrefix("v3/fp16/"), "fp16 bundle '\(fp16)' should live under v3/fp16/") + XCTAssertTrue( + palettized.hasPrefix("v3/palettized/"), + "palettized bundle '\(palettized)' should live under v3/palettized/") + // Default (no precision) bundle must equal the fp16 path. + XCTAssertEqual(ModelNames.Sortformer.bundle(for: variant), fp16) + } + } + + func testSortformerConfigPrecisionDrivesBundle() { + var config = SortformerConfig.highContextV2_1 + XCTAssertEqual(config.precision, .fp16, "precision should default to fp16") + XCTAssertEqual(ModelNames.Sortformer.bundle(for: config), config.modelVariant?.fileName(precision: .fp16)) + + config.precision = .palettized + XCTAssertEqual( + ModelNames.Sortformer.bundle(for: config), config.modelVariant?.fileName(precision: .palettized), + "Flipping config.precision must redirect the bundle to the palettized set") + } + // MARK: - Specific Model Names func testASRModelNamesEndInMlmodelc() { diff --git a/Tests/FluidAudioTests/Diarizer/Sortformer/OfflineSortformerTests.swift b/Tests/FluidAudioTests/Diarizer/Sortformer/OfflineSortformerTests.swift new file mode 100644 index 000000000..c4576b186 --- /dev/null +++ b/Tests/FluidAudioTests/Diarizer/Sortformer/OfflineSortformerTests.swift @@ -0,0 +1,121 @@ +import Foundation +import XCTest + +@testable import FluidAudio + +final class OfflineSortformerTests: XCTestCase { + + // MARK: - SortformerSpeakerStitcher + + /// Window speakers already in global order -> identity mapping. + func testStitcherIdentityWhenAligned() { + let frames = 4 + let ns = 4 + // One-hot active speaker per frame, same order in both. + var global = [Float](repeating: 0, count: frames * ns) + for f in 0.. stitcher recovers the inverse mapping. + func testStitcherRecoversPermutation() { + let frames = 8 + let ns = 4 + // global speaker g active on frame g (and g+4). window uses permutation perm[g]=w. + let perm = [2, 0, 3, 1] // global g -> window column perm[g] + var global = [Float](repeating: 0, count: frames * ns) + var window = [Float](repeating: 0, count: frames * ns) + for f in 0.. 0 { + XCTAssertEqual(mapping[w], f % ns, "window col \(w) should map to global \(f % ns)") + } + } + } + + /// Soft (non one-hot) activity still aligns by maximum correlation. + func testStitcherSoftActivity() { + let frames = 3 + let ns = 4 + // global: speaker 1 dominant; window: column 3 dominant -> col3 should map to global 1. + var global = [Float](repeating: 0.1, count: frames * ns) + var window = [Float](repeating: 0.1, count: frames * ns) + for f in 0.. identity (nothing to align on). + func testStitcherZeroFramesIsIdentity() { + let mapping = SortformerSpeakerStitcher.alignment( + global: [], window: [], frames: 0, numSpeakers: 4) + XCTAssertEqual(mapping, [0, 1, 2, 3]) + } + + /// Mapping is always a valid bijection of speaker indices. + func testStitcherMappingIsBijection() { + let frames = 5 + let ns = 4 + var global = [Float](repeating: 0, count: frames * ns) + var window = [Float](repeating: 0, count: frames * ns) + for f in 0..