Skip to content

Commit 3335c16

Browse files
localai-botteam-coding-agent-1
authored andcommitted
feat: Add WebSocket endpoint support for OpenAI Responses API
- Add WebSocket route handling for /v1/responses and /responses endpoints - Add WebSocket message types to schema (ORWebSocketClientMessage, ORWebSocketServerEvent, etc.) - Add connection-local cache types for response storage - Implement initial WebSocket infrastructure (handler to be added in next commit) Signed-off-by: team-coding-agent-1 <team-coding-agent-1@localai.dev>
1 parent bf4f8da commit 3335c16

5 files changed

Lines changed: 245 additions & 5 deletions

File tree

IMPLEMENTATION_PLAN.md

Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
# WebSocket Mode Implementation Plan for OpenAI Responses API
2+
3+
## Overview
4+
Implement WebSocket support for LocalAI's OpenAI API-compatible Responses endpoint, enabling persistent WebSocket connections for long-running, tool-call-heavy agentic workflows.
5+
6+
## Technical Requirements
7+
8+
### 1. WebSocket Endpoint
9+
- **Endpoint**: `ws://<host>:<port>/v1/responses`
10+
- **Upgrade**: HTTP upgrade from POST /v1/responses when `Upgrade: websocket` header is present
11+
12+
### 2. Message Types (Client → Server)
13+
14+
#### response.create (Initial Turn)
15+
```json
16+
{
17+
"type": "response.create",
18+
"model": "gpt-4o",
19+
"store": false,
20+
"input": [...],
21+
"tools": []
22+
}
23+
```
24+
25+
#### response.create with Continuation (Subsequent Turns)
26+
```json
27+
{
28+
"type": "response.create",
29+
"model": "gpt-4o",
30+
"store": false,
31+
"previous_response_id": "resp_123",
32+
"input": [...],
33+
"tools": []
34+
}
35+
```
36+
37+
### 3. Response Events (Server → Client)
38+
39+
1. **response.created** - Response object created
40+
2. **response.progress** - Incremental output
41+
3. **response.function_call_arguments.delta** - Streaming function arguments
42+
4. **response.function_call_arguments.done** - Function call complete
43+
5. **response.done** - Final response
44+
45+
### 4. Connection Management
46+
- Track active connections with 60-minute timeout
47+
- Connection-local cache for responses (when store=false)
48+
- One in-flight response at a time per connection
49+
50+
### 5. Error Handling
51+
- `previous_response_not_found` (400)
52+
- `websocket_connection_limit_reached` (400)
53+
54+
## Implementation Steps
55+
56+
### Step 1: Add WebSocket Schema Types
57+
- Add WebSocket message types to `core/schema/openresponses.go`
58+
- Add connection-related types
59+
60+
### Step 2: Add WebSocket Route
61+
- Modify `core/http/routes/openresponses.go` to handle WebSocket upgrade
62+
- Add GET /v1/responses WebSocket endpoint
63+
64+
### Step 3: Create WebSocket Handler
65+
- Create `core/http/endpoints/openresponses/websocket.go`
66+
- Implement connection handling
67+
- Implement message parsing
68+
- Implement event streaming
69+
70+
### Step 4: Add Connection Store
71+
- Implement connection management in store
72+
- Add 60-minute timeout
73+
- Add connection-local cache
74+
75+
## Files to Modify/Create
76+
1. `core/schema/openresponses.go` - Add WebSocket types
77+
2. `core/http/routes/openresponses.go` - Add WebSocket route
78+
3. `core/http/endpoints/openresponses/websocket.go` - New WebSocket handler (create)
79+
4. `core/http/endpoints/openresponses/store.go` - Add connection management

backend/python/nemo/backend.py

Lines changed: 80 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
#!/usr/bin/env python3
22
"""
3-
gRPC server of LocalAI for NVIDIA NEMO Toolkit ASR.
3+
GRPC server of LocalAI for NVIDIA NEMO Toolkit ASR.
44
"""
55
from concurrent import futures
66
import time
@@ -12,6 +12,14 @@
1212
import backend_pb2_grpc
1313
import torch
1414
import nemo.collections.asr as nemo_asr
15+
import numpy as np
16+
17+
try:
18+
import torchaudio
19+
TORCHAUDIO_AVAILABLE = True
20+
except ImportError:
21+
TORCHAUDIO_AVAILABLE = False
22+
print("[WARNING] torchaudio not available, will use fallback audio loading", file=sys.stderr)
1523

1624
import grpc
1725

@@ -36,6 +44,50 @@ def is_int(s):
3644
MAX_WORKERS = int(os.environ.get('PYTHON_GRPC_MAX_WORKERS', '1'))
3745

3846

