-
Notifications
You must be signed in to change notification settings - Fork 53
Expand file tree
/
Copy pathopenai_realtime_client.py
More file actions
171 lines (149 loc) · 6.91 KB
/
openai_realtime_client.py
File metadata and controls
171 lines (149 loc) · 6.91 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
import websockets
import json
import base64
import logging
import time
from typing import Optional, Callable, Dict, List
import asyncio
import os
from prompts import PROMPTS
from config import OPENAI_REALTIME_MODEL
from realtime_client_base import RealtimeClientBase
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
class OpenAIRealtimeAudioTextClient(RealtimeClientBase):
def __init__(self, api_key: str, model: str = OPENAI_REALTIME_MODEL):
super().__init__(api_key)
self.model = model
self.base_url = "wss://api.openai.com/v1/realtime"
self.last_audio_time = None
self.auto_commit_interval = 5
async def connect(self, modalities: List[str] = ["text"], session_mode: str = "conversation"):
"""Connect to OpenAI's realtime API and configure the session"""
headers = {
"Authorization": f"Bearer {self.api_key}",
}
# Support both websockets param names across versions: extra_headers (older) and additional_headers (newer)
try:
self.ws = await websockets.connect(
f"{self.base_url}?model={self.model}",
extra_headers=headers,
)
except TypeError:
# Fallback for newer versions where the kwarg is 'additional_headers'
self.ws = await websockets.connect(
f"{self.base_url}?model={self.model}",
additional_headers=headers,
)
# Wait for session creation
response = await self.ws.recv()
response_data = json.loads(response)
if response_data["type"] == "session.created":
self.session_id = response_data["session"]["id"]
logger.info(f"Session created with ID: {self.session_id}")
session_config_payload = {
"type": "realtime" if session_mode == "conversation" else "transcription",
"output_modalities": modalities,
"audio": {
"input": {
"format": {"type": "audio/pcm", "rate": 24000},
"transcription": {"model": "gpt-4o-transcribe"},
"turn_detection": None,
}
},
}
if session_mode == "transcription":
logger.info("Configuring session for transcription mode.")
else:
session_config_payload["instructions"] = PROMPTS['paraphrase-gpt-realtime-enhanced']
logger.info("Configuring session for conversation mode with transcription and no turn detection.")
# Configure session
await self.ws.send(json.dumps({
"type": "session.update",
"session": session_config_payload
}, ensure_ascii=False))
# Register the default handler
self.register_handler("default", self.default_handler)
# Start the receiver coroutine
self.receive_task = asyncio.create_task(self.receive_messages())
async def send_instructions_audio(self):
"""Send the instructions.wav file as audio input to be appended to current buffer"""
instructions_path = "instructions.wav"
if not os.path.exists(instructions_path):
logger.warning(f"Instructions audio file not found: {instructions_path}")
return
try:
with open(instructions_path, "rb") as f:
audio_data = f.read()
# Send the instructions audio to the buffer (appends to existing user audio)
await self.send_audio(audio_data)
logger.info("Sent instructions audio to OpenAI buffer (appended to user audio)")
except Exception as e:
logger.error(f"Error sending instructions audio: {e}")
async def receive_messages(self):
try:
async for message in self.ws:
data = json.loads(message)
message_type = data.get("type", "default")
handler = self.handlers.get(message_type, self.handlers.get("default"))
if handler:
await handler(data)
else:
logger.warning(f"No handler for message type: {message_type}")
except websockets.exceptions.ConnectionClosed as e:
logger.error(f"OpenAI WebSocket connection closed: {e}")
except Exception as e:
logger.error(f"Error in receive_messages: {e}", exc_info=True)
async def default_handler(self, data: dict):
"""Override default handler for OpenAI-specific logging"""
message_type = data.get("type", "unknown")
logger.warning(f"Unhandled message type received from OpenAI: {message_type}")
async def send_audio(self, audio_data: bytes):
if self._is_ws_open():
await self.ws.send(json.dumps({
"type": "input_audio_buffer.append",
"audio": base64.b64encode(audio_data).decode('utf-8')
}))
else:
logger.error("WebSocket is not open. Cannot send audio.")
async def commit_audio(self):
"""Commit the audio buffer and notify OpenAI"""
if self._is_ws_open():
commit_message = json.dumps({"type": "input_audio_buffer.commit"})
await self.ws.send(commit_message)
logger.info("Sent input_audio_buffer.commit message to OpenAI")
# No recv call here. The receive_messages coroutine handles incoming messages.
else:
logger.error("WebSocket is not open. Cannot commit audio.")
async def clear_audio_buffer(self):
"""Clear the audio buffer"""
if self._is_ws_open():
clear_message = json.dumps({"type": "input_audio_buffer.clear"})
await self.ws.send(clear_message)
logger.info("Sent input_audio_buffer.clear message to OpenAI")
else:
logger.error("WebSocket is not open. Cannot clear audio buffer.")
async def start_response(self, instructions: str):
"""Start a new response with given instructions"""
if self._is_ws_open():
await self.ws.send(json.dumps({
"type": "response.create",
"response": {
"output_modalities": ["text"],
"instructions": instructions
}
}))
logger.info(f"Started response with instructions: {instructions}")
else:
logger.error("WebSocket is not open. Cannot start response.")
async def close(self):
"""Close the WebSocket connection"""
if self.ws:
await self.ws.close()
logger.info("Closed OpenAI WebSocket connection")
if self.receive_task:
self.receive_task.cancel()
try:
await self.receive_task
except asyncio.CancelledError:
pass