diff --git a/.changeset/busy-aliens-wink.md b/.changeset/busy-aliens-wink.md new file mode 100644 index 000000000..cc66d2184 --- /dev/null +++ b/.changeset/busy-aliens-wink.md @@ -0,0 +1,6 @@ +--- +"@livekit/agents": minor +"livekit-agents-examples": patch +--- + +feat(eot): add audio eot model support diff --git a/.changeset/multimodal-eou.md b/.changeset/multimodal-eou.md new file mode 100644 index 000000000..f2cbafc84 --- /dev/null +++ b/.changeset/multimodal-eou.md @@ -0,0 +1,17 @@ +--- +"@livekit/agents": minor +"@livekit/agents-plugin-silero": minor +"@livekit/agents-plugins-livekit": minor +--- + +feat(core): multimodal end-of-turn detection with cloud → local fallback (AGT-2520) + +- New `inference.AudioTurnDetector`: WebSocket cloud EOT transport (`model: 'turn-detector'`) with automatic fallback to the local native model (`model: 'turn-detector-mini'`) via `@livekit/local-inference`. Auto-selects `'turn-detector'` when `LIVEKIT_REMOTE_EOT_URL` is set, `'turn-detector-mini'` otherwise. +- The local EOT model runs in the shared inference process (the same `InferenceProcExecutor` the text turn detector uses), loaded once per worker host (~138 MB) instead of in every job worker. The runner is registered by default when the native binding is available, so the inference process spawns on worker startup; on platforms where the binding can't load, local EOT degrades to a positive-default prediction and the worker still starts. (This is a JS-specific divergence from Python, which keeps EOT in-process and relies on forkserver COW sharing.) +- No prewarm helpers: EOT auto-warms in the inference process; the in-process silero VAD lazy-loads on first stream. (The `inference.prewarm*` helpers added during development were removed before release.) +- New `inference.VAD` (local-only streaming VAD via `@livekit/local-inference`). +- `AgentSession` now auto-provisions a bundled silero VAD when `vad` is omitted (`isDefault=true`). Pass `vad: null` to opt out. +- `livekit-plugins-silero` is deprecated; pass `vad: null` to opt out of the bundled default, or use `inference.VAD({ model: 'silero', ... })` to customise. +- `livekit-plugins-livekit` turn detector is deprecated in favor of `inference.AudioTurnDetector`. +- New `EOTInferenceMetrics` and `EOTModelUsage`; new telemetry span attributes (`lk.eou.source`, `lk.eou.from_cache`, `lk.eou.detection_delay`); new `eot_prediction` event forwarded over remote sessions. +- Requires `@livekit/protocol` >= 1.46.2 (exposes the `AgentInference` message namespace used by the cloud transport). diff --git a/agents/package.json b/agents/package.json index e543404ff..4d3aa996a 100644 --- a/agents/package.json +++ b/agents/package.json @@ -52,8 +52,9 @@ "dependencies": { "@bufbuild/protobuf": "^1.10.0", "@ffmpeg-installer/ffmpeg": "^1.1.0", + "@livekit/local-inference": "^0.2.5", "@livekit/mutex": "^1.1.1", - "@livekit/protocol": "^1.45.7", + "@livekit/protocol": "^1.46.2", "@livekit/typed-emitter": "^3.0.0", "@livekit/throws-transformer": "0.1.8", "@opentelemetry/api": "^1.9.0", diff --git a/agents/src/inference/_warmup.ts b/agents/src/inference/_warmup.ts new file mode 100644 index 000000000..0c77e6814 --- /dev/null +++ b/agents/src/inference/_warmup.ts @@ -0,0 +1,45 @@ +// SPDX-FileCopyrightText: 2026 LiveKit, Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +/** + * Loader for the bundled `@livekit/local-inference` native binding. + * + * Memory model (measured ~138 MB for the EOT model, ~2 MB for VAD): Node has + * no forkserver/COW, so anything loaded in a job worker is private to that + * worker. To avoid paying ~138 MB per worker, the EOT model is NOT loaded in + * job workers — it runs in the shared `InferenceProcExecutor` (see + * `inference/eot/runner.ts`), loaded once per host. The VAD stays in-process + * (it's small and runs continuously) and is reached via this loader. + * + * There are intentionally no public `prewarm*` helpers: EOT auto-warms via + * the inference runner's `initialize()` at proc startup, and the VAD lazy- + * loads on first stream. + */ +import { createRequire } from 'node:module'; +import { log } from '../log.js'; + +const cjsRequire = createRequire(import.meta.url); + +let nativeMod: typeof import('@livekit/local-inference') | undefined; +let triedLoad = false; + +function getNative(): typeof import('@livekit/local-inference') | undefined { + if (triedLoad) return nativeMod; + triedLoad = true; + try { + nativeMod = cjsRequire('@livekit/local-inference') as typeof import('@livekit/local-inference'); + return nativeMod; + } catch (err) { + log().warn( + { err: err instanceof Error ? err.message : String(err) }, + '@livekit/local-inference native binding not loadable; local VAD/EOT paths disabled', + ); + return undefined; + } +} + +/** @internal Returns the loaded native module, or `undefined` if unavailable. */ +export function _getLocalInferenceModule(): typeof import('@livekit/local-inference') | undefined { + return getNative(); +} diff --git a/agents/src/inference/eot/base.test.ts b/agents/src/inference/eot/base.test.ts new file mode 100644 index 000000000..94fad4e15 --- /dev/null +++ b/agents/src/inference/eot/base.test.ts @@ -0,0 +1,384 @@ +// SPDX-FileCopyrightText: 2026 LiveKit, Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +/** + * FSM tests for `AudioTurnDetectorStream`. + * + * Covers the warmup → activate → deactivate / flush lifecycle and the + * regression cases: + * + * - `deactivate()` from a pre-active state must stop the inference cleanly + * so a late prediction for the cancelled request isn't acted on by the + * next activate. + * - a confident prediction (at or above the per-language threshold) + * early-deactivates inline while active, or at `activate()` if it resolved + * during warmup. + * - `predictEndOfTurn` timeout must leave the FSM consistent so the next + * `warmup()` can proceed. + * + * Port of Python `tests/test_turn_detection_fsm.py`. + */ +import type { AudioFrame } from '@livekit/rtc-node'; +import { describe, expect, it } from 'vitest'; +import { + type AudioTurnDetectionTransport, + AudioTurnDetector, + AudioTurnDetectorStream, + type FlushSentinel, + Status, + type TurnDetectorOptions, +} from './base.js'; +import type { TurnDetectorModel } from './languages.js'; + +class FakeTransport implements AudioTurnDetectionTransport { + events: Array<[string, string]> = []; + private _stream: AudioTurnDetectorStream | undefined; + + attach(stream: AudioTurnDetectorStream): void { + this._stream = stream; + } + async run(): Promise { + if (this._stream === undefined) { + throw new Error('stream not bound'); + } + await this._stream._drainAudioChannel(); + } + startInference(requestId: string): void { + this.events.push(['start_inference', requestId]); + } + stopInference(reason?: string): void { + this.events.push(['stop_inference', reason ?? '']); + } + async pushFrame(_frame: AudioFrame): Promise { + // no-op + } + async flush(_sentinel: FlushSentinel): Promise { + // no-op + } + detach(): void { + // no-op + } +} + +class FakeDetector extends AudioTurnDetector { + get model(): TurnDetectorModel { + return 'turn-detector'; + } + stream(): AudioTurnDetectorStream { + throw new Error('unused in FSM tests'); + } +} + +class FakeBackend extends AudioTurnDetectorStream { + fakeTransport: FakeTransport; + + constructor(opts: TurnDetectorOptions) { + const transport = new FakeTransport(); + super({ detector: new FakeDetector(opts), opts, transport }); + this.fakeTransport = transport; + } + + get events(): Array<[string, string]> { + return this.fakeTransport.events; + } + + simulatePrediction(requestId: string, probability: number): void { + this._handlePrediction(requestId, probability); + } + + // Exposed for assertions. + get status(): Status { + return this._status; + } + get preemptiveRequestFut() { + return this._preemptiveRequestFut; + } +} + +function makeOpts(thresholds: Record = {}): TurnDetectorOptions { + return { sampleRate: 16000, thresholds }; +} + +function makeStream(thresholds: Record = {}): FakeBackend { + return new FakeBackend(makeOpts(thresholds)); +} + +/** Did the stream record an early-deactivate (`stop_inference` with the + * positive-EOT trigger)? */ +const earlyDeactivated = (events: Array<[string, string]>) => + events.some((e) => e[0] === 'stop_inference' && e[1] === 'positive eou prediction'); + +const countStartInference = (events: Array<[string, string]>) => + events.filter((e) => e[0] === 'start_inference').length; + +describe('AudioTurnDetectionFSM', () => { + it('warmup starts inference', async () => { + const s = makeStream(); + try { + const fut = s.warmup(); + expect(s.status).toBe(Status.IDLE); + expect(s.isInferenceRunning).toBe(true); + expect(s.preemptiveRequestId).toBeDefined(); + expect(fut.done).toBe(false); + expect(s.events).toEqual([['start_inference', s.preemptiveRequestId!]]); + } finally { + await s.aclose(); + } + }); + + it('warmup is idempotent', async () => { + const s = makeStream(); + try { + s.warmup(); + const firstId = s.preemptiveRequestId; + s.warmup(); + expect(s.preemptiveRequestId).toBe(firstId); + expect(countStartInference(s.events)).toBe(1); + } finally { + await s.aclose(); + } + }); + + it('activate from warmed up', async () => { + const s = makeStream(); + try { + s.warmup(); + s.activate('vad eos'); + expect(s.status).toBe(Status.ACTIVE); + expect(s.isInferenceRunning).toBe(true); + } finally { + await s.aclose(); + } + }); + + it('activate without warmup auto-warms-up', async () => { + const s = makeStream(); + try { + s.activate('manual'); + expect(s.status).toBe(Status.ACTIVE); + expect(countStartInference(s.events)).toBe(1); + } finally { + await s.aclose(); + } + }); + + it('deactivate during preemptive phase stops inference', async () => { + const s = makeStream(); + try { + s.warmup(); + s.deactivate('vad sos'); + expect(s.status).toBe(Status.IDLE); + expect(s.preemptiveRequestId).toBeUndefined(); + expect(s.isInferenceRunning).toBe(false); + expect(s.events).toContainEqual(['stop_inference', 'vad sos']); + } finally { + await s.aclose(); + } + }); + + it('late prediction after deactivate not acted on', async () => { + const s = makeStream({ en: 0.5 }); + try { + s.warmup(); + const cancelledId = s.preemptiveRequestId!; + s.deactivate('vad sos'); + + s.simulatePrediction(cancelledId, 0.9); + // Request-id mismatch → dropped, not cached for a later activate(). + expect(s.lastPrediction).toBeUndefined(); + + s.warmup(); + s.activate('vad eos'); + // No cached prediction for the fresh window → activate must not + // early-deactivate; inference stays running. + expect(s.isInferenceRunning).toBe(true); + expect(earlyDeactivated(s.events)).toBe(false); + } finally { + await s.aclose(); + } + }); + + it('deactivate when idle is a no-op', async () => { + const s = makeStream(); + try { + s.deactivate('vad sos'); + expect(s.events).toEqual([]); + expect(s.status).toBe(Status.IDLE); + } finally { + await s.aclose(); + } + }); + + it('deactivate during warmup resolves future with zero', async () => { + const s = makeStream(); + try { + const fut = s.warmup(); + s.deactivate(); + expect(fut.done).toBe(true); + expect(await fut.await).toBe(0.0); + expect(s.status).toBe(Status.IDLE); + expect(s.preemptiveRequestId).toBeUndefined(); + } finally { + await s.aclose(); + } + }); + + it('predictEndOfTurn timeout leaves fsm consistent', async () => { + const s = makeStream(); + try { + const prob = await s.predictEndOfTurn(undefined, { timeoutMs: 10 }); + expect(prob).toBe(1.0); + expect(s.status).toBe(Status.IDLE); + expect(s.preemptiveRequestId).toBeUndefined(); + expect(s.preemptiveRequestFut).toBeUndefined(); + expect(s.events).toContainEqual(['stop_inference', 'predict_end_of_turn timeout']); + } finally { + await s.aclose(); + } + }); + + it('predictEndOfTurn timeout allows next warmup', async () => { + const s = makeStream(); + try { + await s.predictEndOfTurn(undefined, { timeoutMs: 10 }); + const fut = s.warmup(); + expect(s.preemptiveRequestId).toBeDefined(); + expect(fut.done).toBe(false); + } finally { + await s.aclose(); + } + }); + + it('flush deactivates and emits stop_inference', async () => { + const s = makeStream(); + try { + s.warmup(); + s.activate(); + s.flush('turn committed'); + expect(s.status).toBe(Status.IDLE); + expect(s.isInferenceRunning).toBe(false); + expect(s.events).toContainEqual(['stop_inference', 'turn committed']); + } finally { + await s.aclose(); + } + }); + + it('positive prediction while active early-deactivates', async () => { + const s = makeStream({ en: 0.5 }); + try { + s.warmup(); + s.activate('vad eos'); + const requestId = s.preemptiveRequestId!; + + s.simulatePrediction(requestId, 0.9); // >= 0.5 + expect(s.isInferenceRunning).toBe(false); + expect(s.events).toContainEqual(['stop_inference', 'positive eou prediction']); + } finally { + await s.aclose(); + } + }); + + it('subthreshold prediction while active keeps running', async () => { + const s = makeStream({ en: 0.5 }); + try { + s.warmup(); + s.activate('vad eos'); + const requestId = s.preemptiveRequestId!; + + s.simulatePrediction(requestId, 0.3); // < 0.5 + expect(s.isInferenceRunning).toBe(true); + expect(s.lastPrediction?.endOfTurnProbability).toBe(0.3); + expect(earlyDeactivated(s.events)).toBe(false); + } finally { + await s.aclose(); + } + }); + + it('preemptive positive prediction acted on at activate', async () => { + const s = makeStream({ en: 0.5 }); + try { + s.warmup(); + const requestId = s.preemptiveRequestId!; + s.simulatePrediction(requestId, 0.9); + // Cached, but not active yet → inference still running. + expect(s.isInferenceRunning).toBe(true); + expect(s.lastPrediction?.endOfTurnProbability).toBe(0.9); + + s.activate('vad eos'); + expect(s.isInferenceRunning).toBe(false); + expect(s.events).toContainEqual(['stop_inference', 'positive eou prediction']); + } finally { + await s.aclose(); + } + }); +}); + +describe('PredictOnSilenceGuard', () => { + it('predict short-circuits after flush', async () => { + const s = makeStream(); + try { + s.flush('turn committed'); + const prob = await s.predictEndOfTurn(undefined, { timeoutMs: 1000 }); + expect(prob).toBe(1.0); + expect(countStartInference(s.events)).toBe(0); + expect(s.preemptiveRequestId).toBeUndefined(); + } finally { + await s.aclose(); + } + }); + + it('predict runs after deactivate("vad sos")', async () => { + const s = makeStream(); + try { + s.flush('turn committed'); + s.deactivate('vad sos'); + const prob = await s.predictEndOfTurn(undefined, { timeoutMs: 10 }); + expect(prob).toBe(1.0); // timed out + expect(countStartInference(s.events)).toBe(1); + } finally { + await s.aclose(); + } + }); + + it('predict returns cached prediction before short-circuit', async () => { + const s = makeStream(); + try { + s.warmup(); + const requestId = s.preemptiveRequestId!; + s.simulatePrediction(requestId, 0.4); + const prob = await s.predictEndOfTurn(undefined, { timeoutMs: 1000 }); + expect(prob).toBe(0.4); + } finally { + await s.aclose(); + } + }); + + it('deactivate("vad sos") stops in-flight inference', async () => { + const s = makeStream(); + try { + s.warmup(); + s.activate(); + expect(s.isInferenceRunning).toBe(true); + + s.deactivate('vad sos'); + + expect(s.isInferenceRunning).toBe(false); + expect(s.status).toBe(Status.IDLE); + expect(s.events).toContainEqual(['stop_inference', 'vad sos']); + } finally { + await s.aclose(); + } + }); + + it('initial state does not short-circuit', async () => { + const s = makeStream(); + try { + const prob = await s.predictEndOfTurn(undefined, { timeoutMs: 10 }); + expect(prob).toBe(1.0); // timeout default + expect(countStartInference(s.events)).toBe(1); + } finally { + await s.aclose(); + } + }); +}); diff --git a/agents/src/inference/eot/base.ts b/agents/src/inference/eot/base.ts new file mode 100644 index 000000000..6b52685f9 --- /dev/null +++ b/agents/src/inference/eot/base.ts @@ -0,0 +1,710 @@ +// SPDX-FileCopyrightText: 2026 LiveKit, Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +/** + * Audio EOT (end-of-turn) detector base, stream state machine, and the + * transport interface that concrete cloud/local backends implement. + * + * Concrete implementations live in `agents/src/inference/eot/`. + * + * Port of Python `livekit.agents.voice.turn.audio`. + */ +import type { AudioFrame } from '@livekit/rtc-node'; +import { AudioResampler, AudioResamplerQuality } from '@livekit/rtc-node'; +import type { TypedEventEmitter as TypedEmitter } from '@livekit/typed-emitter'; +import { EventEmitter } from 'node:events'; +import type { LanguageCode } from '../../language.js'; +import type { ChatContext } from '../../llm/chat_context.js'; +import { log } from '../../log.js'; +import type { EOTInferenceMetrics } from '../../metrics/base.js'; +import { type StreamChannel, createStreamChannel } from '../../stream/stream_channel.js'; +import { Future, Task, cancelAndWait, shortuuid } from '../../utils.js'; +import type { TurnDetectorModel } from './languages.js'; + +export const DEFAULT_SAMPLE_RATE = 16000; +export const MIN_SILENCE_DURATION_MS = 200; + +export enum Status { + IDLE = 'idle', + ACTIVE = 'active', +} + +/** + * Options shared by the audio EOT stream and every transport. + * + * Cloud-only transport concerns (base URL, credentials, conn options) + * live on a separate options class owned by the cloud transport. + */ +export interface TurnDetectorOptions { + sampleRate: number; + thresholds: Record; +} + +/** + * Event emitted on each EOT prediction. + */ +export interface TurnDetectionEvent { + type: 'eot_prediction'; + endOfTurnProbability: number; + /** Wall-clock time when the prediction landed (milliseconds since epoch). */ + lastSpeakingTimeMs: number; + /** Latest input-audio creation time → prediction receive time (ms). */ + detectionDelay?: number; + /** Server-side model inference time (ms). */ + inferenceDuration?: number; +} + +/** + * Sentinel value carried alongside flush requests. Transports use + * `keepTailMs` to optionally retain trailing audio for the next turn. + */ +export interface FlushSentinel { + readonly kind: 'flush'; + reason?: string; + keepTailMs: number; +} + +export function isFlushSentinel(value: unknown): value is FlushSentinel { + return typeof value === 'object' && value !== null && (value as FlushSentinel).kind === 'flush'; +} + +/** + * Transport adapter for `AudioTurnDetectorStream` — owns the I/O (WebSocket + * session, in-process predict, etc.). The stream calls these methods + * directly; transports report predictions back via + * `stream._handlePrediction(requestId, probability, ...)`. + */ +export interface AudioTurnDetectionTransport { + attach(stream: AudioTurnDetectorStream): void; + run(): Promise; + startInference(requestId: string): void; + pushFrame(frame: AudioFrame): Promise; + flush(sentinel: FlushSentinel): Promise; + stopInference(reason?: string): void; + detach(): void; +} + +export type AudioTurnDetectorCallbacks = { + metrics_collected: (metrics: EOTInferenceMetrics) => void; +}; + +/** + * Abstract base for audio EOT detectors. Holds the threshold table and + * provides `stream()` to create a per-turn FSM instance. + * + * Subclasses (`AudioTurnDetector` in `inference/eot/detector.ts`) wire up + * concrete transports. + */ +export abstract class AudioTurnDetector extends (EventEmitter as new () => TypedEmitter) { + protected _opts: TurnDetectorOptions; + /** + * Active streams the detector tracks for bulk teardown via `aclose()`. + * `Set` rather than `WeakSet` because we need iteration; each stream + * removes itself on its own `aclose` (see `AudioTurnDetectorStream.aclose`) + * so the strong refs are released without requiring the caller to call + * `detector.aclose()`. + */ + protected _streams: Set = new Set(); + + constructor(opts: TurnDetectorOptions) { + super(); + this._opts = opts; + } + + /** @internal Stream lifecycle hook — called by the stream itself on close. */ + _unregisterStream(stream: AudioTurnDetectorStream): void { + this._streams.delete(stream); + } + + abstract get model(): TurnDetectorModel; + + get provider(): string { + return 'livekit'; + } + + /** Most-recent threshold map (after any cloud→local fallback rescale). */ + get thresholds(): Record { + return this._opts.thresholds; + } + + /** Threshold below which the detector treats the prediction as "unlikely + * to be end-of-turn". Returns `undefined` when the language isn't covered. */ + async unlikelyThreshold(language: LanguageCode | undefined): Promise { + const key = language ?? 'en'; + return this._opts.thresholds[key]; + } + + async supportsLanguage(language: LanguageCode | undefined): Promise { + return (await this.unlikelyThreshold(language)) !== undefined; + } + + abstract stream(): AudioTurnDetectorStream; + + async aclose(): Promise { + const streams = Array.from(this._streams); + this._streams.clear(); + await Promise.allSettled(streams.map((s) => s.aclose())); + } +} + +/** + * Per-turn FSM: + * + * - `warmup()` opens an inference window (transport.startInference). + * - `activate(trigger)` flips IDLE→ACTIVE; commits early if a confident + * prediction already resolved during warmup. + * - `deactivate(trigger)` clears the request id, resolves the in-flight + * future with 0.0, calls transport.stopInference. + * - `flush(reason, keepTailMs)` deactivates and signals turn boundary to + * the transport via a `FlushSentinel`. Clears the cached prediction so + * it can't leak into the next turn. + * - `predictEndOfTurn(chatCtx?, { timeoutMs })` returns a probability, + * defaulting to 1.0 on timeout. + */ +export class SwapAbortError extends Error { + constructor() { + super('__swap__'); + this.name = 'SwapAbortError'; + } +} + +export class AudioTurnDetectorStream { + protected _detector: AudioTurnDetector; + protected _opts: TurnDetectorOptions; + protected _transport: AudioTurnDetectionTransport; + + private _audioInputSampleRate: number | undefined; + private _audioInputNumChannels: number | undefined; + private _audioResampler: AudioResampler | undefined; + private _audioChannel: StreamChannel = createStreamChannel(); + + protected _status: Status = Status.IDLE; + protected _preemptiveRequestId: string | undefined; + protected _preemptiveRequestFut: Future | undefined; + /** + * Latest resolved prediction in the current inference window. Cleared + * when a new window starts (next warmup) or on commit (flush). Lets + * `predictEndOfTurn` return immediately when a prediction is already + * in hand. + */ + protected _lastPrediction: TurnDetectionEvent | undefined; + /** + * Most recent detected language, pushed by `AudioRecognition` on each STT + * transcript. Used by the inline early-deactivation check (`_isLikely`) to + * resolve the per-language unlikely-EOT threshold. + */ + protected _lastLanguage: LanguageCode | undefined; + /** + * True between VAD start-of-speech (when `deactivate('vad sos')` re-arms it) + * and the next `flush()` — i.e. a user turn is open and `predictEndOfTurn` + * should run. When false, predict short-circuits to a positive default (the + * audio EOT model has already committed; an STT final arriving after has + * nothing fresh to evaluate). Initialized true so the first turn isn't + * gated before any flush. + */ + protected _userTurnStarted = true; + /** Warn once per stream when predict is called after a commit. */ + protected _latePredictWarned = false; + + protected _mainTask: Task; + protected _logger = log(); + /** + * Aborted whenever the main loop needs to retry on a new transport (e.g. + * fallback). The base FSM also aborts it from `aclose()` so idle + * transports that are awaiting forever can be unstuck. Listeners check + * `signal.aborted` and surface a sentinel rejection so the `_run` loop + * can decide whether to continue or exit. + */ + protected _swapController = new AbortController(); + + constructor(args: { + detector: AudioTurnDetector; + opts: TurnDetectorOptions; + transport: AudioTurnDetectionTransport; + }) { + this._detector = args.detector; + this._opts = args.opts; + this._transport = args.transport; + this._transport.attach(this); + + this._mainTask = Task.from((controller) => this._mainTaskBody(controller)); + } + + // region: _TurnDetector protocol proxies + + get model(): TurnDetectorModel { + return this._detector.model; + } + + get provider(): string { + return this._detector.provider; + } + + async unlikelyThreshold(language: LanguageCode | undefined): Promise { + const key = language ?? 'en'; + return this._opts.thresholds[key]; + } + + async supportsLanguage(language: LanguageCode | undefined): Promise { + return (await this.unlikelyThreshold(language)) !== undefined; + } + + /** + * Record the most recent detected language so the inline early-deactivation + * check can resolve the unlikely-EOT threshold. Pushed by `AudioRecognition` + * on each STT transcript. + */ + updateLanguage(language: LanguageCode | undefined): void { + this._lastLanguage = language; + } + + /** + * A prediction at or above `unlikelyThreshold` is no longer "unlikely" — it's + * a confident end-of-turn. Mirrors that method's `undefined → "en"` fallback: + * an unknown language still gets the English threshold; an explicitly + * unsupported code misses the table and is never treated as likely. + */ + protected _isLikely(probability: number): boolean { + const key = this._lastLanguage ?? 'en'; + const threshold = this._opts.thresholds[key]; + return threshold !== undefined && probability >= threshold; + } + + // endregion + + // region: state machine + + get isActive(): boolean { + return this._status === Status.ACTIVE; + } + + get isInferenceRunning(): boolean { + return this._preemptiveRequestId !== undefined; + } + + get preemptiveRequestId(): string | undefined { + return this._preemptiveRequestId; + } + + get status(): Status { + return this._status; + } + + get lastPrediction(): TurnDetectionEvent | undefined { + return this._lastPrediction; + } + + /** Start an inference window if one isn't already open. Returns the + * in-flight future. Idempotent. */ + warmup(): Future { + if (this._preemptiveRequestId === undefined) { + const requestId = shortuuid('turn_request_'); + this._preemptiveRequestId = requestId; + this._preemptiveRequestFut = new Future(); + // New inference window — drop any cached prediction from the previous + // window so `predictEndOfTurn` won't return stale. + this._lastPrediction = undefined; + this._transport.startInference(requestId); + } + if (this._preemptiveRequestFut === undefined) { + throw new Error('eot detection warmup failed, no request future'); + } + return this._preemptiveRequestFut; + } + + activate(_trigger?: string): void { + if (this._status === Status.ACTIVE) { + return; + } + if (this._preemptiveRequestId === undefined) { + this._logger.trace( + 'eot detector not warmed up before activation, likely due to overlapping speech', + ); + this.warmup(); + } + this._status = Status.ACTIVE; + // A prediction may have resolved during the preemptive warmup window, + // before activation. We deliberately hold off acting on the threshold + // until now: a confident EOT only commits once VAD confirms end-of-speech + // (the trigger that calls `activate`). + if ( + this._lastPrediction !== undefined && + this._isLikely(this._lastPrediction.endOfTurnProbability) + ) { + this.deactivate('positive eou prediction'); + } + } + + deactivate(trigger?: string): void { + // Mirror Python: clear the "turn committed" guard at the top so a VAD + // start-of-speech (which calls `deactivate('vad sos')`) re-arms the + // user turn even if the FSM was already idle. + this._userTurnStarted = true; + if (this._preemptiveRequestId === undefined && this._status === Status.IDLE) { + return; + } + this._preemptiveRequestId = undefined; + if (this._preemptiveRequestFut !== undefined) { + if (!this._preemptiveRequestFut.done) { + this._preemptiveRequestFut.resolve(0.0); + } + this._preemptiveRequestFut = undefined; + } + this._status = Status.IDLE; + this._transport.stopInference(trigger); + } + + flush(reason?: string, opts: { keepTailMs?: number } = {}): void { + if (this._audioChannel.closed) { + return; + } + const keepTailMs = opts.keepTailMs ?? 0; + for (const resampled of this._flushAudioResampler()) { + void this._audioChannel.write(resampled); + } + const sentinel: FlushSentinel = { + kind: 'flush', + reason, + keepTailMs, + }; + void this._audioChannel.write(sentinel); + // Turn boundary — the cached prediction belongs to the turn we just + // closed and must not leak into the next one. + this._lastPrediction = undefined; + this.deactivate(reason); + // Close the user turn AFTER deactivate (which re-arms the guard on its + // way out): until the next VAD start-of-speech calls `deactivate('vad sos')` + // to flip it back on, `predictEndOfTurn` short-circuits. + this._userTurnStarted = false; + } + + // endregion + + // region: audio ingress + + pushAudio(frame: AudioFrame): void { + if (this._audioChannel.closed) { + return; + } + for (const resampled of this._resampleAudioFrame(frame)) { + void this._audioChannel.write(resampled); + } + } + + endInput(): void { + this.flush(); + void this._audioChannel.close(); + } + + private _resampleAudioFrame(frame: AudioFrame): AudioFrame[] { + if (this._audioInputSampleRate === undefined || this._audioInputNumChannels === undefined) { + this._audioInputSampleRate = frame.sampleRate; + this._audioInputNumChannels = frame.channels; + if (this._audioInputSampleRate !== this._opts.sampleRate) { + this._audioResampler = new AudioResampler( + this._audioInputSampleRate, + this._opts.sampleRate, + this._audioInputNumChannels, + AudioResamplerQuality.QUICK, + ); + } + } else if ( + frame.sampleRate !== this._audioInputSampleRate || + frame.channels !== this._audioInputNumChannels + ) { + this._logger.error( + { + sampleRate: frame.sampleRate, + expectedSampleRate: this._audioInputSampleRate, + numChannels: frame.channels, + expectedNumChannels: this._audioInputNumChannels, + }, + 'a frame with different audio format was already pushed', + ); + return []; + } + if (this._audioResampler === undefined) { + return [frame]; + } + return this._audioResampler.push(frame); + } + + private _flushAudioResampler(): AudioFrame[] { + const frames = this._audioResampler?.flush() ?? []; + this._resetAudioResampler(); + return frames; + } + + private _resetAudioResampler(): void { + this._audioResampler = undefined; + this._audioInputSampleRate = undefined; + this._audioInputNumChannels = undefined; + } + + // endregion + + // region: results + + /** + * Accept a prediction from a transport. The stream owns dedup (by + * requestId), future resolution, and the inline early-deactivate. + */ + _handlePrediction( + requestId: string, + probability: number, + opts: { inferenceDuration?: number; detectionDelay?: number } = {}, + ): void { + // Drop predictions that land after teardown — an in-flight transport + // predict can resolve after `aclose` closed the channels. + if (this._closing) { + return; + } + if (requestId !== this._preemptiveRequestId) { + return; + } + if (this._preemptiveRequestFut !== undefined && !this._preemptiveRequestFut.done) { + this._preemptiveRequestFut.resolve(probability); + } + const event: TurnDetectionEvent = { + type: 'eot_prediction', + endOfTurnProbability: probability, + lastSpeakingTimeMs: Date.now(), + detectionDelay: opts.detectionDelay, + inferenceDuration: opts.inferenceDuration, + }; + this._lastPrediction = event; + // Early-deactivate: stop inference as soon as a confident EOT lands so a + // later intra-speech silence can warm up a fresh window. Only while active + // — predictions during preemptive warmup are cached and re-checked in + // `activate()`. `deactivate` just sends a non-blocking `stopInference`, so + // calling it inline from the transport's prediction callback is safe (no + // reentrant await). + if (this.isActive && this._isLikely(probability)) { + this.deactivate('positive eou prediction'); + } + } + + /** + * Run a warmup inference and wait for a prediction within `timeoutMs`. + * + * Returns the cached prediction if one has already arrived for the + * current inference window. `chatCtx` is accepted (and ignored) so the + * call site stays uniform with text-based `_TurnDetector` impls. + */ + async predictEndOfTurn( + _chatCtx?: ChatContext, + optsOrTimeoutMs?: { timeoutMs?: number } | number, + ): Promise { + // Accept both the options-bag form (FSM-native) and the positional-ms + // form (matches the `_TurnDetector` Protocol so audio detectors are a + // drop-in for text-based detectors). + const opts: { timeoutMs?: number } = + typeof optsOrTimeoutMs === 'number' ? { timeoutMs: optsOrTimeoutMs } : optsOrTimeoutMs ?? {}; + if (this._lastPrediction !== undefined) { + return this._lastPrediction.endOfTurnProbability; + } + if (!this._userTurnStarted) { + if (!this._latePredictWarned) { + this._latePredictWarned = true; + this._logger.warn( + 'predictEndOfTurn called after the audio eot model already committed ' + + 'the turn (likely a late stt final). consider raising `minDelay` in ' + + 'the endpointing options to accommodate slow stt. subsequent ' + + 'occurrences on this stream will log at debug level.', + ); + } else { + this._logger.debug('stt transcript arrived after a turn commit, short-circuiting'); + } + return 1.0; + } + + const timeoutMs = opts.timeoutMs ?? 500; + let fut: Future | undefined; + let timeoutId: ReturnType | undefined; + try { + fut = this.warmup(); + this.activate(); + const winner = await Promise.race([ + fut.await.then((v) => ({ kind: 'value', v }) as const), + new Promise<{ kind: 'timeout' }>((resolve) => { + timeoutId = setTimeout(() => resolve({ kind: 'timeout' }), timeoutMs); + }), + ]); + if (winner.kind === 'value') { + return winner.v; + } + throw new Error('__eot_predict_timeout__'); + } catch (err) { + const isTimeout = err instanceof Error && err.message === '__eot_predict_timeout__'; + if (!isTimeout) throw err; + // Contract on timeout: we couldn't tell within `timeoutMs`, so assume + // the turn is over. Resolve the future with 1.0 (so any concurrent + // waiter sees the same value) and deactivate the inference window + // (a stale prediction arriving later must not fire an event). + this._logger.warn( + { + timeoutMs, + requestId: this._preemptiveRequestId, + default: 1.0, + }, + 'eot prediction timed out, returning a default value', + ); + if (fut !== undefined && !fut.done) { + fut.resolve(1.0); + } + this.deactivate('predict_end_of_turn timeout'); + this._onPredictTimeout(); + // Positive default so minEndpointingDelay applies. + return 1.0; + } finally { + // Always release the timer — on the value path the timeout would + // otherwise keep the event loop alive until it fires, and N + // concurrent turns would queue N pending timers. + if (timeoutId !== undefined) clearTimeout(timeoutId); + } + } + + // endregion + + // region: teardown + + /** + * Synchronously release this stream's registration on its owning detector, + * so a replacement stream can be created before this one's async teardown + * finishes. Base is a no-op; detectors that enforce single-stream ownership + * override it. Idempotent. + */ + detach(): void { + return; + } + + async aclose(): Promise { + this.endInput(); + this._closing = true; + this._swapController.abort(); + await cancelAndWait([this._mainTask]); + if (this._preemptiveRequestFut !== undefined && !this._preemptiveRequestFut.done) { + this._preemptiveRequestFut.resolve(0.0); + } + this._preemptiveRequestFut = undefined; + this._preemptiveRequestId = undefined; + this._status = Status.IDLE; + // Drop our strong reference on the parent detector so callers that + // forget `detector.aclose()` don't leak the stream graph. + this._detector._unregisterStream(this); + } + + /** True once `aclose()` has been called. The `_run` loop uses this to + * distinguish swap-aborts (continue with new transport) from teardown + * aborts (exit). */ + protected _closing = false; + + // endregion + + // region: main task scaffolding + + private async _mainTaskBody(_controller: AbortController): Promise { + await this._run(); + } + + /** + * Drain the shared audio channel into the current transport. + * + * The audio channel exposes a single `ReadableStream` (one underlying + * `transform.readable`), so only one reader may hold its lock at a time. + * When `signal` aborts (a transport being swapped out — e.g. cloud→local + * fallback — fires it via `detach()`), we release the reader lock right + * away: on a pending `read()` this rejects that read and frees the lock so + * the swapped-in transport's `_drainAudioChannel` can re-acquire it. + * Without this an orphaned drain would hold the lock forever and the next + * `getReader()` would throw "ReadableStream is locked". + */ + async _drainAudioChannel(signal?: AbortSignal): Promise { + const stream = this._audioChannel.stream(); + const reader = stream.getReader(); + const release = () => { + try { + reader.releaseLock(); + } catch { + // already released + } + }; + if (signal?.aborted) { + release(); + return; + } + signal?.addEventListener('abort', release, { once: true }); + try { + while (true) { + const { done, value } = await reader.read(); + if (done) return; + if (isFlushSentinel(value)) { + await this._transport.flush(value); + } else { + await this._transport.pushFrame(value); + } + } + } catch (err) { + // The pending `read()` rejects when `release()` runs on abort — a clean + // swap-driven exit, not a drain failure. + if (signal?.aborted) return; + throw err; + } finally { + signal?.removeEventListener('abort', release); + release(); + } + } + + // endregion + + // region: subclass hooks + + /** Default: hand control to the transport. Subclasses override for + * cross-transport orchestration (e.g. cloud→local fallback). */ + protected async _run(): Promise { + await this._raceWithSwap(this._transport.run()); + } + + /** + * Race `inner` against `_swapController.signal`. If the signal aborts + * while `inner` is still pending, throw a `SwapAbortError` so the + * subclass loop can decide whether to continue or exit. Resets the + * controller after a swap-abort so subsequent races have a fresh signal. + * + * `aclose()` aborts during teardown — subclasses observe `_closing` to + * exit cleanly instead of looping. + */ + protected async _raceWithSwap(inner: Promise): Promise { + const signal = this._swapController.signal; + const abortPromise = new Promise((_, reject) => { + if (signal.aborted) { + reject(new SwapAbortError()); + return; + } + signal.addEventListener('abort', () => reject(new SwapAbortError()), { once: true }); + }); + try { + return await Promise.race([inner, abortPromise]); + } finally { + if (signal.aborted) { + // Reset for the next iteration of the subclass loop. + this._swapController = new AbortController(); + } + } + } + + /** @internal Wake up an idle transport so the main loop can pick up a + * new one after fallback. Subclasses call this from their swap logic. */ + protected _signalSwap(): void { + this._swapController.abort(); + } + + /** `predictEndOfTurn` timed out. Subclasses may override to react (e.g. + * promote local on cloud timeout). */ + protected _onPredictTimeout(): void { + return; + } + + // endregion +} diff --git a/agents/src/inference/eot/detector.test.ts b/agents/src/inference/eot/detector.test.ts new file mode 100644 index 000000000..dc23c252e --- /dev/null +++ b/agents/src/inference/eot/detector.test.ts @@ -0,0 +1,535 @@ +// SPDX-FileCopyrightText: 2026 LiveKit, Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +/** + * Tests for the unified `AudioTurnDetector` (auto-select + fallback). + * + * Covers: + * + * - Auto-select via `LIVEKIT_REMOTE_EOT_URL` env var (with creds present, + * with creds missing → silent downgrade). + * - Explicit-cloud missing creds throws. + * - Cloud → local fallback triggers (transport raise, predict timeout). + * - Fallback persistence across turns. + * - Local-failure handling (default 1.0, retry on next turn). + * - Per-session warning dedupe (one warning per failure mode). + * - Threshold scaling: pass-through for cloud / explicit-local, multiplicative + * scaling only on actual fallback. + * + * Port of Python `tests/test_audio_turn_detector_fallback.py`. + */ +import { AudioFrame } from '@livekit/rtc-node'; +import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest'; +import { APIConnectionError } from '../../_exceptions.js'; +import type { InferenceExecutor } from '../../ipc/inference_executor.js'; +import { DEFAULT_API_CONNECT_OPTIONS } from '../../types.js'; +import type { AudioTurnDetectorStream } from './base.js'; +import { + type AudioTurnDetectionTransport, + type FlushSentinel, + type TurnDetectorOptions, +} from './base.js'; +import { AudioTurnDetector, AudioTurnDetectorStreamImpl } from './detector.js'; +import { CLOUD_LANGUAGES, LOCAL_LANGUAGES, materializeThresholds } from './languages.js'; +import { EOT_INFERENCE_METHOD } from './runner.js'; +import { LocalTransport } from './transports.js'; + +async function waitFor(predicate: () => boolean, ticks = 50): Promise { + for (let i = 0; i < ticks; i++) { + if (predicate()) return; + await new Promise((r) => setImmediate(r)); + } +} + +interface ScriptedTransportOptions { + runBehavior?: 'idle' | 'raise' | 'return'; + runExc?: Error; +} + +class ScriptedTransport implements AudioTurnDetectionTransport { + runBehavior: 'idle' | 'raise' | 'return'; + runExc: Error | undefined; + runCalls = 0; + events: Array<[string, unknown]> = []; + private _stream: AudioTurnDetectorStream | undefined; + + constructor(opts: ScriptedTransportOptions = {}) { + this.runBehavior = opts.runBehavior ?? 'idle'; + this.runExc = opts.runExc; + } + + attach(stream: AudioTurnDetectorStream): void { + this._stream = stream; + } + async run(): Promise { + this.runCalls += 1; + if (this.runBehavior === 'raise') { + if (!this.runExc) throw new Error('runExc not set'); + throw this.runExc; + } + if (this.runBehavior === 'return') { + return; + } + // idle — wait until cancelled (resolved by `detach()` via the + // scripted transport's no-op; in our tests the parent stream + // cancels via `aclose`). + await new Promise(() => undefined); + } + startInference(requestId: string): void { + this.events.push(['start_inference', requestId]); + } + async pushFrame(frame: AudioFrame): Promise { + this.events.push(['push_frame', frame]); + } + async flush(sentinel: FlushSentinel): Promise { + this.events.push(['flush', sentinel]); + } + stopInference(reason?: string): void { + this.events.push(['stop_inference', reason]); + } + detach(): void { + this.events.push(['detach', null]); + } +} + +function makeOpts(thresholds?: Record): TurnDetectorOptions { + return { sampleRate: 16000, thresholds: thresholds ?? {} }; +} + +interface MakeStreamOpts { + model?: 'turn-detector' | 'turn-detector-mini'; + userThreshold?: number | Record; + detector?: AudioTurnDetector; +} + +function makeStreamWithTransport( + transport: AudioTurnDetectionTransport, + opts: MakeStreamOpts = {}, +): AudioTurnDetectorStreamImpl { + const model = opts.model ?? 'turn-detector'; + const detector = + opts.detector ?? + makeMockDetector(model, makeOpts(materializeThresholds(opts.userThreshold, model))); + const stream = new AudioTurnDetectorStreamImpl({ + detector, + opts: detector['_opts'] as TurnDetectorOptions, + cloudOpts: + model === 'turn-detector' + ? { + baseUrl: 'ws://test', + apiKey: 'x', + apiSecret: 'x', + connOptions: DEFAULT_API_CONNECT_OPTIONS, + } + : undefined, + model, + transport, + }); + return stream; +} + +/** Build an `AudioTurnDetector` for assertions without going through env + * resolution — useful when we want a specific model + threshold table for + * a stream we'll build separately. */ +function makeMockDetector( + model: 'turn-detector' | 'turn-detector-mini', + opts: TurnDetectorOptions, +): AudioTurnDetector { + // Construct via the public constructor, then override the internal + // model + threshold view to match what we want for the assertion. + const originalEnv = { ...process.env }; + if (model === 'turn-detector-mini') { + delete process.env.LIVEKIT_REMOTE_EOT_URL; + } else { + process.env.LIVEKIT_REMOTE_EOT_URL = 'ws://test'; + process.env.LIVEKIT_API_KEY = 'x'; + process.env.LIVEKIT_API_SECRET = 'x'; + } + const det = new AudioTurnDetector(); + process.env = originalEnv; + const internals = det as unknown as { _model: typeof model; _opts: TurnDetectorOptions }; + internals._model = model; + internals._opts = { ...internals._opts, thresholds: opts.thresholds }; + return det; +} + +function withEnv( + overrides: Record, + fn: () => void | Promise, +): void | Promise { + const original = { ...process.env }; + for (const [k, v] of Object.entries(overrides)) { + if (v === undefined) delete process.env[k]; + else process.env[k] = v; + } + try { + const result = fn(); + if (result instanceof Promise) { + return result.finally(() => { + process.env = original; + }); + } + process.env = original; + return result; + } catch (err) { + process.env = original; + throw err; + } +} + +// Stub `LocalTransport.run` so the fallback FSM doesn't hang on a real +// drain loop. The behavior under test is the swap, not the post-swap I/O. +let runSpy: ReturnType; +beforeEach(() => { + runSpy = vi.spyOn(LocalTransport.prototype, 'run').mockImplementation(async () => undefined); +}); +afterEach(() => { + runSpy.mockRestore(); +}); + +describe('AutoSelect', () => { + it('selects local when no remote EOT url', () => { + void withEnv({ LIVEKIT_REMOTE_EOT_URL: undefined }, () => { + const detector = new AudioTurnDetector(); + expect(detector.model).toBe('turn-detector-mini'); + }); + }); + + it('selects cloud when remote EOT url set', () => { + void withEnv( + { + LIVEKIT_REMOTE_EOT_URL: 'ws://gateway', + LIVEKIT_API_KEY: 'k', + LIVEKIT_API_SECRET: 's', + }, + () => { + const detector = new AudioTurnDetector(); + expect(detector.model).toBe('turn-detector'); + }, + ); + }); + + it('downgrades to local when creds missing', () => { + void withEnv( + { + LIVEKIT_REMOTE_EOT_URL: 'ws://gateway', + LIVEKIT_API_KEY: undefined, + LIVEKIT_API_SECRET: undefined, + LIVEKIT_INFERENCE_API_KEY: undefined, + LIVEKIT_INFERENCE_API_SECRET: undefined, + }, + () => { + const detector = new AudioTurnDetector(); + expect(detector.model).toBe('turn-detector-mini'); + }, + ); + }); +}); + +describe('ExplicitModelErrors', () => { + it('explicit cloud missing creds throws', () => { + void withEnv( + { + LIVEKIT_REMOTE_EOT_URL: undefined, + LIVEKIT_API_KEY: undefined, + LIVEKIT_API_SECRET: undefined, + LIVEKIT_INFERENCE_API_KEY: undefined, + LIVEKIT_INFERENCE_API_SECRET: undefined, + }, + () => { + expect(() => new AudioTurnDetector({ model: 'turn-detector' })).toThrow(); + }, + ); + }); +}); + +describe('Fallback', () => { + it('fallback on transport error swaps to local', async () => { + const transport = new ScriptedTransport({ + runBehavior: 'raise', + runExc: new APIConnectionError({ message: 'boom' }), + }); + const stream = makeStreamWithTransport(transport); + await waitFor(() => stream.model === 'turn-detector-mini'); + expect(stream.model).toBe('turn-detector-mini'); + expect(stream.isFallback).toBe(true); + expect(stream.warnedCloudFailure).toBe(true); + expect(transport.events).toContainEqual(['detach', null]); + await stream.aclose(); + }); + + it('fallback on predict timeout', async () => { + const transport = new ScriptedTransport({ runBehavior: 'idle' }); + const stream = makeStreamWithTransport(transport); + const prob = await stream.predictEndOfTurn(undefined, { timeoutMs: 10 }); + expect(prob).toBe(1.0); + expect(stream.model).toBe('turn-detector-mini'); + expect(stream.isFallback).toBe(true); + await stream.aclose(); + }); + + it('fallback persists across turns', async () => { + const transport = new ScriptedTransport({ + runBehavior: 'raise', + runExc: new APIConnectionError({ message: 'boom' }), + }); + const stream = makeStreamWithTransport(transport); + await waitFor(() => stream.model === 'turn-detector-mini'); + expect(transport.runCalls).toBe(1); + stream.warmup(); + expect(stream.model).toBe('turn-detector-mini'); + await stream.aclose(); + }); +}); + +describe('MultiStreamOwnership', () => { + it('multiple streams can coexist', async () => { + let detector!: AudioTurnDetector; + withEnv({ LIVEKIT_REMOTE_EOT_URL: undefined }, () => { + detector = new AudioTurnDetector({ model: 'turn-detector-mini' }); + }); + // Detector no longer enforces single-stream ownership; fallback state lives + // on the stream itself so multiple streams off the same detector are safe. + const s1 = detector.stream(); + const s2 = detector.stream(); + await s1.aclose(); + await s2.aclose(); + }); +}); + +describe('DetectorViewReflectsConstructionDefaults', () => { + it('detector model/threshold stay at construction-time defaults across fallbacks', async () => { + let detector!: AudioTurnDetector; + await withEnv( + { + LIVEKIT_REMOTE_EOT_URL: 'ws://gateway', + LIVEKIT_API_KEY: 'k', + LIVEKIT_API_SECRET: 's', + }, + async () => { + detector = new AudioTurnDetector({ unlikelyThreshold: 0.5 }); + expect(detector.model).toBe('turn-detector'); + expect(await detector.unlikelyThreshold('en')).toBeCloseTo(0.5); + }, + ); + + const transport = new ScriptedTransport({ + runBehavior: 'raise', + runExc: new APIConnectionError({ message: 'boom' }), + }); + const stream = new AudioTurnDetectorStreamImpl({ + detector, + opts: (detector as unknown as { _opts: TurnDetectorOptions })._opts, + cloudOpts: undefined, + model: 'turn-detector', + transport, + }); + await waitFor(() => stream.model === 'turn-detector-mini'); + + // The stream reflects the fallback... + expect(stream.model).toBe('turn-detector-mini'); + const expected = LOCAL_LANGUAGES.en! * (0.5 / CLOUD_LANGUAGES.en!); + expect(await stream.unlikelyThreshold('en')).toBeCloseTo(expected); + + // ...but the detector still reports the construction-time defaults: the + // fallback state lives on the stream, never written back to the detector, + // so other streams off the same detector aren't corrupted. + expect(detector.model).toBe('turn-detector'); + expect(await detector.unlikelyThreshold('en')).toBeCloseTo(0.5); + await stream.aclose(); + }); +}); + +describe('LocalFailureRetry', () => { + it('local failure emits default and retries on next turn', async () => { + const transport = new ScriptedTransport({ + runBehavior: 'raise', + runExc: new Error('local boom'), + }); + const stream = makeStreamWithTransport(transport, { model: 'turn-detector-mini' }); + await waitFor(() => stream.warnedLocalFailure); + expect(stream.model).toBe('turn-detector-mini'); + expect(stream.isFallback).toBe(false); + expect(stream.warnedLocalFailure).toBe(true); + expect(stream.transport).toBe(transport); + await stream.aclose(); + }); +}); + +describe('WarningDedupe', () => { + it('cloud→local warning logged once per session', async () => { + const transport = new ScriptedTransport({ + runBehavior: 'raise', + runExc: new APIConnectionError({ message: 'boom' }), + }); + const stream = makeStreamWithTransport(transport); + await waitFor(() => stream.model === 'turn-detector-mini'); + // Trigger a second fallback path directly. + stream._fallBackToLocal(new APIConnectionError({ message: 'boom2' })); + // Across both invocations only one warning was emitted — tracked by + // the `warnedCloudFailure` flag staying flipped after the first call. + expect(stream.warnedCloudFailure).toBe(true); + await stream.aclose(); + }); + + it('local warning logged once per session', async () => { + const transport = new ScriptedTransport({ runBehavior: 'idle' }); + const stream = makeStreamWithTransport(transport, { model: 'turn-detector-mini' }); + stream._onLocalFailure(new Error('a')); + stream._onLocalFailure(new Error('b')); + expect(stream.warnedLocalFailure).toBe(true); + await stream.aclose(); + }); +}); + +describe('ThresholdScaling', () => { + it('cloud user threshold is pass-through pre-stream', async () => { + await withEnv( + { + LIVEKIT_REMOTE_EOT_URL: 'ws://gateway', + LIVEKIT_API_KEY: 'k', + LIVEKIT_API_SECRET: 's', + }, + async () => { + const detector = new AudioTurnDetector({ unlikelyThreshold: 0.5 }); + expect(detector.model).toBe('turn-detector'); + const value = await detector.unlikelyThreshold('en'); + expect(value).toBeCloseTo(0.5); + }, + ); + }); + + it('explicit-local user threshold passes through (no rescale)', async () => { + await withEnv({ LIVEKIT_REMOTE_EOT_URL: undefined }, async () => { + const detector = new AudioTurnDetector({ + model: 'turn-detector-mini', + unlikelyThreshold: 0.5, + }); + const value = await detector.unlikelyThreshold('en'); + expect(value).toBeCloseTo(0.5); + }); + }); + + it('post-fallback threshold rescales on stream', async () => { + const transport = new ScriptedTransport({ + runBehavior: 'raise', + runExc: new APIConnectionError({ message: 'boom' }), + }); + const stream = makeStreamWithTransport(transport, { userThreshold: 0.5 }); + await waitFor(() => stream.model === 'turn-detector-mini'); + expect(stream.isFallback).toBe(true); + const value = await stream.unlikelyThreshold('en'); + const expected = LOCAL_LANGUAGES.en! * (0.5 / CLOUD_LANGUAGES.en!); + expect(value).toBeCloseTo(expected); + await stream.aclose(); + }); + + it('threshold default unchanged when user threshold not set', async () => { + await withEnv( + { + LIVEKIT_REMOTE_EOT_URL: 'ws://gateway', + LIVEKIT_API_KEY: 'k', + LIVEKIT_API_SECRET: 's', + }, + async () => { + const detector = new AudioTurnDetector(); + const cloudDefault = await detector.unlikelyThreshold('en'); + expect(cloudDefault).toBeCloseTo(CLOUD_LANGUAGES.en!); + }, + ); + + await withEnv({ LIVEKIT_REMOTE_EOT_URL: undefined }, async () => { + const detector = new AudioTurnDetector(); + const localDefault = await detector.unlikelyThreshold('en'); + expect(localDefault).toBeCloseTo(LOCAL_LANGUAGES.en!); + }); + }); +}); + +describe('ThresholdDictOverride', () => { + it('dict override applies per language', async () => { + await withEnv( + { + LIVEKIT_REMOTE_EOT_URL: 'ws://gateway', + LIVEKIT_API_KEY: 'k', + LIVEKIT_API_SECRET: 's', + }, + async () => { + const detector = new AudioTurnDetector({ + unlikelyThreshold: { en: 0.55, ja: 0.25 }, + }); + expect(await detector.unlikelyThreshold('en')).toBeCloseTo(0.55); + expect(await detector.unlikelyThreshold('ja')).toBeCloseTo(0.25); + expect(await detector.unlikelyThreshold('fr')).toBeCloseTo(CLOUD_LANGUAGES.fr!); + }, + ); + }); + + it('dict keys normalized via language code', async () => { + await withEnv({ LIVEKIT_REMOTE_EOT_URL: undefined }, async () => { + const detector = new AudioTurnDetector({ + unlikelyThreshold: { English: 0.55, 'en-US': 0.55 }, + }); + expect(await detector.unlikelyThreshold('en')).toBeCloseTo(0.55); + }); + }); + + it('dict override rescaled per language on fallback', async () => { + const transport = new ScriptedTransport({ + runBehavior: 'raise', + runExc: new APIConnectionError({ message: 'boom' }), + }); + const stream = makeStreamWithTransport(transport, { + userThreshold: { en: 0.55, ja: 0.25 }, + }); + await waitFor(() => stream.model === 'turn-detector-mini'); + expect(stream.isFallback).toBe(true); + expect(await stream.unlikelyThreshold('en')).toBeCloseTo( + LOCAL_LANGUAGES.en! * (0.55 / CLOUD_LANGUAGES.en!), + ); + expect(await stream.unlikelyThreshold('ja')).toBeCloseTo( + LOCAL_LANGUAGES.ja! * (0.25 / CLOUD_LANGUAGES.ja!), + ); + expect(await stream.unlikelyThreshold('fr')).toBeCloseTo(LOCAL_LANGUAGES.fr!); + await stream.aclose(); + }); +}); + +describe('LocalModelExecutor', () => { + function pcmFrame(samples = 320): AudioFrame { + return new AudioFrame(new Int16Array(samples), 16000, 1, samples); + } + + it('routes local predict through the injected executor (base64 PCM)', async () => { + const doInference = vi.fn(async (method: string, data: unknown) => { + expect(method).toBe(EOT_INFERENCE_METHOD); + expect(typeof (data as { pcm: string }).pcm).toBe('string'); + return { probability: 0.7, inferenceDurationMs: 5 }; + }); + const executor: InferenceExecutor = { doInference }; + const detector = new AudioTurnDetector({ model: 'turn-detector-mini', executor }); + const stream = detector.stream(); + try { + stream.pushAudio(pcmFrame()); + const p = await stream.predictEndOfTurn(undefined, { timeoutMs: 1000 }); + expect(p).toBe(0.7); + expect(doInference).toHaveBeenCalledWith(EOT_INFERENCE_METHOD, expect.anything()); + } finally { + await stream.aclose(); + } + }); + + it('degrades to a positive default when no executor is available', async () => { + // explicit undefined → constructor falls through to getJobContext() + // (throws outside a job) → executor stays undefined. + const detector = new AudioTurnDetector({ model: 'turn-detector-mini', executor: undefined }); + const stream = detector.stream(); + try { + const p = await stream.predictEndOfTurn(undefined, { timeoutMs: 1000 }); + expect(p).toBe(1.0); + } finally { + await stream.aclose(); + } + }); +}); diff --git a/agents/src/inference/eot/detector.ts b/agents/src/inference/eot/detector.ts new file mode 100644 index 000000000..a54d95791 --- /dev/null +++ b/agents/src/inference/eot/detector.ts @@ -0,0 +1,325 @@ +// SPDX-FileCopyrightText: 2026 LiveKit, Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +/** + * Audio end-of-turn detector with `turn-detector` → `turn-detector-mini` + * (cloud → local) fallback. + * + * Port of Python `livekit.agents.inference.eot.detector`. + */ +import type { InferenceExecutor } from '../../ipc/inference_executor.js'; +import { getJobContext } from '../../job.js'; +import type { LanguageCode } from '../../language.js'; +import { log } from '../../log.js'; +import { type APIConnectOptions, DEFAULT_API_CONNECT_OPTIONS } from '../../types.js'; +import { isDevMode, isHosted, resolveEnvVar } from '../../utils.js'; +import { + type AudioTurnDetectionTransport, + AudioTurnDetector as AudioTurnDetectorBase, + AudioTurnDetectorStream, + DEFAULT_SAMPLE_RATE, + SwapAbortError, + type TurnDetectorOptions, +} from './base.js'; +import { getDefaultInferenceUrl } from '../utils.js'; +import { + type TurnDetectorModel, + materializeThresholds, + rescaleForLocalFallback, +} from './languages.js'; +import { CloudTransport, type CloudTransportOptions, LocalTransport } from './transports.js'; + +export interface AudioTurnDetectorOptions { + /** + * Which turn-detector checkpoint to run. `'turn-detector'` is the full + * cloud model (served over the inference gateway); `'turn-detector-mini'` + * is the local in-process model. When omitted, auto-selects `'turn-detector'` + * on hosted/dev environments (falling back to `'turn-detector-mini'` if cloud + * creds are missing) and `'turn-detector-mini'` otherwise. + */ + model?: TurnDetectorModel; + unlikelyThreshold?: number | Record; + baseUrl?: string; + apiKey?: string; + apiSecret?: string; + /** Sample rate (Hz). Defaults to 16000. */ + sampleRate?: number; + connOptions?: APIConnectOptions; + /** + * Inference executor that runs the local `turn-detector-mini` model in the + * shared inference process. Defaults to the current job's + * `getJobContext().inferenceExecutor`. `undefined` (no job context / binding + * unavailable) degrades the local model to a positive-default prediction. + * Mainly an override seam for tests. + */ + executor?: InferenceExecutor; +} + +export class AudioTurnDetector extends AudioTurnDetectorBase { + protected _model: TurnDetectorModel; + protected _cloudOpts: CloudTransportOptions | undefined; + protected _executor: InferenceExecutor | undefined; + + constructor(opts: AudioTurnDetectorOptions = {}) { + // auto = caller didn't pin a model; missing cloud creds warn-and- + // fall-back instead of raising. + const auto = opts.model === undefined; + let resolvedModel: TurnDetectorModel = + opts.model ?? (isHosted() || isDevMode() ? 'turn-detector' : 'turn-detector-mini'); + + let cloudOpts: CloudTransportOptions | undefined; + if (resolvedModel === 'turn-detector') { + const baseUrl = resolveEnvVar( + opts.baseUrl, + ['LIVEKIT_INFERENCE_URL'], + getDefaultInferenceUrl(), + ); + const apiKey = resolveEnvVar(opts.apiKey, ['LIVEKIT_INFERENCE_API_KEY', 'LIVEKIT_API_KEY']); + const apiSecret = resolveEnvVar(opts.apiSecret, [ + 'LIVEKIT_INFERENCE_API_SECRET', + 'LIVEKIT_API_SECRET', + ]); + const missing: string[] = []; + if (!baseUrl) missing.push('LIVEKIT_INFERENCE_URL'); + if (!apiKey) missing.push('LIVEKIT_API_KEY'); + if (!apiSecret) missing.push('LIVEKIT_API_SECRET'); + if (missing.length > 0) { + if (auto) { + log().warn( + { missing }, + "LIVEKIT_INFERENCE_URL is set but creds are missing; falling back to 'turn-detector-mini'", + ); + resolvedModel = 'turn-detector-mini'; + } else { + throw new Error( + `AudioTurnDetector(model='turn-detector') requires ${missing.join(', ')} ` + + '(env or constructor argument).', + ); + } + } else { + cloudOpts = { + baseUrl, + apiKey, + apiSecret, + connOptions: opts.connOptions ?? DEFAULT_API_CONNECT_OPTIONS, + }; + } + } + + const detectorOpts: TurnDetectorOptions = { + sampleRate: opts.sampleRate ?? DEFAULT_SAMPLE_RATE, + thresholds: materializeThresholds(opts.unlikelyThreshold, resolvedModel), + }; + super(detectorOpts); + this._model = resolvedModel; + this._cloudOpts = cloudOpts; + // Default to the current job's shared inference executor. `getJobContext` + // throws outside a job (tests, standalone) — degrade to `undefined` + // (the local model then resolves a positive default) rather than throwing. + if (opts.executor !== undefined) { + this._executor = opts.executor; + } else { + try { + this._executor = getJobContext().inferenceExecutor; + } catch { + this._executor = undefined; + } + } + } + + /** Construction-time model. Per-stream `turn-detector`→`turn-detector-mini` + * (cloud→local) fallback state lives on the stream itself + * (`AudioTurnDetectorStreamImpl.model`) and is never written back here, so a + * fallback on one stream can't corrupt another stream spun off the same + * detector. */ + override get model(): TurnDetectorModel { + return this._model; + } + + override stream(opts: { connOptions?: APIConnectOptions } = {}): AudioTurnDetectorStream { + const cloudOpts = + this._cloudOpts !== undefined + ? { ...this._cloudOpts, connOptions: opts.connOptions ?? this._cloudOpts.connOptions } + : undefined; + const stream = new AudioTurnDetectorStreamImpl({ + detector: this, + opts: this._opts, + cloudOpts, + model: this._model, + executor: this._executor, + }); + this._streams.add(stream); + return stream; + } +} + +export interface AudioTurnDetectorStreamImplArgs { + detector: AudioTurnDetector; + opts: TurnDetectorOptions; + cloudOpts: CloudTransportOptions | undefined; + model: TurnDetectorModel; + /** Shared inference executor for the `turn-detector-mini` (local) model + * (undefined degrades to a positive-default prediction). */ + executor?: InferenceExecutor; + /** Optional transport override (for tests). When omitted, a transport is + * constructed from `model` + `cloudOpts`. */ + transport?: AudioTurnDetectionTransport; +} + +/** + * Stream that owns the `turn-detector` → `turn-detector-mini` (cloud → local) + * fallback FSM. On cloud transport failure (`transport.run()` raises, or + * `predictEndOfTurn` times out), the stream swaps the transport and rescales + * per-language thresholds. The fallback state lives on the stream; the detector + * reads it back through its active-stream delegation rather than being mutated. + */ +export class AudioTurnDetectorStreamImpl extends AudioTurnDetectorStream { + protected _model: TurnDetectorModel; + protected _cloudOpts: CloudTransportOptions | undefined; + protected _executor: InferenceExecutor | undefined; + protected _isFallback = false; + protected _warnedCloudFailure = false; + protected _warnedLocalFailure = false; + private _detLogger = log(); + + constructor(args: AudioTurnDetectorStreamImplArgs) { + const transport = + args.transport ?? + (args.model === 'turn-detector' + ? new CloudTransport({ + detector: args.detector, + opts: args.opts, + cloudOpts: args.cloudOpts!, + }) + : new LocalTransport({ opts: args.opts, executor: args.executor })); + super({ detector: args.detector, opts: args.opts, transport }); + this._model = args.model; + this._cloudOpts = args.cloudOpts; + this._executor = args.executor; + } + + /** This stream's *current* model (flips to `'turn-detector-mini'` after a + * cloud→local fallback). Self-contained: the stream owns its active model, + * so this reports `'turn-detector-mini'` after a fallback while the detector + * delegates here for its own view. */ + override get model(): TurnDetectorModel { + return this._model; + } + + get isFallback(): boolean { + return this._isFallback; + } + + /** @internal Test-visible. */ + get warnedCloudFailure(): boolean { + return this._warnedCloudFailure; + } + /** @internal Test-visible. */ + get warnedLocalFailure(): boolean { + return this._warnedLocalFailure; + } + /** @internal Test-visible. */ + get transport(): AudioTurnDetectionTransport { + return this._transport; + } + + /** @internal Test-visible: same logic as the path taken when `_run` catches + * a cloud transport error. Tests call this directly to verify the warning + * dedupe across multiple invocations on the same stream. */ + _fallBackToLocal(reason: Error): void { + if (!this._warnedCloudFailure) { + this._detLogger.warn( + { reason: reason.message }, + 'cloud audio eot failed; falling back to local mini model', + ); + this._warnedCloudFailure = true; + } + this._emitDefaultForInflight(); + try { + this._transport.detach(); + } catch { + // ignore detach errors during swap + } + const rescaled = rescaleForLocalFallback(this._opts.thresholds); + this._opts = { ...this._opts, thresholds: rescaled }; + this._transport = new LocalTransport({ opts: this._opts, executor: this._executor }); + this._transport.attach(this); + this._model = 'turn-detector-mini'; + this._isFallback = true; + // The fallback view is owned by this stream (`model`/`_opts`); + // we deliberately don't write it back onto the shared detector, so a + // fallback here can't corrupt another stream off the same detector. + } + + /** @internal Test-visible: same logic as the path taken when `_run` sees a + * local transport error. */ + _onLocalFailure(reason: Error): void { + if (!this._warnedLocalFailure) { + this._detLogger.warn( + { reason: reason.message }, + 'local audio eot mini failed; defaulting to 1.0 and retrying on next turn', + ); + this._warnedLocalFailure = true; + } + this._emitDefaultForInflight(); + } + + protected _emitDefaultForInflight(): void { + const requestId = this._preemptiveRequestId; + if (requestId !== undefined) { + this._handlePrediction(requestId, 1.0); + } + } + + override async aclose(): Promise { + // Detach the transport first so the cloud send channel closes and its + // background sender/recv tasks tear down, then run the base teardown + // (which closes the audio channel and cancels the main task). + try { + this._transport.detach(); + } catch { + // ignore detach errors during teardown + } + await super.aclose(); + } + + protected override async _run(): Promise { + while (true) { + try { + await this._raceWithSwap(this._transport.run()); + return; + } catch (err) { + if (err instanceof SwapAbortError) { + if (this._closing) return; + // A swap already happened (e.g. predict timeout → fallback). + // The new transport is mounted; loop and run it. Routing the + // swap through `SwapAbortError` (rather than through the + // cloud/local branch below) is what prevents the "timeout + // flips model mid-await" misclassification — the catch + // exits early before ever consulting `_model`. + continue; + } + const e = err instanceof Error ? err : new Error(String(err)); + if (this._model === 'turn-detector') { + this._fallBackToLocal(e); + continue; + } + this._onLocalFailure(e); + return; + } + } + } + + protected override _onPredictTimeout(): void { + if (this._model === 'turn-detector') { + // Signal the swap BEFORE mutating model/transport state. The + // race in `_raceWithSwap` is rejected with `SwapAbortError` + // immediately, so the main loop exits through the + // SwapAbortError branch and never consults `_model` for a + // classification that would race with the assignment below. + this._signalSwap(); + this._fallBackToLocal(new Error('predict_end_of_turn')); + } + } +} diff --git a/agents/src/inference/eot/index.ts b/agents/src/inference/eot/index.ts new file mode 100644 index 000000000..e2e6d2fbe --- /dev/null +++ b/agents/src/inference/eot/index.ts @@ -0,0 +1,13 @@ +// SPDX-FileCopyrightText: 2026 LiveKit, Inc. +// +// SPDX-License-Identifier: Apache-2.0 +export { AudioTurnDetector, AudioTurnDetectorStreamImpl } from './detector.js'; +export type { AudioTurnDetectorOptions } from './detector.js'; +export { + CLOUD_LANGUAGES, + LOCAL_LANGUAGES, + materializeThresholds, + rescaleForLocalFallback, +} from './languages.js'; +export type { TurnDetectorModel } from './languages.js'; +export { CloudTransport, LocalTransport, type CloudTransportOptions } from './transports.js'; diff --git a/agents/src/inference/eot/languages.ts b/agents/src/inference/eot/languages.ts new file mode 100644 index 000000000..1d205954e --- /dev/null +++ b/agents/src/inference/eot/languages.ts @@ -0,0 +1,135 @@ +// SPDX-FileCopyrightText: 2026 LiveKit, Inc. +// +// SPDX-License-Identifier: Apache-2.0 +/** + * Per-language `unlikely` thresholds for the audio EOT detector. + * + * Calibrated separately per checkpoint — do NOT unify CLOUD and LOCAL tables. + */ + +export type TurnDetectorModel = 'turn-detector' | 'turn-detector-mini'; + +export const CLOUD_LANGUAGES: Readonly> = { + ar: 0.355, + de: 0.495, + en: 0.56, + es: 0.59, + fr: 0.575, + hi: 0.575, + id: 0.47, + it: 0.64, + ja: 0.37, + ko: 0.695, + nl: 0.75, + pt: 0.665, + tr: 0.65, + zh: 0.59, +}; + +export const LOCAL_LANGUAGES: Readonly> = { + ar: 0.35, + de: 0.245, + en: 0.36, + es: 0.35, + fr: 0.285, + hi: 0.305, + id: 0.345, + it: 0.23, + ja: 0.295, + ko: 0.4, + nl: 0.2, + pt: 0.32, + tr: 0.255, + zh: 0.355, +}; + +const BASE: Record>> = { + 'turn-detector': CLOUD_LANGUAGES, + 'turn-detector-mini': LOCAL_LANGUAGES, +}; + +/** + * BCP-47 language tag (or human-readable name) → ISO 639-1 two-letter code. + * + * Minimal port of Python's `LanguageCode` — covers the languages present in + * the threshold tables. Unknown inputs are returned lowercased and unchanged + * (callers should pass `en`, `en-US`, `English`, etc.). + */ +function normalizeLanguage(input: string): string { + const lower = input.toLowerCase().trim(); + if (lower.length === 2) return lower; + const dashIdx = lower.indexOf('-'); + if (dashIdx === 2) return lower.slice(0, 2); + // long-name aliases for languages in our tables + const aliases: Record = { + arabic: 'ar', + german: 'de', + english: 'en', + spanish: 'es', + french: 'fr', + hindi: 'hi', + indonesian: 'id', + italian: 'it', + japanese: 'ja', + korean: 'ko', + dutch: 'nl', + portuguese: 'pt', + turkish: 'tr', + chinese: 'zh', + mandarin: 'zh', + }; + return aliases[lower] ?? lower; +} + +/** + * Resolve user override + per-model defaults into a complete per-language + * threshold map. + * + * - `undefined`: returns a copy of the bare model table. + * - `number`: fills every language with the same value. + * - object: overrides per-language (keys are normalized so `English` / + * `en` / `en-US` all collapse to `en`); unmapped languages keep the default. + */ +export function materializeThresholds( + userValue: number | Record | undefined, + model: TurnDetectorModel, +): Record { + const base = BASE[model]; + if (userValue === undefined) { + return { ...base }; + } + if (typeof userValue === 'number') { + const out: Record = {}; + for (const lang of Object.keys(base)) { + out[lang] = userValue; + } + return out; + } + const norm: Record = {}; + for (const [k, v] of Object.entries(userValue)) { + norm[normalizeLanguage(k)] = Number(v); + } + const out: Record = {}; + for (const [lang, defaultValue] of Object.entries(base)) { + out[lang] = norm[lang] ?? defaultValue; + } + return out; +} + +/** + * Preserve the user's cloud-vs-default ratio when promoting local: + * `local = LOCAL[lang] * (cloud_t / CLOUD[lang])` per language. + */ +export function rescaleForLocalFallback( + cloudThresholds: Record, +): Record { + const out: Record = {}; + for (const [lang, cloudT] of Object.entries(cloudThresholds)) { + const cloudDefault = CLOUD_LANGUAGES[lang]; + const localDefault = LOCAL_LANGUAGES[lang]; + if (cloudDefault !== undefined && localDefault !== undefined && cloudDefault !== 0) { + out[lang] = localDefault * (cloudT / cloudDefault); + } + } + return out; +} diff --git a/agents/src/inference/eot/runner.test.ts b/agents/src/inference/eot/runner.test.ts new file mode 100644 index 000000000..4058ca9d9 --- /dev/null +++ b/agents/src/inference/eot/runner.test.ts @@ -0,0 +1,55 @@ +// SPDX-FileCopyrightText: 2026 LiveKit, Inc. +// +// SPDX-License-Identifier: Apache-2.0 +import { afterEach, describe, expect, it, vi } from 'vitest'; +import * as warmup from '../_warmup.js'; +import EotRunner from './runner.js'; + +describe('EotRunner', () => { + afterEach(() => { + vi.restoreAllMocks(); + }); + + it('initializes the native EOT model and predicts on decoded PCM', async () => { + const received: Int16Array[] = []; + const fakeMod = { + initEot: vi.fn(), + initVad: vi.fn(), + createVad: vi.fn(), + VAD_WINDOW_SAMPLES: 512, + predict: vi.fn(async (pcm: Int16Array) => { + received.push(pcm); + return 0.83; + }), + }; + vi.spyOn(warmup, '_getLocalInferenceModule').mockReturnValue( + fakeMod as unknown as ReturnType, + ); + + const runner = new EotRunner(); + await runner.initialize(); + expect(fakeMod.initEot).toHaveBeenCalledOnce(); + + // 4 samples of s16le PCM → base64 + const samples = Int16Array.from([1, -2, 3, -4]); + const pcm = Buffer.from(samples.buffer, samples.byteOffset, samples.byteLength).toString( + 'base64', + ); + + const out = await runner.run({ pcm }); + expect(out.probability).toBe(0.83); + expect(out.inferenceDurationMs).toBeGreaterThanOrEqual(0); + + // the runner decoded the base64 back to the same samples + expect(received).toHaveLength(1); + expect(Array.from(received[0]!)).toEqual([1, -2, 3, -4]); + + await runner.close(); + }); + + it('throws on initialize when the native binding is unavailable', async () => { + vi.spyOn(warmup, '_getLocalInferenceModule').mockReturnValue(undefined); + const runner = new EotRunner(); + await expect(runner.initialize()).rejects.toThrow(/native binding unavailable/); + }); +}); diff --git a/agents/src/inference/eot/runner.ts b/agents/src/inference/eot/runner.ts new file mode 100644 index 000000000..d79ed7ad5 --- /dev/null +++ b/agents/src/inference/eot/runner.ts @@ -0,0 +1,71 @@ +// SPDX-FileCopyrightText: 2026 LiveKit, Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +/** + * Audio EOT inference runner — runs inside the shared `InferenceProcExecutor` + * so the ~138 MB native model loads once per host instead of once per job + * worker. Job-side transports reach it via `executor.doInference(...)`. + * + * The inference proc instantiates this with `new Runner()` (no args) and + * calls `initialize()` once at startup, then dispatches `run(data)` per + * request — see `ipc/inference_proc_lazy_main.ts`. Hence the default export + * + no-arg constructor. + */ +import { InferenceRunner } from '../../inference_runner.js'; +import { log } from '../../log.js'; +import { _getLocalInferenceModule } from '../_warmup.js'; + +/** Inference method id used to register + dispatch the audio EOT runner. */ +export const EOT_INFERENCE_METHOD = 'lk_eot_audio'; + +/** Request payload: base64-encoded 16 kHz s16le PCM (up to 1.2 s). */ +export interface EotInferenceInput { + pcm: string; +} + +export interface EotInferenceOutput { + probability: number; + inferenceDurationMs: number; +} + +export default class EotRunner extends InferenceRunner { + #logger = log(); + #mod: ReturnType; + + async initialize(): Promise { + this.#mod = _getLocalInferenceModule(); + if (this.#mod === undefined) { + throw new Error( + 'EotRunner: @livekit/local-inference native binding unavailable in the inference process', + ); + } + // Eagerly page in the EOT model singleton (~138 MB) so the first + // request doesn't pay the load on the hot path. + this.#mod.initEot(); + } + + async run(data: EotInferenceInput): Promise { + if (this.#mod === undefined) { + throw new Error('EotRunner not initialized'); + } + // base64 → bytes → Int16Array view (PCM is 16 kHz s16le) + const bytes = Buffer.from(data.pcm, 'base64'); + const pcm = new Int16Array(bytes.buffer, bytes.byteOffset, Math.floor(bytes.byteLength / 2)); + const t0 = performance.now(); + let probability = 0.0; + try { + probability = await this.#mod.predict(pcm); + } catch (err) { + this.#logger.error( + { err: err instanceof Error ? err.message : String(err) }, + 'local audio EOT prediction failed', + ); + } + return { probability, inferenceDurationMs: performance.now() - t0 }; + } + + async close(): Promise { + return; + } +} diff --git a/agents/src/inference/eot/transports.test.ts b/agents/src/inference/eot/transports.test.ts new file mode 100644 index 000000000..445a27066 --- /dev/null +++ b/agents/src/inference/eot/transports.test.ts @@ -0,0 +1,224 @@ +// SPDX-FileCopyrightText: 2026 LiveKit, Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +/** + * Tests for `CloudTransport` (cloud WS body, driven by the unified + * `AudioTurnDetectorStreamImpl` stream). + * + * Uses an in-process fake WebSocket to drive the transport + * deterministically. Covers: + * + * - Retry counter resets after a successful connect (so transient drops + * across the session lifetime don't accumulate toward `maxRetry`). + * - All outbound messages are FIFO-ordered on the wire, even when control + * hooks fire synchronously between two awaited audio frames. + * + * Port of Python `tests/test_turn_detection_cloud_stream.py`. + */ +import { AgentInference } from '@livekit/protocol'; +import { AudioFrame } from '@livekit/rtc-node'; +import { describe, expect, it } from 'vitest'; +import { APIConnectionError } from '../../_exceptions.js'; +import { DEFAULT_API_CONNECT_OPTIONS } from '../../types.js'; +import { AudioTurnDetector, type TurnDetectorOptions } from './base.js'; +import { AudioTurnDetectorStreamImpl } from './detector.js'; +import type { TurnDetectorModel } from './languages.js'; +import { CloudTransport, type CloudWebSocket } from './transports.js'; + +const { ClientMessage } = AgentInference; + +/** Fake WebSocket capturing outbound frames as parsed `ClientMessage`s. */ +class FakeWS implements CloudWebSocket { + sent: InstanceType[] = []; + readyState = 1; // OPEN + private closeCbs: Array<() => void> = []; + + send(data: Uint8Array): void { + if (this.readyState !== 1) throw new Error('ws closed'); + this.sent.push(ClientMessage.fromBinary(data)); + } + close(): void { + this.readyState = 3; // CLOSED + for (const cb of this.closeCbs) cb(); + } + on(event: 'message' | 'close' | 'error', cb: (...args: never[]) => void): void { + if (event === 'close') this.closeCbs.push(cb as () => void); + // message/error not driven in these tests + } +} + +class FakeDetector extends AudioTurnDetector { + get model(): TurnDetectorModel { + return 'turn-detector'; + } + stream(): never { + throw new Error('unused'); + } +} + +interface MakeStreamResult { + stream: AudioTurnDetectorStreamImpl; + fakeWs: FakeWS; + transport: CloudTransport; +} + +function makeStream(opts: { + connectScript?: Array; + maxRetry?: number; + retryIntervalMs?: number; +}): MakeStreamResult { + const fakeWs = new FakeWS(); + const script = [...(opts.connectScript ?? [])]; + const turnOpts: TurnDetectorOptions = { sampleRate: 16000, thresholds: {} }; + const detector = new FakeDetector(turnOpts); + const cloudOpts = { + baseUrl: '', + apiKey: 'x', + apiSecret: 'x', + connOptions: { + ...DEFAULT_API_CONNECT_OPTIONS, + maxRetry: opts.maxRetry ?? 3, + retryIntervalMs: opts.retryIntervalMs ?? 0, + }, + }; + // Scripted connect: consume the script left-to-right. An Error rejects; + // null (or exhausted) returns the fake ws. + const connect = async (): Promise => { + if (script.length > 0) { + const r = script.shift(); + if (r instanceof Error) throw r; + } + fakeWs.readyState = 1; + return fakeWs; + }; + const transport = new CloudTransport({ detector, opts: turnOpts, cloudOpts, connect }); + const stream = new AudioTurnDetectorStreamImpl({ + detector, + opts: turnOpts, + cloudOpts, + model: 'turn-detector', + transport, + }); + return { stream, fakeWs, transport }; +} + +async function tick(): Promise { + await new Promise((r) => setImmediate(r)); +} + +async function waitUntilConnected(transport: CloudTransport, ticks = 50): Promise { + for (let i = 0; i < ticks; i++) { + if (transport.transportReady()) return; + await tick(); + } + throw new Error('transport did not connect within timeout'); +} + +async function drainSendQueue(_transport: CloudTransport, ticks = 50): Promise { + // Let the sender task flush the buffered ClientMsgs to the fake socket. + for (let i = 0; i < ticks; i++) { + await tick(); + } +} + +async function waitForCond(predicate: () => boolean, ticks = 50): Promise { + for (let i = 0; i < ticks; i++) { + if (predicate()) return; + await tick(); + } +} + +function pcmFrame(samples = 320): AudioFrame { + return new AudioFrame(new Int16Array(samples), 16000, 1, samples); +} + +describe('CloudStreamRetry', () => { + it('num retries resets after a successful connect', async () => { + const { stream, transport } = makeStream({ + connectScript: [new APIConnectionError({ message: 'transient' }), null], + maxRetry: 3, + retryIntervalMs: 0, + }); + try { + await waitUntilConnected(transport); + // Two attempts: first raised (counter 0→1), second succeeded → reset to 0. + expect(transport.connectCalls).toBe(2); + expect(transport.numRetries).toBe(0); + } finally { + await stream.aclose(); + } + }); +}); + +describe('CloudToLocalFallback', () => { + it('releases the shared audio reader lock on fallback (regression)', async () => { + const { stream, transport } = makeStream({ connectScript: [null] }); + try { + await waitUntilConnected(transport); + // Drive a frame so the cloud drain task is actively parked on + // `reader.read()`, holding the audio channel's single reader lock. + stream.pushAudio(pcmFrame()); + await tick(); + + // Predict timeout triggers a cloud→local fallback. The orphaned cloud + // drain must release the shared reader lock before the real + // `LocalTransport.run()` re-acquires it — otherwise `getReader()` throws + // "ReadableStream is locked", which the FSM mis-reports as a local + // failure. + const prob = await stream.predictEndOfTurn(undefined, { timeoutMs: 10 }); + expect(prob).toBe(1.0); + + await waitForCond(() => stream.model === 'turn-detector-mini'); + expect(stream.isFallback).toBe(true); + + // Let the swapped-in LocalTransport.run() re-acquire the reader and start + // draining. A freed lock ⇒ no "ReadableStream is locked" TypeError ⇒ no + // local failure flagged. + for (let i = 0; i < 10; i++) await tick(); + expect(stream.warnedLocalFailure).toBe(false); + } finally { + await stream.aclose(); + } + }); +}); + +describe('CloudStreamSendOrdering', () => { + it('inferenceStart precedes inputAudio (FIFO)', async () => { + const { stream, fakeWs, transport } = makeStream({ connectScript: [null] }); + try { + await waitUntilConnected(transport); + stream.warmup(); + stream.pushAudio(pcmFrame()); + await drainSendQueue(transport); + + const kinds = fakeWs.sent.map((m) => m.message.case); + const startIdx = kinds.indexOf('inferenceStart'); + const audioIdx = kinds.indexOf('inputAudio'); + expect(startIdx).toBeGreaterThanOrEqual(0); + expect(audioIdx).toBeGreaterThanOrEqual(0); + expect(startIdx).toBeLessThan(audioIdx); + } finally { + await stream.aclose(); + } + }); + + it('inferenceStart precedes inferenceStop (FIFO)', async () => { + const { stream, fakeWs, transport } = makeStream({ connectScript: [null] }); + try { + await waitUntilConnected(transport); + stream.warmup(); + stream.deactivate('vad sos'); + await drainSendQueue(transport); + + const kinds = fakeWs.sent.map((m) => m.message.case); + const startIdx = kinds.indexOf('inferenceStart'); + const stopIdx = kinds.indexOf('inferenceStop'); + expect(startIdx).toBeGreaterThanOrEqual(0); + expect(stopIdx).toBeGreaterThanOrEqual(0); + expect(startIdx).toBeLessThan(stopIdx); + } finally { + await stream.aclose(); + } + }); +}); diff --git a/agents/src/inference/eot/transports.ts b/agents/src/inference/eot/transports.ts new file mode 100644 index 000000000..55192b450 --- /dev/null +++ b/agents/src/inference/eot/transports.ts @@ -0,0 +1,614 @@ +// SPDX-FileCopyrightText: 2026 LiveKit, Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +/** + * Audio EOT transports: cloud (WebSocket) + local (@livekit/local-inference). + * + * Port of Python `livekit.agents.inference.eot.transports`. + */ +import { type Duration, Timestamp } from '@bufbuild/protobuf'; +import { AgentInference } from '@livekit/protocol'; +import type { AudioFrame } from '@livekit/rtc-node'; +import { APIConnectionError, APIError, APIStatusError } from '../../_exceptions.js'; +import type { InferenceExecutor } from '../../ipc/inference_executor.js'; +import { log } from '../../log.js'; +import { type StreamChannel, createStreamChannel } from '../../stream/stream_channel.js'; +import { type APIConnectOptions, intervalForRetry } from '../../types.js'; +import { Task, delay } from '../../utils.js'; +import { + type AudioTurnDetectionTransport, + type AudioTurnDetectorStream, + DEFAULT_SAMPLE_RATE, + type FlushSentinel, + type TurnDetectorOptions, +} from './base.js'; +import { buildMetadataHeaders, connectWs, createAccessToken } from '../utils.js'; +import type { AudioTurnDetector } from './detector.js'; +import { EOT_INFERENCE_METHOD } from './runner.js'; + +const AudioEncoding = AgentInference.AudioEncoding; +const ClientMessageCtor = AgentInference.ClientMessage; +const ServerMessageCtor = AgentInference.ServerMessage; +const InferenceStart = AgentInference.InferenceStart; +const InferenceStop = AgentInference.InferenceStop; +const InputAudio = AgentInference.InputAudio; +const SessionClose = AgentInference.SessionClose; +const SessionCreate = AgentInference.SessionCreate; +const SessionFlush = AgentInference.SessionFlush; +const SessionSettings = AgentInference.SessionSettings; +type ClientMsg = InstanceType; +type ServerMsg = InstanceType; + +export interface CloudTransportOptions { + baseUrl: string; + apiKey: string; + apiSecret: string; + connOptions: APIConnectOptions; +} + +/** + * Minimal WebSocket shape both the real `ws` socket and test fakes satisfy. + * The cloud transport only needs send/close/readyState + the three events. + */ +export interface CloudWebSocket { + send(data: Uint8Array): void; + close(): void; + readonly readyState: number; + on(event: 'message', cb: (data: Buffer | ArrayBuffer | Buffer[]) => void): void; + on(event: 'close', cb: () => void): void; + on(event: 'error', cb: (err: Error) => void): void; +} + +const WS_OPEN = 1; + +function nowTimestamp(): Timestamp { + const now = Date.now(); + return new Timestamp({ + seconds: BigInt(Math.floor(now / 1000)), + nanos: (now % 1000) * 1_000_000, + }); +} + +function timestampToMs(ts?: Timestamp): number { + if (ts === undefined) return 0; + return Number(ts.seconds) * 1000 + Math.floor(ts.nanos / 1_000_000); +} + +function durationToMs(d?: Duration): number { + if (d === undefined) return 0; + return Number(d.seconds) * 1000 + Math.floor(d.nanos / 1_000_000); +} + +// Native model operates on up to 1.2 s of 16 kHz s16le PCM per predict. +const CLIENT_BUFFER_SECONDS = 1.2; +const CLIENT_BUFFER_SAMPLES = Math.floor(CLIENT_BUFFER_SECONDS * DEFAULT_SAMPLE_RATE); + +/** + * Append-only ring buffer of 16-bit PCM samples used by the local transport + * to keep the last ~1.2 s of audio available for per-window prediction. + */ +class PcmRingBuffer { + private buf: Int16Array; + private writeIdx = 0; + private filled = 0; + + constructor(public readonly capacity: number) { + this.buf = new Int16Array(capacity); + } + + pushFrame(frame: AudioFrame): void { + const src = frame.data; // Int16Array + for (let i = 0; i < src.length; i++) { + this.buf[this.writeIdx] = src[i]!; + this.writeIdx = (this.writeIdx + 1) % this.capacity; + } + this.filled = Math.min(this.filled + src.length, this.capacity); + } + + /** Returns a contiguous Int16Array snapshot of the last `filled` samples. */ + read(): Int16Array { + const out = new Int16Array(this.filled); + const start = (this.writeIdx - this.filled + this.capacity) % this.capacity; + if (start + this.filled <= this.capacity) { + out.set(this.buf.subarray(start, start + this.filled)); + } else { + const tail = this.capacity - start; + out.set(this.buf.subarray(start, this.capacity), 0); + out.set(this.buf.subarray(0, this.filled - tail), tail); + } + return out; + } + + /** Drop the oldest `n` samples. */ + shift(n: number): void { + this.filled = Math.max(0, this.filled - n); + } + + get length(): number { + return this.filled; + } +} + +/** + * Transport for the local `turn-detector-mini` model. + * + * The native model runs in the shared `InferenceProcExecutor` (one load per + * host, ~138 MB) rather than in every job worker. Audio is buffered locally + * in the job process (no per-frame IPC); on each inference window the last + * ~1.2 s is snapshotted, base64-encoded, and sent over IPC to the runner + * (`inference/eot/runner.ts`) via `executor.doInference(...)`. + * + * When no executor is available (binding couldn't load on this platform), + * predictions resolve to a positive default (1.0) so the session still + * commits turns after `minDelay` — same as the existing local-failure path. + */ +export class LocalTransport implements AudioTurnDetectionTransport { + protected _opts: TurnDetectorOptions; + protected _executor: InferenceExecutor | undefined; + protected _buf: PcmRingBuffer; + protected _streamRef: WeakRef | undefined; + protected _tasks = new Set>(); + protected _warnedNoExecutor = false; + protected _logger = log(); + + constructor(opts: { opts: TurnDetectorOptions; executor: InferenceExecutor | undefined }) { + this._opts = opts.opts; + this._executor = opts.executor; + this._buf = new PcmRingBuffer(CLIENT_BUFFER_SAMPLES); + } + + attach(stream: AudioTurnDetectorStream): void { + this._streamRef = new WeakRef(stream); + } + + startInference(requestId: string): void { + const snapshot = this._buf.read(); + const task = this._predict(requestId, snapshot); + this._tasks.add(task); + void task.finally(() => this._tasks.delete(task)); + } + + protected async _predict(requestId: string, pcmSnapshot: Int16Array): Promise { + const stream = this._streamRef?.deref(); + if (stream === undefined) return; + + if (this._executor === undefined) { + if (!this._warnedNoExecutor) { + this._warnedNoExecutor = true; + this._logger.warn( + 'local audio EOT unavailable (no inference executor / native binding); ' + + 'defaulting predictions to 1.0 so turns still commit after minDelay', + ); + } + stream._handlePrediction(requestId, 1.0); + return; + } + + // base64-encode the s16le PCM so it survives the default JSON IPC + // serialization compactly (a raw Int16Array would balloon to an + // array-of-numbers). Only the snapshot crosses the boundary. + const pcm = Buffer.from( + pcmSnapshot.buffer, + pcmSnapshot.byteOffset, + pcmSnapshot.byteLength, + ).toString('base64'); + + let prob = 0.0; + let inferenceDurationMs = 0; + try { + const out = (await this._executor.doInference(EOT_INFERENCE_METHOD, { + pcm, + })) as { probability: number; inferenceDurationMs: number }; + prob = out.probability; + inferenceDurationMs = out.inferenceDurationMs; + } catch (err) { + this._logger.error( + { err: err instanceof Error ? err.message : String(err) }, + 'local audio EOT inference (executor) failed', + ); + } + const freshStream = this._streamRef?.deref(); + if (freshStream === undefined) return; + freshStream._handlePrediction(requestId, prob, { inferenceDuration: inferenceDurationMs }); + } + + async pushFrame(frame: AudioFrame): Promise { + this._buf.pushFrame(frame); + } + + async flush(sentinel: FlushSentinel): Promise { + const keepSamples = Math.floor((sentinel.keepTailMs * DEFAULT_SAMPLE_RATE) / 1000); + if (this._buf.length > keepSamples) { + this._buf.shift(this._buf.length - keepSamples); + } + } + + stopInference(_reason?: string): void { + // In-flight predictions run to completion; `_predict` drops stale results. + return; + } + + detach(): void { + this._tasks.clear(); + } + + async run(): Promise { + const stream = this._streamRef?.deref(); + if (stream === undefined) return; + await stream._drainAudioChannel(); + } +} + +/** + * WebSocket transport for the `turn-detector` (cloud) model. + * + * Maintains one inference session against the LiveKit Agent Gateway: + * connect → `SessionCreate` → three concurrent tasks (drain audio, send, + * receive) → protobuf encode/decode → `stream._handlePrediction(...)` + + * `EOTInferenceMetrics` on the detector. Mirrors Python `_CloudTransport`. + * + * All outbound messages flow through a single FIFO send channel so control + * hooks fired synchronously between two awaited audio frames (e.g. + * `inferenceStart` then `inputAudio`) reach the wire in call order. + */ +export class CloudTransport implements AudioTurnDetectionTransport { + protected _detectorRef: WeakRef; + protected _opts: TurnDetectorOptions; + protected _cloudOpts: CloudTransportOptions; + protected _connOptions: APIConnectOptions; + protected _streamRef: WeakRef | undefined; + protected _ws: CloudWebSocket | undefined; + protected _numRetries = 0; + protected _connectCalls = 0; + /** Outbound FIFO for the active connection; recreated per `_runOnce`. */ + protected _sendChannel: StreamChannel | undefined; + /** Set by `detach()`; stops the retry loop and suppresses the + * connection-closed throw so a teardown can't trigger a reconnect. */ + protected _detached = false; + /** Aborted by `detach()` to release the audio-drain reader lock so a + * swapped-in transport can re-acquire the shared audio stream. */ + protected _runAbort: AbortController | undefined; + protected _logger = log(); + /** Optional connect override for tests; defaults to a real WS handshake. */ + private _connectImpl: (() => Promise) | undefined; + + constructor(args: { + detector: AudioTurnDetector; + opts: TurnDetectorOptions; + cloudOpts: CloudTransportOptions; + /** @internal test seam — supply a fake WebSocket factory. */ + connect?: (transport: CloudTransport) => Promise; + }) { + this._detectorRef = new WeakRef(args.detector); + this._opts = args.opts; + this._cloudOpts = args.cloudOpts; + this._connOptions = args.cloudOpts.connOptions; + this._connectImpl = args.connect ? () => args.connect!(this) : undefined; + } + + /** @internal Test-visible: number of connect attempts. */ + get connectCalls(): number { + return this._connectCalls; + } + /** @internal Test-visible: retry counter (resets to 0 after a connect). */ + get numRetries(): number { + return this._numRetries; + } + + attach(stream: AudioTurnDetectorStream): void { + this._streamRef = new WeakRef(stream); + } + + /** @internal Test-visible: true once the WS handshake is open. Not part of + * the transport interface — the stream FSM no longer gates on this. */ + transportReady(): boolean { + return this._ws !== undefined && this._ws.readyState === WS_OPEN; + } + + startInference(requestId: string): void { + this._enqueue( + new ClientMessageCtor({ + message: { case: 'inferenceStart', value: new InferenceStart({ requestId }) }, + }), + ); + } + + stopInference(_reason?: string): void { + this._enqueue( + new ClientMessageCtor({ + message: { case: 'inferenceStop', value: new InferenceStop() }, + createdAt: nowTimestamp(), + }), + ); + } + + async pushFrame(frame: AudioFrame): Promise { + if (frame.data.byteLength === 0) return; + this._enqueue( + new ClientMessageCtor({ + message: { + case: 'inputAudio', + value: new InputAudio({ + audio: new Uint8Array(frame.data.buffer, frame.data.byteOffset, frame.data.byteLength), + numSamples: frame.samplesPerChannel, + createdAt: nowTimestamp(), + }), + }, + }), + ); + } + + async flush(_sentinel: FlushSentinel): Promise { + this._enqueue( + new ClientMessageCtor({ message: { case: 'sessionFlush', value: new SessionFlush() } }), + ); + } + + detach(): void { + this._detached = true; + // Abort the active run: this releases the audio-drain reader lock (held by + // `stream._drainAudioChannel`) so a swapped-in transport can re-acquire the + // shared audio stream, and unblocks the recv/send tasks below. + this._runAbort?.abort(); + void this._sendChannel?.close(); + const ws = this._ws; + this._ws = undefined; + try { + ws?.close(); + } catch { + // ignore + } + } + + private _enqueue(msg: ClientMsg): void { + // The WS handle is cleared synchronously by `detach()` while + // `_sendChannel.close()` is still in flight (its `closed` flag flips + // asynchronously). Gate on `_ws` to drop late control hooks that the + // stream FSM may fire after the transport is being torn down. + if (this._ws === undefined || this._ws.readyState !== WS_OPEN) return; + const channel = this._sendChannel; + if (channel === undefined || channel.closed) return; + void channel.write(msg).catch(() => {}); + } + + private async _defaultConnect(): Promise { + let baseUrl = this._cloudOpts.baseUrl; + if (baseUrl.startsWith('http://')) baseUrl = baseUrl.replace('http://', 'ws://'); + else if (baseUrl.startsWith('https://')) baseUrl = baseUrl.replace('https://', 'wss://'); + const token = await createAccessToken(this._cloudOpts.apiKey, this._cloudOpts.apiSecret); + const headers = { ...buildMetadataHeaders(), Authorization: `Bearer ${token}` }; + const ws = await connectWs(`${baseUrl}/eot`, headers, this._connOptions.timeoutMs); + return ws as unknown as CloudWebSocket; + } + + protected _processServerMessage(msg: ServerMsg): void { + const stream = this._streamRef?.deref(); + if (stream === undefined) return; + const kind = msg.message.case; + if (kind === 'eotPrediction') { + const prediction = msg.message.value; + const stats = prediction.inferenceStats; + const requestSentAtMs = timestampToMs(stats?.latestClientCreatedAt); + const detectionDelayMs = requestSentAtMs > 0 ? Date.now() - requestSentAtMs : 0; + const inferenceDurationMs = durationToMs(stats?.serverE2eLatency); + stream._handlePrediction(msg.requestId ?? '', prediction.probability, { + detectionDelay: detectionDelayMs, + inferenceDuration: inferenceDurationMs, + }); + const detector = this._detectorRef.deref(); + if (detector !== undefined) { + detector.emit('metrics_collected', { + type: 'eot_inference_metrics', + timestamp: Date.now(), + totalDuration: durationToMs(stats?.clientE2eLatency), + predictionDuration: inferenceDurationMs, + detectionDelay: detectionDelayMs, + numRequests: 1, + metadata: { modelName: detector.model, modelProvider: detector.provider }, + }); + } + } else if (kind === 'error') { + const err = msg.message.value; + throw new APIStatusError({ + message: err.message, + options: { statusCode: err.code, requestId: msg.requestId }, + }); + } else if ( + kind === 'sessionCreated' || + kind === 'sessionClosed' || + kind === 'inferenceStarted' || + kind === 'inferenceStopped' + ) { + const clientCreatedAtMs = timestampToMs(msg.clientCreatedAt); + const transportLatency = Date.now() - clientCreatedAtMs; + if (transportLatency > 500 && clientCreatedAtMs > 0) { + this._logger.warn( + { transportLatencyMs: transportLatency }, + 'turn detection transport latency is too high', + ); + } + } else { + this._logger.warn({ kind }, 'unexpected turn detector message'); + } + } + + async run(): Promise { + const maxRetries = this._connOptions.maxRetry; + while (!this._detached && this._numRetries <= maxRetries) { + try { + await this._runOnce(); + return; + } catch (err) { + // A detach (e.g. cloud→local fallback) tears the session down; don't + // surface that as a connection error or retry into a reconnect. + if (this._detached) return; + if (!(err instanceof APIError) || maxRetries === 0 || !err.retryable) throw err; + if (this._numRetries === maxRetries) { + throw new APIConnectionError({ + message: `failed to connect livekit turn detector after ${this._numRetries} attempts`, + }); + } + const retryIntervalMs = intervalForRetry(this._connOptions, this._numRetries); + this._logger.warn( + { err: err.message, attempt: this._numRetries, retryIntervalMs }, + 'livekit turn detector connection failed; retrying', + ); + await delay(retryIntervalMs); + this._numRetries += 1; + } + } + } + + protected async _runOnce(): Promise { + const stream = this._streamRef?.deref(); + if (stream === undefined) return; + + // Per-run abort: `detach()` fires it to release the audio-drain reader + // lock and stop the recv/send tasks without a spurious "closed" throw. + const runAbort = new AbortController(); + this._runAbort = runAbort; + + this._connectCalls += 1; + const ws = await (this._connectImpl ?? this._defaultConnect.bind(this))(); + + // Detached while the handshake was in flight — don't revive the session. + if (this._detached) { + try { + ws.close(); + } catch { + // ignore + } + return; + } + + // Successful connect — reset transient-failure counter so drops across + // the session lifetime don't accumulate toward maxRetry. + this._numRetries = 0; + this._ws = ws; + const sendChannel = createStreamChannel(); + this._sendChannel = sendChannel; + + // Send the SessionCreate handshake first, before any queued control msg. + ws.send( + new ClientMessageCtor({ + message: { + case: 'sessionCreate', + value: new SessionCreate({ + settings: new SessionSettings({ + sampleRate: this._opts.sampleRate, + encoding: AudioEncoding.PCM_S16LE, + }), + }), + }, + createdAt: nowTimestamp(), + }).toBinary(), + ); + + let closingWs = false; + let socketErr: Error | undefined; + // Closing the recv channel makes the reader drain buffered frames and then + // observe `done`; we use it (not `abort`) on socket close/error so the + // post-drain throw below still decides the outcome. + const recvChannel = createStreamChannel(); + + ws.on('message', (data) => { + const chunk = + data instanceof Buffer + ? new Uint8Array(data.buffer, data.byteOffset, data.byteLength) + : Array.isArray(data) + ? new Uint8Array(Buffer.concat(data)) + : new Uint8Array(data); + void recvChannel.write(chunk).catch(() => {}); + }); + ws.on('close', () => { + void recvChannel.close(); + void sendChannel.close(); + }); + ws.on('error', (err) => { + socketErr = err; + void recvChannel.close(); + void sendChannel.close(); + }); + + const drainAudioTask = Task.from(async () => { + await stream._drainAudioChannel(runAbort.signal); + // Detached mid-drain (fallback/teardown): the lock is already released; + // skip the graceful sessionClose — the session is being abandoned. + if (runAbort.signal.aborted) return; + closingWs = true; + this._enqueue( + new ClientMessageCtor({ message: { case: 'sessionClose', value: new SessionClose() } }), + ); + // Close after enqueue so the sender flushes `sessionClose` before exiting. + await sendChannel.close(); + }); + + const senderTask = Task.from(async () => { + const reader = sendChannel.stream().getReader(); + try { + while (true) { + const { done, value: msg } = await reader.read(); + if (done) return; + if (msg.createdAt === undefined) msg.createdAt = nowTimestamp(); + if (ws.readyState !== WS_OPEN) return; + try { + ws.send(msg.toBinary()); + } catch { + return; + } + } + } finally { + reader.releaseLock(); + } + }); + + const recvTask = Task.from(async () => { + const reader = recvChannel.stream().getReader(); + try { + while (true) { + const { done, value: chunk } = await reader.read(); + if (done) break; + this._processServerMessage(ServerMessageCtor.fromBinary(chunk)); + } + } finally { + reader.releaseLock(); + } + // A detach-driven ws close is expected teardown, not a failure. + if (socketErr !== undefined && !closingWs && !runAbort.signal.aborted) { + throw new APIConnectionError({ + message: `turn detector connection error: ${socketErr.message}`, + }); + } + if (!closingWs && !runAbort.signal.aborted) { + throw new APIStatusError({ + message: 'turn detector connection closed unexpectedly', + options: { statusCode: -1 }, + }); + } + }); + + try { + await Promise.all([drainAudioTask.result, senderTask.result, recvTask.result]); + } finally { + drainAudioTask.cancel(); + senderTask.cancel(); + recvTask.cancel(); + void sendChannel.close(); + void recvChannel.close(); + this._ws = undefined; + try { + ws.close(); + } catch { + // ignore + } + } + } +} + +// Re-export the transport interface from the FSM module so callers that +// import `AudioTurnDetectionTransport` from this package barrel see the +// same type. +export type { AudioTurnDetectionTransport }; +// Expose APIError so detector + fallback code can narrow on it. +export type { APIError }; diff --git a/agents/src/inference/index.ts b/agents/src/inference/index.ts index b6847f59c..3ecdc9155 100644 --- a/agents/src/inference/index.ts +++ b/agents/src/inference/index.ts @@ -1,10 +1,28 @@ // SPDX-FileCopyrightText: 2025 LiveKit, Inc. // // SPDX-License-Identifier: Apache-2.0 +import * as eot from './eot/index.js'; import * as llm from './llm.js'; import * as stt from './stt.js'; import * as tts from './tts.js'; +export { eot }; +export { + AudioTurnDetector, + AudioTurnDetectorStreamImpl, + CLOUD_LANGUAGES, + LOCAL_LANGUAGES, + CloudTransport, + LocalTransport, + materializeThresholds, + rescaleForLocalFallback, + type AudioTurnDetectorOptions, + type CloudTransportOptions, + type TurnDetectorModel, +} from './eot/index.js'; + +export { VAD, type VADOptions, type VADModels } from './vad.js'; + export { LLM, LLMStream, diff --git a/agents/src/inference/vad.ts b/agents/src/inference/vad.ts new file mode 100644 index 000000000..d59e70433 --- /dev/null +++ b/agents/src/inference/vad.ts @@ -0,0 +1,369 @@ +// SPDX-FileCopyrightText: 2026 LiveKit, Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +/** + * Voice Activity Detection backed by `@livekit/local-inference`. + * + * Provides the same streaming VAD shape as `plugins/silero` but routes + * inference through the bundled native model so a default instance can be + * auto-provisioned by `AgentSession` without an explicit plugin import. + * + * Port of Python `livekit.agents.inference.vad`. + */ +import { AudioFrame, AudioResampler, AudioResamplerQuality } from '@livekit/rtc-node'; +import { log } from '../log.js'; +import { VAD as BaseVAD, VADStream as BaseVADStream, VADEventType } from '../vad.js'; +import { _getLocalInferenceModule } from './_warmup.js'; + +const SLOW_INFERENCE_THRESHOLD_MS = 200; +const MODEL_SAMPLE_RATE = 16000; + +export type VADModels = 'silero'; + +export interface VADOptions { + /** Minimum speech duration (ms) before reporting START_OF_SPEECH. */ + minSpeechDuration: number; + /** Trailing silence (ms) before reporting END_OF_SPEECH. */ + minSilenceDuration: number; + /** Pre-roll (ms) included in the speech buffer ahead of START_OF_SPEECH. */ + prefixPaddingDuration: number; + /** Maximum (ms) of buffered speech per utterance. */ + maxBufferedSpeech: number; + /** Sigmoid probability threshold for activation. */ + activationThreshold: number; + /** Sigmoid probability threshold for deactivation (defaults to + * `max(activationThreshold - 0.15, 0.01)`). */ + deactivationThreshold: number; +} + +const defaultVADOptions: VADOptions = { + minSpeechDuration: 50, + // 250ms (= MIN_SILENCE_DURATION_MS + 50) so the default satisfies the audio + // end-of-turn detector's silence-window requirement out of the box. + minSilenceDuration: 250, + prefixPaddingDuration: 500, + maxBufferedSpeech: 60_000, + activationThreshold: 0.5, + deactivationThreshold: 0.35, +}; + +export class VAD extends BaseVAD { + protected _opts: VADOptions; + protected _model: VADModels; + label = 'inference.VAD'; + + constructor(opts: Partial & { model?: VADModels } = {}) { + super({ updateInterval: 32 }); + const model: VADModels = opts.model ?? 'silero'; + if (model !== 'silero') { + throw new Error(`Unknown VAD model: ${String(model)}. Supported: 'silero'.`); + } + this._model = model; + const activation = opts.activationThreshold ?? defaultVADOptions.activationThreshold; + this._opts = { + ...defaultVADOptions, + ...opts, + activationThreshold: activation, + deactivationThreshold: opts.deactivationThreshold ?? Math.max(activation - 0.15, 0.01), + }; + } + + get model(): string { + return this._model; + } + + get provider(): string { + return 'livekit-local-inference'; + } + + override get minSilenceDuration(): number { + return this._opts.minSilenceDuration; + } + + /** Update one or more knobs at runtime. */ + updateOptions(opts: Partial): void { + this._opts = { ...this._opts, ...opts }; + } + + stream(): BaseVADStream { + return new InferenceVADStream(this, { ...this._opts }); + } +} + +class InferenceVADStream extends BaseVADStream { + private _opts: VADOptions; + private _logger = log(); + private _nativeVad: + | ReturnType>['createVad']> + | undefined; + private _windowSamples: number; + private _inputSampleRate = 0; + private _resampler: AudioResampler | undefined; + private _speechBuffer: Int16Array | null = null; + private _speechBufferMaxReached = false; + private _prefixPaddingSamples = 0; + private _pumpTask: Promise; + + constructor(parent: VAD, opts: VADOptions) { + super(parent); + this._opts = opts; + const mod = _getLocalInferenceModule(); + if (mod === undefined) { + this._logger.warn( + 'inference.VAD created without @livekit/local-inference; stream will be a no-op', + ); + this._windowSamples = 512; + } else { + this._nativeVad = mod.createVad(); + this._windowSamples = mod.VAD_WINDOW_SAMPLES; + } + this._pumpTask = this._pump().catch((err) => { + this._logger.error( + { err: err instanceof Error ? err.message : String(err) }, + 'VAD pump failed', + ); + }); + } + + private async _pump(): Promise { + let pubSpeaking = false; + let pubSpeechDurationMs = 0; + let pubSilenceDurationMs = 0; + let pubCurrentSample = 0; + let pubTimestampMs = 0; + let speechThresholdDurationMs = 0; + let silenceThresholdDurationMs = 0; + let inputFrames: AudioFrame[] = []; + let inferenceFrames: AudioFrame[] = []; + let inputCopyRemainingFrac = 0; + let extraInferenceTime = 0; + // Write cursor into `_speechBuffer`. The buffer holds: + // [ ...prefix-padding (sliding pre-roll) ..., ...active speech... ] + // and is reset on END_OF_SPEECH (and on silence while idle) so the next + // turn starts from a fresh pre-roll window. + let speechBufferIndex = 0; + + const resetWriteCursor = () => { + if (this._speechBuffer === null) return; + if (speechBufferIndex <= this._prefixPaddingSamples) return; + // Slide the most-recent `prefixPaddingSamples` samples to the head + // of the buffer so the next utterance has continuous pre-roll + // context (the audio that immediately preceded START_OF_SPEECH). + const paddingData = this._speechBuffer.subarray( + speechBufferIndex - this._prefixPaddingSamples, + speechBufferIndex, + ); + this._speechBuffer.set(paddingData, 0); + speechBufferIndex = this._prefixPaddingSamples; + this._speechBufferMaxReached = false; + }; + + const copySpeechBuffer = (): AudioFrame => { + if (this._speechBuffer === null) { + return new AudioFrame(new Int16Array(0), this._inputSampleRate, 1, 0); + } + return new AudioFrame( + this._speechBuffer.subarray(0, speechBufferIndex), + this._inputSampleRate, + 1, + speechBufferIndex, + ); + }; + + while (!this.closed) { + const { done, value: frame } = await this.inputReader.read(); + if (done) break; + if (typeof frame === 'symbol') continue; + + if (!this._inputSampleRate) { + this._inputSampleRate = frame.sampleRate; + this._prefixPaddingSamples = Math.trunc( + (this._opts.prefixPaddingDuration * this._inputSampleRate) / 1000, + ); + const bufferSize = + Math.trunc((this._opts.maxBufferedSpeech * this._inputSampleRate) / 1000) + + this._prefixPaddingSamples; + this._speechBuffer = new Int16Array(bufferSize); + if (this._inputSampleRate !== MODEL_SAMPLE_RATE) { + this._resampler = new AudioResampler( + this._inputSampleRate, + MODEL_SAMPLE_RATE, + 1, + AudioResamplerQuality.QUICK, + ); + } + } else if (frame.sampleRate !== this._inputSampleRate) { + this._logger.error('a frame with a different sample rate was already pushed'); + continue; + } + + if (this._speechBuffer === null) continue; + + inputFrames.push(frame); + if (this._resampler !== undefined) { + inferenceFrames.push(...this._resampler.push(frame)); + } else { + inferenceFrames.push(frame); + } + + while (!this.closed) { + const startTime = performance.now(); + const availableInferenceSamples = inferenceFrames.reduce( + (acc, f) => acc + f.samplesPerChannel, + 0, + ); + if (availableInferenceSamples < this._windowSamples) break; + + const inputFrame = mergeFrames(inputFrames); + const inferenceFrame = mergeFrames(inferenceFrames); + const inferenceWindow = inferenceFrame.data.subarray(0, this._windowSamples); + + let p = 0.0; + if (this._nativeVad !== undefined) { + p = await this._nativeVad.predict(inferenceWindow); + } + + const windowDurationMs = (this._windowSamples / MODEL_SAMPLE_RATE) * 1000; + pubCurrentSample += this._windowSamples; + pubTimestampMs += windowDurationMs; + const resamplingRatio = this._inputSampleRate / MODEL_SAMPLE_RATE; + const toCopy = this._windowSamples * resamplingRatio + inputCopyRemainingFrac; + const toCopyInt = Math.trunc(toCopy); + inputCopyRemainingFrac = toCopy - toCopyInt; + + // Append the input-rate samples we just consumed into the + // speech buffer so START_OF_SPEECH / END_OF_SPEECH events can + // hand downstream consumers (STT, transcription) the prefix- + // padded audio they need. + const availableSpace = this._speechBuffer.length - speechBufferIndex; + const toCopyBuffer = Math.min(toCopyInt, availableSpace); + if (toCopyBuffer > 0) { + this._speechBuffer.set(inputFrame.data.subarray(0, toCopyBuffer), speechBufferIndex); + speechBufferIndex += toCopyBuffer; + } else if (!this._speechBufferMaxReached) { + this._speechBufferMaxReached = true; + this._logger.warn( + 'maxBufferedSpeech reached, ignoring further data for the current speech input', + ); + } + + const inferenceDuration = performance.now() - startTime; + extraInferenceTime = Math.max(0, extraInferenceTime + inferenceDuration - windowDurationMs); + if (extraInferenceTime > SLOW_INFERENCE_THRESHOLD_MS) { + this._logger.warn( + { extraInferenceTimeMs: extraInferenceTime }, + 'VAD slower than realtime', + ); + } + + if (pubSpeaking) pubSpeechDurationMs += windowDurationMs; + else pubSilenceDurationMs += windowDurationMs; + + this.sendVADEvent({ + type: VADEventType.INFERENCE_DONE, + samplesIndex: pubCurrentSample, + timestamp: pubTimestampMs, + silenceDuration: pubSilenceDurationMs, + speechDuration: pubSpeechDurationMs, + probability: p, + inferenceDuration, + frames: [ + new AudioFrame( + inputFrame.data.subarray(0, toCopyInt), + this._inputSampleRate, + 1, + toCopyInt, + ), + ], + speaking: pubSpeaking, + rawAccumulatedSilence: silenceThresholdDurationMs, + rawAccumulatedSpeech: speechThresholdDurationMs, + }); + + if ( + p >= this._opts.activationThreshold || + (pubSpeaking && p > this._opts.deactivationThreshold) + ) { + speechThresholdDurationMs += windowDurationMs; + silenceThresholdDurationMs = 0; + if (!pubSpeaking && speechThresholdDurationMs >= this._opts.minSpeechDuration) { + pubSpeaking = true; + pubSilenceDurationMs = 0; + pubSpeechDurationMs = speechThresholdDurationMs; + this.sendVADEvent({ + type: VADEventType.START_OF_SPEECH, + samplesIndex: pubCurrentSample, + timestamp: pubTimestampMs, + silenceDuration: pubSilenceDurationMs, + speechDuration: pubSpeechDurationMs, + probability: p, + inferenceDuration, + frames: [copySpeechBuffer()], + speaking: true, + rawAccumulatedSilence: 0, + rawAccumulatedSpeech: 0, + }); + } + } else { + silenceThresholdDurationMs += windowDurationMs; + speechThresholdDurationMs = 0; + // Keep a sliding pre-roll window while we're not in active + // speech — without this the buffer would fill with idle + // silence and the next START_OF_SPEECH would lose its + // prefix-padding context. + if (!pubSpeaking) resetWriteCursor(); + if (pubSpeaking && silenceThresholdDurationMs >= this._opts.minSilenceDuration) { + pubSpeaking = false; + pubSilenceDurationMs = silenceThresholdDurationMs; + this.sendVADEvent({ + type: VADEventType.END_OF_SPEECH, + samplesIndex: pubCurrentSample, + timestamp: pubTimestampMs, + silenceDuration: pubSilenceDurationMs, + speechDuration: Math.max(0, pubSpeechDurationMs - silenceThresholdDurationMs), + probability: p, + inferenceDuration, + frames: [copySpeechBuffer()], + speaking: false, + rawAccumulatedSilence: 0, + rawAccumulatedSpeech: 0, + }); + pubSpeechDurationMs = 0; + resetWriteCursor(); + } + } + + inputFrames = []; + inferenceFrames = []; + if (inputFrame.data.length > toCopyInt) { + const data = inputFrame.data.subarray(toCopyInt); + inputFrames.push(new AudioFrame(data, this._inputSampleRate, 1, Math.trunc(data.length))); + } + if (inferenceFrame.data.length > this._windowSamples) { + const data = inferenceFrame.data.subarray(this._windowSamples); + inferenceFrames.push(new AudioFrame(data, MODEL_SAMPLE_RATE, 1, Math.trunc(data.length))); + } + } + } + this._resampler?.close?.(); + } +} + +/** Minimal frame-merging helper. The silero plugin uses `mergeFrames` from + * the agents package — for the inference VAD we keep a local copy to avoid + * an import cycle through `index.ts`. */ +function mergeFrames(frames: AudioFrame[]): AudioFrame { + if (frames.length === 1) return frames[0]!; + const sampleRate = frames[0]!.sampleRate; + const channels = frames[0]!.channels; + let total = 0; + for (const f of frames) total += f.samplesPerChannel; + const buf = new Int16Array(total * channels); + let offset = 0; + for (const f of frames) { + buf.set(f.data, offset); + offset += f.samplesPerChannel * channels; + } + return new AudioFrame(buf, sampleRate, channels, total); +} diff --git a/agents/src/metrics/base.ts b/agents/src/metrics/base.ts index f6af79ec9..3ae5546ac 100644 --- a/agents/src/metrics/base.ts +++ b/agents/src/metrics/base.ts @@ -15,6 +15,7 @@ export type AgentMetrics = | TTSMetrics | VADMetrics | EOUMetrics + | EOTInferenceMetrics | RealtimeModelMetrics | InterruptionMetrics | AvatarMetrics; @@ -197,6 +198,25 @@ export type RealtimeModelMetrics = { metadata?: MetricsMetadata; }; +/** + * Per-prediction telemetry for the audio EOT (end-of-turn) detector. Emitted + * by transports on each cloud or local prediction so we can track detection + * latency and inference time per call. + */ +export type EOTInferenceMetrics = { + type: 'eot_inference_metrics'; + timestamp: number; + /** Latest RTT time taken to perform inference, in milliseconds. */ + totalDuration: number; + /** Latest time taken by the model side, in milliseconds. */ + predictionDuration: number; + /** Latest total time from audio-frame creation to prediction receive, in milliseconds. */ + detectionDelay: number; + /** Number of prediction requests served (incremental). */ + numRequests: number; + metadata?: MetricsMetadata; +}; + export type InterruptionMetrics = { type: 'interruption_metrics'; timestamp: number; diff --git a/agents/src/metrics/model_usage.ts b/agents/src/metrics/model_usage.ts index 5e723fb51..c2a3e5f51 100644 --- a/agents/src/metrics/model_usage.ts +++ b/agents/src/metrics/model_usage.ts @@ -3,6 +3,7 @@ // SPDX-License-Identifier: Apache-2.0 import type { AgentMetrics, + EOTInferenceMetrics, InterruptionMetrics, LLMMetrics, RealtimeModelMetrics, @@ -84,7 +85,23 @@ export type InterruptionModelUsage = { totalRequests: number; }; -export type ModelUsage = LLMModelUsage | TTSModelUsage | STTModelUsage | InterruptionModelUsage; +/** Aggregate per-provider usage for the audio EOT detector. */ +export type EOTModelUsage = { + type: 'eot_usage'; + /** The provider name (e.g., 'livekit'). */ + provider: string; + /** The model name (e.g., 'turn-detector' for cloud, 'turn-detector-mini' for local). */ + model: string; + /** Total number of EOT prediction requests served. */ + totalRequests: number; +}; + +export type ModelUsage = + | LLMModelUsage + | TTSModelUsage + | STTModelUsage + | InterruptionModelUsage + | EOTModelUsage; export function filterZeroValues(usage: T): Partial { const result: Partial = {} as Partial; @@ -102,10 +119,17 @@ export class ModelUsageCollector { private sttUsage: Map = new Map(); private interruptionUsage: Map = new Map(); + private eotUsage: Map = new Map(); /** Extract provider and model from metrics metadata. */ private extractProviderModel( - metrics: LLMMetrics | STTMetrics | TTSMetrics | RealtimeModelMetrics | InterruptionMetrics, + metrics: + | LLMMetrics + | STTMetrics + | TTSMetrics + | RealtimeModelMetrics + | InterruptionMetrics + | EOTInferenceMetrics, ): [string, string] { let provider = ''; let model = ''; @@ -195,6 +219,21 @@ export class ModelUsageCollector { return usage; } + private getEotUsage(provider: string, model: string): EOTModelUsage { + const key = `${provider}:${model}`; + let usage = this.eotUsage.get(key); + if (!usage) { + usage = { + type: 'eot_usage', + provider, + model, + totalRequests: 0, + }; + this.eotUsage.set(key, usage); + } + return usage; + } + /** Collect metrics and aggregate usage by model/provider. */ collect(metrics: AgentMetrics): void { if (metrics.type === 'llm_metrics') { @@ -239,8 +278,13 @@ export class ModelUsageCollector { const [provider, model] = this.extractProviderModel(metrics); const usage = this.getInterruptionUsage(provider, model); usage.totalRequests += metrics.numRequests; + } else if (metrics.type === 'eot_inference_metrics') { + const [provider, model] = this.extractProviderModel(metrics); + const usage = this.getEotUsage(provider, model); + usage.totalRequests += metrics.numRequests; } - // VAD and EOU metrics are not aggregated for usage tracking. + // VAD and EOU (session-level summary) metrics are not aggregated for + // usage tracking; only per-prediction EOT inference metrics are. } flatten(): ModelUsage[] { @@ -257,6 +301,9 @@ export class ModelUsageCollector { for (const u of this.interruptionUsage.values()) { result.push({ ...u }); } + for (const u of this.eotUsage.values()) { + result.push({ ...u }); + } return result; } } diff --git a/agents/src/telemetry/trace_types.ts b/agents/src/telemetry/trace_types.ts index 7c1bb159a..1f79eca63 100644 --- a/agents/src/telemetry/trace_types.ts +++ b/agents/src/telemetry/trace_types.ts @@ -65,6 +65,13 @@ export const ATTR_EOU_PROBABILITY = 'lk.eou.probability'; export const ATTR_EOU_UNLIKELY_THRESHOLD = 'lk.eou.unlikely_threshold'; export const ATTR_EOU_DELAY = 'lk.eou.endpointing_delay'; export const ATTR_EOU_LANGUAGE = 'lk.eou.language'; +/** Which signal triggered the EOU detection: 'vad' | 'stt' | 'manual'. */ +export const ATTR_EOU_SOURCE = 'lk.eou.source'; +/** True when the audio EOT detector resolved this prediction from its + * inference-window cache instead of running a fresh predict. */ +export const ATTR_EOU_FROM_CACHE = 'lk.eou.from_cache'; +/** Latest input-audio creation time → prediction receive time (ms). */ +export const ATTR_EOU_DETECTION_DELAY = 'lk.eou.detection_delay'; export const ATTR_USER_TRANSCRIPT = 'lk.user_transcript'; export const ATTR_TRANSCRIPT_CONFIDENCE = 'lk.transcript_confidence'; export const ATTR_TRANSCRIPTION_DELAY = 'lk.transcription_delay'; diff --git a/agents/src/utils.ts b/agents/src/utils.ts index 9988da762..8ca4d8b89 100644 --- a/agents/src/utils.ts +++ b/agents/src/utils.ts @@ -1338,6 +1338,31 @@ export function asError(maybeError: unknown): Error { return new Error(String(maybeError)); } +/** + * Resolve a value that may come from an explicit argument, one of several + * environment variables (checked in order), or a final default. + * + * Mirrors Python `livekit.agents.utils.resolve_env_var`. Used by inference + * transports to plumb credentials and URLs (e.g. `LIVEKIT_REMOTE_EOT_URL`, + * `LIVEKIT_INFERENCE_API_KEY`). + */ +export function resolveEnvVar( + value: string | undefined, + envVars: readonly string[], + defaultValue = '', +): string { + if (value !== undefined && value !== '') { + return value; + } + for (const name of envVars) { + const v = process.env[name]; + if (v !== undefined && v !== '') { + return v; + } + } + return defaultValue; +} + /** * Tagged template literal that strips common leading indentation from every line, * trims the first empty line and any trailing whitespace. diff --git a/agents/src/utils_env.test.ts b/agents/src/utils_env.test.ts new file mode 100644 index 000000000..71ecb531a --- /dev/null +++ b/agents/src/utils_env.test.ts @@ -0,0 +1,88 @@ +// SPDX-FileCopyrightText: 2026 LiveKit, Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +/** + * Tests for the `resolveEnvVar` helper contract. + * + * Port of Python `tests/test_utils_env.py`. + */ +import { afterEach, beforeEach, describe, expect, it } from 'vitest'; +import { resolveEnvVar } from './utils.js'; + +const ENV_KEYS = ['LIVEKIT_INFERENCE_URL', 'LIVEKIT_URL'] as const; +const saved: Record = {}; + +beforeEach(() => { + for (const k of ENV_KEYS) { + saved[k] = process.env[k]; + delete process.env[k]; + } +}); + +afterEach(() => { + for (const k of ENV_KEYS) { + if (saved[k] === undefined) delete process.env[k]; + else process.env[k] = saved[k]; + } +}); + +describe('resolveEnvVar', () => { + it('returns empty string when no env or default', () => { + expect(resolveEnvVar(undefined, ['LIVEKIT_INFERENCE_URL'])).toBe(''); + }); + + it('returns default when no matching env exists', () => { + expect(resolveEnvVar(undefined, ['LIVEKIT_INFERENCE_URL'], 'https://default.example.com')).toBe( + 'https://default.example.com', + ); + }); + + it('returns first matching env value', () => { + process.env.LIVEKIT_INFERENCE_URL = 'https://inference.example.com'; + process.env.LIVEKIT_URL = 'https://livekit.example.com'; + expect( + resolveEnvVar( + undefined, + ['LIVEKIT_INFERENCE_URL', 'LIVEKIT_URL'], + 'https://default.example.com', + ), + ).toBe('https://inference.example.com'); + }); + + it('falls back to later env when earlier env missing', () => { + process.env.LIVEKIT_URL = 'https://livekit.example.com'; + expect( + resolveEnvVar( + undefined, + ['LIVEKIT_INFERENCE_URL', 'LIVEKIT_URL'], + 'https://default.example.com', + ), + ).toBe('https://livekit.example.com'); + }); + + it('prefers explicit value over environment', () => { + process.env.LIVEKIT_INFERENCE_URL = 'https://env.example.com'; + expect( + resolveEnvVar( + 'https://explicit.example.com', + ['LIVEKIT_INFERENCE_URL'], + 'https://default.example.com', + ), + ).toBe('https://explicit.example.com'); + }); + + it('treats empty env value as missing', () => { + process.env.LIVEKIT_INFERENCE_URL = ''; + expect(resolveEnvVar(undefined, ['LIVEKIT_INFERENCE_URL'], 'https://default.example.com')).toBe( + 'https://default.example.com', + ); + }); + + it('treats whitespace env value as set', () => { + process.env.LIVEKIT_INFERENCE_URL = ' '; + expect(resolveEnvVar(undefined, ['LIVEKIT_INFERENCE_URL'], 'https://default.example.com')).toBe( + ' ', + ); + }); +}); diff --git a/agents/src/vad.ts b/agents/src/vad.ts index 422c2654c..09220ab0f 100644 --- a/agents/src/vad.ts +++ b/agents/src/vad.ts @@ -28,11 +28,11 @@ export interface VADEvent { * Index of the audio sample where the event occurred, relative to the inference sample rate. */ samplesIndex: number; - /** Timestamp when the event was fired. */ + /** Timestamp (milliseconds since epoch) when the event was fired. */ timestamp: number; - /** Duration of the speech segment in seconds. */ + /** Duration of the speech segment in milliseconds. */ speechDuration: number; - /** Duration of the silence segment in seconds. */ + /** Duration of the silence segment in milliseconds. */ silenceDuration: number; /** * List of audio frames associated with the speech. @@ -45,7 +45,7 @@ export interface VADEvent { frames: AudioFrame[]; /** Probability that speech is present (only for `INFERENCE_DONE` events). */ probability: number; - /** Time taken to perform the inference, in seconds (only for `INFERENCE_DONE` events). */ + /** Time taken to perform the inference, in milliseconds (only for `INFERENCE_DONE` events). */ inferenceDuration: number; /** Indicates whether speech was detected in the frames. */ speaking: boolean; @@ -77,6 +77,19 @@ export abstract class VAD extends (EventEmitter as new () => TypedEmitter 0) { const outputs = await ThrowsPromise.allSettled(tasks); @@ -222,6 +234,14 @@ export class AgentActivity implements RecognitionHooks { private isInterruptionByAudioActivityEnabled: boolean; private isDefaultInterruptionByAudioActivityEnabled: boolean; + /** + * Validated turn detection for this activity. Equals `this.turnDetection` + * except when an `AudioTurnDetector` instance fails the runtime preconditions + * (no VAD, or RealtimeModel with server-side turn detection enabled), in + * which case it is downgraded to `undefined` and a warning is logged. + */ + private _resolvedTurnDetection: TurnDetectionMode | undefined; + // for false interruption handling private pausedSpeech?: PausedSpeechInfo; private falseInterruptionTimer?: NodeJS.Timeout; @@ -288,8 +308,9 @@ export class AgentActivity implements RecognitionHooks { }); this.q_updated = new Future(); + this._resolvedTurnDetection = this._resolveTurnDetection(this.turnDetection); this.turnDetectionMode = - typeof this.turnDetection === 'string' ? this.turnDetection : undefined; + typeof this._resolvedTurnDetection === 'string' ? this._resolvedTurnDetection : undefined; if (this.turnDetectionMode === 'vad' && this.vad === undefined) { this.logger.warn( @@ -338,10 +359,13 @@ export class AgentActivity implements RecognitionHooks { this.turnDetectionMode = undefined; } - // fallback to VAD if server side turn detection is disabled and VAD is available + // fallback to VAD if server side turn detection is disabled and the + // user explicitly supplied a VAD. The bundled-default VAD is treated + // as absent here so behavior matches "no vad passed" sessions. if ( !this.llm.capabilities.turnDetection && this.vad && + !this.usingDefaultVad && this.turnDetectionMode === undefined ) { this.turnDetectionMode = 'vad'; @@ -512,12 +536,27 @@ export class AgentActivity implements RecognitionHooks { this.vad.on('metrics_collected', this.onMetricsCollected); } + if (this._resolvedTurnDetection instanceof AudioTurnDetector) { + this._resolvedTurnDetection.on('metrics_collected', this.onMetricsCollected); + } + + // Bundled-default VAD is treated as absent when the RealtimeModel does + // its own server-side turn detection — the realtime session is already + // canonical and an extra audio pipeline would just pay the native model + // load for no behavioral gain. User-supplied VADs still flow through + // (e.g. when the user wants adaptive interruption). + const realtimeUsesServerVad = + this.llm instanceof RealtimeModel && this.llm.capabilities.turnDetection === true; + const recognitionVad = this.usingDefaultVad && realtimeUsesServerVad ? undefined : this.vad; + this.audioRecognition = new AudioRecognition({ recognitionHooks: this, // Disable stt node if stt is not provided stt: this.stt ? (...args) => this.agent.sttNode(...args) : undefined, - vad: this.vad, - turnDetector: typeof this.turnDetection === 'string' ? undefined : this.turnDetection, + vad: recognitionVad, + usingDefaultVad: this.usingDefaultVad, + turnDetector: + typeof this._resolvedTurnDetection === 'string' ? undefined : this._resolvedTurnDetection, turnDetectionMode: this.turnDetectionMode, interruptionDetection: this.interruptionDetector, backchannelBoundary: @@ -534,12 +573,19 @@ export class AgentActivity implements RecognitionHooks { shouldDiscardAudioForStt: () => this.shouldDiscardInputAudio(), }); - if (reuseResources?.sttPipeline) { + const sttPipeline = reuseResources?.sttPipeline; + const turnDetectorStream = reuseResources?.turnDetectorStream; + if (sttPipeline) { this.logger.debug('reusing STT pipeline from previous activity'); - await this.audioRecognition.start({ sttPipeline: reuseResources.sttPipeline }); - reuseResources.sttPipeline = undefined; // ownership transferred - } else { - await this.audioRecognition.start(); + } + if (turnDetectorStream) { + this.logger.debug('reusing turn detector stream from previous activity'); + } + await this.audioRecognition.start({ sttPipeline, turnDetectorStream }); + if (reuseResources) { + // ownership transferred to the new AudioRecognition + reuseResources.sttPipeline = undefined; + reuseResources.turnDetectorStream = undefined; } this.started = true; @@ -578,6 +624,15 @@ export class AgentActivity implements RecognitionHooks { resources.sttPipeline = await this.audioRecognition.detachSttPipeline(); } + // reuse the turn detector stream during a handoff whenever we can + if ( + this.audioRecognition && + this._resolvedTurnDetection instanceof AudioTurnDetector && + this._resolvedTurnDetection === newActivity._resolvedTurnDetection + ) { + resources.turnDetectorStream = this.audioRecognition.detachTurnDetector(); + } + // rt session if ( this.realtimeSession && @@ -641,6 +696,18 @@ export class AgentActivity implements RecognitionHooks { return this.agent.vad || this.agentSession.vad; } + /** + * True iff the effective VAD for this activity is the framework-auto-provisioned + * default. False when the user passed `vad=` to either the agent or the + * session, even if the value happens to be the same silero model. + */ + get usingDefaultVad(): boolean { + if (this.agent.vad !== undefined) { + return false; + } + return this.agentSession._usingDefaultVad; + } + get stt(): STT | undefined { return this.agent.stt || this.agentSession.stt; } @@ -966,7 +1033,13 @@ export class AgentActivity implements RecognitionHooks { // -- Metrics and errors -- private onMetricsCollected = ( - ev: STTMetrics | TTSMetrics | VADMetrics | LLMMetrics | RealtimeModelMetrics, + ev: + | STTMetrics + | TTSMetrics + | VADMetrics + | LLMMetrics + | RealtimeModelMetrics + | EOTInferenceMetrics, ) => { const speechHandle = speechHandleStorage.getStore(); if (speechHandle && (ev.type === 'llm_metrics' || ev.type === 'tts_metrics')) { @@ -1018,7 +1091,11 @@ export class AgentActivity implements RecognitionHooks { onInputSpeechStarted(_ev: InputSpeechStartedEvent): void { this.logger.info('onInputSpeechStarted'); - if (!this.vad) { + // Bundled-default VAD is treated as absent here so the realtime + // session's own server-side turn detection drives the user-state / + // overlap-detection update, identical to a session that didn't + // configure any VAD. + if (!this.vad || this.usingDefaultVad) { this.agentSession._updateUserState('speaking'); if (this.isInterruptionDetectionEnabled && this.audioRecognition) { this.audioRecognition.onStartOfOverlapSpeech( @@ -1044,7 +1121,7 @@ export class AgentActivity implements RecognitionHooks { onInputSpeechStopped(ev: InputSpeechStoppedEvent): void { this.logger.info(ev, 'onInputSpeechStopped'); - if (!this.vad) { + if (!this.vad || this.usingDefaultVad) { if (this.isInterruptionDetectionEnabled && this.audioRecognition) { this.audioRecognition.onEndOfOverlapSpeech(Date.now(), this.agentSession._userSpeakingSpan); } @@ -1376,6 +1453,12 @@ export class AgentActivity implements RecognitionHooks { this.cancelSpeechPauseTask = this.cancelSpeechPause(); } + /** Forward audio EOT predictions up to the session so listeners (e.g. + * remote-session forwarders) can observe them. */ + onEotPrediction(ev: EotPredictionEvent): void { + this.agentSession.emit(AgentSessionEventTypes.EotPrediction, ev); + } + onPreemptiveGeneration(info: PreemptiveGenerationInfo): void { const preemptiveOpts = this.agentSession.sessionOptions.turnHandling.preemptiveGeneration; if ( @@ -3708,6 +3791,29 @@ export class AgentActivity implements RecognitionHooks { } } + private _resolveTurnDetection( + turnDetection: TurnDetectionMode | undefined, + ): TurnDetectionMode | undefined { + if (turnDetection !== undefined && typeof turnDetection !== 'string') { + if (turnDetection instanceof AudioTurnDetector) { + if (this.vad === undefined) { + this.logger.warn( + 'AudioTurnDetector requires a VAD model. Pass vad=inference.VAD() to AgentSession/Agent or turnDetection=null to disable the default AudioTurnDetector', + ); + return undefined; + } + if (this.llm instanceof RealtimeModel && this.llm.capabilities.turnDetection) { + this.logger.warn( + 'turnDetection is an AudioTurnDetector, but the LLM is a RealtimeModel with server-side turn detection enabled, ignoring the turnDetection setting', + ); + return undefined; + } + } + return turnDetection; + } + return turnDetection; + } + private resolveInterruptionDetector(): AdaptiveInterruptionDetector | undefined { const agentInterruptionDetection = this.agent.turnHandling?.interruption?.mode; const sessionInterruptionDetection = this.agentSession.interruptionDetection; @@ -3716,7 +3822,7 @@ export class AgentActivity implements RecognitionHooks { this.stt && this.stt.capabilities.alignedTranscript && this.stt.capabilities.streaming && - this.vad && + this.vad !== undefined && this.turnDetection !== 'manual' && this.turnDetection !== 'realtime_llm' && !(this.llm instanceof RealtimeModel) @@ -3970,6 +4076,10 @@ export class AgentActivity implements RecognitionHooks { this.vad.off('metrics_collected', this.onMetricsCollected); } + if (this._resolvedTurnDetection instanceof AudioTurnDetector) { + this._resolvedTurnDetection.off('metrics_collected', this.onMetricsCollected); + } + this.detachAudioInput(); this.realtimeSpans?.clear(); await this.realtimeSession?.close(); diff --git a/agents/src/voice/agent_session.ts b/agents/src/voice/agent_session.ts index 6449cb080..1fd27b6da 100644 --- a/agents/src/voice/agent_session.ts +++ b/agents/src/voice/agent_session.ts @@ -12,8 +12,10 @@ import type { ReadableStream } from 'node:stream/web'; import type { z } from 'zod'; import { LLM as InferenceLLM, + AudioTurnDetector as InferenceAudioTurnDetector, STT as InferenceSTT, TTS as InferenceTTS, + VAD as InferenceVAD, type LLMModels, type STTModelString, type TTSModelString, @@ -61,6 +63,7 @@ import { type CloseEvent, CloseReason, type ConversationItemAddedEvent, + type EotPredictionEvent, type ErrorEvent, type FunctionToolsExecutedEvent, type MetricsCollectedEvent, @@ -88,6 +91,7 @@ import type { UnknownUserData } from './run_context.js'; import type { SpeechHandle } from './speech_handle.js'; import { RunResult } from './testing/run_result.js'; import type { TextTransform } from './transcription/text_transforms.js'; +import type { AudioTurnDetector } from '../inference/eot/base.js'; import type { EndpointingOptions } from './turn_config/endpointing.js'; import type { InterruptionOptions } from './turn_config/interruption.js'; import type { @@ -143,7 +147,13 @@ export type VoiceOptions = { maxEndpointingDelay?: number; }; -export type TurnDetectionMode = 'stt' | 'vad' | 'realtime_llm' | 'manual' | _TurnDetector; +export type TurnDetectionMode = + | 'stt' + | 'vad' + | 'realtime_llm' + | 'manual' + | _TurnDetector + | AudioTurnDetector; export type AgentSessionCallbacks = { [AgentSessionEventTypes.UserInputTranscribed]: (ev: UserInputTranscribedEvent) => void; @@ -158,11 +168,18 @@ export type AgentSessionCallbacks = { [AgentSessionEventTypes.Error]: (ev: ErrorEvent) => void; [AgentSessionEventTypes.Close]: (ev: CloseEvent) => void; [AgentSessionEventTypes.OverlappingSpeech]: (ev: OverlappingSpeechEvent) => void; + [AgentSessionEventTypes.EotPrediction]: (ev: EotPredictionEvent) => void; }; export type AgentSessionOptions = { stt?: STT | STTModelString; - vad?: VAD; + /** + * Voice Activity Detection. When omitted, `AgentSession` auto-provisions a + * bundled `inference.VAD({ model: 'silero' })` and marks it as the default + * (so sites that previously distinguished "user supplied a VAD" continue + * to treat the bundled one as absent). Pass `null` to opt out entirely. + */ + vad?: VAD | null; llm?: LLM | RealtimeModel | LLMModels; tts?: TTS | TTSModelString; userData?: UserData; @@ -305,6 +322,15 @@ export class AgentSession< private _interruptionDetection?: InterruptionOptions['mode']; + /** + * True iff this session auto-provisioned the bundled silero VAD because the + * caller passed no `vad=`. Set once in the constructor; immutable from then + * on. Read it via `AgentActivity.usingDefaultVad` from voice-pipeline code. + * + * @internal + */ + _usingDefaultVad: boolean = false; + /** @internal */ _usageCollector: ModelUsageCollector = new ModelUsageCollector(); @@ -371,7 +397,19 @@ export class AgentSession< DEFAULT_SESSION_CONNECT_OPTIONS.maxUnrecoverableErrors, }; - this.vad = vad; + // VAD: undefined → auto-provision bundled inference.VAD (silero). The + // `_usingDefaultVad` marker is the single source of truth for "this VAD + // was framework-provisioned" — code paths that should ignore a default + // VAD read it via `AgentActivity.usingDefaultVad`. null → leave VAD off + // entirely. Otherwise use what the caller supplied. + this._usingDefaultVad = vad === undefined; + if (vad === undefined) { + this.vad = new InferenceVAD({ model: 'silero' }); + } else if (vad === null) { + this.vad = undefined; + } else { + this.vad = vad; + } if (typeof stt === 'string') { this.stt = InferenceSTT.fromModelString(stt); @@ -391,7 +429,11 @@ export class AgentSession< this.tts = tts; } - this.turnDetection = resolvedSessionOptions.turnHandling.turnDetection; + // Default turn_detection: when the caller didn't pin a mode or supply a + // detector instance, fall back to a fresh inference.AudioTurnDetector so + // every session ships with multimodal EOT out of the box. + this.turnDetection = + resolvedSessionOptions.turnHandling.turnDetection ?? new InferenceAudioTurnDetector(); this._interruptionDetection = resolvedSessionOptions.turnHandling.interruption?.mode; this._userData = userData; diff --git a/agents/src/voice/agent_session_default_vad.test.ts b/agents/src/voice/agent_session_default_vad.test.ts new file mode 100644 index 000000000..28711c62a --- /dev/null +++ b/agents/src/voice/agent_session_default_vad.test.ts @@ -0,0 +1,60 @@ +// SPDX-FileCopyrightText: 2026 LiveKit, Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +/** + * Tests for the bundled-default VAD behavior on `AgentSession`. + * + * Port of three test additions on `tests/test_agent_session.py`: + * + * - `test_default_vad_is_auto_provisioned` + * - `test_explicit_vad_none_opts_out` + * - `test_user_supplied_vad_keeps_using_default_false` + */ +import { describe, expect, it } from 'vitest'; +import type { VADStream } from '../vad.js'; +import { VAD as BaseVAD } from '../vad.js'; +import { AgentSession } from './agent_session.js'; + +class FakeVAD extends BaseVAD { + label = 'FakeVAD'; + constructor() { + super({ updateInterval: 32 }); + } + stream(): VADStream { + throw new Error('not used in this test'); + } +} + +describe('AgentSession default VAD', () => { + it('auto-provisions a default VAD when none passed', async () => { + const session = new AgentSession(); + try { + expect(session.vad).toBeDefined(); + expect(session._usingDefaultVad).toBe(true); + } finally { + await session.close().catch(() => {}); + } + }); + + it('explicit `vad: null` opts out', async () => { + const session = new AgentSession({ vad: null }); + try { + expect(session.vad).toBeUndefined(); + expect(session._usingDefaultVad).toBe(false); + } finally { + await session.close().catch(() => {}); + } + }); + + it('user-supplied VAD keeps _usingDefaultVad false', async () => { + const userVad = new FakeVAD(); + const session = new AgentSession({ vad: userVad }); + try { + expect(session.vad).toBe(userVad); + expect(session._usingDefaultVad).toBe(false); + } finally { + await session.close().catch(() => {}); + } + }); +}); diff --git a/agents/src/voice/audio_recognition.ts b/agents/src/voice/audio_recognition.ts index 1b2d76742..e3ca6c025 100644 --- a/agents/src/voice/audio_recognition.ts +++ b/agents/src/voice/audio_recognition.ts @@ -14,6 +14,12 @@ import { import type { ReadableStream, WritableStreamDefaultWriter } from 'node:stream/web'; import { TransformStream } from 'node:stream/web'; import { isAPIError } from '../_exceptions.js'; +import { + AudioTurnDetector, + AudioTurnDetectorStream, + MIN_SILENCE_DURATION_MS, + type TurnDetectionEvent, +} from '../inference/eot/base.js'; import { apiConnectDefaults, intervalForRetry } from '../inference/interruption/defaults.js'; import { InterruptionDetectionError } from '../inference/interruption/errors.js'; import type { AdaptiveInterruptionDetector } from '../inference/interruption/interruption_detector.js'; @@ -32,10 +38,15 @@ import { type StreamChannel, createStreamChannel } from '../stream/stream_channe import { type SpeechEvent, SpeechEventType } from '../stt/stt.js'; import { traceTypes, tracer } from '../telemetry/index.js'; import { splitWords } from '../tokenize/basic/word.js'; -import { Task, cancelAndWait, delay, readStream, waitForAbort } from '../utils.js'; +import { Event, Task, cancelAndWait, delay, readStream, waitForAbort } from '../utils.js'; import { type VAD, type VADEvent, VADEventType } from '../vad.js'; import type { TurnDetectionMode } from './agent_session.js'; -import { type UserTurnExceededEvent, createUserTurnExceededEvent } from './events.js'; +import { + type EotPredictionEvent, + type UserTurnExceededEvent, + createEotPredictionEvent, + createUserTurnExceededEvent, +} from './events.js'; import type { STTNode } from './io.js'; import { type BaseEndpointing, @@ -79,6 +90,7 @@ export interface RecognitionHooks { onInterimTranscript: (ev: SpeechEvent, speaking: boolean | undefined) => void; onFinalTranscript: (ev: SpeechEvent, speaking: boolean | undefined) => void; onEndOfTurn: (info: EndOfTurnInfo) => Promise; + onEotPrediction: (ev: EotPredictionEvent) => void; onPreemptiveGeneration: (info: PreemptiveGenerationInfo) => void; onUserTurnExceeded: (ev: UserTurnExceededEvent) => void; @@ -91,6 +103,42 @@ interface UserTurnTracker { startedAt?: number; } +/** + * Edge-triggered event with an abort-aware `waitOnce` helper. + * + * Used by the audio-EOT bounce race: the bounce task awaits either the + * endpointing delay or a fresh "user started speaking" signal. We extend + * the base `Event` rather than reimplementing it because the base already + * handles the resolver / waiter bookkeeping; this subclass just layers a + * `waitOnce(signal)` that rejects on cancel so the race can tear down + * cleanly when the parent task is aborted. + */ +class SpeakingEvent extends Event { + /** + * Resolves on the next `set()`. Rejects (and cleans up the listener) if + * `signal` aborts first. Returns immediately if the event is already set. + */ + async waitOnce(signal: AbortSignal): Promise { + if (this.isSet) return; + let abortListener: (() => void) | undefined; + try { + await Promise.race([ + this.wait().then(() => undefined), + new Promise((_resolve, reject) => { + if (signal.aborted) { + reject(signal.reason ?? new Error('aborted')); + return; + } + abortListener = () => reject(signal.reason ?? new Error('aborted')); + signal.addEventListener('abort', abortListener, { once: true }); + }), + ]); + } finally { + if (abortListener !== undefined) signal.removeEventListener('abort', abortListener); + } + } +} + export class STTPipeline { static readonly PUMP_TASK_CANCEL_TIMEOUT = 5000; @@ -141,7 +189,11 @@ export interface _TurnDetector { readonly provider: string; unlikelyThreshold: (language?: LanguageCode) => Promise; supportsLanguage: (language?: LanguageCode) => Promise; - predictEndOfTurn(chatCtx: ChatContext, timeout?: number): Promise; + /** + * @param timeoutMs - Optional inference wait budget in milliseconds. The audio + * EOT detector honors it; text-based detectors currently ignore it. + */ + predictEndOfTurn(chatCtx: ChatContext, timeoutMs?: number): Promise; } export interface AudioRecognitionOptions { @@ -151,8 +203,17 @@ export interface AudioRecognitionOptions { stt?: STTNode; /** Voice activity detection. */ vad?: VAD; - /** Turn detector for end-of-turn prediction. */ - turnDetector?: _TurnDetector; + /** + * True iff the wired VAD was auto-provisioned by `AgentSession` rather than + * supplied by the caller. Read at every "is VAD configured?" call site so + * a framework-default VAD behaves like no VAD for downstream eligibility + * decisions (e.g. STT-hook `speaking=` payload). + */ + usingDefaultVad?: boolean; + /** Turn detector for end-of-turn prediction. Accepts text-based detectors + * via `_TurnDetector` (e.g. plugins/livekit) or audio-based detectors via + * `AudioTurnDetector` (e.g. `inference.AudioTurnDetector`). */ + turnDetector?: _TurnDetector | AudioTurnDetector; /** Turn detection mode. */ turnDetectionMode?: TurnDetectionMode; interruptionDetection?: AdaptiveInterruptionDetector; @@ -199,7 +260,27 @@ export class AudioRecognition { private stt?: STTNode; private sttPipeline?: STTPipeline; private vad?: VAD; - private turnDetector?: _TurnDetector; + private usingDefaultVad: boolean; + private turnDetector?: _TurnDetector | AudioTurnDetector; + private turnDetectorStream?: AudioTurnDetectorStream; + /** + * The last `TurnDetectionEvent` we forwarded via `onEotPrediction`, kept + * by reference to dedupe: both EOU triggers in a turn read the same + * cached prediction, but the event should fire once per inference window. + */ + private lastEmittedEotPrediction?: TurnDetectionEvent; + /** + * Edge-triggered "user is speaking" event used by the audio-EOT bounce + * race. Set on VAD `START_OF_SPEECH` (and on any `INFERENCE_DONE` with + * accumulated speech), cleared on `END_OF_SPEECH`. Mirrors Python + * `_user_speaking_event`. + * + * `Event.set()` is idempotent (re-setting an already-set event resolves + * any new waiters immediately); cleared on EOS so subsequent waiters + * park until the next utterance. + */ + private userSpeakingEvent = new SpeakingEvent(); + private warnedTurnDetectorPushFailure = false; private turnDetectionMode?: TurnDetectionMode; private endpointing: BaseEndpointing; private userTurnLimit?: UserTurnLimitOptions; @@ -278,7 +359,12 @@ export class AudioRecognition { this.hooks = opts.recognitionHooks; this.stt = opts.stt; this.vad = opts.vad; + this.usingDefaultVad = opts.usingDefaultVad ?? false; this.turnDetector = opts.turnDetector; + this.checkVadSilenceRequirement(); + // The FSM stream is opened on `start()` so callers can hand off the + // previous activity's stream (cloud↔local fallback state, in-flight + // inference) instead of forcing a cold restart. this.turnDetectionMode = opts.turnDetectionMode; this.userTurnLimit = opts.userTurnLimit; this.endpointing = @@ -327,6 +413,26 @@ export class AudioRecognition { { transform: (chunk, controller) => { controller.enqueue(chunk); + // Fan the same frame into the audio EOT detector stream when + // one is attached. The FSM accepts arbitrary-rate input and + // resamples internally. `pushAudio` is a no-op when the stream's + // internal channel is closed; any actual throw indicates a bug + // (e.g. resampler init failure, sample-rate mismatch). Log once + // when we hit that path so a regression doesn't silently drop + // every audio frame. + if (this.turnDetectorStream !== undefined) { + try { + this.turnDetectorStream.pushAudio(chunk); + } catch (err) { + if (!this.warnedTurnDetectorPushFailure) { + this.warnedTurnDetectorPushFailure = true; + this.logger.warn( + { err: err instanceof Error ? err.message : String(err) }, + 'audio EOT stream pushAudio failed; dropping frames for this turn', + ); + } + } + } if (this.subscriberWriters.length === 0) return; for (const writer of this.subscriberWriters) { writer.write(chunk).catch(() => { @@ -405,7 +511,147 @@ export class AudioRecognition { } } - async start(options?: { sttPipeline?: STTPipeline }) { + /** True iff the user supplied their own VAD (default-VAD is treated as + * absent at sites that decide between "use VAD signal" and "STT-derived + * speaking"). */ + private get hasUserVad(): boolean { + return this.vad !== undefined && !this.usingDefaultVad; + } + + /** + * Swap the active turn detector at runtime. When an `AudioTurnDetector` + * is provided, opens a per-turn FSM stream after retiring the prior one. + * + * When `stream` is provided it is adopted as-is (handoff reuse) instead of + * opening a fresh stream on `detector`; the live transport stream — and its + * per-session cloud→local fallback state — survives the handoff. + */ + updateTurnDetector( + detector: _TurnDetector | AudioTurnDetector | undefined, + options?: { stream?: AudioTurnDetectorStream }, + ): void { + // Validate against the incoming detector before swapping in so the error + // — when raised — names the configuration that failed. + this.checkVadSilenceRequirement(detector); + this.turnDetector = detector; + + const reuseStream = options?.stream; + // Retire the prior stream before creating the new one. `detach()` frees + // the detector's single-stream slot synchronously (so `stream()` below + // won't throw if the same detector is reused), while the network teardown + // runs in the background. + const oldStream = this.turnDetectorStream; + if (oldStream !== undefined && oldStream !== reuseStream) { + oldStream.detach(); + void oldStream.aclose().catch(() => undefined); + } + // Cross-detector state should not leak: the cached speaking signal + // from the prior detector's turn must not race the new detector's + // first bounce. + this.userSpeakingEvent.clear(); + if (reuseStream !== undefined) { + this.turnDetectorStream = reuseStream; + } else { + this.turnDetectorStream = + detector instanceof AudioTurnDetector ? detector.stream() : undefined; + } + } + + /** + * Detach the turn detector stream for handoff to another AudioRecognition. + * + * Returns the live stream (transport run loop intact) without closing it. + * The caller passes it to the new AudioRecognition via + * `start({ turnDetectorStream })`. The stream stays attached to its + * detector, retaining the detector's single-stream slot, so the new + * AudioRecognition must adopt it rather than open a second stream. + */ + detachTurnDetector(): AudioTurnDetectorStream | undefined { + const stream = this.turnDetectorStream; + this.turnDetectorStream = undefined; + return stream; + } + + /** + * The audio EOT detector needs a wider silence window than typical VAD + * defaults. Rather than mutate the VAD's knob, require the caller to + * configure it: raise if the bound VAD exposes `minSilenceDuration` and it + * is below the floor. VADs that don't expose the knob are left untouched. + */ + private checkVadSilenceRequirement( + detector: _TurnDetector | AudioTurnDetector | undefined = this.turnDetector, + ): void { + if (!(detector instanceof AudioTurnDetector) || this.vad === undefined) { + return; + } + const current = this.vad.minSilenceDuration; + if (current === null) { + return; + } + const required = MIN_SILENCE_DURATION_MS + 50; + if (current < required) { + throw new Error( + `vad minSilenceDuration=${current}ms is too low for the AudioTurnDetector. ` + + `Raise the VAD's minSilenceDuration to at least ${required}ms.`, + ); + } + } + + /** + * Speaking-guard wrapper for the bounce-EOU task, mirroring Python's + * `_bounce_eou_task_with_speaking_guard`. When an `AudioTurnDetector` + * is active, the bounce task races against the `userSpeakingEvent`: + * + * - if the user is already speaking, skip the EOU outright; + * - if the user starts speaking during the endpointing delay (e.g. + * the LLM hadn't returned yet but the user added another phrase), + * abort the inner bounce so the next turn drives the decision. + * + * VAD `START_OF_SPEECH` also calls `bounceEOUTask?.cancel()`, but the + * cancel path only races VAD sessions. STT-only audio-EOT setups need + * the explicit event-driven race here. + */ + private async bounceEOUTaskWithSpeakingGuard( + controller: AbortController, + inner: (innerController: AbortController) => Promise, + context: { + lastSpeakingTime: number | undefined; + lastFinalTranscriptTime: number; + speechStartTime: number | undefined; + }, + ): Promise { + if (this.speaking) { + this.logger.debug(context, 'user is still speaking, skipping end of turn task'); + return; + } + const innerController = new AbortController(); + // Propagate outer cancellation into the inner task. + const onOuterAbort = () => innerController.abort(); + controller.signal.addEventListener('abort', onOuterAbort, { once: true }); + + let speakingWon = false; + try { + await Promise.race([ + inner(innerController), + this.userSpeakingEvent.waitOnce(controller.signal).then(() => { + speakingWon = true; + }), + ]); + if (speakingWon) { + this.logger.debug(context, 'user spoke during endpointing, cancelling end of turn task'); + } + } finally { + controller.signal.removeEventListener('abort', onOuterAbort); + // If the speaking-event branch won (or the outer was aborted), tear + // down the inner bounce so it doesn't keep awaiting the delay. + innerController.abort(); + } + } + + async start(options?: { + sttPipeline?: STTPipeline; + turnDetectorStream?: AudioTurnDetectorStream; + }) { this.startSttTasks(options?.sttPipeline); this.vadTask = Task.from(({ signal }) => this.createVadTask(this.vad, signal)); @@ -419,6 +665,14 @@ export class AudioRecognition { this.interruptionTask.result.catch((err) => { this.logger.error(`Error running interruption task: ${err}`); }); + + // Open (or adopt) the audio EOT detector stream now that the activity is + // running. We only call `updateTurnDetector` for AudioTurnDetector / + // undefined detectors — plugin-based `_TurnDetector` instances are + // text-only and don't carry a stream. + if (this.turnDetector instanceof AudioTurnDetector || this.turnDetector === undefined) { + this.updateTurnDetector(this.turnDetector, { stream: options?.turnDetectorStream }); + } } async stop() { @@ -426,6 +680,11 @@ export class AudioRecognition { await this.sttForwardTask?.cancelAndWait(); await this.vadTask?.cancelAndWait(); await this.interruptionTask?.cancelAndWait(); + if (this.turnDetectorStream !== undefined) { + const stream = this.turnDetectorStream; + this.turnDetectorStream = undefined; + await stream.aclose().catch(() => undefined); + } } async disableInterruptionDetection(): Promise { @@ -819,6 +1078,7 @@ export class AudioRecognition { const transcript = ev.alternatives?.[0]?.text; const confidence = ev.alternatives?.[0]?.confidence ?? 0; this.lastLanguage = ev.alternatives?.[0]?.language; + this.turnDetectorStream?.updateLanguage(this.lastLanguage); if (!transcript) { // stt final transcript received but no transcript @@ -827,7 +1087,7 @@ export class AudioRecognition { this.hooks.onFinalTranscript( ev, - this.vad || this.turnDetectionMode === 'stt' ? this.speaking : undefined, + this.hasUserVad || this.turnDetectionMode === 'stt' ? this.speaking : undefined, ); this.logger.debug( @@ -846,7 +1106,7 @@ export class AudioRecognition { this.audioInterimTranscript = ''; this.audioPreflightTranscript = ''; - if (!this.vad || this.lastSpeakingTime === undefined) { + if (!this.hasUserVad || this.lastSpeakingTime === undefined) { // vad disabled, use stt timestamp // TODO: this would screw up transcription latency metrics // but we'll live with it for now. @@ -877,14 +1137,14 @@ export class AudioRecognition { if (!this.speaking) { const chatCtx = this.hooks.retrieveChatCtx(); this.logger.debug('running EOU detection on stt FINAL_TRANSCRIPT'); - this.runEOUDetection(chatCtx); + this.runEOUDetection(chatCtx, 'stt'); } } break; case SpeechEventType.PREFLIGHT_TRANSCRIPT: this.hooks.onInterimTranscript( ev, - this.vad || this.turnDetectionMode === 'stt' ? this.speaking : undefined, + this.hasUserVad || this.turnDetectionMode === 'stt' ? this.speaking : undefined, ); const preflightTranscript = ev.alternatives?.[0]?.text ?? ''; const preflightConfidence = ev.alternatives?.[0]?.confidence ?? 0; @@ -896,6 +1156,7 @@ export class AudioRecognition { (preflightLanguage && preflightTranscript.length > MIN_LANGUAGE_DETECTION_LENGTH) ) { this.lastLanguage = preflightLanguage; + this.turnDetectorStream?.updateLanguage(this.lastLanguage); } if (!preflightTranscript) { @@ -917,7 +1178,7 @@ export class AudioRecognition { `${this.audioTranscript} ${preflightTranscript}`.trimStart(); this.audioInterimTranscript = preflightTranscript; - if (!this.vad || this.lastSpeakingTime === undefined) { + if (!this.hasUserVad || this.lastSpeakingTime === undefined) { // vad disabled, use stt timestamp this.lastSpeakingTime = Date.now(); } @@ -947,7 +1208,7 @@ export class AudioRecognition { this.logger.debug({ transcript: ev.alternatives?.[0]?.text }, 'interim transcript'); this.hooks.onInterimTranscript( ev, - this.vad || this.turnDetectionMode === 'stt' ? this.speaking : undefined, + this.hasUserVad || this.turnDetectionMode === 'stt' ? this.speaking : undefined, ); this.audioInterimTranscript = ev.alternatives?.[0]?.text ?? ''; break; @@ -977,6 +1238,10 @@ export class AudioRecognition { } this.speaking = true; this.lastSpeakingTime = Date.now(); + // STT-only sessions never see VAD events; surface the speaking + // signal here so the audio-EOT bounce race can still abort on a + // mid-window fresh utterance. + this.userSpeakingEvent.set(); this.bounceEOUTask?.cancel(); break; @@ -1013,18 +1278,21 @@ export class AudioRecognition { // and user state won't be updated until a new VAD SOS is received. // Reset VAD so that incorrect end of turn from STT can be corrected by VAD interruption. // If user is still speaking (an immediate VAD SOS will interrupt the agent). - if (this.vad && this.speaking) { + // Default-bundled VAD is treated as absent here — only user-supplied VADs + // are reset, matching the matrix in PR_DESCRIPTION. + if (this.hasUserVad && this.speaking) { this.logger.warn('stt end of speech received while user is speaking, resetting vad'); this.resetVad(); } this.speaking = false; + this.userSpeakingEvent.clear(); this.userTurnCommitted = true; this.lastSpeakingTime = Date.now(); if (!this.speaking) { const chatCtx = this.hooks.retrieveChatCtx(); this.logger.debug('running EOU detection on stt END_OF_SPEECH'); - this.runEOUDetection(chatCtx); + this.runEOUDetection(chatCtx, 'stt'); } } } @@ -1044,7 +1312,7 @@ export class AudioRecognition { } } - private runEOUDetection(chatCtx: ChatContext) { + private runEOUDetection(chatCtx: ChatContext, trigger: 'vad' | 'stt' | 'manual' = 'vad') { this.logger.debug( { stt: this.stt, @@ -1061,11 +1329,32 @@ export class AudioRecognition { } chatCtx = chatCtx.copy(); - chatCtx.addMessage({ role: 'user', content: this.audioTranscript }); + if (this.audioTranscript) { + chatCtx.addMessage({ role: 'user', content: this.audioTranscript }); + } - const turnDetector = - // disable EOU model if manual turn detection enabled - this.audioTranscript && this.turnDetectionMode !== 'manual' ? this.turnDetector : undefined; + // Pick the right detector: + // - manual mode: no detector (turn boundary decided externally) + // - audio EOT detector: prefer the per-turn stream (it caches the + // prediction for the current inference window so the bounce task + // can short-circuit on cache) + // - text-based detector: only run when we have a transcript to score + const hasAudioDetector = this.turnDetector instanceof AudioTurnDetector; + const useDetector = + this.turnDetectionMode !== 'manual' && (this.audioTranscript || hasAudioDetector); + // The unified type only covers the predict surface; the audio + // detector's per-turn stream stands in for the parent when one is + // attached so the cached prediction is available. + let turnDetector: _TurnDetector | AudioTurnDetectorStream | undefined; + if (!useDetector) { + turnDetector = undefined; + } else if (hasAudioDetector) { + turnDetector = this.turnDetectorStream; + } else { + // text-based detector — `this.turnDetector` cannot be the audio + // base class here, because `hasAudioDetector` already screened it. + turnDetector = this.turnDetector as _TurnDetector | undefined; + } const bounceEOUTask = ( @@ -1086,27 +1375,97 @@ export class AudioRecognition { let endOfTurnProbability = 0.0; let unlikelyThreshold: number | undefined; + // Audio detectors cache the resolved prediction for the + // current inference window — a non-undefined value here + // means `predictEndOfTurn` will short-circuit on cache. + const fromCache = + turnDetector instanceof AudioTurnDetectorStream && + turnDetector.lastPrediction !== undefined; if (!(await turnDetector.supportsLanguage(this.lastLanguage))) { this.logger.debug(`Turn detector does not support language ${this.lastLanguage}`); } else { try { - endOfTurnProbability = await turnDetector.predictEndOfTurn(chatCtx); + endOfTurnProbability = await turnDetector.predictEndOfTurn( + chatCtx, + endpointingDelay, + ); unlikelyThreshold = await turnDetector.unlikelyThreshold(this.lastLanguage); - this.logger.debug( - { endOfTurnProbability, unlikelyThreshold, language: this.lastLanguage }, - 'end of turn probability', - ); + // A newer trigger (e.g. the STT final after VAD end-of-speech) + // calls `bounceEOUTask?.cancel()` on this task. Unlike asyncio, + // a JS abort does NOT interrupt the in-flight `predictEndOfTurn` + // await — both bounces share the same in-flight inference and + // resolve to the same probability. Bail here so the superseded + // bounce doesn't log + emit a duplicate prediction; only the + // surviving bounce proceeds. Mirrors asyncio task cancellation. + if (controller.signal.aborted) { + return; + } if (unlikelyThreshold && endOfTurnProbability < unlikelyThreshold) { endpointingDelay = this.endpointing.maxDelay; } + + // `fromCache` distinguishes the two EOU runs in a single + // turn: the first (vad) does real inference, the second + // (stt final) is served from the detector's window cache — + // same probability, hence the apparent duplicate. + this.logger.debug( + { + endOfTurnProbability, + unlikelyThreshold, + endpointingDelay, + language: this.lastLanguage, + trigger, + fromCache, + }, + 'end of turn probability', + ); } catch (error) { this.logger.error(error, 'Error predicting end of turn'); } } + // Emit the prediction event for every turn detector, text or + // audio. Text-based detectors (e.g. plugins/livekit) have no + // streaming inference window, so `inferenceDurationMs` and + // `detectionDelay` are reported as 0 — that's expected, the + // event still carries the probability + threshold. + // + // Audio detectors cache one prediction per inference window + // and both EOU triggers in a turn (vad + stt final) read the + // same `TurnDetectionEvent`. Dedupe by reference so the event + // fires once per window: the abort guard above drops a + // superseded bounce cancelled mid-await, and this catches the + // race where the first bounce fully completes (and emits) in + // the few ms before the second trigger fires. Text detectors + // run a fresh (slower) inference each bounce with no cache, so + // the abort guard alone is sufficient there. + // Text-based detectors return `undefined` here (no streaming + // inference window); audio detectors return the cached + // `TurnDetectionEvent` for the window. Emit when there's no + // window (text / timeout — always fresh) OR the prediction + // differs from the last one we forwarded; set the tracker on + // every emit (matches Python `_last_emitted_prediction`). + const prediction = + turnDetector instanceof AudioTurnDetectorStream + ? turnDetector.lastPrediction + : undefined; + if (prediction === undefined || prediction !== this.lastEmittedEotPrediction) { + this.lastEmittedEotPrediction = prediction; + const inferenceDurationMs = prediction?.inferenceDuration ?? 0; + const delayMs = lastSpeakingTime !== undefined ? Date.now() - lastSpeakingTime : 0; + this.hooks.onEotPrediction( + createEotPredictionEvent({ + probability: endOfTurnProbability, + threshold: unlikelyThreshold ?? 0, + inferenceDurationMs, + delayMs, + }), + ); + } + span.setAttribute( traceTypes.ATTR_CHAT_CTX, JSON.stringify(chatCtx.toJSON({ excludeTimestamp: false })), @@ -1115,6 +1474,11 @@ export class AudioRecognition { span.setAttribute(traceTypes.ATTR_EOU_UNLIKELY_THRESHOLD, unlikelyThreshold ?? 0); span.setAttribute(traceTypes.ATTR_EOU_DELAY, endpointingDelay); span.setAttribute(traceTypes.ATTR_EOU_LANGUAGE, this.lastLanguage ?? ''); + span.setAttribute(traceTypes.ATTR_EOU_FROM_CACHE, fromCache); + span.setAttribute(traceTypes.ATTR_EOU_SOURCE, trigger); + if (prediction?.detectionDelay !== undefined) { + span.setAttribute(traceTypes.ATTR_EOU_DETECTION_DELAY, prediction.detectionDelay); + } }, { name: 'eou_detection', @@ -1190,9 +1554,24 @@ export class AudioRecognition { // cancel any existing EOU task this.bounceEOUTask?.cancel(); // copy the values before awaiting (the values can change) - this.bounceEOUTask = Task.from( - bounceEOUTask(this.lastSpeakingTime, this.lastFinalTranscriptTime, this.userTurnStart), - ); + const lastSpeakingTime = this.lastSpeakingTime; + const lastFinalTranscriptTime = this.lastFinalTranscriptTime; + const speechStartTime = this.userTurnStart; + + // Audio-EOT detectors get a speaking-guard wrapper: if the user starts + // speaking again during the endpointing delay, abort the EOU and let + // the next turn drive the decision. Text-based detectors (no audio + // pipeline) keep the simpler bounce task — they can't race against + // mid-window utterances anyway since they don't run during silence. + const factory = hasAudioDetector + ? (controller: AbortController) => + this.bounceEOUTaskWithSpeakingGuard( + controller, + bounceEOUTask(lastSpeakingTime, lastFinalTranscriptTime, speechStartTime), + { lastSpeakingTime, lastFinalTranscriptTime, speechStartTime }, + ) + : bounceEOUTask(lastSpeakingTime, lastFinalTranscriptTime, speechStartTime); + this.bounceEOUTask = Task.from(factory); this.bounceEOUTask.result .then(() => { @@ -1340,6 +1719,11 @@ export class AudioRecognition { otelContext.with(ctx, () => this.hooks.onStartOfSpeech(ev)); } this.speaking = true; + this.userSpeakingEvent.set(); + + // Audio EOT FSM: re-arm the predict guard and tear down any + // in-flight inference for the now-stale prior window. + this.turnDetectorStream?.deactivate('vad sos'); // Capture sample rate from the first VAD event if not already set if (ev.frames.length > 0 && ev.frames[0]) { @@ -1359,6 +1743,29 @@ export class AudioRecognition { // ev.rawAccumulatedSpeech is in ms (VADEvent durations are all ms in TS). this.speechStartTime = Date.now() - ev.rawAccumulatedSpeech; } + // Wake any speaking-guard waiter — STT-only sessions don't + // see START_OF_SPEECH but do see INFERENCE_DONE-with-speech. + this.userSpeakingEvent.set(); + } else if (!this.speaking) { + // A sub-threshold speech spike can set `userSpeakingEvent` without + // ever reaching START_OF_SPEECH, so no END_OF_SPEECH will fire to + // clear it. Clear it here once speech drops back to zero (confirmed + // turns are cleared by EOS). + this.userSpeakingEvent.clear(); + } + + // Audio EOT FSM: warm up inference once we've seen enough + // trailing silence (matches Python's `MIN_SILENCE_DURATION_MS`). + // Skip if a prediction is already in flight or the agent is + // actively speaking (where buffered audio is mostly its own). + if ( + this.turnDetectorStream !== undefined && + ev.rawAccumulatedSilence >= MIN_SILENCE_DURATION_MS && + this.speaking && + !this.turnDetectorStream.isInferenceRunning && + !this.isAgentSpeaking + ) { + this.turnDetectorStream.warmup(); } break; case VADEventType.END_OF_SPEECH: @@ -1378,13 +1785,20 @@ export class AudioRecognition { // when VAD fires END_OF_SPEECH, it already waited for the silence_duration this.speaking = false; + this.userSpeakingEvent.clear(); + this.lastSpeakingTime = Date.now() - ev.silenceDuration - ev.inferenceDuration; + + // Audio EOT FSM: commit the inference window so the bounce + // task's `predictEndOfTurn` can resolve against the cached + // prediction (or wait for one). + this.turnDetectorStream?.activate('vad eos'); if ( this.vadBaseTurnDetection || (this.turnDetectionMode === 'stt' && this.userTurnCommitted) ) { const chatCtx = this.hooks.retrieveChatCtx(); - this.runEOUDetection(chatCtx); + this.runEOUDetection(chatCtx, 'vad'); } break; } @@ -1577,6 +1991,15 @@ export class AudioRecognition { this.speaking = false; this.userTurnCommitted = false; this.userTurnTracker = { words: 0, transcript: '' }; + // Clear the speaking event so a stale `set()` from the just-finished + // turn doesn't immediately trip the next speaking-guard race. + this.userSpeakingEvent.clear(); + // New turn → allow the next window's prediction to emit. + this.lastEmittedEotPrediction = undefined; + + // Any cached prediction on the audio stream belongs to the turn we + // just cleared — flush it so the next predict starts fresh. + this.turnDetectorStream?.flush('clear_user_turn'); if (this.userTurnSpan?.isRecording()) { this.userTurnSpan.end(); @@ -1650,7 +2073,7 @@ export class AudioRecognition { const chatCtx = this.hooks.retrieveChatCtx(); this.logger.debug('running EOU detection on commitUserTurn'); - this.runEOUDetection(chatCtx); + this.runEOUDetection(chatCtx, 'manual'); this.userTurnCommitted = true; }; @@ -1697,6 +2120,13 @@ export class AudioRecognition { await this.vadTask?.cancelAndWait(); await this.bounceEOUTask?.cancelAndWait(); await this.interruptionTask?.cancelAndWait(); + + if (this.turnDetectorStream !== undefined) { + const stream = this.turnDetectorStream; + this.turnDetectorStream = undefined; + await stream.aclose().catch(() => undefined); + } + await this.interruptionStreamChannel?.close(); this.cancelBackchannelBoundary(); } diff --git a/agents/src/voice/audio_recognition_span.test.ts b/agents/src/voice/audio_recognition_span.test.ts index cfe92a821..5ce592042 100644 --- a/agents/src/voice/audio_recognition_span.test.ts +++ b/agents/src/voice/audio_recognition_span.test.ts @@ -110,6 +110,7 @@ describe('AudioRecognition user_turn span parity', () => { onInterimTranscript: vi.fn(), onFinalTranscript: vi.fn(), onPreemptiveGeneration: vi.fn(), + onEotPrediction: vi.fn(), retrieveChatCtx: () => ChatContext.empty(), onEndOfTurn: vi.fn(async () => true), }; @@ -191,6 +192,7 @@ describe('AudioRecognition user_turn span parity', () => { onInterimTranscript: vi.fn(), onFinalTranscript: vi.fn(), onPreemptiveGeneration: vi.fn(), + onEotPrediction: vi.fn(), retrieveChatCtx: () => ChatContext.empty(), onEndOfTurn: vi.fn(async () => true), }; diff --git a/agents/src/voice/events.ts b/agents/src/voice/events.ts index 55cd13e0f..46a85b40c 100644 --- a/agents/src/voice/events.ts +++ b/agents/src/voice/events.ts @@ -33,6 +33,8 @@ export enum AgentSessionEventTypes { SpeechCreated = 'speech_created', AgentFalseInterruption = 'agent_false_interruption', OverlappingSpeech = 'overlapping_speech', + /** Audio EOT detector emitted a per-turn prediction. */ + EotPrediction = 'eot_prediction', Error = 'error', Close = 'close', } @@ -245,6 +247,46 @@ export const createSpeechCreatedEvent = ({ createdAt, }); +/** + * Audio EOT prediction landed on the wire. Emitted once per turn boundary + * decision when an `AudioTurnDetector` is wired into the session. + * + * Port of Python `EotPredictionEvent`. + */ +export type EotPredictionEvent = { + type: 'eot_prediction'; + /** End-of-turn probability in [0, 1] returned by the detector. */ + probability: number; + /** Threshold below which the detector treats the prediction as unlikely. */ + threshold: number; + /** Model-side inference time, in milliseconds. */ + inferenceDurationMs: number; + /** End-of-speech → prediction receive time, in milliseconds. */ + delayMs: number; + createdAt: number; +}; + +export const createEotPredictionEvent = ({ + probability, + threshold, + inferenceDurationMs, + delayMs, + createdAt = Date.now(), +}: { + probability: number; + threshold: number; + inferenceDurationMs: number; + delayMs: number; + createdAt?: number; +}): EotPredictionEvent => ({ + type: 'eot_prediction', + probability, + threshold, + inferenceDurationMs, + delayMs, + createdAt, +}); + export type UserTurnExceededEvent = { type: 'user_turn_exceeded'; /** Transcript from the current uncommitted user turn only. */ diff --git a/agents/src/voice/remote_session.ts b/agents/src/voice/remote_session.ts index f970b064e..9d0e0ce10 100644 --- a/agents/src/voice/remote_session.ts +++ b/agents/src/voice/remote_session.ts @@ -18,6 +18,7 @@ import { isInstructions, renderInstructions } from '../llm/chat_context.js'; import type { ToolContext } from '../llm/tool_context.js'; import { log } from '../log.js'; import type { + EOTModelUsage, InterruptionModelUsage, LLMModelUsage, STTModelUsage, @@ -32,6 +33,7 @@ import { type AgentState, type AgentStateChangedEvent, type ConversationItemAddedEvent, + type EotPredictionEvent, type ErrorEvent, type FunctionToolsExecutedEvent, type MetricsCollectedEvent, @@ -62,6 +64,7 @@ export type RemoteSessionEventTypes = | 'function_tools_executed' | 'overlapping_speech' | 'amd_prediction' + | 'eot_prediction' | 'session_usage' | 'error'; @@ -74,6 +77,7 @@ export type RemoteSessionCallbacks = { function_tools_executed: (ev: pb.AgentSessionEvent_FunctionToolsExecuted) => void; overlapping_speech: (ev: pb.AgentSessionEvent_OverlappingSpeech) => void; amd_prediction: (ev: pb.AgentSessionEvent_AmdPrediction) => void; + eot_prediction: (ev: pb.AgentSessionEvent_EotPrediction) => void; session_usage: (ev: pb.AgentSessionEvent_SessionUsageUpdated) => void; error: (ev: pb.AgentSessionEvent_Error) => void; }; @@ -464,6 +468,22 @@ function sessionUsageToProto(usage: AgentSessionUsage): pb.AgentSessionUsage { ); break; } + case 'eot_usage': { + const eu = mu as Partial; + modelUsages.push( + new pb.ModelUsage({ + usage: { + case: 'eot', + value: new pb.EotModelUsage({ + provider: eu.provider ?? '', + model: eu.model ?? '', + totalRequests: eu.totalRequests ?? 0, + }), + }, + }), + ); + break; + } } } return new pb.AgentSessionUsage({ modelUsage: modelUsages }); @@ -521,6 +541,7 @@ export class SessionHost { session.on(AgentSessionEventTypes.FunctionToolsExecuted, this.onFunctionToolsExecuted); session.on(AgentSessionEventTypes.MetricsCollected, this.onMetricsCollected); session.on(AgentSessionEventTypes.OverlappingSpeech, this.onOverlappingSpeech); + session.on(AgentSessionEventTypes.EotPrediction, this.onEotPrediction); session.on(AgentSessionEventTypes.Error, this.onHostError); } } @@ -549,6 +570,7 @@ export class SessionHost { this.session.off(AgentSessionEventTypes.FunctionToolsExecuted, this.onFunctionToolsExecuted); this.session.off(AgentSessionEventTypes.MetricsCollected, this.onMetricsCollected); this.session.off(AgentSessionEventTypes.OverlappingSpeech, this.onOverlappingSpeech); + this.session.off(AgentSessionEventTypes.EotPrediction, this.onEotPrediction); this.session.off(AgentSessionEventTypes.Error, this.onHostError); } @@ -676,6 +698,10 @@ export class SessionHost { ); }; + private onEotPrediction = (event: EotPredictionEvent): void => { + this._onEotPrediction(event); + }; + private onOverlappingSpeech = (event: OverlappingSpeechEvent): void => { const value = new pb.AgentSessionEvent_OverlappingSpeech({ isInterruption: event.isInterruption, @@ -731,6 +757,23 @@ export class SessionHost { }); } + /** + * @internal — forwards an audio-EOT prediction to the connected + * {@link RemoteSession} peer. Mirrors python + * `SessionHost._on_eot_prediction`. + */ + _onEotPrediction(event: EotPredictionEvent): void { + this.emitEvent({ + case: 'eotPrediction', + value: new pb.AgentSessionEvent_EotPrediction({ + probability: event.probability, + threshold: event.threshold, + inferenceDuration: msToDuration(event.inferenceDurationMs), + delay: msToDuration(event.delayMs), + }), + }); + } + private async handleRequestSafe(req: pb.SessionRequest): Promise { try { await this.handleRequest(req); @@ -1001,6 +1044,9 @@ export class RemoteSession extends (EventEmitter as new () => TypedEventEmitter< case 'amdPrediction': this.emit('amd_prediction', ev.value); break; + case 'eotPrediction': + this.emit('eot_prediction', ev.value); + break; case 'sessionUsageUpdated': this.emit('session_usage', ev.value); break; diff --git a/agents/src/worker.ts b/agents/src/worker.ts index 3ecc45dc3..ae70b452d 100644 --- a/agents/src/worker.ts +++ b/agents/src/worker.ts @@ -15,10 +15,13 @@ import type { ParticipantInfo } from 'livekit-server-sdk'; import { AccessToken, RoomServiceClient } from 'livekit-server-sdk'; import { EventEmitter } from 'node:events'; import { availableParallelism } from 'node:os'; +import { extname } from 'node:path'; import { WebSocket } from 'ws'; import { APIStatusError } from './_exceptions.js'; import { getCpuMonitor } from './cpu.js'; import { HTTPServer } from './http_server.js'; +import { _getLocalInferenceModule } from './inference/_warmup.js'; +import { EOT_INFERENCE_METHOD } from './inference/eot/runner.js'; import { InferenceRunner } from './inference_runner.js'; import { InferenceProcExecutor } from './ipc/inference_proc_executor.js'; import { ProcPool } from './ipc/proc_pool.js'; @@ -33,6 +36,32 @@ const ASSIGNMENT_TIMEOUT = 7.5 * 1000; const UPDATE_LOAD_INTERVAL = 2.5 * 1000; const PROJECT_TYPE = 'nodejs'; +let localEotRunnerRegistered = false; +/** + * Register the local audio-EOT inference runner so it runs in the shared + * inference process. Idempotent and guarded by native-binding availability; + * a no-op (with a one-time warning) when `@livekit/local-inference` can't be + * loaded so the worker still starts on unsupported platforms. + */ +function maybeRegisterLocalEotRunner(): void { + if (localEotRunnerRegistered) return; + localEotRunnerRegistered = true; + if (InferenceRunner.registeredRunners[EOT_INFERENCE_METHOD]) return; + if (_getLocalInferenceModule() === undefined) { + log().warn( + '@livekit/local-inference native binding unavailable; local audio EOT disabled ' + + '(predictions will degrade to a positive default). cloud EOT and other turn ' + + 'detection modes are unaffected.', + ); + return; + } + const ext = extname(import.meta.url); // '.js' (built) or '.ts' (tsx/ts-node) + InferenceRunner.registerRunner( + EOT_INFERENCE_METHOD, + new URL(`./inference/eot/runner${ext}`, import.meta.url).toString(), + ); +} + class Default { static loadThreshold(production: boolean): number { if (production) { @@ -322,6 +351,13 @@ export class AgentServer { } } + // Register the local audio-EOT runner so it runs in the shared inference + // process (loaded once per host, ~138 MB) instead of in every job worker. + // Guarded by binding availability: on a platform where + // `@livekit/local-inference` can't load, skip registration so the worker + // still starts (local EOT then degrades to a positive-default prediction). + maybeRegisterLocalEotRunner(); + if (Object.entries(InferenceRunner.registeredRunners).length) { this.#inferenceExecutor = new InferenceProcExecutor({ runners: InferenceRunner.registeredRunners, diff --git a/examples/src/anam_realtime_agent.ts b/examples/src/anam_realtime_agent.ts index bc2fd33e4..376bb1ba7 100644 --- a/examples/src/anam_realtime_agent.ts +++ b/examples/src/anam_realtime_agent.ts @@ -3,7 +3,6 @@ // SPDX-License-Identifier: Apache-2.0 import { type JobContext, - type JobProcess, ServerOptions, cli, defineAgent, @@ -14,15 +13,11 @@ import { import * as anam from '@livekit/agents-plugin-anam'; import * as livekit from '@livekit/agents-plugin-livekit'; import * as openai from '@livekit/agents-plugin-openai'; -import * as silero from '@livekit/agents-plugin-silero'; import { fileURLToPath } from 'node:url'; // Uses OpenAI Advanced Voice (Realtime), so no separate STT/TTS/VAD. export default defineAgent({ - prewarm: async (proc: JobProcess) => { - proc.userData.vad = await silero.VAD.load(); - }, entry: async (ctx: JobContext) => { initializeLogger({ pretty: true }); @@ -31,7 +26,6 @@ export default defineAgent({ }); const session = new voice.AgentSession({ - vad: ctx.proc.userData.vad! as silero.VAD, stt: new inference.STT({ model: 'deepgram/nova-3', language: 'en' }), tts: new inference.TTS({ model: 'cartesia/sonic-3', diff --git a/examples/src/basic_agent.ts b/examples/src/basic_agent.ts index 95ecddb9a..a7f4d0925 100644 --- a/examples/src/basic_agent.ts +++ b/examples/src/basic_agent.ts @@ -3,7 +3,6 @@ // SPDX-License-Identifier: Apache-2.0 import { type JobContext, - type JobProcess, ServerOptions, cli, defineAgent, @@ -13,16 +12,15 @@ import { metrics, voice, } from '@livekit/agents'; -import * as livekit from '@livekit/agents-plugin-livekit'; -import * as silero from '@livekit/agents-plugin-silero'; +// import * as livekit from '@livekit/agents-plugin-livekit'; import { BackgroundVoiceCancellation } from '@livekit/noise-cancellation-node'; import { fileURLToPath } from 'node:url'; import { z } from 'zod'; +// No prewarm hook needed: the local EOT model runs in the shared inference +// process (loaded once per host), and the silero VAD (~2MB, in-process) +// lazy-loads on first stream. export default defineAgent({ - prewarm: async (proc: JobProcess) => { - proc.userData.vad = await silero.VAD.load(); - }, entry: async (ctx: JobContext) => { const agent = new voice.Agent({ instructions: @@ -43,9 +41,6 @@ export default defineAgent({ const logger = log(); const session = new voice.AgentSession({ - // VAD and turn detection are used to determine when the user is speaking and when the agent should respond - // See more at https://docs.livekit.io/agents/build/turns - vad: ctx.proc.userData.vad! as silero.VAD, // Speech-to-text (STT) is your agent's ears, turning the user's speech into text that the LLM can understand // See all available models at https://docs.livekit.io/agents/models/stt/ stt: new inference.STT({ @@ -69,7 +64,11 @@ export default defineAgent({ }), ttsTextTransforms: ['filter_markdown', 'filter_emoji'], turnHandling: { - turnDetection: new livekit.turnDetector.MultilingualModel(), + // turn detection determines when the agent should respond. See https://docs.livekit.io/agents/build/turns + turnDetection: new inference.AudioTurnDetector(), + // To use the local on-device text turn detector instead, re-enable the + // `@livekit/agents-plugin-livekit` import above and use: + // turnDetection: new livekit.turnDetector.MultilingualModel(), interruption: { // Enable false-interruption auto-resume behavior. resumeFalseInterruption: true, @@ -118,7 +117,7 @@ export default defineAgent({ }); session.on(voice.AgentSessionEventTypes.OverlappingSpeech, (ev) => { - logger.warn({ type: ev.type, isInterruption: ev.isInterruption }, 'user overlapping speech'); + logger.info({ type: ev.type, isInterruption: ev.isInterruption }, 'user overlapping speech'); }); await session.start({ diff --git a/examples/src/basic_agent_task.ts b/examples/src/basic_agent_task.ts index aacbeee5c..0549f4197 100644 --- a/examples/src/basic_agent_task.ts +++ b/examples/src/basic_agent_task.ts @@ -3,7 +3,6 @@ // SPDX-License-Identifier: Apache-2.0 import { type JobContext, - type JobProcess, ServerOptions, cli, defineAgent, @@ -12,7 +11,6 @@ import { voice, } from '@livekit/agents'; import * as openai from '@livekit/agents-plugin-openai'; -import * as silero from '@livekit/agents-plugin-silero'; import { fileURLToPath } from 'node:url'; import { z } from 'zod'; @@ -110,12 +108,8 @@ class SurveyAgent extends voice.Agent { } export default defineAgent({ - prewarm: async (proc: JobProcess) => { - proc.userData.vad = await silero.VAD.load(); - }, entry: async (ctx: JobContext) => { const session = new voice.AgentSession({ - vad: ctx.proc.userData.vad as silero.VAD, stt: new inference.STT({ model: 'deepgram/nova-3' }), llm: new openai.responses.LLM({ useWebSocket: true }), tts: new inference.TTS({ diff --git a/examples/src/basic_task_group.ts b/examples/src/basic_task_group.ts index d40befe2a..0c24c2059 100644 --- a/examples/src/basic_task_group.ts +++ b/examples/src/basic_task_group.ts @@ -3,7 +3,6 @@ // SPDX-License-Identifier: Apache-2.0 import { type JobContext, - type JobProcess, ServerOptions, beta, cli, @@ -13,7 +12,6 @@ import { voice, } from '@livekit/agents'; import * as openai from '@livekit/agents-plugin-openai'; -import * as silero from '@livekit/agents-plugin-silero'; import { fileURLToPath } from 'node:url'; import { z } from 'zod'; @@ -120,12 +118,8 @@ class TaskGroupDemoAgent extends voice.Agent { } export default defineAgent({ - prewarm: async (proc: JobProcess) => { - proc.userData.vad = await silero.VAD.load(); - }, entry: async (ctx: JobContext) => { const session = new voice.AgentSession({ - vad: ctx.proc.userData.vad as silero.VAD, stt: new inference.STT({ model: 'deepgram/nova-3' }), llm: new openai.responses.LLM({ model: 'gpt-5.2', diff --git a/examples/src/basic_tool_call_agent.ts b/examples/src/basic_tool_call_agent.ts index 5642ef488..d29d72119 100644 --- a/examples/src/basic_tool_call_agent.ts +++ b/examples/src/basic_tool_call_agent.ts @@ -3,7 +3,6 @@ // SPDX-License-Identifier: Apache-2.0 import { type JobContext, - type JobProcess, ServerOptions, cli, defineAgent, @@ -12,7 +11,6 @@ import { voice, } from '@livekit/agents'; import * as livekit from '@livekit/agents-plugin-livekit'; -import * as silero from '@livekit/agents-plugin-silero'; import { BackgroundVoiceCancellation } from '@livekit/noise-cancellation-node'; import { fileURLToPath } from 'node:url'; import { z } from 'zod'; @@ -39,9 +37,6 @@ class GameAgent extends voice.Agent { } export default defineAgent({ - prewarm: async (proc: JobProcess) => { - proc.userData.vad = await silero.VAD.load(); - }, entry: async (ctx: JobContext) => { const getWeather = llm.tool({ description: ' Called when the user asks about the weather.', @@ -133,10 +128,7 @@ export default defineAgent({ }, }); - const vad = ctx.proc.userData.vad! as silero.VAD; - const session = new voice.AgentSession({ - vad, stt: new inference.STT({ model: 'deepgram/nova-3', language: 'en' }), llm: new inference.LLM({ model: 'google/gemini-3-flash-preview' }), tts: new inference.TTS({ diff --git a/examples/src/cartesia.ts b/examples/src/cartesia.ts index 34cd47fd4..e49bb623d 100644 --- a/examples/src/cartesia.ts +++ b/examples/src/cartesia.ts @@ -4,7 +4,6 @@ import type { llm as llmModule } from '@livekit/agents'; import { type JobContext, - type JobProcess, ServerOptions, cli, defineAgent, @@ -16,14 +15,10 @@ import { import * as cartesia from '@livekit/agents-plugin-cartesia'; import * as google from '@livekit/agents-plugin-google'; import * as openai from '@livekit/agents-plugin-openai'; -import * as silero from '@livekit/agents-plugin-silero'; import { BackgroundVoiceCancellation } from '@livekit/noise-cancellation-node'; import { fileURLToPath } from 'node:url'; export default defineAgent({ - prewarm: async (proc: JobProcess) => { - proc.userData.vad = await silero.VAD.load(); - }, entry: async (ctx: JobContext) => { const agent = new voice.Agent({ instructions: @@ -31,8 +26,6 @@ export default defineAgent({ }); const logger = log(); - const vad = - ctx.proc.userData.vad instanceof silero.VAD ? ctx.proc.userData.vad : await silero.VAD.load(); const apiKey = process.env.CARTESIA_API_KEY; @@ -67,7 +60,6 @@ export default defineAgent({ } const session = new voice.AgentSession({ - vad, stt: new cartesia.STT({ model: 'ink-2', apiKey }), llm, tts: new cartesia.TTS({ model: 'sonic-3.5', apiKey }), diff --git a/examples/src/comprehensive_test.ts b/examples/src/comprehensive_test.ts index 6e9fc8f07..9e9e25225 100644 --- a/examples/src/comprehensive_test.ts +++ b/examples/src/comprehensive_test.ts @@ -3,7 +3,6 @@ // SPDX-License-Identifier: Apache-2.0 import { type JobContext, - type JobProcess, ServerOptions, cli, dedent, @@ -21,7 +20,6 @@ import * as livekit from '@livekit/agents-plugin-livekit'; import * as neuphonic from '@livekit/agents-plugin-neuphonic'; import * as openai from '@livekit/agents-plugin-openai'; import * as resemble from '@livekit/agents-plugin-resemble'; -import * as silero from '@livekit/agents-plugin-silero'; import { BackgroundVoiceCancellation } from '@livekit/noise-cancellation-node'; import { fileURLToPath } from 'node:url'; import { z } from 'zod'; @@ -238,14 +236,9 @@ class TestAgent extends voice.Agent { } export default defineAgent({ - prewarm: async (proc: JobProcess) => { - proc.userData.vad = await silero.VAD.load(); - }, entry: async (ctx: JobContext) => { const logger = log(); - const vad = ctx.proc.userData.vad! as silero.VAD; const session = new voice.AgentSession({ - vad, userData: { testedSttChoices: new Set(), testedTtsChoices: new Set(), diff --git a/examples/src/custom_text_handler.ts b/examples/src/custom_text_handler.ts index 5ba65e773..29de2b848 100644 --- a/examples/src/custom_text_handler.ts +++ b/examples/src/custom_text_handler.ts @@ -3,7 +3,6 @@ // SPDX-License-Identifier: Apache-2.0 import { type JobContext, - type JobProcess, ServerOptions, cli, defineAgent, @@ -11,7 +10,6 @@ import { voice, } from '@livekit/agents'; import * as livekit from '@livekit/agents-plugin-livekit'; -import * as silero from '@livekit/agents-plugin-silero'; import { BackgroundVoiceCancellation } from '@livekit/noise-cancellation-node'; import { fileURLToPath } from 'node:url'; @@ -39,19 +37,13 @@ const customTextInputHandler = (session: voice.AgentSession, event: voice.TextIn }; export default defineAgent({ - prewarm: async (proc: JobProcess) => { - proc.userData.vad = await silero.VAD.load(); - }, entry: async (ctx: JobContext) => { const agent = new voice.Agent({ instructions: "You are a helpful assistant, you can hear the user's message and respond to it.", }); - const vad = ctx.proc.userData.vad! as silero.VAD; - const session = new voice.AgentSession({ - vad, stt: new inference.STT({ model: 'deepgram/nova-3', language: 'en' }), llm: new inference.LLM({ model: 'openai/gpt-4.1-mini' }), tts: new inference.TTS({ diff --git a/examples/src/drive-thru/drivethru_agent.ts b/examples/src/drive-thru/drivethru_agent.ts index 9882f6fcd..0da757a0a 100644 --- a/examples/src/drive-thru/drivethru_agent.ts +++ b/examples/src/drive-thru/drivethru_agent.ts @@ -3,10 +3,10 @@ // SPDX-License-Identifier: Apache-2.0 import { type JobContext, - type JobProcess, ServerOptions, cli, defineAgent, + inference, llm, voice, } from '@livekit/agents'; @@ -14,7 +14,6 @@ import * as deepgram from '@livekit/agents-plugin-deepgram'; import * as elevenlabs from '@livekit/agents-plugin-elevenlabs'; import * as livekit from '@livekit/agents-plugin-livekit'; import * as openai from '@livekit/agents-plugin-openai'; -import * as silero from '@livekit/agents-plugin-silero'; import { fileURLToPath } from 'node:url'; import { z } from 'zod'; import { @@ -376,15 +375,13 @@ export async function newUserData(): Promise { } export default defineAgent({ - prewarm: async (proc: JobProcess) => { - proc.userData.vad = await silero.VAD.load(); - }, entry: async (ctx: JobContext) => { const userdata = await newUserData(); - const vad = ctx.proc.userData.vad! as silero.VAD; const session = new voice.AgentSession({ - vad, + // VAD lazy-loads the bundled silero model on first stream, so no + // prewarm hook is needed. + vad: new inference.VAD(), stt: new deepgram.STT(), llm: new openai.LLM({ model: 'gpt-4.1', temperature: 0.45 }), tts: new elevenlabs.TTS(), diff --git a/examples/src/elevenlabs_scribe_v2.ts b/examples/src/elevenlabs_scribe_v2.ts index d0574c02c..dbb4ac1fa 100644 --- a/examples/src/elevenlabs_scribe_v2.ts +++ b/examples/src/elevenlabs_scribe_v2.ts @@ -3,7 +3,6 @@ // SPDX-License-Identifier: Apache-2.0 import { type JobContext, - type JobProcess, ServerOptions, cli, defineAgent, @@ -11,13 +10,9 @@ import { voice, } from '@livekit/agents'; import * as elevenlabs from '@livekit/agents-plugin-elevenlabs'; -import * as silero from '@livekit/agents-plugin-silero'; import { fileURLToPath } from 'node:url'; export default defineAgent({ - prewarm: async (proc: JobProcess) => { - proc.userData.vad = await silero.VAD.load(); - }, entry: async (ctx: JobContext) => { const stt = new elevenlabs.STT({ useRealtime: true, @@ -32,7 +27,6 @@ export default defineAgent({ const session = new voice.AgentSession({ voiceOptions: { allowInterruptions: true }, - vad: ctx.proc.userData.vad! as silero.VAD, stt, llm: new inference.LLM({ model: 'openai/gpt-4.1-mini' }), tts: new inference.TTS({ model: 'cartesia/sonic-3' }), diff --git a/examples/src/frontdesk/frontdesk_agent.ts b/examples/src/frontdesk/frontdesk_agent.ts index d5d2e1ab1..316d64474 100644 --- a/examples/src/frontdesk/frontdesk_agent.ts +++ b/examples/src/frontdesk/frontdesk_agent.ts @@ -3,10 +3,10 @@ // SPDX-License-Identifier: Apache-2.0 import { type JobContext, - type JobProcess, ServerOptions, cli, defineAgent, + inference, llm, voice, } from '@livekit/agents'; @@ -14,7 +14,6 @@ import * as deepgram from '@livekit/agents-plugin-deepgram'; import * as elevenlabs from '@livekit/agents-plugin-elevenlabs'; import * as livekit from '@livekit/agents-plugin-livekit'; import * as openai from '@livekit/agents-plugin-openai'; -import * as silero from '@livekit/agents-plugin-silero'; import { BackgroundVoiceCancellation } from '@livekit/noise-cancellation-node'; import { fileURLToPath } from 'node:url'; import { z } from 'zod'; @@ -196,9 +195,6 @@ You must infer the appropriate range implicitly from the conversational context } export default defineAgent({ - prewarm: async (proc: JobProcess) => { - proc.userData.vad = await silero.VAD.load(); - }, entry: async (ctx: JobContext) => { const timezone = 'UTC'; @@ -220,7 +216,9 @@ export default defineAgent({ const userdata: Userdata = { cal }; const session = new voice.AgentSession({ - vad: ctx.proc.userData.vad! as silero.VAD, + // VAD lazy-loads the bundled silero model on first stream, so no + // prewarm hook is needed. + vad: new inference.VAD(), stt: new deepgram.STT(), llm: new openai.LLM({ model: 'gpt-4.1', diff --git a/examples/src/gemini_realtime_agent.ts b/examples/src/gemini_realtime_agent.ts index 60cbb443e..3b82c8ec0 100644 --- a/examples/src/gemini_realtime_agent.ts +++ b/examples/src/gemini_realtime_agent.ts @@ -3,7 +3,6 @@ // SPDX-License-Identifier: Apache-2.0 import { type JobContext, - type JobProcess, ServerOptions, cli, dedent, @@ -12,7 +11,6 @@ import { voice, } from '@livekit/agents'; import * as google from '@livekit/agents-plugin-google'; -import * as silero from '@livekit/agents-plugin-silero'; import { fileURLToPath } from 'node:url'; import { z } from 'zod'; @@ -117,14 +115,10 @@ class StoryAgent extends voice.Agent { } export default defineAgent({ - prewarm: async (proc: JobProcess) => { - proc.userData.vad = await silero.VAD.load(); - }, entry: async (ctx: JobContext) => { const userdata: StoryData = {}; const session = new voice.AgentSession({ - vad: ctx.proc.userData.vad! as silero.VAD, llm: new google.realtime.RealtimeModel({ thinkingConfig: { // Making the thoughts false to speed up the realtime response diff --git a/examples/src/hume_tts.ts b/examples/src/hume_tts.ts index fbf05c689..00425722d 100644 --- a/examples/src/hume_tts.ts +++ b/examples/src/hume_tts.ts @@ -3,7 +3,6 @@ // SPDX-License-Identifier: Apache-2.0 import { type JobContext, - type JobProcess, ServerOptions, cli, defineAgent, @@ -13,14 +12,10 @@ import { } from '@livekit/agents'; import * as hume from '@livekit/agents-plugin-hume'; import * as livekit from '@livekit/agents-plugin-livekit'; -import * as silero from '@livekit/agents-plugin-silero'; import { BackgroundVoiceCancellation } from '@livekit/noise-cancellation-node'; import { fileURLToPath } from 'node:url'; export default defineAgent({ - prewarm: async (proc: JobProcess) => { - proc.userData.vad = await silero.VAD.load(); - }, entry: async (ctx: JobContext) => { const agent = new voice.Agent({ instructions: @@ -39,7 +34,6 @@ export default defineAgent({ stt: 'deepgram/nova-3', llm: 'openai/gpt-4.1-mini', tts, - vad: ctx.proc.userData.vad! as silero.VAD, turnDetection: new livekit.turnDetector.MultilingualModel(), voiceOptions: { preemptiveGeneration: true, diff --git a/examples/src/idle_user_timeout_example.ts b/examples/src/idle_user_timeout_example.ts index 47b9d2643..d326c881b 100644 --- a/examples/src/idle_user_timeout_example.ts +++ b/examples/src/idle_user_timeout_example.ts @@ -8,7 +8,6 @@ */ import { type JobContext, - type JobProcess, ServerOptions, Task, cli, @@ -19,21 +18,15 @@ import { log, voice, } from '@livekit/agents'; -import * as silero from '@livekit/agents-plugin-silero'; import { fileURLToPath } from 'node:url'; initializeLogger({ pretty: true }); export default defineAgent({ - prewarm: async (proc: JobProcess) => { - proc.userData.vad = await silero.VAD.load(); - }, entry: async (ctx: JobContext) => { const logger = log(); - const vad = ctx.proc.userData.vad! as silero.VAD; const session = new voice.AgentSession({ - vad, llm: new inference.LLM({ model: 'openai/gpt-4.1-mini' }), stt: new inference.STT({ model: 'deepgram/nova-3', language: 'en' }), tts: new inference.TTS({ diff --git a/examples/src/instructions_per_modality.ts b/examples/src/instructions_per_modality.ts index 71f2f3f04..793f9819c 100644 --- a/examples/src/instructions_per_modality.ts +++ b/examples/src/instructions_per_modality.ts @@ -3,7 +3,6 @@ // SPDX-License-Identifier: Apache-2.0 import { type JobContext, - type JobProcess, ServerOptions, cli, defineAgent, @@ -12,7 +11,6 @@ import { log, voice, } from '@livekit/agents'; -import * as silero from '@livekit/agents-plugin-silero'; import { fileURLToPath } from 'node:url'; import { z } from 'zod'; @@ -79,12 +77,8 @@ class SchedulingAgent extends voice.Agent { } export default defineAgent({ - prewarm: async (proc: JobProcess) => { - proc.userData.vad = await silero.VAD.load(); - }, entry: async (ctx: JobContext) => { const session = new voice.AgentSession({ - vad: ctx.proc.userData.vad! as silero.VAD, stt: new inference.STT({ model: 'deepgram/nova-3' }), llm: new inference.LLM({ model: 'openai/gpt-4.1-mini' }), tts: new inference.TTS({ diff --git a/examples/src/inworld_tts.ts b/examples/src/inworld_tts.ts index 9a4ddf2a4..2854a3660 100644 --- a/examples/src/inworld_tts.ts +++ b/examples/src/inworld_tts.ts @@ -3,7 +3,6 @@ // SPDX-License-Identifier: Apache-2.0 import { type JobContext, - type JobProcess, ServerOptions, cli, defineAgent, @@ -13,14 +12,10 @@ import { } from '@livekit/agents'; import * as inworld from '@livekit/agents-plugin-inworld'; import * as livekit from '@livekit/agents-plugin-livekit'; -import * as silero from '@livekit/agents-plugin-silero'; import { BackgroundVoiceCancellation } from '@livekit/noise-cancellation-node'; import { fileURLToPath } from 'node:url'; export default defineAgent({ - prewarm: async (proc: JobProcess) => { - proc.userData.vad = await silero.VAD.load(); - }, entry: async (ctx: JobContext) => { const agent = new voice.Agent({ instructions: @@ -71,7 +66,6 @@ export default defineAgent({ tts, // VAD and turn detection are used to determine when the user is speaking and when the agent should respond // See more at https://docs.livekit.io/agents/build/turns - vad: ctx.proc.userData.vad! as silero.VAD, turnDetection: new livekit.turnDetector.MultilingualModel(), // to use realtime model, replace the stt, llm, tts and vad with the following // llm: new openai.realtime.RealtimeModel(), diff --git a/examples/src/lemonslice_realtime_avatar.ts b/examples/src/lemonslice_realtime_avatar.ts index b2afc544b..c33f2ad51 100644 --- a/examples/src/lemonslice_realtime_avatar.ts +++ b/examples/src/lemonslice_realtime_avatar.ts @@ -3,7 +3,6 @@ // SPDX-License-Identifier: Apache-2.0 import { type JobContext, - type JobProcess, ServerOptions, cli, defineAgent, @@ -13,15 +12,11 @@ import { } from '@livekit/agents'; import * as lemonslice from '@livekit/agents-plugin-lemonslice'; import * as livekit from '@livekit/agents-plugin-livekit'; -import * as silero from '@livekit/agents-plugin-silero'; import { fileURLToPath } from 'node:url'; initializeLogger({ pretty: true }); export default defineAgent({ - prewarm: async (proc: JobProcess) => { - proc.userData.vad = await silero.VAD.load(); - }, entry: async (ctx: JobContext) => { try { const agent = new voice.Agent({ @@ -41,7 +36,6 @@ export default defineAgent({ voice: '9626c31c-bec5-4cca-baa8-f8ba9e84c8bc', }), turnDetection: new livekit.turnDetector.MultilingualModel(), - vad: ctx.proc.userData.vad! as silero.VAD, turnHandling: { interruption: { resumeFalseInterruption: false, diff --git a/examples/src/liveavatar_avatar.ts b/examples/src/liveavatar_avatar.ts index fe502d3b3..c5fc8f8e7 100644 --- a/examples/src/liveavatar_avatar.ts +++ b/examples/src/liveavatar_avatar.ts @@ -3,7 +3,6 @@ // SPDX-License-Identifier: Apache-2.0 import { type JobContext, - type JobProcess, ServerOptions, cli, defineAgent, @@ -14,13 +13,9 @@ import { } from '@livekit/agents'; import * as liveavatar from '@livekit/agents-plugin-liveavatar'; import * as livekit from '@livekit/agents-plugin-livekit'; -import * as silero from '@livekit/agents-plugin-silero'; import { fileURLToPath } from 'node:url'; export default defineAgent({ - prewarm: async (proc: JobProcess) => { - proc.userData.vad = await silero.VAD.load(); - }, entry: async (ctx: JobContext) => { const logger = log().child({ example: 'liveavatar_avatar' }); @@ -40,7 +35,6 @@ export default defineAgent({ voice: '9626c31c-bec5-4cca-baa8-f8ba9e84c8bc', }), turnDetection: new livekit.turnDetector.MultilingualModel(), - vad: ctx.proc.userData.vad! as silero.VAD, voiceOptions: { preemptiveGeneration: true, }, diff --git a/examples/src/llm_fallback_adapter.ts b/examples/src/llm_fallback_adapter.ts index d053464dc..d3d407214 100644 --- a/examples/src/llm_fallback_adapter.ts +++ b/examples/src/llm_fallback_adapter.ts @@ -16,26 +16,14 @@ * - Configurable timeouts and retry behavior * - Event emission when provider availability changes */ -import { - type JobContext, - type JobProcess, - ServerOptions, - cli, - defineAgent, - llm, - voice, -} from '@livekit/agents'; +import { type JobContext, ServerOptions, cli, defineAgent, llm, voice } from '@livekit/agents'; import * as deepgram from '@livekit/agents-plugin-deepgram'; import * as elevenlabs from '@livekit/agents-plugin-elevenlabs'; import * as openai from '@livekit/agents-plugin-openai'; -import * as silero from '@livekit/agents-plugin-silero'; import { fileURLToPath } from 'node:url'; import { z } from 'zod'; export default defineAgent({ - prewarm: async (proc: JobProcess) => { - proc.userData.vad = await silero.VAD.load(); - }, entry: async (ctx: JobContext) => { // Create multiple LLM instances for fallback // The FallbackAdapter will try them in order: primary -> secondary -> tertiary @@ -85,7 +73,6 @@ export default defineAgent({ }); const session = new voice.AgentSession({ - vad: ctx.proc.userData.vad! as silero.VAD, stt: new deepgram.STT(), tts: new elevenlabs.TTS(), llm: fallbackLLM, // Use the FallbackAdapter instead of a single LLM diff --git a/examples/src/manual_shutdown.ts b/examples/src/manual_shutdown.ts index 96bedb901..a4efe110d 100644 --- a/examples/src/manual_shutdown.ts +++ b/examples/src/manual_shutdown.ts @@ -3,7 +3,6 @@ // SPDX-License-Identifier: Apache-2.0 import { type JobContext, - type JobProcess, ServerOptions, cli, defineAgent, @@ -12,15 +11,11 @@ import { voice, } from '@livekit/agents'; import * as livekit from '@livekit/agents-plugin-livekit'; -import * as silero from '@livekit/agents-plugin-silero'; import { BackgroundVoiceCancellation } from '@livekit/noise-cancellation-node'; import { fileURLToPath } from 'node:url'; import { z } from 'zod'; export default defineAgent({ - prewarm: async (proc: JobProcess) => { - proc.userData.vad = await silero.VAD.load(); - }, entry: async (ctx: JobContext) => { const agent = new voice.Agent({ instructions: @@ -66,7 +61,6 @@ export default defineAgent({ model: 'cartesia/sonic-3', voice: '9626c31c-bec5-4cca-baa8-f8ba9e84c8bc', }), - vad: ctx.proc.userData.vad! as silero.VAD, turnDetection: new livekit.turnDetector.MultilingualModel(), voiceOptions: { preemptiveGeneration: true, diff --git a/examples/src/multi_agent.ts b/examples/src/multi_agent.ts index 7f4819bed..ad0abd1f6 100644 --- a/examples/src/multi_agent.ts +++ b/examples/src/multi_agent.ts @@ -3,7 +3,6 @@ // SPDX-License-Identifier: Apache-2.0 import { type JobContext, - type JobProcess, ServerOptions, cli, dedent, @@ -13,7 +12,6 @@ import { voice, } from '@livekit/agents'; import * as livekit from '@livekit/agents-plugin-livekit'; -import * as silero from '@livekit/agents-plugin-silero'; import { fileURLToPath } from 'node:url'; import { z } from 'zod'; @@ -72,14 +70,10 @@ class StoryAgent extends voice.Agent { } export default defineAgent({ - prewarm: async (proc: JobProcess) => { - proc.userData.vad = await silero.VAD.load(); - }, entry: async (ctx: JobContext) => { const userdata: StoryData = {}; const session = new voice.AgentSession({ - vad: ctx.proc.userData.vad! as silero.VAD, stt: new inference.STT({ model: 'deepgram/nova-3', language: 'en' }), tts: new inference.TTS({ model: 'cartesia/sonic-3', diff --git a/examples/src/push_to_talk.ts b/examples/src/push_to_talk.ts index ecba61363..06dbba40a 100644 --- a/examples/src/push_to_talk.ts +++ b/examples/src/push_to_talk.ts @@ -3,7 +3,6 @@ // SPDX-License-Identifier: Apache-2.0 import { type JobContext, - type JobProcess, ServerOptions, cli, defineAgent, @@ -11,7 +10,6 @@ import { initializeLogger, voice, } from '@livekit/agents'; -import * as silero from '@livekit/agents-plugin-silero'; import type { ChatContext, ChatMessage } from 'agents/dist/llm/chat_context.js'; import { fileURLToPath } from 'node:url'; @@ -25,14 +23,10 @@ class MyAgent extends voice.Agent { } export default defineAgent({ - prewarm: async (proc: JobProcess) => { - proc.userData.vad = await silero.VAD.load(); - }, entry: async (ctx: JobContext) => { initializeLogger({ pretty: true }); const session = new voice.AgentSession({ - vad: ctx.proc.userData.vad! as silero.VAD, stt: new inference.STT({ model: 'deepgram/nova-3', language: 'en' }), llm: new inference.LLM({ model: 'openai/gpt-4.1-mini' }), tts: new inference.TTS({ diff --git a/examples/src/raw_function_description.ts b/examples/src/raw_function_description.ts index 6548fd011..8a407f05c 100644 --- a/examples/src/raw_function_description.ts +++ b/examples/src/raw_function_description.ts @@ -3,7 +3,6 @@ // SPDX-License-Identifier: Apache-2.0 import { type JobContext, - type JobProcess, ServerOptions, cli, defineAgent, @@ -12,7 +11,6 @@ import { voice, } from '@livekit/agents'; import * as livekit from '@livekit/agents-plugin-livekit'; -import * as silero from '@livekit/agents-plugin-silero'; import { fileURLToPath } from 'node:url'; function createRawFunctionAgent() { @@ -48,14 +46,8 @@ function createRawFunctionAgent() { } export default defineAgent({ - prewarm: async (proc: JobProcess) => { - proc.userData.vad = await silero.VAD.load(); - }, entry: async (ctx: JobContext) => { - const vad = ctx.proc.userData.vad! as silero.VAD; - const session = new voice.AgentSession({ - vad, stt: new inference.STT({ model: 'deepgram/nova-3', language: 'en', diff --git a/examples/src/realtime_agent.ts b/examples/src/realtime_agent.ts index b30171776..a6879262b 100644 --- a/examples/src/realtime_agent.ts +++ b/examples/src/realtime_agent.ts @@ -1,17 +1,8 @@ // SPDX-FileCopyrightText: 2025 LiveKit, Inc. // // SPDX-License-Identifier: Apache-2.0 -import { - type JobContext, - type JobProcess, - ServerOptions, - cli, - defineAgent, - llm, - voice, -} from '@livekit/agents'; +import { type JobContext, ServerOptions, cli, defineAgent, llm, voice } from '@livekit/agents'; import * as openai from '@livekit/agents-plugin-openai'; -import * as silero from '@livekit/agents-plugin-silero'; import { readFileSync } from 'node:fs'; import { fileURLToPath } from 'node:url'; import { z } from 'zod'; @@ -19,9 +10,6 @@ import { z } from 'zod'; const roomNameSchema = z.enum(['bedroom', 'living room', 'kitchen', 'bathroom', 'office']); export default defineAgent({ - prewarm: async (proc: JobProcess) => { - proc.userData.vad = await silero.VAD.load(); - }, entry: async (ctx: JobContext) => { const getWeather = llm.tool({ description: ' Called when the user asks about the weather.', diff --git a/examples/src/realtime_turn_detector.ts b/examples/src/realtime_turn_detector.ts index 6e6ff90dd..949b732f8 100644 --- a/examples/src/realtime_turn_detector.ts +++ b/examples/src/realtime_turn_detector.ts @@ -1,28 +1,16 @@ // SPDX-FileCopyrightText: 2025 LiveKit, Inc. // // SPDX-License-Identifier: Apache-2.0 -import { - type JobContext, - type JobProcess, - ServerOptions, - cli, - defineAgent, - voice, -} from '@livekit/agents'; +import { type JobContext, ServerOptions, cli, defineAgent, voice } from '@livekit/agents'; import * as deepgram from '@livekit/agents-plugin-deepgram'; import * as elevenlabs from '@livekit/agents-plugin-elevenlabs'; import * as livekit from '@livekit/agents-plugin-livekit'; import * as openai from '@livekit/agents-plugin-openai'; -import * as silero from '@livekit/agents-plugin-silero'; import { fileURLToPath } from 'node:url'; export default defineAgent({ - prewarm: async (proc: JobProcess) => { - proc.userData.vad = await silero.VAD.load(); - }, entry: async (ctx: JobContext) => { const session = new voice.AgentSession({ - vad: ctx.proc.userData.vad! as silero.VAD, stt: new deepgram.STT(), tts: new elevenlabs.TTS(), // To use OpenAI Realtime API diff --git a/examples/src/realtime_with_tts.ts b/examples/src/realtime_with_tts.ts index d87db7853..05df047be 100644 --- a/examples/src/realtime_with_tts.ts +++ b/examples/src/realtime_with_tts.ts @@ -1,27 +1,14 @@ // SPDX-FileCopyrightText: 2025 LiveKit, Inc. // // SPDX-License-Identifier: Apache-2.0 -import { - type JobContext, - type JobProcess, - ServerOptions, - cli, - defineAgent, - llm, - log, - voice, -} from '@livekit/agents'; +import { type JobContext, ServerOptions, cli, defineAgent, llm, log, voice } from '@livekit/agents'; import * as cartesia from '@livekit/agents-plugin-cartesia'; import * as openai from '@livekit/agents-plugin-openai'; -import * as silero from '@livekit/agents-plugin-silero'; import { BackgroundVoiceCancellation } from '@livekit/noise-cancellation-node'; import { fileURLToPath } from 'node:url'; import { z } from 'zod'; export default defineAgent({ - prewarm: async (proc: JobProcess) => { - proc.userData.vad = await silero.VAD.load(); - }, entry: async (ctx: JobContext) => { const logger = log(); diff --git a/examples/src/restaurant_agent.ts b/examples/src/restaurant_agent.ts index d9faaf9a5..081552c1a 100644 --- a/examples/src/restaurant_agent.ts +++ b/examples/src/restaurant_agent.ts @@ -3,7 +3,6 @@ // SPDX-License-Identifier: Apache-2.0 import { type JobContext, - type JobProcess, ServerOptions, cli, dedent, @@ -12,7 +11,6 @@ import { llm, voice, } from '@livekit/agents'; -import * as silero from '@livekit/agents-plugin-silero'; import { fileURLToPath } from 'node:url'; import { z } from 'zod'; @@ -358,9 +356,6 @@ function createCheckoutAgent(menu: string) { } export default defineAgent({ - prewarm: async (proc: JobProcess) => { - proc.userData.vad = await silero.VAD.load(); - }, entry: async (ctx: JobContext) => { const menu = 'Pizza: $10, Salad: $5, Ice Cream: $3, Coffee: $2'; const userData = createUserData({ @@ -370,9 +365,10 @@ export default defineAgent({ checkout: createCheckoutAgent(menu), }); - const vad = ctx.proc.userData.vad! as silero.VAD; const session = new voice.AgentSession({ - vad, + // VAD is auto-provisioned by AgentSession (bundled silero via + // @livekit/local-inference). Pass `vad: null` to opt out, or pass + // your own `new inference.VAD({ ... })` to customise. stt: new inference.STT({ model: 'deepgram/nova-3' }), llm: new inference.LLM({ model: 'openai/gpt-4.1-mini' }), tts: new inference.TTS({ model: 'cartesia/sonic-3' }), diff --git a/examples/src/runway_avatar.ts b/examples/src/runway_avatar.ts index 3d3cde0bd..dd0c3aaf5 100644 --- a/examples/src/runway_avatar.ts +++ b/examples/src/runway_avatar.ts @@ -3,7 +3,6 @@ // SPDX-License-Identifier: Apache-2.0 import { type JobContext, - type JobProcess, ServerOptions, cli, defineAgent, @@ -13,17 +12,12 @@ import { } from '@livekit/agents'; import * as google from '@livekit/agents-plugin-google'; import * as runway from '@livekit/agents-plugin-runway'; -import * as silero from '@livekit/agents-plugin-silero'; import { fileURLToPath } from 'node:url'; export default defineAgent({ - prewarm: async (proc: JobProcess) => { - proc.userData.vad = await silero.VAD.load(); - }, entry: async (ctx: JobContext) => { const logger = log(); const session = new voice.AgentSession({ - vad: ctx.proc.userData.vad! as silero.VAD, llm: new google.realtime.RealtimeModel({ thinkingConfig: { includeThoughts: false }, }), diff --git a/examples/src/telephony_amd.ts b/examples/src/telephony_amd.ts index 424221935..3ab61d150 100644 --- a/examples/src/telephony_amd.ts +++ b/examples/src/telephony_amd.ts @@ -3,7 +3,6 @@ // SPDX-License-Identifier: Apache-2.0 import { type JobContext, - type JobProcess, ServerOptions, cli, defineAgent, @@ -12,7 +11,6 @@ import { voice, } from '@livekit/agents'; import * as livekit from '@livekit/agents-plugin-livekit'; -import * as silero from '@livekit/agents-plugin-silero'; import { TrackKind } from '@livekit/rtc-node'; import { RoomServiceClient, SipClient } from 'livekit-server-sdk'; import { fileURLToPath } from 'node:url'; @@ -41,9 +39,6 @@ class MyAgent extends voice.Agent { * SIP_PARTICIPANT_IDENTITY — identity to assign the dialed participant */ export default defineAgent({ - prewarm: async (proc: JobProcess) => { - proc.userData.vad = await silero.VAD.load(); - }, entry: async (ctx: JobContext) => { const logger = log().child({ room: ctx.room.name }); @@ -60,7 +55,6 @@ export default defineAgent({ turnHandling: { turnDetection: new livekit.turnDetector.MultilingualModel(), }, - vad: ctx.proc.userData.vad! as silero.VAD, preemptiveGeneration: true, }); diff --git a/examples/src/tool_call_disfluency.ts b/examples/src/tool_call_disfluency.ts index 8f92183a8..bc2368164 100644 --- a/examples/src/tool_call_disfluency.ts +++ b/examples/src/tool_call_disfluency.ts @@ -4,7 +4,6 @@ import { AutoSubscribe, type JobContext, - type JobProcess, ServerOptions, cli, defineAgent, @@ -14,7 +13,6 @@ import { import * as elevenlabs from '@livekit/agents-plugin-elevenlabs'; import * as livekit from '@livekit/agents-plugin-livekit'; import * as openai from '@livekit/agents-plugin-openai'; -import * as silero from '@livekit/agents-plugin-silero'; import { fileURLToPath } from 'node:url'; import { z } from 'zod'; @@ -30,13 +28,9 @@ class VoiceAgent extends voice.Agent { } export default defineAgent({ - prewarm: async (proc: JobProcess) => { - proc.userData.vad = await silero.VAD.load(); - }, entry: async (ctx: JobContext) => { await ctx.connect(undefined, AutoSubscribe.AUDIO_ONLY, undefined); await ctx.waitForParticipant(); - const vad = ctx.proc.userData.vad! as silero.VAD; const getWeather = llm.tool({ description: ' Called when the user asks about the weather.', @@ -61,7 +55,6 @@ export default defineAgent({ }); const session = new voice.AgentSession({ - vad, llm: new openai.realtime.RealtimeModel(), tts: new elevenlabs.TTS(), turnDetection: new livekit.turnDetector.MultilingualModel(), diff --git a/plugins/livekit/src/turn_detector/base.ts b/plugins/livekit/src/turn_detector/base.ts index 93ecdd7f9..3fa1dd139 100644 --- a/plugins/livekit/src/turn_detector/base.ts +++ b/plugins/livekit/src/turn_detector/base.ts @@ -231,8 +231,11 @@ export abstract class EOUModel { return (await this.unlikelyThreshold(language)) !== undefined; } + // `_timeoutMs` is part of the unified `_TurnDetector` contract (milliseconds, + // matching the audio EOT detector). Text-based inference is bounded by the IPC + // executor itself, so this detector does not use the value. // eslint-disable-next-line @typescript-eslint/no-unused-vars - async predictEndOfTurn(chatCtx: llm.ChatContext, timeout: number = 3): Promise { + async predictEndOfTurn(chatCtx: llm.ChatContext, _timeoutMs?: number): Promise { let messages: RawChatItem[] = []; for (const message of chatCtx.items) { diff --git a/plugins/livekit/src/turn_detector/index.ts b/plugins/livekit/src/turn_detector/index.ts index 8ffad4c1b..b0234e8c0 100644 --- a/plugins/livekit/src/turn_detector/index.ts +++ b/plugins/livekit/src/turn_detector/index.ts @@ -6,6 +6,13 @@ import { extname } from 'node:path'; import { INFERENCE_METHOD_EN } from './english.js'; import { INFERENCE_METHOD_MULTILINGUAL } from './multilingual.js'; +console.warn( + 'The text-based turn detector from @livekit/agents-plugin-livekit is deprecated. ' + + 'The audio EOT detector in `@livekit/agents` inference (AudioTurnDetector) replaces ' + + 'it and runs natively on-device via @livekit/local-inference. ' + + 'This text-based path will be removed in a future release.', +); + export { EOUModel } from './base.js'; export { EnglishModel } from './english.js'; export { MultilingualModel } from './multilingual.js'; diff --git a/plugins/livekit/src/turn_detector/multilingual.ts b/plugins/livekit/src/turn_detector/multilingual.ts index 57e94ba8d..cd0423913 100644 --- a/plugins/livekit/src/turn_detector/multilingual.ts +++ b/plugins/livekit/src/turn_detector/multilingual.ts @@ -68,10 +68,10 @@ export class MultilingualModel extends EOUModel { return threshold; } - async predictEndOfTurn(chatCtx: llm.ChatContext, timeout: number = 3): Promise { + async predictEndOfTurn(chatCtx: llm.ChatContext, timeoutMs?: number): Promise { const url = remoteInferenceUrl(); if (!url) { - return await super.predictEndOfTurn(chatCtx, timeout); + return await super.predictEndOfTurn(chatCtx, timeoutMs); } // Copy and process chat context similar to Python implementation diff --git a/plugins/silero/src/index.ts b/plugins/silero/src/index.ts index 2b5b67fb6..41a4dc96e 100644 --- a/plugins/silero/src/index.ts +++ b/plugins/silero/src/index.ts @@ -5,6 +5,14 @@ import { Plugin } from '@livekit/agents'; export { VAD, VADStream } from './vad.js'; +console.warn( + '@livekit/agents-plugin-silero is deprecated and will be removed in v2.0. ' + + 'AgentSession now defaults to the bundled silero VAD (via @livekit/local-inference); ' + + 'drop the explicit `vad=` argument entirely, pass `vad: null` to opt out, or use ' + + "`import { inference } from '@livekit/agents'; new inference.VAD({ model: 'silero', ... })` " + + 'to customise options.', +); + class SileroPlugin extends Plugin { constructor() { super({ diff --git a/plugins/silero/src/vad.ts b/plugins/silero/src/vad.ts index e71702aa3..66b96382d 100644 --- a/plugins/silero/src/vad.ts +++ b/plugins/silero/src/vad.ts @@ -103,6 +103,10 @@ export class VAD extends baseVAD { return new VAD(session, mergedOpts); } + override get minSilenceDuration(): number { + return this.#opts.minSilenceDuration; + } + stream(): VADStream { const stream = new VADStream( this, diff --git a/pnpm-lock.yaml b/pnpm-lock.yaml index 51c33a91e..92c6d1af8 100644 --- a/pnpm-lock.yaml +++ b/pnpm-lock.yaml @@ -112,12 +112,15 @@ importers: '@ffmpeg-installer/ffmpeg': specifier: ^1.1.0 version: 1.1.0 + '@livekit/local-inference': + specifier: ^0.2.5 + version: 0.2.5 '@livekit/mutex': specifier: ^1.1.1 version: 1.1.1 '@livekit/protocol': - specifier: ^1.45.7 - version: 1.45.7 + specifier: ^1.46.2 + version: 1.46.2 '@livekit/throws-transformer': specifier: 0.1.8 version: 0.1.8(typescript@5.9.3) @@ -2219,6 +2222,35 @@ packages: '@livekit/changesets-changelog-github@0.0.4': resolution: {integrity: sha512-MXaiLYwgkYciZb8G2wkVtZ1pJJzZmVx5cM30Q+ClslrIYyAqQhRbPmZDM79/5CGxb1MTemR/tfOM25tgJgAK0g==} + '@livekit/local-inference-darwin-arm64@0.2.5': + resolution: {integrity: sha512-tdAGJRiYwko0rOmeI/dXf7Mo5TF+oeWDsK55Ga/2PZ/SHuYZ8jkJAPRaG1k78ePsJ119lySWZsxnJdVnOJowRA==} + cpu: [arm64] + os: [darwin] + + '@livekit/local-inference-darwin-x64@0.2.5': + resolution: {integrity: sha512-FeJUHbx1swyAssS/X9CoI8s4OqeSrYJy/xhKhL0VnH1b5tlVfc6V5OjkLNZl55Jw9JYj0YkYpt0m0OIg3SvYRw==} + cpu: [x64] + os: [darwin] + + '@livekit/local-inference-linux-arm64-gnu@0.2.5': + resolution: {integrity: sha512-hXigtVBLS55wT6oOfpDl2Xh6mhfzsrMxvkLftFFfttjFfFjSouuxkxG5NgQTGP01DGAvYO6mnIP8ASK6livr1w==} + cpu: [arm64] + os: [linux] + + '@livekit/local-inference-linux-x64-gnu@0.2.5': + resolution: {integrity: sha512-3unNMNNc9rLCvGH6f3W6DKd4AlF5Z63mdOh9bGtEDZdPon/h7O3oWo9+6N/sHgULfHyD/vZn2NtT4MLtuhoJIw==} + cpu: [x64] + os: [linux] + + '@livekit/local-inference-win32-x64-msvc@0.2.5': + resolution: {integrity: sha512-3s9paiOPwU+TQYPHNLzMxm/xCoZ8swzt8GF2BZSofI/jL2ao4SK1J3D23JEZuQfuZF4iLZm2dlIxMqAodQ9TCA==} + cpu: [x64] + os: [win32] + + '@livekit/local-inference@0.2.5': + resolution: {integrity: sha512-0n2m4pld1jMqgeZyHs4+3q9gPzq0ousrx3wA8kULAoia/464uIsJ3JqrVGnH8yD4P/yrGeK11VpZ87S+hKeMAQ==} + engines: {node: '>=18.0.0'} + '@livekit/mutex@1.1.1': resolution: {integrity: sha512-EsshAucklmpuUAfkABPxJNhzj9v2sG7JuzFDL4ML1oJQSV14sqrpTYnsaOudMAw9yOaW53NU3QQTlUQoRs4czw==} @@ -2250,8 +2282,8 @@ packages: cpu: [x64] os: [win32] - '@livekit/protocol@1.45.7': - resolution: {integrity: sha512-UVYtWQohAwowygFFglMKfgjVZMQncCEmHmsQX2yJDhgBf1nZQdfANgUJg+ifxZDTfVpNnQWQjikWMHViq5fh2Q==} + '@livekit/protocol@1.46.2': + resolution: {integrity: sha512-qTL4TIxAguYFG8PAqBaCGch7X0VOSCnMs0GJp5fc0GiXtvMwfh+yWaoWTBMPX3RWguSZlvN4/Jpu1RZZyXHakA==} '@livekit/rtc-ffi-bindings-darwin-arm64@0.12.52-patch.0': resolution: {integrity: sha512-IKUir6goV8yVRR7E2qrAP0JtH7gUyMkO0TG8G+dopO/fkXAsPpSealgI9fLcBJl0zhKK+eGCr741r6xR+xxsVw==} @@ -4423,6 +4455,10 @@ packages: resolution: {integrity: sha512-dRB78srN/l6gqWulah9SrxeYnxeddIG30+GOqK/9OlLVyLg3HPnr6SqOWTWOXKRwC2eGYCkZ59NNuSgvSrpgOA==} engines: {node: ^12.20.0 || ^14.13.1 || >=16.0.0} + node-gyp-build@4.8.4: + resolution: {integrity: sha512-LA4ZjwlnUblHVgq0oBF3Jl/6h/Nvs5fzBLwdEF4nuxnFdsfajde4WfxtJr3CaiH+F6ewcIB/q4jQ4UzPyid+CQ==} + hasBin: true + npm-run-path@5.3.0: resolution: {integrity: sha512-ppwTtiJZq0O/ai0z7yfudtBpWIoxM8yE6nHi1X47eFR2EWORqfbu6CnPlNsjeN683eT0qG6H/Pyf9fCcvjnnnQ==} engines: {node: ^12.20.0 || ^14.13.1 || >=16.0.0} @@ -6266,6 +6302,31 @@ snapshots: transitivePeerDependencies: - encoding + '@livekit/local-inference-darwin-arm64@0.2.5': + optional: true + + '@livekit/local-inference-darwin-x64@0.2.5': + optional: true + + '@livekit/local-inference-linux-arm64-gnu@0.2.5': + optional: true + + '@livekit/local-inference-linux-x64-gnu@0.2.5': + optional: true + + '@livekit/local-inference-win32-x64-msvc@0.2.5': + optional: true + + '@livekit/local-inference@0.2.5': + dependencies: + node-gyp-build: 4.8.4 + optionalDependencies: + '@livekit/local-inference-darwin-arm64': 0.2.5 + '@livekit/local-inference-darwin-x64': 0.2.5 + '@livekit/local-inference-linux-arm64-gnu': 0.2.5 + '@livekit/local-inference-linux-x64-gnu': 0.2.5 + '@livekit/local-inference-win32-x64-msvc': 0.2.5 + '@livekit/mutex@1.1.1': {} '@livekit/noise-cancellation-darwin-arm64@0.1.9': @@ -6294,7 +6355,7 @@ snapshots: '@livekit/noise-cancellation-win32-x64@0.1.9': optional: true - '@livekit/protocol@1.45.7': + '@livekit/protocol@1.46.2': dependencies: '@bufbuild/protobuf': 1.10.1 @@ -8624,7 +8685,7 @@ snapshots: livekit-server-sdk@2.14.1: dependencies: '@bufbuild/protobuf': 1.10.1 - '@livekit/protocol': 1.45.7 + '@livekit/protocol': 1.46.2 camelcase-keys: 9.1.3 jose: 5.2.4 @@ -8781,6 +8842,8 @@ snapshots: fetch-blob: 3.2.0 formdata-polyfill: 4.0.10 + node-gyp-build@4.8.4: {} + npm-run-path@5.3.0: dependencies: path-key: 4.0.0