diff --git a/pyproject.toml b/pyproject.toml index 69da33c84..b9a2b5167 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "uipath-langchain" -version = "0.10.23" +version = "0.10.24" description = "Python SDK that enables developers to build and deploy LangGraph agents to the UiPath Cloud Platform" readme = { file = "README.md", content-type = "text/markdown" } requires-python = ">=3.11" diff --git a/src/uipath_langchain/_utils/__init__.py b/src/uipath_langchain/_utils/__init__.py index 8e48071ea..7a2e483c5 100644 --- a/src/uipath_langchain/_utils/__init__.py +++ b/src/uipath_langchain/_utils/__init__.py @@ -1,9 +1,15 @@ from ._environment import get_execution_folder_path -from ._otel import set_span_attribute +from ._otel import ( + get_current_span_and_trace_ids, + set_current_span_error, + set_span_attribute, +) from ._request_mixin import UiPathRequestMixin __all__ = [ "UiPathRequestMixin", + "get_current_span_and_trace_ids", "get_execution_folder_path", + "set_current_span_error", "set_span_attribute", ] diff --git a/src/uipath_langchain/_utils/_otel.py b/src/uipath_langchain/_utils/_otel.py index 99ffed0d0..796cf00ac 100644 --- a/src/uipath_langchain/_utils/_otel.py +++ b/src/uipath_langchain/_utils/_otel.py @@ -3,6 +3,20 @@ from typing import Any +def get_current_span_and_trace_ids() -> tuple[str, str]: + """Return the current OTel span and trace IDs as hex strings.""" + try: + from opentelemetry import trace + + span = trace.get_current_span() + context = span.get_span_context() + if not context.is_valid: + return "", "" + return f"{context.span_id:016x}", f"{context.trace_id:032x}" + except ImportError: + return "", "" + + def set_span_attribute(name: str, value: Any) -> None: """Set an attribute on the current OTel span (no-op if unavailable).""" try: @@ -13,3 +27,17 @@ def set_span_attribute(name: str, value: Any) -> None: span.set_attribute(name, value) except ImportError: pass + + +def set_current_span_error(error: BaseException) -> None: + """Record an exception and mark the current OTel span as errored.""" + try: + from opentelemetry import trace + from opentelemetry.trace import StatusCode + + span = trace.get_current_span() + if span.is_recording(): + span.record_exception(error) + span.set_status(StatusCode.ERROR, str(error)) + except ImportError: + pass diff --git a/src/uipath_langchain/agent/tools/escalation_memory.py b/src/uipath_langchain/agent/tools/escalation_memory.py new file mode 100644 index 000000000..8c82ff3c2 --- /dev/null +++ b/src/uipath_langchain/agent/tools/escalation_memory.py @@ -0,0 +1,643 @@ +"""Escalation memory support for Action Center escalation tools.""" + +import json +import logging +from typing import Any +from uuid import UUID + +from pydantic import BaseModel, ConfigDict, Field, ValidationError, field_validator +from uipath.agent.models.agent import AgentEscalationResourceConfig +from uipath.platform import UiPath +from uipath.platform.common import UiPathConfig +from uipath.platform.common._bindings import _resource_overwrites +from uipath.platform.memory import ( + EscalationMemoryIngestRequest, + FieldSettings, + MemorySearchRequest, + SearchField, + SearchMode, + SearchSettings, +) + +from uipath_langchain._utils import ( + get_execution_folder_path, + set_current_span_error, + set_span_attribute, +) + +logger = logging.getLogger(__name__) + +MEMORY_CACHE_HIT_METRIC = "MemoryCacheHit" +MEMORY_CACHE_MISS_METRIC = "MemoryCacheMiss" + +_metric_counters: dict[str, Any] = {} +_MISSING_VALUE = object() + + +class EscalationMemoryFieldSetting(BaseModel): + """Per-field search configuration for escalation memory.""" + + model_config = ConfigDict(validate_by_alias=True, validate_by_name=True) + + name: str + weight: float = Field(default=1.0, ge=0.0, le=1.0) + + +class EscalationMemorySettings(BaseModel): + """Search settings configured on an escalation memory resource.""" + + model_config = ConfigDict(validate_by_alias=True, validate_by_name=True) + + threshold: float = Field(default=0.0, ge=0.0, le=1.0) + search_mode: SearchMode = Field(default=SearchMode.Hybrid, alias="searchMode") + field_settings: list[EscalationMemoryFieldSetting] | None = Field( + default=None, + alias="fieldSettings", + ) + + @field_validator("search_mode", mode="before") + @classmethod + def _normalize_search_mode(cls, value: Any) -> Any: + if isinstance(value, str): + normalized = value.lower() + if normalized == "hybrid": + return SearchMode.Hybrid + if normalized == "semantic": + return SearchMode.Semantic + return value + + +class EscalationMemoryCachedResult(BaseModel): + """Cached escalation output returned by memory search.""" + + output: Any = None + outcome: str | None = None + + +class EscalationMemoryRetriever: + """Retrieves previously resolved escalation outcomes from UiPath memory.""" + + def __init__( + self, + memory_space_id: str, + *, + folder_path: str | None = None, + memory_settings: EscalationMemorySettings | None = None, + uipath_sdk: UiPath | None = None, + ) -> None: + self.memory_space_id = memory_space_id + self.folder_path = folder_path + self.memory_settings = memory_settings or EscalationMemorySettings() + self._uipath_sdk = uipath_sdk + + async def aretrieve( + self, + serialized_input: dict[str, Any], + ) -> EscalationMemoryCachedResult | None: + """Search escalation memory and return the first cached answer.""" + request = self._build_search_request(serialized_input) + sdk = self._uipath_sdk if self._uipath_sdk is not None else UiPath() + try: + response = await sdk.memory.escalation_search_async( + memory_space_id=self.memory_space_id, + request=request, + folder_path=self.folder_path, + ) + except ValidationError: + response = await self._raw_escalation_search(sdk, request) + + return _cached_result_from_search_response(response) + + def _build_search_request( + self, + serialized_input: dict[str, Any], + ) -> MemorySearchRequest: + fields = _build_search_fields(serialized_input, self.memory_settings) + return MemorySearchRequest( + fields=fields, + settings=SearchSettings( + threshold=self.memory_settings.threshold, + result_count=1, + search_mode=self.memory_settings.search_mode, + ), + definition_system_prompt="", + ) + + async def _raw_escalation_search( + self, + sdk: UiPath, + request: MemorySearchRequest, + ) -> Any: + spec = sdk.memory._escalation_search_spec( + self.memory_space_id, + folder_path=self.folder_path, + ) + response = await sdk.memory.request_async( + spec.method, + spec.endpoint, + json=request.model_dump(by_alias=True, exclude_none=True), + headers=spec.headers, + ) + return response.json() + + +def _get_escalation_memory_space_id( + resource: AgentEscalationResourceConfig, + agent: Any | None = None, +) -> str | None: + """Resolve memory space ID from escalation resource or agent memory feature.""" + if not _is_escalation_memory_enabled(resource): + return None + + memory = _get_escalation_memory_properties(resource) + memory_space_id = _read_first_value( + (resource, memory), + "memorySpaceId", + "memory_space_id", + ) + if memory_space_id: + return str(memory_space_id) + + memory_space_name = _read_first_value( + (resource, memory), + "memorySpaceName", + "memory_space_name", + ) + folder_path = _read_value(memory, "folderPath", "folder_path") + if not memory_space_name: + feature = _get_agent_memory_space_feature(agent) + memory_space_id = _read_value(feature, "memorySpaceId", "memory_space_id") + if memory_space_id: + return str(memory_space_id) + memory_space_name = _read_value( + feature, + "memorySpaceName", + "memory_space_name", + ) + folder_path = _read_value(feature, "folderPath", "folder_path") or folder_path + + if not memory_space_name: + return None + + return _resolve_memory_space_id_by_name(str(memory_space_name), folder_path) + + +def _get_escalation_memory_folder_path( + resource: AgentEscalationResourceConfig, + agent: Any | None = None, +) -> str | None: + """Resolve folder path to use for escalation memory API calls.""" + if not _is_escalation_memory_enabled(resource): + return None + + memory = _get_escalation_memory_properties(resource) + memory_space_name = _read_first_value( + (resource, memory), + "memorySpaceName", + "memory_space_name", + ) + folder_path = _read_value(memory, "folderPath", "folder_path") + if not memory_space_name and not folder_path: + feature = _get_agent_memory_space_feature(agent) + memory_space_name = _read_value( + feature, + "memorySpaceName", + "memory_space_name", + ) + folder_path = _read_value(feature, "folderPath", "folder_path") or folder_path + + return _resolve_memory_folder_path( + folder_path, str(memory_space_name) if memory_space_name else None + ) + + +def _get_escalation_memory_settings( + resource: AgentEscalationResourceConfig, +) -> EscalationMemorySettings | None: + """Extract memory settings from escalation resource properties.""" + if not _is_escalation_memory_enabled(resource): + return None + + memory = _get_escalation_memory_properties(resource) + if memory is None: + return None + return _coerce_memory_settings(memory) + + +def _is_escalation_memory_enabled(resource: AgentEscalationResourceConfig) -> bool: + memory = _get_escalation_memory_properties(resource) + memory_enabled = _read_value(memory, "isEnabled", "is_enabled") + if memory_enabled is not None: + return bool(memory_enabled) + return bool( + _read_value(resource, "isAgentMemoryEnabled", "is_agent_memory_enabled") + ) + + +def _get_escalation_memory_properties(resource: AgentEscalationResourceConfig) -> Any: + properties = _read_value(resource, "properties") + return _read_value(properties, "memory") if properties is not None else None + + +def _get_agent_memory_space_feature(agent: Any | None) -> Any: + features = _read_value(agent, "features") or [] + for feature in features: + feature_type = _read_value( + feature, "$featureType", "featureType", "feature_type" + ) + if feature_type != "memorySpace": + continue + is_enabled = _read_value(feature, "isEnabled", "is_enabled") + if is_enabled is False: + continue + if _read_value(feature, "memorySpaceId", "memory_space_id") or _read_value( + feature, + "memorySpaceName", + "memory_space_name", + ): + return feature + return None + + +def _resolve_memory_space_id_by_name( + memory_space_name: str, + folder_path: Any, +) -> str | None: + resolved_folder_path = _resolve_memory_folder_path(folder_path, memory_space_name) + try: + escaped_name = memory_space_name.replace("'", "''") + spaces = UiPath().memory.list( + filter=f"name eq '{escaped_name}'", + folder_path=resolved_folder_path, + ) + except Exception: + logger.warning( + "Failed to resolve escalation memory space '%s'", + memory_space_name, + exc_info=True, + ) + return None + + if not spaces.value: + logger.warning( + "Escalation memory space '%s' was not found", + memory_space_name, + ) + return None + return str(spaces.value[0].id) + + +def _resolve_memory_folder_path( + folder_path: Any, + memory_space_name: str | None = None, +) -> str | None: + if memory_space_name: + folder_path = ( + _get_memory_space_folder_override(memory_space_name) or folder_path + ) + if folder_path in (None, "", ".", "solution_folder"): + return get_execution_folder_path() + return str(folder_path) + + +def _get_memory_space_folder_override(memory_space_name: str) -> str | None: + overwrites = _resource_overwrites.get() + if not overwrites: + return None + + overwrite = overwrites.get(f"memorySpace.{memory_space_name}") + if not overwrite: + return None + + folder_identifier = getattr(overwrite, "folder_identifier", None) + if not folder_identifier: + return None + + logger.info( + "Memory space '%s' folder_path overwritten: '%s'", + memory_space_name, + folder_identifier, + ) + return str(folder_identifier) + + +def _get_user_email(user: Any) -> str | None: + """Extract an email address from an Action Center user payload.""" + if user is None: + return None + + for key in ("emailAddress", "email", "Email", "userName"): + value = _read_value(user, key) + if value: + return str(value) + + return None + + +def _get_user_id(user: Any) -> str | None: + """Extract a LLMOps-compatible reviewer ID from an Action Center user payload.""" + if user is None: + return None + + for key in ("identifier", "userId", "userGlobalId", "id"): + user_id = _normalize_user_id(_read_value(user, key)) + if user_id is not None: + return user_id + + return None + + +async def _resolve_user_id(user: Any) -> str | None: + """Resolve the Action Center reviewer to the directory ID expected by LLMOps.""" + user_id = _get_user_id(user) + if user_id: + return user_id + + email = _get_user_email(user) + if not email: + return None + + org_id = UiPathConfig.organization_id + if not org_id: + return None + + try: + response = await UiPath().api_client.request_async( + "GET", + f"/identity_/api/Directory/Search/{org_id}", + scoped="org", + params={ + "startsWith": email, + "sourceFilter": ["directoryUsers", "localUsers"], + }, + ) + except Exception: + logger.warning("Failed to resolve reviewer '%s'", email, exc_info=True) + return None + + for entry in response.json() or []: + if _get_user_email(entry) != email: + continue + user_id = _get_user_id(entry) + if user_id is not None: + return user_id + + return None + + +async def _check_escalation_memory_cache( + memory_space_id: str, + serialized_input: dict[str, Any], + folder_path: str | None = None, + memory_settings: EscalationMemorySettings | None = None, +) -> EscalationMemoryCachedResult | None: + """Check escalation memory for a cached answer.""" + retriever = EscalationMemoryRetriever( + memory_space_id, + folder_path=folder_path, + memory_settings=memory_settings, + ) + + try: + cached_result = await retriever.aretrieve(serialized_input) + except ValueError as error: + logger.warning( + "Skipping escalation memory search for space '%s': %s", + memory_space_id, + error, + ) + _record_custom_metric(MEMORY_CACHE_MISS_METRIC, memory_space_id) + return None + except Exception as error: + set_current_span_error(error) + logger.warning( + "Escalation memory search failed for space '%s'", + memory_space_id, + exc_info=True, + ) + return None + + if cached_result is None: + _record_custom_metric(MEMORY_CACHE_MISS_METRIC, memory_space_id) + return None + + _record_custom_metric(MEMORY_CACHE_HIT_METRIC, memory_space_id) + logger.info("Escalation memory cache hit for space '%s'", memory_space_id) + set_span_attribute("fromMemory", True) + return cached_result + + +async def _ingest_escalation_memory( + memory_space_id: str, + answer: str, + attributes: str, + parent_span_id: str, + trace_id: str, + user_id: str | None = None, + folder_path: str | None = None, +) -> None: + """Persist a resolved escalation outcome into memory.""" + set_span_attribute("fromMemory", False) + normalized_user_id = _normalize_user_id(user_id) + if user_id is not None and normalized_user_id is None: + logger.info( + "Skipping escalation memory reviewer user ID because it is not a GUID: %s", + user_id, + ) + + try: + request = EscalationMemoryIngestRequest( + span_id=parent_span_id, + trace_id=trace_id, + answer=answer, + attributes=attributes, + user_id=normalized_user_id, + ) + sdk = UiPath() + await sdk.memory.escalation_ingest_async( + memory_space_id=memory_space_id, + request=request, + folder_path=folder_path, + ) + set_span_attribute("savedToMemory", True) + logger.info( + "Ingested escalation outcome into memory space '%s'", memory_space_id + ) + except Exception as error: + set_span_attribute("savedToMemory", False) + set_current_span_error(error) + logger.warning( + "Failed to ingest escalation outcome into memory space '%s': %s", + memory_space_id, + error, + exc_info=True, + ) + + +def _build_search_fields( + serialized_input: dict[str, Any], + memory_settings: EscalationMemorySettings, +) -> list[SearchField]: + field_settings = memory_settings.field_settings + field_settings_lookup = ( + {field_setting.name: field_setting for field_setting in field_settings} + if field_settings is not None + else None + ) + + fields: list[SearchField] = [] + for name, value in serialized_input.items(): + value_str = _stringify_search_value(value) + if not value_str: + continue + if field_settings_lookup is not None and name not in field_settings_lookup: + continue + settings = FieldSettings() + if field_settings_lookup is not None: + settings = FieldSettings(weight=field_settings_lookup[name].weight) + fields.append( + SearchField( + key_path=["escalation-input", name], + value=value_str, + settings=settings, + ) + ) + + if not fields: + raise ValueError( + "Escalation memory search requires at least one configured input field." + ) + return fields + + +def _record_custom_metric(metric_name: str, memory_space_id: str) -> None: + attributes = {"memorySpaceId": memory_space_id} + try: + from opentelemetry import metrics, trace + + counter = _metric_counters.get(metric_name) + if counter is None: + counter = metrics.get_meter( + "uipath_langchain.escalation_memory" + ).create_counter(metric_name) + _metric_counters[metric_name] = counter + counter.add(1, attributes) + + span = trace.get_current_span() + if span.is_recording(): + span.add_event( + "customMetric", + { + "name": metric_name, + "value": 1, + **attributes, + }, + ) + except Exception: + logger.debug("Failed to record metric '%s'", metric_name, exc_info=True) + + +def _cached_result_from_search_response( + response: Any, +) -> EscalationMemoryCachedResult | None: + results = _read_value(response, "results") or [] + if not results: + return None + + answer = _read_value(results[0], "answer") + if not answer: + return None + + if isinstance(answer, str): + try: + answer = json.loads(answer) + except json.JSONDecodeError: + logger.warning("Escalation memory cache entry answer is not valid JSON") + return None + + output = _read_value(answer, "output", "Output") + if output is None: + logger.warning( + "Escalation memory cache entry has no 'output' property; treating as cache miss." + ) + return None + + return EscalationMemoryCachedResult( + output=output, + outcome=_read_value(answer, "outcome", "Outcome"), + ) + + +def _coerce_memory_settings(memory: Any) -> EscalationMemorySettings: + if isinstance(memory, EscalationMemorySettings): + return memory + if isinstance(memory, BaseModel): + memory = memory.model_dump(by_alias=True, exclude_none=True) + elif not isinstance(memory, dict): + memory = { + key: getattr(memory, key) + for key in ( + "threshold", + "searchMode", + "search_mode", + "fieldSettings", + "field_settings", + ) + if hasattr(memory, key) + } + return EscalationMemorySettings.model_validate(memory) + + +def _read_value(source: Any, *keys: str) -> Any: + if source is None: + return None + if isinstance(source, dict): + value = _read_mapping_value(source, keys) + return None if value is _MISSING_VALUE else value + if isinstance(source, BaseModel): + value = _read_mapping_value(source.model_extra or {}, keys) + if value is not _MISSING_VALUE: + return value + return _read_attribute_value(source, keys) + + +def _read_first_value(sources: tuple[Any, ...], *keys: str) -> Any: + for source in sources: + value = _read_value(source, *keys) + if value is not None: + return value + return None + + +def _read_mapping_value(source: dict[str, Any], keys: tuple[str, ...]) -> Any: + for key in keys: + if key in source: + return source[key] + return _MISSING_VALUE + + +def _read_attribute_value(source: Any, keys: tuple[str, ...]) -> Any: + for key in keys: + value = getattr(source, key, _MISSING_VALUE) + if value is not _MISSING_VALUE: + return value + return None + + +def _normalize_user_id(value: Any) -> str | None: + if value is None: + return None + try: + return str(UUID(str(value))) + except (TypeError, ValueError): + return None + + +def _stringify_search_value(value: Any) -> str: + if value is None: + return "" + if isinstance(value, str): + return value + if isinstance(value, bool | int | float | list | dict): + return json.dumps(value, sort_keys=True) + return str(value) diff --git a/src/uipath_langchain/agent/tools/escalation_tool.py b/src/uipath_langchain/agent/tools/escalation_tool.py index f59ba421e..b9da0f6e8 100644 --- a/src/uipath_langchain/agent/tools/escalation_tool.py +++ b/src/uipath_langchain/agent/tools/escalation_tool.py @@ -1,11 +1,14 @@ """Escalation tool creation for Action Center integration.""" +import json +import logging +import os from enum import Enum from typing import Any, Literal from langchain_core.messages.tool import ToolCall from langchain_core.tools import BaseTool, StructuredTool -from pydantic import BaseModel, TypeAdapter +from pydantic import BaseModel from uipath.agent.models.agent import ( AgentEscalationChannel, AgentEscalationRecipient, @@ -14,16 +17,20 @@ ArgumentEmailRecipient, ArgumentGroupNameRecipient, AssetRecipient, + LowCodeAgentDefinition, StandardRecipient, ) from uipath.agent.utils.text_tokens import safe_get_nested from uipath.eval.mocks import mockable from uipath.platform import UiPath -from uipath.platform.action_center.tasks import TaskRecipient, TaskRecipientType +from uipath.platform.action_center.tasks import Task, TaskRecipient, TaskRecipientType from uipath.platform.common import WaitEscalation from uipath.runtime.errors import UiPathErrorCategory -from uipath_langchain._utils import get_execution_folder_path +from uipath_langchain._utils import ( + get_current_span_and_trace_ids, + get_execution_folder_path, +) from uipath_langchain._utils.durable_interrupt import durable_interrupt from uipath_langchain.agent.react.jsonschema_pydantic_converter import create_model from uipath_langchain.agent.tools.structured_tool_with_argument_properties import ( @@ -32,6 +39,15 @@ from ..exceptions import AgentRuntimeError, AgentRuntimeErrorCode from ..react.types import AgentGraphState +from .escalation_memory import ( + EscalationMemorySettings, + _check_escalation_memory_cache, + _get_escalation_memory_folder_path, + _get_escalation_memory_settings, + _get_escalation_memory_space_id, + _ingest_escalation_memory, + _resolve_user_id, +) from .tool_node import ToolWrapperReturnType from .utils import ( resolve_task_title, @@ -39,6 +55,8 @@ sanitize_tool_name, ) +_escalation_logger = logging.getLogger(__name__) + class EscalationAction(str, Enum): """Actions that can be taken after an escalation completes.""" @@ -110,15 +128,6 @@ async def resolve_asset(asset_name: str, folder_path: str | None) -> str | None: ) from e -def _get_user_email(user: Any) -> str | None: - """Extract email from user object/dict.""" - if user is None: - return None - if isinstance(user, dict): - return user.get("emailAddress") - return getattr(user, "emailAddress", None) - - def _parse_task_data( data: dict[str, Any], input_schema: dict[str, Any], @@ -161,8 +170,88 @@ def _parse_task_data( return filtered_fields +def _resolve_escalation_action( + outcome: str | None, + outcome_mapping: dict[str, str] | None, +) -> EscalationAction: + outcome_action = ( + outcome_mapping.get(outcome) if outcome_mapping and outcome else None + ) + return ( + EscalationAction(outcome_action) + if outcome_action + else EscalationAction.CONTINUE + ) + + +def _build_escalation_memory_payload( + serialized_input: dict[str, Any], + escalation_output: dict[str, Any], + outcome: str | None, +) -> tuple[dict[str, Any], dict[str, Any]]: + answer = {"output": escalation_output, "outcome": outcome} + attributes = {"arguments": serialized_input} + return answer, attributes + + +def _pop_escalation_memory_span_context( + metadata: dict[str, Any] | None, +) -> tuple[str | None, str | None]: + span_context = (metadata or {}).get("_span_context") + if not isinstance(span_context, dict): + _escalation_logger.debug( + "Escalation memory span context missing _span_context metadata" + ) + return None, None + + parent_span_id = _format_otel_id(span_context.pop("parent_span_id", None), 16) + trace_id = _format_otel_id(span_context.pop("trace_id", None), 32) + _escalation_logger.debug( + "Escalation memory span context: %s", + json.dumps( + { + "parentSpanId": parent_span_id, + "traceId": trace_id, + "remainingContext": span_context, + }, + default=str, + ), + ) + return parent_span_id, trace_id + + +def _format_otel_id(value: Any, width: int) -> str | None: + if value in (None, ""): + return None + if isinstance(value, int): + return f"{value:0{width}x}" + return str(value) + + +def _normalize_trace_id(value: str) -> str: + normalized = value.replace("-", "").lower() + if len(normalized) != 32: + raise ValueError(f"Invalid trace ID format: {value}") + return normalized + + +def _get_exported_trace_id(trace_id: str | None) -> str | None: + trace_id_override = os.environ.get("UIPATH_TRACE_ID") + if trace_id_override: + try: + return _normalize_trace_id(trace_id_override) + except ValueError: + _escalation_logger.warning( + "Ignoring invalid UIPATH_TRACE_ID override: %s", + trace_id_override, + ) + + return trace_id + + def create_escalation_tool( resource: AgentEscalationResourceConfig, + agent: LowCodeAgentDefinition | None = None, ) -> StructuredTool: """Uses interrupt() for Action Center human-in-the-loop.""" @@ -177,7 +266,15 @@ class EscalationToolOutput(BaseModel): data: output_model is_deleted: bool = False + _span_context: dict[str, Any] = {} _bts_context: dict[str, Any] = {} + _memory_space_id: str | None = _get_escalation_memory_space_id(resource, agent) + _memory_folder_path: str | None = _get_escalation_memory_folder_path( + resource, agent + ) + _memory_settings: EscalationMemorySettings | None = _get_escalation_memory_settings( + resource + ) async def escalation_tool_fn(**kwargs: Any) -> dict[str, Any]: agent_input: dict[str, Any] = ( @@ -198,6 +295,24 @@ async def escalation_tool_fn(**kwargs: Any) -> dict[str, Any]: serialized_data = input_model.model_validate(kwargs).model_dump(mode="json") + # --- Escalation memory: check cache before creating HITL task --- + if _memory_space_id: + cached_result = await _check_escalation_memory_cache( + _memory_space_id, + serialized_data, + folder_path=_memory_folder_path or folder_path, + memory_settings=_memory_settings, + ) + if cached_result is not None: + return { + "action": _resolve_escalation_action( + cached_result.outcome, + channel.outcome_mapping, + ), + "output": cached_result.output, + "outcome": cached_result.outcome, + } + @mockable( name=tool_name.lower(), description=resource.description, @@ -235,7 +350,7 @@ async def create_escalation_task(): result = await escalate(**kwargs) if isinstance(result, dict): - result = TypeAdapter(EscalationToolOutput).validate_python(result) + result = Task.model_validate(result) if result.is_deleted: return { @@ -253,15 +368,62 @@ async def create_escalation_task(): output_schema=output_model.model_json_schema(), ) - outcome_str = ( - channel.outcome_mapping.get(outcome) - if channel.outcome_mapping and outcome - else None - ) - escalation_action = ( - EscalationAction(outcome_str) if outcome_str else EscalationAction.CONTINUE + escalation_action = _resolve_escalation_action( + outcome, + channel.outcome_mapping, ) + # --- Escalation memory: persist outcome for future recall --- + if _memory_space_id: + user_id = await _resolve_user_id(result.completed_by_user) + parent_span_id, trace_id = _pop_escalation_memory_span_context( + tool.metadata + ) + if not parent_span_id or not trace_id: + fallback_span_id, fallback_trace_id = get_current_span_and_trace_ids() + _escalation_logger.debug( + "Escalation memory span context fallback: %s", + json.dumps( + { + "fallbackSpanId": fallback_span_id, + "fallbackTraceId": fallback_trace_id, + "hadParentSpanId": bool(parent_span_id), + "hadTraceId": bool(trace_id), + }, + default=str, + ), + ) + parent_span_id = parent_span_id or fallback_span_id + trace_id = trace_id or _get_exported_trace_id(fallback_trace_id) + if not parent_span_id or not trace_id: + _escalation_logger.warning( + "Skipping escalation memory ingest because span provenance is incomplete" + ) + return { + "action": escalation_action, + "output": escalation_output, + "outcome": outcome, + } + answer_payload, attributes_payload = _build_escalation_memory_payload( + serialized_data, + escalation_output, + outcome, + ) + await _ingest_escalation_memory( + _memory_space_id, + answer=json.dumps(answer_payload), + attributes=json.dumps(attributes_payload), + parent_span_id=parent_span_id, + trace_id=trace_id, + user_id=user_id, + folder_path=_memory_folder_path or folder_path, + ) + if user_id is None: + _escalation_logger.info( + "Ingested escalation memory without reviewer user ID " + "because the completed user could not be resolved" + ) + return { "action": escalation_action, "output": escalation_output, @@ -327,6 +489,7 @@ async def escalation_wrapper( "recipient": None, "args_schema": input_model, "output_schema": output_model, + "_span_context": _span_context, "_bts_context": _bts_context, }, ) diff --git a/src/uipath_langchain/agent/tools/tool_factory.py b/src/uipath_langchain/agent/tools/tool_factory.py index 17708f17c..0cbb0135e 100644 --- a/src/uipath_langchain/agent/tools/tool_factory.py +++ b/src/uipath_langchain/agent/tools/tool_factory.py @@ -106,7 +106,7 @@ async def _build_tool_for_resource( return create_context_tool(resource, llm=llm, agent=agent) elif isinstance(resource, AgentEscalationResourceConfig): - return create_escalation_tool(resource) + return create_escalation_tool(resource, agent=agent) elif isinstance(resource, AgentIntegrationToolResourceConfig): return create_integration_tool(resource) diff --git a/tests/agent/tools/test_escalation_memory.py b/tests/agent/tools/test_escalation_memory.py new file mode 100644 index 000000000..116731aa7 --- /dev/null +++ b/tests/agent/tools/test_escalation_memory.py @@ -0,0 +1,697 @@ +"""Tests for escalation memory cache check and ingest.""" + +from types import SimpleNamespace +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest +from pydantic import BaseModel, ConfigDict +from uipath.agent.models.agent import AgentEscalationResourceConfig +from uipath.platform.common._bindings import ( + GenericResourceOverwrite, + _resource_overwrites, +) +from uipath.platform.memory import EscalationMemorySearchResponse + +from uipath_langchain.agent.tools.escalation_memory import ( + MEMORY_CACHE_HIT_METRIC, + MEMORY_CACHE_MISS_METRIC, + EscalationMemoryFieldSetting, + EscalationMemoryRetriever, + EscalationMemorySettings, + _build_search_fields, + _check_escalation_memory_cache, + _coerce_memory_settings, + _get_escalation_memory_folder_path, + _get_escalation_memory_settings, + _get_escalation_memory_space_id, + _get_user_email, + _get_user_id, + _ingest_escalation_memory, + _read_value, + _record_custom_metric, + _resolve_user_id, + _stringify_search_value, +) + +USER_GUID = "a543cbbd-f3f3-4868-bccf-f5142d2d3d7e" + + +def _memory_resource(**overrides: object) -> AgentEscalationResourceConfig: + values: dict[str, object] = { + "name": "approval", + "description": "Request approval", + "channels": [], + } + values.update(overrides) + return AgentEscalationResourceConfig(**values) + + +class TestGetEscalationMemorySpaceId: + def test_returns_none_when_disabled(self) -> None: + resource = _memory_resource(is_agent_memory_enabled=False) + assert _get_escalation_memory_space_id(resource) is None + + def test_returns_space_id_from_extra_field(self) -> None: + resource = _memory_resource( + is_agent_memory_enabled=True, + memorySpaceId="space-abc", + ) + assert _get_escalation_memory_space_id(resource) == "space-abc" + + def test_returns_none_when_no_space_id(self) -> None: + resource = _memory_resource(is_agent_memory_enabled=True) + assert _get_escalation_memory_space_id(resource) is None + + def test_returns_space_id_when_escalation_memory_enabled_in_properties( + self, + ) -> None: + resource = _memory_resource( + is_agent_memory_enabled=False, + properties={ + "memory": { + "isEnabled": True, + "memorySpaceId": "space-from-memory-properties", + } + }, + ) + + assert ( + _get_escalation_memory_space_id(resource) == "space-from-memory-properties" + ) + + @patch("uipath_langchain.agent.tools.escalation_memory.get_execution_folder_path") + @patch("uipath_langchain.agent.tools.escalation_memory.UiPath") + def test_resolves_space_id_from_agent_memory_feature( + self, + mock_uipath_cls: MagicMock, + mock_get_execution_folder_path: MagicMock, + ) -> None: + resource = _memory_resource( + is_agent_memory_enabled=False, + properties={"memory": {"isEnabled": True}}, + ) + agent = SimpleNamespace( + features=[ + { + "$featureType": "memorySpace", + "isEnabled": True, + "memorySpaceName": "MemorySpace", + "folderPath": "solution_folder", + "dynamicFewShotSettings": {"isEnabled": False}, + } + ] + ) + mock_get_execution_folder_path.return_value = "/My Workspace" + mock_sdk = MagicMock() + mock_sdk.memory.list.return_value = SimpleNamespace( + value=[SimpleNamespace(id="resolved-space-id")] + ) + mock_uipath_cls.return_value = mock_sdk + + result = _get_escalation_memory_space_id(resource, agent) + + assert result == "resolved-space-id" + mock_sdk.memory.list.assert_called_once_with( + filter="name eq 'MemorySpace'", + folder_path="/My Workspace", + ) + + @patch("uipath_langchain.agent.tools.escalation_memory.UiPath") + def test_resolves_agent_memory_feature_with_resource_overwrite( + self, + mock_uipath_cls: MagicMock, + ) -> None: + resource = _memory_resource( + is_agent_memory_enabled=False, + properties={"memory": {"isEnabled": True}}, + ) + agent = SimpleNamespace( + features=[ + { + "$featureType": "memorySpace", + "isEnabled": True, + "memorySpaceName": "MemorySpace", + "folderPath": "solution_folder", + "dynamicFewShotSettings": {"isEnabled": False}, + } + ] + ) + mock_sdk = MagicMock() + mock_sdk.memory.list.return_value = SimpleNamespace( + value=[SimpleNamespace(id="resolved-space-id")] + ) + mock_uipath_cls.return_value = mock_sdk + token = _resource_overwrites.set( + { + "memorySpace.MemorySpace": GenericResourceOverwrite( + resource_type="memorySpace", + name="MemorySpace", + folder_path="/My Workspace/Debug_escs", + ) + } + ) + + try: + result = _get_escalation_memory_space_id(resource, agent) + folder_path = _get_escalation_memory_folder_path(resource, agent) + finally: + _resource_overwrites.reset(token) + + assert result == "resolved-space-id" + assert folder_path == "/My Workspace/Debug_escs" + mock_sdk.memory.list.assert_called_once_with( + filter="name eq 'MemorySpace'", + folder_path="/My Workspace/Debug_escs", + ) + + +class TestGetEscalationMemorySettings: + def test_returns_none_when_disabled(self) -> None: + resource = _memory_resource(is_agent_memory_enabled=False) + assert _get_escalation_memory_settings(resource) is None + + def test_returns_none_when_memory_properties_missing(self) -> None: + resource = _memory_resource(is_agent_memory_enabled=True, properties={}) + assert _get_escalation_memory_settings(resource) is None + + def test_returns_typed_settings_from_properties(self) -> None: + resource = _memory_resource( + is_agent_memory_enabled=True, + properties={ + "memory": { + "threshold": 0.7, + "searchMode": "Semantic", + "fieldSettings": [{"name": "question", "weight": 0.4}], + } + }, + ) + + settings = _get_escalation_memory_settings(resource) + + assert settings is not None + assert settings.threshold == 0.7 + assert settings.search_mode.value == "Semantic" + assert settings.field_settings == [ + EscalationMemoryFieldSetting(name="question", weight=0.4) + ] + + def test_returns_settings_when_escalation_memory_enabled_in_properties( + self, + ) -> None: + resource = _memory_resource( + is_agent_memory_enabled=False, + properties={ + "memory": { + "isEnabled": True, + "threshold": 0.8, + "searchMode": "hybrid", + "fieldSettings": [{"name": "request_details", "weight": 1}], + } + }, + ) + + settings = _get_escalation_memory_settings(resource) + + assert settings is not None + assert settings.threshold == 0.8 + assert settings.search_mode.value == "Hybrid" + assert settings.field_settings == [ + EscalationMemoryFieldSetting(name="request_details", weight=1) + ] + + +class TestGetUserEmail: + def test_extracts_email_from_supported_shapes(self) -> None: + assert _get_user_email(None) is None + assert ( + _get_user_email({"emailAddress": "dict@example.com"}) == "dict@example.com" + ) + assert _get_user_email({"email": "email@example.com"}) == "email@example.com" + assert _get_user_email({"Email": "pascal@example.com"}) == "pascal@example.com" + assert _get_user_email({"userName": "user@example.com"}) == "user@example.com" + assert _get_user_email({"name": "Reviewer"}) is None + assert ( + _get_user_email(SimpleNamespace(emailAddress="object@example.com")) + == "object@example.com" + ) + + +class TestGetUserId: + def test_extracts_user_id_from_supported_shapes(self) -> None: + assert _get_user_id(None) is None + assert _get_user_id({"identifier": USER_GUID}) == USER_GUID + assert ( + _get_user_id({"identifier": "aad|cef1337c-3456-4ae9-81c9-30d033dc2bef"}) + is None + ) + assert _get_user_id({"id": "dict-id"}) is None + assert _get_user_id({"id": 4753819}) is None + assert _get_user_id({"id": "4753819"}) is None + assert _get_user_id({"userId": USER_GUID}) == USER_GUID + assert _get_user_id({"userGlobalId": USER_GUID.upper()}) == USER_GUID + assert _get_user_id(SimpleNamespace(identifier=USER_GUID)) == USER_GUID + + +class TestResolveUserId: + @pytest.mark.asyncio + async def test_returns_existing_user_id_without_api_call(self) -> None: + assert await _resolve_user_id({"identifier": USER_GUID}) == USER_GUID + + @pytest.mark.asyncio + @patch("uipath_langchain.agent.tools.escalation_memory.UiPathConfig") + @patch("uipath_langchain.agent.tools.escalation_memory.UiPath") + async def test_resolves_email_to_guid_identifier( + self, + mock_uipath_cls: MagicMock, + mock_config: MagicMock, + ) -> None: + mock_config.organization_id = "org-123" + mock_response = MagicMock() + mock_response.json.return_value = [ + {"email": "reviewer@example.com", "identifier": USER_GUID} + ] + mock_sdk = MagicMock() + mock_sdk.api_client.request_async = AsyncMock(return_value=mock_response) + mock_uipath_cls.return_value = mock_sdk + + result = await _resolve_user_id( + {"emailAddress": "reviewer@example.com", "id": 4753819} + ) + + assert result == USER_GUID + mock_sdk.api_client.request_async.assert_awaited_once_with( + "GET", + "/identity_/api/Directory/Search/org-123", + scoped="org", + params={ + "startsWith": "reviewer@example.com", + "sourceFilter": ["directoryUsers", "localUsers"], + }, + ) + + @pytest.mark.asyncio + @patch("uipath_langchain.agent.tools.escalation_memory.UiPathConfig") + @patch("uipath_langchain.agent.tools.escalation_memory.UiPath") + async def test_ignores_directory_identifier_that_is_not_guid( + self, + mock_uipath_cls: MagicMock, + mock_config: MagicMock, + ) -> None: + mock_config.organization_id = "org-123" + mock_response = MagicMock() + mock_response.json.return_value = [ + { + "email": "reviewer@example.com", + "identifier": "aad|cef1337c-3456-4ae9-81c9-30d033dc2bef", + } + ] + mock_sdk = MagicMock() + mock_sdk.api_client.request_async = AsyncMock(return_value=mock_response) + mock_uipath_cls.return_value = mock_sdk + + result = await _resolve_user_id({"emailAddress": "reviewer@example.com"}) + + assert result is None + + +class TestCheckEscalationMemoryCache: + @pytest.mark.asyncio + @patch("uipath_langchain.agent.tools.escalation_memory._record_custom_metric") + @patch("uipath_langchain.agent.tools.escalation_memory.UiPath") + async def test_returns_cached_answer( + self, mock_uipath_cls: MagicMock, mock_record_metric: MagicMock + ) -> None: + mock_sdk = MagicMock() + mock_uipath_cls.return_value = mock_sdk + + cached_answer = MagicMock() + cached_answer.output = {"action": "approve", "reason": "meets criteria"} + cached_answer.outcome = "approved" + + mock_match = MagicMock() + mock_match.answer = cached_answer + + mock_response = MagicMock() + mock_response.results = [mock_match] + mock_sdk.memory.escalation_search_async = AsyncMock(return_value=mock_response) + + result = await _check_escalation_memory_cache( + "space-123", {"Content": "Is the sky blue?"} + ) + + assert result is not None + assert result.output == {"action": "approve", "reason": "meets criteria"} + assert result.outcome == "approved" + mock_record_metric.assert_called_once_with(MEMORY_CACHE_HIT_METRIC, "space-123") + + @pytest.mark.asyncio + @patch("uipath_langchain.agent.tools.escalation_memory._record_custom_metric") + @patch("uipath_langchain.agent.tools.escalation_memory.UiPath") + async def test_returns_cached_answer_when_sdk_response_has_string_answer( + self, mock_uipath_cls: MagicMock, mock_record_metric: MagicMock + ) -> None: + mock_sdk = MagicMock() + mock_uipath_cls.return_value = mock_sdk + validation_error: Exception | None = None + try: + EscalationMemorySearchResponse.model_validate( + { + "results": [ + { + "answer": '{"output": {"approved": true}, "outcome": "approved"}' + } + ] + } + ) + except Exception as error: + validation_error = error + assert validation_error is not None + mock_sdk.memory.escalation_search_async = AsyncMock( + side_effect=validation_error + ) + mock_sdk.memory._escalation_search_spec.return_value = SimpleNamespace( + method="POST", + endpoint="/llmopstenant_/api/Agent/memory/space-123/escalation/search", + headers={"x-uipath-folderkey": "folder-key"}, + ) + mock_response = MagicMock() + mock_response.json.return_value = { + "results": [ + {"answer": '{"output": {"approved": true}, "outcome": "approved"}'} + ] + } + mock_sdk.memory.request_async = AsyncMock(return_value=mock_response) + + result = await _check_escalation_memory_cache( + "space-123", + {"Content": "Is the sky blue?"}, + folder_path="/Memory/Folder", + ) + + assert result is not None + assert result.output == {"approved": True} + assert result.outcome == "approved" + mock_sdk.memory.request_async.assert_awaited_once() + mock_record_metric.assert_called_once_with(MEMORY_CACHE_HIT_METRIC, "space-123") + + @pytest.mark.asyncio + @patch("uipath_langchain.agent.tools.escalation_memory._record_custom_metric") + @patch("uipath_langchain.agent.tools.escalation_memory.UiPath") + async def test_returns_none_on_empty_results( + self, mock_uipath_cls: MagicMock, mock_record_metric: MagicMock + ) -> None: + mock_sdk = MagicMock() + mock_uipath_cls.return_value = mock_sdk + mock_response = MagicMock() + mock_response.results = [] + mock_sdk.memory.escalation_search_async = AsyncMock(return_value=mock_response) + + result = await _check_escalation_memory_cache("space-123", {"key": "val"}) + assert result is None + mock_record_metric.assert_called_once_with( + MEMORY_CACHE_MISS_METRIC, "space-123" + ) + + @pytest.mark.asyncio + @patch("uipath_langchain.agent.tools.escalation_memory.UiPath") + async def test_returns_none_on_failure(self, mock_uipath_cls: MagicMock) -> None: + mock_sdk = MagicMock() + mock_uipath_cls.return_value = mock_sdk + mock_sdk.memory.escalation_search_async = AsyncMock( + side_effect=Exception("fail") + ) + + result = await _check_escalation_memory_cache("space-123", {"key": "val"}) + assert result is None + + @pytest.mark.asyncio + @patch("uipath_langchain.agent.tools.escalation_memory._record_custom_metric") + async def test_treats_empty_search_fields_as_cache_miss( + self, mock_record_metric: MagicMock + ) -> None: + result = await _check_escalation_memory_cache("space-123", {}) + + assert result is None + mock_record_metric.assert_called_once_with( + MEMORY_CACHE_MISS_METRIC, "space-123" + ) + + @pytest.mark.asyncio + @patch("uipath_langchain.agent.tools.escalation_memory._record_custom_metric") + async def test_treats_unmatched_configured_fields_as_cache_miss( + self, mock_record_metric: MagicMock + ) -> None: + settings = EscalationMemorySettings( + fieldSettings=[{"name": "other", "weight": 1.0}] + ) + + result = await _check_escalation_memory_cache( + "space-123", + {"key": "val"}, + memory_settings=settings, + ) + + assert result is None + mock_record_metric.assert_called_once_with( + MEMORY_CACHE_MISS_METRIC, "space-123" + ) + + +class TestBuildSearchFields: + def test_search_request_includes_required_definition_prompt(self) -> None: + request = EscalationMemoryRetriever("space-123")._build_search_request( + {"keep": "value"} + ) + + assert request.definition_system_prompt == "" + assert ( + request.model_dump(by_alias=True, exclude_none=True)[ + "definitionSystemPrompt" + ] + == "" + ) + + def test_filters_empty_and_unconfigured_fields(self) -> None: + settings = EscalationMemorySettings( + fieldSettings=[ + {"name": "keep", "weight": 0.25}, + {"name": "empty", "weight": 1.0}, + ] + ) + + fields = _build_search_fields( + { + "keep": {"answer": True}, + "empty": None, + "ignored": "value", + }, + settings, + ) + + assert len(fields) == 1 + assert fields[0].key_path == ["escalation-input", "keep"] + assert fields[0].value == '{"answer": true}' + assert fields[0].settings.weight == 0.25 + + +class TestIngestEscalationMemory: + @pytest.mark.asyncio + @patch("uipath_langchain.agent.tools.escalation_memory.UiPath") + async def test_calls_ingest(self, mock_uipath_cls: MagicMock) -> None: + mock_sdk = MagicMock() + mock_uipath_cls.return_value = mock_sdk + mock_sdk.memory.escalation_ingest_async = AsyncMock() + + await _ingest_escalation_memory( + "space-123", + answer='{"approved": true}', + attributes='{"input": "test"}', + parent_span_id="abc123", + trace_id="def456", + user_id=USER_GUID, + ) + + mock_sdk.memory.escalation_ingest_async.assert_called_once() + request = mock_sdk.memory.escalation_ingest_async.call_args.kwargs["request"] + assert request.span_id == "abc123" + assert request.trace_id == "def456" + assert request.user_id == USER_GUID + + @pytest.mark.asyncio + @patch("uipath_langchain.agent.tools.escalation_memory.UiPath") + async def test_calls_ingest_without_user_id( + self, mock_uipath_cls: MagicMock + ) -> None: + mock_sdk = MagicMock() + mock_uipath_cls.return_value = mock_sdk + mock_sdk.memory.escalation_ingest_async = AsyncMock() + + await _ingest_escalation_memory( + "space-123", + answer='{"approved": true}', + attributes='{"input": "test"}', + parent_span_id="abc123", + trace_id="def456", + ) + + request = mock_sdk.memory.escalation_ingest_async.call_args.kwargs["request"] + assert request.user_id is None + + @pytest.mark.asyncio + @patch("uipath_langchain.agent.tools.escalation_memory.UiPath") + async def test_calls_ingest_without_invalid_user_id( + self, mock_uipath_cls: MagicMock + ) -> None: + mock_sdk = MagicMock() + mock_uipath_cls.return_value = mock_sdk + mock_sdk.memory.escalation_ingest_async = AsyncMock() + + await _ingest_escalation_memory( + "space-123", + answer='{"approved": true}', + attributes='{"input": "test"}', + parent_span_id="abc123", + trace_id="def456", + user_id="aad|cef1337c-3456-4ae9-81c9-30d033dc2bef", + ) + + request = mock_sdk.memory.escalation_ingest_async.call_args.kwargs["request"] + assert request.user_id is None + + @pytest.mark.asyncio + @patch("uipath_langchain.agent.tools.escalation_memory.UiPath") + async def test_graceful_on_failure(self, mock_uipath_cls: MagicMock) -> None: + mock_sdk = MagicMock() + mock_uipath_cls.return_value = mock_sdk + mock_sdk.memory.escalation_ingest_async = AsyncMock( + side_effect=Exception("fail") + ) + + # Should not raise + await _ingest_escalation_memory( + "space-123", + answer="yes", + attributes="{}", + parent_span_id="abc123", + trace_id="def456", + user_id="reviewer@example.com", + ) + + +class TestEscalationMemoryUtilities: + def test_record_custom_metric_creates_and_reuses_counter(self, monkeypatch) -> None: + from opentelemetry import metrics, trace + + counters: list[tuple[str, int, dict[str, str]]] = [] + events: list[tuple[str, dict[str, object]]] = [] + + class Counter: + def __init__(self, name: str) -> None: + self.name = name + + def add(self, value: int, attributes: dict[str, str]) -> None: + counters.append((self.name, value, attributes)) + + class Meter: + def __init__(self) -> None: + self.created: list[str] = [] + + def create_counter(self, name: str) -> Counter: + self.created.append(name) + return Counter(name) + + class Span: + def is_recording(self) -> bool: + return True + + def add_event(self, name: str, attributes: dict[str, object]) -> None: + events.append((name, attributes)) + + meter = Meter() + monkeypatch.setattr(metrics, "get_meter", lambda _name: meter) + monkeypatch.setattr(trace, "get_current_span", lambda: Span()) + + from uipath_langchain.agent.tools import escalation_memory + + escalation_memory._metric_counters.clear() + _record_custom_metric(MEMORY_CACHE_HIT_METRIC, "space-123") + _record_custom_metric(MEMORY_CACHE_HIT_METRIC, "space-123") + + assert meter.created == [MEMORY_CACHE_HIT_METRIC] + assert counters == [ + (MEMORY_CACHE_HIT_METRIC, 1, {"memorySpaceId": "space-123"}), + (MEMORY_CACHE_HIT_METRIC, 1, {"memorySpaceId": "space-123"}), + ] + assert events == [ + ( + "customMetric", + { + "name": MEMORY_CACHE_HIT_METRIC, + "value": 1, + "memorySpaceId": "space-123", + }, + ), + ( + "customMetric", + { + "name": MEMORY_CACHE_HIT_METRIC, + "value": 1, + "memorySpaceId": "space-123", + }, + ), + ] + + def test_record_custom_metric_is_best_effort(self, monkeypatch) -> None: + from opentelemetry import metrics + + monkeypatch.setattr( + metrics, + "get_meter", + MagicMock(side_effect=RuntimeError("metrics unavailable")), + ) + + from uipath_langchain.agent.tools import escalation_memory + + escalation_memory._metric_counters.clear() + _record_custom_metric(MEMORY_CACHE_MISS_METRIC, "space-123") + + def test_coerce_memory_settings_from_supported_shapes(self) -> None: + class MemoryModel(BaseModel): + threshold: float = 0.6 + searchMode: str = "Semantic" + fieldSettings: list[dict[str, object]] = [ + {"name": "model-field", "weight": 0.5} + ] + + class MemoryObject: + threshold = 0.8 + searchMode = "Hybrid" + fieldSettings = [{"name": "object-field", "weight": 0.9}] + + existing = EscalationMemorySettings(threshold=0.1) + assert _coerce_memory_settings(existing) is existing + assert _coerce_memory_settings(MemoryModel()).field_settings == [ + EscalationMemoryFieldSetting(name="model-field", weight=0.5) + ] + object_settings = _coerce_memory_settings(MemoryObject()) + assert object_settings.threshold == 0.8 + assert object_settings.field_settings == [ + EscalationMemoryFieldSetting(name="object-field", weight=0.9) + ] + + def test_read_value_from_supported_shapes(self) -> None: + class ExtraModel(BaseModel): + model_config = ConfigDict(extra="allow") + + assert _read_value(None, "missing") is None + assert _read_value({"present": "yes"}, "present") == "yes" + assert _read_value({"other": "yes"}, "missing") is None + assert _read_value(ExtraModel(extra_value="yes"), "extra_value") == "yes" + assert _read_value(SimpleNamespace(present="yes"), "present") == "yes" + assert _read_value(SimpleNamespace(), "missing") is None + + def test_stringify_search_value(self) -> None: + assert _stringify_search_value(None) == "" + assert _stringify_search_value("text") == "text" + assert _stringify_search_value({"b": 2, "a": 1}) == '{"a": 1, "b": 2}' + assert _stringify_search_value(("tuple", 1)) == "('tuple', 1)" diff --git a/tests/agent/tools/test_escalation_tool.py b/tests/agent/tools/test_escalation_tool.py index 8106b56f1..510e2c77f 100644 --- a/tests/agent/tools/test_escalation_tool.py +++ b/tests/agent/tools/test_escalation_tool.py @@ -15,8 +15,12 @@ ) from uipath.platform.action_center.tasks import Task, TaskRecipient, TaskRecipientType -from uipath_langchain.agent.tools.escalation_tool import ( +from uipath_langchain.agent.tools.escalation_memory import ( + EscalationMemoryCachedResult, _get_user_email, +) +from uipath_langchain.agent.tools.escalation_tool import ( + _build_escalation_memory_payload, _parse_task_data, create_escalation_tool, resolve_asset, @@ -286,6 +290,14 @@ async def test_escalation_tool_metadata_has_channel_type(self, escalation_resour assert tool.metadata is not None assert tool.metadata["channel_type"] == "actionCenter" + @pytest.mark.asyncio + async def test_escalation_tool_metadata_has_span_context(self, escalation_resource): + """Test that metadata contains a span context carrier for memory ingest.""" + tool = create_escalation_tool(escalation_resource) + assert tool.metadata is not None + assert "_span_context" in tool.metadata + assert isinstance(tool.metadata["_span_context"], dict) + @pytest.mark.asyncio @patch("uipath_langchain.agent.tools.escalation_tool.UiPath") @patch("uipath_langchain._utils.durable_interrupt.decorator.interrupt") @@ -700,6 +712,111 @@ async def test_escalation_tool_with_outcome_mapping_end( assert mock_interrupt.called + @pytest.mark.asyncio + @patch( + "uipath_langchain.agent.tools.escalation_tool._check_escalation_memory_cache" + ) + async def test_cached_escalation_uses_outcome_mapping( + self, mock_check_memory_cache: AsyncMock + ): + """Test cached outcomes follow the same outcome mapping as live results.""" + from uipath_langchain.agent.exceptions import AgentRuntimeError + + mock_check_memory_cache.return_value = EscalationMemoryCachedResult( + output={"approved": True}, + outcome="approve", + ) + + channel_dict = { + "name": "action_center", + "type": "actionCenter", + "description": "Action Center channel", + "inputSchema": {"type": "object", "properties": {}}, + "outputSchema": {"type": "object", "properties": {}}, + "properties": { + "appName": "ApprovalApp", + "appVersion": 1, + "resourceKey": "test-key", + }, + "recipients": [], + "outcomeMapping": {"approve": "end", "reject": "continue"}, + } + + resource = AgentEscalationResourceConfig( + name="approval", + description="Request approval", + channels=[AgentEscalationChannel(**channel_dict)], + isAgentMemoryEnabled=True, + memorySpaceId="space-123", + ) + + tool = create_escalation_tool(resource) + call = ToolCall(args={}, id="test-call", name=tool.name) + + with pytest.raises(AgentRuntimeError): + await tool.awrapper(tool, call, {}) # type: ignore[attr-defined] + + @pytest.mark.asyncio + @patch("uipath_langchain.agent.tools.escalation_tool.get_execution_folder_path") + @patch( + "uipath_langchain.agent.tools.escalation_tool._check_escalation_memory_cache" + ) + async def test_cache_lookup_uses_memory_folder_path( + self, + mock_check_memory_cache: AsyncMock, + mock_get_execution_folder_path: MagicMock, + ): + """Test escalation memory calls use the memory folder, not task folder.""" + mock_get_execution_folder_path.return_value = "/Execution/Folder" + mock_check_memory_cache.return_value = EscalationMemoryCachedResult( + output={"approved": True}, + outcome="approve", + ) + + channel_dict = { + "name": "action_center", + "type": "actionCenter", + "description": "Action Center channel", + "inputSchema": {"type": "object", "properties": {}}, + "outputSchema": {"type": "object", "properties": {}}, + "properties": { + "appName": "ApprovalApp", + "appVersion": 1, + "resourceKey": "test-key", + }, + "recipients": [], + } + + resource = AgentEscalationResourceConfig( + name="approval", + description="Request approval", + channels=[AgentEscalationChannel(**channel_dict)], + properties={ + "memory": { + "isEnabled": True, + "memorySpaceId": "space-123", + "folderPath": "/Memory/Folder", + } + }, + ) + + tool = create_escalation_tool(resource) + call = ToolCall(args={}, id="test-call", name=tool.name) + + result = await tool.awrapper(tool, call, {}) # type: ignore[attr-defined] + + assert result == { + "output": {"approved": True}, + "outcome": "approve", + "task_id": None, + "assigned_to": None, + } + mock_check_memory_cache.assert_awaited_once() + assert mock_check_memory_cache.await_args is not None + assert ( + mock_check_memory_cache.await_args.kwargs["folder_path"] == "/Memory/Folder" + ) + class TestGetUserEmail: """Test the _get_user_email helper function.""" @@ -906,6 +1023,86 @@ async def test_task_creation_failure_propagates( with pytest.raises(Exception, match="API error"): await tool.awrapper(tool, call, {}) # type: ignore[attr-defined] + @pytest.mark.asyncio + @patch( + "uipath_langchain.agent.tools.escalation_tool.get_current_span_and_trace_ids" + ) + @patch("uipath_langchain.agent.tools.escalation_tool._ingest_escalation_memory") + @patch("uipath_langchain.agent.tools.escalation_tool._resolve_user_id") + @patch( + "uipath_langchain.agent.tools.escalation_tool._check_escalation_memory_cache" + ) + @patch("uipath_langchain.agent.tools.escalation_tool.UiPath") + @patch("uipath_langchain._utils.durable_interrupt.decorator.interrupt") + async def test_memory_ingest_uses_traced_escalation_span_context( + self, + mock_interrupt, + mock_uipath_class, + mock_check_memory_cache, + mock_resolve_user_id, + mock_ingest_memory, + mock_get_current_span_and_trace_ids, + ): + """Escalation memory ingest should use the escalationTool child span.""" + mock_check_memory_cache.return_value = None + mock_resolve_user_id.return_value = "cef1337c-3456-4ae9-81c9-30d033dc2bef" + mock_ingest_memory.return_value = None + mock_get_current_span_and_trace_ids.return_value = ("wrong-span", "wrong-trace") + + task = _make_mock_task(id=555) + mock_client = MagicMock() + mock_client.tasks.create_async = AsyncMock(return_value=task) + mock_uipath_class.return_value = mock_client + + mock_result = MagicMock() + mock_result.action = "approve" + mock_result.data = {} + mock_result.completed_by_user = {"emailAddress": "reviewer@example.com"} + mock_result.is_deleted = False + mock_interrupt.return_value = mock_result + + resource = AgentEscalationResourceConfig( + name="approval", + description="Request approval", + channels=[ + AgentEscalationChannel( + name="action_center", + type="actionCenter", + description="Action Center channel", + input_schema={"type": "object", "properties": {}}, + output_schema={"type": "object", "properties": {}}, + properties=AgentEscalationChannelProperties( + app_name="ApprovalApp", + app_version=1, + resource_key="test-key", + ), + recipients=[], + ) + ], + isAgentMemoryEnabled=True, + memorySpaceId="space-123", + ) + + tool = create_escalation_tool(resource) + assert tool.metadata is not None + tool.metadata["_span_context"]["parent_span_id"] = "3a064d559eca5d62" + tool.metadata["_span_context"]["trace_id"] = "5d3feebba60343dfb9364b89ee304a5b" + + call = ToolCall(args={}, id="test-call", name=tool.name) + await tool.awrapper(tool, call, {}) # type: ignore[attr-defined] + + mock_get_current_span_and_trace_ids.assert_not_called() + mock_ingest_memory.assert_awaited_once() + assert mock_ingest_memory.await_args is not None + assert ( + mock_ingest_memory.await_args.kwargs["parent_span_id"] == "3a064d559eca5d62" + ) + assert ( + mock_ingest_memory.await_args.kwargs["trace_id"] + == "5d3feebba60343dfb9364b89ee304a5b" + ) + assert tool.metadata["_span_context"] == {} + class TestParseTaskData: """Test output task data is filtered correctly.""" @@ -940,3 +1137,27 @@ def test_handles_missing_properties_in_schemas(self): # No properties key in schemas result = _parse_task_data(data, {}, None) assert result == {"field": "value"} + + +class TestEscalationMemoryPayload: + """Test escalation memory ingest payload shape.""" + + def test_builds_trace_and_search_payloads(self): + """Test memory ingest matches the escalation memory service contract.""" + serialized_input = { + "request_details": "User requested escalation before answering." + } + escalation_output = {"reviewer_comment": "approve"} + + answer, attributes = _build_escalation_memory_payload( + serialized_input, + escalation_output, + "Approve", + ) + + assert answer == { + "output": {"reviewer_comment": "approve"}, + "outcome": "Approve", + } + assert attributes == {"arguments": serialized_input} + assert "escalation-input" not in attributes diff --git a/tests/utils/test_otel.py b/tests/utils/test_otel.py new file mode 100644 index 000000000..871e28432 --- /dev/null +++ b/tests/utils/test_otel.py @@ -0,0 +1,106 @@ +"""Tests for OpenTelemetry utility helpers.""" + +import builtins + +from uipath_langchain._utils._otel import ( + get_current_span_and_trace_ids, + set_current_span_error, + set_span_attribute, +) + + +class _SpanContext: + is_valid = True + span_id = 0x123 + trace_id = 0x456 + + +class _RecordingSpan: + def __init__(self) -> None: + self.attributes: dict[str, object] = {} + self.exceptions: list[BaseException] = [] + self.status: tuple[object, str] | None = None + + def get_span_context(self) -> _SpanContext: + return _SpanContext() + + def is_recording(self) -> bool: + return True + + def set_attribute(self, name: str, value: object) -> None: + self.attributes[name] = value + + def record_exception(self, error: BaseException) -> None: + self.exceptions.append(error) + + def set_status(self, code: object, description: str) -> None: + self.status = (code, description) + + +class _InvalidSpan(_RecordingSpan): + def get_span_context(self): + class InvalidContext: + is_valid = False + + return InvalidContext() + + +def test_get_current_span_and_trace_ids(monkeypatch) -> None: + from opentelemetry import trace + + monkeypatch.setattr(trace, "get_current_span", lambda: _RecordingSpan()) + + assert get_current_span_and_trace_ids() == ( + "0000000000000123", + "00000000000000000000000000000456", + ) + + +def test_get_current_span_and_trace_ids_returns_empty_for_invalid_context( + monkeypatch, +) -> None: + from opentelemetry import trace + + monkeypatch.setattr(trace, "get_current_span", lambda: _InvalidSpan()) + + assert get_current_span_and_trace_ids() == ("", "") + + +def test_set_span_attribute(monkeypatch) -> None: + from opentelemetry import trace + + span = _RecordingSpan() + monkeypatch.setattr(trace, "get_current_span", lambda: span) + + set_span_attribute("savedToMemory", True) + + assert span.attributes == {"savedToMemory": True} + + +def test_set_current_span_error(monkeypatch) -> None: + from opentelemetry import trace + from opentelemetry.trace import StatusCode + + span = _RecordingSpan() + error = RuntimeError("memory failed") + monkeypatch.setattr(trace, "get_current_span", lambda: span) + + set_current_span_error(error) + + assert span.exceptions == [error] + assert span.status == (StatusCode.ERROR, "memory failed") + + +def test_otel_helpers_are_noops_without_opentelemetry(monkeypatch) -> None: + real_import = builtins.__import__ + + def fake_import(name, *args, **kwargs): + if name.startswith("opentelemetry"): + raise ImportError + return real_import(name, *args, **kwargs) + + monkeypatch.setattr(builtins, "__import__", fake_import) + + assert get_current_span_and_trace_ids() == ("", "") + set_span_attribute("fromMemory", False) + set_current_span_error(RuntimeError("memory failed")) diff --git a/uv.lock b/uv.lock index 5739ff5ff..700542e30 100644 --- a/uv.lock +++ b/uv.lock @@ -4375,7 +4375,7 @@ wheels = [ [[package]] name = "uipath-langchain" -version = "0.10.23" +version = "0.10.24" source = { editable = "." } dependencies = [ { name = "a2a-sdk" },