diff --git a/Sources/FluidAudio/ASR/Canary/CanaryConfig.swift b/Sources/FluidAudio/ASR/Canary/CanaryConfig.swift new file mode 100644 index 00000000..9996429d --- /dev/null +++ b/Sources/FluidAudio/ASR/Canary/CanaryConfig.swift @@ -0,0 +1,68 @@ +@preconcurrency import CoreML +import Foundation + +/// Canary encoder/decoder weight precision. +/// +/// `int4` (per-block-32 symmetric) runs on the Neural Engine and is the +/// smallest build (~573 MB) — but int4 weight payloads require iOS18 / macOS 15. +/// `fp16` is the iOS17 parity fallback (exact match to PyTorch). `int8` +/// (per-channel) decodes correctly only on CPU — it crashes the GPU/ANE MPSGraph +/// backend — so it is a CPU/size-only option. +public enum CanaryPrecision: String, Sendable, CaseIterable { + case int4 + case fp16 + case int8 + + var encoderName: String { + switch self { + case .int4: return ModelNames.Canary.encoderInt4 + case .fp16: return ModelNames.Canary.encoder + case .int8: return ModelNames.Canary.encoderInt8 + } + } + + var decoderName: String { + switch self { + case .int4: return ModelNames.Canary.decoderInt4 + case .fp16: return ModelNames.Canary.decoder + case .int8: return ModelNames.Canary.decoderInt8 + } + } + + /// int8 only decodes correctly on CPU; int4/fp16 run on the Neural Engine. + var computeUnits: MLComputeUnits { + self == .int8 ? .cpuOnly : .cpuAndNeuralEngine + } +} + +/// Fixed-shape contract for the canary-1b-v2 CoreML pipeline (15 s window). +public enum CanaryConfig { + public static let sampleRate = 16000 + /// 15 s window — the preprocessor input is fixed at this sample count. + public static let maxSamples = 240_000 + /// Overlap between adjacent windows when chunking audio longer than 15 s. + /// 3 s (~19 tokens) gives the seam LCS-merge enough shared context to align + /// reliably while wasting little recompute. Hop = maxSamples − this. + public static let chunkOverlapSeconds = 3.0 + public static let chunkOverlapSamples = 48_000 + public static let melDim = 128 + public static let melFrames = 1501 + public static let encoderHidden = 1024 + public static let encoderFrames = 188 + /// Decoder is exported at a fixed `[1, maxDecoderSteps]`. 128 covers a 15 s + /// window (max observed ~108 tokens incl. prompt) and is ~1.5× faster than 256. + /// `CanaryManager` reads the real length from the loaded model, so this is just + /// the contract/fallback value. + public static let maxDecoderSteps = 128 + public static let vocabSize = 16384 + + // Special token ids (the model's real decoder ids — see vocab.json). + public static let eosId = 3 // <|endoftext|> + public static let padId = 2 // + public static let bosId = 4 // <|startoftranscript|> + + /// canary2 prompt for English transcribe + punctuation/capitalization: + /// ▁ <|startofcontext|> <|startoftranscript|> <|emo:undefined|> <|en|> <|en|> + /// <|pnc|> <|noitn|> <|notimestamp|> <|nodiarize|> + public static let promptEnTranscribePnc: [Int32] = [16053, 7, 4, 16, 64, 64, 5, 9, 11, 13] +} diff --git a/Sources/FluidAudio/ASR/Canary/CanaryKeywordBooster.swift b/Sources/FluidAudio/ASR/Canary/CanaryKeywordBooster.swift new file mode 100644 index 00000000..eee6ac3e --- /dev/null +++ b/Sources/FluidAudio/ASR/Canary/CanaryKeywordBooster.swift @@ -0,0 +1,169 @@ +import Foundation + +/// Applies custom-vocabulary keyword boosting to a Canary (AED) transcript using +/// the existing CTC keyword spotter — the same detector the parakeet "ctc custom +/// vocab" path uses. +/// +/// Canary decodes autoregressively and emits no per-frame timestamps, so the +/// timestamp-constrained CTC rescorer (`VocabularyRescorer.ctcTokenRescore`) +/// cannot be applied directly. Instead this reuses the engine-independent +/// `CtcKeywordSpotter` to detect dictionary terms from the audio, then injects +/// each detected term into Canary's transcript by fuzzy string match: a span that +/// is close-but-not-exact to a detected term (i.e. Canary mis-spelled the domain +/// word) is replaced with the canonical term. +public struct CanaryKeywordBooster: Sendable { + + public struct Result: Sendable { + public let text: String + /// Distinct terms the CTC spotter detected in the audio. + public let detected: [String] + /// Terms actually substituted into the transcript. + public let applied: [String] + } + + private let spotter: CtcKeywordSpotter + private let tokenizer: CtcTokenizer + /// CTC detection score floor (log-prob; higher = stronger). Matches the + /// permissive detection threshold the earnings benchmark uses. + private let minScore: Float + /// Replace a transcript span only when its similarity to the term is at least + /// this (close enough to be the same word mis-transcribed). + private let minSimilarity: Float + /// …and below this (above it the span is already essentially the term). + private let maxSimilarity: Float + /// When a detected term has no fuzzy-matchable span (canary missed it entirely), + /// insert it at the position implied by the CTC detection time. + private let insertOnMiss: Bool + /// Only insert (vs replace) when the detection score clears this stronger floor — + /// protects precision against weak detections being force-inserted. + private let insertScoreFloor: Float + + private static let logger = AppLogger(category: "CanaryKeywordBooster") + + public init( + spotter: CtcKeywordSpotter, + tokenizer: CtcTokenizer, + minScore: Float = -15.0, + minSimilarity: Float = 0.60, + maxSimilarity: Float = 0.97, + insertOnMiss: Bool = true, + insertScoreFloor: Float = -6.0 + ) { + self.spotter = spotter + self.tokenizer = tokenizer + self.minScore = minScore + self.minSimilarity = minSimilarity + self.maxSimilarity = maxSimilarity + self.insertOnMiss = insertOnMiss + self.insertScoreFloor = insertScoreFloor + } + + /// Load the CTC spotter + tokenizer (parakeet-tdt_ctc-110m) and build a booster. + public static func load( + minScore: Float = -15.0, + minSimilarity: Float = 0.60, + insertOnMiss: Bool = true, + insertScoreFloor: Float = -6.0 + ) async throws -> CanaryKeywordBooster { + let models = try await CtcModels.downloadAndLoad() + let tokenizer = try await CtcTokenizer.load() + return CanaryKeywordBooster( + spotter: CtcKeywordSpotter(models: models), tokenizer: tokenizer, minScore: minScore, + minSimilarity: minSimilarity, insertOnMiss: insertOnMiss, insertScoreFloor: insertScoreFloor) + } + + /// Ensure every term carries CTC token IDs (the spotter scores by them). + private func tokenized(_ vocabulary: CustomVocabularyContext) -> CustomVocabularyContext { + let terms = vocabulary.terms.map { term -> CustomVocabularyTerm in + if let ids = term.ctcTokenIds, !ids.isEmpty { return term } + let ids = tokenizer.encode(term.text) + return CustomVocabularyTerm( + text: term.text, weight: term.weight, aliases: term.aliases, tokenIds: term.tokenIds, + ctcTokenIds: ids) + } + return CustomVocabularyContext( + terms: terms, alpha: vocabulary.alpha, minCtcScore: vocabulary.minCtcScore, + minSimilarity: vocabulary.minSimilarity, minCombinedConfidence: vocabulary.minCombinedConfidence, + minTermLength: vocabulary.minTermLength) + } + + /// Inject CTC-spotted custom-vocabulary terms into `transcript`. + public func boost( + transcript: String, audioSamples: [Float], vocabulary: CustomVocabularyContext + ) async throws -> Result { + let vocab = tokenized(vocabulary) + let spot = try await spotter.spotKeywordsWithLogProbs( + audioSamples: audioSamples, customVocabulary: vocab, minScore: minScore) + + // Best CTC detection (score + start time) per detected term. + var detByTerm: [String: (term: CustomVocabularyTerm, score: Float, startTime: TimeInterval)] = [:] + for d in spot.detections where d.score >= minScore { + let key = d.term.textLowercased + if let cur = detByTerm[key], cur.score >= d.score { continue } + detByTerm[key] = (d.term, d.score, d.startTime) + } + let detected = detByTerm.values.map { $0.term.text }.sorted() + guard !detByTerm.isEmpty else { + return Result(text: transcript, detected: detected, applied: []) + } + let duration = max(0.001, Double(audioSamples.count) / 16000.0) + + // Strongest detections first; longer phrases before shorter to avoid + // a single word stealing a multi-word match. + let ordered = detByTerm.values.sorted { + $0.term.text.split(separator: " ").count != $1.term.text.split(separator: " ").count + ? $0.term.text.split(separator: " ").count > $1.term.text.split(separator: " ").count + : $0.score > $1.score + } + + var words = transcript.split(separator: " ").map(String.init) + var applied: [String] = [] + + for entry in ordered { + let term = entry.term + let termLower = term.textLowercased + // Already present (case-insensitive substring) → nothing to fix. + if words.joined(separator: " ").lowercased().contains(termLower) { continue } + + let termWords = term.text.split(separator: " ").map(String.init) + let span = max(1, termWords.count) + + // 1) Fuzzy replace: a close-but-wrong span is canary mis-spelling the term. + var bestIdx = -1 + var bestSim: Float = 0 + if words.count >= span { + for i in 0...(words.count - span) { + let window = normalize(words[i..<(i + span)].joined(separator: " ")) + let sim = VocabularyRescorer.stringSimilarity(window, termLower) + if sim > bestSim { + bestSim = sim + bestIdx = i + } + } + } + + if bestIdx >= 0, bestSim >= minSimilarity, bestSim < maxSimilarity { + words.replaceSubrange(bestIdx..<(bestIdx + span), with: termWords) + applied.append(term.text) + continue + } + + // 2) Timestamp-guided insertion: canary missed the word entirely (no fuzzy + // span). The CTC detection still localizes it in time, so insert it at the + // proportional word position. Gated by a stronger score floor to protect + // precision. + if insertOnMiss, entry.score >= insertScoreFloor, !words.isEmpty { + let frac = min(1.0, max(0.0, entry.startTime / duration)) + let pos = min(words.count, Int((frac * Double(words.count)).rounded())) + words.insert(contentsOf: termWords, at: pos) + applied.append(term.text) + } + } + + return Result(text: words.joined(separator: " "), detected: detected, applied: applied) + } + + private func normalize(_ s: String) -> String { + s.lowercased().filter { !$0.isPunctuation } + } +} diff --git a/Sources/FluidAudio/ASR/Canary/CanaryManager.swift b/Sources/FluidAudio/ASR/Canary/CanaryManager.swift new file mode 100644 index 00000000..386228e3 --- /dev/null +++ b/Sources/FluidAudio/ASR/Canary/CanaryManager.swift @@ -0,0 +1,355 @@ +@preconcurrency import CoreML +import Foundation + +/// Manager for NVIDIA Canary-1B-v2 transcription (attention encoder-decoder). +/// +/// Pipeline: waveform → [Preprocessor fp32/CPU] mel → [Encoder int4/ANE] → +/// transpose to [1, T, D] → greedy autoregressive loop ([Decoder] → last hidden +/// → [Projection] → argmax until EOS) → SentencePiece detokenize. +/// +/// The decoder carries no KV cache: each step re-runs the full `[1, 256]` token +/// sequence (matches the converted CoreML model). The 15 s window is fixed; audio +/// longer than 15 s is truncated (chunking is a future addition). +public actor CanaryManager { + + private let models: CanaryModels + private let prompt: [Int32] + private static let logger = AppLogger(category: "CanaryManager") + + public init(models: CanaryModels, prompt: [Int32] = CanaryConfig.promptEnTranscribePnc) { + self.models = models + self.prompt = prompt + } + + /// Load models from the default cache (downloading if needed), then build a manager. + public static func load( + precision: CanaryPrecision = .int4, + progressHandler: DownloadUtils.ProgressHandler? = nil + ) async throws -> CanaryManager { + let models = try await CanaryModels.downloadAndLoad(precision: precision, progressHandler: progressHandler) + return CanaryManager(models: models) + } + + /// Transcribe a 16 kHz mono audio file. + public func transcribe(audioURL: URL) throws -> String { + let converter = AudioConverter(sampleRate: Double(CanaryConfig.sampleRate)) + let samples = try converter.resampleAudioFile(audioURL) + return try transcribe(audio: samples) + } + + /// Transcribe 16 kHz mono float samples (in [-1, 1]). + /// + /// Audio within the 15 s window is decoded in one pass. Longer audio is split + /// into overlapping 15 s windows (hop = 15 s − `chunkOverlapSeconds`), decoded + /// independently, and stitched at the seams via token-level + /// longest-common-substring (`mergeTokenStreams`). No model change — each + /// window still sees the fixed 15 s contract and the decoder is reset per window. + public func transcribe(audio: [Float]) throws -> String { + let maxN = CanaryConfig.maxSamples + if audio.count <= maxN { + return detokenize(try transcribeWindow(audio: audio)) + } + + let hop = maxN - CanaryConfig.chunkOverlapSamples + var merged: [Int] = [] + var start = 0 + var chunkIndex = 0 + while start < audio.count { + let end = min(start + maxN, audio.count) + // Don't decode a final tail that is pure overlap — the previous window + // already covered it. + if chunkIndex > 0, (end - start) <= (maxN - hop) { break } + + let tokens = try transcribeWindow(audio: Array(audio[start..= audio.count { break } + start += hop + } + return detokenize(merged) + } + + /// Run the 4-stage pipeline over a single ≤15 s window; returns generated + /// token ids (prompt stripped, EOS excluded). + private func transcribeWindow(audio: [Float]) throws -> [Int] { + let (mel, melLength) = try runPreprocessor(audio: audio) + let (encoder, encoderLength) = try runEncoder(mel: mel, melLength: melLength) + let (embeddings, encoderMask) = try makeDecoderContext(encoder: encoder, encoderLength: encoderLength) + return try greedyDecode(embeddings: embeddings, encoderMask: encoderMask) + } + + /// Merge two adjacent window token streams using longest-common-substring. + /// + /// Both windows transcribe `chunkOverlapSeconds` of identical audio at their + /// seam, so their token ids share a common substring near the prefix's tail / + /// the suffix's head. Search a bounded window (`windowTokens` at the boundary) + /// for the longest common substring of length ≥ `minMatch`. On a hit, drop the + /// suffix's matched head so the seam is not duplicated; on a miss, concatenate + /// plainly — better to duplicate a few tokens than to lose content. + static func mergeTokenStreams( + prefix: [Int], + suffix: [Int], + windowTokens: Int = 32, + minMatch: Int = 4 + ) -> [Int] { + if prefix.isEmpty { return suffix } + if suffix.isEmpty { return prefix } + + let pTail = Array(prefix.suffix(windowTokens)) + let sHead = Array(suffix.prefix(windowTokens)) + let m = pTail.count + let n = sHead.count + if m == 0 || n == 0 { return prefix + suffix } + + // Classic LCS-substring DP (O(m·n), m,n ≤ windowTokens). + var dp = [Int](repeating: 0, count: n + 1) + var bestLen = 0 + var bestSEnd = 0 // index in sHead (exclusive) where the match ends + for i in 1...m { + var prev = 0 + for j in 1...n { + let temp = dp[j] + if pTail[i - 1] == sHead[j - 1] { + dp[j] = prev + 1 + if dp[j] > bestLen { + bestLen = dp[j] + bestSEnd = j + } + } else { + dp[j] = 0 + } + prev = temp + } + } + + guard bestLen >= minMatch else { return prefix + suffix } + return prefix + Array(suffix.dropFirst(bestSEnd)) + } + + // MARK: - Pipeline + + /// waveform → mel [1, 128, 1501]. Audio is padded/truncated to the fixed 15 s window. + private func runPreprocessor(audio: [Float]) throws -> (MLMultiArray, MLMultiArray) { + let maxN = CanaryConfig.maxSamples + let validN = min(audio.count, maxN) + if audio.count > maxN { + Self.logger.warning("Audio \(audio.count) samples > 15 s window; truncating to \(maxN)") + } + + let signal = try MLMultiArray(shape: [1, maxN as NSNumber], dataType: .float32) + let sptr = signal.dataPointer.assumingMemoryBound(to: Float32.self) + memset(sptr, 0, maxN * MemoryLayout.size) + audio.prefix(validN).withUnsafeBufferPointer { src in + sptr.update(from: src.baseAddress!, count: validN) + } + + let length = try MLMultiArray(shape: [1], dataType: .int32) + length[0] = NSNumber(value: validN) + + let input = try MLDictionaryFeatureProvider(dictionary: [ + "audio_signal": MLFeatureValue(multiArray: signal), + "audio_length": MLFeatureValue(multiArray: length), + ]) + let out = try models.preprocessor.prediction(from: input) + guard let mel = out.featureValue(for: "processed")?.multiArrayValue, + let melLen = out.featureValue(for: "processed_length")?.multiArrayValue + else { + throw ASRError.processingFailed("Canary preprocessor produced no `processed`") + } + return (mel, melLen) + } + + /// mel → encoder [1, 1024, 188]. + private func runEncoder(mel: MLMultiArray, melLength: MLMultiArray) throws -> (MLMultiArray, Int) { + let featLen = try MLMultiArray(shape: [1], dataType: .int32) + featLen[0] = NSNumber(value: melLength[0].intValue) + + let input = try MLDictionaryFeatureProvider(dictionary: [ + "features": MLFeatureValue(multiArray: mel), + "features_length": MLFeatureValue(multiArray: featLen), + ]) + let out = try models.encoder.prediction(from: input) + guard let enc = out.featureValue(for: "encoder")?.multiArrayValue else { + throw ASRError.processingFailed("Canary encoder produced no `encoder`") + } + let encLen = out.featureValue(for: "encoder_length")?.multiArrayValue?[0].intValue ?? CanaryConfig.encoderFrames + return (enc, encLen) + } + + /// encoder [1, D, T] → encoder_embeddings [1, T, D] + encoder_mask [1, T]. + /// + /// CoreML pads the encoder's last dim to a 64-element boundary (T=188 → + /// stride 192), so the transpose must use the array's real strides, not a + /// dense linear read. + private func makeDecoderContext(encoder: MLMultiArray, encoderLength: Int) throws -> (MLMultiArray, MLMultiArray) { + let d = CanaryConfig.encoderHidden + let t = CanaryConfig.encoderFrames + let emb = try MLMultiArray(shape: [1, t as NSNumber, d as NSNumber], dataType: .float32) + let eptr = emb.dataPointer.assumingMemoryBound(to: Float32.self) + let strides = encoder.strides.map { $0.intValue } + let sD = strides[1] + let sT = strides[2] + let read = floatReader(encoder) + for ti in 0.. [Int] { + // Use the decoder's actual sequence length (the exported `[1, S]` shape), + // so a shorter decoder export (e.g. S=128) is picked up automatically. + let s = + models.decoder.modelDescription.inputDescriptionsByName["input_ids"]? + .multiArrayConstraint?.shape.last?.intValue ?? CanaryConfig.maxDecoderSteps + + let inputIds = try MLMultiArray(shape: [1, s as NSNumber], dataType: .int32) + let decoderMask = try MLMultiArray(shape: [1, s as NSNumber], dataType: .float32) + let idptr = inputIds.dataPointer.assumingMemoryBound(to: Int32.self) + let mkptr = decoderMask.dataPointer.assumingMemoryBound(to: Float32.self) + for i in 0..