Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from langgraph.graph.message import add_messages
from langgraph.graph.state import CompiledStateGraph
from pydantic import BaseModel
from uipath.platform.entities import Entity, QueryRoutingOverrideContext
from uipath.platform.entities import Entity

from ..datafabric_query_tool import DataFabricQueryTool
from . import datafabric_prompt_builder
Expand All @@ -42,18 +42,23 @@ class DataFabricSubgraphState(BaseModel):
class QueryExecutor:
"""Executes SQL queries against Data Fabric."""

def __init__(self, routing_context: QueryRoutingOverrideContext) -> None:
def __init__(self, folders_map: dict[str, str]) -> None:
from uipath.platform import UiPath

self._sdk = UiPath()
self._routing_context = routing_context
from uipath.platform.entities import EntitiesService

sdk = UiPath()
self._entities = EntitiesService(
config=sdk._config,
execution_context=sdk._execution_context,
folders_service=sdk.folders,
folders_map=folders_map,
)

async def __call__(self, sql_query: str) -> dict[str, Any]:
logger.debug("execute_sql called with SQL: %s", sql_query)
try:
records = await self._sdk.entities.query_entity_records_async(
records = await self._entities.query_entity_records_async(
sql_query=sql_query,
routing_context=self._routing_context,
)
return {
"records": records,
Expand Down Expand Up @@ -81,15 +86,13 @@ def __init__(
self,
llm: BaseChatModel,
entities: list[Entity],
routing_context: QueryRoutingOverrideContext,
folders_map: dict[str, str],
max_iterations: int = 25,
resource_description: str = "",
base_system_prompt: str = "",
) -> None:
self._max_iterations = max_iterations
self._execute_sql_tool = self._create_execute_sql_tool(
routing_context, entities
)
self._execute_sql_tool = self._create_execute_sql_tool(folders_map, entities)
self._system_message = SystemMessage(
content=datafabric_prompt_builder.build(
entities, resource_description, base_system_prompt
Expand Down Expand Up @@ -175,7 +178,7 @@ def router(self, state: DataFabricSubgraphState) -> str:

def _create_execute_sql_tool(
self,
routing_context: QueryRoutingOverrideContext,
folders_map: dict[str, str],
entities: list[Entity],
) -> BaseTool:
"""Create the inner ``execute_sql`` tool."""
Expand All @@ -188,15 +191,15 @@ def _create_execute_sql_tool(
"tables and columns. Retry with a corrected query on errors."
),
args_schema=DataFabricExecuteSqlInput,
coroutine=QueryExecutor(routing_context),
coroutine=QueryExecutor(folders_map),
metadata={"tool_type": "datafabric_sql"},
)

@staticmethod
def create(
llm: BaseChatModel,
entities: list[Entity],
routing_context: QueryRoutingOverrideContext,
folders_map: dict[str, str],
max_iterations: int = 25,
resource_description: str = "",
base_system_prompt: str = "",
Expand All @@ -205,7 +208,7 @@ def create(
graph = DataFabricGraph(
llm,
entities,
routing_context,
folders_map,
max_iterations,
resource_description,
base_system_prompt,
Expand Down
41 changes: 25 additions & 16 deletions src/uipath_langchain/agent/tools/datafabric_tool/datafabric_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from langchain_core.tools import BaseTool
from langgraph.graph.state import CompiledStateGraph
from uipath.agent.models.agent import AgentContextResourceConfig
from uipath.platform.entities import Entity, EntityRouting, QueryRoutingOverrideContext
from uipath.platform.entities import Entity

from ..base_uipath_structured_tool import BaseUiPathStructuredTool
from .models import DataFabricQueryInput
Expand All @@ -41,13 +41,13 @@ class DataFabricTextQueryHandler:
def __init__(
self,
entity_identifiers: list[str],
routing_context: QueryRoutingOverrideContext,
folders_map: dict[str, str],
llm: BaseChatModel,
resource_description: str = "",
base_system_prompt: str = "",
) -> None:
self._entity_identifiers = entity_identifiers
self._routing_context = routing_context
self._folders_map = folders_map
self._llm = llm
self._resource_description = resource_description
self._base_system_prompt = base_system_prompt
Expand Down Expand Up @@ -78,7 +78,7 @@ async def _ensure_datafabric_graph(self) -> CompiledStateGraph[Any]:
self._compiled = DataFabricGraph.create(
llm=self._llm,
entities=entities,
routing_context=self._routing_context,
folders_map=self._folders_map,
resource_description=self._resource_description,
base_system_prompt=self._base_system_prompt,
)
Expand Down Expand Up @@ -120,20 +120,29 @@ async def fetch_entity_schemas(entity_identifiers: list[str]) -> list[Entity]:
return [e for e in results if e is not None]


def _build_routing_context(
def _build_folders_map(
resource: AgentContextResourceConfig,
) -> QueryRoutingOverrideContext:
"""Build query routing context from entity set items.
) -> dict[str, str]:
"""Build an entity-name-to-folder-id map from entity set items.

Maps each entity to its folder so the backend resolves
entities at folder level instead of tenant level.
Keys are always ``item.name`` (the entity name used in SQL generation)
so that ``with_folders_map()`` can match generated SQL table names.
When entity resource overwrites are active, the overwritten folder is
used instead of the original ``item.folder_id``.
"""
return QueryRoutingOverrideContext(
entity_routings=[
EntityRouting(entity_name=item.name, folder_id=item.folder_id)
for item in (resource.entity_set or [])
]
)
from uipath.platform.common._bindings import _resource_overwrites

context_overwrites = _resource_overwrites.get() or {}

folders_map: dict[str, str] = {}
for item in resource.entity_set or []:
overwrite = context_overwrites.get(f"entity.{item.id}")
if overwrite is not None:
folders_map[item.name] = overwrite.folder_identifier
else:
folders_map[item.name] = item.folder_id

return folders_map


def create_datafabric_query_tool(
Expand All @@ -154,7 +163,7 @@ def create_datafabric_query_tool(
config = agent_config or {}
handler = DataFabricTextQueryHandler(
entity_identifiers=resource.datafabric_entity_identifiers,
routing_context=_build_routing_context(resource),
folders_map=_build_folders_map(resource),
llm=llm,
resource_description=resource.description or "",
base_system_prompt=config.get(BASE_SYSTEM_PROMPT, ""),
Expand Down
Loading