diff --git a/src/uipath_langchain/agent/tools/client_side_tool.py b/src/uipath_langchain/agent/tools/client_side_tool.py new file mode 100644 index 000000000..6a0b76790 --- /dev/null +++ b/src/uipath_langchain/agent/tools/client_side_tool.py @@ -0,0 +1,85 @@ +"""Factory for creating client-side tools that execute on the client SDK.""" + +import json +from typing import Annotated, Any + +from langchain_core.messages import ToolMessage +from langchain_core.tools import InjectedToolCallId, StructuredTool +from uipath.agent.models.agent import AgentClientSideToolResourceConfig +from uipath.eval.mocks import mockable + +from uipath_langchain._utils.durable_interrupt import durable_interrupt +from uipath_langchain.agent.react.jsonschema_pydantic_converter import ( + create_model as create_model_from_schema, +) +from uipath_langchain.chat.hitl import CLIENT_SIDE_TOOL_MARKER + +from .utils import sanitize_tool_name + + +def create_client_side_tool( + resource: AgentClientSideToolResourceConfig, +) -> StructuredTool: + """Create a client-side tool that pauses the graph and waits for the client to execute it. + + The tool uses @durable_interrupt to suspend the graph. The client SDK receives + an executingToolCall event, runs its registered handler, and sends endToolCall + back through CAS. The bridge routes that endToolCall to wait_for_resume(), + which unblocks the graph with the client's result. + """ + tool_name = sanitize_tool_name(resource.name) + input_model = create_model_from_schema(resource.input_schema) + + async def client_side_tool_fn( + *, tool_call_id: Annotated[str, InjectedToolCallId], **kwargs: Any + ) -> Any: + @mockable( + name=resource.name, + description=resource.description, + input_schema=input_model.model_json_schema(), + output_schema=(resource.output_schema or {}), + example_calls=getattr(resource.properties, "example_calls", None), + ) + async def execute_tool() -> dict[str, Any]: + """Execute client-side tool, pausing for client response.""" + + @durable_interrupt + async def wait_for_client_execution() -> dict[str, Any]: + return { + "tool_call_id": tool_call_id, + "tool_name": tool_name, + "input": kwargs, + "is_execution_phase": True, + } + + result = await wait_for_client_execution() + return result.get("output", result) if isinstance(result, dict) else result + + result = await execute_tool() + + if isinstance(result, dict): + try: + content = json.dumps(result) + except TypeError: + content = str(result) + else: + content = str(result) if result is not None else "" + + return ToolMessage( + content=content, + tool_call_id=tool_call_id, + response_metadata={CLIENT_SIDE_TOOL_MARKER: True}, + ) + + tool = StructuredTool( + name=tool_name, + description=resource.description or f"Client-side tool: {tool_name}", + args_schema=input_model, + coroutine=client_side_tool_fn, + metadata={ + CLIENT_SIDE_TOOL_MARKER: True, + "output_schema": resource.output_schema, + }, + ) + + return tool diff --git a/src/uipath_langchain/agent/tools/tool_factory.py b/src/uipath_langchain/agent/tools/tool_factory.py index 17708f17c..2a28cbf49 100644 --- a/src/uipath_langchain/agent/tools/tool_factory.py +++ b/src/uipath_langchain/agent/tools/tool_factory.py @@ -5,6 +5,7 @@ from langchain_core.language_models import BaseChatModel from langchain_core.tools import BaseTool from uipath.agent.models.agent import ( + AgentClientSideToolResourceConfig, AgentContextResourceConfig, AgentEscalationResourceConfig, AgentIntegrationToolResourceConfig, @@ -18,6 +19,7 @@ from uipath_langchain.chat.hitl import REQUIRE_CONVERSATIONAL_CONFIRMATION +from .client_side_tool import create_client_side_tool from .context_tool import create_context_tool from .escalation_tool import create_escalation_tool from .extraction_tool import create_ixp_extraction_tool @@ -120,4 +122,7 @@ async def _build_tool_for_resource( elif isinstance(resource, AgentIxpVsEscalationResourceConfig): return create_ixp_escalation_tool(resource) + elif isinstance(resource, AgentClientSideToolResourceConfig): + return create_client_side_tool(resource) + return None diff --git a/src/uipath_langchain/agent/tools/tool_node.py b/src/uipath_langchain/agent/tools/tool_node.py index 88480c5a3..03028e73f 100644 --- a/src/uipath_langchain/agent/tools/tool_node.py +++ b/src/uipath_langchain/agent/tools/tool_node.py @@ -23,6 +23,7 @@ find_latest_ai_message, ) from uipath_langchain.chat.hitl import ( + CLIENT_SIDE_TOOL_MARKER, REQUIRE_CONVERSATIONAL_CONFIRMATION, request_conversational_tool_confirmation, ) @@ -279,10 +280,13 @@ async def _afunc(state: AgentGraphState) -> OutputType: tool = getattr(tool_node, "tool", None) - # Preserve tool ref so the runtime can discover which tools need confirmation - # (see runtime.py _get_tool_confirmation_info) + # Preserve tool ref so the runtime can discover tool metadata + # (confirmation requirements, client-side markers, etc.) metadata = getattr(tool, "metadata", None) or {} - if isinstance(tool, BaseTool) and metadata.get(REQUIRE_CONVERSATIONAL_CONFIRMATION): + if isinstance(tool, BaseTool) and ( + metadata.get(REQUIRE_CONVERSATIONAL_CONFIRMATION) + or metadata.get(CLIENT_SIDE_TOOL_MARKER) + ): return RunnableCallableWithTool( func=_func, afunc=_afunc, name=tool_name, tool=tool ) diff --git a/src/uipath_langchain/chat/hitl.py b/src/uipath_langchain/chat/hitl.py index 72a99800e..161296867 100644 --- a/src/uipath_langchain/chat/hitl.py +++ b/src/uipath_langchain/chat/hitl.py @@ -14,6 +14,7 @@ CANCELLED_MESSAGE = "Cancelled by user" ARGS_MODIFIED_MESSAGE = "User has modified the tool arguments" +CLIENT_SIDE_TOOL_MARKER = "uipath_client_tool" CONVERSATIONAL_APPROVED_TOOL_ARGS = "conversational_approved_tool_args" REQUIRE_CONVERSATIONAL_CONFIRMATION = "require_conversational_confirmation" @@ -126,12 +127,18 @@ def request_approval( """ tool_call_id: str = tool_args.pop("tool_call_id") + # If this is a server-side tool (not client-side), execution follows immediately + # after confirmation — mark this as the execution trigger so the bridge emits + # executingToolCall. For client-side tools, the execution interrupt sets this instead. + is_execution_trigger = not (tool.metadata or {}).get(CLIENT_SIDE_TOOL_MARKER, False) + @durable_interrupt def ask_confirmation(): return { "tool_call_id": tool_call_id, "tool_name": tool.name, "input": tool_args, + "is_execution_phase": is_execution_trigger, } response = ask_confirmation() diff --git a/src/uipath_langchain/runtime/messages.py b/src/uipath_langchain/runtime/messages.py index 5d7d63aa8..0e36c34ff 100644 --- a/src/uipath_langchain/runtime/messages.py +++ b/src/uipath_langchain/runtime/messages.py @@ -24,6 +24,7 @@ UiPathConversationContentPartEndEvent, UiPathConversationContentPartEvent, UiPathConversationContentPartStartEvent, + UiPathConversationExecutingToolCallEvent, UiPathConversationMessage, UiPathConversationMessageData, UiPathConversationMessageEndEvent, @@ -39,6 +40,8 @@ ) from uipath.runtime import UiPathRuntimeStorageProtocol +from uipath_langchain.chat.hitl import CLIENT_SIDE_TOOL_MARKER + from ._citations import CitationStreamProcessor, extract_citations_from_text logger = logging.getLogger(__name__) @@ -60,6 +63,7 @@ def __init__(self, runtime_id: str, storage: UiPathRuntimeStorageProtocol | None self.storage = storage self.current_message: AIMessageChunk | AIMessage self.tools_requiring_confirmation: dict[str, Any] = {} + self.client_side_tools: dict[str, Any] = {} # {tool_name: output_schema} self.seen_message_ids: set[str] = set() self._storage_lock = asyncio.Lock() self._citation_stream_processor = CitationStreamProcessor() @@ -436,15 +440,40 @@ async def map_current_message_to_start_tool_call_events(self): tool_name in self.tools_requiring_confirmation ) input_schema = self.tools_requiring_confirmation.get(tool_name) + is_client_side = tool_name in self.client_side_tools + output_schema = ( + self.client_side_tools.get(tool_name) + if is_client_side + else None + ) events.append( self.map_tool_call_to_tool_call_start_event( self.current_message.id, tool_call, require_confirmation=require_confirmation or None, input_schema=input_schema, + is_client_side_tool=is_client_side or None, + output_schema=output_schema, ) ) + # Emit executingToolCall from MessageMapper for tools without + # a durable interrupt. Tools with interrupts (client-side, HITL) + # get executingToolCall from the bridge instead. + if not require_confirmation and not is_client_side: + events.append( + UiPathConversationMessageEvent( + message_id=self.current_message.id, + tool_call=UiPathConversationToolCallEvent( + tool_call_id=tool_call["id"], + executing=UiPathConversationExecutingToolCallEvent( + tool_name=tool_call["name"], + input=tool_call["args"], + ), + ), + ) + ) + if self.storage is not None: await self.storage.set_value( self.runtime_id, @@ -476,19 +505,24 @@ async def map_tool_message_to_events( # Keep as string if not valid JSON pass - events = [ - UiPathConversationMessageEvent( - message_id=message_id, - tool_call=UiPathConversationToolCallEvent( - tool_call_id=message.tool_call_id, - end=UiPathConversationToolCallEndEvent( - timestamp=self.get_timestamp(), - output=content_value, - is_error=message.status == "error", + # Suppress endToolCall for client-side tools — the client already has the result (it produced it). + is_client_side = message.response_metadata.get(CLIENT_SIDE_TOOL_MARKER, False) + events: list[UiPathConversationMessageEvent] = [] + + if not is_client_side: + events.append( + UiPathConversationMessageEvent( + message_id=message_id, + tool_call=UiPathConversationToolCallEvent( + tool_call_id=message.tool_call_id, + end=UiPathConversationToolCallEndEvent( + timestamp=self.get_timestamp(), + output=content_value, + is_error=message.status == "error", + ), ), - ), + ) ) - ] if is_last_tool_call: events.append(self.map_to_message_end_event(message_id)) @@ -546,6 +580,8 @@ def map_tool_call_to_tool_call_start_event( *, require_confirmation: bool | None = None, input_schema: Any | None = None, + is_client_side_tool: bool | None = None, + output_schema: Any | None = None, ) -> UiPathConversationMessageEvent: return UiPathConversationMessageEvent( message_id=message_id, @@ -557,6 +593,8 @@ def map_tool_call_to_tool_call_start_event( input=tool_call["args"], require_confirmation=require_confirmation, input_schema=input_schema, + is_client_side_tool=is_client_side_tool, + output_schema=output_schema, ), ), ) diff --git a/src/uipath_langchain/runtime/runtime.py b/src/uipath_langchain/runtime/runtime.py index da8d90918..e418648ab 100644 --- a/src/uipath_langchain/runtime/runtime.py +++ b/src/uipath_langchain/runtime/runtime.py @@ -1,5 +1,6 @@ import logging import os +from collections.abc import Iterator from typing import Any, AsyncGenerator from uuid import uuid4 @@ -31,7 +32,7 @@ from uipath.runtime.schema import UiPathRuntimeSchema from uipath_langchain.agent.tools.tool_node import RunnableCallableWithTool -from uipath_langchain.chat.hitl import get_confirmation_schema +from uipath_langchain.chat.hitl import CLIENT_SIDE_TOOL_MARKER, get_confirmation_schema from uipath_langchain.runtime.errors import LangGraphErrorCode, LangGraphRuntimeError from uipath_langchain.runtime.messages import UiPathChatMessagesMapper from uipath_langchain.runtime.schema import get_entrypoints_schema, get_graph_schema @@ -68,6 +69,7 @@ def __init__( self.callbacks: list[BaseCallbackHandler] = callbacks or [] self.chat = UiPathChatMessagesMapper(self.runtime_id, storage) self.chat.tools_requiring_confirmation = self._get_tool_confirmation_info() + self.chat.client_side_tools = self._get_client_side_tools() self._middleware_node_names: set[str] = self._detect_middleware_nodes() async def execute( @@ -490,38 +492,42 @@ def _detect_middleware_nodes(self) -> set[str]: return middleware_nodes - def _get_tool_confirmation_info(self) -> dict[str, Any]: - """Build {tool_name: input_schema} for tools requiring confirmation. - - Walks compiled graph nodes once at runtime init. This is needed because coded agents - (create_agent) export a compiled graph as the only artifact — there's no side channel - to pass confirmation metadata from the build step to the runtime. - """ - schemas: dict[str, Any] = {} + def _iter_graph_tools(self) -> Iterator[BaseTool]: + """Yield all BaseTool instances from compiled graph nodes.""" for node_spec in self.graph.nodes.values(): bound = getattr(node_spec, "bound", None) if bound is None: continue - # Coded agents: one tool per node - if isinstance(bound, RunnableCallableWithTool): - schema = get_confirmation_schema(bound.tool) - if schema is not None: - schemas[bound.tool.name] = schema + tool = getattr(bound, "tool", None) + if isinstance(tool, BaseTool): + yield tool continue - # Low-code agents: multiple tools in one node tools_by_name = getattr(bound, "tools_by_name", None) if isinstance(tools_by_name, dict): - for tool in tools_by_name.values(): - if not isinstance(tool, BaseTool): - continue - schema = get_confirmation_schema(tool) - if schema is not None: - schemas[tool.name] = schema + for t in tools_by_name.values(): + if isinstance(t, BaseTool): + yield t + def _get_tool_confirmation_info(self) -> dict[str, Any]: + """Build {tool_name: input_schema} for tools requiring confirmation.""" + schemas: dict[str, Any] = {} + for tool in self._iter_graph_tools(): + schema = get_confirmation_schema(tool) + if schema is not None: + schemas[tool.name] = schema return schemas + def _get_client_side_tools(self) -> dict[str, Any]: + """Build {tool_name: output_schema} for client-side tools.""" + tools: dict[str, Any] = {} + for tool in self._iter_graph_tools(): + metadata = getattr(tool, "metadata", None) or {} + if metadata.get(CLIENT_SIDE_TOOL_MARKER): + tools[tool.name] = metadata.get("output_schema") + return tools + def _is_middleware_node(self, node_name: str) -> bool: """Check if a node name represents a middleware node.""" return node_name in self._middleware_node_names diff --git a/tests/runtime/test_chat_message_mapper.py b/tests/runtime/test_chat_message_mapper.py index d2bdee792..0233ad373 100644 --- a/tests/runtime/test_chat_message_mapper.py +++ b/tests/runtime/test_chat_message_mapper.py @@ -2102,3 +2102,151 @@ async def test_mixed_tools_only_confirmation_has_metadata(self): assert "confirm_tool" in tool_starts assert tool_starts["normal_tool"].require_confirmation is None assert tool_starts["confirm_tool"].require_confirmation is True + + +class TestExecutingToolCallEmission: + """Tests for executingToolCall event emission from MessageMapper.""" + + @pytest.mark.asyncio + async def test_emits_executing_for_normal_tool(self): + """Should emit executingToolCall for a server tool without confirmation or client-side marker.""" + storage = create_mock_storage() + storage.get_value.return_value = {} + mapper = UiPathChatMessagesMapper("test-runtime", storage) + + first_chunk = AIMessageChunk( + content="", + id="msg-1", + tool_calls=[{"id": "tc-1", "name": "server_tool", "args": {"x": 1}}], + ) + await mapper.map_event(first_chunk) + + last_chunk = AIMessageChunk(content="", id="msg-1") + object.__setattr__(last_chunk, "chunk_position", "last") + result = await mapper.map_event(last_chunk) + + assert result is not None + executing_events = [ + e for e in result + if e.tool_call is not None and e.tool_call.executing is not None + ] + assert len(executing_events) == 1 + assert executing_events[0].tool_call.executing.tool_name == "server_tool" + + @pytest.mark.asyncio + async def test_no_executing_for_confirmation_tool(self): + """Should NOT emit executingToolCall for a tool that requires confirmation.""" + storage = create_mock_storage() + storage.get_value.return_value = {} + mapper = UiPathChatMessagesMapper("test-runtime", storage) + mapper.tools_requiring_confirmation = {"confirm_tool": {}} + + first_chunk = AIMessageChunk( + content="", + id="msg-1", + tool_calls=[{"id": "tc-1", "name": "confirm_tool", "args": {}}], + ) + await mapper.map_event(first_chunk) + + last_chunk = AIMessageChunk(content="", id="msg-1") + object.__setattr__(last_chunk, "chunk_position", "last") + result = await mapper.map_event(last_chunk) + + assert result is not None + executing_events = [ + e for e in result + if e.tool_call is not None and e.tool_call.executing is not None + ] + assert len(executing_events) == 0 + + @pytest.mark.asyncio + async def test_no_executing_for_client_side_tool(self): + """Should NOT emit executingToolCall for a client-side tool (bridge handles it).""" + storage = create_mock_storage() + storage.get_value.return_value = {} + mapper = UiPathChatMessagesMapper("test-runtime", storage) + mapper.client_side_tools = {"client_tool": {"type": "object"}} + + first_chunk = AIMessageChunk( + content="", + id="msg-1", + tool_calls=[{"id": "tc-1", "name": "client_tool", "args": {"title": "Avatar"}}], + ) + await mapper.map_event(first_chunk) + + last_chunk = AIMessageChunk(content="", id="msg-1") + object.__setattr__(last_chunk, "chunk_position", "last") + result = await mapper.map_event(last_chunk) + + assert result is not None + executing_events = [ + e for e in result + if e.tool_call is not None and e.tool_call.executing is not None + ] + assert len(executing_events) == 0 + + +class TestClientSideToolEndSuppression: + """Tests for suppressing endToolCall for client-side tools.""" + + @pytest.mark.asyncio + async def test_client_side_tool_suppresses_end_event(self): + """ToolMessage with CLIENT_SIDE_TOOL_MARKER should NOT emit endToolCall.""" + storage = create_mock_storage() + storage.get_value.return_value = {"tool-1": "msg-123"} + mapper = UiPathChatMessagesMapper("test-runtime", storage) + + tool_msg = ToolMessage( + content='{"rating": 9}', + tool_call_id="tool-1", + response_metadata={"uipath_client_tool": True}, + ) + + result = await mapper.map_event(tool_msg) + + assert result is not None + tool_end_events = [ + e for e in result + if e.tool_call is not None and e.tool_call.end is not None + ] + assert len(tool_end_events) == 0 + + @pytest.mark.asyncio + async def test_client_side_tool_still_emits_message_end(self): + """ToolMessage with CLIENT_SIDE_TOOL_MARKER should still emit message end when it's the last tool.""" + storage = create_mock_storage() + storage.get_value.return_value = {"tool-1": "msg-123"} + mapper = UiPathChatMessagesMapper("test-runtime", storage) + + tool_msg = ToolMessage( + content='{"rating": 9}', + tool_call_id="tool-1", + response_metadata={"uipath_client_tool": True}, + ) + + result = await mapper.map_event(tool_msg) + + assert result is not None + message_end_events = [e for e in result if e.end is not None] + assert len(message_end_events) == 1 + + @pytest.mark.asyncio + async def test_normal_tool_emits_end_event(self): + """ToolMessage without CLIENT_SIDE_TOOL_MARKER should emit endToolCall normally.""" + storage = create_mock_storage() + storage.get_value.return_value = {"tool-1": "msg-123"} + mapper = UiPathChatMessagesMapper("test-runtime", storage) + + tool_msg = ToolMessage( + content='{"result": "success"}', + tool_call_id="tool-1", + ) + + result = await mapper.map_event(tool_msg) + + assert result is not None + tool_end_events = [ + e for e in result + if e.tool_call is not None and e.tool_call.end is not None + ] + assert len(tool_end_events) == 1 diff --git a/tests/runtime/test_client_side_tool_discovery.py b/tests/runtime/test_client_side_tool_discovery.py new file mode 100644 index 000000000..d74922942 --- /dev/null +++ b/tests/runtime/test_client_side_tool_discovery.py @@ -0,0 +1,91 @@ +"""Tests that _get_client_side_tools discovers client-side tools through RunnableCallableWithTool wrappers. + +Integration guard: if the wrapping pipeline changes and stops preserving the +BaseTool reference for client-side tools, these tests will fail. +""" + +from typing import Any + +from langchain_core.tools import BaseTool +from langgraph.constants import END, START +from langgraph.graph import StateGraph +from pydantic import BaseModel, Field + +from uipath_langchain.agent.tools.tool_node import ( + UiPathToolNode, + wrap_tools_with_error_handling, +) +from uipath_langchain.chat.hitl import CLIENT_SIDE_TOOL_MARKER +from uipath_langchain.runtime.runtime import UiPathLangGraphRuntime + + +class _ClientSideInput(BaseModel): + title: str = Field(description="Movie title") + + +class _ClientSideTool(BaseTool): + name: str = "client_tool" + description: str = "A client-side tool" + args_schema: type[BaseModel] = _ClientSideInput + metadata: dict[str, Any] = { + CLIENT_SIDE_TOOL_MARKER: True, + "output_schema": {"type": "object", "properties": {"rating": {"type": "number"}}}, + } + + def _run(self, title: str) -> str: + return f"result for {title}" + + +class _NormalTool(BaseTool): + name: str = "normal_tool" + description: str = "A normal server tool" + + def _run(self) -> str: + return "done" + + +class _MinimalState(BaseModel): + value: str = "" + + +def _compile_graph_with_wrapped_tools(tools: list[BaseTool]): + """Build and compile a minimal graph with tools wrapped through the standard pipeline.""" + tool_nodes = {t.name: UiPathToolNode(t) for t in tools} + wrapped = wrap_tools_with_error_handling(tool_nodes) + + builder: StateGraph[_MinimalState] = StateGraph(_MinimalState) + names = list(wrapped.keys()) + for name, node in wrapped.items(): + builder.add_node(name, node) + + builder.add_edge(START, names[0]) + for i in range(len(names) - 1): + builder.add_edge(names[i], names[i + 1]) + builder.add_edge(names[-1], END) + + return builder.compile() + + +class TestClientSideToolDiscovery: + def test_discovers_client_side_tool_through_wrapper(self): + graph = _compile_graph_with_wrapped_tools([_ClientSideTool(), _NormalTool()]) + runtime = UiPathLangGraphRuntime(graph) + + client_tools = runtime.chat.client_side_tools + assert "client_tool" in client_tools + assert "normal_tool" not in client_tools + + def test_output_schema_is_preserved(self): + graph = _compile_graph_with_wrapped_tools([_ClientSideTool()]) + runtime = UiPathLangGraphRuntime(graph) + + schema = runtime.chat.client_side_tools["client_tool"] + assert schema is not None + assert "properties" in schema + assert "rating" in schema["properties"] + + def test_empty_when_no_client_side_tools(self): + graph = _compile_graph_with_wrapped_tools([_NormalTool()]) + runtime = UiPathLangGraphRuntime(graph) + + assert runtime.chat.client_side_tools == {}