-
Notifications
You must be signed in to change notification settings - Fork 110
Expand file tree
/
Copy pathcua_handler.py
More file actions
610 lines (532 loc) · 24.9 KB
/
cua_handler.py
File metadata and controls
610 lines (532 loc) · 24.9 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
import asyncio
import base64
from typing import Any, Optional
from ..types.agent import (
ActionExecutionResult,
AgentAction,
)
class StagehandFunctionName:
AGENT = "agent"
class CUAHandler: # Computer Use Agent Handler
"""Handles Computer Use Agent tasks by executing actions on the page."""
def __init__(
self,
stagehand,
page,
logger,
):
self.stagehand = stagehand
self.logger = logger
self.page = page
async def get_screenshot_base64(self) -> str:
"""Captures a screenshot of the current page and returns it as a base64 encoded string."""
self.logger.debug(
"Capturing screenshot for CUA client", category=StagehandFunctionName.AGENT
)
screenshot_bytes = await self.page.screenshot(full_page=False, type="png")
return base64.b64encode(screenshot_bytes).decode()
async def perform_action(self, action: AgentAction) -> ActionExecutionResult:
"""Execute a single action on the page."""
specific_action_model = action.action
self.logger.info(
f"Performing action: {specific_action_model or ''}",
category=StagehandFunctionName.AGENT,
)
action_type = action.action_type
if not specific_action_model:
self.logger.error(
f"No specific action model found for action type {action_type}",
category=StagehandFunctionName.AGENT,
)
return {
"success": False,
"error": f"No specific action model for {action_type}",
}
try:
# Store initial URL to detect navigation
initial_url = self.page.url
if action_type == "click":
# specific_action_model is already an instance of ClickAction
x, y = specific_action_model.x, specific_action_model.y
button = getattr(specific_action_model, "button", "left")
if button == "back":
await self.page.go_back()
elif button == "forward":
await self.page.go_forward()
else:
await self._update_cursor_position(x, y)
await self._animate_click(x, y)
await asyncio.sleep(0.1) # Ensure animation is visible
await self.page.mouse.click(x, y, button=button)
# Check for page navigation
await self.handle_page_navigation("click", initial_url)
return {"success": True}
elif action_type == "double_click":
# specific_action_model is e.g. DoubleClickAction
x, y = specific_action_model.x, specific_action_model.y
await self._update_cursor_position(x, y)
await self._animate_click(x, y)
await asyncio.sleep(0.1)
await self._animate_click(x, y)
await asyncio.sleep(0.1)
await self.page.mouse.dblclick(x, y)
# Check for page navigation
await self.handle_page_navigation("double_click", initial_url)
return {"success": True}
elif action_type == "type":
# specific_action_model is TypeAction
clear_before = getattr(specific_action_model, "clear_before_typing", False)
if (
hasattr(specific_action_model, "x")
and hasattr(specific_action_model, "y")
and specific_action_model.x is not None
and specific_action_model.y is not None
):
await self._update_cursor_position(
specific_action_model.x, specific_action_model.y
)
if clear_before:
# Triple-click to select all text in the field, then type to replace
await self.page.mouse.click(
specific_action_model.x,
specific_action_model.y,
click_count=3,
)
await asyncio.sleep(0.05) # Brief pause for selection to register
else:
# Single click to position cursor
await self.page.mouse.click(
specific_action_model.x, specific_action_model.y
)
elif clear_before:
# No coordinates but clear requested - use keyboard shortcuts as fallback
try:
await self.page.keyboard.press("Meta+a")
except Exception:
pass
try:
await self.page.keyboard.press("Control+a")
except Exception:
pass
await self.page.keyboard.press("Backspace")
await self.page.keyboard.type(specific_action_model.text)
if (
hasattr(specific_action_model, "press_enter_after")
and specific_action_model.press_enter_after
):
await self.page.keyboard.press("Enter")
await self.handle_page_navigation("type", initial_url)
return {"success": True}
elif action_type == "keypress":
# specific_action_model is KeyPressAction
# Ensure playwright_key is defined before loop or correctly scoped if used after
playwright_key = "" # Initialize
for key_str in specific_action_model.keys:
playwright_key = self._convert_key_name(key_str)
await self.page.keyboard.press(playwright_key) # Press each key
# Check for page navigation - keys like Enter can cause navigation
await self.handle_page_navigation("keypress", initial_url)
return {"success": True}
elif action_type == "scroll":
# specific_action_model is ScrollAction
x, y = specific_action_model.x, specific_action_model.y
scroll_x = getattr(specific_action_model, "scroll_x", 0)
scroll_y = getattr(specific_action_model, "scroll_y", 0)
await self.page.mouse.move(x, y)
await self.page.mouse.wheel(scroll_x, scroll_y)
return {"success": True}
elif action_type == "function":
# specific_action_model is FunctionAction
name = specific_action_model.name
args = getattr(specific_action_model, "arguments", {})
if name == "goto" and args.url:
await self.page.goto(args.url)
return {"success": True}
elif name == "navigate_back":
await self.page.go_back()
return {"success": True}
# Add other function calls like back, forward, reload if needed, similar to TS version
self.logger.error(
f"Unsupported function call: {name}",
category=StagehandFunctionName.AGENT,
)
return {"success": False, "error": f"Unsupported function: {name}"}
elif (
action_type == "key"
): # Anthropic specific key action (can be generalized or mapped by Anthropic client)
# specific_action_model is KeyAction
text = specific_action_model.text
# This logic might be better if Anthropic client translates to a "keypress" AgentAction
# or if _convert_key_name handles these common names too.
if text.lower() in ["return", "enter"]:
await self.page.keyboard.press("Enter")
elif text.lower() == "tab":
await self.page.keyboard.press("Tab")
else:
# Use _convert_key_name for consistency if possible, or press directly
await self.page.keyboard.press(self._convert_key_name(text))
# Check for page navigation - Enter and other keys may navigate
await self.handle_page_navigation("key", initial_url)
return {"success": True}
elif action_type == "wait":
await asyncio.gather(
asyncio.sleep(specific_action_model.miliseconds / 1000),
self.inject_cursor(),
)
return {"success": True}
elif action_type == "move":
x, y = specific_action_model.x, specific_action_model.y
await self._update_cursor_position(x, y)
return {"success": True}
elif action_type == "screenshot":
return {"success": True}
elif action_type == "goto":
await self.page.goto(specific_action_model.url)
await self.handle_page_navigation("goto", initial_url)
return {"success": True}
else:
self.logger.error(
f"Unsupported action type: {action_type}",
category=StagehandFunctionName.AGENT,
)
return {
"success": False,
"error": f"Unsupported action type: {action_type}",
}
except Exception as e:
self.logger.error(
f"Error executing action {action_type}: {e}",
category=StagehandFunctionName.AGENT,
)
return {"success": False, "error": str(e)}
async def inject_cursor(self) -> None:
"""Inject a cursor element into the page for visual feedback by calling the JS function."""
self.logger.debug(
"Attempting to inject cursor via window.__stagehandInjectCursor",
category=StagehandFunctionName.AGENT,
)
try:
await self.page.evaluate("window.__stagehandInjectCursor()")
self.logger.debug(
"Cursor injection via JS function initiated.",
category=StagehandFunctionName.AGENT,
)
except Exception as e:
self.logger.error(
f"Failed to call window.__stagehandInjectCursor: {e}",
category=StagehandFunctionName.AGENT,
)
async def _update_cursor_position(self, x: int, y: int) -> None:
"""Update the cursor position on the page by calling the JS function."""
try:
await self.page.evaluate(
f"window.__stagehandUpdateCursorPosition({x}, {y})"
)
except Exception as e:
self.logger.debug(
f"Failed to call window.__stagehandUpdateCursorPosition: {e}",
category=StagehandFunctionName.AGENT,
)
async def _animate_click(self, x: int, y: int) -> None:
"""Animate a click at the given position by calling the JS function."""
try:
await self.page.evaluate(f"window.__stagehandAnimateClick({x}, {y})")
except Exception as e:
self.logger.debug(
f"Failed to call window.__stagehandAnimateClick: {e}",
category=StagehandFunctionName.AGENT,
)
async def _wait_for_settled_dom(self, timeout_ms: Optional[int] = None) -> None:
timeout = (
timeout_ms if timeout_ms is not None else 10000
) # Default to 10s, can be configured via stagehand options
loop = asyncio.get_event_loop()
future = loop.create_future()
cdp_session = None
try:
cdp_session = await self.page.context.new_cdp_session(self.page)
# Check if document exists, similar to TypeScript version's hasDoc
try:
await self.page.title()
except Exception:
await self.page.wait_for_load_state("domcontentloaded", timeout=timeout)
await cdp_session.send("Network.enable")
await cdp_session.send("Page.enable")
await cdp_session.send(
"Target.setAutoAttach",
{
"autoAttach": True,
"waitForDebuggerOnStart": False,
"flatten": True,
},
)
inflight_requests: set[str] = set()
request_meta: dict[str, dict[str, Any]] = (
{}
) # {requestId: {url: string, start: float}}
doc_by_frame: dict[str, str] = {} # {frameId: requestId}
quiet_timer_handle: Optional[asyncio.TimerHandle] = None
stalled_request_sweep_task: Optional[asyncio.Task] = None
# Helper to clear quiet timer
def clear_quiet_timer():
nonlocal quiet_timer_handle
if quiet_timer_handle:
quiet_timer_handle.cancel()
quiet_timer_handle = None
# Forward declaration for resolve_done
resolve_done_callbacks = [] # To store cleanup actions
def resolve_done_action():
nonlocal quiet_timer_handle, stalled_request_sweep_task
for callback in resolve_done_callbacks:
try:
callback()
except Exception as e:
self.logger.debug(
f"Error during resolve_done callback: {e}", category="dom"
)
clear_quiet_timer()
if stalled_request_sweep_task and not stalled_request_sweep_task.done():
stalled_request_sweep_task.cancel()
if not future.done():
future.set_result(None)
# Helper to potentially resolve if network is quiet
def maybe_quiet():
nonlocal quiet_timer_handle
if (
not inflight_requests
and not quiet_timer_handle
and not future.done()
):
quiet_timer_handle = loop.call_later(
1.0, resolve_done_action
) # Increased to 1000ms (from 0.5)
# Finishes a request
def finish_request(request_id: str):
if request_id not in inflight_requests:
return
inflight_requests.remove(request_id)
request_meta.pop(request_id, None)
frames_to_remove = [
fid for fid, rid in doc_by_frame.items() if rid == request_id
]
for fid in frames_to_remove:
doc_by_frame.pop(fid, None)
clear_quiet_timer()
maybe_quiet()
# Event handlers
def on_request_will_be_sent(params: dict):
request_type = params.get("type")
if request_type == "WebSocket" or request_type == "EventSource":
return
request_id = params["requestId"]
inflight_requests.add(request_id)
request_meta[request_id] = {
"url": params["request"]["url"],
"start": loop.time(),
}
if params.get("type") == "Document" and params.get("frameId"):
doc_by_frame[params["frameId"]] = request_id
clear_quiet_timer()
def on_loading_finished(params: dict):
finish_request(params["requestId"])
def on_loading_failed(params: dict):
finish_request(params["requestId"])
def on_request_served_from_cache(params: dict):
finish_request(params["requestId"])
def on_response_received(params: dict): # For data URLs
response_url = params.get("response", {}).get("url", "")
if response_url.startswith("data:"):
finish_request(params["requestId"])
def on_frame_stopped_loading(params: dict):
frame_id = params["frameId"]
request_id = doc_by_frame.get(frame_id)
if request_id:
finish_request(request_id)
# Attach CDP event listeners
cdp_session.on("Network.requestWillBeSent", on_request_will_be_sent)
cdp_session.on("Network.loadingFinished", on_loading_finished)
cdp_session.on("Network.loadingFailed", on_loading_failed)
cdp_session.on(
"Network.requestServedFromCache", on_request_served_from_cache
)
cdp_session.on(
"Network.responseReceived", on_response_received
) # For data URLs
cdp_session.on("Page.frameStoppedLoading", on_frame_stopped_loading)
resolve_done_callbacks.append(
lambda: cdp_session.remove_listener(
"Network.requestWillBeSent", on_request_will_be_sent
)
)
resolve_done_callbacks.append(
lambda: cdp_session.remove_listener(
"Network.loadingFinished", on_loading_finished
)
)
resolve_done_callbacks.append(
lambda: cdp_session.remove_listener(
"Network.loadingFailed", on_loading_failed
)
)
resolve_done_callbacks.append(
lambda: cdp_session.remove_listener(
"Network.requestServedFromCache", on_request_served_from_cache
)
)
resolve_done_callbacks.append(
lambda: cdp_session.remove_listener(
"Network.responseReceived", on_response_received
)
)
resolve_done_callbacks.append(
lambda: cdp_session.remove_listener(
"Page.frameStoppedLoading", on_frame_stopped_loading
)
)
# Stalled request sweeper
async def sweep_stalled_requests():
while not future.done():
await asyncio.sleep(0.5) # 500ms interval
now = loop.time()
stalled_ids_to_remove = []
for req_id, meta in list(
request_meta.items()
): # Iterate over a copy for safe modification
if (
now - meta["start"] > 4.0
): # Increased to 4 seconds (from 2.0)
stalled_ids_to_remove.append(req_id)
self.logger.debug(
f"DOM Settle: Forcing completion of stalled request {req_id}, URL: {meta['url'][:120]}",
category="dom", # Using "dom" as a category for these logs
)
if stalled_ids_to_remove:
for req_id in stalled_ids_to_remove:
if (
req_id in inflight_requests
): # Ensure it's still considered inflight
inflight_requests.remove(req_id)
request_meta.pop(req_id, None)
clear_quiet_timer() # State changed
maybe_quiet() # Re-evaluate if network is quiet
stalled_request_sweep_task = loop.create_task(sweep_stalled_requests())
# Overall timeout guard
guard_handle = loop.call_later(
timeout / 1000.0, lambda: {resolve_done_action()}
)
resolve_done_callbacks.append(lambda: guard_handle.cancel())
maybe_quiet() # Initial check if already quiet
await future # Wait for the future to be resolved
except Exception as e:
self.logger.error(f"Error in _wait_for_settled_dom: {e}", category="dom")
if not future.done():
future.set_exception(e) # Propagate error if future not done
finally:
if (
"resolve_done_action" in locals()
and callable(resolve_done_action)
and not future.done()
):
# If future isn't done but we are exiting, ensure cleanup happens.
# This might happen on an unexpected early exit from the try block.
# However, guard_handle or quiet_timer should eventually call resolve_done_action.
# If an unhandled exception caused early exit before guard/quiet timers, this is a fallback.
self.logger.debug(
"Ensuring resolve_done_action is called in finally due to early exit",
category="dom",
)
# resolve_done_action() # Be cautious calling it directly here, might lead to double calls or race conditions
# Rely on the guard and quiet timers mostly.
if stalled_request_sweep_task and not stalled_request_sweep_task.done():
stalled_request_sweep_task.cancel()
try:
await stalled_request_sweep_task # Allow cleanup
except asyncio.CancelledError:
pass # Expected
if cdp_session:
try:
await cdp_session.detach()
except Exception as e_detach:
self.logger.debug(
f"Error detaching CDP session: {e_detach}", category="dom"
)
def _convert_key_name(self, key: str) -> str:
"""Convert CUA key names to Playwright key names."""
key_map = {
"ENTER": "Enter",
"RETURN": "Enter", # Added for Anthropic 'key' type if used via this
"ESCAPE": "Escape",
"ESC": "Escape", # Added
"BACKSPACE": "Backspace",
"TAB": "Tab",
"SPACE": " ",
"ARROWUP": "ArrowUp",
"ARROWDOWN": "ArrowDown",
"ARROWLEFT": "ArrowLeft",
"ARROWRIGHT": "ArrowRight",
"UP": "ArrowUp",
"DOWN": "ArrowDown",
"LEFT": "ArrowLeft",
"RIGHT": "ArrowRight",
"SHIFT": "Shift",
"CONTROL": "Control",
"CTRL": "Control", # Added
"ALT": "Alt",
"OPTION": "Alt", # Added
"META": "Meta",
"COMMAND": "Meta",
"CMD": "Meta", # Added
"DELETE": "Delete",
"HOME": "Home",
"END": "End",
"PAGEUP": "PageUp",
"PAGEDOWN": "PageDown",
"CAPSLOCK": "CapsLock",
"INSERT": "Insert",
"/": "Divide",
"\\": "Backslash",
}
# Convert to uppercase for case-insensitive matching then check map,
# default to original key if not found.
return key_map.get(key.upper(), key)
async def handle_page_navigation(
self,
action_description: str,
initial_url: str,
dom_settle_timeout_ms: int = 5000, # Increased default for the new method
) -> None:
"""Handle possible page navigation after an action."""
self.logger.info(
f"{action_description} - checking for page navigation",
category=StagehandFunctionName.AGENT,
)
newly_opened_page = None
try:
# Using a short timeout for immediate new tab detection
async with self.page.context.expect_page(timeout=1000) as new_page_info:
pass # The action that might open a page has already run. We check if one was caught.
newly_opened_page = await new_page_info.value
# Don't close the new tab - let it remain open and be handled by the context
# The StagehandContext will automatically make this the active page via its event listener
self.logger.debug(
f"New page detected with URL: {newly_opened_page.url}",
category=StagehandFunctionName.AGENT,
)
except asyncio.TimeoutError:
newly_opened_page = None
except Exception:
newly_opened_page = None
await self._wait_for_settled_dom(timeout_ms=dom_settle_timeout_ms)
final_url = self.page.url
if final_url != initial_url:
self.logger.debug(
f"Page navigation handled. Initial URL: {initial_url}, Final URL: {final_url}",
category=StagehandFunctionName.AGENT,
)
else:
self.logger.debug(
f"Finished checking for page navigation. URL remains {initial_url}.",
category=StagehandFunctionName.AGENT,
)
# Ensure cursor is injected after any potential navigation or page reload
await self.inject_cursor()