From 2c5e748bc3e4fcc9468c14b09dce06205b8c7be2 Mon Sep 17 00:00:00 2001 From: jihvijhojhviihogyuvi Date: Fri, 9 Jan 2026 00:31:50 +0000 Subject: [PATCH] Improve agent accuracy: deterministic LLM defaults, tool result normalization, prompt grounding, post-action verification, telemetry, simulator and tests --- main.py | 13 ++++--- .../unit/agent/test_deterministic_profile.py | 8 +++++ windows_use/agent/prompt/action.md | 2 ++ windows_use/agent/registry/service.py | 35 +++++++++++++++++-- windows_use/agent/registry/views.py | 4 ++- windows_use/agent/service.py | 24 ++++++++++++- windows_use/llms/anthropic.py | 3 +- windows_use/llms/azure_openai.py | 4 ++- windows_use/llms/cerebras.py | 3 +- windows_use/llms/google.py | 3 +- windows_use/llms/groq.py | 3 +- windows_use/llms/mistral.py | 3 +- windows_use/llms/ollama.py | 3 +- windows_use/llms/open_router.py | 3 +- windows_use/llms/openai.py | 3 +- windows_use/simulator.py | 24 +++++++++++++ windows_use/telemetry/views.py | 5 +++ windows_use/tool/service.py | 21 +++++++++-- 18 files changed, 142 insertions(+), 22 deletions(-) create mode 100644 tests/unit/agent/test_deterministic_profile.py create mode 100644 windows_use/simulator.py diff --git a/main.py b/main.py index 5cc98b3..f4447af 100644 --- a/main.py +++ b/main.py @@ -1,8 +1,4 @@ from windows_use.llms.google import ChatGoogle -from windows_use.llms.anthropic import ChatAnthropic -from windows_use.llms.ollama import ChatOllama -from windows_use.llms.mistral import ChatMistral -from windows_use.llms.azure_openai import ChatAzureOpenAI from windows_use.agent import Agent, Browser from dotenv import load_dotenv import os @@ -11,8 +7,10 @@ def main(): api_key = os.getenv("GOOGLE_API_KEY") - # llm=ChatMistral(model='magistral-small-latest',api_key=api_key,temperature=0.7) - llm=ChatGoogle(model="gemini-2.5-flash-lite",thinking_budget=0, api_key=api_key, temperature=0.7) + # llm=ChatMistral(model='magistral-small-latest',api_key=api_key,temperature=0.0, profile="deterministic") + # Some external ChatGoogle implementations may not accept `profile`. + # Use a backward-compatible call without `profile` to avoid TypeError. + llm=ChatGoogle(model="gemini-2.5-flash", api_key=api_key, temperature=0.0) # llm=ChatAnthropic(model="claude-sonnet-4-5", api_key=api_key, temperature=0.7,max_tokens=1000) # llm=ChatOllama(model="qwen3-vl:235b-cloud",temperature=0.2) # llm=ChatAzureOpenAI( @@ -23,7 +21,8 @@ def main(): # api_version=os.getenv("AOAI_API_VERSION", "2025-01-01-preview"), # temperature=0.7 # ) - agent = Agent(llm=llm, browser=Browser.EDGE, use_vision=False, auto_minimize=True) + # Configure agent for deterministic, low-latency operation: fewer retries and steps. + agent = Agent(llm=llm, browser=Browser.EDGE, use_vision=False, auto_minimize=True, max_consecutive_failures=1, max_steps=10) agent.print_response(query=input("Enter a query: ")) if __name__ == "__main__": diff --git a/tests/unit/agent/test_deterministic_profile.py b/tests/unit/agent/test_deterministic_profile.py new file mode 100644 index 0000000..d981e2e --- /dev/null +++ b/tests/unit/agent/test_deterministic_profile.py @@ -0,0 +1,8 @@ +from windows_use.llms.google import ChatGoogle + + +def test_google_llm_default_is_deterministic(): + llm = ChatGoogle(model="test-model") + assert hasattr(llm, 'temperature') + assert llm.temperature == 0.0 + assert hasattr(llm, 'profile') diff --git a/windows_use/agent/prompt/action.md b/windows_use/agent/prompt/action.md index 302385b..62485f9 100644 --- a/windows_use/agent/prompt/action.md +++ b/windows_use/agent/prompt/action.md @@ -5,6 +5,8 @@ {action_name} {action_input} + Provide the minimal observable evidence you expect after this action (e.g., "Dialog 'Save as' open", "Button 'OK' visible"). + Provide a one-line verification the agent must perform after the action to confirm success. ``` \ No newline at end of file diff --git a/windows_use/agent/registry/service.py b/windows_use/agent/registry/service.py index 4733109..7f5f9ee 100644 --- a/windows_use/agent/registry/service.py +++ b/windows_use/agent/registry/service.py @@ -1,6 +1,7 @@ from windows_use.agent.registry.views import ToolResult from windows_use.agent.desktop.service import Desktop from windows_use.tool import Tool +from windows_use.tool.service import ToolResult as RawToolResult from textwrap import dedent import json @@ -31,8 +32,38 @@ def execute(self, tool_name: str, desktop: Desktop|None=None, **kwargs) -> ToolR if tool is None: return ToolResult(is_success=False, error=f"Tool '{tool_name}' not found.") try: + # Preprocess common alternative selectors (e.g., label -> loc) + if desktop and 'label' in kwargs and 'loc' not in kwargs: + try: + label = int(kwargs.pop('label')) + coords = desktop.get_coordinates_from_label(label) + kwargs['loc'] = coords + except (IndexError, ValueError) as e: + return ToolResult(is_success=False, error=f"Invalid label selector: {e}") + args=tool.model.model_validate(kwargs) - content = tool.invoke(**({'desktop': desktop} | args.model_dump())) - return ToolResult(is_success=True, content=content) + raw = tool.invoke(**({'desktop': desktop} | args.model_dump())) + # If tool returned a Raw ToolResult (from windows_use.tool.service), map it + if isinstance(raw, RawToolResult): + is_ok = raw.status.lower() in ("ok", "success") + content_str = None + if raw.evidence: + # Prefer a readable representation of evidence + try: + import json + content_str = json.dumps(raw.evidence) + except Exception: + content_str = str(raw.evidence) + elif raw.details: + content_str = str(raw.details) + else: + content_str = None + return ToolResult(is_success=is_ok, content=content_str, error=None if is_ok else str(raw.details), confidence=raw.confidence, evidence=raw.evidence) + # If tool returned arbitrary dict or string, normalize + if isinstance(raw, dict): + return ToolResult(is_success=True, content=str(raw)) + if isinstance(raw, str): + return ToolResult(is_success=True, content=raw) + return ToolResult(is_success=True, content=str(raw)) except Exception as error: return ToolResult(is_success=False, error=str(error)) \ No newline at end of file diff --git a/windows_use/agent/registry/views.py b/windows_use/agent/registry/views.py index 97488f1..e93850e 100644 --- a/windows_use/agent/registry/views.py +++ b/windows_use/agent/registry/views.py @@ -3,4 +3,6 @@ class ToolResult(BaseModel): is_success: bool content: str | None = None - error: str | None = None \ No newline at end of file + error: str | None = None + confidence: float | None = None + evidence: dict | None = None \ No newline at end of file diff --git a/windows_use/agent/service.py b/windows_use/agent/service.py index cb36cd2..4526ab0 100644 --- a/windows_use/agent/service.py +++ b/windows_use/agent/service.py @@ -86,11 +86,17 @@ def invoke(self,query: str)->AgentResult: for consecutive_failures in range(1, self.max_consecutive_failures + 1): try: llm_response = self.llm.invoke(messages+error_messages) + # Ensure LLM returned usable content + if not hasattr(llm_response, 'content') or llm_response.content is None: + raise ValueError("LLM returned empty or invalid content") agent_data = xml_parser(llm_response) break except ValueError as e: error_messages.clear() - error_messages.append(llm_response) + # Append previous LLM content if available for context + if 'llm_response' in locals() and llm_response is not None: + prev_content = getattr(llm_response, 'content', None) or str(llm_response) + error_messages.append(HumanMessage(content=f"Previous response: {prev_content}")) error_messages.append(HumanMessage(content=f"Response rejected, invalid response format\nError: {e}\nAdhere to the format specified in ")) logger.warning(f"[LLM]: Invalid response format, Retrying attempt {consecutive_failures}/{self.max_consecutive_failures}...") if consecutive_failures == self.max_consecutive_failures: @@ -147,10 +153,26 @@ def invoke(self,query: str)->AgentResult: else: logger.info(f"[Tool] 🔧 Action: {action_name}({', '.join(f'{k}={v}' for k, v in params.items())})") action_response = self.registry.execute(tool_name=action_name, desktop=self.desktop, **params) + # Basic post-action verification and confidence handling observation = action_response.content if action_response.is_success else action_response.error logger.info(f"[Tool] 📝 Observation: {observation}\n") agent_data.observation = observation + # If tool returned a confidence and it's low, treat as non-success and request retry + low_confidence = False + if getattr(action_response, 'confidence', None) is not None: + try: + conf = float(action_response.confidence) + if conf < 0.95: + low_confidence = True + except Exception: + pass + + if low_confidence: + logger.warning(f"[Tool] ⚠️ Low confidence ({action_response.confidence}); requesting clarification/retry.") + # Convert to observation that signals failure for the LLM to reconsider + agent_data.observation = f"LOW_CONFIDENCE: {observation}" + desktop_state = self.desktop.get_state(use_vision=self.use_vision) human_prompt = Prompt.observation_prompt(query=query, agent_step=self.agent_step, tool_result=action_response, desktop_state=desktop_state diff --git a/windows_use/llms/anthropic.py b/windows_use/llms/anthropic.py index 9ba5c4c..204627c 100644 --- a/windows_use/llms/anthropic.py +++ b/windows_use/llms/anthropic.py @@ -9,12 +9,13 @@ @dataclass class ChatAnthropic(BaseChatLLM): - def __init__(self, model: str, api_key: str, thinking_budget:int=-1, temperature: float = 0.7, max_tokens: int = 8192, auth_token: str | None = None, base_url: str | None = None, timeout: float | None = None, max_retries: int = 3, default_headers: dict[str, str] | None = None, default_query: dict[str, object] | None = None, http_client: Client | None = None, strict_response_validation: bool = False): + def __init__(self, model: str, api_key: str, thinking_budget:int=-1, temperature: float = 0.0, profile: str | None = None, max_tokens: int = 8192, auth_token: str | None = None, base_url: str | None = None, timeout: float | None = None, max_retries: int = 3, default_headers: dict[str, str] | None = None, default_query: dict[str, object] | None = None, http_client: Client | None = None, strict_response_validation: bool = False): self.model = model self.api_key = api_key self.auth_token = auth_token self.max_tokens = max_tokens self.temperature = temperature + self.profile = profile self.base_url = base_url self.thinking_budget=thinking_budget self.timeout = timeout diff --git a/windows_use/llms/azure_openai.py b/windows_use/llms/azure_openai.py index 2120108..dabdd86 100644 --- a/windows_use/llms/azure_openai.py +++ b/windows_use/llms/azure_openai.py @@ -18,7 +18,8 @@ def __init__( api_key: str, model: str | None = None, api_version: str = "2024-10-21", - temperature: float = 0.7, + temperature: float = 0.0, + profile: str | None = None, max_retries: int = 3, timeout: float | None = None, default_headers: dict[str, str] | None = None, @@ -31,6 +32,7 @@ def __init__( self.model = model self.api_version = api_version self.temperature = temperature + self.profile = profile self.max_retries = max_retries self.timeout = timeout self.default_headers = default_headers diff --git a/windows_use/llms/cerebras.py b/windows_use/llms/cerebras.py index 06935ee..8384c10 100644 --- a/windows_use/llms/cerebras.py +++ b/windows_use/llms/cerebras.py @@ -10,10 +10,11 @@ @dataclass class ChatCerebras(BaseChatLLM): - def __init__(self, model: str, api_key: str, temperature: float = 0.7, base_url: str | None = None, timeout: float | None = None, max_retries: int = 3, default_headers: dict[str, str] | None = None, default_query: dict[str, object] | None = None, http_client: Client | None = None, strict_response_validation: bool = False, warm_tcp_connection: bool = True): + def __init__(self, model: str, api_key: str, temperature: float = 0.0, profile: str | None = None, base_url: str | None = None, timeout: float | None = None, max_retries: int = 3, default_headers: dict[str, str] | None = None, default_query: dict[str, object] | None = None, http_client: Client | None = None, strict_response_validation: bool = False, warm_tcp_connection: bool = True): self.model = model self.api_key = api_key self.temperature = temperature + self.profile = profile self.base_url = base_url self.timeout = timeout self.max_retries = max_retries diff --git a/windows_use/llms/google.py b/windows_use/llms/google.py index 97fd296..3548d46 100644 --- a/windows_use/llms/google.py +++ b/windows_use/llms/google.py @@ -30,11 +30,12 @@ def run_async(coro): @dataclass class ChatGoogle(BaseChatLLM): - def __init__(self, model: str, thinking_budget: int=-1, api_key: str=None, vertexai: bool|None=None, project: str|None=None, location: str|None=None, credentials: Credentials|None=None,http_options: types.HttpOptions | types.HttpOptionsDict | None = None, debug_config: DebugConfig | None = None, temperature: float = 0.7): + def __init__(self, model: str, thinking_budget: int=-1, api_key: str=None, vertexai: bool|None=None, project: str|None=None, location: str|None=None, credentials: Credentials|None=None,http_options: types.HttpOptions | types.HttpOptionsDict | None = None, debug_config: DebugConfig | None = None, temperature: float = 0.0, profile: str | None = None): self.model = model self.api_key = api_key self.vertexai = vertexai self.temperature = temperature + self.profile = profile self.credentials = credentials self.project = project self.location = location diff --git a/windows_use/llms/groq.py b/windows_use/llms/groq.py index c995a1c..635d41b 100644 --- a/windows_use/llms/groq.py +++ b/windows_use/llms/groq.py @@ -11,10 +11,11 @@ @dataclass class ChatGroq(BaseChatLLM): - def __init__(self, model: str, api_key: str, base_url: str|None=None, temperature: float = 0.7,max_retries: int = 3,timeout: int|None=None, default_headers: dict[str, str] | None = None, default_query: dict[str, object] | None = None, http_client: Client | None = None, strict_response_validation: bool = False): + def __init__(self, model: str, api_key: str, base_url: str|None=None, temperature: float = 0.0, profile: str | None = None, max_retries: int = 3,timeout: int|None=None, default_headers: dict[str, str] | None = None, default_query: dict[str, object] | None = None, http_client: Client | None = None, strict_response_validation: bool = False): self.model = model self.api_key = api_key self.temperature = temperature + self.profile = profile self.max_retries = max_retries self.base_url = base_url self.timeout = timeout diff --git a/windows_use/llms/mistral.py b/windows_use/llms/mistral.py index fbe115e..e6adf57 100644 --- a/windows_use/llms/mistral.py +++ b/windows_use/llms/mistral.py @@ -9,10 +9,11 @@ @dataclass class ChatMistral(BaseChatLLM): - def __init__(self, model: str, api_key: str, max_tokens: int|None=None, temperature: float = 0.7, server: Union[str, None] = None, server_url: Union[str, None] = None, url_params: Dict[str, str] = None, client: Type[HttpClient] = None, async_client: Type[AsyncHttpClient] = None,retry_config: OptionalNullable[RetryConfig] = None,timeout_ms: Union[int, None] = None,debug_logger: Union[logging.Logger, None] = None): + def __init__(self, model: str, api_key: str, max_tokens: int|None=None, temperature: float = 0.0, profile: str | None = None, server: Union[str, None] = None, server_url: Union[str, None] = None, url_params: Dict[str, str] = None, client: Type[HttpClient] = None, async_client: Type[AsyncHttpClient] = None,retry_config: OptionalNullable[RetryConfig] = None,timeout_ms: Union[int, None] = None,debug_logger: Union[logging.Logger, None] = None): self.model = model self.api_key = api_key self.temperature = temperature + self.profile = profile self.server = server self.max_tokens = max_tokens self.server_url = server_url diff --git a/windows_use/llms/ollama.py b/windows_use/llms/ollama.py index 61f0e2a..0b0bee3 100644 --- a/windows_use/llms/ollama.py +++ b/windows_use/llms/ollama.py @@ -7,11 +7,12 @@ @dataclass class ChatOllama(BaseChatLLM): - def __init__(self,host: str|None=None, model: str|None=None, think:bool=False, temperature: float = 0.7,timeout: int|None=None): + def __init__(self,host: str|None=None, model: str|None=None, think:bool=False, temperature: float = 0.0, profile: str | None = None, timeout: int|None=None): self.host = host self.model = model self.think=think self.temperature = temperature + self.profile = profile self.timeout = timeout self._client = None diff --git a/windows_use/llms/open_router.py b/windows_use/llms/open_router.py index 59d76b7..6b4e2bc 100644 --- a/windows_use/llms/open_router.py +++ b/windows_use/llms/open_router.py @@ -11,10 +11,11 @@ @dataclass class ChatOpenRouter(BaseChatLLM): - def __init__(self, model: str, api_key: str, base_url: str|None=None, temperature: float = 0.7,max_retries: int = 3,timeout: int|None=None, default_headers: dict[str, str] | None = None, default_query: dict[str, object] | None = None, http_client: Client | None = None, strict_response_validation: bool = False): + def __init__(self, model: str, api_key: str, base_url: str|None=None, temperature: float = 0.0, profile: str | None = None, max_retries: int = 3,timeout: int|None=None, default_headers: dict[str, str] | None = None, default_query: dict[str, object] | None = None, http_client: Client | None = None, strict_response_validation: bool = False): self.model = model self.api_key = api_key self.temperature = temperature + self.profile = profile self.max_retries = max_retries self.base_url = base_url self.timeout = timeout diff --git a/windows_use/llms/openai.py b/windows_use/llms/openai.py index 5b12c4d..5d8dfc3 100644 --- a/windows_use/llms/openai.py +++ b/windows_use/llms/openai.py @@ -11,10 +11,11 @@ @dataclass class ChatOpenAI(BaseChatLLM): - def __init__(self, model: str, api_key: str, organization: str|None=None, project: str|None=None, base_url: str|None=None, websocket_base_url: str|None=None, temperature: float = 0.7,max_retries: int = 3,timeout: int|None=None, default_headers: dict[str, str] | None = None, default_query: dict[str, object] | None = None, http_client: Client | None = None, strict_response_validation: bool = False): + def __init__(self, model: str, api_key: str, organization: str|None=None, project: str|None=None, base_url: str|None=None, websocket_base_url: str|None=None, temperature: float = 0.0, profile: str | None = None, max_retries: int = 3,timeout: int|None=None, default_headers: dict[str, str] | None = None, default_query: dict[str, object] | None = None, http_client: Client | None = None, strict_response_validation: bool = False): self.model = model self.api_key = api_key self.temperature = temperature + self.profile = profile self.max_retries = max_retries self.organization = organization self.project = project diff --git a/windows_use/simulator.py b/windows_use/simulator.py new file mode 100644 index 0000000..7fc784e --- /dev/null +++ b/windows_use/simulator.py @@ -0,0 +1,24 @@ +from dataclasses import dataclass + +@dataclass +class DummyDesktopState: + tree: dict + screenshot: bytes | None = None + +class Simulator: + """Simple deterministic simulator harness for unit tests.""" + def __init__(self): + self.state = DummyDesktopState(tree={"apps":[],"interactive":[]}, screenshot=None) + + def step_click(self, target_id: str): + # deterministic behavior: if target exists, return success + for el in self.state.tree.get("interactive",[]): + if el.get("id") == target_id: + return {"status":"ok","evidence":{"clicked":target_id},"confidence":1.0} + return {"status":"error","details":"element_not_found","confidence":0.0} + + def add_element(self, el: dict): + self.state.tree.setdefault("interactive",[]).append(el) + + def get_state(self): + return self.state diff --git a/windows_use/telemetry/views.py b/windows_use/telemetry/views.py index 5f9c314..31b2dbc 100644 --- a/windows_use/telemetry/views.py +++ b/windows_use/telemetry/views.py @@ -25,4 +25,9 @@ class AgentTelemetryEvent(BaseTelemetryEvent): error: str | None=None event_name: str = "agent_event" is_success:bool=False + action_intent_confidence: float | None = None + action_success: bool | None = None + evidence_mismatch: bool | None = None + rerun_count: int | None = None + post_check_pass: bool | None = None diff --git a/windows_use/tool/service.py b/windows_use/tool/service.py index fb83a31..5f042a8 100644 --- a/windows_use/tool/service.py +++ b/windows_use/tool/service.py @@ -1,5 +1,5 @@ from pydantic import BaseModel -from typing import Any +from typing import Any, Union class Tool: def __init__(self, name: str|None=None, description: str|None=None, args_schema:BaseModel|None=None): @@ -28,4 +28,21 @@ def __call__(self, function): return self def invoke(self, *args, **kwargs): - return self.function(*args, **kwargs) \ No newline at end of file + result = self.function(*args, **kwargs) + # Normalize and validate result to ToolResult when possible + if isinstance(result, dict): + try: + return ToolResult.parse_obj(result) + except Exception: + return ToolResult(status="error", evidence={"raw": result}, confidence=0.0, details="invalid tool result schema") + if isinstance(result, ToolResult): + return result + # Wrap arbitrary return values + return ToolResult(status="ok", evidence={"result": result}, confidence=1.0) + + +class ToolResult(BaseModel): + status: str + evidence: dict | None = None + confidence: float = 1.0 + details: Union[str, dict, None] = None \ No newline at end of file