diff --git a/src/uipath_langchain/agent/react/__init__.py b/src/uipath_langchain/agent/react/__init__.py index 1cd32a9a..96e2ef1a 100644 --- a/src/uipath_langchain/agent/react/__init__.py +++ b/src/uipath_langchain/agent/react/__init__.py @@ -1,7 +1,7 @@ """UiPath ReAct Agent implementation""" from .agent import create_agent -from .types import AgentGraphConfig, AgentGraphNode, AgentGraphState +from .types import AgentGraphConfig, AgentGraphNode, AgentGraphState, AgentResources from .utils import resolve_input_model, resolve_output_model __all__ = [ @@ -11,4 +11,5 @@ "AgentGraphNode", "AgentGraphState", "AgentGraphConfig", + "AgentResources", ] diff --git a/src/uipath_langchain/agent/react/agent.py b/src/uipath_langchain/agent/react/agent.py index 338d348c..9f374f7b 100644 --- a/src/uipath_langchain/agent/react/agent.py +++ b/src/uipath_langchain/agent/react/agent.py @@ -36,6 +36,7 @@ AgentGraphConfig, AgentGraphNode, AgentGraphState, + AgentResources, ) from .utils import create_state_with_input @@ -53,6 +54,7 @@ def create_agent( output_schema: Type[OutputT] | None = None, config: AgentGraphConfig | None = None, guardrails: Sequence[tuple[BaseGuardrail, GuardrailAction]] | None = None, + resources: AgentResources | None = None, ) -> StateGraph[AgentGraphState, None, InputT, OutputT]: """Build agent graph with INIT -> AGENT (subgraph) <-> TOOLS loop, terminated by control flow tools. @@ -74,8 +76,12 @@ def create_agent( ) llm_tools: list[BaseTool] = [*agent_tools, *flow_control_tools] - init_node = create_init_node(messages, input_schema, config.is_conversational) - + init_node = create_init_node( + messages, + input_schema, + config.is_conversational, + resources_for_init=resources, + ) tool_nodes = create_tool_node( agent_tools, handle_tool_errors=config.is_conversational ) diff --git a/src/uipath_langchain/agent/react/init_node.py b/src/uipath_langchain/agent/react/init_node.py index 36b66e92..619c6d1d 100644 --- a/src/uipath_langchain/agent/react/init_node.py +++ b/src/uipath_langchain/agent/react/init_node.py @@ -1,5 +1,6 @@ """State initialization node for the ReAct Agent graph.""" +import logging from typing import Any, Callable, Sequence from langchain_core.messages import HumanMessage, SystemMessage @@ -10,20 +11,51 @@ get_job_attachments, parse_attachments_from_conversation_messages, ) +from .types import AgentResources + +logger = logging.getLogger(__name__) def create_init_node( messages: Sequence[SystemMessage | HumanMessage] - | Callable[[Any], Sequence[SystemMessage | HumanMessage]], + | Callable[..., Sequence[SystemMessage | HumanMessage]], input_schema: type[BaseModel] | None, is_conversational: bool = False, + resources_for_init: AgentResources | None = None, ): - def graph_state_init(state: Any) -> Any: + async def graph_state_init(state: Any) -> Any: + # --- Data Fabric schema fetch (INIT-time) --- + schema_context: str | None = None + if resources_for_init: + from uipath_langchain.agent.tools.datafabric_tool import ( + fetch_entity_schemas, + format_schemas_for_context, + get_datafabric_entity_identifiers_from_resources, + ) + + entity_identifiers = get_datafabric_entity_identifiers_from_resources( + resources_for_init + ) + if entity_identifiers: + logger.info( + "Fetching Data Fabric schemas for %d identifier(s)", + len(entity_identifiers), + ) + entities = await fetch_entity_schemas(entity_identifiers) + schema_context = format_schemas_for_context(entities) + + # --- Resolve messages --- resolved_messages: Sequence[SystemMessage | HumanMessage] | Overwrite if callable(messages): - resolved_messages = list(messages(state)) + if schema_context: + resolved_messages = list( + messages(state, additional_context=schema_context) + ) + else: + resolved_messages = list(messages(state)) else: resolved_messages = list(messages) + if is_conversational: # For conversational agents we need to reorder the messages so that the system message is first, followed by # the initial user message. When resuming the conversation, the state will have the entire message history, diff --git a/src/uipath_langchain/agent/react/types.py b/src/uipath_langchain/agent/react/types.py index d2ab924b..2ea3d98f 100644 --- a/src/uipath_langchain/agent/react/types.py +++ b/src/uipath_langchain/agent/react/types.py @@ -1,9 +1,10 @@ from enum import StrEnum -from typing import Annotated, Any, Hashable, Literal, Optional +from typing import Annotated, Any, Hashable, Literal, Optional, Sequence from langchain_core.messages import AnyMessage from langgraph.graph.message import add_messages from pydantic import BaseModel, Field +from uipath.agent.models.agent import BaseAgentResourceConfig from uipath.agent.react import END_EXECUTION_TOOL, RAISE_ERROR_TOOL from uipath.platform.attachments import Attachment @@ -14,6 +15,8 @@ FLOW_CONTROL_TOOLS = [END_EXECUTION_TOOL.name, RAISE_ERROR_TOOL.name] +AgentResources = Sequence[BaseAgentResourceConfig] + class InnerAgentGraphState(BaseModel): job_attachments: Annotated[dict[str, Attachment], merge_dicts] = {} diff --git a/src/uipath_langchain/agent/tools/__init__.py b/src/uipath_langchain/agent/tools/__init__.py index d06e191b..ced4a9c6 100644 --- a/src/uipath_langchain/agent/tools/__init__.py +++ b/src/uipath_langchain/agent/tools/__init__.py @@ -1,6 +1,13 @@ """Tool creation and management for LowCode agents.""" from .context_tool import create_context_tool +from .datafabric_tool import ( + create_datafabric_tools, + fetch_entity_schemas, + format_schemas_for_context, + get_datafabric_contexts, + get_datafabric_entity_identifiers_from_resources, +) from .escalation_tool import create_escalation_tool from .extraction_tool import create_ixp_extraction_tool from .integration_tool import create_integration_tool @@ -16,12 +23,17 @@ "create_tools_from_resources", "create_tool_node", "create_context_tool", + "create_datafabric_tools", "open_mcp_tools", "create_process_tool", "create_integration_tool", "create_escalation_tool", "create_ixp_extraction_tool", "create_ixp_escalation_tool", + "fetch_entity_schemas", + "format_schemas_for_context", + "get_datafabric_contexts", + "get_datafabric_entity_identifiers_from_resources", "UiPathToolNode", "ToolWrapperMixin", ] diff --git a/src/uipath_langchain/agent/tools/context_tool.py b/src/uipath_langchain/agent/tools/context_tool.py index 70b2b6be..23112cd4 100644 --- a/src/uipath_langchain/agent/tools/context_tool.py +++ b/src/uipath_langchain/agent/tools/context_tool.py @@ -142,6 +142,11 @@ def create_context_tool(resource: AgentContextResourceConfig) -> StructuredTool: return handle_deep_rag(tool_name, resource) elif retrieval_mode == AgentContextRetrievalMode.BATCH_TRANSFORM.value.lower(): return handle_batch_transform(tool_name, resource) + elif retrieval_mode == AgentContextRetrievalMode.DATA_FABRIC.value.lower(): + raise ValueError( + "Data Fabric context should be handled via create_datafabric_tools(), " + "not create_context_tool()" + ) else: return handle_semantic_search(tool_name, resource) diff --git a/src/uipath_langchain/agent/tools/datafabric_tool/__init__.py b/src/uipath_langchain/agent/tools/datafabric_tool/__init__.py new file mode 100644 index 00000000..f000a197 --- /dev/null +++ b/src/uipath_langchain/agent/tools/datafabric_tool/__init__.py @@ -0,0 +1,17 @@ +"""Data Fabric tool module for entity-based SQL queries.""" + +from .datafabric_tool import ( + create_datafabric_tools, + fetch_entity_schemas, + format_schemas_for_context, + get_datafabric_contexts, + get_datafabric_entity_identifiers_from_resources, +) + +__all__ = [ + "create_datafabric_tools", + "fetch_entity_schemas", + "format_schemas_for_context", + "get_datafabric_contexts", + "get_datafabric_entity_identifiers_from_resources", +] diff --git a/src/uipath_langchain/agent/tools/datafabric_tool/datafabric_tool.py b/src/uipath_langchain/agent/tools/datafabric_tool/datafabric_tool.py new file mode 100644 index 00000000..e44d64fa --- /dev/null +++ b/src/uipath_langchain/agent/tools/datafabric_tool/datafabric_tool.py @@ -0,0 +1,352 @@ +"""Data Fabric tool creation for entity-based queries. + +This module provides: +1. A single generic ``query_datafabric`` tool (no per-entity knowledge at build time) +2. Schema fetching & formatting helpers consumed by the INIT node at runtime +3. Helpers to extract entity identifiers from agent definitions +""" + +import logging +from functools import lru_cache +from pathlib import Path +from typing import Any, Sequence + +from langchain_core.tools import BaseTool +from uipath.agent.models.agent import ( + AgentContextResourceConfig, + BaseAgentResourceConfig, + LowCodeAgentDefinition, +) +from uipath.platform.entities import Entity, FieldMetadata + +from ..base_uipath_structured_tool import BaseUiPathStructuredTool + +logger = logging.getLogger(__name__) + +# --- Prompt and Constraints Loading --- + +_PROMPTS_DIR = Path(__file__).parent + + +@lru_cache(maxsize=1) +def _load_sql_constraints() -> str: + """Load SQL constraints from sql_constraints.txt.""" + constraints_path = _PROMPTS_DIR / "sql_constraints.txt" + try: + return constraints_path.read_text(encoding="utf-8") + except FileNotFoundError: + logger.warning(f"SQL constraints file not found: {constraints_path}") + return "" + + +@lru_cache(maxsize=1) +def _load_system_prompt() -> str: + """Load SQL generation strategy from system_prompt.txt.""" + prompt_path = _PROMPTS_DIR / "system_prompt.txt" + try: + return prompt_path.read_text(encoding="utf-8") + except FileNotFoundError: + logger.warning(f"System prompt file not found: {prompt_path}") + return "" + + +# --- Schema Fetching and Formatting --- + + +async def fetch_entity_schemas(entity_identifiers: list[str]) -> list[Entity]: + """Fetch entity metadata from Data Fabric for the given entity identifiers. + + Args: + entity_identifiers: List of entity identifiers to fetch. + + Returns: + List of Entity objects with full schema information. + """ + from uipath.platform import UiPath + + sdk = UiPath() + entities: list[Entity] = [] + + for entity_identifier in entity_identifiers: + try: + entity = await sdk.entities.retrieve_async(entity_identifier) + entities.append(entity) + logger.info(f"Fetched schema for entity '{entity.display_name}'") + except Exception as e: + logger.warning(f"Failed to fetch entity '{entity_identifier}': {e}") + + return entities + + +def format_field_type(field: FieldMetadata) -> str: + """Format a field's type information for display.""" + type_name = field.sql_type.name if field.sql_type else "unknown" + + modifiers = [] + if field.is_primary_key: + modifiers.append("PK") + if field.is_foreign_key and field.reference_entity: + ref_name = field.reference_entity.display_name or field.reference_entity.name + modifiers.append(f"FK → {ref_name}") + if field.is_required: + modifiers.append("required") + + if modifiers: + return f"{type_name}, {', '.join(modifiers)}" + return type_name + + +def format_schemas_for_context(entities: list[Entity]) -> str: + """Format entity schemas as markdown for injection into agent system prompt. + + The output is optimized for SQL query generation by the LLM. + Includes: SQL strategy prompt, constraints, entity schemas, and query patterns. + + Args: + entities: List of Entity objects with schema information. + + Returns: + Markdown-formatted string describing entity schemas and SQL guidance. + """ + if not entities: + return "" + + lines: list[str] = [] + + system_prompt = _load_system_prompt() + if system_prompt: + lines.append("## SQL Query Generation Guidelines") + lines.append("") + lines.append(system_prompt) + lines.append("") + + sql_constraints = _load_sql_constraints() + if sql_constraints: + lines.append("## SQL Constraints") + lines.append("") + lines.append(sql_constraints) + lines.append("") + + lines.append("## Available Data Fabric Entities") + lines.append("") + + for entity in entities: + sql_table = entity.name + display_name = entity.display_name or entity.name + lines.append(f"### Entity: {display_name} (SQL table: `{sql_table}`)") + if entity.description: + lines.append(f"_{entity.description}_") + lines.append("") + lines.append("| Field | Type |") + lines.append("|-------|------|") + + field_names: list[str] = [] + numeric_field: str | None = None + text_field: str | None = None + + for field in entity.fields or []: + if field.is_hidden_field or field.is_system_field: + continue + field_name = field.name + field_type = format_field_type(field) + field_names.append(field_name) + lines.append(f"| {field_name} | {field_type} |") + + sql_type = field.sql_type.name.lower() if field.sql_type else "" + if not numeric_field and sql_type in ( + "int", + "decimal", + "float", + "double", + "bigint", + ): + numeric_field = field_name + if not text_field and sql_type in ( + "varchar", + "nvarchar", + "text", + "string", + "ntext", + ): + text_field = field_name + + lines.append("") + + group_field = text_field or (field_names[0] if field_names else "Category") + agg_field = numeric_field or ( + field_names[1] if len(field_names) > 1 else "Amount" + ) + filter_field = text_field or (field_names[0] if field_names else "Name") + fields_sample = ", ".join(field_names[:5]) if field_names else "*" + + lines.append(f"**Query Patterns for {sql_table}:**") + lines.append("") + lines.append("| User Intent | SQL Pattern |") + lines.append("|-------------|-------------|") + lines.append( + f"| 'Show all' | `SELECT {fields_sample} FROM {sql_table} LIMIT 100` |" + ) + lines.append( + f"| 'Find by X' | `SELECT {fields_sample} FROM {sql_table} WHERE {filter_field} = 'value' LIMIT 100` |" + ) + lines.append( + f"| 'Top N by Y' | `SELECT {fields_sample} FROM {sql_table} ORDER BY {agg_field} DESC LIMIT N` |" + ) + count_col = field_names[0] if field_names else "id" + lines.append( + f"| 'Count by X' | `SELECT {group_field}, COUNT({count_col}) as count FROM {sql_table} GROUP BY {group_field}` |" + ) + lines.append( + f"| 'Top N segments' | `SELECT {group_field}, COUNT({count_col}) as count FROM {sql_table} GROUP BY {group_field} ORDER BY count DESC LIMIT N` |" + ) + lines.append( + f"| 'Sum/Avg of Y' | `SELECT SUM({agg_field}) as total FROM {sql_table}` |" + ) + lines.append("") + + return "\n".join(lines) + + +# --- Data Fabric Context Detection --- + + +def get_datafabric_contexts( + agent: LowCodeAgentDefinition, +) -> list[AgentContextResourceConfig]: + """Extract Data Fabric context resources from agent definition. + + Args: + agent: The agent definition to search. + + Returns: + List of context resources configured for Data Fabric retrieval mode. + """ + return _filter_datafabric_contexts(agent.resources) + + +def _filter_datafabric_contexts( + resources: Sequence[BaseAgentResourceConfig], +) -> list[AgentContextResourceConfig]: + """Filter resources to only Data Fabric context configs.""" + return [ + resource + for resource in resources + if isinstance(resource, AgentContextResourceConfig) + and resource.is_enabled + and resource.is_datafabric + ] + + +def get_datafabric_entity_identifiers_from_resources( + resources: Sequence[BaseAgentResourceConfig], +) -> list[str]: + """Extract Data Fabric entity identifiers from a sequence of resource configs. + + Args: + resources: Resource configs (typically from ``agent_definition.resources``). + + Returns: + Flat list of entity identifier strings across all Data Fabric contexts. + """ + identifiers: list[str] = [] + for context in _filter_datafabric_contexts(resources): + identifiers.extend(context.datafabric_entity_identifiers) + return identifiers + + +# --- Generic Tool Creation --- + +_MAX_RECORDS_IN_RESPONSE = 50 + + +def create_datafabric_query_tool() -> BaseTool: + """Create a single generic ``query_datafabric`` tool. + + The tool accepts an arbitrary SQL SELECT query and dispatches it to + ``sdk.entities.query_entity_records_async()``. Entity knowledge is + *not* baked in — the LLM receives schema guidance via the system + prompt (injected at INIT time) and constructs raw SQL. + """ + + async def _query_datafabric(sql_query: str) -> dict[str, Any]: + from uipath.platform import UiPath + + logger.debug(f"query_datafabric called with SQL: {sql_query}") + + sdk = UiPath() + try: + records = await sdk.entities.query_entity_records_async( + sql_query=sql_query, + ) + total_count = len(records) + truncated = total_count > _MAX_RECORDS_IN_RESPONSE + returned_records = ( + records[:_MAX_RECORDS_IN_RESPONSE] if truncated else records + ) + + result: dict[str, Any] = { + "records": returned_records, + "total_count": total_count, + "returned_count": len(returned_records), + "sql_query": sql_query, + } + if truncated: + result["truncated"] = True + result["message"] = ( + f"Showing {len(returned_records)} of {total_count} records. " + "Use more specific filters or LIMIT to narrow results." + ) + return result + except Exception as e: + logger.error(f"SQL query failed: {e}") + return { + "records": [], + "total_count": 0, + "error": str(e), + "sql_query": sql_query, + } + + return BaseUiPathStructuredTool( + name="query_datafabric", + description=( + "Execute a SQL SELECT query against Data Fabric entities. " + "Refer to the entity schemas in the system prompt for available tables and columns. " + "Include LIMIT unless aggregating." + ), + args_schema={ + "type": "object", + "properties": { + "sql_query": { + "type": "string", + "description": ( + "Complete SQL SELECT statement. " + "Use exact table and column names from the entity schemas in the system prompt." + ), + }, + }, + "required": ["sql_query"], + }, + coroutine=_query_datafabric, + metadata={"tool_type": "datafabric_sql"}, + ) + + +def create_datafabric_tools( + agent: LowCodeAgentDefinition, +) -> list[BaseTool]: + """Register the generic Data Fabric query tool if the agent has DF contexts. + + No fetching, no formatting, no schema — purely tool registration. + Schema hydration happens at INIT time. + + Args: + agent: The agent definition containing Data Fabric context resources. + + Returns: + A list containing the single ``query_datafabric`` tool, or empty. + """ + if not get_datafabric_entity_identifiers_from_resources(agent.resources): + return [] + + logger.info("Registering generic query_datafabric tool") + return [create_datafabric_query_tool()] diff --git a/src/uipath_langchain/agent/tools/datafabric_tool/sql_constraints.txt b/src/uipath_langchain/agent/tools/datafabric_tool/sql_constraints.txt new file mode 100644 index 00000000..804537d5 --- /dev/null +++ b/src/uipath_langchain/agent/tools/datafabric_tool/sql_constraints.txt @@ -0,0 +1,205 @@ +# SQL Query Constraints for SQLite + +## SUPPORTED SCENARIOS + +### 1. Single-Entity Baselines +- Simple projections with explicit column names (NO SELECT *) +- Single-field predicates: =, <>, >, <, >=, <=, BETWEEN, IN, LIKE +- WHERE clauses with AND/OR & parentheses +- IS NULL / IS NOT NULL + +**Examples:** +- SELECT id, name FROM Customer +- SELECT id, name FROM Customer WHERE age >= 21 AND (region='APAC' OR vip=1) +- SELECT id, name FROM Customer WHERE deleted_at IS NULL + +### 2. Multi-Entity Joins (≤4 adapters) +- LEFT JOIN chains via entity model (up to 4 tables) +- Optional adapters pruned +- Shared intermediates +- Null-preserving semantics + +**Examples:** +- SELECT o.id, c.name FROM Order o LEFT JOIN Customer c ON o.customer_id = c.id +- Fields spanning 3-4 adapters with proper LEFT JOIN chains + +### 3. Predicate Distribution & Pushdown +- Adapter-scoped predicates pushed down +- Cross-adapter/global predicates at root +- Empty IN () treated as FALSE + +**Examples:** +- SELECT c.id, c.name FROM Customer c WHERE c.country='IN' AND c.total>1000 +- SELECT id FROM Customer WHERE id IN () -- evaluates to FALSE + +### 4. Aggregations & Grouping (Basic) +- GROUP BY entity fields +- Aggregate functions: SUM, AVG, MIN, MAX, COUNT(column_name) +- Simple expressions in aggregates +- HAVING on aggregates or plain fields + +**Examples:** +- SELECT country, COUNT(id) FROM Customer GROUP BY country +- SELECT dept, SUM(price*qty) as total FROM LineItem GROUP BY dept +- SELECT country, COUNT(id) as cnt FROM Customer GROUP BY country HAVING COUNT(id)>10 + +### 5. Expressions (Minimal) +- CASE (simple/searched) in SELECT/WHERE/ORDER BY +- Arithmetic: +, -, *, / (SQLite-compatible) +- Functions: COALESCE, NULLIF, || +- String functions: LOWER, UPPER, TRIM, LTRIM, RTRIM +- Math functions: ROUND, ABS +- Note: CEIL/FLOOR may have limited adapter support + +**Examples:** +- SELECT CASE WHEN age>=18 THEN 'Adult' ELSE 'Minor' END AS segment FROM Customer +- SELECT COALESCE(nickname, name) as display_name FROM Customer +- SELECT ROUND(amount, 2) FROM Payments + +### 6. Casting & Coercion (Basic) +- CAST among common scalar types +- Implicit numeric widening where adapter accepts +- Explicit casts preferred at root + +**Examples:** +- SELECT CAST(amount AS DECIMAL(12,2)) FROM Payments +- SELECT CAST(id AS TEXT) FROM Customer + +### 7. Ordering & Pagination +- ORDER BY fields/aliases/expressions +- Multi-column ORDER BY +- LIMIT/OFFSET for pagination +- Note: Pagination without ORDER BY may produce non-deterministic results + +**Examples:** +- SELECT id, price*qty AS amt FROM LineItem ORDER BY amt DESC LIMIT 50 OFFSET 100 +- SELECT id, name FROM Customer ORDER BY name LIMIT 10 + +### 8. DISTINCT +- SELECT DISTINCT with explicit column names +- DISTINCT with ORDER BY on projected items/aliases + +**Examples:** +- SELECT DISTINCT country FROM Customer ORDER BY country +- SELECT DISTINCT dept, location FROM Employee + +### 9. Metadata Remapping & Aliasing +- Physical column remaps per adapter +- Alias reuse consistent across clauses +- Column aliases can be used in ORDER BY + +**Examples:** +- SELECT name AS customer_name FROM Customer ORDER BY customer_name + +--- + +## UNSUPPORTED SCENARIOS + +### 1. METADATA_RESOLUTION +- Unknown table/entity names +- Unknown column/field names +- Ambiguous column references without table prefix +- Field not in SELECT but used in ORDER BY (without alias) + +**Examples:** +- SELECT name FROM UnknownTable -- ❌ +- SELECT unknown_column FROM Customer -- ❌ +- SELECT id FROM Customer ORDER BY name -- ❌ (name not in SELECT) + +### 2. PROHIBITED_SQL_PATTERNS +- SELECT * FROM table -- ❌ Must use explicit column names +- Subqueries in FROM, WHERE, or SELECT +- UNION/UNION ALL/INTERSECT/EXCEPT +- Common Table Expressions (WITH/CTE) +- Window functions (ROW_NUMBER, RANK, PARTITION BY) +- Self-joins +- RIGHT JOIN or FULL OUTER JOIN (only LEFT JOIN supported) +- CROSS JOIN + +**Examples:** +- SELECT * FROM Customer -- ❌ +- SELECT id FROM (SELECT * FROM Customer) -- ❌ +- SELECT id FROM Customer UNION SELECT id FROM Order -- ❌ +- WITH cte AS (SELECT id FROM Customer) SELECT * FROM cte -- ❌ + +### 3. COMPLEX_AGGREGATIONS +- Nested aggregations: COUNT(DISTINCT(...)), SUM(DISTINCT(...)) +- Aggregations without GROUP BY on non-aggregated columns +- HAVING without GROUP BY +- COUNT(*) not allowed, use COUNT(column_name) instead + +**Examples:** +- SELECT COUNT(DISTINCT dept) FROM Employee -- ❌ +- SELECT name, COUNT(id) FROM Employee -- ❌ (name not in GROUP BY) +- SELECT AVG(salary) FROM Employee HAVING AVG(salary) > 50000 -- ❌ (no GROUP BY) + +### 4. ADVANCED_JOINS +- More than 4 tables in JOIN chain +- RIGHT JOIN +- FULL OUTER JOIN +- CROSS JOIN +- Self-joins +- Non-equi joins (theta joins) + +**Examples:** +- SELECT * FROM t1 RIGHT JOIN t2 -- ❌ +- SELECT * FROM t1, t2 -- ❌ (implicit CROSS JOIN) +- SELECT * FROM Employee e1 JOIN Employee e2 ON e1.manager_id = e2.id -- ❌ (self-join) + +### 5. UNSUPPORTED_FUNCTIONS +- Date/time manipulation functions (DATE_ADD, DATE_SUB, DATEDIFF) +- JSON functions (JSON_EXTRACT, JSON_ARRAY) +- Regex functions +- User-defined functions (UDFs) +- String aggregation (GROUP_CONCAT with complex separators) + +**Examples:** +- SELECT DATE_ADD(created_at, INTERVAL 1 DAY) FROM Order -- ❌ +- SELECT JSON_EXTRACT(data, '$.field') FROM Table -- ❌ + +### 6. COMPLEX_PREDICATES +- Correlated subqueries in WHERE +- EXISTS/NOT EXISTS +- ANY/ALL operators +- IN with subquery + +**Examples:** +- SELECT id FROM Customer WHERE EXISTS (SELECT 1 FROM Order WHERE customer_id = Customer.id) -- ❌ +- SELECT id FROM Customer WHERE id IN (SELECT customer_id FROM Order) -- ❌ + +### 7. MODIFICATIONS +- INSERT, UPDATE, DELETE, MERGE +- CREATE, ALTER, DROP (DDL) +- TRUNCATE +- Transactions (BEGIN, COMMIT, ROLLBACK) + +**Examples:** +- INSERT INTO Customer VALUES (...) -- ❌ +- UPDATE Customer SET name = 'John' -- ❌ +- DELETE FROM Customer -- ❌ + +### 8. UNSUPPORTED_CLAUSES +- HAVING without GROUP BY +- LIMIT without explicit value (e.g., LIMIT ALL) +- OFFSET without LIMIT +- FOR UPDATE / FOR SHARE +- INTO clause (SELECT INTO) + +**Examples:** +- SELECT AVG(salary) FROM Employee HAVING AVG(salary) > 50000 -- ❌ +- SELECT id FROM Customer OFFSET 10 -- ❌ (no LIMIT) + +--- + +## CRITICAL RULES + +1. **ALWAYS use explicit column names** - Never use SELECT * +2. **Use COUNT(column_name)** - Never use COUNT(*) +3. **Only LEFT JOIN** - No RIGHT JOIN, FULL OUTER JOIN, or CROSS JOIN +4. **Maximum 4 tables** - No more than 4 tables in a JOIN chain +5. **No subqueries** - No subqueries in any clause +6. **No CTEs** - No WITH clauses +7. **No window functions** - No ROW_NUMBER, RANK, PARTITION BY, etc. +8. **Explicit GROUP BY** - All non-aggregated columns in SELECT must be in GROUP BY +9. **Simple aggregations only** - No DISTINCT in aggregates +10. **ORDER BY only selected columns** - Cannot ORDER BY columns not in SELECT list diff --git a/src/uipath_langchain/agent/tools/datafabric_tool/system_prompt.txt b/src/uipath_langchain/agent/tools/datafabric_tool/system_prompt.txt new file mode 100644 index 00000000..aed1e659 --- /dev/null +++ b/src/uipath_langchain/agent/tools/datafabric_tool/system_prompt.txt @@ -0,0 +1,135 @@ +You are a SQL expert specialized in converting natural language questions into SQL queries. + +Given a database schema, translate the user's natural language question into a valid SQL query. + +IMPORTANT: You MUST follow the SQL constraints defined in sql_constraints.txt. These constraints define supported and unsupported SQL patterns for SQLite. + +Rules: +1. Return ONLY the SQL query without any explanations, markdown formatting, or additional text +2. Use SQLite syntax +3. Ensure the query is syntactically correct +4. Use appropriate JOINs, WHERE clauses, and aggregations as needed +5. Include LIMIT clauses when appropriate to prevent returning too many rows +6. Use the exact table and column names from the provided schema +7. For financial values (salary, price, etc.), use ROUND() function +8. Handle NULL values appropriately with COALESCE() or IFNULL() + +SUPPORTED SCENARIOS (Use these patterns): + +1. Single-Entity Baselines: + - Simple projections: SELECT id, name FROM Customer + - Single-field predicates: =, <>, >, <, BETWEEN, IN, LIKE + - WHERE with AND/OR & parentheses: WHERE age >= 21 AND (region='APAC' OR vip=1) + - IS NULL/IS NOT NULL: WHERE deleted_at IS NULL + +2. Multi-Entity Joins (≤4 tables): + - LEFT JOIN chains (up to 4 tables): SELECT o.id, c.name FROM Order o LEFT JOIN Customer c ON o.customer_id = c.id + - Null-preserving semantics + +3. Predicate Distribution: + - Table-scoped predicates: WHERE c.country='IN' AND o.total>1000 + - Empty IN () evaluates to FALSE + +4. Aggregations & Grouping: + - GROUP BY entity fields: SELECT country, COUNT(id) FROM Customer GROUP BY country + - Supported functions: SUM, AVG, MIN, MAX, COUNT(column) + - Simple expressions in aggregates: SELECT SUM(price*qty) FROM LineItem + - HAVING on aggregates or plain fields: HAVING COUNT(id)>10 + +5. Expressions (Minimal): + - CASE (simple/searched) in SELECT/WHERE/ORDER BY + - Arithmetic: + - * / (SQLite-compatible) + - String functions: COALESCE, NULLIF, ||, LOWER, UPPER, TRIM, LTRIM, RTRIM + - Math functions: ROUND, ABS + - Example: SELECT CASE WHEN age>=18 THEN 'Adult' ELSE 'Minor' END AS segment FROM Customer + +6. Casting & Coercion: + - CAST among common scalar types: SELECT CAST(amount AS DECIMAL(12,2)) FROM Payments + - Prefer explicit casts + +7. Ordering & Pagination: + - ORDER BY fields/aliases/expressions: SELECT price*qty AS amt FROM LineItem ORDER BY amt DESC + - Multi-column ordering: ORDER BY country, name + - LIMIT/OFFSET: LIMIT 50 OFFSET 100 + +8. DISTINCT: + - SELECT DISTINCT country FROM Customer ORDER BY country + +9. Aliasing: + - Column aliases: SELECT name AS customer_name + - Reuse aliases in ORDER BY: SELECT price*qty AS amt FROM LineItem ORDER BY amt + +UNSUPPORTED SCENARIOS (Avoid these patterns): + +1. METADATA_RESOLUTION: + - Unknown/non-existent tables or entities + - More than 4 JOINs in a query + - Relationship cycles + +2. SQL_PARSING: + - Malformed SQL or SQL injection attempts + - UNION/INTERSECT/EXCEPT + - WITH (CTEs - Common Table Expressions) + - VALUES clause + - PRAGMA statements + +3. VALIDATION_GUARDRAIL: + - SELECT * (always specify columns explicitly) + - COUNT(*) or COUNT(1) (use COUNT(column_name) instead) + - Aggregates on literals + - HAVING without GROUP BY + - DISTINCT on constants + - Invalid LIMIT/OFFSET values + +4. UNSUPPORTED_CONSTRUCTS - Subqueries/Windows/DML/DDL: + - ANY subqueries or derived tables: WHERE x IN (SELECT ...) + - Window functions: ROW_NUMBER() OVER(...) + - DML: UPDATE, INSERT, DELETE + - DDL: CREATE, ALTER, DROP + - Temporary objects or transactions + +5. UNSUPPORTED_CONSTRUCTS - Joins: + - RIGHT JOIN, FULL OUTER JOIN, CROSS JOIN + - Non-equi join conditions: ON a.created_at > b.created_at + - Self-joins + - LATERAL/APPLY + +6. PARTITIONING: + - Disconnected tables with no join path + - Contradictory references causing cycles + +7. UNSUPPORTED_CONSTRUCTS - Advanced Aggregation: + - ROLLUP/CUBE/GROUPING SETS + - Approximate or ordered-set aggregates + - PERCENTILE functions + - Multi-table DISTINCT aggregates + +8. VALIDATION_GUARDRAIL - Ordering: + - ORDER BY ordinals: ORDER BY 1 + - NULLS FIRST/LAST + - TOP n syntax + - WITH TIES + - COLLATE clause + +9. UNSUPPORTED_CONSTRUCTS - Advanced Functions: + - REGEXP/SIMILAR TO/ILIKE + - Advanced math functions beyond basic arithmetic + - Date truncation/extraction beyond SQLite basics + - User-defined functions (UDFs) + +10. VALIDATION_GUARDRAIL - Types: + - JSON/ARRAY/MAP/GEOMETRY/BLOB operations + - Complex timezone-aware timestamp operations + +RETRY BEHAVIOR: +If a query fails with a validation error (e.g. missing LIMIT, SELECT *, COUNT(*)), DO NOT give up. Instead: +1. Read the error message carefully +2. Fix the query to comply with the constraint +3. Call query_datafabric again with the corrected query + +Example: if "Queries without WHERE must include a LIMIT clause" is returned, add LIMIT 100 and retry: + Before: SELECT name, department FROM Employee + After: SELECT name, department FROM Employee LIMIT 100 + + +Return only the SQL query as plain text. \ No newline at end of file diff --git a/src/uipath_langchain/agent/tools/tool_factory.py b/src/uipath_langchain/agent/tools/tool_factory.py index 8f7bc5ca..1c593f1c 100644 --- a/src/uipath_langchain/agent/tools/tool_factory.py +++ b/src/uipath_langchain/agent/tools/tool_factory.py @@ -19,6 +19,7 @@ from uipath_langchain.chat.hitl import REQUIRE_CONVERSATIONAL_CONFIRMATION from .context_tool import create_context_tool +from .datafabric_tool import create_datafabric_tools from .escalation_tool import create_escalation_tool from .extraction_tool import create_ixp_extraction_tool from .integration_tool import create_integration_tool @@ -32,9 +33,22 @@ async def create_tools_from_resources( agent: LowCodeAgentDefinition, llm: BaseChatModel ) -> list[BaseTool]: + """Create tools from agent resources including Data Fabric tools. + + Args: + agent: The agent definition. + llm: The language model for tool creation. + + Returns: + List of BaseTool instances. + """ tools: list[BaseTool] = [] logger.info("Creating tools for agent '%s' from resources", agent.name) + + # Register the generic Data Fabric query tool (no fetching/schema here) + tools.extend(create_datafabric_tools(agent)) + for resource in agent.resources: if not resource.is_enabled: logger.info( @@ -75,6 +89,12 @@ async def _build_tool_for_resource( return create_process_tool(resource) elif isinstance(resource, AgentContextResourceConfig): + if resource.is_datafabric: + logger.info( + "Skipping Data Fabric context '%s' - handled separately", + resource.name, + ) + return None return create_context_tool(resource) elif isinstance(resource, AgentEscalationResourceConfig): diff --git a/tests/agent/react/test_create_agent.py b/tests/agent/react/test_create_agent.py index cfe9cc4f..224580c0 100644 --- a/tests/agent/react/test_create_agent.py +++ b/tests/agent/react/test_create_agent.py @@ -151,6 +151,7 @@ def test_autonomous_agent_with_tools( messages, None, # input schema False, # is_conversational + resources_for_init=None, ) mock_create_terminate_node.assert_called_once_with( None, # output schema @@ -246,6 +247,7 @@ def test_conversational_agent_with_tools( messages, None, # input schema True, # is_conversational + resources_for_init=None, ) mock_create_terminate_node.assert_called_once_with( None, # output schema diff --git a/tests/agent/react/test_init_node.py b/tests/agent/react/test_init_node.py index b9c9919f..aec32d68 100644 --- a/tests/agent/react/test_init_node.py +++ b/tests/agent/react/test_init_node.py @@ -16,6 +16,7 @@ class MockState(BaseModel): messages: list[Any] = [] +@pytest.mark.asyncio class TestCreateInitNodeConversational: """Test cases for create_init_node with is_conversational=True.""" @@ -43,7 +44,7 @@ def empty_state(self): """Fixture for state with no messages.""" return MockState(messages=[]) - def test_conversational_empty_state_returns_overwrite( + async def test_conversational_empty_state_returns_overwrite( self, system_message, user_message ): """Conversational mode with empty state should use Overwrite with new messages.""" @@ -53,7 +54,7 @@ def test_conversational_empty_state_returns_overwrite( ) state = MockState(messages=[]) - result = init_node(state) + result = await init_node(state) assert "messages" in result assert isinstance(result["messages"], Overwrite) @@ -63,7 +64,7 @@ def test_conversational_empty_state_returns_overwrite( assert overwrite_value[0] == system_message assert overwrite_value[1] == user_message - def test_conversational_resume_replaces_system_message( + async def test_conversational_resume_replaces_system_message( self, new_system_message, user_message ): """Conversational mode should replace old SystemMessage when resuming.""" @@ -79,7 +80,7 @@ def test_conversational_resume_replaces_system_message( new_messages, input_schema=None, is_conversational=True ) - result = init_node(state) + result = await init_node(state) assert isinstance(result["messages"], Overwrite) overwrite_value = result["messages"].value @@ -89,7 +90,7 @@ def test_conversational_resume_replaces_system_message( assert overwrite_value[1] == user_message assert overwrite_value[2] == old_human_message - def test_conversational_resume_preserves_non_system_first_message( + async def test_conversational_resume_preserves_non_system_first_message( self, system_message, user_message ): """Conversational mode should preserve all messages if first is not SystemMessage.""" @@ -102,7 +103,7 @@ def test_conversational_resume_preserves_non_system_first_message( new_messages, input_schema=None, is_conversational=True ) - result = init_node(state) + result = await init_node(state) assert isinstance(result["messages"], Overwrite) overwrite_value = result["messages"].value @@ -112,10 +113,10 @@ def test_conversational_resume_preserves_non_system_first_message( assert overwrite_value[1] == user_message assert overwrite_value[2] == existing_human_message - def test_conversational_with_callable_messages(self): + async def test_conversational_with_callable_messages(self): """Conversational mode should work with callable message generators.""" - def message_generator(state): + def message_generator(state, **kwargs): return [ SystemMessage( content=f"System for state with {len(state.messages)} messages" @@ -128,7 +129,7 @@ def message_generator(state): message_generator, input_schema=None, is_conversational=True ) - result = init_node(state) + result = await init_node(state) assert isinstance(result["messages"], Overwrite) overwrite_value = result["messages"].value @@ -136,6 +137,7 @@ def message_generator(state): assert "System for state with 0 messages" in overwrite_value[0].content +@pytest.mark.asyncio class TestCreateInitNodeNonConversational: """Test cases for create_init_node with is_conversational=False (default).""" @@ -149,7 +151,7 @@ def user_message(self): """Fixture for a user message.""" return HumanMessage(content="Hello") - def test_non_conversational_returns_list_not_overwrite( + async def test_non_conversational_returns_list_not_overwrite( self, system_message, user_message ): """Non-conversational mode should return list, not Overwrite.""" @@ -159,7 +161,7 @@ def test_non_conversational_returns_list_not_overwrite( ) state = MockState(messages=[]) - result = init_node(state) + result = await init_node(state) assert "messages" in result # Non-conversational mode returns a list (for add_messages reducer to append) @@ -167,22 +169,25 @@ def test_non_conversational_returns_list_not_overwrite( assert not isinstance(result["messages"], Overwrite) assert len(result["messages"]) == 2 - def test_non_conversational_default_behavior(self, system_message, user_message): + async def test_non_conversational_default_behavior( + self, system_message, user_message + ): """Default behavior (no is_conversational param) should be non-conversational.""" messages = [system_message, user_message] init_node = create_init_node(messages, input_schema=None) state = MockState(messages=[]) - result = init_node(state) + result = await init_node(state) assert isinstance(result["messages"], list) assert not isinstance(result["messages"], Overwrite) +@pytest.mark.asyncio class TestCreateInitNodeInnerState: """Test cases for init node inner_state initialization.""" - def test_returns_inner_state_with_job_attachments(self): + async def test_returns_inner_state_with_job_attachments(self): """Init node should return inner_state with job_attachments dict.""" messages: list[SystemMessage | HumanMessage] = [ SystemMessage(content="System"), @@ -193,13 +198,13 @@ def test_returns_inner_state_with_job_attachments(self): ) state = MockState(messages=[]) - result = init_node(state) + result = await init_node(state) assert "inner_state" in result assert "job_attachments" in result["inner_state"] assert isinstance(result["inner_state"]["job_attachments"], dict) - def test_inner_state_present_in_conversational_mode(self): + async def test_inner_state_present_in_conversational_mode(self): """Inner state should also be present in conversational mode.""" messages: list[SystemMessage | HumanMessage] = [ SystemMessage(content="System"), @@ -210,12 +215,12 @@ def test_inner_state_present_in_conversational_mode(self): ) state = MockState(messages=[]) - result = init_node(state) + result = await init_node(state) assert "inner_state" in result assert "job_attachments" in result["inner_state"] - def test_conversational_merges_attachments_from_preserved_messages(self): + async def test_conversational_merges_attachments_from_preserved_messages(self): """Conversational mode should merge attachments from preserved message metadata.""" attachment_id = "a940a416-b97b-4146-3089-08de5f4d0a87" old_system_message = SystemMessage(content="Old system") @@ -240,14 +245,14 @@ def test_conversational_merges_attachments_from_preserved_messages(self): new_messages, input_schema=None, is_conversational=True ) - result = init_node(state) + result = await init_node(state) job_attachments = result["inner_state"]["job_attachments"] assert attachment_id in job_attachments assert job_attachments[attachment_id].full_name == "document.pdf" assert job_attachments[attachment_id].mime_type == "application/pdf" - def test_initial_message_count_in_non_conversational_mode(self): + async def test_initial_message_count_in_non_conversational_mode(self): """Non-conversational mode should set initial_message_count.""" messages: list[SystemMessage | HumanMessage] = [ SystemMessage(content="System"), @@ -258,13 +263,13 @@ def test_initial_message_count_in_non_conversational_mode(self): ) state = MockState(messages=[]) - result = init_node(state) + result = await init_node(state) assert "initial_message_count" in result["inner_state"] # In non-conversational mode, messages is a list assert result["inner_state"]["initial_message_count"] == 2 - def test_initial_message_count_in_conversational_mode(self): + async def test_initial_message_count_in_conversational_mode(self): """Conversational mode should set initial_message_count based on Overwrite.""" messages: list[SystemMessage | HumanMessage] = [ SystemMessage(content="System"), @@ -276,7 +281,7 @@ def test_initial_message_count_in_conversational_mode(self): ) state = MockState(messages=[]) - result = init_node(state) + result = await init_node(state) assert "initial_message_count" in result["inner_state"] assert result["inner_state"]["initial_message_count"] == 3