diff --git a/.github/workflows/backend.yml b/.github/workflows/backend.yml index 3bff80e409b9..50cc9b180bff 100644 --- a/.github/workflows/backend.yml +++ b/.github/workflows/backend.yml @@ -2014,6 +2014,20 @@ jobs: dockerfile: "./backend/Dockerfile.golang" context: "./" ubuntu-version: '2404' + #opus + - build-type: '' + cuda-major-version: "" + cuda-minor-version: "" + platforms: 'linux/amd64,linux/arm64' + tag-latest: 'auto' + tag-suffix: '-cpu-opus' + runs-on: 'ubuntu-latest' + base-image: "ubuntu:24.04" + skip-drivers: 'false' + backend: "opus" + dockerfile: "./backend/Dockerfile.golang" + context: "./" + ubuntu-version: '2404' #silero-vad - build-type: '' cuda-major-version: "" @@ -2347,6 +2361,10 @@ jobs: tag-suffix: "-metal-darwin-arm64-piper" build-type: "metal" lang: "go" + - backend: "opus" + tag-suffix: "-metal-darwin-arm64-opus" + build-type: "metal" + lang: "go" - backend: "silero-vad" tag-suffix: "-metal-darwin-arm64-silero-vad" build-type: "metal" diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 61e954733b94..4a37c3b50c73 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -93,7 +93,7 @@ jobs: - name: Dependencies run: | sudo apt-get update - sudo apt-get install curl ffmpeg + sudo apt-get install curl ffmpeg libopus-dev - name: Setup Node.js uses: actions/setup-node@v6 with: @@ -195,7 +195,7 @@ jobs: run: go version - name: Dependencies run: | - brew install protobuf grpc make protoc-gen-go protoc-gen-go-grpc libomp llvm + brew install protobuf grpc make protoc-gen-go protoc-gen-go-grpc libomp llvm opus pip install --user --no-cache-dir grpcio-tools grpcio - name: Setup Node.js uses: actions/setup-node@v6 diff --git a/.github/workflows/tests-e2e.yml b/.github/workflows/tests-e2e.yml index 5bb69a6acacc..73a9535b664b 100644 --- a/.github/workflows/tests-e2e.yml +++ b/.github/workflows/tests-e2e.yml @@ -43,7 +43,7 @@ jobs: - name: Dependencies run: | sudo apt-get update - sudo apt-get install -y build-essential + sudo apt-get install -y build-essential libopus-dev - name: Setup Node.js uses: actions/setup-node@v6 with: diff --git a/.gitignore b/.gitignore index 3d7e27f7a96d..3dcb309ca40d 100644 --- a/.gitignore +++ b/.gitignore @@ -38,6 +38,7 @@ test-models/ test-dir/ tests/e2e-aio/backends tests/e2e-aio/models +mock-backend release/ @@ -69,3 +70,6 @@ docs/static/gallery.html # React UI build artifacts (keep placeholder dist/index.html) core/http/react-ui/node_modules/ core/http/react-ui/dist + +# Extracted backend binaries for container-based testing +local-backends/ diff --git a/Dockerfile b/Dockerfile index 99bf2842e598..85492e81a9f7 100644 --- a/Dockerfile +++ b/Dockerfile @@ -10,7 +10,7 @@ ENV DEBIAN_FRONTEND=noninteractive RUN apt-get update && \ apt-get install -y --no-install-recommends \ ca-certificates curl wget espeak-ng libgomp1 \ - ffmpeg libopenblas0 libopenblas-dev sox && \ + ffmpeg libopenblas0 libopenblas-dev libopus0 sox && \ apt-get clean && \ rm -rf /var/lib/apt/lists/* @@ -190,6 +190,7 @@ RUN apt-get update && \ curl libssl-dev \ git \ git-lfs \ + libopus-dev pkg-config \ unzip upx-ucl python3 python-is-python3 && \ apt-get clean && \ rm -rf /var/lib/apt/lists/* @@ -378,6 +379,9 @@ COPY ./entrypoint.sh . # Copy the binary COPY --from=builder /build/local-ai ./ +# Copy the opus shim if it was built +RUN --mount=from=builder,src=/build/,dst=/mnt/build \ + if [ -f /mnt/build/libopusshim.so ]; then cp /mnt/build/libopusshim.so ./; fi # Make sure the models directory exists RUN mkdir -p /models /backends diff --git a/Makefile b/Makefile index 4a2385a4531e..6a8b639d1ec1 100644 --- a/Makefile +++ b/Makefile @@ -1,5 +1,5 @@ # Disable parallel execution for backend builds -.NOTPARALLEL: backends/diffusers backends/llama-cpp backends/outetts backends/piper backends/stablediffusion-ggml backends/whisper backends/faster-whisper backends/silero-vad backends/local-store backends/huggingface backends/rfdetr backends/kitten-tts backends/kokoro backends/chatterbox backends/llama-cpp-darwin backends/neutts build-darwin-python-backend build-darwin-go-backend backends/mlx backends/diffuser-darwin backends/mlx-vlm backends/mlx-audio backends/mlx-distributed backends/stablediffusion-ggml-darwin backends/vllm backends/vllm-omni backends/moonshine backends/pocket-tts backends/qwen-tts backends/faster-qwen3-tts backends/qwen-asr backends/nemo backends/voxcpm backends/whisperx backends/ace-step backends/acestep-cpp backends/fish-speech backends/voxtral +.NOTPARALLEL: backends/diffusers backends/llama-cpp backends/outetts backends/piper backends/stablediffusion-ggml backends/whisper backends/faster-whisper backends/silero-vad backends/local-store backends/huggingface backends/rfdetr backends/kitten-tts backends/kokoro backends/chatterbox backends/llama-cpp-darwin backends/neutts build-darwin-python-backend build-darwin-go-backend backends/mlx backends/diffuser-darwin backends/mlx-vlm backends/mlx-audio backends/mlx-distributed backends/stablediffusion-ggml-darwin backends/vllm backends/vllm-omni backends/moonshine backends/pocket-tts backends/qwen-tts backends/faster-qwen3-tts backends/qwen-asr backends/nemo backends/voxcpm backends/whisperx backends/ace-step backends/acestep-cpp backends/fish-speech backends/voxtral backends/opus GOCMD=go GOTEST=$(GOCMD) test @@ -106,6 +106,7 @@ react-ui-docker: core/http/react-ui/dist: react-ui ## Build: + build: protogen-go install-go-tools core/http/react-ui/dist ## Build the project $(info ${GREEN}I local-ai build info:${RESET}) $(info ${GREEN}I BUILD_TYPE: ${YELLOW}$(BUILD_TYPE)${RESET}) @@ -163,6 +164,7 @@ test: test-models/testmodel.ggml protogen-go @echo 'Running tests' export GO_TAGS="debug" $(MAKE) prepare-test + OPUS_SHIM_LIBRARY=$(abspath ./pkg/opus/shim/libopusshim.so) \ HUGGINGFACE_GRPC=$(abspath ./)/backend/python/transformers/run.sh TEST_DIR=$(abspath ./)/test-dir/ FIXTURES=$(abspath ./)/tests/fixtures CONFIG_FILE=$(abspath ./)/test-models/config.yaml MODELS_PATH=$(abspath ./)/test-models BACKENDS_PATH=$(abspath ./)/backends \ $(GOCMD) run github.com/onsi/ginkgo/v2/ginkgo --label-filter="!llama-gguf" --flake-attempts $(TEST_FLAKES) --fail-fast -v -r $(TEST_PATHS) $(MAKE) test-llama-gguf @@ -250,6 +252,88 @@ test-stablediffusion: prepare-test test-stores: $(GOCMD) run github.com/onsi/ginkgo/v2/ginkgo --label-filter="stores" --flake-attempts $(TEST_FLAKES) -v -r tests/integration +test-opus: + @echo 'Running opus backend tests' + $(MAKE) -C backend/go/opus libopusshim.so + $(GOCMD) run github.com/onsi/ginkgo/v2/ginkgo --flake-attempts $(TEST_FLAKES) -v -r ./backend/go/opus/... + +test-opus-docker: + @echo 'Running opus backend tests in Docker' + docker build --target builder \ + --build-arg BUILD_TYPE=$(or $(BUILD_TYPE),) \ + --build-arg BASE_IMAGE=$(or $(BASE_IMAGE),ubuntu:24.04) \ + --build-arg BACKEND=opus \ + -t localai-opus-test -f backend/Dockerfile.golang . + docker run --rm localai-opus-test \ + bash -c 'cd /LocalAI && go run github.com/onsi/ginkgo/v2/ginkgo --flake-attempts $(TEST_FLAKES) -v -r ./backend/go/opus/...' + +test-realtime: build-mock-backend + @echo 'Running realtime e2e tests (mock backend)' + $(GOCMD) run github.com/onsi/ginkgo/v2/ginkgo --label-filter="Realtime && !real-models" --flake-attempts $(TEST_FLAKES) -v -r ./tests/e2e + +# Real-model realtime tests. Set REALTIME_TEST_MODEL to use your own pipeline, +# or leave unset to auto-build one from the component env vars below. +REALTIME_VAD?=silero-vad-ggml +REALTIME_STT?=whisper-1 +REALTIME_LLM?=qwen3-0.6b +REALTIME_TTS?=tts-1 +REALTIME_BACKENDS_PATH?=$(abspath ./)/backends + +test-realtime-models: build-mock-backend + @echo 'Running realtime e2e tests (real models)' + REALTIME_TEST_MODEL=$${REALTIME_TEST_MODEL:-realtime-test-pipeline} \ + REALTIME_VAD=$(REALTIME_VAD) \ + REALTIME_STT=$(REALTIME_STT) \ + REALTIME_LLM=$(REALTIME_LLM) \ + REALTIME_TTS=$(REALTIME_TTS) \ + REALTIME_BACKENDS_PATH=$(REALTIME_BACKENDS_PATH) \ + $(GOCMD) run github.com/onsi/ginkgo/v2/ginkgo --label-filter="Realtime" --flake-attempts $(TEST_FLAKES) -v -r ./tests/e2e + +# --- Container-based real-model testing --- + +REALTIME_BACKEND_NAMES ?= silero-vad whisper llama-cpp kokoro +REALTIME_MODELS_DIR ?= $(abspath ./models) +REALTIME_BACKENDS_DIR ?= $(abspath ./local-backends) +REALTIME_DOCKER_FLAGS ?= --gpus all + +local-backends: + mkdir -p local-backends + +extract-backend-%: docker-build-% local-backends + @echo "Extracting backend $*..." + @CID=$$(docker create local-ai-backend:$*) && \ + rm -rf local-backends/$* && mkdir -p local-backends/$* && \ + docker cp $$CID:/ - | tar -xf - -C local-backends/$* && \ + docker rm $$CID > /dev/null + +extract-realtime-backends: $(addprefix extract-backend-,$(REALTIME_BACKEND_NAMES)) + +test-realtime-models-docker: build-mock-backend + docker build --target build-requirements \ + --build-arg BUILD_TYPE=$(or $(BUILD_TYPE),cublas) \ + --build-arg CUDA_MAJOR_VERSION=$(or $(CUDA_MAJOR_VERSION),13) \ + --build-arg CUDA_MINOR_VERSION=$(or $(CUDA_MINOR_VERSION),0) \ + -t localai-test-runner . + docker run --rm \ + $(REALTIME_DOCKER_FLAGS) \ + -v $(abspath ./):/build \ + -v $(REALTIME_MODELS_DIR):/models:ro \ + -v $(REALTIME_BACKENDS_DIR):/backends \ + -v localai-go-cache:/root/go/pkg/mod \ + -v localai-go-build-cache:/root/.cache/go-build \ + -e REALTIME_TEST_MODEL=$${REALTIME_TEST_MODEL:-realtime-test-pipeline} \ + -e REALTIME_VAD=$(REALTIME_VAD) \ + -e REALTIME_STT=$(REALTIME_STT) \ + -e REALTIME_LLM=$(REALTIME_LLM) \ + -e REALTIME_TTS=$(REALTIME_TTS) \ + -e REALTIME_BACKENDS_PATH=/backends \ + -e REALTIME_MODELS_PATH=/models \ + -w /build \ + localai-test-runner \ + bash -c 'git config --global --add safe.directory /build && \ + make protogen-go && make build-mock-backend && \ + go run github.com/onsi/ginkgo/v2/ginkgo --label-filter="Realtime" --flake-attempts $(TEST_FLAKES) -v -r ./tests/e2e' + test-container: docker build --target requirements -t local-ai-test-container . docker run -ti --rm --entrypoint /bin/bash -ti -v $(abspath ./):/build local-ai-test-container @@ -477,6 +561,7 @@ BACKEND_STABLEDIFFUSION_GGML = stablediffusion-ggml|golang|.|--progress=plain|tr BACKEND_WHISPER = whisper|golang|.|false|true BACKEND_VOXTRAL = voxtral|golang|.|false|true BACKEND_ACESTEP_CPP = acestep-cpp|golang|.|false|true +BACKEND_OPUS = opus|golang|.|false|true # Python backends with root context BACKEND_RERANKERS = rerankers|python|.|false|true @@ -534,6 +619,7 @@ $(eval $(call generate-docker-build-target,$(BACKEND_SILERO_VAD))) $(eval $(call generate-docker-build-target,$(BACKEND_STABLEDIFFUSION_GGML))) $(eval $(call generate-docker-build-target,$(BACKEND_WHISPER))) $(eval $(call generate-docker-build-target,$(BACKEND_VOXTRAL))) +$(eval $(call generate-docker-build-target,$(BACKEND_OPUS))) $(eval $(call generate-docker-build-target,$(BACKEND_RERANKERS))) $(eval $(call generate-docker-build-target,$(BACKEND_TRANSFORMERS))) $(eval $(call generate-docker-build-target,$(BACKEND_OUTETTS))) diff --git a/backend/Dockerfile.golang b/backend/Dockerfile.golang index fce4c77242a1..3bf15c508ea7 100644 --- a/backend/Dockerfile.golang +++ b/backend/Dockerfile.golang @@ -180,6 +180,11 @@ RUN < options = 4; +} + +message AudioEncodeResult { + repeated bytes frames = 1; + int32 sample_rate = 2; + int32 samples_per_frame = 3; +} + +message AudioDecodeRequest { + repeated bytes frames = 1; + map options = 2; +} + +message AudioDecodeResult { + bytes pcm_data = 1; + int32 sample_rate = 2; + int32 samples_per_frame = 3; +} + message ModelMetadataResponse { bool supports_thinking = 1; string rendered_template = 2; // The rendered chat template with enable_thinking=true (empty if not applicable) diff --git a/backend/go/opus/Makefile b/backend/go/opus/Makefile new file mode 100644 index 000000000000..028b16bc5adb --- /dev/null +++ b/backend/go/opus/Makefile @@ -0,0 +1,19 @@ +GOCMD?=go +GO_TAGS?= + +OPUS_CFLAGS := $(shell pkg-config --cflags opus) +OPUS_LIBS := $(shell pkg-config --libs opus) + +libopusshim.so: csrc/opus_shim.c + $(CC) -shared -fPIC -o $@ $< $(OPUS_CFLAGS) $(OPUS_LIBS) + +opus: libopusshim.so + $(GOCMD) build -tags "$(GO_TAGS)" -o opus ./ + +package: opus + bash package.sh + +build: package + +clean: + rm -f opus libopusshim.so diff --git a/backend/go/opus/codec.go b/backend/go/opus/codec.go new file mode 100644 index 000000000000..8c56a09a84c5 --- /dev/null +++ b/backend/go/opus/codec.go @@ -0,0 +1,256 @@ +package main + +import ( + "errors" + "fmt" + "os" + "path/filepath" + "runtime" + "sync" + + "github.com/ebitengine/purego" +) + +const ( + ApplicationVoIP = 2048 + ApplicationAudio = 2049 + ApplicationRestrictedLowDelay = 2051 +) + +var ( + initOnce sync.Once + initErr error + + opusLib uintptr + shimLib uintptr + + // libopus functions + cEncoderCreate func(fs int32, channels int32, application int32, errPtr *int32) uintptr + cEncode func(st uintptr, pcm *int16, frameSize int32, data *byte, maxBytes int32) int32 + cEncoderDestroy func(st uintptr) + + cDecoderCreate func(fs int32, channels int32, errPtr *int32) uintptr + cDecode func(st uintptr, data *byte, dataLen int32, pcm *int16, frameSize int32, decodeFec int32) int32 + cDecoderDestroy func(st uintptr) + + // shim functions (non-variadic wrappers for opus_encoder_ctl) + cSetBitrate func(st uintptr, bitrate int32) int32 + cSetComplexity func(st uintptr, complexity int32) int32 +) + +func loadLib(names []string) (uintptr, error) { + var firstErr error + for _, name := range names { + h, err := purego.Dlopen(name, purego.RTLD_NOW|purego.RTLD_GLOBAL) + if err == nil { + return h, nil + } + if firstErr == nil { + firstErr = err + } + } + return 0, firstErr +} + +func ensureInit() error { + initOnce.Do(func() { + initErr = doInit() + }) + return initErr +} + +const shimHint = "ensure libopus-dev is installed and rebuild, or set OPUS_LIBRARY / OPUS_SHIM_LIBRARY env vars" + +func doInit() error { + opusNames := opusSearchPaths() + var err error + opusLib, err = loadLib(opusNames) + if err != nil { + return fmt.Errorf("opus: failed to load libopus (%s): %w", shimHint, err) + } + + purego.RegisterLibFunc(&cEncoderCreate, opusLib, "opus_encoder_create") + purego.RegisterLibFunc(&cEncode, opusLib, "opus_encode") + purego.RegisterLibFunc(&cEncoderDestroy, opusLib, "opus_encoder_destroy") + purego.RegisterLibFunc(&cDecoderCreate, opusLib, "opus_decoder_create") + purego.RegisterLibFunc(&cDecode, opusLib, "opus_decode") + purego.RegisterLibFunc(&cDecoderDestroy, opusLib, "opus_decoder_destroy") + + shimNames := shimSearchPaths() + shimLib, err = loadLib(shimNames) + if err != nil { + return fmt.Errorf("opus: failed to load libopusshim (%s): %w", shimHint, err) + } + + purego.RegisterLibFunc(&cSetBitrate, shimLib, "opus_shim_encoder_set_bitrate") + purego.RegisterLibFunc(&cSetComplexity, shimLib, "opus_shim_encoder_set_complexity") + + return nil +} + +func opusSearchPaths() []string { + var paths []string + + if env := os.Getenv("OPUS_LIBRARY"); env != "" { + paths = append(paths, env) + } + + if exe, err := os.Executable(); err == nil { + dir := filepath.Dir(exe) + paths = append(paths, filepath.Join(dir, "libopus.so.0"), filepath.Join(dir, "libopus.so")) + if runtime.GOOS == "darwin" { + paths = append(paths, filepath.Join(dir, "libopus.dylib")) + } + } + + paths = append(paths, "libopus.so.0", "libopus.so", "libopus.dylib", "opus.dll") + + if runtime.GOOS == "darwin" { + paths = append(paths, + "/opt/homebrew/lib/libopus.dylib", + "/usr/local/lib/libopus.dylib", + ) + } + + return paths +} + +func shimSearchPaths() []string { + var paths []string + + if env := os.Getenv("OPUS_SHIM_LIBRARY"); env != "" { + paths = append(paths, env) + } + + if exe, err := os.Executable(); err == nil { + dir := filepath.Dir(exe) + paths = append(paths, filepath.Join(dir, "libopusshim.so")) + if runtime.GOOS == "darwin" { + paths = append(paths, filepath.Join(dir, "libopusshim.dylib")) + } + } + + paths = append(paths, "./libopusshim.so", "libopusshim.so") + if runtime.GOOS == "darwin" { + paths = append(paths, "./libopusshim.dylib", "libopusshim.dylib") + } + return paths +} + +// Encoder wraps a libopus OpusEncoder via purego. +type Encoder struct { + st uintptr +} + +func NewEncoder(sampleRate, channels, application int) (*Encoder, error) { + if err := ensureInit(); err != nil { + return nil, err + } + + var opusErr int32 + st := cEncoderCreate(int32(sampleRate), int32(channels), int32(application), &opusErr) + if opusErr != 0 || st == 0 { + return nil, fmt.Errorf("opus_encoder_create failed: error %d", opusErr) + } + return &Encoder{st: st}, nil +} + +// Encode encodes a frame of PCM int16 samples. It returns the number of bytes +// written to out, or a negative error code. +func (e *Encoder) Encode(pcm []int16, frameSize int, out []byte) (int, error) { + if len(pcm) == 0 || len(out) == 0 { + return 0, errors.New("opus encode: empty input or output buffer") + } + n := cEncode(e.st, &pcm[0], int32(frameSize), &out[0], int32(len(out))) + if n < 0 { + return 0, fmt.Errorf("opus_encode failed: error %d", n) + } + return int(n), nil +} + +func (e *Encoder) SetBitrate(bitrate int) error { + if ret := cSetBitrate(e.st, int32(bitrate)); ret != 0 { + return fmt.Errorf("opus set bitrate: error %d", ret) + } + return nil +} + +func (e *Encoder) SetComplexity(complexity int) error { + if ret := cSetComplexity(e.st, int32(complexity)); ret != 0 { + return fmt.Errorf("opus set complexity: error %d", ret) + } + return nil +} + +func (e *Encoder) Close() { + if e.st != 0 { + cEncoderDestroy(e.st) + e.st = 0 + } +} + +// Decoder wraps a libopus OpusDecoder via purego. +type Decoder struct { + st uintptr +} + +func NewDecoder(sampleRate, channels int) (*Decoder, error) { + if err := ensureInit(); err != nil { + return nil, err + } + + var opusErr int32 + st := cDecoderCreate(int32(sampleRate), int32(channels), &opusErr) + if opusErr != 0 || st == 0 { + return nil, fmt.Errorf("opus_decoder_create failed: error %d", opusErr) + } + return &Decoder{st: st}, nil +} + +// Decode decodes an Opus packet into pcm. frameSize is the max number of +// samples per channel that pcm can hold. Returns the number of decoded samples +// per channel. +func (d *Decoder) Decode(data []byte, pcm []int16, frameSize int, fec bool) (int, error) { + if len(pcm) == 0 { + return 0, errors.New("opus decode: empty output buffer") + } + + var dataPtr *byte + var dataLen int32 + if len(data) > 0 { + dataPtr = &data[0] + dataLen = int32(len(data)) + } + + decodeFec := int32(0) + if fec { + decodeFec = 1 + } + + n := cDecode(d.st, dataPtr, dataLen, &pcm[0], int32(frameSize), decodeFec) + if n < 0 { + return 0, fmt.Errorf("opus_decode failed: error %d", n) + } + return int(n), nil +} + +func (d *Decoder) Close() { + if d.st != 0 { + cDecoderDestroy(d.st) + d.st = 0 + } +} + +// Init eagerly loads the opus libraries, returning any error. +// Calling this is optional; the libraries are loaded lazily on first use. +func Init() error { + return ensureInit() +} + +// Reset allows re-initialization (for testing). +func Reset() { + initOnce = sync.Once{} + initErr = nil + opusLib = 0 + shimLib = 0 +} diff --git a/backend/go/opus/csrc/opus_shim.c b/backend/go/opus/csrc/opus_shim.c new file mode 100644 index 000000000000..75d3babb4625 --- /dev/null +++ b/backend/go/opus/csrc/opus_shim.c @@ -0,0 +1,9 @@ +#include + +int opus_shim_encoder_set_bitrate(OpusEncoder *st, opus_int32 bitrate) { + return opus_encoder_ctl(st, OPUS_SET_BITRATE(bitrate)); +} + +int opus_shim_encoder_set_complexity(OpusEncoder *st, opus_int32 complexity) { + return opus_encoder_ctl(st, OPUS_SET_COMPLEXITY(complexity)); +} diff --git a/backend/go/opus/main.go b/backend/go/opus/main.go new file mode 100644 index 000000000000..9bdb68a10b78 --- /dev/null +++ b/backend/go/opus/main.go @@ -0,0 +1,16 @@ +package main + +import ( + "flag" + + grpc "github.com/mudler/LocalAI/pkg/grpc" +) + +var addr = flag.String("addr", "localhost:50051", "the address to connect to") + +func main() { + flag.Parse() + if err := grpc.StartServer(*addr, &Opus{}); err != nil { + panic(err) + } +} diff --git a/backend/go/opus/opus.go b/backend/go/opus/opus.go new file mode 100644 index 000000000000..66478d0e2ba6 --- /dev/null +++ b/backend/go/opus/opus.go @@ -0,0 +1,184 @@ +package main + +import ( + "fmt" + "sync" + "time" + + "github.com/mudler/LocalAI/pkg/grpc/base" + pb "github.com/mudler/LocalAI/pkg/grpc/proto" + "github.com/mudler/LocalAI/pkg/sound" +) + +const ( + opusSampleRate = 48000 + opusChannels = 1 + opusFrameSize = 960 // 20ms at 48kHz + opusMaxPacketSize = 4000 + opusMaxFrameSize = 5760 // 120ms at 48kHz + + decoderIdleTTL = 60 * time.Second + decoderEvictTick = 30 * time.Second +) + +type cachedDecoder struct { + mu sync.Mutex + dec *Decoder + lastUsed time.Time +} + +type Opus struct { + base.Base + + decodersMu sync.Mutex + decoders map[string]*cachedDecoder +} + +func (o *Opus) Load(opts *pb.ModelOptions) error { + o.decoders = make(map[string]*cachedDecoder) + go o.evictLoop() + return Init() +} + +func (o *Opus) evictLoop() { + ticker := time.NewTicker(decoderEvictTick) + defer ticker.Stop() + for range ticker.C { + o.decodersMu.Lock() + now := time.Now() + for id, cd := range o.decoders { + if now.Sub(cd.lastUsed) > decoderIdleTTL { + cd.dec.Close() + delete(o.decoders, id) + } + } + o.decodersMu.Unlock() + } +} + +// getOrCreateDecoder returns a cached decoder for the given session ID, +// creating one if it doesn't exist yet. +func (o *Opus) getOrCreateDecoder(sessionID string) (*cachedDecoder, error) { + o.decodersMu.Lock() + defer o.decodersMu.Unlock() + + if cd, ok := o.decoders[sessionID]; ok { + cd.lastUsed = time.Now() + return cd, nil + } + + dec, err := NewDecoder(opusSampleRate, opusChannels) + if err != nil { + return nil, err + } + cd := &cachedDecoder{dec: dec, lastUsed: time.Now()} + o.decoders[sessionID] = cd + return cd, nil +} + +func (o *Opus) AudioEncode(req *pb.AudioEncodeRequest) (*pb.AudioEncodeResult, error) { + enc, err := NewEncoder(opusSampleRate, opusChannels, ApplicationAudio) + if err != nil { + return nil, fmt.Errorf("opus encoder create: %w", err) + } + defer enc.Close() + + if err := enc.SetBitrate(64000); err != nil { + return nil, fmt.Errorf("opus set bitrate: %w", err) + } + if err := enc.SetComplexity(10); err != nil { + return nil, fmt.Errorf("opus set complexity: %w", err) + } + + samples := sound.BytesToInt16sLE(req.PcmData) + if len(samples) == 0 { + return &pb.AudioEncodeResult{ + SampleRate: opusSampleRate, + SamplesPerFrame: opusFrameSize, + }, nil + } + + if req.SampleRate != 0 && int(req.SampleRate) != opusSampleRate { + samples = sound.ResampleInt16(samples, int(req.SampleRate), opusSampleRate) + } + + var frames [][]byte + packet := make([]byte, opusMaxPacketSize) + + for offset := 0; offset+opusFrameSize <= len(samples); offset += opusFrameSize { + frame := samples[offset : offset+opusFrameSize] + n, err := enc.Encode(frame, opusFrameSize, packet) + if err != nil { + return nil, fmt.Errorf("opus encode: %w", err) + } + out := make([]byte, n) + copy(out, packet[:n]) + frames = append(frames, out) + } + + return &pb.AudioEncodeResult{ + Frames: frames, + SampleRate: opusSampleRate, + SamplesPerFrame: opusFrameSize, + }, nil +} + +func (o *Opus) AudioDecode(req *pb.AudioDecodeRequest) (*pb.AudioDecodeResult, error) { + if len(req.Frames) == 0 { + return &pb.AudioDecodeResult{ + SampleRate: opusSampleRate, + SamplesPerFrame: opusFrameSize, + }, nil + } + + // Use a persistent decoder when a session ID is provided so that Opus + // prediction state carries across batches. Fall back to a fresh decoder + // for backward compatibility. + sessionID := req.Options["session_id"] + + var cd *cachedDecoder + var ownedDec *Decoder + + if sessionID != "" && o.decoders != nil { + var err error + cd, err = o.getOrCreateDecoder(sessionID) + if err != nil { + return nil, fmt.Errorf("opus decoder create: %w", err) + } + cd.mu.Lock() + defer cd.mu.Unlock() + } else { + dec, err := NewDecoder(opusSampleRate, opusChannels) + if err != nil { + return nil, fmt.Errorf("opus decoder create: %w", err) + } + ownedDec = dec + defer ownedDec.Close() + } + + dec := ownedDec + if cd != nil { + dec = cd.dec + } + + var allSamples []int16 + var samplesPerFrame int32 + + pcm := make([]int16, opusMaxFrameSize) + for _, frame := range req.Frames { + n, err := dec.Decode(frame, pcm, opusMaxFrameSize, false) + if err != nil { + return nil, fmt.Errorf("opus decode: %w", err) + } + if samplesPerFrame == 0 { + samplesPerFrame = int32(n) + } + allSamples = append(allSamples, pcm[:n]...) + } + + return &pb.AudioDecodeResult{ + PcmData: sound.Int16toBytesLE(allSamples), + SampleRate: opusSampleRate, + SamplesPerFrame: samplesPerFrame, + }, nil +} diff --git a/backend/go/opus/opus_test.go b/backend/go/opus/opus_test.go new file mode 100644 index 000000000000..b3daf7148072 --- /dev/null +++ b/backend/go/opus/opus_test.go @@ -0,0 +1,1346 @@ +package main + +import ( + "encoding/binary" + "fmt" + "io" + "math" + "math/rand/v2" + "os" + "os/exec" + "path/filepath" + "sync" + "testing" + "time" + + pb "github.com/mudler/LocalAI/pkg/grpc/proto" + "github.com/mudler/LocalAI/pkg/sound" + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" + "github.com/pion/rtp" + "github.com/pion/webrtc/v4" +) + +func TestOpusBackend(t *testing.T) { + RegisterFailHandler(Fail) + RunSpecs(t, "Opus Backend Suite") +} + +// --- helpers --- + +func generateSineWave(freq float64, sampleRate, numSamples int) []int16 { + out := make([]int16, numSamples) + for i := range out { + t := float64(i) / float64(sampleRate) + out[i] = int16(math.MaxInt16 / 2 * math.Sin(2*math.Pi*freq*t)) + } + return out +} + +func computeRMS(samples []int16) float64 { + if len(samples) == 0 { + return 0 + } + var sum float64 + for _, s := range samples { + v := float64(s) + sum += v * v + } + return math.Sqrt(sum / float64(len(samples))) +} + +func estimateFrequency(samples []int16, sampleRate int) float64 { + if len(samples) < 2 { + return 0 + } + crossings := 0 + for i := 1; i < len(samples); i++ { + if (samples[i-1] >= 0 && samples[i] < 0) || (samples[i-1] < 0 && samples[i] >= 0) { + crossings++ + } + } + duration := float64(len(samples)) / float64(sampleRate) + return float64(crossings) / (2 * duration) +} + +// encodeDecodeRoundtrip uses the Opus backend to encode PCM and decode all +// resulting frames, returning the concatenated decoded samples. +func encodeDecodeRoundtrip(o *Opus, pcmBytes []byte, sampleRate int) []int16 { + encResult, err := o.AudioEncode(&pb.AudioEncodeRequest{ + PcmData: pcmBytes, + SampleRate: int32(sampleRate), + Channels: 1, + }) + Expect(err).ToNot(HaveOccurred(), "AudioEncode") + + if len(encResult.Frames) == 0 { + return nil + } + + decResult, err := o.AudioDecode(&pb.AudioDecodeRequest{ + Frames: encResult.Frames, + }) + Expect(err).ToNot(HaveOccurred(), "AudioDecode") + + return sound.BytesToInt16sLE(decResult.PcmData) +} + +func extractOpusFramesFromOgg(data []byte) [][]byte { + var frames [][]byte + pos := 0 + pageNum := 0 + + for pos+27 <= len(data) { + Expect(string(data[pos:pos+4])).To(Equal("OggS"), fmt.Sprintf("invalid Ogg page at offset %d", pos)) + + nSegments := int(data[pos+26]) + if pos+27+nSegments > len(data) { + break + } + + segTable := data[pos+27 : pos+27+nSegments] + dataStart := pos + 27 + nSegments + + var totalDataSize int + for _, s := range segTable { + totalDataSize += int(s) + } + + if dataStart+totalDataSize > len(data) { + break + } + + if pageNum >= 2 { + pageData := data[dataStart : dataStart+totalDataSize] + offset := 0 + var packet []byte + for _, segSize := range segTable { + packet = append(packet, pageData[offset:offset+int(segSize)]...) + offset += int(segSize) + if segSize < 255 { + if len(packet) > 0 { + frameCopy := make([]byte, len(packet)) + copy(frameCopy, packet) + frames = append(frames, frameCopy) + } + packet = nil + } + } + if len(packet) > 0 { + frameCopy := make([]byte, len(packet)) + copy(frameCopy, packet) + frames = append(frames, frameCopy) + } + } + + pos = dataStart + totalDataSize + pageNum++ + } + + return frames +} + +func parseTestWAV(data []byte) (pcm []byte, sampleRate int) { + if len(data) < 44 || string(data[0:4]) != "RIFF" { + return data, 0 + } + pos := 12 + sr := int(binary.LittleEndian.Uint32(data[24:28])) + for pos+8 <= len(data) { + id := string(data[pos : pos+4]) + sz := int(binary.LittleEndian.Uint32(data[pos+4 : pos+8])) + if id == "data" { + end := pos + 8 + sz + if end > len(data) { + end = len(data) + } + return data[pos+8 : end], sr + } + pos += 8 + sz + if sz%2 != 0 { + pos++ + } + } + return data[44:], sr +} + +func writeOggOpus(path string, frames [][]byte, sampleRate, channels int) error { + f, err := os.Create(path) + if err != nil { + return err + } + defer f.Close() + + serial := uint32(0x4C6F6341) // "LocA" + var pageSeq uint32 + const preSkip = 312 + + opusHead := make([]byte, 19) + copy(opusHead[0:8], "OpusHead") + opusHead[8] = 1 + opusHead[9] = byte(channels) + binary.LittleEndian.PutUint16(opusHead[10:12], uint16(preSkip)) + binary.LittleEndian.PutUint32(opusHead[12:16], uint32(sampleRate)) + binary.LittleEndian.PutUint16(opusHead[16:18], 0) + opusHead[18] = 0 + if err := writeOggPage(f, serial, pageSeq, 0, 0x02, [][]byte{opusHead}); err != nil { + return err + } + pageSeq++ + + opusTags := make([]byte, 16) + copy(opusTags[0:8], "OpusTags") + binary.LittleEndian.PutUint32(opusTags[8:12], 0) + binary.LittleEndian.PutUint32(opusTags[12:16], 0) + if err := writeOggPage(f, serial, pageSeq, 0, 0x00, [][]byte{opusTags}); err != nil { + return err + } + pageSeq++ + + var granulePos uint64 + for i, frame := range frames { + granulePos += 960 + headerType := byte(0x00) + if i == len(frames)-1 { + headerType = 0x04 + } + if err := writeOggPage(f, serial, pageSeq, granulePos, headerType, [][]byte{frame}); err != nil { + return err + } + pageSeq++ + } + + return nil +} + +func writeOggPage(w io.Writer, serial, pageSeq uint32, granulePos uint64, headerType byte, packets [][]byte) error { + var segments []byte + var pageData []byte + for _, pkt := range packets { + remaining := len(pkt) + for remaining >= 255 { + segments = append(segments, 255) + remaining -= 255 + } + segments = append(segments, byte(remaining)) + pageData = append(pageData, pkt...) + } + + hdr := make([]byte, 27+len(segments)) + copy(hdr[0:4], "OggS") + hdr[4] = 0 + hdr[5] = headerType + binary.LittleEndian.PutUint64(hdr[6:14], granulePos) + binary.LittleEndian.PutUint32(hdr[14:18], serial) + binary.LittleEndian.PutUint32(hdr[18:22], pageSeq) + hdr[26] = byte(len(segments)) + copy(hdr[27:], segments) + + crc := oggCRC32(hdr, pageData) + binary.LittleEndian.PutUint32(hdr[22:26], crc) + + if _, err := w.Write(hdr); err != nil { + return err + } + _, err := w.Write(pageData) + return err +} + +func oggCRC32(header, data []byte) uint32 { + var crc uint32 + for _, b := range header { + crc = (crc << 8) ^ oggCRCTable[byte(crc>>24)^b] + } + for _, b := range data { + crc = (crc << 8) ^ oggCRCTable[byte(crc>>24)^b] + } + return crc +} + +var oggCRCTable = func() [256]uint32 { + var t [256]uint32 + for i := range 256 { + r := uint32(i) << 24 + for range 8 { + if r&0x80000000 != 0 { + r = (r << 1) ^ 0x04C11DB7 + } else { + r <<= 1 + } + } + t[i] = r + } + return t +}() + +func goertzel(samples []int16, targetFreq float64, sampleRate int) float64 { + N := len(samples) + if N == 0 { + return 0 + } + k := 0.5 + float64(N)*targetFreq/float64(sampleRate) + w := 2 * math.Pi * k / float64(N) + coeff := 2 * math.Cos(w) + var s1, s2 float64 + for _, sample := range samples { + s0 := float64(sample) + coeff*s1 - s2 + s2 = s1 + s1 = s0 + } + return s1*s1 + s2*s2 - coeff*s1*s2 +} + +func computeTHD(samples []int16, fundamentalHz float64, sampleRate, numHarmonics int) float64 { + fundPower := goertzel(samples, fundamentalHz, sampleRate) + if fundPower <= 0 { + return 0 + } + var harmonicSum float64 + for h := 2; h <= numHarmonics; h++ { + harmonicSum += goertzel(samples, fundamentalHz*float64(h), sampleRate) + } + return math.Sqrt(harmonicSum/fundPower) * 100 +} + +// --- Opus specs --- + +var _ = Describe("Opus", func() { + var o *Opus + + BeforeEach(func() { + o = &Opus{} + Expect(o.Load(&pb.ModelOptions{})).To(Succeed()) + }) + + It("decodes Chrome-like VoIP frames", func() { + enc, err := NewEncoder(48000, 1, ApplicationVoIP) + Expect(err).ToNot(HaveOccurred()) + defer enc.Close() + Expect(enc.SetBitrate(32000)).To(Succeed()) + Expect(enc.SetComplexity(5)).To(Succeed()) + + sine := generateSineWave(440, 48000, 48000) + packet := make([]byte, 4000) + + var opusFrames [][]byte + for offset := 0; offset+opusFrameSize <= len(sine); offset += opusFrameSize { + frame := sine[offset : offset+opusFrameSize] + n, err := enc.Encode(frame, opusFrameSize, packet) + Expect(err).ToNot(HaveOccurred(), "VoIP encode") + out := make([]byte, n) + copy(out, packet[:n]) + opusFrames = append(opusFrames, out) + } + + result, err := o.AudioDecode(&pb.AudioDecodeRequest{Frames: opusFrames}) + Expect(err).ToNot(HaveOccurred()) + + allDecoded := sound.BytesToInt16sLE(result.PcmData) + Expect(allDecoded).ToNot(BeEmpty(), "no decoded samples from VoIP encoder") + + skip := min(len(allDecoded)/4, 48000*100/1000) + tail := allDecoded[skip:] + rms := computeRMS(tail) + + GinkgoWriter.Printf("VoIP/SILK roundtrip: %d decoded samples, RMS=%.1f\n", len(allDecoded), rms) + Expect(rms).To(BeNumerically(">=", 50), "VoIP decoded RMS is too low; SILK decoder may be broken") + }) + + It("decodes stereo-encoded Opus with a mono decoder", func() { + enc, err := NewEncoder(48000, 2, ApplicationVoIP) + Expect(err).ToNot(HaveOccurred()) + defer enc.Close() + Expect(enc.SetBitrate(32000)).To(Succeed()) + + mono := generateSineWave(440, 48000, 48000) + stereo := make([]int16, len(mono)*2) + for i, s := range mono { + stereo[i*2] = s + stereo[i*2+1] = s + } + + packet := make([]byte, 4000) + var opusFrames [][]byte + for offset := 0; offset+opusFrameSize*2 <= len(stereo); offset += opusFrameSize * 2 { + frame := stereo[offset : offset+opusFrameSize*2] + n, err := enc.Encode(frame, opusFrameSize, packet) + Expect(err).ToNot(HaveOccurred(), "Stereo encode") + out := make([]byte, n) + copy(out, packet[:n]) + opusFrames = append(opusFrames, out) + } + + result, err := o.AudioDecode(&pb.AudioDecodeRequest{Frames: opusFrames}) + Expect(err).ToNot(HaveOccurred()) + + allDecoded := sound.BytesToInt16sLE(result.PcmData) + Expect(allDecoded).ToNot(BeEmpty(), "no decoded samples from stereo encoder") + + skip := min(len(allDecoded)/4, 48000*100/1000) + tail := allDecoded[skip:] + rms := computeRMS(tail) + + GinkgoWriter.Printf("Stereo->Mono: %d decoded samples, RMS=%.1f\n", len(allDecoded), rms) + Expect(rms).To(BeNumerically(">=", 50), "Stereo->Mono decoded RMS is too low") + }) + + Describe("decoding libopus-encoded audio", func() { + var ffmpegPath string + var tmpDir string + var pcmPath string + var sine []int16 + + BeforeEach(func() { + var err error + ffmpegPath, err = exec.LookPath("ffmpeg") + if err != nil { + Skip("ffmpeg not found") + } + + tmpDir = GinkgoT().TempDir() + + sine = generateSineWave(440, 48000, 48000) + pcmBytes := sound.Int16toBytesLE(sine) + pcmPath = filepath.Join(tmpDir, "input.raw") + Expect(os.WriteFile(pcmPath, pcmBytes, 0644)).To(Succeed()) + }) + + for _, tc := range []struct { + name string + bitrate string + app string + }{ + {"voip_32k", "32000", "voip"}, + {"voip_64k", "64000", "voip"}, + {"audio_64k", "64000", "audio"}, + {"audio_128k", "128000", "audio"}, + } { + tc := tc + It(tc.name, func() { + oggPath := filepath.Join(tmpDir, fmt.Sprintf("libopus_%s_%s.ogg", tc.app, tc.bitrate)) + cmd := exec.Command(ffmpegPath, + "-y", + "-f", "s16le", "-ar", "48000", "-ac", "1", "-i", pcmPath, + "-c:a", "libopus", + "-b:a", tc.bitrate, + "-application", tc.app, + "-frame_duration", "20", + "-vbr", "on", + oggPath, + ) + out, err := cmd.CombinedOutput() + Expect(err).ToNot(HaveOccurred(), fmt.Sprintf("ffmpeg encode: %s", out)) + + oggData, err := os.ReadFile(oggPath) + Expect(err).ToNot(HaveOccurred()) + + opusFrames := extractOpusFramesFromOgg(oggData) + Expect(opusFrames).ToNot(BeEmpty(), "no Opus frames extracted from Ogg container") + GinkgoWriter.Printf("Extracted %d Opus frames from libopus encoder (first frame %d bytes)\n", len(opusFrames), len(opusFrames[0])) + + result, err := o.AudioDecode(&pb.AudioDecodeRequest{Frames: opusFrames}) + Expect(err).ToNot(HaveOccurred()) + + allDecoded := sound.BytesToInt16sLE(result.PcmData) + Expect(allDecoded).ToNot(BeEmpty(), "no decoded samples from libopus-encoded Opus") + + skip := min(len(allDecoded)/4, 48000*100/1000) + tail := allDecoded[skip:] + rms := computeRMS(tail) + freq := estimateFrequency(tail, 48000) + + GinkgoWriter.Printf("libopus->opus-go: %d decoded samples, RMS=%.1f, freq≈%.0f Hz\n", len(allDecoded), rms, freq) + + Expect(rms).To(BeNumerically(">=", 50), "RMS is too low — opus-go cannot decode libopus output") + Expect(freq).To(BeNumerically("~", 440, 30), fmt.Sprintf("frequency %.0f Hz deviates from expected 440 Hz", freq)) + }) + } + }) + + It("roundtrips at 48kHz", func() { + sine := generateSineWave(440, 48000, 48000) + pcmBytes := sound.Int16toBytesLE(sine) + + decoded := encodeDecodeRoundtrip(o, pcmBytes, 48000) + Expect(decoded).ToNot(BeEmpty()) + + decodedSR := 48000 + skipDecoded := decodedSR * 50 / 1000 + if skipDecoded > len(decoded)/2 { + skipDecoded = len(decoded) / 4 + } + tail := decoded[skipDecoded:] + + rms := computeRMS(tail) + GinkgoWriter.Printf("48kHz roundtrip: %d decoded samples, RMS=%.1f\n", len(decoded), rms) + + Expect(rms).To(BeNumerically(">=", 50), "decoded audio RMS is too low; signal appears silent") + }) + + It("roundtrips at 16kHz", func() { + sine16k := generateSineWave(440, 16000, 16000) + pcmBytes := sound.Int16toBytesLE(sine16k) + + decoded := encodeDecodeRoundtrip(o, pcmBytes, 16000) + Expect(decoded).ToNot(BeEmpty()) + + decoded16k := sound.ResampleInt16(decoded, 48000, 16000) + + skip := min(len(decoded16k)/4, 16000*50/1000) + tail := decoded16k[skip:] + + rms := computeRMS(tail) + GinkgoWriter.Printf("16kHz roundtrip: %d decoded@48k -> %d resampled@16k, RMS=%.1f\n", + len(decoded), len(decoded16k), rms) + + Expect(rms).To(BeNumerically(">=", 50), "decoded audio RMS is too low; signal appears silent") + }) + + It("returns empty frames for empty input", func() { + result, err := o.AudioEncode(&pb.AudioEncodeRequest{ + PcmData: []byte{}, + SampleRate: 48000, + Channels: 1, + }) + Expect(err).ToNot(HaveOccurred()) + Expect(result.Frames).To(BeEmpty()) + }) + + It("silently drops sub-frame input", func() { + sine := generateSineWave(440, 48000, 500) // < 960 + pcmBytes := sound.Int16toBytesLE(sine) + + result, err := o.AudioEncode(&pb.AudioEncodeRequest{ + PcmData: pcmBytes, + SampleRate: 48000, + Channels: 1, + }) + Expect(err).ToNot(HaveOccurred()) + Expect(result.Frames).To(BeEmpty(), fmt.Sprintf("expected 0 frames for %d samples (< 960)", len(sine))) + }) + + It("encodes multiple frames", func() { + sine := generateSineWave(440, 48000, 2880) // exactly 3 frames + pcmBytes := sound.Int16toBytesLE(sine) + + result, err := o.AudioEncode(&pb.AudioEncodeRequest{ + PcmData: pcmBytes, + SampleRate: 48000, + Channels: 1, + }) + Expect(err).ToNot(HaveOccurred()) + Expect(result.Frames).To(HaveLen(3)) + }) + + It("produces expected decoded frame size", func() { + sine := generateSineWave(440, 48000, 960) + pcmBytes := sound.Int16toBytesLE(sine) + + encResult, err := o.AudioEncode(&pb.AudioEncodeRequest{ + PcmData: pcmBytes, + SampleRate: 48000, + Channels: 1, + }) + Expect(err).ToNot(HaveOccurred()) + Expect(encResult.Frames).To(HaveLen(1)) + + decResult, err := o.AudioDecode(&pb.AudioDecodeRequest{ + Frames: encResult.Frames, + }) + Expect(err).ToNot(HaveOccurred()) + + decoded := sound.BytesToInt16sLE(decResult.PcmData) + GinkgoWriter.Printf("Encoder input: 960 samples (20ms @ 48kHz)\n") + GinkgoWriter.Printf("Decoder output: %d samples (%.1fms @ 48kHz)\n", + len(decoded), float64(len(decoded))/48.0) + + Expect(len(decoded)).To(SatisfyAny(Equal(960), Equal(480)), + fmt.Sprintf("unexpected decoded frame size %d", len(decoded))) + }) + + It("handles the full WebRTC output path", func() { + sine16k := generateSineWave(440, 16000, 16000) + pcmBytes := sound.Int16toBytesLE(sine16k) + + decoded := encodeDecodeRoundtrip(o, pcmBytes, 16000) + Expect(decoded).ToNot(BeEmpty()) + + rms := computeRMS(decoded) + GinkgoWriter.Printf("WebRTC output path: %d decoded samples at 48kHz, RMS=%.1f\n", len(decoded), rms) + + Expect(rms).To(BeNumerically(">=", 50), "decoded audio RMS is too low") + }) + + It("handles the full WebRTC input path", func() { + sine48k := generateSineWave(440, 48000, 48000) + pcmBytes := sound.Int16toBytesLE(sine48k) + + decoded48k := encodeDecodeRoundtrip(o, pcmBytes, 48000) + Expect(decoded48k).ToNot(BeEmpty()) + + step24k := sound.ResampleInt16(decoded48k, 48000, 24000) + webrtcPath := sound.ResampleInt16(step24k, 24000, 16000) + + rms := computeRMS(webrtcPath) + GinkgoWriter.Printf("WebRTC input path: %d decoded@48k -> %d@24k -> %d@16k, RMS=%.1f\n", + len(decoded48k), len(step24k), len(webrtcPath), rms) + + Expect(rms).To(BeNumerically(">=", 50), "WebRTC input path signal lost in pipeline") + }) + + Context("bug documentation", func() { + It("documents trailing sample loss", func() { + sine := generateSineWave(440, 48000, 1000) + pcmBytes := sound.Int16toBytesLE(sine) + + result, err := o.AudioEncode(&pb.AudioEncodeRequest{ + PcmData: pcmBytes, + SampleRate: 48000, + Channels: 1, + }) + Expect(err).ToNot(HaveOccurred()) + Expect(result.Frames).To(HaveLen(1)) + + decResult, err := o.AudioDecode(&pb.AudioDecodeRequest{Frames: result.Frames}) + Expect(err).ToNot(HaveOccurred()) + + decoded := sound.BytesToInt16sLE(decResult.PcmData) + GinkgoWriter.Printf("Input: 1000 samples, Encoded: 1 frame, Decoded: %d samples (40 samples lost)\n", len(decoded)) + Expect(len(decoded)).To(BeNumerically("<=", 960), + fmt.Sprintf("decoded more samples (%d) than the encoder consumed (960)", len(decoded))) + }) + + It("documents TTS sample rate mismatch", func() { + sine24k := generateSineWave(440, 24000, 24000) + pcmBytes := sound.Int16toBytesLE(sine24k) + + decodedBug := encodeDecodeRoundtrip(o, pcmBytes, 16000) + decodedCorrect := encodeDecodeRoundtrip(o, pcmBytes, 24000) + + skipBug := min(len(decodedBug)/4, 48000*100/1000) + skipCorrect := min(len(decodedCorrect)/4, 48000*100/1000) + + bugTail := decodedBug[skipBug:] + correctTail := decodedCorrect[skipCorrect:] + + bugFreq := estimateFrequency(bugTail, 48000) + correctFreq := estimateFrequency(correctTail, 48000) + + GinkgoWriter.Printf("Bug path: %d decoded samples, freq≈%.0f Hz (expected ~660 Hz = 440*1.5)\n", len(decodedBug), bugFreq) + GinkgoWriter.Printf("Correct path: %d decoded samples, freq≈%.0f Hz (expected ~440 Hz)\n", len(decodedCorrect), correctFreq) + + if len(decodedBug) > 0 && len(decodedCorrect) > 0 { + ratio := float64(len(decodedBug)) / float64(len(decodedCorrect)) + GinkgoWriter.Printf("Sample count ratio (bug/correct): %.2f (expected ~1.5)\n", ratio) + Expect(ratio).To(BeNumerically(">=", 1.1), + "expected bug path to produce significantly more samples due to wrong resample ratio") + } + }) + }) + + Context("batch boundary discontinuity", func() { + // These tests simulate the exact production pipeline: + // Browser encodes → RTP → batch 15 frames (300ms) → decode → resample 48k→16k → append + // They test both with and without persistent decoders to verify + // that the session_id persistent decoder path works correctly. + + It("batched decode+resample with persistent decoder matches one-shot", func() { + // Encode 3 seconds of 440Hz at 48kHz — enough for 10 batches + sine := generateSineWave(440, 48000, 48000*3) + pcmBytes := sound.Int16toBytesLE(sine) + + encResult, err := o.AudioEncode(&pb.AudioEncodeRequest{ + PcmData: pcmBytes, + SampleRate: 48000, + Channels: 1, + }) + Expect(err).ToNot(HaveOccurred()) + GinkgoWriter.Printf("Encoded %d frames (%.0fms)\n", len(encResult.Frames), + float64(len(encResult.Frames))*20.0) + + // Ground truth: decode ALL frames with one decoder, resample in one shot + decAll, err := o.AudioDecode(&pb.AudioDecodeRequest{ + Frames: encResult.Frames, + Options: map[string]string{"session_id": "ground-truth"}, + }) + Expect(err).ToNot(HaveOccurred()) + allSamples := sound.BytesToInt16sLE(decAll.PcmData) + oneShotResampled := sound.ResampleInt16(allSamples, 48000, 16000) + + // Production path: decode in 15-frame batches with persistent decoder, + // resample each batch independently, concatenate + const framesPerBatch = 15 + sessionID := "batch-test" + var batchedResampled []int16 + batchCount := 0 + for i := 0; i < len(encResult.Frames); i += framesPerBatch { + end := min(i+framesPerBatch, len(encResult.Frames)) + + decBatch, err := o.AudioDecode(&pb.AudioDecodeRequest{ + Frames: encResult.Frames[i:end], + Options: map[string]string{"session_id": sessionID}, + }) + Expect(err).ToNot(HaveOccurred()) + + batchSamples := sound.BytesToInt16sLE(decBatch.PcmData) + batchResampled := sound.ResampleInt16(batchSamples, 48000, 16000) + batchedResampled = append(batchedResampled, batchResampled...) + batchCount++ + } + + GinkgoWriter.Printf("Decoded in %d batches, oneshot=%d samples, batched=%d samples\n", + batchCount, len(oneShotResampled), len(batchedResampled)) + + // Skip codec startup transient (first 100ms) + skip := 16000 * 100 / 1000 + oneShotTail := oneShotResampled[skip:] + batchedTail := batchedResampled[skip:] + minLen := min(len(oneShotTail), len(batchedTail)) + + // With persistent decoder, batched decode should be nearly identical + // to one-shot (only difference is resampler batch boundaries). + var maxDiff float64 + var sumDiffSq float64 + for i := 0; i < minLen; i++ { + diff := math.Abs(float64(oneShotTail[i]) - float64(batchedTail[i])) + if diff > maxDiff { + maxDiff = diff + } + sumDiffSq += diff * diff + } + rmsDiff := math.Sqrt(sumDiffSq / float64(minLen)) + + GinkgoWriter.Printf("Persistent decoder: maxDiff=%.0f, rmsDiff=%.1f\n", maxDiff, rmsDiff) + + // Tight threshold: with persistent decoder and fixed resampler, + // the output should be very close to one-shot + Expect(maxDiff).To(BeNumerically("<", 500), + "persistent decoder batched path diverges too much from one-shot") + Expect(rmsDiff).To(BeNumerically("<", 50), + "RMS deviation too high between batched and one-shot") + }) + + It("fresh decoder per batch produces worse quality than persistent", func() { + // This test proves the value of persistent decoders by showing + // that fresh decoders produce larger deviations at batch boundaries. + sine := generateSineWave(440, 48000, 48000*2) + pcmBytes := sound.Int16toBytesLE(sine) + + encResult, err := o.AudioEncode(&pb.AudioEncodeRequest{ + PcmData: pcmBytes, + SampleRate: 48000, + Channels: 1, + }) + Expect(err).ToNot(HaveOccurred()) + + // Ground truth: one-shot decode + decAll, err := o.AudioDecode(&pb.AudioDecodeRequest{ + Frames: encResult.Frames, + Options: map[string]string{"session_id": "ref"}, + }) + Expect(err).ToNot(HaveOccurred()) + refSamples := sound.BytesToInt16sLE(decAll.PcmData) + + const framesPerBatch = 15 + + // Path A: persistent decoder + var persistentSamples []int16 + for i := 0; i < len(encResult.Frames); i += framesPerBatch { + end := min(i+framesPerBatch, len(encResult.Frames)) + dec, err := o.AudioDecode(&pb.AudioDecodeRequest{ + Frames: encResult.Frames[i:end], + Options: map[string]string{"session_id": "persistent"}, + }) + Expect(err).ToNot(HaveOccurred()) + persistentSamples = append(persistentSamples, sound.BytesToInt16sLE(dec.PcmData)...) + } + + // Path B: fresh decoder per batch (no session_id) + var freshSamples []int16 + for i := 0; i < len(encResult.Frames); i += framesPerBatch { + end := min(i+framesPerBatch, len(encResult.Frames)) + dec, err := o.AudioDecode(&pb.AudioDecodeRequest{ + Frames: encResult.Frames[i:end], + }) + Expect(err).ToNot(HaveOccurred()) + freshSamples = append(freshSamples, sound.BytesToInt16sLE(dec.PcmData)...) + } + + // Compare both to reference + skip := 48000 * 100 / 1000 + refTail := refSamples[skip:] + persistentTail := persistentSamples[skip:] + freshTail := freshSamples[skip:] + minLen := min(len(refTail), min(len(persistentTail), len(freshTail))) + + var persistentMaxDiff, freshMaxDiff float64 + for i := 0; i < minLen; i++ { + pd := math.Abs(float64(refTail[i]) - float64(persistentTail[i])) + fd := math.Abs(float64(refTail[i]) - float64(freshTail[i])) + if pd > persistentMaxDiff { + persistentMaxDiff = pd + } + if fd > freshMaxDiff { + freshMaxDiff = fd + } + } + + GinkgoWriter.Printf("vs reference: persistent maxDiff=%.0f, fresh maxDiff=%.0f\n", + persistentMaxDiff, freshMaxDiff) + + // Persistent decoder should be closer to reference than fresh + Expect(persistentMaxDiff).To(BeNumerically("<=", freshMaxDiff), + "persistent decoder should match reference at least as well as fresh decoder") + }) + + It("checks for PCM discontinuities at batch boundaries", func() { + // Encode 2 seconds, decode in batches, resample, and check + // for anomalous jumps at the exact batch boundaries in the output + sine := generateSineWave(440, 48000, 48000*2) + pcmBytes := sound.Int16toBytesLE(sine) + + encResult, err := o.AudioEncode(&pb.AudioEncodeRequest{ + PcmData: pcmBytes, + SampleRate: 48000, + Channels: 1, + }) + Expect(err).ToNot(HaveOccurred()) + + const framesPerBatch = 15 + sessionID := "boundary-check" + var batchedOutput []int16 + var batchBoundaries []int // indices where batch boundaries fall in output + for i := 0; i < len(encResult.Frames); i += framesPerBatch { + end := min(i+framesPerBatch, len(encResult.Frames)) + + dec, err := o.AudioDecode(&pb.AudioDecodeRequest{ + Frames: encResult.Frames[i:end], + Options: map[string]string{"session_id": sessionID}, + }) + Expect(err).ToNot(HaveOccurred()) + + batchSamples := sound.BytesToInt16sLE(dec.PcmData) + batchResampled := sound.ResampleInt16(batchSamples, 48000, 16000) + + if len(batchedOutput) > 0 { + batchBoundaries = append(batchBoundaries, len(batchedOutput)) + } + batchedOutput = append(batchedOutput, batchResampled...) + } + + GinkgoWriter.Printf("Output: %d samples, %d batch boundaries\n", + len(batchedOutput), len(batchBoundaries)) + + // For each batch boundary, check if the sample-to-sample jump + // is anomalously large compared to neighboring deltas + for bIdx, boundary := range batchBoundaries { + if boundary < 10 || boundary+10 >= len(batchedOutput) { + continue + } + + jump := math.Abs(float64(batchedOutput[boundary]) - float64(batchedOutput[boundary-1])) + + // Compute average delta in the 20-sample neighborhood (excluding boundary) + var avgDelta float64 + count := 0 + for i := boundary - 10; i < boundary+10; i++ { + if i == boundary-1 || i == boundary { + continue + } + if i+1 < len(batchedOutput) { + avgDelta += math.Abs(float64(batchedOutput[i+1]) - float64(batchedOutput[i])) + count++ + } + } + if count > 0 { + avgDelta /= float64(count) + } + + ratio := 0.0 + if avgDelta > 0 { + ratio = jump / avgDelta + } + + GinkgoWriter.Printf("Boundary %d (idx %d): jump=%.0f, avg_delta=%.0f, ratio=%.1f\n", + bIdx, boundary, jump, avgDelta, ratio) + + // The boundary jump should not be more than 5x the average + // (with codec artifacts, some variation is expected) + Expect(jump).To(BeNumerically("<=", avgDelta*5+1), + fmt.Sprintf("discontinuity at batch boundary %d: jump=%.0f vs avg=%.0f (ratio=%.1f)", + bIdx, jump, avgDelta, ratio)) + } + }) + + It("maintains sine wave phase continuity across batches", func() { + sine := generateSineWave(440, 48000, 48000*2) // 2 seconds + pcmBytes := sound.Int16toBytesLE(sine) + + encResult, err := o.AudioEncode(&pb.AudioEncodeRequest{ + PcmData: pcmBytes, + SampleRate: 48000, + Channels: 1, + }) + Expect(err).ToNot(HaveOccurred()) + + // Decode in batches with persistent decoder, resample each + const framesPerBatch = 15 + sessionID := "phase-test" + var fullOutput []int16 + for i := 0; i < len(encResult.Frames); i += framesPerBatch { + end := min(i+framesPerBatch, len(encResult.Frames)) + dec, err := o.AudioDecode(&pb.AudioDecodeRequest{ + Frames: encResult.Frames[i:end], + Options: map[string]string{"session_id": sessionID}, + }) + Expect(err).ToNot(HaveOccurred()) + samples := sound.BytesToInt16sLE(dec.PcmData) + resampled := sound.ResampleInt16(samples, 48000, 16000) + fullOutput = append(fullOutput, resampled...) + } + + // Check zero-crossing regularity after startup transient + skip := 16000 * 200 / 1000 // skip first 200ms + tail := fullOutput[skip:] + + var crossingPositions []int + for i := 1; i < len(tail); i++ { + if (tail[i-1] >= 0 && tail[i] < 0) || (tail[i-1] < 0 && tail[i] >= 0) { + crossingPositions = append(crossingPositions, i) + } + } + Expect(crossingPositions).ToNot(BeEmpty(), "no zero crossings found") + + var intervals []float64 + for i := 1; i < len(crossingPositions); i++ { + intervals = append(intervals, float64(crossingPositions[i]-crossingPositions[i-1])) + } + + var sum float64 + for _, v := range intervals { + sum += v + } + mean := sum / float64(len(intervals)) + + var variance float64 + for _, v := range intervals { + d := v - mean + variance += d * d + } + stddev := math.Sqrt(variance / float64(len(intervals))) + + GinkgoWriter.Printf("Zero-crossing intervals: mean=%.2f stddev=%.2f CV=%.3f (expected period ~%.1f)\n", + mean, stddev, stddev/mean, 16000.0/440.0/2.0) + + Expect(stddev / mean).To(BeNumerically("<", 0.15), + fmt.Sprintf("irregular zero crossings suggest discontinuity: CV=%.3f", stddev/mean)) + + // Also check frequency is correct + freq := estimateFrequency(tail, 16000) + GinkgoWriter.Printf("Estimated frequency: %.0f Hz (expected 440)\n", freq) + Expect(freq).To(BeNumerically("~", 440, 20)) + }) + + It("produces identical resampled output for batched vs one-shot resample", func() { + // Isolate the resampler from the codec: decode once, then compare + // one-shot resample vs batched resample of the same PCM. + sine := generateSineWave(440, 48000, 48000*3) + pcmBytes := sound.Int16toBytesLE(sine) + + encResult, err := o.AudioEncode(&pb.AudioEncodeRequest{ + PcmData: pcmBytes, + SampleRate: 48000, + Channels: 1, + }) + Expect(err).ToNot(HaveOccurred()) + + decResult, err := o.AudioDecode(&pb.AudioDecodeRequest{ + Frames: encResult.Frames, + Options: map[string]string{"session_id": "resample-test"}, + }) + Expect(err).ToNot(HaveOccurred()) + allSamples := sound.BytesToInt16sLE(decResult.PcmData) + + // One-shot resample + oneShot := sound.ResampleInt16(allSamples, 48000, 16000) + + // Batched resample (300ms chunks at 48kHz = 14400 samples) + batchSize := 48000 * 300 / 1000 + var batched []int16 + for offset := 0; offset < len(allSamples); offset += batchSize { + end := min(offset+batchSize, len(allSamples)) + chunk := sound.ResampleInt16(allSamples[offset:end], 48000, 16000) + batched = append(batched, chunk...) + } + + Expect(len(batched)).To(Equal(len(oneShot)), + fmt.Sprintf("length mismatch: batched=%d oneshot=%d", len(batched), len(oneShot))) + + // Every sample must be identical — the resampler is deterministic + var maxDiff float64 + for i := 0; i < len(oneShot); i++ { + diff := math.Abs(float64(oneShot[i]) - float64(batched[i])) + if diff > maxDiff { + maxDiff = diff + } + } + + GinkgoWriter.Printf("Resample-only: batched vs one-shot maxDiff=%.0f\n", maxDiff) + Expect(maxDiff).To(BeNumerically("==", 0), + "batched resample should produce identical output to one-shot resample") + }) + + It("writes WAV files for manual inspection", func() { + // This test writes WAV files of the batched vs one-shot pipeline + // so you can visually/audibly inspect for discontinuities. + tmpDir := GinkgoT().TempDir() + + sine := generateSineWave(440, 48000, 48000*3) // 3 seconds + pcmBytes := sound.Int16toBytesLE(sine) + + encResult, err := o.AudioEncode(&pb.AudioEncodeRequest{ + PcmData: pcmBytes, + SampleRate: 48000, + Channels: 1, + }) + Expect(err).ToNot(HaveOccurred()) + + // One-shot path (reference) + decAll, err := o.AudioDecode(&pb.AudioDecodeRequest{ + Frames: encResult.Frames, + Options: map[string]string{"session_id": "wav-ref"}, + }) + Expect(err).ToNot(HaveOccurred()) + refSamples := sound.BytesToInt16sLE(decAll.PcmData) + refResampled := sound.ResampleInt16(refSamples, 48000, 16000) + + // Batched path (production simulation) + const framesPerBatch = 15 + var batchedResampled []int16 + for i := 0; i < len(encResult.Frames); i += framesPerBatch { + end := min(i+framesPerBatch, len(encResult.Frames)) + dec, err := o.AudioDecode(&pb.AudioDecodeRequest{ + Frames: encResult.Frames[i:end], + Options: map[string]string{"session_id": "wav-batched"}, + }) + Expect(err).ToNot(HaveOccurred()) + samples := sound.BytesToInt16sLE(dec.PcmData) + resampled := sound.ResampleInt16(samples, 48000, 16000) + batchedResampled = append(batchedResampled, resampled...) + } + + // Write WAV files + writeWAV := func(path string, samples []int16, sampleRate int) { + dataLen := len(samples) * 2 + hdr := make([]byte, 44) + copy(hdr[0:4], "RIFF") + binary.LittleEndian.PutUint32(hdr[4:8], uint32(36+dataLen)) + copy(hdr[8:12], "WAVE") + copy(hdr[12:16], "fmt ") + binary.LittleEndian.PutUint32(hdr[16:20], 16) // chunk size + binary.LittleEndian.PutUint16(hdr[20:22], 1) // PCM + binary.LittleEndian.PutUint16(hdr[22:24], 1) // mono + binary.LittleEndian.PutUint32(hdr[24:28], uint32(sampleRate)) // sample rate + binary.LittleEndian.PutUint32(hdr[28:32], uint32(sampleRate*2)) // byte rate + binary.LittleEndian.PutUint16(hdr[32:34], 2) // block align + binary.LittleEndian.PutUint16(hdr[34:36], 16) // bits per sample + copy(hdr[36:40], "data") + binary.LittleEndian.PutUint32(hdr[40:44], uint32(dataLen)) + + f, err := os.Create(path) + Expect(err).ToNot(HaveOccurred()) + defer f.Close() + _, err = f.Write(hdr) + Expect(err).ToNot(HaveOccurred()) + _, err = f.Write(sound.Int16toBytesLE(samples)) + Expect(err).ToNot(HaveOccurred()) + } + + refPath := filepath.Join(tmpDir, "oneshot_16k.wav") + batchedPath := filepath.Join(tmpDir, "batched_16k.wav") + writeWAV(refPath, refResampled, 16000) + writeWAV(batchedPath, batchedResampled, 16000) + + GinkgoWriter.Printf("WAV files written for manual inspection:\n") + GinkgoWriter.Printf(" Reference: %s\n", refPath) + GinkgoWriter.Printf(" Batched: %s\n", batchedPath) + GinkgoWriter.Printf(" Ref samples: %d, Batched samples: %d\n", + len(refResampled), len(batchedResampled)) + }) + }) + + It("produces frames decodable by ffmpeg (cross-library compat)", func() { + ffmpegPath, err := exec.LookPath("ffmpeg") + if err != nil { + Skip("ffmpeg not found") + } + + sine := generateSineWave(440, 48000, 48000) + pcmBytes := sound.Int16toBytesLE(sine) + + result, err := o.AudioEncode(&pb.AudioEncodeRequest{ + PcmData: pcmBytes, + SampleRate: 48000, + Channels: 1, + }) + Expect(err).ToNot(HaveOccurred()) + Expect(result.Frames).ToNot(BeEmpty()) + GinkgoWriter.Printf("opus-go produced %d frames (first frame %d bytes)\n", len(result.Frames), len(result.Frames[0])) + + tmpDir := GinkgoT().TempDir() + oggPath := filepath.Join(tmpDir, "opus_go_output.ogg") + Expect(writeOggOpus(oggPath, result.Frames, 48000, 1)).To(Succeed()) + + decodedWavPath := filepath.Join(tmpDir, "ffmpeg_decoded.wav") + cmd := exec.Command(ffmpegPath, "-y", "-i", oggPath, "-ar", "48000", "-ac", "1", "-c:a", "pcm_s16le", decodedWavPath) + out, err := cmd.CombinedOutput() + Expect(err).ToNot(HaveOccurred(), fmt.Sprintf("ffmpeg failed to decode opus-go output: %s", out)) + + decodedData, err := os.ReadFile(decodedWavPath) + Expect(err).ToNot(HaveOccurred()) + + decodedPCM, sr := parseTestWAV(decodedData) + Expect(sr).ToNot(BeZero(), "ffmpeg output has no WAV header") + decodedSamples := sound.BytesToInt16sLE(decodedPCM) + + skip := min(len(decodedSamples)/4, sr*100/1000) + if skip >= len(decodedSamples) { + skip = 0 + } + tail := decodedSamples[skip:] + rms := computeRMS(tail) + + GinkgoWriter.Printf("ffmpeg decoded opus-go output: %d samples at %dHz, RMS=%.1f\n", len(decodedSamples), sr, rms) + + Expect(rms).To(BeNumerically(">=", 50), + "ffmpeg decoded RMS is too low — opus-go frames are likely incompatible with standard decoders") + }) + + It("delivers audio through a full WebRTC pipeline", func() { + const ( + toneFreq = 440.0 + toneSampleRate = 24000 + toneDuration = 1 + toneAmplitude = 16000 + toneNumSamples = toneSampleRate * toneDuration + ) + + pcm := make([]byte, toneNumSamples*2) + for i := 0; i < toneNumSamples; i++ { + sample := int16(toneAmplitude * math.Sin(2*math.Pi*toneFreq*float64(i)/float64(toneSampleRate))) + binary.LittleEndian.PutUint16(pcm[i*2:], uint16(sample)) + } + + encResult, err := o.AudioEncode(&pb.AudioEncodeRequest{ + PcmData: pcm, + SampleRate: toneSampleRate, + Channels: 1, + }) + Expect(err).ToNot(HaveOccurred()) + Expect(encResult.Frames).ToNot(BeEmpty()) + GinkgoWriter.Printf("Encoded %d Opus frames from %d PCM samples at %dHz\n", len(encResult.Frames), toneNumSamples, toneSampleRate) + + // Create sender PeerConnection + senderME := &webrtc.MediaEngine{} + Expect(senderME.RegisterDefaultCodecs()).To(Succeed()) + senderAPI := webrtc.NewAPI(webrtc.WithMediaEngine(senderME)) + senderPC, err := senderAPI.NewPeerConnection(webrtc.Configuration{}) + Expect(err).ToNot(HaveOccurred()) + defer senderPC.Close() + + audioTrack, err := webrtc.NewTrackLocalStaticRTP( + webrtc.RTPCodecCapability{ + MimeType: webrtc.MimeTypeOpus, + ClockRate: 48000, + Channels: 2, + }, + "audio", "test", + ) + Expect(err).ToNot(HaveOccurred()) + + rtpSender, err := senderPC.AddTrack(audioTrack) + Expect(err).ToNot(HaveOccurred()) + go func() { + buf := make([]byte, 1500) + for { + if _, _, err := rtpSender.Read(buf); err != nil { + return + } + } + }() + + // Create receiver PeerConnection + receiverME := &webrtc.MediaEngine{} + Expect(receiverME.RegisterDefaultCodecs()).To(Succeed()) + receiverAPI := webrtc.NewAPI(webrtc.WithMediaEngine(receiverME)) + receiverPC, err := receiverAPI.NewPeerConnection(webrtc.Configuration{}) + Expect(err).ToNot(HaveOccurred()) + defer receiverPC.Close() + + type receivedPacket struct { + seqNum uint16 + timestamp uint32 + marker bool + payload []byte + } + var ( + receivedMu sync.Mutex + receivedPackets []receivedPacket + trackDone = make(chan struct{}) + ) + + receiverPC.OnTrack(func(track *webrtc.TrackRemote, receiver *webrtc.RTPReceiver) { + defer close(trackDone) + for { + pkt, _, err := track.ReadRTP() + if err != nil { + return + } + payload := make([]byte, len(pkt.Payload)) + copy(payload, pkt.Payload) + receivedMu.Lock() + receivedPackets = append(receivedPackets, receivedPacket{ + seqNum: pkt.Header.SequenceNumber, + timestamp: pkt.Header.Timestamp, + marker: pkt.Header.Marker, + payload: payload, + }) + receivedMu.Unlock() + } + }) + + // Exchange SDP + offer, err := senderPC.CreateOffer(nil) + Expect(err).ToNot(HaveOccurred()) + Expect(senderPC.SetLocalDescription(offer)).To(Succeed()) + senderGatherDone := webrtc.GatheringCompletePromise(senderPC) + Eventually(senderGatherDone, 5*time.Second).Should(BeClosed()) + + Expect(receiverPC.SetRemoteDescription(*senderPC.LocalDescription())).To(Succeed()) + answer, err := receiverPC.CreateAnswer(nil) + Expect(err).ToNot(HaveOccurred()) + Expect(receiverPC.SetLocalDescription(answer)).To(Succeed()) + receiverGatherDone := webrtc.GatheringCompletePromise(receiverPC) + Eventually(receiverGatherDone, 5*time.Second).Should(BeClosed()) + + Expect(senderPC.SetRemoteDescription(*receiverPC.LocalDescription())).To(Succeed()) + + // Wait for connection + connected := make(chan struct{}) + senderPC.OnConnectionStateChange(func(s webrtc.PeerConnectionState) { + if s == webrtc.PeerConnectionStateConnected { + select { + case <-connected: + default: + close(connected) + } + } + }) + Eventually(connected, 5*time.Second).Should(BeClosed()) + + // Send test tone via RTP + const samplesPerFrame = 960 + seqNum := uint16(rand.UintN(65536)) + timestamp := rand.Uint32() + marker := true + + ticker := time.NewTicker(20 * time.Millisecond) + defer ticker.Stop() + + for i, frame := range encResult.Frames { + pkt := &rtp.Packet{ + Header: rtp.Header{ + Version: 2, + Marker: marker, + SequenceNumber: seqNum, + Timestamp: timestamp, + }, + Payload: frame, + } + seqNum++ + timestamp += samplesPerFrame + marker = false + + Expect(audioTrack.WriteRTP(pkt)).To(Succeed(), fmt.Sprintf("WriteRTP frame %d", i)) + if i < len(encResult.Frames)-1 { + <-ticker.C + } + } + + // Wait for packets to arrive + time.Sleep(500 * time.Millisecond) + + senderPC.Close() + + select { + case <-trackDone: + case <-time.After(2 * time.Second): + } + + // Decode received Opus frames via the backend + receivedMu.Lock() + pkts := make([]receivedPacket, len(receivedPackets)) + copy(pkts, receivedPackets) + receivedMu.Unlock() + + Expect(pkts).ToNot(BeEmpty(), "no RTP packets received") + + var receivedFrames [][]byte + for _, pkt := range pkts { + receivedFrames = append(receivedFrames, pkt.payload) + } + + decResult, err := o.AudioDecode(&pb.AudioDecodeRequest{Frames: receivedFrames}) + Expect(err).ToNot(HaveOccurred()) + + allDecoded := sound.BytesToInt16sLE(decResult.PcmData) + Expect(allDecoded).ToNot(BeEmpty(), "no decoded samples") + + // Analyse RTP packet delivery + frameLoss := len(encResult.Frames) - len(pkts) + seqGaps := 0 + for i := 1; i < len(pkts); i++ { + expected := pkts[i-1].seqNum + 1 + if pkts[i].seqNum != expected { + seqGaps++ + } + } + markerCount := 0 + for _, pkt := range pkts { + if pkt.marker { + markerCount++ + } + } + + GinkgoWriter.Println("── RTP Delivery ──") + GinkgoWriter.Printf(" Frames sent: %d\n", len(encResult.Frames)) + GinkgoWriter.Printf(" Packets recv: %d\n", len(pkts)) + GinkgoWriter.Printf(" Frame loss: %d\n", frameLoss) + GinkgoWriter.Printf(" Sequence gaps: %d\n", seqGaps) + GinkgoWriter.Printf(" Marker packets: %d (expect 1)\n", markerCount) + + // Audio quality metrics + skip := 48000 * 100 / 1000 + if skip > len(allDecoded)/2 { + skip = len(allDecoded) / 4 + } + tail := allDecoded[skip:] + + rms := computeRMS(tail) + freq := estimateFrequency(tail, 48000) + thd := computeTHD(tail, toneFreq, 48000, 10) + + GinkgoWriter.Println("── Audio Quality ──") + GinkgoWriter.Printf(" Decoded samples: %d (%.1f ms at 48kHz)\n", len(allDecoded), float64(len(allDecoded))/48.0) + GinkgoWriter.Printf(" RMS level: %.1f\n", rms) + GinkgoWriter.Printf(" Peak frequency: %.0f Hz (expected %.0f Hz)\n", freq, toneFreq) + GinkgoWriter.Printf(" THD (h2-h10): %.1f%%\n", thd) + + Expect(frameLoss).To(BeZero(), "lost frames in localhost transport") + Expect(seqGaps).To(BeZero(), "sequence number gaps detected") + Expect(markerCount).To(Equal(1), "expected exactly 1 marker packet") + Expect(rms).To(BeNumerically(">=", 50), "signal appears silent or severely attenuated") + Expect(freq).To(BeNumerically("~", toneFreq, 20), fmt.Sprintf("peak frequency %.0f Hz deviates from expected", freq)) + Expect(thd).To(BeNumerically("<", 50), "signal is severely distorted") + }) +}) diff --git a/backend/go/opus/package.sh b/backend/go/opus/package.sh new file mode 100644 index 000000000000..a55834f3e0b9 --- /dev/null +++ b/backend/go/opus/package.sh @@ -0,0 +1,47 @@ +#!/bin/bash +set -e + +CURDIR=$(dirname "$(realpath $0)") + +mkdir -p $CURDIR/package/lib + +cp -avf $CURDIR/opus $CURDIR/package/ +cp -avf $CURDIR/run.sh $CURDIR/package/ + +# Copy the opus shim library +cp -avf $CURDIR/libopusshim.so $CURDIR/package/lib/ + +# Copy system libopus +if command -v pkg-config >/dev/null 2>&1 && pkg-config --exists opus; then + LIBOPUS_DIR=$(pkg-config --variable=libdir opus) + cp -avfL $LIBOPUS_DIR/libopus.so* $CURDIR/package/lib/ 2>/dev/null || true +fi + +# Detect architecture and copy appropriate libraries +if [ -f "/lib64/ld-linux-x86-64.so.2" ]; then + echo "Detected x86_64 architecture, copying x86_64 libraries..." + cp -arfLv /lib64/ld-linux-x86-64.so.2 $CURDIR/package/lib/ld.so + cp -arfLv /lib/x86_64-linux-gnu/libc.so.6 $CURDIR/package/lib/libc.so.6 + cp -arfLv /lib/x86_64-linux-gnu/libgcc_s.so.1 $CURDIR/package/lib/libgcc_s.so.1 + cp -arfLv /lib/x86_64-linux-gnu/libstdc++.so.6 $CURDIR/package/lib/libstdc++.so.6 + cp -arfLv /lib/x86_64-linux-gnu/libm.so.6 $CURDIR/package/lib/libm.so.6 + cp -arfLv /lib/x86_64-linux-gnu/libdl.so.2 $CURDIR/package/lib/libdl.so.2 + cp -arfLv /lib/x86_64-linux-gnu/librt.so.1 $CURDIR/package/lib/librt.so.1 + cp -arfLv /lib/x86_64-linux-gnu/libpthread.so.0 $CURDIR/package/lib/libpthread.so.0 +elif [ -f "/lib/ld-linux-aarch64.so.1" ]; then + echo "Detected ARM64 architecture, copying ARM64 libraries..." + cp -arfLv /lib/ld-linux-aarch64.so.1 $CURDIR/package/lib/ld.so + cp -arfLv /lib/aarch64-linux-gnu/libc.so.6 $CURDIR/package/lib/libc.so.6 + cp -arfLv /lib/aarch64-linux-gnu/libgcc_s.so.1 $CURDIR/package/lib/libgcc_s.so.1 + cp -arfLv /lib/aarch64-linux-gnu/libstdc++.so.6 $CURDIR/package/lib/libstdc++.so.6 + cp -arfLv /lib/aarch64-linux-gnu/libm.so.6 $CURDIR/package/lib/libm.so.6 + cp -arfLv /lib/aarch64-linux-gnu/libdl.so.2 $CURDIR/package/lib/libdl.so.2 + cp -arfLv /lib/aarch64-linux-gnu/librt.so.1 $CURDIR/package/lib/librt.so.1 + cp -arfLv /lib/aarch64-linux-gnu/libpthread.so.0 $CURDIR/package/lib/libpthread.so.0 +else + echo "Warning: Could not detect architecture for system library bundling" +fi + +echo "Packaging completed successfully" +ls -liah $CURDIR/package/ +ls -liah $CURDIR/package/lib/ diff --git a/backend/go/opus/run.sh b/backend/go/opus/run.sh new file mode 100644 index 000000000000..d926c57d03f0 --- /dev/null +++ b/backend/go/opus/run.sh @@ -0,0 +1,15 @@ +#!/bin/bash +set -ex + +CURDIR=$(dirname "$(realpath $0)") + +export LD_LIBRARY_PATH=$CURDIR/lib:$LD_LIBRARY_PATH +export OPUS_SHIM_LIBRARY=$CURDIR/lib/libopusshim.so + +# If there is a lib/ld.so, use it +if [ -f $CURDIR/lib/ld.so ]; then + echo "Using lib/ld.so" + exec $CURDIR/lib/ld.so $CURDIR/opus "$@" +fi + +exec $CURDIR/opus "$@" diff --git a/backend/index.yaml b/backend/index.yaml index 619c17a82ba7..e5a4d3bbe5e1 100644 --- a/backend/index.yaml +++ b/backend/index.yaml @@ -724,6 +724,23 @@ tags: - text-to-speech - TTS +- &opus + name: "opus" + uri: "quay.io/go-skynet/local-ai-backends:latest-cpu-opus" + urls: + - https://opus-codec.org/ + mirrors: + - localai/localai-backends:latest-cpu-opus + license: BSD-3-Clause + description: | + Opus audio codec backend for encoding and decoding audio. + Required for WebRTC transport in the Realtime API. + tags: + - audio-codec + - opus + - WebRTC + - realtime + - CPU - &silero-vad name: "silero-vad" uri: "quay.io/go-skynet/local-ai-backends:latest-cpu-silero-vad" @@ -1088,6 +1105,21 @@ uri: "quay.io/go-skynet/local-ai-backends:master-metal-darwin-arm64-local-store" mirrors: - localai/localai-backends:master-metal-darwin-arm64-local-store +- !!merge <<: *opus + name: "opus-development" + uri: "quay.io/go-skynet/local-ai-backends:master-cpu-opus" + mirrors: + - localai/localai-backends:master-cpu-opus +- !!merge <<: *opus + name: "metal-opus" + uri: "quay.io/go-skynet/local-ai-backends:latest-metal-darwin-arm64-opus" + mirrors: + - localai/localai-backends:latest-metal-darwin-arm64-opus +- !!merge <<: *opus + name: "metal-opus-development" + uri: "quay.io/go-skynet/local-ai-backends:master-metal-darwin-arm64-opus" + mirrors: + - localai/localai-backends:master-metal-darwin-arm64-opus - !!merge <<: *silero-vad name: "silero-vad-development" uri: "quay.io/go-skynet/local-ai-backends:master-cpu-silero-vad" diff --git a/core/backend/llm.go b/core/backend/llm.go index 4b8f37bc98e5..db407e5a1f4b 100644 --- a/core/backend/llm.go +++ b/core/backend/llm.go @@ -84,6 +84,7 @@ func ModelInference(ctx context.Context, s string, messages schema.Messages, ima } // in GRPC, the backend is supposed to answer to 1 single token if stream is not supported + var capturedPredictOpts *proto.PredictOptions fn := func() (LLMResponse, error) { opts := gRPCPredictOpts(*c, loader.ModelPath) // Merge request-level metadata (overrides config defaults) @@ -111,6 +112,7 @@ func ModelInference(ctx context.Context, s string, messages schema.Messages, ima opts.LogitBias = string(logitBiasJSON) } } + capturedPredictOpts = opts tokenUsage := TokenUsage{} @@ -245,16 +247,12 @@ func ModelInference(ctx context.Context, s string, messages schema.Messages, ima trace.InitBackendTracingIfEnabled(o.TracingMaxItems) traceData := map[string]any{ - "prompt": s, - "use_tokenizer_template": c.TemplateConfig.UseTokenizerTemplate, - "chat_template": c.TemplateConfig.Chat, - "function_template": c.TemplateConfig.Functions, - "grammar": c.Grammar, - "stop_words": c.StopWords, - "streaming": tokenCallback != nil, - "images_count": len(images), - "videos_count": len(videos), - "audios_count": len(audios), + "chat_template": c.TemplateConfig.Chat, + "function_template": c.TemplateConfig.Functions, + "streaming": tokenCallback != nil, + "images_count": len(images), + "videos_count": len(videos), + "audios_count": len(audios), } if len(messages) > 0 { @@ -262,12 +260,6 @@ func ModelInference(ctx context.Context, s string, messages schema.Messages, ima traceData["messages"] = string(msgJSON) } } - if tools != "" { - traceData["tools"] = tools - } - if toolChoice != "" { - traceData["tool_choice"] = toolChoice - } if reasoningJSON, err := json.Marshal(c.ReasoningConfig); err == nil { traceData["reasoning_config"] = string(reasoningJSON) } @@ -277,15 +269,6 @@ func ModelInference(ctx context.Context, s string, messages schema.Messages, ima "mixed_mode": c.FunctionsConfig.GrammarConfig.MixedMode, "xml_format_preset": c.FunctionsConfig.XMLFormatPreset, } - if c.Temperature != nil { - traceData["temperature"] = *c.Temperature - } - if c.TopP != nil { - traceData["top_p"] = *c.TopP - } - if c.Maxtokens != nil { - traceData["max_tokens"] = *c.Maxtokens - } startTime := time.Now() originalFn := fn @@ -299,6 +282,42 @@ func ModelInference(ctx context.Context, s string, messages schema.Messages, ima "completion": resp.Usage.Completion, } + if len(resp.ChatDeltas) > 0 { + chatDeltasInfo := map[string]any{ + "total_deltas": len(resp.ChatDeltas), + } + var contentParts, reasoningParts []string + toolCallCount := 0 + for _, d := range resp.ChatDeltas { + if d.Content != "" { + contentParts = append(contentParts, d.Content) + } + if d.ReasoningContent != "" { + reasoningParts = append(reasoningParts, d.ReasoningContent) + } + toolCallCount += len(d.ToolCalls) + } + if len(contentParts) > 0 { + chatDeltasInfo["content"] = strings.Join(contentParts, "") + } + if len(reasoningParts) > 0 { + chatDeltasInfo["reasoning_content"] = strings.Join(reasoningParts, "") + } + if toolCallCount > 0 { + chatDeltasInfo["tool_call_count"] = toolCallCount + } + traceData["chat_deltas"] = chatDeltasInfo + } + + if capturedPredictOpts != nil { + if optsJSON, err := json.Marshal(capturedPredictOpts); err == nil { + var optsMap map[string]any + if err := json.Unmarshal(optsJSON, &optsMap); err == nil { + traceData["predict_options"] = optsMap + } + } + } + errStr := "" if err != nil { errStr = err.Error() diff --git a/core/backend/transcript.go b/core/backend/transcript.go index dbbf718a3a48..7568e4e40706 100644 --- a/core/backend/transcript.go +++ b/core/backend/transcript.go @@ -3,11 +3,12 @@ package backend import ( "context" "fmt" + "maps" "time" "github.com/mudler/LocalAI/core/config" - "github.com/mudler/LocalAI/core/trace" "github.com/mudler/LocalAI/core/schema" + "github.com/mudler/LocalAI/core/trace" "github.com/mudler/LocalAI/pkg/grpc/proto" "github.com/mudler/LocalAI/pkg/model" @@ -30,9 +31,12 @@ func ModelTranscription(audio, language string, translate, diarize bool, prompt } var startTime time.Time + var audioSnippet map[string]any if appConfig.EnableTracing { trace.InitBackendTracingIfEnabled(appConfig.TracingMaxItems) startTime = time.Now() + // Capture audio before the backend call — the backend may delete the file. + audioSnippet = trace.AudioSnippet(audio) } r, err := transcriptionModel.AudioTranscription(context.Background(), &proto.TranscriptRequest{ @@ -45,6 +49,16 @@ func ModelTranscription(audio, language string, translate, diarize bool, prompt }) if err != nil { if appConfig.EnableTracing { + errData := map[string]any{ + "audio_file": audio, + "language": language, + "translate": translate, + "diarize": diarize, + "prompt": prompt, + } + if audioSnippet != nil { + maps.Copy(errData, audioSnippet) + } trace.RecordBackendTrace(trace.BackendTrace{ Timestamp: startTime, Duration: time.Since(startTime), @@ -53,13 +67,7 @@ func ModelTranscription(audio, language string, translate, diarize bool, prompt Backend: modelConfig.Backend, Summary: trace.TruncateString(audio, 200), Error: err.Error(), - Data: map[string]any{ - "audio_file": audio, - "language": language, - "translate": translate, - "diarize": diarize, - "prompt": prompt, - }, + Data: errData, }) } return nil, err @@ -84,6 +92,18 @@ func ModelTranscription(audio, language string, translate, diarize bool, prompt } if appConfig.EnableTracing { + data := map[string]any{ + "audio_file": audio, + "language": language, + "translate": translate, + "diarize": diarize, + "prompt": prompt, + "result_text": tr.Text, + "segments_count": len(tr.Segments), + } + if audioSnippet != nil { + maps.Copy(data, audioSnippet) + } trace.RecordBackendTrace(trace.BackendTrace{ Timestamp: startTime, Duration: time.Since(startTime), @@ -91,15 +111,7 @@ func ModelTranscription(audio, language string, translate, diarize bool, prompt ModelName: modelConfig.Name, Backend: modelConfig.Backend, Summary: trace.TruncateString(audio+" -> "+tr.Text, 200), - Data: map[string]any{ - "audio_file": audio, - "language": language, - "translate": translate, - "diarize": diarize, - "prompt": prompt, - "result_text": tr.Text, - "segments_count": len(tr.Segments), - }, + Data: data, }) } diff --git a/core/backend/tts.go b/core/backend/tts.go index 7859cd67cb71..69193db12a5d 100644 --- a/core/backend/tts.go +++ b/core/backend/tts.go @@ -6,6 +6,7 @@ import ( "encoding/binary" "encoding/json" "fmt" + "maps" "os" "path/filepath" "time" @@ -84,6 +85,16 @@ func ModelTTS( errStr = fmt.Sprintf("TTS error: %s", res.Message) } + data := map[string]any{ + "text": text, + "voice": voice, + "language": language, + } + if err == nil && res.Success { + if snippet := trace.AudioSnippet(filePath); snippet != nil { + maps.Copy(data, snippet) + } + } trace.RecordBackendTrace(trace.BackendTrace{ Timestamp: startTime, Duration: time.Since(startTime), @@ -92,11 +103,7 @@ func ModelTTS( Backend: modelConfig.Backend, Summary: trace.TruncateString(text, 200), Error: errStr, - Data: map[string]any{ - "text": text, - "voice": voice, - "language": language, - }, + Data: data, }) } @@ -158,6 +165,11 @@ func ModelTTSStream( headerSent := false var callbackErr error + // Collect up to 30s of audio for tracing + var snippetPCM []byte + var totalPCMBytes int + snippetCapped := false + err = ttsModel.TTSStream(context.Background(), &proto.TTSRequest{ Text: text, Model: modelPath, @@ -166,7 +178,7 @@ func ModelTTSStream( }, func(reply *proto.Reply) { // First message contains sample rate info if !headerSent && len(reply.Message) > 0 { - var info map[string]interface{} + var info map[string]any if json.Unmarshal(reply.Message, &info) == nil { if sr, ok := info["sample_rate"].(float64); ok { sampleRate = uint32(sr) @@ -207,6 +219,22 @@ func ModelTTSStream( if writeErr := audioCallback(reply.Audio); writeErr != nil { callbackErr = writeErr } + // Accumulate PCM for tracing snippet + totalPCMBytes += len(reply.Audio) + if appConfig.EnableTracing && !snippetCapped { + maxBytes := int(sampleRate) * 2 * trace.MaxSnippetSeconds // 16-bit mono + if len(snippetPCM)+len(reply.Audio) <= maxBytes { + snippetPCM = append(snippetPCM, reply.Audio...) + } else { + remaining := maxBytes - len(snippetPCM) + if remaining > 0 { + // Align to sample boundary (2 bytes per sample) + remaining = remaining &^ 1 + snippetPCM = append(snippetPCM, reply.Audio[:remaining]...) + } + snippetCapped = true + } + } } }) @@ -221,6 +249,17 @@ func ModelTTSStream( errStr = resultErr.Error() } + data := map[string]any{ + "text": text, + "voice": voice, + "language": language, + "streaming": true, + } + if resultErr == nil && len(snippetPCM) > 0 { + if snippet := trace.AudioSnippetFromPCM(snippetPCM, int(sampleRate), totalPCMBytes); snippet != nil { + maps.Copy(data, snippet) + } + } trace.RecordBackendTrace(trace.BackendTrace{ Timestamp: startTime, Duration: time.Since(startTime), @@ -229,12 +268,7 @@ func ModelTTSStream( Backend: modelConfig.Backend, Summary: trace.TruncateString(text, 200), Error: errStr, - Data: map[string]any{ - "text": text, - "voice": voice, - "language": language, - "streaming": true, - }, + Data: data, }) } diff --git a/core/http/endpoints/openai/inpainting_test.go b/core/http/endpoints/openai/inpainting_test.go index de4678d347e8..69a80b6deee7 100644 --- a/core/http/endpoints/openai/inpainting_test.go +++ b/core/http/endpoints/openai/inpainting_test.go @@ -7,17 +7,17 @@ import ( "net/http/httptest" "os" "path/filepath" - "testing" "github.com/labstack/echo/v4" "github.com/mudler/LocalAI/core/backend" "github.com/mudler/LocalAI/core/config" "github.com/mudler/LocalAI/core/http/middleware" model "github.com/mudler/LocalAI/pkg/model" - "github.com/stretchr/testify/require" + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" ) -func makeMultipartRequest(t *testing.T, fields map[string]string, files map[string][]byte) (*http.Request, string) { +func makeMultipartRequest(fields map[string]string, files map[string][]byte) (*http.Request, string) { b := &bytes.Buffer{} w := multipart.NewWriter(b) for k, v := range fields { @@ -25,83 +25,73 @@ func makeMultipartRequest(t *testing.T, fields map[string]string, files map[stri } for fname, content := range files { fw, err := w.CreateFormFile(fname, fname+".png") - require.NoError(t, err) + Expect(err).ToNot(HaveOccurred()) _, err = fw.Write(content) - require.NoError(t, err) + Expect(err).ToNot(HaveOccurred()) } - require.NoError(t, w.Close()) + Expect(w.Close()).To(Succeed()) req := httptest.NewRequest(http.MethodPost, "/v1/images/inpainting", b) req.Header.Set("Content-Type", w.FormDataContentType()) return req, w.FormDataContentType() } -func TestInpainting_MissingFiles(t *testing.T) { - e := echo.New() - // handler requires cl, ml, appConfig but this test verifies missing files early - h := InpaintingEndpoint(nil, nil, config.NewApplicationConfig()) +var _ = Describe("Inpainting", func() { + It("returns error for missing files", func() { + e := echo.New() + h := InpaintingEndpoint(nil, nil, config.NewApplicationConfig()) - req := httptest.NewRequest(http.MethodPost, "/v1/images/inpainting", nil) - rec := httptest.NewRecorder() - c := e.NewContext(req, rec) + req := httptest.NewRequest(http.MethodPost, "/v1/images/inpainting", nil) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) - err := h(c) - require.Error(t, err) -} + err := h(c) + Expect(err).To(HaveOccurred()) + }) -func TestInpainting_HappyPath(t *testing.T) { - // Setup temp generated content dir - tmpDir, err := os.MkdirTemp("", "gencontent") - require.NoError(t, err) - defer os.RemoveAll(tmpDir) + It("handles the happy path", func() { + tmpDir, err := os.MkdirTemp("", "gencontent") + Expect(err).ToNot(HaveOccurred()) + DeferCleanup(func() { os.RemoveAll(tmpDir) }) - appConf := config.NewApplicationConfig(config.WithGeneratedContentDir(tmpDir)) + appConf := config.NewApplicationConfig(config.WithGeneratedContentDir(tmpDir)) - // stub the backend.ImageGenerationFunc - orig := backend.ImageGenerationFunc - backend.ImageGenerationFunc = func(height, width, step, seed int, positive_prompt, negative_prompt, src, dst string, loader *model.ModelLoader, modelConfig config.ModelConfig, appConfig *config.ApplicationConfig, refImages []string) (func() error, error) { - fn := func() error { - // write a fake png file to dst - return os.WriteFile(dst, []byte("PNGDATA"), 0644) + orig := backend.ImageGenerationFunc + backend.ImageGenerationFunc = func(height, width, step, seed int, positive_prompt, negative_prompt, src, dst string, loader *model.ModelLoader, modelConfig config.ModelConfig, appConfig *config.ApplicationConfig, refImages []string) (func() error, error) { + fn := func() error { + return os.WriteFile(dst, []byte("PNGDATA"), 0644) + } + return fn, nil } - return fn, nil - } - defer func() { backend.ImageGenerationFunc = orig }() - - // prepare multipart request with image and mask - fields := map[string]string{"model": "dreamshaper-8-inpainting", "prompt": "A test"} - files := map[string][]byte{"image": []byte("IMAGEDATA"), "mask": []byte("MASKDATA")} - reqBuf, _ := makeMultipartRequest(t, fields, files) - - rec := httptest.NewRecorder() - e := echo.New() - c := e.NewContext(reqBuf, rec) - - // set a minimal model config in context as handler expects - c.Set(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG, &config.ModelConfig{Backend: "diffusers"}) - - h := InpaintingEndpoint(nil, nil, appConf) - - // call handler - err = h(c) - require.NoError(t, err) - require.Equal(t, http.StatusOK, rec.Code) - - // verify response body contains generated-images path - body := rec.Body.String() - require.Contains(t, body, "generated-images") - - // confirm the file was created in tmpDir - // parse out filename from response (naive search) - // find "generated-images/" and extract until closing quote or brace - idx := bytes.Index(rec.Body.Bytes(), []byte("generated-images/")) - require.True(t, idx >= 0) - rest := rec.Body.Bytes()[idx:] - end := bytes.IndexAny(rest, "\",}\n") - if end == -1 { - end = len(rest) - } - fname := string(rest[len("generated-images/"):end]) - // ensure file exists - _, err = os.Stat(filepath.Join(tmpDir, fname)) - require.NoError(t, err) -} + DeferCleanup(func() { backend.ImageGenerationFunc = orig }) + + fields := map[string]string{"model": "dreamshaper-8-inpainting", "prompt": "A test"} + files := map[string][]byte{"image": []byte("IMAGEDATA"), "mask": []byte("MASKDATA")} + reqBuf, _ := makeMultipartRequest(fields, files) + + rec := httptest.NewRecorder() + e := echo.New() + c := e.NewContext(reqBuf, rec) + + c.Set(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG, &config.ModelConfig{Backend: "diffusers"}) + + h := InpaintingEndpoint(nil, nil, appConf) + + err = h(c) + Expect(err).ToNot(HaveOccurred()) + Expect(rec.Code).To(Equal(http.StatusOK)) + + body := rec.Body.String() + Expect(body).To(ContainSubstring("generated-images")) + + idx := bytes.Index(rec.Body.Bytes(), []byte("generated-images/")) + Expect(idx).To(BeNumerically(">=", 0)) + rest := rec.Body.Bytes()[idx:] + end := bytes.IndexAny(rest, "\",}\n") + if end == -1 { + end = len(rest) + } + fname := string(rest[len("generated-images/"):end]) + _, err = os.Stat(filepath.Join(tmpDir, fname)) + Expect(err).ToNot(HaveOccurred()) + }) +}) diff --git a/core/http/endpoints/openai/realtime.go b/core/http/endpoints/openai/realtime.go index 415e75b18f62..32923a4aca75 100644 --- a/core/http/endpoints/openai/realtime.go +++ b/core/http/endpoints/openai/realtime.go @@ -3,8 +3,10 @@ package openai import ( "context" "encoding/base64" + "encoding/binary" "encoding/json" "fmt" + "math" "os" "sync" "time" @@ -23,6 +25,7 @@ import ( "github.com/mudler/LocalAI/core/templates" laudio "github.com/mudler/LocalAI/pkg/audio" "github.com/mudler/LocalAI/pkg/functions" + "github.com/mudler/LocalAI/pkg/grpc" "github.com/mudler/LocalAI/pkg/grpc/proto" model "github.com/mudler/LocalAI/pkg/model" "github.com/mudler/LocalAI/pkg/reasoning" @@ -40,23 +43,17 @@ const ( maxAudioBufferSize = 100 * 1024 * 1024 // Maximum WebSocket message size in bytes (10MB) to prevent DoS attacks maxWebSocketMessageSize = 10 * 1024 * 1024 + + defaultInstructions = "You are a helpful voice assistant. " + + "Your responses will be spoken aloud using text-to-speech, so keep them concise and conversational. " + + "Do not use markdown formatting, bullet points, numbered lists, code blocks, or special characters. " + + "Speak naturally as you would in a phone conversation. " + + "Avoid parenthetical asides, URLs, and anything that cannot be clearly vocalized." ) // A model can be "emulated" that is: transcribe audio to text -> feed text to the LLM -> generate audio as result // If the model support instead audio-to-audio, we will use the specific gRPC calls instead -// LockedWebsocket wraps a websocket connection with a mutex for safe concurrent writes -type LockedWebsocket struct { - *websocket.Conn - sync.Mutex -} - -func (l *LockedWebsocket) WriteMessage(messageType int, data []byte) error { - l.Lock() - defer l.Unlock() - return l.Conn.WriteMessage(messageType, data) -} - // Session represents a single WebSocket connection and its state type Session struct { ID string @@ -72,13 +69,55 @@ type Session struct { Conversations map[string]*Conversation InputAudioBuffer []byte AudioBufferLock sync.Mutex + OpusFrames [][]byte + OpusFramesLock sync.Mutex Instructions string DefaultConversationID string ModelInterface Model // The pipeline model config or the config for an any-to-any model ModelConfig *config.ModelConfig - InputSampleRate int - MaxOutputTokens types.IntOrInf + InputSampleRate int + OutputSampleRate int + MaxOutputTokens types.IntOrInf + + // Response cancellation: protects activeResponseCancel/activeResponseDone + responseMu sync.Mutex + activeResponseCancel context.CancelFunc + activeResponseDone chan struct{} +} + +// cancelActiveResponse cancels any in-flight response and waits for its +// goroutine to exit. This ensures we never have overlapping responses and +// that interrupted responses are fully cleaned up before starting a new one. +func (s *Session) cancelActiveResponse() { + s.responseMu.Lock() + cancel := s.activeResponseCancel + done := s.activeResponseDone + s.responseMu.Unlock() + + if cancel != nil { + cancel() + } + if done != nil { + <-done + } +} + +// startResponse cancels any active response and returns a new context for +// the replacement response. The caller MUST close the returned done channel +// when the response goroutine exits. +func (s *Session) startResponse(parent context.Context) (context.Context, chan struct{}) { + s.cancelActiveResponse() + + ctx, cancel := context.WithCancel(parent) + done := make(chan struct{}) + + s.responseMu.Lock() + s.activeResponseCancel = cancel + s.activeResponseDone = done + s.responseMu.Unlock() + + return ctx, done } func (s *Session) FromClient(session *types.SessionUnion) { @@ -187,378 +226,431 @@ func Realtime(application *application.Application) echo.HandlerFunc { func registerRealtime(application *application.Application, model string) func(c *websocket.Conn) { return func(conn *websocket.Conn) { - c := &LockedWebsocket{Conn: conn} - + t := NewWebSocketTransport(conn) evaluator := application.TemplatesEvaluator() - xlog.Debug("Realtime WebSocket connection established", "address", c.RemoteAddr().String(), "model", model) + xlog.Debug("Realtime WebSocket connection established", "address", conn.RemoteAddr().String(), "model", model) + runRealtimeSession(application, t, model, evaluator) + } +} - // TODO: Allow any-to-any model to be specified - cl := application.ModelConfigLoader() - cfg, err := cl.LoadModelConfigFileByNameDefaultOptions(model, application.ApplicationConfig()) - if err != nil { - xlog.Error("failed to load model config", "error", err) - sendError(c, "model_load_error", "Failed to load model config", "", "") - return - } +// runRealtimeSession runs the main event loop for a realtime session. +// It is transport-agnostic and works with both WebSocket and WebRTC. +func runRealtimeSession(application *application.Application, t Transport, model string, evaluator *templates.Evaluator) { + // TODO: Allow any-to-any model to be specified + cl := application.ModelConfigLoader() + cfg, err := cl.LoadModelConfigFileByNameDefaultOptions(model, application.ApplicationConfig()) + if err != nil { + xlog.Error("failed to load model config", "error", err) + sendError(t, "model_load_error", "Failed to load model config", "", "") + return + } - if cfg == nil || (cfg.Pipeline.VAD == "" && cfg.Pipeline.Transcription == "" && cfg.Pipeline.TTS == "" && cfg.Pipeline.LLM == "") { - xlog.Error("model is not a pipeline", "model", model) - sendError(c, "invalid_model", "Model is not a pipeline model", "", "") - return - } + if cfg == nil || (cfg.Pipeline.VAD == "" && cfg.Pipeline.Transcription == "" && cfg.Pipeline.TTS == "" && cfg.Pipeline.LLM == "") { + xlog.Error("model is not a pipeline", "model", model) + sendError(t, "invalid_model", "Model is not a pipeline model", "", "") + return + } - sttModel := cfg.Pipeline.Transcription - - sessionID := generateSessionID() - session := &Session{ - ID: sessionID, - TranscriptionOnly: false, - Model: model, - Voice: cfg.TTSConfig.Voice, - ModelConfig: cfg, - TurnDetection: &types.TurnDetectionUnion{ - ServerVad: &types.ServerVad{ - Threshold: 0.5, - PrefixPaddingMs: 300, - SilenceDurationMs: 500, - CreateResponse: true, - }, + sttModel := cfg.Pipeline.Transcription + + sessionID := generateSessionID() + session := &Session{ + ID: sessionID, + TranscriptionOnly: false, + Model: model, + Voice: cfg.TTSConfig.Voice, + Instructions: defaultInstructions, + ModelConfig: cfg, + TurnDetection: &types.TurnDetectionUnion{ + ServerVad: &types.ServerVad{ + Threshold: 0.5, + PrefixPaddingMs: 300, + SilenceDurationMs: 500, + CreateResponse: true, }, - InputAudioTranscription: &types.AudioTranscription{ - Model: sttModel, - }, - Conversations: make(map[string]*Conversation), - InputSampleRate: defaultRemoteSampleRate, - } + }, + InputAudioTranscription: &types.AudioTranscription{ + Model: sttModel, + }, + Conversations: make(map[string]*Conversation), + InputSampleRate: defaultRemoteSampleRate, + OutputSampleRate: defaultRemoteSampleRate, + } - // Create a default conversation - conversationID := generateConversationID() - conversation := &Conversation{ - ID: conversationID, - // TODO: We need to truncate the conversation items when a new item is added and we have run out of space. There are multiple places where items - // can be added so we could use a datastructure here that enforces truncation upon addition - Items: []*types.MessageItemUnion{}, - } - session.Conversations[conversationID] = conversation - session.DefaultConversationID = conversationID - - m, err := newModel( - &cfg.Pipeline, - application.ModelConfigLoader(), - application.ModelLoader(), - application.ApplicationConfig(), - evaluator, - ) - if err != nil { - xlog.Error("failed to load model", "error", err) - sendError(c, "model_load_error", "Failed to load model", "", "") - return - } - session.ModelInterface = m + // Create a default conversation + conversationID := generateConversationID() + conversation := &Conversation{ + ID: conversationID, + // TODO: We need to truncate the conversation items when a new item is added and we have run out of space. There are multiple places where items + // can be added so we could use a datastructure here that enforces truncation upon addition + Items: []*types.MessageItemUnion{}, + } + session.Conversations[conversationID] = conversation + session.DefaultConversationID = conversationID + + m, err := newModel( + &cfg.Pipeline, + application.ModelConfigLoader(), + application.ModelLoader(), + application.ApplicationConfig(), + evaluator, + ) + if err != nil { + xlog.Error("failed to load model", "error", err) + sendError(t, "model_load_error", "Failed to load model", "", "") + return + } + session.ModelInterface = m - // Store the session - sessionLock.Lock() - sessions[sessionID] = session - sessionLock.Unlock() + // Store the session and notify the transport (for WebRTC audio track handling) + sessionLock.Lock() + sessions[sessionID] = session + sessionLock.Unlock() + + // For WebRTC, inbound audio arrives as Opus (48kHz) and is decoded+resampled + // to localSampleRate in handleIncomingAudioTrack. Set InputSampleRate to + // match so handleVAD doesn't needlessly double-resample. + if _, ok := t.(*WebRTCTransport); ok { + session.InputSampleRate = localSampleRate + } - sendEvent(c, types.SessionCreatedEvent{ - ServerEventBase: types.ServerEventBase{ - EventID: "event_TODO", - }, - Session: session.ToServer(), - }) + if sn, ok := t.(interface{ SetSession(*Session) }); ok { + sn.SetSession(session) + } - var ( - msg []byte - wg sync.WaitGroup - done = make(chan struct{}) - ) + sendEvent(t, types.SessionCreatedEvent{ + ServerEventBase: types.ServerEventBase{ + EventID: "event_TODO", + }, + Session: session.ToServer(), + }) - vadServerStarted := false - toggleVAD := func() { - if session.TurnDetection.ServerVad != nil && !vadServerStarted { - xlog.Debug("Starting VAD goroutine...") - wg.Add(1) - go func() { - defer wg.Done() - conversation := session.Conversations[session.DefaultConversationID] - handleVAD(session, conversation, c, done) - }() - vadServerStarted = true - } else if session.TurnDetection.ServerVad == nil && vadServerStarted { - xlog.Debug("Stopping VAD goroutine...") + var ( + msg []byte + wg sync.WaitGroup + done = make(chan struct{}) + ) - go func() { - done <- struct{}{} - }() - vadServerStarted = false - } + vadServerStarted := false + toggleVAD := func() { + if session.TurnDetection != nil && session.TurnDetection.ServerVad != nil && !vadServerStarted { + xlog.Debug("Starting VAD goroutine...") + done = make(chan struct{}) + wg.Add(1) + go func() { + defer wg.Done() + conversation := session.Conversations[session.DefaultConversationID] + handleVAD(session, conversation, t, done) + }() + vadServerStarted = true + } else if (session.TurnDetection == nil || session.TurnDetection.ServerVad == nil) && vadServerStarted { + xlog.Debug("Stopping VAD goroutine...") + close(done) + vadServerStarted = false } + } - toggleVAD() + // For WebRTC sessions, start the Opus decode loop before VAD so that + // decoded PCM is already flowing when VAD's first tick fires. + var decodeDone chan struct{} + if wt, ok := t.(*WebRTCTransport); ok { + decodeDone = make(chan struct{}) + go decodeOpusLoop(session, wt.opusBackend, decodeDone) + } - for { - if _, msg, err = c.ReadMessage(); err != nil { - xlog.Error("read error", "error", err) - break - } + toggleVAD() - // Parse the incoming message - event, err := types.UnmarshalClientEvent(msg) - if err != nil { - xlog.Error("invalid json", "error", err) - sendError(c, "invalid_json", "Invalid JSON format", "", "") - continue - } + for { + msg, err = t.ReadEvent() + if err != nil { + xlog.Error("read error", "error", err) + break + } - switch e := event.(type) { - case types.SessionUpdateEvent: - xlog.Debug("recv", "message", string(msg)) - - // Handle transcription session update - if e.Session.Transcription != nil { - if err := updateTransSession( - session, - &e.Session, - application.ModelConfigLoader(), - application.ModelLoader(), - application.ApplicationConfig(), - ); err != nil { - xlog.Error("failed to update session", "error", err) - sendError(c, "session_update_error", "Failed to update session", "", "") - continue - } + // Handle diagnostic events that aren't part of the OpenAI protocol + var rawType struct { + Type string `json:"type"` + } + if json.Unmarshal(msg, &rawType) == nil && rawType.Type == "test_tone" { + if _, ok := t.(*WebSocketTransport); ok { + sendError(t, "not_supported", "test_tone is only supported on WebRTC connections", "", "") + } else { + xlog.Debug("Generating test tone") + go sendTestTone(t) + } + continue + } - toggleVAD() + // Parse the incoming message + event, err := types.UnmarshalClientEvent(msg) + if err != nil { + xlog.Error("invalid json", "error", err) + sendError(t, "invalid_json", "Invalid JSON format", "", "") + continue + } - sendEvent(c, types.SessionUpdatedEvent{ - ServerEventBase: types.ServerEventBase{ - EventID: "event_TODO", - }, - Session: session.ToServer(), - }) + switch e := event.(type) { + case types.SessionUpdateEvent: + xlog.Debug("recv", "message", string(msg)) + + // Handle transcription session update + if e.Session.Transcription != nil { + if err := updateTransSession( + session, + &e.Session, + application.ModelConfigLoader(), + application.ModelLoader(), + application.ApplicationConfig(), + ); err != nil { + xlog.Error("failed to update session", "error", err) + sendError(t, "session_update_error", "Failed to update session", "", "") + continue } - // Handle realtime session update - if e.Session.Realtime != nil { - if err := updateSession( - session, - &e.Session, - application.ModelConfigLoader(), - application.ModelLoader(), - application.ApplicationConfig(), - evaluator, - ); err != nil { - xlog.Error("failed to update session", "error", err) - sendError(c, "session_update_error", "Failed to update session", "", "") - continue - } - - toggleVAD() + toggleVAD() - sendEvent(c, types.SessionUpdatedEvent{ - ServerEventBase: types.ServerEventBase{ - EventID: "event_TODO", - }, - Session: session.ToServer(), - }) - } + sendEvent(t, types.SessionUpdatedEvent{ + ServerEventBase: types.ServerEventBase{ + EventID: "event_TODO", + }, + Session: session.ToServer(), + }) + } - case types.InputAudioBufferAppendEvent: - // Handle 'input_audio_buffer.append' - if e.Audio == "" { - xlog.Error("Audio data is missing in 'input_audio_buffer.append'") - sendError(c, "missing_audio_data", "Audio data is missing", "", "") + // Handle realtime session update + if e.Session.Realtime != nil { + if err := updateSession( + session, + &e.Session, + application.ModelConfigLoader(), + application.ModelLoader(), + application.ApplicationConfig(), + evaluator, + ); err != nil { + xlog.Error("failed to update session", "error", err) + sendError(t, "session_update_error", "Failed to update session", "", "") continue } - // Decode base64 audio data - decodedAudio, err := base64.StdEncoding.DecodeString(e.Audio) - if err != nil { - xlog.Error("failed to decode audio data", "error", err) - sendError(c, "invalid_audio_data", "Failed to decode audio data", "", "") - continue - } + toggleVAD() - // Check buffer size limits before appending - session.AudioBufferLock.Lock() - newSize := len(session.InputAudioBuffer) + len(decodedAudio) - if newSize > maxAudioBufferSize { - session.AudioBufferLock.Unlock() - xlog.Error("audio buffer size limit exceeded", "current_size", len(session.InputAudioBuffer), "incoming_size", len(decodedAudio), "limit", maxAudioBufferSize) - sendError(c, "buffer_size_exceeded", fmt.Sprintf("Audio buffer size limit exceeded (max %d bytes)", maxAudioBufferSize), "", "") - continue - } + sendEvent(t, types.SessionUpdatedEvent{ + ServerEventBase: types.ServerEventBase{ + EventID: "event_TODO", + }, + Session: session.ToServer(), + }) + } + + case types.InputAudioBufferAppendEvent: + // Handle 'input_audio_buffer.append' + if e.Audio == "" { + xlog.Error("Audio data is missing in 'input_audio_buffer.append'") + sendError(t, "missing_audio_data", "Audio data is missing", "", "") + continue + } + + // Decode base64 audio data + decodedAudio, err := base64.StdEncoding.DecodeString(e.Audio) + if err != nil { + xlog.Error("failed to decode audio data", "error", err) + sendError(t, "invalid_audio_data", "Failed to decode audio data", "", "") + continue + } - // Append to InputAudioBuffer - session.InputAudioBuffer = append(session.InputAudioBuffer, decodedAudio...) + // Check buffer size limits before appending + session.AudioBufferLock.Lock() + newSize := len(session.InputAudioBuffer) + len(decodedAudio) + if newSize > maxAudioBufferSize { session.AudioBufferLock.Unlock() + xlog.Error("audio buffer size limit exceeded", "current_size", len(session.InputAudioBuffer), "incoming_size", len(decodedAudio), "limit", maxAudioBufferSize) + sendError(t, "buffer_size_exceeded", fmt.Sprintf("Audio buffer size limit exceeded (max %d bytes)", maxAudioBufferSize), "", "") + continue + } - case types.InputAudioBufferCommitEvent: - xlog.Debug("recv", "message", string(msg)) + // Append to InputAudioBuffer + session.InputAudioBuffer = append(session.InputAudioBuffer, decodedAudio...) + session.AudioBufferLock.Unlock() - sessionLock.Lock() - isServerVAD := session.TurnDetection.ServerVad != nil - sessionLock.Unlock() + case types.InputAudioBufferCommitEvent: + xlog.Debug("recv", "message", string(msg)) - // TODO: At the least need to check locking and timer state in the VAD Go routine before allowing this - if isServerVAD { - sendNotImplemented(c, "input_audio_buffer.commit in conjunction with VAD") - continue - } + sessionLock.Lock() + isServerVAD := session.TurnDetection != nil && session.TurnDetection.ServerVad != nil + sessionLock.Unlock() - session.AudioBufferLock.Lock() - allAudio := make([]byte, len(session.InputAudioBuffer)) - copy(allAudio, session.InputAudioBuffer) - session.InputAudioBuffer = nil - session.AudioBufferLock.Unlock() + // TODO: At the least need to check locking and timer state in the VAD Go routine before allowing this + if isServerVAD { + sendNotImplemented(t, "input_audio_buffer.commit in conjunction with VAD") + continue + } - go commitUtterance(context.TODO(), allAudio, session, conversation, c) + session.AudioBufferLock.Lock() + allAudio := make([]byte, len(session.InputAudioBuffer)) + copy(allAudio, session.InputAudioBuffer) + session.InputAudioBuffer = nil + session.AudioBufferLock.Unlock() - case types.ConversationItemCreateEvent: - xlog.Debug("recv", "message", string(msg)) - // Add the item to the conversation - item := e.Item - // Ensure IDs are present - if item.User != nil && item.User.ID == "" { - item.User.ID = generateItemID() - } - if item.Assistant != nil && item.Assistant.ID == "" { - item.Assistant.ID = generateItemID() - } - if item.System != nil && item.System.ID == "" { - item.System.ID = generateItemID() - } - if item.FunctionCall != nil && item.FunctionCall.ID == "" { - item.FunctionCall.ID = generateItemID() - } - if item.FunctionCallOutput != nil && item.FunctionCallOutput.ID == "" { - item.FunctionCallOutput.ID = generateItemID() - } + sendEvent(t, types.InputAudioBufferCommittedEvent{ + ServerEventBase: types.ServerEventBase{}, + ItemID: generateItemID(), + }) - conversation.Lock.Lock() - conversation.Items = append(conversation.Items, &item) - conversation.Lock.Unlock() + respCtx, respDone := session.startResponse(context.Background()) + go func() { + defer close(respDone) + commitUtterance(respCtx, allAudio, session, conversation, t) + }() + + case types.ConversationItemCreateEvent: + xlog.Debug("recv", "message", string(msg)) + // Add the item to the conversation + item := e.Item + // Ensure IDs are present + if item.User != nil && item.User.ID == "" { + item.User.ID = generateItemID() + } + if item.Assistant != nil && item.Assistant.ID == "" { + item.Assistant.ID = generateItemID() + } + if item.System != nil && item.System.ID == "" { + item.System.ID = generateItemID() + } + if item.FunctionCall != nil && item.FunctionCall.ID == "" { + item.FunctionCall.ID = generateItemID() + } + if item.FunctionCallOutput != nil && item.FunctionCallOutput.ID == "" { + item.FunctionCallOutput.ID = generateItemID() + } - sendEvent(c, types.ConversationItemAddedEvent{ - ServerEventBase: types.ServerEventBase{ - EventID: e.EventID, - }, - PreviousItemID: e.PreviousItemID, - Item: item, - }) + conversation.Lock.Lock() + conversation.Items = append(conversation.Items, &item) + conversation.Lock.Unlock() - case types.ConversationItemDeleteEvent: - sendError(c, "not_implemented", "Deleting items not implemented", "", "event_TODO") + sendEvent(t, types.ConversationItemAddedEvent{ + ServerEventBase: types.ServerEventBase{ + EventID: e.EventID, + }, + PreviousItemID: e.PreviousItemID, + Item: item, + }) - case types.ConversationItemRetrieveEvent: - xlog.Debug("recv", "message", string(msg)) + case types.ConversationItemDeleteEvent: + sendError(t, "not_implemented", "Deleting items not implemented", "", "event_TODO") - if e.ItemID == "" { - sendError(c, "invalid_item_id", "Need item_id, but none specified", "", "event_TODO") - continue + case types.ConversationItemRetrieveEvent: + xlog.Debug("recv", "message", string(msg)) + + if e.ItemID == "" { + sendError(t, "invalid_item_id", "Need item_id, but none specified", "", "event_TODO") + continue + } + + conversation.Lock.Lock() + var retrievedItem types.MessageItemUnion + for _, item := range conversation.Items { + // We need to check ID in the union + var id string + if item.System != nil { + id = item.System.ID + } else if item.User != nil { + id = item.User.ID + } else if item.Assistant != nil { + id = item.Assistant.ID + } else if item.FunctionCall != nil { + id = item.FunctionCall.ID + } else if item.FunctionCallOutput != nil { + id = item.FunctionCallOutput.ID } + if id == e.ItemID { + retrievedItem = *item + break + } + } + conversation.Lock.Unlock() + + sendEvent(t, types.ConversationItemRetrievedEvent{ + ServerEventBase: types.ServerEventBase{ + EventID: "event_TODO", + }, + Item: retrievedItem, + }) + + case types.ResponseCreateEvent: + xlog.Debug("recv", "message", string(msg)) + + // Handle optional items to add to context + if len(e.Response.Input) > 0 { conversation.Lock.Lock() - var retrievedItem types.MessageItemUnion - for _, item := range conversation.Items { - // We need to check ID in the union - var id string - if item.System != nil { - id = item.System.ID - } else if item.User != nil { - id = item.User.ID - } else if item.Assistant != nil { - id = item.Assistant.ID - } else if item.FunctionCall != nil { - id = item.FunctionCall.ID - } else if item.FunctionCallOutput != nil { - id = item.FunctionCallOutput.ID + for _, item := range e.Response.Input { + // Ensure IDs are present + if item.User != nil && item.User.ID == "" { + item.User.ID = generateItemID() } - - if id == e.ItemID { - retrievedItem = *item - break + if item.Assistant != nil && item.Assistant.ID == "" { + item.Assistant.ID = generateItemID() } + if item.System != nil && item.System.ID == "" { + item.System.ID = generateItemID() + } + if item.FunctionCall != nil && item.FunctionCall.ID == "" { + item.FunctionCall.ID = generateItemID() + } + if item.FunctionCallOutput != nil && item.FunctionCallOutput.ID == "" { + item.FunctionCallOutput.ID = generateItemID() + } + + conversation.Items = append(conversation.Items, &item) } conversation.Lock.Unlock() + } - sendEvent(c, types.ConversationItemRetrievedEvent{ - ServerEventBase: types.ServerEventBase{ - EventID: "event_TODO", - }, - Item: retrievedItem, - }) - - case types.ResponseCreateEvent: - xlog.Debug("recv", "message", string(msg)) - - // Handle optional items to add to context - if len(e.Response.Input) > 0 { - conversation.Lock.Lock() - for _, item := range e.Response.Input { - // Ensure IDs are present - if item.User != nil && item.User.ID == "" { - item.User.ID = generateItemID() - } - if item.Assistant != nil && item.Assistant.ID == "" { - item.Assistant.ID = generateItemID() - } - if item.System != nil && item.System.ID == "" { - item.System.ID = generateItemID() - } - if item.FunctionCall != nil && item.FunctionCall.ID == "" { - item.FunctionCall.ID = generateItemID() - } - if item.FunctionCallOutput != nil && item.FunctionCallOutput.ID == "" { - item.FunctionCallOutput.ID = generateItemID() - } - - conversation.Items = append(conversation.Items, &item) - } - conversation.Lock.Unlock() - } + respCtx, respDone := session.startResponse(context.Background()) + go func() { + defer close(respDone) + triggerResponse(respCtx, session, conversation, t, &e.Response) + }() - go triggerResponse(session, conversation, c, &e.Response) + case types.ResponseCancelEvent: + xlog.Debug("recv", "message", string(msg)) + session.cancelActiveResponse() - case types.ResponseCancelEvent: - xlog.Debug("recv", "message", string(msg)) + default: + xlog.Error("unknown message type") + // sendError(t, "unknown_message_type", fmt.Sprintf("Unknown message type: %s", incomingMsg.Type), "", "") + } + } - // Handle cancellation of ongoing responses - // Implement cancellation logic as needed - sendNotImplemented(c, "response.cancel") + // Cancel any in-flight response before tearing down + session.cancelActiveResponse() - default: - xlog.Error("unknown message type") - // sendError(c, "unknown_message_type", fmt.Sprintf("Unknown message type: %s", incomingMsg.Type), "", "") - } - } + // Stop the Opus decode goroutine (if running) + if decodeDone != nil { + close(decodeDone) + } - // Close the done channel to signal goroutines to exit + // Signal any running VAD goroutine to exit. + if vadServerStarted { close(done) - wg.Wait() - - // Remove the session from the sessions map - sessionLock.Lock() - delete(sessions, sessionID) - sessionLock.Unlock() } + wg.Wait() + + // Remove the session from the sessions map + sessionLock.Lock() + delete(sessions, sessionID) + sessionLock.Unlock() } -// Helper function to send events to the client -func sendEvent(c *LockedWebsocket, event types.ServerEvent) { - eventBytes, err := json.Marshal(event) - if err != nil { - xlog.Error("failed to marshal event", "error", err) - return - } - if err = c.WriteMessage(websocket.TextMessage, eventBytes); err != nil { +// sendEvent sends a server event via the transport, logging any errors. +func sendEvent(t Transport, event types.ServerEvent) { + if err := t.SendEvent(event); err != nil { xlog.Error("write error", "error", err) } } -// Helper function to send errors to the client -func sendError(c *LockedWebsocket, code, message, param, eventID string) { +// sendError sends an error event to the client. +func sendError(t Transport, code, message, param, eventID string) { errorEvent := types.ErrorEvent{ ServerEventBase: types.ServerEventBase{ EventID: eventID, @@ -572,11 +664,35 @@ func sendError(c *LockedWebsocket, code, message, param, eventID string) { }, } - sendEvent(c, errorEvent) + sendEvent(t, errorEvent) +} + +func sendNotImplemented(t Transport, message string) { + sendError(t, "not_implemented", message, "", "event_TODO") } -func sendNotImplemented(c *LockedWebsocket, message string) { - sendError(c, "not_implemented", message, "", "event_TODO") +// sendTestTone generates a 1-second 440 Hz sine wave and sends it through +// the transport's audio path. This exercises the full Opus encode → RTP → +// browser decode pipeline without involving TTS. +func sendTestTone(t Transport) { + const ( + freq = 440.0 + sampleRate = 24000 + duration = 1 // seconds + amplitude = 16000 + numSamples = sampleRate * duration + ) + + pcm := make([]byte, numSamples*2) // 16-bit samples = 2 bytes each + for i := 0; i < numSamples; i++ { + sample := int16(amplitude * math.Sin(2*math.Pi*freq*float64(i)/sampleRate)) + binary.LittleEndian.PutUint16(pcm[i*2:], uint16(sample)) + } + + xlog.Debug("Sending test tone", "samples", numSamples, "sample_rate", sampleRate, "freq", freq) + if err := t.SendAudio(context.Background(), pcm, sampleRate); err != nil { + xlog.Error("test tone send failed", "error", err) + } } func updateTransSession(session *Session, update *types.SessionUnion, cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) error { @@ -616,7 +732,7 @@ func updateTransSession(session *Session, update *types.SessionUnion, cl *config trCur.Prompt = trUpd.Prompt } - if update.Transcription.Audio.Input.TurnDetection != nil { + if update.Transcription.Audio.Input.TurnDetectionSet { session.TurnDetection = update.Transcription.Audio.Input.TurnDetection } @@ -675,7 +791,7 @@ func updateSession(session *Session, update *types.SessionUnion, cl *config.Mode session.ModelInterface = m } - if rt.Audio != nil && rt.Audio.Input != nil && rt.Audio.Input.TurnDetection != nil { + if rt.Audio != nil && rt.Audio.Input != nil && rt.Audio.Input.TurnDetectionSet { session.TurnDetection = rt.Audio.Input.TurnDetection } @@ -685,6 +801,12 @@ func updateSession(session *Session, update *types.SessionUnion, cl *config.Mode } } + if rt.Audio != nil && rt.Audio.Output != nil && rt.Audio.Output.Format != nil && rt.Audio.Output.Format.PCM != nil { + if rt.Audio.Output.Format.PCM.Rate > 0 { + session.OutputSampleRate = rt.Audio.Output.Format.PCM.Rate + } + } + if rt.Instructions != "" { session.Instructions = rt.Instructions } @@ -703,9 +825,64 @@ func updateSession(session *Session, update *types.SessionUnion, cl *config.Mode return nil } +// decodeOpusLoop runs a ticker that drains buffered raw Opus frames from the +// session, decodes them in a single batched gRPC call, and appends the +// resulting PCM to InputAudioBuffer. This gives ~3 gRPC calls/sec instead of +// 50 (one per RTP packet) and keeps decode diagnostics once-per-batch. +func decodeOpusLoop(session *Session, opusBackend grpc.Backend, done chan struct{}) { + ticker := time.NewTicker(300 * time.Millisecond) + defer ticker.Stop() + for { + select { + case <-ticker.C: + session.OpusFramesLock.Lock() + frames := session.OpusFrames + session.OpusFrames = nil + session.OpusFramesLock.Unlock() + if len(frames) == 0 { + continue + } + + result, err := opusBackend.AudioDecode(context.Background(), &proto.AudioDecodeRequest{ + Frames: frames, + Options: map[string]string{ + "session_id": session.ID, + }, + }) + if err != nil { + xlog.Warn("opus decode batch error", "error", err, "frames", len(frames)) + continue + } + + samples := sound.BytesToInt16sLE(result.PcmData) + + xlog.Debug("opus decode batch", + "frames", len(frames), + "decoded_samples", len(samples), + "sample_rate", result.SampleRate, + ) + + // Resample from 48kHz to session input rate (16kHz) if needed + if result.SampleRate != int32(session.InputSampleRate) { + samples = sound.ResampleInt16(samples, int(result.SampleRate), session.InputSampleRate) + } + + pcmBytes := sound.Int16toBytesLE(samples) + session.AudioBufferLock.Lock() + newSize := len(session.InputAudioBuffer) + len(pcmBytes) + if newSize <= maxAudioBufferSize { + session.InputAudioBuffer = append(session.InputAudioBuffer, pcmBytes...) + } + session.AudioBufferLock.Unlock() + case <-done: + return + } + } +} + // handleVAD is a goroutine that listens for audio data from the client, // runs VAD on the audio data, and commits utterances to the conversation -func handleVAD(session *Session, conv *Conversation, c *LockedWebsocket, done chan struct{}) { +func handleVAD(session *Session, conv *Conversation, t Transport, done chan struct{}) { vadContext, cancel := context.WithCancel(context.Background()) go func() { <-done @@ -713,7 +890,7 @@ func handleVAD(session *Session, conv *Conversation, c *LockedWebsocket, done ch }() silenceThreshold := 0.5 // Default 500ms - if session.TurnDetection.ServerVad != nil { + if session.TurnDetection != nil && session.TurnDetection.ServerVad != nil { silenceThreshold = float64(session.TurnDetection.ServerVad.SilenceDurationMs) / 1000 } @@ -734,7 +911,7 @@ func handleVAD(session *Session, conv *Conversation, c *LockedWebsocket, done ch session.AudioBufferLock.Unlock() aints := sound.BytesToInt16sLE(allAudio) - if len(aints) == 0 || len(aints) < int(silenceThreshold)*session.InputSampleRate { + if len(aints) == 0 || len(aints) < int(silenceThreshold*float64(session.InputSampleRate)) { continue } @@ -748,7 +925,7 @@ func handleVAD(session *Session, conv *Conversation, c *LockedWebsocket, done ch continue } xlog.Error("failed to process audio", "error", err) - sendError(c, "processing_error", "Failed to process audio: "+err.Error(), "", "") + sendError(t, "processing_error", "Failed to process audio: "+err.Error(), "", "") continue } @@ -760,21 +937,17 @@ func handleVAD(session *Session, conv *Conversation, c *LockedWebsocket, done ch session.InputAudioBuffer = nil session.AudioBufferLock.Unlock() - // NOTE: OpenAI doesn't send this message unless the client requests it - // xlog.Debug("Detected silence for a while, clearing audio buffer") - // sendEvent(c, types.InputAudioBufferClearedEvent{ - // ServerEventBase: types.ServerEventBase{ - // EventID: "event_TODO", - // }, - // }) - continue } else if len(segments) == 0 { continue } if !speechStarted { - sendEvent(c, types.InputAudioBufferSpeechStartedEvent{ + // Barge-in: cancel any in-flight response so we stop + // sending audio and don't keep the interrupted reply in history. + session.cancelActiveResponse() + + sendEvent(t, types.InputAudioBufferSpeechStartedEvent{ ServerEventBase: types.ServerEventBase{ EventID: "event_TODO", }, @@ -795,7 +968,7 @@ func handleVAD(session *Session, conv *Conversation, c *LockedWebsocket, done ch session.InputAudioBuffer = nil session.AudioBufferLock.Unlock() - sendEvent(c, types.InputAudioBufferSpeechStoppedEvent{ + sendEvent(t, types.InputAudioBufferSpeechStoppedEvent{ ServerEventBase: types.ServerEventBase{ EventID: "event_TODO", }, @@ -803,7 +976,7 @@ func handleVAD(session *Session, conv *Conversation, c *LockedWebsocket, done ch }) speechStarted = false - sendEvent(c, types.InputAudioBufferCommittedEvent{ + sendEvent(t, types.InputAudioBufferCommittedEvent{ ServerEventBase: types.ServerEventBase{ EventID: "event_TODO", }, @@ -813,13 +986,17 @@ func handleVAD(session *Session, conv *Conversation, c *LockedWebsocket, done ch abytes := sound.Int16toBytesLE(aints) // TODO: Remove prefix silence that is is over TurnDetectionParams.PrefixPaddingMs - go commitUtterance(vadContext, abytes, session, conv, c) + respCtx, respDone := session.startResponse(vadContext) + go func() { + defer close(respDone) + commitUtterance(respCtx, abytes, session, conv, t) + }() } } } } -func commitUtterance(ctx context.Context, utt []byte, session *Session, conv *Conversation, c *LockedWebsocket) { +func commitUtterance(ctx context.Context, utt []byte, session *Session, conv *Conversation, t Transport) { if len(utt) == 0 { return } @@ -851,15 +1028,15 @@ func commitUtterance(ctx context.Context, utt []byte, session *Session, conv *Co if session.InputAudioTranscription != nil { tr, err := session.ModelInterface.Transcribe(ctx, f.Name(), session.InputAudioTranscription.Language, false, false, session.InputAudioTranscription.Prompt) if err != nil { - sendError(c, "transcription_failed", err.Error(), "", "event_TODO") + sendError(t, "transcription_failed", err.Error(), "", "event_TODO") return } else if tr == nil { - sendError(c, "transcription_failed", "trancribe result is nil", "", "event_TODO") + sendError(t, "transcription_failed", "trancribe result is nil", "", "event_TODO") return } transcript = tr.Text - sendEvent(c, types.ConversationItemInputAudioTranscriptionCompletedEvent{ + sendEvent(t, types.ConversationItemInputAudioTranscriptionCompletedEvent{ ServerEventBase: types.ServerEventBase{ EventID: "event_TODO", }, @@ -871,12 +1048,12 @@ func commitUtterance(ctx context.Context, utt []byte, session *Session, conv *Co Transcript: transcript, }) } else { - sendNotImplemented(c, "any-to-any models") + sendNotImplemented(t, "any-to-any models") return } if !session.TranscriptionOnly { - generateResponse(session, utt, transcript, conv, c, websocket.TextMessage) + generateResponse(ctx, session, utt, transcript, conv, t) } } @@ -901,7 +1078,7 @@ func runVAD(ctx context.Context, session *Session, adata []int16) ([]schema.VADS } // Function to generate a response based on the conversation -func generateResponse(session *Session, utt []byte, transcript string, conv *Conversation, c *LockedWebsocket, mt int) { +func generateResponse(ctx context.Context, session *Session, utt []byte, transcript string, conv *Conversation, t Transport) { xlog.Debug("Generating realtime response...") // Create user message item @@ -922,14 +1099,14 @@ func generateResponse(session *Session, utt []byte, transcript string, conv *Con conv.Items = append(conv.Items, &item) conv.Lock.Unlock() - sendEvent(c, types.ConversationItemAddedEvent{ + sendEvent(t, types.ConversationItemAddedEvent{ Item: item, }) - triggerResponse(session, conv, c, nil) + triggerResponse(ctx, session, conv, t, nil) } -func triggerResponse(session *Session, conv *Conversation, c *LockedWebsocket, overrides *types.ResponseCreateParams) { +func triggerResponse(ctx context.Context, session *Session, conv *Conversation, t Transport, overrides *types.ResponseCreateParams) { config := session.ModelInterface.PredictConfig() // Default values @@ -1077,7 +1254,7 @@ func triggerResponse(session *Session, conv *Conversation, c *LockedWebsocket, o } responseID := generateUniqueID() - sendEvent(c, types.ResponseCreatedEvent{ + sendEvent(t, types.ResponseCreatedEvent{ ServerEventBase: types.ServerEventBase{}, Response: types.Response{ ID: responseID, @@ -1086,15 +1263,29 @@ func triggerResponse(session *Session, conv *Conversation, c *LockedWebsocket, o }, }) - predFunc, err := session.ModelInterface.Predict(context.TODO(), conversationHistory, images, nil, nil, nil, tools, toolChoice, nil, nil, nil) + predFunc, err := session.ModelInterface.Predict(ctx, conversationHistory, images, nil, nil, nil, tools, toolChoice, nil, nil, nil) if err != nil { - sendError(c, "inference_failed", fmt.Sprintf("backend error: %v", err), "", "") // item.Assistant.ID is unknown here + sendError(t, "inference_failed", fmt.Sprintf("backend error: %v", err), "", "") // item.Assistant.ID is unknown here return } pred, err := predFunc() if err != nil { - sendError(c, "prediction_failed", fmt.Sprintf("backend error: %v", err), "", "") + sendError(t, "prediction_failed", fmt.Sprintf("backend error: %v", err), "", "") + return + } + + // Check for cancellation after LLM inference (barge-in may have fired) + if ctx.Err() != nil { + xlog.Debug("Response cancelled after LLM inference (barge-in)") + sendEvent(t, types.ResponseDoneEvent{ + ServerEventBase: types.ServerEventBase{}, + Response: types.Response{ + ID: responseID, + Object: "realtime.response", + Status: types.ResponseStatusCancelled, + }, + }) return } @@ -1194,14 +1385,14 @@ func triggerResponse(session *Session, conv *Conversation, c *LockedWebsocket, o conv.Items = append(conv.Items, &item) conv.Lock.Unlock() - sendEvent(c, types.ResponseOutputItemAddedEvent{ + sendEvent(t, types.ResponseOutputItemAddedEvent{ ServerEventBase: types.ServerEventBase{}, ResponseID: responseID, OutputIndex: 0, Item: item, }) - sendEvent(c, types.ResponseContentPartAddedEvent{ + sendEvent(t, types.ResponseContentPartAddedEvent{ ServerEventBase: types.ServerEventBase{}, ResponseID: responseID, ItemID: item.Assistant.ID, @@ -1210,15 +1401,54 @@ func triggerResponse(session *Session, conv *Conversation, c *LockedWebsocket, o Part: item.Assistant.Content[0], }) - audioFilePath, res, err := session.ModelInterface.TTS(context.TODO(), finalSpeech, session.Voice, session.InputAudioTranscription.Language) + // removeItemFromConv removes the last occurrence of an item with + // the given assistant ID from conversation history. + removeItemFromConv := func(assistantID string) { + conv.Lock.Lock() + for i := len(conv.Items) - 1; i >= 0; i-- { + if conv.Items[i].Assistant != nil && conv.Items[i].Assistant.ID == assistantID { + conv.Items = append(conv.Items[:i], conv.Items[i+1:]...) + break + } + } + conv.Lock.Unlock() + } + + // sendCancelledResponse emits the cancelled status and cleans up the + // assistant item so the interrupted reply is not in chat history. + sendCancelledResponse := func() { + removeItemFromConv(item.Assistant.ID) + sendEvent(t, types.ResponseDoneEvent{ + ServerEventBase: types.ServerEventBase{}, + Response: types.Response{ + ID: responseID, + Object: "realtime.response", + Status: types.ResponseStatusCancelled, + }, + }) + } + + // Check for cancellation before TTS + if ctx.Err() != nil { + xlog.Debug("Response cancelled before TTS (barge-in)") + sendCancelledResponse() + return + } + + audioFilePath, res, err := session.ModelInterface.TTS(ctx, finalSpeech, session.Voice, session.InputAudioTranscription.Language) if err != nil { + if ctx.Err() != nil { + xlog.Debug("TTS cancelled (barge-in)") + sendCancelledResponse() + return + } xlog.Error("TTS failed", "error", err) - sendError(c, "tts_error", fmt.Sprintf("TTS generation failed: %v", err), "", item.Assistant.ID) + sendError(t, "tts_error", fmt.Sprintf("TTS generation failed: %v", err), "", item.Assistant.ID) return } if !res.Success { xlog.Error("TTS failed", "message", res.Message) - sendError(c, "tts_error", fmt.Sprintf("TTS generation failed: %s", res.Message), "", item.Assistant.ID) + sendError(t, "tts_error", fmt.Sprintf("TTS generation failed: %s", res.Message), "", item.Assistant.ID) return } defer os.Remove(audioFilePath) @@ -1226,21 +1456,47 @@ func triggerResponse(session *Session, conv *Conversation, c *LockedWebsocket, o audioBytes, err := os.ReadFile(audioFilePath) if err != nil { xlog.Error("failed to read TTS file", "error", err) - sendError(c, "tts_error", fmt.Sprintf("Failed to read TTS audio: %v", err), "", item.Assistant.ID) + sendError(t, "tts_error", fmt.Sprintf("Failed to read TTS audio: %v", err), "", item.Assistant.ID) return } - // Strip WAV header (44 bytes) to get raw PCM data - // The OpenAI Realtime API expects raw PCM, not WAV files - const wavHeaderSize = 44 - pcmData := audioBytes - if len(audioBytes) > wavHeaderSize { - pcmData = audioBytes[wavHeaderSize:] + // Parse WAV header to get raw PCM and the actual sample rate from the TTS backend. + pcmData, ttsSampleRate := laudio.ParseWAV(audioBytes) + if ttsSampleRate == 0 { + ttsSampleRate = localSampleRate + } + xlog.Debug("TTS audio parsed", "raw_bytes", len(audioBytes), "pcm_bytes", len(pcmData), "sample_rate", ttsSampleRate) + + // SendAudio (WebRTC) passes PCM at the TTS sample rate directly to the + // Opus encoder, which resamples to 48kHz internally. This avoids a + // lossy intermediate resample through 16kHz. + // XXX: This is a noop in websocket mode; it's included in the JSON instead + if err := t.SendAudio(ctx, pcmData, ttsSampleRate); err != nil { + if ctx.Err() != nil { + xlog.Debug("Audio playback cancelled (barge-in)") + sendCancelledResponse() + return + } + xlog.Error("failed to send audio via transport", "error", err) } - audioString := base64.StdEncoding.EncodeToString(pcmData) + _, isWebRTC := t.(*WebRTCTransport) + + // For WebSocket clients, resample to the session's output rate and + // deliver audio as base64 in JSON events. WebRTC clients already + // received audio over the RTP track, so skip the base64 payload. + var audioString string + if !isWebRTC { + wsPCM := pcmData + if ttsSampleRate != session.OutputSampleRate { + samples := sound.BytesToInt16sLE(pcmData) + resampled := sound.ResampleInt16(samples, ttsSampleRate, session.OutputSampleRate) + wsPCM = sound.Int16toBytesLE(resampled) + } + audioString = base64.StdEncoding.EncodeToString(wsPCM) + } - sendEvent(c, types.ResponseOutputAudioTranscriptDeltaEvent{ + sendEvent(t, types.ResponseOutputAudioTranscriptDeltaEvent{ ServerEventBase: types.ServerEventBase{}, ResponseID: responseID, ItemID: item.Assistant.ID, @@ -1248,7 +1504,7 @@ func triggerResponse(session *Session, conv *Conversation, c *LockedWebsocket, o ContentIndex: 0, Delta: finalSpeech, }) - sendEvent(c, types.ResponseOutputAudioTranscriptDoneEvent{ + sendEvent(t, types.ResponseOutputAudioTranscriptDoneEvent{ ServerEventBase: types.ServerEventBase{}, ResponseID: responseID, ItemID: item.Assistant.ID, @@ -1257,23 +1513,25 @@ func triggerResponse(session *Session, conv *Conversation, c *LockedWebsocket, o Transcript: finalSpeech, }) - sendEvent(c, types.ResponseOutputAudioDeltaEvent{ - ServerEventBase: types.ServerEventBase{}, - ResponseID: responseID, - ItemID: item.Assistant.ID, - OutputIndex: 0, - ContentIndex: 0, - Delta: audioString, - }) - sendEvent(c, types.ResponseOutputAudioDoneEvent{ - ServerEventBase: types.ServerEventBase{}, - ResponseID: responseID, - ItemID: item.Assistant.ID, - OutputIndex: 0, - ContentIndex: 0, - }) + if !isWebRTC { + sendEvent(t, types.ResponseOutputAudioDeltaEvent{ + ServerEventBase: types.ServerEventBase{}, + ResponseID: responseID, + ItemID: item.Assistant.ID, + OutputIndex: 0, + ContentIndex: 0, + Delta: audioString, + }) + sendEvent(t, types.ResponseOutputAudioDoneEvent{ + ServerEventBase: types.ServerEventBase{}, + ResponseID: responseID, + ItemID: item.Assistant.ID, + OutputIndex: 0, + ContentIndex: 0, + }) + } - sendEvent(c, types.ResponseContentPartDoneEvent{ + sendEvent(t, types.ResponseContentPartDoneEvent{ ServerEventBase: types.ServerEventBase{}, ResponseID: responseID, ItemID: item.Assistant.ID, @@ -1284,10 +1542,12 @@ func triggerResponse(session *Session, conv *Conversation, c *LockedWebsocket, o conv.Lock.Lock() item.Assistant.Status = types.ItemStatusCompleted - item.Assistant.Content[0].Audio = audioString + if !isWebRTC { + item.Assistant.Content[0].Audio = audioString + } conv.Lock.Unlock() - sendEvent(c, types.ResponseOutputItemDoneEvent{ + sendEvent(t, types.ResponseOutputItemDoneEvent{ ServerEventBase: types.ServerEventBase{}, ResponseID: responseID, OutputIndex: 0, @@ -1321,14 +1581,14 @@ func triggerResponse(session *Session, conv *Conversation, c *LockedWebsocket, o outputIndex++ } - sendEvent(c, types.ResponseOutputItemAddedEvent{ + sendEvent(t, types.ResponseOutputItemAddedEvent{ ServerEventBase: types.ServerEventBase{}, ResponseID: responseID, OutputIndex: outputIndex, Item: fcItem, }) - sendEvent(c, types.ResponseFunctionCallArgumentsDeltaEvent{ + sendEvent(t, types.ResponseFunctionCallArgumentsDeltaEvent{ ServerEventBase: types.ServerEventBase{}, ResponseID: responseID, ItemID: toolCallID, @@ -1337,7 +1597,7 @@ func triggerResponse(session *Session, conv *Conversation, c *LockedWebsocket, o Delta: tc.Arguments, }) - sendEvent(c, types.ResponseFunctionCallArgumentsDoneEvent{ + sendEvent(t, types.ResponseFunctionCallArgumentsDoneEvent{ ServerEventBase: types.ServerEventBase{}, ResponseID: responseID, ItemID: toolCallID, @@ -1347,7 +1607,7 @@ func triggerResponse(session *Session, conv *Conversation, c *LockedWebsocket, o Name: tc.Name, }) - sendEvent(c, types.ResponseOutputItemDoneEvent{ + sendEvent(t, types.ResponseOutputItemDoneEvent{ ServerEventBase: types.ServerEventBase{}, ResponseID: responseID, OutputIndex: outputIndex, @@ -1355,7 +1615,7 @@ func triggerResponse(session *Session, conv *Conversation, c *LockedWebsocket, o }) } - sendEvent(c, types.ResponseDoneEvent{ + sendEvent(t, types.ResponseDoneEvent{ ServerEventBase: types.ServerEventBase{}, Response: types.Response{ ID: responseID, diff --git a/core/http/endpoints/openai/realtime_transport.go b/core/http/endpoints/openai/realtime_transport.go new file mode 100644 index 000000000000..5ffcb0ba917e --- /dev/null +++ b/core/http/endpoints/openai/realtime_transport.go @@ -0,0 +1,23 @@ +package openai + +import ( + "context" + + "github.com/mudler/LocalAI/core/http/endpoints/openai/types" +) + +// Transport abstracts event and audio I/O so the same session logic +// can serve both WebSocket and WebRTC connections. +type Transport interface { + // SendEvent marshals and sends a server event to the client. + SendEvent(event types.ServerEvent) error + // ReadEvent reads the next raw client event (JSON bytes). + ReadEvent() ([]byte, error) + // SendAudio sends raw PCM audio to the client at the given sample rate. + // For WebSocket this is a no-op (audio is sent via JSON events). + // For WebRTC this encodes to Opus and writes to the media track. + // The context allows cancellation for barge-in support. + SendAudio(ctx context.Context, pcmData []byte, sampleRate int) error + // Close tears down the underlying connection. + Close() error +} diff --git a/core/http/endpoints/openai/realtime_transport_webrtc.go b/core/http/endpoints/openai/realtime_transport_webrtc.go new file mode 100644 index 000000000000..af7aa046bcff --- /dev/null +++ b/core/http/endpoints/openai/realtime_transport_webrtc.go @@ -0,0 +1,251 @@ +package openai + +import ( + "context" + "encoding/json" + "fmt" + "math/rand/v2" + "sync" + "time" + + "github.com/mudler/LocalAI/core/http/endpoints/openai/types" + "github.com/mudler/LocalAI/pkg/grpc" + pb "github.com/mudler/LocalAI/pkg/grpc/proto" + "github.com/mudler/xlog" + "github.com/pion/rtp" + "github.com/pion/webrtc/v4" +) + +// WebRTCTransport implements Transport over a pion/webrtc PeerConnection. +// Events travel via the "oai-events" DataChannel; audio goes over an RTP track. +type WebRTCTransport struct { + pc *webrtc.PeerConnection + dc *webrtc.DataChannel + audioTrack *webrtc.TrackLocalStaticRTP + opusBackend grpc.Backend + inEvents chan []byte + outEvents chan []byte // buffered outbound event queue + closed chan struct{} + closeOnce sync.Once + flushed chan struct{} // closed when sender goroutine has drained outEvents + dcReady chan struct{} // closed when data channel is open + dcReadyOnce sync.Once + sessionCh chan *Session // delivers session from runRealtimeSession to handleIncomingAudioTrack + + // RTP state for outbound audio — protected by rtpMu + rtpMu sync.Mutex + rtpSeqNum uint16 + rtpTimestamp uint32 + rtpMarker bool // true → next packet gets marker bit set +} + +func NewWebRTCTransport(pc *webrtc.PeerConnection, audioTrack *webrtc.TrackLocalStaticRTP, opusBackend grpc.Backend) *WebRTCTransport { + t := &WebRTCTransport{ + pc: pc, + audioTrack: audioTrack, + opusBackend: opusBackend, + inEvents: make(chan []byte, 256), + outEvents: make(chan []byte, 256), + closed: make(chan struct{}), + flushed: make(chan struct{}), + dcReady: make(chan struct{}), + sessionCh: make(chan *Session, 1), + rtpSeqNum: uint16(rand.UintN(65536)), + rtpTimestamp: rand.Uint32(), + rtpMarker: true, // first packet of the stream gets marker + } + + // The client creates the "oai-events" data channel (so m=application is + // included in the SDP offer). We receive it here via OnDataChannel. + pc.OnDataChannel(func(dc *webrtc.DataChannel) { + if dc.Label() != "oai-events" { + return + } + t.dc = dc + dc.OnOpen(func() { + t.dcReadyOnce.Do(func() { close(t.dcReady) }) + }) + dc.OnMessage(func(msg webrtc.DataChannelMessage) { + select { + case t.inEvents <- msg.Data: + case <-t.closed: + } + }) + // The channel may already be open by the time OnDataChannel fires + if dc.ReadyState() == webrtc.DataChannelStateOpen { + t.dcReadyOnce.Do(func() { close(t.dcReady) }) + } + }) + + pc.OnConnectionStateChange(func(state webrtc.PeerConnectionState) { + xlog.Debug("WebRTC connection state", "state", state.String()) + if state == webrtc.PeerConnectionStateFailed || + state == webrtc.PeerConnectionStateClosed || + state == webrtc.PeerConnectionStateDisconnected { + t.closeOnce.Do(func() { close(t.closed) }) + } + }) + + go t.sendLoop() + + return t +} + +// sendLoop is a dedicated goroutine that drains outEvents and sends them +// over the data channel. It waits for the data channel to open before +// sending, and drains any remaining events when closed is signalled. +func (t *WebRTCTransport) sendLoop() { + defer close(t.flushed) + + // Wait for data channel to be ready + select { + case <-t.dcReady: + case <-t.closed: + return + } + + for { + select { + case data, ok := <-t.outEvents: + if !ok { + return + } + if err := t.dc.SendText(string(data)); err != nil { + xlog.Error("data channel send failed", "error", err) + return + } + case <-t.closed: + // Drain any remaining queued events before exiting + for { + select { + case data := <-t.outEvents: + if err := t.dc.SendText(string(data)); err != nil { + return + } + default: + return + } + } + } + } +} + +func (t *WebRTCTransport) SendEvent(event types.ServerEvent) error { + data, err := json.Marshal(event) + if err != nil { + return fmt.Errorf("marshal event: %w", err) + } + + select { + case t.outEvents <- data: + return nil + case <-t.closed: + return fmt.Errorf("transport closed") + } +} + +func (t *WebRTCTransport) ReadEvent() ([]byte, error) { + select { + case msg := <-t.inEvents: + return msg, nil + case <-t.closed: + return nil, fmt.Errorf("transport closed") + } +} + +// SendAudio encodes raw PCM int16 LE to Opus and writes RTP packets to the +// audio track. The encoder resamples from the given sampleRate to 48kHz +// internally. Frames are paced at real-time intervals (20ms per frame) to +// avoid overwhelming the browser's jitter buffer with a burst of packets. +// +// The context allows callers to cancel mid-stream for barge-in support. +// When cancelled, the marker bit is set so the next audio segment starts +// cleanly in the browser's jitter buffer. +// +// RTP packets are constructed manually (rather than via WriteSample) so we +// can control the marker bit. pion's WriteSample sets the marker bit on +// every Opus packet, which causes Chrome's NetEq jitter buffer to reset +// its timing estimation for each frame, producing severe audio distortion. +func (t *WebRTCTransport) SendAudio(ctx context.Context, pcmData []byte, sampleRate int) error { + result, err := t.opusBackend.AudioEncode(ctx, &pb.AudioEncodeRequest{ + PcmData: pcmData, + SampleRate: int32(sampleRate), + Channels: 1, + }) + if err != nil { + return fmt.Errorf("opus encode: %w", err) + } + + frames := result.Frames + const frameDuration = 20 * time.Millisecond + const samplesPerFrame = 960 // 20ms at 48kHz + + ticker := time.NewTicker(frameDuration) + defer ticker.Stop() + + for i, frame := range frames { + t.rtpMu.Lock() + pkt := &rtp.Packet{ + Header: rtp.Header{ + Version: 2, + Marker: t.rtpMarker, + SequenceNumber: t.rtpSeqNum, + Timestamp: t.rtpTimestamp, + // SSRC and PayloadType are overridden by pion's writeRTP + }, + Payload: frame, + } + t.rtpSeqNum++ + t.rtpTimestamp += samplesPerFrame + t.rtpMarker = false // only the first packet gets marker + t.rtpMu.Unlock() + + if err := t.audioTrack.WriteRTP(pkt); err != nil { + return fmt.Errorf("write rtp: %w", err) + } + + // Pace output at ~real-time so the browser's jitter buffer + // receives packets at the expected rate. Skip wait after last frame. + if i < len(frames)-1 { + select { + case <-ticker.C: + case <-ctx.Done(): + // Barge-in: mark the next packet so the browser knows + // a new audio segment is starting after the interruption. + t.rtpMu.Lock() + t.rtpMarker = true + t.rtpMu.Unlock() + return ctx.Err() + case <-t.closed: + return fmt.Errorf("transport closed during audio send") + } + } + } + return nil +} + +// SetSession delivers the session to any goroutine waiting in WaitForSession. +func (t *WebRTCTransport) SetSession(s *Session) { + select { + case t.sessionCh <- s: + case <-t.closed: + } +} + +// WaitForSession blocks until the session is available or the transport closes. +func (t *WebRTCTransport) WaitForSession() *Session { + select { + case s := <-t.sessionCh: + return s + case <-t.closed: + return nil + } +} + +func (t *WebRTCTransport) Close() error { + // Signal no more events and unblock the sender if it's waiting + t.closeOnce.Do(func() { close(t.closed) }) + // Wait for the sender to drain any remaining queued events + <-t.flushed + return t.pc.Close() +} diff --git a/core/http/endpoints/openai/realtime_transport_ws.go b/core/http/endpoints/openai/realtime_transport_ws.go new file mode 100644 index 000000000000..6621f2ca6b82 --- /dev/null +++ b/core/http/endpoints/openai/realtime_transport_ws.go @@ -0,0 +1,47 @@ +package openai + +import ( + "context" + "encoding/json" + "sync" + + "github.com/gorilla/websocket" + "github.com/mudler/LocalAI/core/http/endpoints/openai/types" + "github.com/mudler/xlog" +) + +// WebSocketTransport implements Transport over a gorilla/websocket connection. +type WebSocketTransport struct { + conn *websocket.Conn + mu sync.Mutex +} + +func NewWebSocketTransport(conn *websocket.Conn) *WebSocketTransport { + return &WebSocketTransport{conn: conn} +} + +func (t *WebSocketTransport) SendEvent(event types.ServerEvent) error { + eventBytes, err := json.Marshal(event) + if err != nil { + xlog.Error("failed to marshal event", "error", err) + return err + } + t.mu.Lock() + defer t.mu.Unlock() + return t.conn.WriteMessage(websocket.TextMessage, eventBytes) +} + +func (t *WebSocketTransport) ReadEvent() ([]byte, error) { + _, msg, err := t.conn.ReadMessage() + return msg, err +} + +// SendAudio is a no-op for WebSocket — audio is delivered via JSON events +// (base64-encoded in response.audio.delta). +func (t *WebSocketTransport) SendAudio(_ context.Context, _ []byte, _ int) error { + return nil +} + +func (t *WebSocketTransport) Close() error { + return t.conn.Close() +} diff --git a/core/http/endpoints/openai/realtime_webrtc.go b/core/http/endpoints/openai/realtime_webrtc.go new file mode 100644 index 000000000000..864c67b19862 --- /dev/null +++ b/core/http/endpoints/openai/realtime_webrtc.go @@ -0,0 +1,206 @@ +package openai + +import ( + "net/http" + "time" + + "github.com/labstack/echo/v4" + "github.com/mudler/LocalAI/core/application" + model "github.com/mudler/LocalAI/pkg/model" + "github.com/mudler/xlog" + "github.com/pion/webrtc/v4" +) + +// RealtimeCallRequest is the JSON body for POST /v1/realtime/calls. +type RealtimeCallRequest struct { + SDP string `json:"sdp"` + Model string `json:"model"` +} + +// RealtimeCallResponse is the JSON response for POST /v1/realtime/calls. +type RealtimeCallResponse struct { + SDP string `json:"sdp"` + SessionID string `json:"session_id"` +} + +// RealtimeCalls handles POST /v1/realtime/calls for WebRTC signaling. +func RealtimeCalls(application *application.Application) echo.HandlerFunc { + return func(c echo.Context) error { + var req RealtimeCallRequest + if err := c.Bind(&req); err != nil { + return c.JSON(http.StatusBadRequest, map[string]string{"error": "invalid request body"}) + } + if req.SDP == "" { + return c.JSON(http.StatusBadRequest, map[string]string{"error": "sdp is required"}) + } + if req.Model == "" { + return c.JSON(http.StatusBadRequest, map[string]string{"error": "model is required"}) + } + + // Create a MediaEngine with Opus support + m := &webrtc.MediaEngine{} + if err := m.RegisterDefaultCodecs(); err != nil { + xlog.Error("failed to register codecs", "error", err) + return c.JSON(http.StatusInternalServerError, map[string]string{"error": "codec registration failed"}) + } + + api := webrtc.NewAPI(webrtc.WithMediaEngine(m)) + + pc, err := api.NewPeerConnection(webrtc.Configuration{}) + if err != nil { + xlog.Error("failed to create peer connection", "error", err) + return c.JSON(http.StatusInternalServerError, map[string]string{"error": "failed to create peer connection"}) + } + + // Create outbound audio track (Opus, 48kHz). + // We use TrackLocalStaticRTP (not TrackLocalStaticSample) so that + // SendAudio can construct RTP packets directly and control the marker + // bit. pion's WriteSample sets the marker bit on every Opus packet, + // which causes Chrome's NetEq jitter buffer to reset for each frame. + audioTrack, err := webrtc.NewTrackLocalStaticRTP( + webrtc.RTPCodecCapability{ + MimeType: webrtc.MimeTypeOpus, + ClockRate: 48000, + Channels: 2, // Opus in WebRTC is always signaled as 2 channels per RFC 7587 + }, + "audio", + "localai", + ) + if err != nil { + pc.Close() + xlog.Error("failed to create audio track", "error", err) + return c.JSON(http.StatusInternalServerError, map[string]string{"error": "failed to create audio track"}) + } + + rtpSender, err := pc.AddTrack(audioTrack) + if err != nil { + pc.Close() + xlog.Error("failed to add audio track", "error", err) + return c.JSON(http.StatusInternalServerError, map[string]string{"error": "failed to add audio track"}) + } + + // Drain RTCP (control protocol) packets we don't have anyting useful to do with + go func() { + buf := make([]byte, 1500) + for { + if _, _, err := rtpSender.Read(buf); err != nil { + return + } + } + }() + + // Load the Opus backend + opusBackend, err := application.ModelLoader().Load( + model.WithBackendString("opus"), + model.WithModelID("__opus_codec__"), + model.WithModel("opus"), + ) + if err != nil { + pc.Close() + xlog.Error("failed to load opus backend", "error", err) + return c.JSON(http.StatusInternalServerError, map[string]string{"error": "opus backend not available"}) + } + + // Create the transport (the data channel is created by the client and + // received via pc.OnDataChannel inside NewWebRTCTransport) + transport := NewWebRTCTransport(pc, audioTrack, opusBackend) + + // Handle incoming audio track from the client + pc.OnTrack(func(track *webrtc.TrackRemote, receiver *webrtc.RTPReceiver) { + codec := track.Codec() + if codec.MimeType != webrtc.MimeTypeOpus { + xlog.Warn("unexpected track codec, ignoring", "mime", codec.MimeType) + return + } + xlog.Debug("Received audio track from client", + "codec", codec.MimeType, + "clock_rate", codec.ClockRate, + "channels", codec.Channels, + "sdp_fmtp", codec.SDPFmtpLine, + "payload_type", codec.PayloadType, + ) + + handleIncomingAudioTrack(track, transport) + }) + + // Set the remote SDP (client's offer) + if err := pc.SetRemoteDescription(webrtc.SessionDescription{ + Type: webrtc.SDPTypeOffer, + SDP: req.SDP, + }); err != nil { + transport.Close() + xlog.Error("failed to set remote description", "error", err) + return c.JSON(http.StatusBadRequest, map[string]string{"error": "invalid SDP offer"}) + } + + // Create answer + answer, err := pc.CreateAnswer(nil) + if err != nil { + transport.Close() + xlog.Error("failed to create answer", "error", err) + return c.JSON(http.StatusInternalServerError, map[string]string{"error": "failed to create answer"}) + } + + if err := pc.SetLocalDescription(answer); err != nil { + transport.Close() + xlog.Error("failed to set local description", "error", err) + return c.JSON(http.StatusInternalServerError, map[string]string{"error": "failed to set local description"}) + } + + // Wait for ICE gathering to complete (with timeout) + gatherDone := webrtc.GatheringCompletePromise(pc) + select { + case <-gatherDone: + case <-time.After(10 * time.Second): + xlog.Warn("ICE gathering timed out, using partial candidates") + } + + localDesc := pc.LocalDescription() + if localDesc == nil { + transport.Close() + return c.JSON(http.StatusInternalServerError, map[string]string{"error": "no local description"}) + } + + sessionID := generateSessionID() + + // Start the realtime session in a goroutine + evaluator := application.TemplatesEvaluator() + go func() { + defer transport.Close() + runRealtimeSession(application, transport, req.Model, evaluator) + }() + + return c.JSON(http.StatusCreated, RealtimeCallResponse{ + SDP: localDesc.SDP, + SessionID: sessionID, + }) + } +} + +// handleIncomingAudioTrack reads RTP packets from a remote WebRTC track +// and buffers the raw Opus payloads on the session. Decoding is done in +// batches by decodeOpusLoop in realtime.go. +func handleIncomingAudioTrack(track *webrtc.TrackRemote, transport *WebRTCTransport) { + session := transport.WaitForSession() + if session == nil { + xlog.Error("could not find session for incoming audio track (transport closed)") + sendError(transport, "session_error", "Session failed to start — check server logs", "", "") + return + } + + for { + pkt, _, err := track.ReadRTP() + if err != nil { + xlog.Debug("audio track read ended", "error", err) + return + } + + // Copy the payload — pion's ReadRTP may back it by a reusable buffer + payload := make([]byte, len(pkt.Payload)) + copy(payload, pkt.Payload) + + session.OpusFramesLock.Lock() + session.OpusFrames = append(session.OpusFrames, payload) + session.OpusFramesLock.Unlock() + } +} diff --git a/core/http/endpoints/openai/types/types.go b/core/http/endpoints/openai/types/types.go index 751e79b6fbd5..2f75486adcc3 100644 --- a/core/http/endpoints/openai/types/types.go +++ b/core/http/endpoints/openai/types/types.go @@ -712,17 +712,39 @@ type SessionAudioInput struct { // Configuration for input audio noise reduction. This can be set to null to turn off. Noise reduction filters audio added to the input audio buffer before it is sent to VAD and the model. Filtering the audio can improve VAD and turn detection accuracy (reducing false positives) and model performance by improving perception of the input audio. NoiseReduction *AudioNoiseReduction `json:"noise_reduction,omitempty"` - // Configuration for input audio transcription, defaults to off and can be set to null to turn off once on. Input audio transcription is not native to the model, since the model consumes audio directly. Transcription runs asynchronously through the /audio/transcriptions endpoint and should be treated as guidance of input audio content rather than precisely what the model heard. The client can optionally set the language and prompt for transcription, these offer additional guidance to the transcription service. + // Configuration for turn detection: Server VAD or Semantic VAD. Set to null + // to turn off, in which case the client must manually trigger model response. TurnDetection *TurnDetectionUnion `json:"turn_detection,omitempty"` - // Configuration for turn detection, ether Server VAD or Semantic VAD. This can be set to null to turn off, in which case the client must manually trigger model response. - // - // Server VAD means that the model will detect the start and end of speech based on audio volume and respond at the end of user speech. - // - // Semantic VAD is more advanced and uses a turn detection model (in conjunction with VAD) to semantically estimate whether the user has finished speaking, then dynamically sets a timeout based on this probability. For example, if user audio trails off with "uhhm", the model will score a low probability of turn end and wait longer for the user to continue speaking. This can be useful for more natural conversations, but may have a higher latency. + // True when the JSON payload explicitly included "turn_detection" (even as null). + // Standard Go JSON can't distinguish absent from null for pointer fields. + TurnDetectionSet bool `json:"-"` + + // Configuration for input audio transcription, defaults to off and can be + // set to null to turn off once on. Transcription *AudioTranscription `json:"transcription,omitempty"` } +func (s *SessionAudioInput) UnmarshalJSON(data []byte) error { + // Check whether turn_detection key exists in the raw JSON. + var raw map[string]json.RawMessage + if err := json.Unmarshal(data, &raw); err != nil { + return err + } + + type alias SessionAudioInput + var a alias + if err := json.Unmarshal(data, &a); err != nil { + return err + } + *s = SessionAudioInput(a) + + if _, ok := raw["turn_detection"]; ok { + s.TurnDetectionSet = true + } + return nil +} + type SessionAudioOutput struct { Format *AudioFormatUnion `json:"format,omitempty"` Speed float32 `json:"speed,omitempty"` @@ -1012,10 +1034,13 @@ func (r *SessionUnion) UnmarshalJSON(data []byte) error { return err } switch SessionType(t.Type) { - case SessionTypeRealtime: - return json.Unmarshal(data, &r.Realtime) + case SessionTypeRealtime, "": + // Default to realtime when no type field is present (e.g. session.update events). + r.Realtime = &RealtimeSession{} + return json.Unmarshal(data, r.Realtime) case SessionTypeTranscription: - return json.Unmarshal(data, &r.Transcription) + r.Transcription = &TranscriptionSession{} + return json.Unmarshal(data, r.Transcription) default: return fmt.Errorf("unknown session type: %s", t.Type) } diff --git a/core/http/middleware/trace.go b/core/http/middleware/trace.go index 15bc970a965e..2d7bcf16719d 100644 --- a/core/http/middleware/trace.go +++ b/core/http/middleware/trace.go @@ -158,7 +158,7 @@ func GetTraces() []APIExchange { mu.Unlock() sort.Slice(traces, func(i, j int) bool { - return traces[i].Timestamp.Before(traces[j].Timestamp) + return traces[i].Timestamp.After(traces[j].Timestamp) }) return traces diff --git a/core/http/react-ui/src/pages/Settings.jsx b/core/http/react-ui/src/pages/Settings.jsx index b112c91215ad..9ac3c00a9cf0 100644 --- a/core/http/react-ui/src/pages/Settings.jsx +++ b/core/http/react-ui/src/pages/Settings.jsx @@ -55,6 +55,7 @@ const SECTIONS = [ { id: 'memory', icon: 'fa-memory', color: 'var(--color-accent)', label: 'Memory' }, { id: 'backends', icon: 'fa-cogs', color: 'var(--color-accent)', label: 'Backends' }, { id: 'performance', icon: 'fa-gauge-high', color: 'var(--color-success)', label: 'Performance' }, + { id: 'tracing', icon: 'fa-bug', color: 'var(--color-warning)', label: 'Tracing' }, { id: 'api', icon: 'fa-globe', color: 'var(--color-warning)', label: 'API & CORS' }, { id: 'p2p', icon: 'fa-network-wired', color: 'var(--color-accent)', label: 'P2P' }, { id: 'galleries', icon: 'fa-images', color: 'var(--color-accent)', label: 'Galleries' }, @@ -327,10 +328,19 @@ export default function Settings() { update('debug', v)} /> - + + + + {/* Tracing */} +
sectionRefs.current.tracing = el} style={{ marginBottom: 'var(--spacing-xl)' }}> +

