From 165a16baefeef77419b0a61050845edf58f9965d Mon Sep 17 00:00:00 2001 From: godququ5-code <256881196+godququ5-code@users.noreply.github.com> Date: Fri, 3 Jul 2026 22:39:33 +0300 Subject: [PATCH] Fix AG-UI approval thread aliases --- .../ag-ui/agent_framework_ag_ui/_agent_run.py | 54 +++++++-- .../ag_ui/test_approval_thread_id_mismatch.py | 107 ++++++++++++++++++ 2 files changed, 150 insertions(+), 11 deletions(-) create mode 100644 python/packages/ag-ui/tests/ag_ui/test_approval_thread_id_mismatch.py diff --git a/python/packages/ag-ui/agent_framework_ag_ui/_agent_run.py b/python/packages/ag-ui/agent_framework_ag_ui/_agent_run.py index 6f1174f0014..bd2ca7e4190 100644 --- a/python/packages/ag-ui/agent_framework_ag_ui/_agent_run.py +++ b/python/packages/ag-ui/agent_framework_ag_ui/_agent_run.py @@ -8,6 +8,7 @@ import json import logging import uuid +from collections import OrderedDict from collections.abc import AsyncIterable, Awaitable from typing import TYPE_CHECKING, Any, TypedDict, cast @@ -416,13 +417,14 @@ class _PendingApproval(TypedDict): name: str arguments: str | None + keys: list[str] PendingApprovalEntry = _PendingApproval | str -def _make_pending_approval_entry(name: str, arguments: str | None) -> _PendingApproval: - return {"name": name, "arguments": arguments} +def _make_pending_approval_entry(name: str, arguments: str | None, keys: list[str]) -> _PendingApproval: + return {"name": name, "arguments": arguments, "keys": keys} def _pending_approval_name(entry: PendingApprovalEntry) -> str | None: @@ -437,19 +439,45 @@ def _pending_approval_arguments(entry: PendingApprovalEntry) -> str | None: return entry["arguments"] +def _remove_pending_approval(registry: dict[str, PendingApprovalEntry], key: str) -> None: + """Remove one pending approval and all identity aliases that still reference it.""" + entry = registry.get(key) + if entry is None: + return + + keys = [key] if isinstance(entry, str) else entry["keys"] + for alias in keys: + if registry.get(alias) is entry: + del registry[alias] + + +def _register_pending_approval( + registry: dict[str, PendingApprovalEntry], + keys: list[str], + name: str, + arguments: str | None, +) -> None: + """Register one pending approval under each distinct thread identity.""" + unique_keys = list(dict.fromkeys(keys)) + for key in unique_keys: + _remove_pending_approval(registry, key) + + entry = _make_pending_approval_entry(name, arguments, unique_keys) + for key in unique_keys: + registry[key] = entry + + def _evict_oldest_approvals(registry: dict[str, PendingApprovalEntry], max_size: int = 10_000) -> None: """Evict the oldest entries from the pending-approvals registry (LRU). Only effective when *registry* is an ``OrderedDict``; plain dicts are left untouched because insertion-order eviction is unreliable for them. """ - if len(registry) <= max_size: + if len(registry) <= max_size or not isinstance(registry, OrderedDict): return - try: - while len(registry) > max_size: - registry.popitem(last=False) # type: ignore[call-arg] - except (TypeError, KeyError): - pass + while len(registry) > max_size: + oldest_key = next(iter(registry)) + _remove_pending_approval(registry, oldest_key) async def _resolve_approval_responses( @@ -532,8 +560,8 @@ async def _resolve_approval_responses( invalid_ids.add(resp_id) continue - # Valid — consume entry to prevent replay - del pending_approvals[registry_key] + # Valid — consume every identity alias to prevent replay + _remove_pending_approval(pending_approvals, registry_key) if resp.approved: validated.append(resp) else: @@ -855,6 +883,7 @@ async def run_agent_stream( """ # Parse IDs thread_id = input_data.get("thread_id") or input_data.get("threadId") or str(uuid.uuid4()) + client_thread_id = thread_id run_id = input_data.get("run_id") or input_data.get("runId") or str(uuid.uuid4()) snapshot_scope = cast(str | None, input_data.get(_SNAPSHOT_SCOPE_INPUT_KEY)) @@ -1092,7 +1121,10 @@ async def run_agent_stream( # Register pending approval requests so we can validate responses later if content_type == "function_approval_request" and pending_approvals is not None: if content.id and content.function_call and content.function_call.name: - pending_approvals[f"{thread_id}:{content.id}"] = _make_pending_approval_entry( + request_id = content.id + _register_pending_approval( + pending_approvals, + [f"{client_thread_id}:{request_id}", f"{thread_id}:{request_id}"], content.function_call.name, canonical_function_arguments(content.function_call), ) diff --git a/python/packages/ag-ui/tests/ag_ui/test_approval_thread_id_mismatch.py b/python/packages/ag-ui/tests/ag_ui/test_approval_thread_id_mismatch.py new file mode 100644 index 00000000000..ed4ec767981 --- /dev/null +++ b/python/packages/ag-ui/tests/ag_ui/test_approval_thread_id_mismatch.py @@ -0,0 +1,107 @@ +# Copyright (c) Microsoft. All rights reserved. + +"""Regression tests for approval lookup across client and provider thread IDs.""" + +from collections.abc import AsyncIterator, MutableSequence +from typing import Any + +import pytest +from agent_framework import Agent, ChatOptions, ChatResponseUpdate, Content, Message, tool +from agent_framework.ag_ui import AgentFrameworkAgent + + +@pytest.mark.parametrize( + "resume_thread_id", + [ + pytest.param("client-thread", id="client-thread"), + pytest.param("provider-conversation", id="provider-conversation"), + ], +) +async def test_approval_resolves_with_client_or_provider_thread_id( + streaming_chat_client_stub: Any, + resume_thread_id: str, +) -> None: + """A stateful provider approval remains resolvable by either advertised thread identity.""" + execution_count = 0 + + @tool( + name="sensitive_action", + description="A sensitive action requiring approval", + approval_mode="always_require", + ) + def sensitive_action() -> str: + nonlocal execution_count + execution_count += 1 + return "executed" + + async def approval_stream( + messages: MutableSequence[Message], options: ChatOptions, **kwargs: Any + ) -> AsyncIterator[ChatResponseUpdate]: + yield ChatResponseUpdate( + contents=[ + Content.from_function_call( + name="sensitive_action", + call_id="call_sensitive", + arguments="{}", + ) + ], + conversation_id="provider-conversation", + ) + + wrapper = AgentFrameworkAgent( + agent=Agent( + client=streaming_chat_client_stub(approval_stream), + name="test_agent", + instructions="Test", + tools=[sensitive_action], + ) + ) + + async for _ in wrapper.run({"thread_id": "client-thread", "messages": [{"role": "user", "content": "do it"}]}): + pass + + assert "provider-conversation:call_sensitive" in wrapper._pending_approvals + + async def completion_stream( + messages: MutableSequence[Message], options: ChatOptions, **kwargs: Any + ) -> AsyncIterator[ChatResponseUpdate]: + yield ChatResponseUpdate(contents=[Content.from_text(text="Done")]) + + wrapper.agent = Agent( + client=streaming_chat_client_stub(completion_stream), + name="test_agent", + instructions="Test", + tools=[sensitive_action], + ) + + def approval_input(thread_id: str) -> dict[str, Any]: + return { + "thread_id": thread_id, + "messages": [ + { + "role": "user", + "content": "approved", + "function_approvals": [ + { + "id": "call_sensitive", + "call_id": "call_sensitive", + "name": "sensitive_action", + "approved": True, + "arguments": {}, + } + ], + } + ], + } + + async for _ in wrapper.run(approval_input(resume_thread_id)): + pass + + assert execution_count == 1 + assert not any("call_sensitive" in key for key in wrapper._pending_approvals) + + replay_thread_id = "provider-conversation" if resume_thread_id == "client-thread" else "client-thread" + async for _ in wrapper.run(approval_input(replay_thread_id)): + pass + + assert execution_count == 1