-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathvoice_control.py
More file actions
230 lines (201 loc) · 9.48 KB
/
voice_control.py
File metadata and controls
230 lines (201 loc) · 9.48 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
import asyncio
import json
import sys
import logging
import sounddevice as sd
import soundfile as sf
import tempfile
import os
import wave
import base64
from modules.audio_manager import AudioManager
from modules.computer_actions import ComputerActions
from modules.api_client import APIClient
from modules.config import RECEIVE_SAMPLE_RATE
logger = logging.getLogger('VoiceControl')
class VoiceControl:
def __init__(self, mode='voice'):
self.running = True
self.mode = mode
self.audio_manager = AudioManager() if mode == 'voice' else None
self.computer_actions = ComputerActions()
self.api_client = APIClient()
self.send_screenshots = True
async def handle_response(self, response: dict):
"""Handle API responses."""
try:
if "serverContent" in response:
content = response["serverContent"]
if "modelTurn" in content:
turn = content["modelTurn"]
for part in turn.get("parts", []):
if "text" in part:
print("Assistant:", part["text"])
elif "inlineData" in part and self.mode == 'voice':
audio_data = base64.b64decode(part["inlineData"]["data"])
await self.play_audio(audio_data)
if "turnComplete" in content and self.send_screenshots:
try:
await self.api_client.send_screenshot()
except Exception as e:
logger.warning(f"Failed to send screenshot: {str(e)}")
self.send_screenshots = False
elif "toolCall" in response:
tool_call = response["toolCall"]
for fc in tool_call.get("functionCalls", []):
if fc["name"] == "execute_computer_action":
args = fc["args"]
if isinstance(args, str):
args = json.loads(args)
success = await self.computer_actions.execute_action(args)
self.api_client.update_context({"action": args, "success": success})
try:
msg = {
"tool_response": {
"function_responses": [{
"name": fc["name"],
"id": fc["id"],
"response": {
"result": "ok" if success else "failed",
"context": self.api_client.context
}
}]
}
}
await self.api_client.ws.send(json.dumps(msg))
except Exception as e:
logger.error(f"Failed to send tool response: {str(e)}")
except Exception as e:
logger.error(f"Error handling response: {str(e)}")
async def play_audio(self, audio_data: bytes):
"""Play audio response."""
try:
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as temp_file:
with wave.open(temp_file.name, 'wb') as wf:
wf.setnchannels(1)
wf.setsampwidth(2)
wf.setframerate(RECEIVE_SAMPLE_RATE)
wf.writeframes(audio_data)
data, fs = sf.read(temp_file.name, dtype='float32')
sd.play(data, fs)
sd.wait()
os.unlink(temp_file.name)
except Exception as e:
logger.error(f"Error playing audio: {str(e)}")
async def text_input_loop(self):
"""Handle text input mode."""
print("\nComputer Control Assistant")
print("Available commands:")
print("- Normal commands: The assistant will help you control your computer")
print("- 'exit': Quit the program")
while self.running:
try:
user_input = await asyncio.get_event_loop().run_in_executor(
None, lambda: input("\nEnter command (or 'exit' to quit): ").strip()
)
if user_input.lower() == 'exit':
self.running = False
break
if user_input:
try:
await self.api_client.send_text(user_input)
except Exception as e:
print(f"\nError: {str(e)}")
print("Attempting to reconnect...")
if await self.reconnect():
print("Successfully reconnected. Please try your command again.")
else:
print("Failed to reconnect. Please restart the application.")
except Exception as e:
logger.error(f"Error processing text input: {str(e)}")
break
async def reconnect(self, max_attempts=3):
"""Attempt to reconnect to the API service."""
for attempt in range(max_attempts):
try:
if await self.api_client.connect() and await self.api_client.setup():
return True
await asyncio.sleep(1)
except Exception as e:
logger.error(f"Reconnection attempt {attempt + 1} failed: {str(e)}")
return False
async def run(self):
"""Main run loop with enhanced error handling and state management."""
try:
logger.info("Starting main loop")
if not await self.api_client.connect():
logger.error("Failed to connect")
return
if not await self.api_client.setup():
logger.error("Failed to setup connection")
return
async with asyncio.TaskGroup() as tg:
# Create appropriate input task based on mode
if self.mode == 'voice':
input_task = tg.create_task(
self.audio_manager.capture_audio(self.api_client.send_text)
)
logger.info("Audio capture started")
else:
input_task = tg.create_task(self.text_input_loop())
logger.info("Text input mode started")
while True:
try:
if not self.running:
logger.info("Received stop signal")
break
msg = await self.api_client.ws.recv()
try:
response = json.loads(msg)
await self.handle_response(response)
except json.JSONDecodeError:
logger.error("Failed to decode response")
except Exception as e:
logger.error(f"Error handling response: {str(e)}")
if "invalid frame payload data" in str(e):
logger.warning("Disabled screenshot sending due to compatibility issues")
self.send_screenshots = False
continue
except asyncio.CancelledError:
logger.info("Task cancelled")
break
except Exception as e:
if "invalid frame payload data" in str(e):
logger.warning("WebSocket error with image data, continuing without screenshots")
self.send_screenshots = False
continue
elif "connection" in str(e).lower():
if self.mode == 'text':
print("\nConnection lost. Attempting to reconnect...")
if await self.reconnect():
if self.mode == 'text':
print("Reconnected successfully. Please try your command again.")
continue
else:
if self.mode == 'text':
print("Failed to reconnect. Please restart the application.")
logger.error(f"WebSocket error: {str(e)}")
break
except Exception as e:
logger.error(f"Fatal error: {str(e)}")
finally:
self.running = False
if self.audio_manager:
self.audio_manager.cleanup()
if self.api_client.ws:
try:
await self.api_client.ws.close()
except:
pass
logger.info("Cleanup complete")
if __name__ == "__main__":
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
mode = input("Choose input mode (voice/text): ").lower()
if mode not in ['voice', 'text']:
print("Invalid mode. Defaulting to text mode.")
mode = 'text'
voice_control = VoiceControl(mode=mode)
asyncio.run(voice_control.run())