+ Tracing +

+
+ update('enable_tracing', v)} /> - + update('tracing_max_items', parseInt(e.target.value) || 0)} placeholder="100" disabled={!settings.enable_tracing} />
diff --git a/core/http/react-ui/src/pages/Talk.jsx b/core/http/react-ui/src/pages/Talk.jsx index 590b89bda32d..fa9a784fad09 100644 --- a/core/http/react-ui/src/pages/Talk.jsx +++ b/core/http/react-ui/src/pages/Talk.jsx @@ -1,196 +1,688 @@ -import { useState, useRef, useCallback } from 'react' +import { useState, useRef, useEffect, useCallback } from 'react' import { useOutletContext } from 'react-router-dom' -import ModelSelector from '../components/ModelSelector' -import LoadingSpinner from '../components/LoadingSpinner' -import { chatApi, ttsApi, audioApi } from '../utils/api' +import { realtimeApi } from '../utils/api' + +const STATUS_STYLES = { + disconnected: { icon: 'fa-solid fa-circle', color: 'var(--color-text-secondary)', bg: 'transparent' }, + connecting: { icon: 'fa-solid fa-spinner fa-spin', color: 'var(--color-primary)', bg: 'var(--color-primary-light)' }, + connected: { icon: 'fa-solid fa-circle', color: 'var(--color-success)', bg: 'rgba(34,197,94,0.1)' }, + listening: { icon: 'fa-solid fa-microphone', color: 'var(--color-success)', bg: 'rgba(34,197,94,0.1)' }, + thinking: { icon: 'fa-solid fa-brain fa-beat', color: 'var(--color-primary)', bg: 'var(--color-primary-light)' }, + speaking: { icon: 'fa-solid fa-volume-high fa-beat-fade', color: 'var(--color-accent)', bg: 'rgba(168,85,247,0.1)' }, + error: { icon: 'fa-solid fa-circle', color: 'var(--color-error)', bg: 'var(--color-error-light)' }, +} export default function Talk() { const { addToast } = useOutletContext() - const [llmModel, setLlmModel] = useState('') - const [whisperModel, setWhisperModel] = useState('') - const [ttsModel, setTtsModel] = useState('') - const [isRecording, setIsRecording] = useState(false) - const [loading, setLoading] = useState(false) - const [status, setStatus] = useState('Press the record button to start talking.') - const [audioUrl, setAudioUrl] = useState(null) - const [conversationHistory, setConversationHistory] = useState([]) - const mediaRecorderRef = useRef(null) - const chunksRef = useRef([]) + + // Pipeline models + const [pipelineModels, setPipelineModels] = useState([]) + const [selectedModel, setSelectedModel] = useState('') + const [modelsLoading, setModelsLoading] = useState(true) + + // Connection state + const [status, setStatus] = useState('disconnected') + const [statusText, setStatusText] = useState('Disconnected') + const [isConnected, setIsConnected] = useState(false) + + // Transcript + const [transcript, setTranscript] = useState([]) + const streamingRef = useRef(null) // tracks the index of the in-progress assistant message + + // Session settings + const [instructions, setInstructions] = useState( + 'You are a helpful voice assistant. Your responses will be spoken aloud using text-to-speech, so keep them concise and conversational. Do not use markdown formatting, bullet points, numbered lists, code blocks, or special characters. Speak naturally as you would in a phone conversation.' + ) + const [voice, setVoice] = useState('') + const [voiceEdited, setVoiceEdited] = useState(false) + const [language, setLanguage] = useState('') + + // Diagnostics + const [diagVisible, setDiagVisible] = useState(false) + + // Refs for WebRTC / audio + const pcRef = useRef(null) + const dcRef = useRef(null) + const localStreamRef = useRef(null) const audioRef = useRef(null) + const hasErrorRef = useRef(false) + + // Diagnostics refs + const audioCtxRef = useRef(null) + const analyserRef = useRef(null) + const diagFrameRef = useRef(null) + const statsIntervalRef = useRef(null) + const waveCanvasRef = useRef(null) + const specCanvasRef = useRef(null) + const transcriptEndRef = useRef(null) + + // Diagnostics stats (not worth re-rendering for every frame) + const [diagStats, setDiagStats] = useState({ + peakFreq: '--', thd: '--', rms: '--', sampleRate: '--', + packetsRecv: '--', packetsLost: '--', jitter: '--', concealed: '--', raw: '', + }) + + // Fetch pipeline models on mount + useEffect(() => { + realtimeApi.pipelineModels() + .then(models => { + setPipelineModels(models || []) + if (models?.length > 0) { + setSelectedModel(models[0].name) + if (!voiceEdited) setVoice(models[0].voice || '') + } + }) + .catch(err => addToast(`Failed to load pipeline models: ${err.message}`, 'error')) + .finally(() => setModelsLoading(false)) + }, []) + + // Auto-scroll transcript + useEffect(() => { + transcriptEndRef.current?.scrollIntoView({ behavior: 'smooth' }) + }, [transcript]) + + const selectedModelInfo = pipelineModels.find(m => m.name === selectedModel) + + // ── Status helper ── + const updateStatus = useCallback((state, text) => { + setStatus(state) + setStatusText(text || state) + }, []) + + // ── Session update ── + const sendSessionUpdate = useCallback(() => { + const dc = dcRef.current + if (!dc || dc.readyState !== 'open') return + if (!instructions.trim() && !voice.trim() && !language.trim()) return + + const session = {} + if (instructions.trim()) session.instructions = instructions.trim() + if (voice.trim() || language.trim()) { + session.audio = {} + if (voice.trim()) session.audio.output = { voice: voice.trim() } + if (language.trim()) session.audio.input = { transcription: { language: language.trim() } } + } + + dc.send(JSON.stringify({ type: 'session.update', session })) + }, [instructions, voice, language]) + + // ── Server event handler ── + const handleServerEvent = useCallback((event) => { + switch (event.type) { + case 'session.created': + sendSessionUpdate() + updateStatus('listening', 'Listening...') + break + case 'session.updated': + break + case 'input_audio_buffer.speech_started': + updateStatus('listening', 'Hearing you speak...') + break + case 'input_audio_buffer.speech_stopped': + updateStatus('thinking', 'Processing...') + break + case 'conversation.item.input_audio_transcription.completed': + if (event.transcript) { + streamingRef.current = null + setTranscript(prev => [...prev, { role: 'user', text: event.transcript }]) + } + updateStatus('thinking', 'Generating response...') + break + case 'response.output_audio_transcript.delta': + if (event.delta) { + setTranscript(prev => { + if (streamingRef.current !== null) { + const updated = [...prev] + updated[streamingRef.current] = { + ...updated[streamingRef.current], + text: updated[streamingRef.current].text + event.delta, + } + return updated + } + streamingRef.current = prev.length + return [...prev, { role: 'assistant', text: event.delta }] + }) + } + break + case 'response.output_audio_transcript.done': + if (event.transcript) { + setTranscript(prev => { + if (streamingRef.current !== null) { + const updated = [...prev] + updated[streamingRef.current] = { ...updated[streamingRef.current], text: event.transcript } + return updated + } + return [...prev, { role: 'assistant', text: event.transcript }] + }) + } + streamingRef.current = null + break + case 'response.output_audio.delta': + updateStatus('speaking', 'Speaking...') + break + case 'response.done': + updateStatus('listening', 'Listening...') + break + case 'error': + hasErrorRef.current = true + updateStatus('error', 'Error: ' + (event.error?.message || 'Unknown error')) + break + } + }, [sendSessionUpdate, updateStatus]) - const startRecording = async () => { - if (!navigator.mediaDevices) { - addToast('MediaDevices API not supported', 'error') + // ── Connect ── + const connect = useCallback(async () => { + if (!selectedModel) { + addToast('Please select a pipeline model first.', 'warning') return } + if (!navigator.mediaDevices?.getUserMedia) { + updateStatus('error', 'Microphone access requires HTTPS or localhost.') + return + } + + updateStatus('connecting', 'Connecting...') + setIsConnected(true) + try { - const stream = await navigator.mediaDevices.getUserMedia({ audio: true }) - const recorder = new MediaRecorder(stream) - chunksRef.current = [] - recorder.ondataavailable = (e) => chunksRef.current.push(e.data) - recorder.start() - mediaRecorderRef.current = recorder - setIsRecording(true) - setStatus('Recording... Click to stop.') + const localStream = await navigator.mediaDevices.getUserMedia({ audio: true }) + localStreamRef.current = localStream + + const pc = new RTCPeerConnection({}) + pcRef.current = pc + + for (const track of localStream.getAudioTracks()) { + pc.addTrack(track, localStream) + } + + pc.ontrack = (event) => { + if (audioRef.current) audioRef.current.srcObject = event.streams[0] + if (diagVisible) startDiagnostics() + } + + const dc = pc.createDataChannel('oai-events') + dcRef.current = dc + dc.onmessage = (msg) => { + try { + const text = typeof msg.data === 'string' ? msg.data : new TextDecoder().decode(msg.data) + handleServerEvent(JSON.parse(text)) + } catch (e) { + console.error('Failed to parse server event:', e) + } + } + dc.onclose = () => console.log('Data channel closed') + + pc.onconnectionstatechange = () => { + if (pc.connectionState === 'connected') { + updateStatus('connected', 'Connected, waiting for session...') + } else if (pc.connectionState === 'failed' || pc.connectionState === 'closed') { + disconnect() + } + } + + const offer = await pc.createOffer() + await pc.setLocalDescription(offer) + + await new Promise((resolve) => { + if (pc.iceGatheringState === 'complete') return resolve() + pc.onicegatheringstatechange = () => { + if (pc.iceGatheringState === 'complete') resolve() + } + setTimeout(resolve, 5000) + }) + + const data = await realtimeApi.call({ + sdp: pc.localDescription.sdp, + model: selectedModel, + }) + + await pc.setRemoteDescription({ type: 'answer', sdp: data.sdp }) } catch (err) { - addToast(`Microphone error: ${err.message}`, 'error') + hasErrorRef.current = true + updateStatus('error', 'Connection failed: ' + err.message) + disconnect() + } + }, [selectedModel, diagVisible, handleServerEvent, updateStatus, addToast]) + + // ── Disconnect ── + const disconnect = useCallback(() => { + stopDiagnostics() + if (dcRef.current) { dcRef.current.close(); dcRef.current = null } + if (pcRef.current) { pcRef.current.close(); pcRef.current = null } + if (localStreamRef.current) { + localStreamRef.current.getTracks().forEach(t => t.stop()) + localStreamRef.current = null } + if (audioRef.current) audioRef.current.srcObject = null + + if (!hasErrorRef.current) updateStatus('disconnected', 'Disconnected') + hasErrorRef.current = false + setIsConnected(false) + }, [updateStatus]) + + // Cleanup on unmount + useEffect(() => { + return () => { + stopDiagnostics() + if (dcRef.current) dcRef.current.close() + if (pcRef.current) pcRef.current.close() + if (localStreamRef.current) localStreamRef.current.getTracks().forEach(t => t.stop()) + } + }, []) + + // ── Test tone ── + const sendTestTone = useCallback(() => { + const dc = dcRef.current + if (!dc || dc.readyState !== 'open') return + dc.send(JSON.stringify({ type: 'test_tone' })) + setTranscript(prev => [...prev, { role: 'assistant', text: '(Test tone requested)' }]) + }, []) + + // ── Diagnostics ── + function startDiagnostics() { + const audioEl = audioRef.current + if (!audioEl?.srcObject) return + + if (!audioCtxRef.current) { + const ctx = new AudioContext() + const source = ctx.createMediaStreamSource(audioEl.srcObject) + const analyser = ctx.createAnalyser() + analyser.fftSize = 8192 + analyser.smoothingTimeConstant = 0.3 + source.connect(analyser) + audioCtxRef.current = ctx + analyserRef.current = analyser + setDiagStats(prev => ({ ...prev, sampleRate: ctx.sampleRate + ' Hz' })) + } + + if (!diagFrameRef.current) drawDiagnostics() + if (!statsIntervalRef.current) { + pollWebRTCStats() + statsIntervalRef.current = setInterval(pollWebRTCStats, 1000) + } + } + + function stopDiagnostics() { + if (diagFrameRef.current) { cancelAnimationFrame(diagFrameRef.current); diagFrameRef.current = null } + if (statsIntervalRef.current) { clearInterval(statsIntervalRef.current); statsIntervalRef.current = null } + if (audioCtxRef.current) { audioCtxRef.current.close(); audioCtxRef.current = null; analyserRef.current = null } } - const stopRecording = useCallback(() => { - if (!mediaRecorderRef.current) return - - mediaRecorderRef.current.onstop = async () => { - setIsRecording(false) - setLoading(true) - - const audioBlob = new Blob(chunksRef.current, { type: 'audio/webm' }) - - try { - // 1. Transcribe - setStatus('Transcribing audio...') - const formData = new FormData() - formData.append('file', audioBlob) - formData.append('model', whisperModel) - const transcription = await audioApi.transcribe(formData) - const userText = transcription.text - - setStatus(`You said: "${userText}". Generating response...`) - - // 2. Chat completion - const newHistory = [...conversationHistory, { role: 'user', content: userText }] - const chatResponse = await chatApi.complete({ - model: llmModel, - messages: newHistory, - }) - const assistantText = chatResponse?.choices?.[0]?.message?.content || '' - const updatedHistory = [...newHistory, { role: 'assistant', content: assistantText }] - setConversationHistory(updatedHistory) - - setStatus(`Response: "${assistantText}". Generating speech...`) - - // 3. TTS - const ttsBlob = await ttsApi.generateV1({ input: assistantText, model: ttsModel }) - const url = URL.createObjectURL(ttsBlob) - setAudioUrl(url) - setStatus('Press the record button to continue.') - - // Auto-play - setTimeout(() => audioRef.current?.play(), 100) - } catch (err) { - addToast(`Error: ${err.message}`, 'error') - setStatus('Error occurred. Try again.') - } finally { - setLoading(false) + function drawDiagnostics() { + const analyser = analyserRef.current + if (!analyser) { diagFrameRef.current = null; return } + + diagFrameRef.current = requestAnimationFrame(drawDiagnostics) + + // Waveform + const waveCanvas = waveCanvasRef.current + if (waveCanvas) { + const wCtx = waveCanvas.getContext('2d') + const timeData = new Float32Array(analyser.fftSize) + analyser.getFloatTimeDomainData(timeData) + const w = waveCanvas.width, h = waveCanvas.height + wCtx.fillStyle = '#000'; wCtx.fillRect(0, 0, w, h) + wCtx.strokeStyle = '#0f0'; wCtx.lineWidth = 1; wCtx.beginPath() + const sliceWidth = w / timeData.length + let x = 0 + for (let i = 0; i < timeData.length; i++) { + const y = (1 - timeData[i]) * h / 2 + i === 0 ? wCtx.moveTo(x, y) : wCtx.lineTo(x, y) + x += sliceWidth } + wCtx.stroke() + + let sumSq = 0 + for (let i = 0; i < timeData.length; i++) sumSq += timeData[i] * timeData[i] + const rms = Math.sqrt(sumSq / timeData.length) + const rmsDb = rms > 0 ? (20 * Math.log10(rms)).toFixed(1) : '-Inf' + setDiagStats(prev => ({ ...prev, rms: rmsDb + ' dBFS' })) } - mediaRecorderRef.current.stop() - mediaRecorderRef.current.stream?.getTracks().forEach(t => t.stop()) - }, [whisperModel, llmModel, ttsModel, conversationHistory]) + // Spectrum + const specCanvas = specCanvasRef.current + if (specCanvas && audioCtxRef.current) { + const sCtx = specCanvas.getContext('2d') + const freqData = new Float32Array(analyser.frequencyBinCount) + analyser.getFloatFrequencyData(freqData) + const sw = specCanvas.width, sh = specCanvas.height + sCtx.fillStyle = '#000'; sCtx.fillRect(0, 0, sw, sh) + + const sampleRate = audioCtxRef.current.sampleRate + const binHz = sampleRate / analyser.fftSize + const maxFreqDisplay = 4000 + const maxBin = Math.min(Math.ceil(maxFreqDisplay / binHz), freqData.length) + const barWidth = sw / maxBin + + sCtx.fillStyle = '#0cf' + let peakBin = 0, peakVal = -Infinity + for (let i = 0; i < maxBin; i++) { + const db = freqData[i] + if (db > peakVal) { peakVal = db; peakBin = i } + const barH = Math.max(0, ((db + 100) / 100) * sh) + sCtx.fillRect(i * barWidth, sh - barH, Math.max(1, barWidth - 0.5), barH) + } + + // Frequency labels + sCtx.fillStyle = '#888'; sCtx.font = '10px monospace' + for (let f = 500; f <= maxFreqDisplay; f += 500) { + sCtx.fillText(f + '', (f / binHz) * barWidth - 10, sh - 2) + } + + // 440 Hz marker + const bin440 = Math.round(440 / binHz) + const x440 = bin440 * barWidth + sCtx.strokeStyle = '#f00'; sCtx.lineWidth = 1 + sCtx.beginPath(); sCtx.moveTo(x440, 0); sCtx.lineTo(x440, sh); sCtx.stroke() + sCtx.fillStyle = '#f00'; sCtx.fillText('440', x440 + 2, 10) - const resetConversation = () => { - setConversationHistory([]) - setAudioUrl(null) - setStatus('Conversation reset. Press record to start.') - addToast('Conversation reset', 'info') + const peakFreq = peakBin * binHz + const fundamentalBin = Math.round(440 / binHz) + const fundamentalPower = Math.pow(10, freqData[fundamentalBin] / 10) + let harmonicPower = 0 + for (let h = 2; h <= 10; h++) { + const hBin = Math.round(440 * h / binHz) + if (hBin < freqData.length) harmonicPower += Math.pow(10, freqData[hBin] / 10) + } + const thd = fundamentalPower > 0 + ? (Math.sqrt(harmonicPower / fundamentalPower) * 100).toFixed(1) + '%' + : '--%' + + setDiagStats(prev => ({ + ...prev, + peakFreq: peakFreq.toFixed(0) + ' Hz (' + peakVal.toFixed(1) + ' dB)', + thd, + })) + } } - const allModelsSet = llmModel && whisperModel && ttsModel + async function pollWebRTCStats() { + const pc = pcRef.current + if (!pc) return + try { + const stats = await pc.getStats() + const raw = [] + stats.forEach((report) => { + if (report.type === 'inbound-rtp' && report.kind === 'audio') { + setDiagStats(prev => ({ + ...prev, + packetsRecv: report.packetsReceived ?? '--', + packetsLost: report.packetsLost ?? '--', + jitter: report.jitter !== undefined ? (report.jitter * 1000).toFixed(1) + ' ms' : '--', + concealed: report.concealedSamples ?? '--', + })) + raw.push('-- inbound-rtp (audio) --') + raw.push(' packetsReceived: ' + report.packetsReceived) + raw.push(' packetsLost: ' + report.packetsLost) + raw.push(' jitter: ' + (report.jitter !== undefined ? (report.jitter * 1000).toFixed(2) + ' ms' : 'N/A')) + raw.push(' bytesReceived: ' + report.bytesReceived) + raw.push(' concealedSamples: ' + report.concealedSamples) + raw.push(' totalSamplesReceived: ' + report.totalSamplesReceived) + } + }) + setDiagStats(prev => ({ ...prev, raw: raw.join('\n') })) + } catch (_e) { /* stats polling error */ } + } + + const toggleDiagnostics = useCallback(() => { + setDiagVisible(prev => { + const next = !prev + if (next) { + setTimeout(startDiagnostics, 0) + } else { + stopDiagnostics() + } + return next + }) + }, []) + + const statusStyle = STATUS_STYLES[status] || STATUS_STYLES.disconnected + // ── Render ── return (
-
+

Talk

-

Voice conversation with AI

+

Real-time voice conversation via WebRTC

- {/* Main interaction area */} -
- {/* Big record button */} - - - {/* Status */} -

- {loading ? : null} - {' '}{status} -

- - {/* Recording indicator */} - {isRecording && ( +
+ {/* Connection status */} +
+ + {statusText} +
+ + {/* Info note */} +
+ +

+ Note: Select a pipeline model and click Connect. + Your microphone streams continuously; the server detects speech and responds automatically. +

+
+ + {/* Pipeline model selector */} +
+ + +
+ + {/* Pipeline details */} + {selectedModelInfo && (
- - Recording... + {[ + { label: 'VAD', value: selectedModelInfo.vad }, + { label: 'Transcription', value: selectedModelInfo.transcription }, + { label: 'LLM', value: selectedModelInfo.llm }, + { label: 'TTS', value: selectedModelInfo.tts }, + ].map(item => ( +
+
{item.label}
+
{item.value}
+
+ ))}
)} - {/* Audio playback */} - {audioUrl && ( -
-
- - - + + + + + +
- + {{template "views/partials/footer" .}}
diff --git a/core/http/views/traces.html b/core/http/views/traces.html index 3e66c82b41f8..6287cc47782f 100644 --- a/core/http/views/traces.html +++ b/core/http/views/traces.html @@ -254,12 +254,54 @@

Response

+ + +