47+
def load_audio_np(audio_path, target_sample_rate=16000):
48+
"""Load audio file as numpy array using available methods."""
49+
if TORCHAUDIO_AVAILABLE:
50+
try:
51+
waveform, sample_rate = torchaudio.load(audio_path)
52+
# Convert to mono if stereo
53+
if waveform.shape[0] > 1:
54+
waveform = waveform.mean(dim=0, keepdim=True)
55+
# Resample if needed
56+
if sample_rate != target_sample_rate:
57+
resampler = torchaudio.transforms.Resample(sample_rate, target_sample_rate)
58+
waveform = resampler(waveform)
59+
# Convert to numpy
60+
audio_np = waveform.squeeze().numpy()
61+
return audio_np, target_sample_rate
62+
except Exception as e:
63+
print(f"[WARNING] torchaudio loading failed: {e}, trying fallback", file=sys.stderr)
64+
65+
# Fallback: try using scipy or soundfile
66+
try:
67+
import soundfile as sf
68+
audio_np, sample_rate = sf.read(audio_path)
69+
if audio_np.ndim > 1:
70+
audio_np = audio_np.mean(axis=1)
71+
if sample_rate != target_sample_rate:
72+
from scipy.signal import resample
73+
num_samples = int(len(audio_np) * target_sample_rate / sample_rate)
74+
audio_np = resample(audio_np, num_samples)
75+
return audio_np, target_sample_rate
76+
except ImportError:
77+
pass
78+
79+
try:
80+
from scipy.io import wavfile
81+
sample_rate, audio_np = wavfile.read(audio_path)
82+
if audio_np.ndim > 1:
83+
audio_np = audio_np.mean(axis=1)
84+
return audio_np, sample_rate
85+
except ImportError:
86+
pass
87+
88+
raise RuntimeError("No audio loading library available (torchaudio, soundfile, scipy)")
89+
90+
3991
class BackendServicer(backend_pb2_grpc.BackendServicer):
4092
def Health(self, request, context):
4193
return backend_pb2.Reply(message=bytes("OK", 'utf-8'))
@@ -89,14 +141,37 @@ def AudioTranscription(self, request, context):
89141
print(f"Error: Audio file not found: {audio_path}", file=sys.stderr)
90142
return backend_pb2.TranscriptResult(segments=[], text="")
91143

92-
# NEMO's transcribe method accepts a list of audio paths and returns a list of transcripts
93-
results = self.model.transcribe([audio_path])
94-
144+
# Load audio as numpy array to avoid lhotse dataloader issues
145+
audio_np, sample_rate = load_audio_np(audio_path, target_sample_rate=16000)
146+
147+
# Convert to torch tensor
148+
audio_tensor = torch.from_numpy(audio_np).float()
149+
audio_tensor = audio_tensor.unsqueeze(0) # Add batch dimension
150+
151+
# Use the model's transcribe method with the tensor directly
152+
# Some NEMO models accept audio tensors directly
153+
try:
154+
# Try passing the waveform tensor directly
155+
results = self.model.transcribe(audio_tensor, return_char_alignments=False)
156+
except TypeError:
157+
# Fallback: try with dict format
158+
results = self.model.transcribe(
159+
[{"audio_file": audio_path}],
160+
return_char_alignments=False
161+
)
162+
95163
if not results or len(results) == 0:
164+
print("[WARNING] No transcription results returned", file=sys.stderr)
96165
return backend_pb2.TranscriptResult(segments=[], text="")
97166

98167
# Get the transcript text from the first result
99-
text = results[0]
168+
if isinstance(results, list) and len(results) > 0:
169+
text = results[0]
170+
elif isinstance(results, dict) and "text" in results:
171+
text = results["text"]
172+
else:
173+
text = str(results) if results else ""
174+
100175
if text:
101176
# Create a single segment with the full transcription
102177
result_segments.append(backend_pb2.TranscriptSegment(

backend/python/nemo/requirements.txt

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,3 +4,7 @@ certifi
44
packaging==24.1
55
setuptools
66
pyarrow==20.0.0
7+
torchaudio
8+
soundfile
9+
scipy
10+
numpy

core/http/routes/openresponses.go

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,18 @@ func RegisterOpenResponsesRoutes(app *echo.Echo,
5252
cancelResponseHandler := openresponses.CancelResponseEndpoint()
5353
app.POST("/v1/responses/:id/cancel", cancelResponseHandler, middleware.TraceMiddleware(application))
5454
app.POST("/responses/:id/cancel", cancelResponseHandler, middleware.TraceMiddleware(application))
55+
56+
// WebSocket endpoint for OpenAI Responses API WebSocket Mode
57+
websocketHandler := openresponses.WebSocketEndpoint(
58+
application.ModelConfigLoader(),
59+
application.ModelLoader(),
60+
application.TemplatesEvaluator(),
61+
application.ApplicationConfig(),
62+
)
63+
64+
// WebSocket at /v1/responses (GET method for upgrade)
65+
app.GET("/v1/responses", websocketHandler, middleware.TraceMiddleware(application))
66+
app.GET("/responses", websocketHandler, middleware.TraceMiddleware(application))
5567
}
5668

5769
// setOpenResponsesRequestContext sets up the context and cancel function for Open Responses requests

core/schema/openresponses.go

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package schema
22

33
import (
4+
"time"
45
"context"
56
)
67

@@ -317,3 +318,72 @@ func ORContentPartWithLogprobs(text string, logprobs *Logprobs) ORContentPart {
317318
Logprobs: orLogprobs, // REQUIRED - must always be present as array (empty if none)
318319
}
319320
}
321+
322+
// WebSocket message types for Open Responses API WebSocket Mode
323+
// https://developers.openai.com/api/docs/guides/websocket-mode
324+
325+
// ORWebSocketMessage represents a WebSocket message (client -> server or server -> client)
326+
type ORWebSocketMessage struct {
327+
Type string `json:"type"` // response.create, response.created, response.progress, etc.
328+
}
329+
330+
// ORWebSocketClientMessage represents a client message to the WebSocket endpoint
331+
type ORWebSocketClientMessage struct {
332+
Type string `json:"type"` // "response.create"
333+
Model string `json:"model,omitempty"`
334+
Input interface{} `json:"input,omitempty"`
335+
Tools []ORFunctionTool `json:"tools,omitempty"`
336+
ToolChoice interface{} `json:"tool_choice,omitempty"`
337+
MaxOutputTokens *int `json:"max_output_tokens,omitempty"`
338+
Temperature *float64 `json:"temperature,omitempty"`
339+
TopP *float64 `json:"top_p,omitempty"`
340+
Truncation string `json:"truncation,omitempty"`
341+
Instructions string `json:"instructions,omitempty"`
342+
Reasoning *ORReasoningParam `json:"reasoning,omitempty"`
343+
Metadata map[string]string `json:"metadata,omitempty"`
344+
PreviousResponseID string `json:"previous_response_id,omitempty"`
345+
Store *bool `json:"store,omitempty"`
346+
TextFormat interface{} `json:"text_format,omitempty"`
347+
ServiceTier string `json:"service_tier,omitempty"`
348+
AllowedTools []string `json:"allowed_tools,omitempty"`
349+
ParallelToolCalls *bool `json:"parallel_tool_calls,omitempty"`
350+
PresencePenalty *float64 `json:"presence_penalty,omitempty"`
351+
FrequencyPenalty *float64 `json:"frequency_penalty,omitempty"`
352+
TopLogprobs *int `json:"top_logprobs,omitempty"`
353+
MaxToolCalls *int `json:"max_tool_calls,omitempty"`
354+
Generate *bool `json:"generate,omitempty"` // If false, just warm up and return response_id
355+
}
356+
357+
// ORWebSocketServerEvent represents a server event to the WebSocket
358+
type ORWebSocketServerEvent struct {
359+
Type string `json:"type"` // response.created, response.progress, etc.
360+
ResponseID string `json:"response_id,omitempty"`
361+
Response *ORResponseResource `json:"response,omitempty"`
362+
OutputIndex *int `json:"output_index,omitempty"`
363+
Output []ORItemField `json:"output,omitempty"`
364+
ItemID string `json:"item_id,omitempty"`
365+
Item *ORItemField `json:"item,omitempty"`
366+
ContentIndex *int `json:"content_index,omitempty"`
367+
Delta *string `json:"delta,omitempty"`
368+
Text *string `json:"text,omitempty"`
369+
CallID string `json:"call_id,omitempty"`
370+
Arguments *string `json:"arguments,omitempty"`
371+
Error *ORError `json:"error,omitempty"`
372+
}
373+
374+
// ORWebSocketError represents a WebSocket error event
375+
type ORWebSocketError struct {
376+
Type string `json:"type"` // error
377+
Code string `json:"code,omitempty"` // previous_response_not_found, websocket_connection_limit_reached, etc.
378+
Message string `json:"message"`
379+
Param string `json:"param,omitempty"`
380+
}
381+
382+
// ConnectionLocalCacheEntry represents a cached response in connection-local storage
383+
type ConnectionLocalCacheEntry struct {
384+
ResponseID string
385+
Response *ORResponseResource
386+
Input *ORWebSocketClientMessage
387+
CachedAt time.Time
388+
ExpiresAt *time.Time
389+
}

0 commit comments

Comments
 (0)