From a29ef1e228cd75bf4d2211b9bd2fbc569be3a87a Mon Sep 17 00:00:00 2001 From: Harshit Rohatgi Date: Fri, 13 Feb 2026 01:12:30 +0530 Subject: [PATCH 1/8] Integration-e2e --- pyproject.toml | 4 + src/uipath_langchain/agent/tools/__init__.py | 10 + .../agent/tools/context_tool.py | 6 + .../agent/tools/datafabric_tool/__init__.py | 15 + .../tools/datafabric_tool/datafabric_tool.py | 369 ++++++++++++++++++ .../tools/datafabric_tool/sql_constraints.txt | 205 ++++++++++ .../tools/datafabric_tool/system_prompt.txt | 124 ++++++ .../agent/tools/tool_factory.py | 30 +- uv.lock | 2 + 9 files changed, 763 insertions(+), 2 deletions(-) create mode 100644 src/uipath_langchain/agent/tools/datafabric_tool/__init__.py create mode 100644 src/uipath_langchain/agent/tools/datafabric_tool/datafabric_tool.py create mode 100644 src/uipath_langchain/agent/tools/datafabric_tool/sql_constraints.txt create mode 100644 src/uipath_langchain/agent/tools/datafabric_tool/system_prompt.txt diff --git a/pyproject.toml b/pyproject.toml index b6424d0c6..63ba42db9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -67,6 +67,7 @@ dev = [ "numpy>=1.24.0", "pytest_httpx>=0.35.0", "rust-just>=1.39.0", + "uipath-langchain", ] [tool.hatch.build.targets.wheel] @@ -116,3 +117,6 @@ name = "testpypi" url = "https://test.pypi.org/simple/" publish-url = "https://test.pypi.org/legacy/" explicit = true + +[tool.uv.sources] +uipath-langchain = { workspace = true } diff --git a/src/uipath_langchain/agent/tools/__init__.py b/src/uipath_langchain/agent/tools/__init__.py index 2f2f8214e..ce29925a5 100644 --- a/src/uipath_langchain/agent/tools/__init__.py +++ b/src/uipath_langchain/agent/tools/__init__.py @@ -1,6 +1,12 @@ """Tool creation and management for LowCode agents.""" from .context_tool import create_context_tool +from .datafabric_tool import ( + create_datafabric_tools, + fetch_entity_schemas, + format_schemas_for_context, + get_datafabric_contexts, +) from .escalation_tool import create_escalation_tool from .extraction_tool import create_ixp_extraction_tool from .integration_tool import create_integration_tool @@ -15,11 +21,15 @@ "create_tools_from_resources", "create_tool_node", "create_context_tool", + "create_datafabric_tools", "create_process_tool", "create_integration_tool", "create_escalation_tool", "create_ixp_extraction_tool", "create_ixp_escalation_tool", + "fetch_entity_schemas", + "format_schemas_for_context", + "get_datafabric_contexts", "UiPathToolNode", "ToolWrapperMixin", ] diff --git a/src/uipath_langchain/agent/tools/context_tool.py b/src/uipath_langchain/agent/tools/context_tool.py index ac917863c..b970deaf5 100644 --- a/src/uipath_langchain/agent/tools/context_tool.py +++ b/src/uipath_langchain/agent/tools/context_tool.py @@ -40,6 +40,12 @@ def create_context_tool(resource: AgentContextResourceConfig) -> StructuredTool: return handle_deep_rag(tool_name, resource) elif retrieval_mode == AgentContextRetrievalMode.BATCH_TRANSFORM.value.lower(): return handle_batch_transform(tool_name, resource) + elif retrieval_mode == AgentContextRetrievalMode.DATA_FABRIC.value.lower(): + # Data Fabric contexts are handled by create_datafabric_tools() in tool_factory.py + raise ValueError( + "Data Fabric context should be handled via create_datafabric_tools(), " + "not create_context_tool()" + ) else: return handle_semantic_search(tool_name, resource) diff --git a/src/uipath_langchain/agent/tools/datafabric_tool/__init__.py b/src/uipath_langchain/agent/tools/datafabric_tool/__init__.py new file mode 100644 index 000000000..ccfcf433d --- /dev/null +++ b/src/uipath_langchain/agent/tools/datafabric_tool/__init__.py @@ -0,0 +1,15 @@ +"""Data Fabric tool module for entity-based SQL queries.""" + +from .datafabric_tool import ( + create_datafabric_tools, + fetch_entity_schemas, + format_schemas_for_context, + get_datafabric_contexts, +) + +__all__ = [ + "create_datafabric_tools", + "fetch_entity_schemas", + "format_schemas_for_context", + "get_datafabric_contexts", +] diff --git a/src/uipath_langchain/agent/tools/datafabric_tool/datafabric_tool.py b/src/uipath_langchain/agent/tools/datafabric_tool/datafabric_tool.py new file mode 100644 index 000000000..5709c36b7 --- /dev/null +++ b/src/uipath_langchain/agent/tools/datafabric_tool/datafabric_tool.py @@ -0,0 +1,369 @@ +"""Data Fabric tool creation for entity-based queries. + +This module provides functionality to: +1. Fetch and format entity schemas for agent context hydration +2. Create SQL-based query tools for the agent +""" + +import logging +from functools import lru_cache +from pathlib import Path +from typing import Any + +from langchain_core.tools import BaseTool +from pydantic import BaseModel, Field +from uipath.agent.models.agent import ( + AgentContextResourceConfig, + LowCodeAgentDefinition, +) +from uipath.platform.entities import Entity, FieldMetadata + +from ..base_uipath_structured_tool import BaseUiPathStructuredTool +from ..utils import sanitize_tool_name + +logger = logging.getLogger(__name__) + +# --- Prompt and Constraints Loading --- + +_PROMPTS_DIR = Path(__file__).parent + + +@lru_cache(maxsize=1) +def _load_sql_constraints() -> str: + """Load SQL constraints from sql_constraints.txt.""" + constraints_path = _PROMPTS_DIR / "sql_constraints.txt" + try: + return constraints_path.read_text(encoding="utf-8") + except FileNotFoundError: + logger.warning(f"SQL constraints file not found: {constraints_path}") + return "" + + +@lru_cache(maxsize=1) +def _load_system_prompt() -> str: + """Load SQL generation strategy from system_prompt.txt.""" + prompt_path = _PROMPTS_DIR / "system_prompt.txt" + try: + return prompt_path.read_text(encoding="utf-8") + except FileNotFoundError: + logger.warning(f"System prompt file not found: {prompt_path}") + return "" + + +# --- Schema Fetching and Formatting --- + + +async def fetch_entity_schemas(entity_identifiers: list[str]) -> list[Entity]: + """Fetch entity metadata from Data Fabric for the given entity identifiers. + + Args: + entity_identifiers: List of entity identifiers to fetch. + + Returns: + List of Entity objects with full schema information. + """ + from uipath.platform import UiPath + + sdk = UiPath() + entities: list[Entity] = [] + + for entity_identifier in entity_identifiers: + try: + entity = await sdk.entities.retrieve_async(entity_identifier) + entities.append(entity) + logger.info(f"Fetched schema for entity '{entity.display_name}'") + except Exception as e: + logger.warning(f"Failed to fetch entity '{entity_identifier}': {e}") + + return entities + + +def format_field_type(field: FieldMetadata) -> str: + """Format a field's type information for display.""" + type_name = field.sql_type.name if field.sql_type else "unknown" + + modifiers = [] + if field.is_primary_key: + modifiers.append("PK") + if field.is_foreign_key and field.reference_entity: + ref_name = field.reference_entity.display_name or field.reference_entity.name + modifiers.append(f"FK → {ref_name}") + if field.is_required: + modifiers.append("required") + + if modifiers: + return f"{type_name}, {', '.join(modifiers)}" + return type_name + + +def format_schemas_for_context(entities: list[Entity]) -> str: + """Format entity schemas as markdown for injection into agent system prompt. + + The output is optimized for SQL query generation by the LLM. + + Args: + entities: List of Entity objects with schema information. + + Returns: + Markdown-formatted string describing entity schemas. + """ + if not entities: + return "" + + lines = [ + "## Available Data Fabric Entities", + "", + ] + + for entity in entities: + display_name = entity.display_name or entity.name + lines.append(f"### Entity: {display_name}") + if entity.description: + lines.append(f"_{entity.description}_") + lines.append("") + lines.append("| Field | Type |") + lines.append("|-------|------|") + + for field in entity.fields or []: + if field.is_hidden_field or field.is_system_field: + continue + field_name = field.display_name or field.name + field_type = format_field_type(field) + lines.append(f"| {field_name} | {field_type} |") + + lines.append("") + + return "\n".join(lines) + + +# --- Data Fabric Context Detection --- + + +def get_datafabric_contexts( + agent: LowCodeAgentDefinition, +) -> list[AgentContextResourceConfig]: + """Extract Data Fabric context resources from agent definition. + + Args: + agent: The agent definition to search. + + Returns: + List of context resources configured for Data Fabric retrieval mode. + """ + datafabric_contexts: list[AgentContextResourceConfig] = [] + + for resource in agent.resources: + if not isinstance(resource, AgentContextResourceConfig): + continue + if not resource.is_enabled: + continue + if resource.settings.retrieval_mode.lower() == "datafabric": + datafabric_contexts.append(resource) + + return datafabric_contexts + + +# --- Tool Creation --- + + +class QueryEntityInput(BaseModel): + """Input schema for the query_entity tool.""" + + entity_identifier: str = Field( + ..., description="The entity identifier to query" + ) + sql_where: str = Field( + default="", + description="SQL WHERE clause to filter records (without the WHERE keyword). " + "Example: 'Status = \"Active\" AND Amount > 100'", + ) + limit: int = Field( + default=1000, description="Maximum number of records to return" + ) + + +class QueryEntityOutput(BaseModel): + """Output schema for the query_entity tool.""" + + records: list[dict[str, Any]] = Field( + ..., description="List of entity records matching the query" + ) + total_count: int = Field(..., description="Total number of matching records") + + +async def create_datafabric_tools( + agent: LowCodeAgentDefinition, +) -> tuple[list[BaseTool], str]: + """Create Data Fabric tools and schema context from agent definition. + + This function: + 1. Finds all Data Fabric context resources in the agent + 2. Fetches entity schemas for context hydration + 3. Returns tools and schema context string + + Args: + agent: The agent definition containing Data Fabric context resources. + + Returns: + Tuple of (tools, schema_context) where: + - tools: List of BaseTool instances for querying entities + - schema_context: Formatted schema string to inject into system prompt + """ + tools: list[BaseTool] = [] + all_entities: list[Entity] = [] + + datafabric_contexts = get_datafabric_contexts(agent) + + if not datafabric_contexts: + return tools, "" + + logger.info(f"Found {len(datafabric_contexts)} Data Fabric context resource(s)") + + for context in datafabric_contexts: + entity_identifiers = context.settings.entity_identifiers or [] + + if not entity_identifiers: + logger.warning( + f"Data Fabric context '{context.name}' has no entity_identifiers configured" + ) + continue + + # Fetch entity schemas + entities = await fetch_entity_schemas(entity_identifiers) + all_entities.extend(entities) + + context_tools = _create_sdk_based_tools(context, entities) + tools.extend(context_tools) + + logger.info( + f"Created {len(context_tools)} tools for Data Fabric context '{context.name}'" + ) + + # Format all entity schemas for context injection + schema_context = format_schemas_for_context(all_entities) + + return tools, schema_context + + +def _create_sdk_based_tools( + context: AgentContextResourceConfig, + entities: list[Entity], +) -> list[BaseTool]: + """Create SDK-based tools for querying entities using SQL. + + Each tool accepts a full SQL query and executes it via the SDK's + query_entity_records_async method. + """ + tools: list[BaseTool] = [] + MAX_RECORDS_IN_RESPONSE = 50 # Limit records to prevent context overflow + + for entity in entities: + tool_name = sanitize_tool_name(f"query_{entity.name}") + + # Create a closure to capture the entity name + entity_display_name = entity.display_name or entity.name + + async def query_fn( + sql_query: str, + _entity_name: str = entity_display_name, + _max_records: int = MAX_RECORDS_IN_RESPONSE, + ) -> dict[str, Any]: + """Execute a SQL query against the Data Fabric entity.""" + from uipath.platform import UiPath + + print(f"[DEBUG] query_fn called for entity '{_entity_name}' with SQL: {sql_query}") + logger.info(f"Executing SQL query for entity '{_entity_name}': {sql_query}") + + sdk = UiPath() + try: + records = await sdk.entities.query_entity_records_async( + sql_query=sql_query, + ) + total_count = len(records) + truncated = total_count > _max_records + returned_records = records[:_max_records] if truncated else records + + print(f"[DEBUG] Retrieved {total_count} records, returning {len(returned_records)} for entity '{_entity_name}'") + + result = { + "records": returned_records, + "total_count": total_count, + "returned_count": len(returned_records), + "entity": _entity_name, + "sql_query": sql_query, + } + if truncated: + result["truncated"] = True + result["message"] = f"Showing {len(returned_records)} of {total_count} records. Use more specific filters or LIMIT to narrow results." + return result + except Exception as e: + logger.error(f"SQL query failed for entity '{_entity_name}': {e}") + return { + "records": [], + "total_count": 0, + "error": str(e), + "sql_query": sql_query, + } + + entity_description = entity.description or f"Query {entity_display_name} records" + + # Extract actual field names from entity schema (exclude system/hidden fields) + field_names = [ + f.display_name or f.name + for f in (entity.fields or []) + if not f.is_hidden_field and not f.is_system_field + ] + fields_str = ", ".join(field_names[:10]) # Limit to first 10 fields + if len(field_names) > 10: + fields_str += f", ... ({len(field_names)} total)" + + # Identify categorical fields for segmentation (text-like, non-PK) + categorical_fields = [ + f.display_name or f.name + for f in (entity.fields or []) + if not f.is_hidden_field and not f.is_system_field + and not f.is_primary_key + and f.sql_type and f.sql_type.name.lower() in ("text", "nvarchar", "varchar", "string", "ntext") + ] + segment_field = categorical_fields[0] if categorical_fields else (field_names[1] if len(field_names) > 1 else "Category") + count_field = field_names[0] if field_names else "Id" + + # Build intent-based query examples using actual entity fields + intent_examples = ( + f"QUERY PATTERNS for {entity_display_name}:\n" + f"- 'show all' → SELECT {fields_str} FROM {entity_display_name} LIMIT 100\n" + f"- 'top N by X' → SELECT {fields_str} FROM {entity_display_name} ORDER BY X DESC LIMIT N\n" + f"- 'top N segments/categories/groups' → SELECT {segment_field}, COUNT({count_field}) as count FROM {entity_display_name} GROUP BY {segment_field} ORDER BY count DESC LIMIT N\n" + f"- 'filter by X=value' → SELECT {fields_str} FROM {entity_display_name} WHERE X = 'value' LIMIT 100\n" + f"- 'average/sum/count of X' → SELECT AVG(X) as avg_x FROM {entity_display_name} LIMIT 1\n" + f"RULES: ALWAYS use explicit columns (no SELECT *). ALWAYS include LIMIT. Extract filter values from user message." + ) + + tools.append( + BaseUiPathStructuredTool( + name=tool_name, + description=( + f"{context.description}. {entity_description}. " + f"Available fields: {fields_str}. " + f"Generate SQL based on user's request." + ), + args_schema={ + "type": "object", + "properties": { + "sql_query": { + "type": "string", + "description": intent_examples, + }, + }, + "required": ["sql_query"], + }, + coroutine=query_fn, + metadata={ + "tool_type": "datafabric_sql", + "display_name": f"Query {entity_display_name}", + "entity_name": entity_display_name, + }, + ) + ) + + return tools diff --git a/src/uipath_langchain/agent/tools/datafabric_tool/sql_constraints.txt b/src/uipath_langchain/agent/tools/datafabric_tool/sql_constraints.txt new file mode 100644 index 000000000..804537d54 --- /dev/null +++ b/src/uipath_langchain/agent/tools/datafabric_tool/sql_constraints.txt @@ -0,0 +1,205 @@ +# SQL Query Constraints for SQLite + +## SUPPORTED SCENARIOS + +### 1. Single-Entity Baselines +- Simple projections with explicit column names (NO SELECT *) +- Single-field predicates: =, <>, >, <, >=, <=, BETWEEN, IN, LIKE +- WHERE clauses with AND/OR & parentheses +- IS NULL / IS NOT NULL + +**Examples:** +- SELECT id, name FROM Customer +- SELECT id, name FROM Customer WHERE age >= 21 AND (region='APAC' OR vip=1) +- SELECT id, name FROM Customer WHERE deleted_at IS NULL + +### 2. Multi-Entity Joins (≤4 adapters) +- LEFT JOIN chains via entity model (up to 4 tables) +- Optional adapters pruned +- Shared intermediates +- Null-preserving semantics + +**Examples:** +- SELECT o.id, c.name FROM Order o LEFT JOIN Customer c ON o.customer_id = c.id +- Fields spanning 3-4 adapters with proper LEFT JOIN chains + +### 3. Predicate Distribution & Pushdown +- Adapter-scoped predicates pushed down +- Cross-adapter/global predicates at root +- Empty IN () treated as FALSE + +**Examples:** +- SELECT c.id, c.name FROM Customer c WHERE c.country='IN' AND c.total>1000 +- SELECT id FROM Customer WHERE id IN () -- evaluates to FALSE + +### 4. Aggregations & Grouping (Basic) +- GROUP BY entity fields +- Aggregate functions: SUM, AVG, MIN, MAX, COUNT(column_name) +- Simple expressions in aggregates +- HAVING on aggregates or plain fields + +**Examples:** +- SELECT country, COUNT(id) FROM Customer GROUP BY country +- SELECT dept, SUM(price*qty) as total FROM LineItem GROUP BY dept +- SELECT country, COUNT(id) as cnt FROM Customer GROUP BY country HAVING COUNT(id)>10 + +### 5. Expressions (Minimal) +- CASE (simple/searched) in SELECT/WHERE/ORDER BY +- Arithmetic: +, -, *, / (SQLite-compatible) +- Functions: COALESCE, NULLIF, || +- String functions: LOWER, UPPER, TRIM, LTRIM, RTRIM +- Math functions: ROUND, ABS +- Note: CEIL/FLOOR may have limited adapter support + +**Examples:** +- SELECT CASE WHEN age>=18 THEN 'Adult' ELSE 'Minor' END AS segment FROM Customer +- SELECT COALESCE(nickname, name) as display_name FROM Customer +- SELECT ROUND(amount, 2) FROM Payments + +### 6. Casting & Coercion (Basic) +- CAST among common scalar types +- Implicit numeric widening where adapter accepts +- Explicit casts preferred at root + +**Examples:** +- SELECT CAST(amount AS DECIMAL(12,2)) FROM Payments +- SELECT CAST(id AS TEXT) FROM Customer + +### 7. Ordering & Pagination +- ORDER BY fields/aliases/expressions +- Multi-column ORDER BY +- LIMIT/OFFSET for pagination +- Note: Pagination without ORDER BY may produce non-deterministic results + +**Examples:** +- SELECT id, price*qty AS amt FROM LineItem ORDER BY amt DESC LIMIT 50 OFFSET 100 +- SELECT id, name FROM Customer ORDER BY name LIMIT 10 + +### 8. DISTINCT +- SELECT DISTINCT with explicit column names +- DISTINCT with ORDER BY on projected items/aliases + +**Examples:** +- SELECT DISTINCT country FROM Customer ORDER BY country +- SELECT DISTINCT dept, location FROM Employee + +### 9. Metadata Remapping & Aliasing +- Physical column remaps per adapter +- Alias reuse consistent across clauses +- Column aliases can be used in ORDER BY + +**Examples:** +- SELECT name AS customer_name FROM Customer ORDER BY customer_name + +--- + +## UNSUPPORTED SCENARIOS + +### 1. METADATA_RESOLUTION +- Unknown table/entity names +- Unknown column/field names +- Ambiguous column references without table prefix +- Field not in SELECT but used in ORDER BY (without alias) + +**Examples:** +- SELECT name FROM UnknownTable -- ❌ +- SELECT unknown_column FROM Customer -- ❌ +- SELECT id FROM Customer ORDER BY name -- ❌ (name not in SELECT) + +### 2. PROHIBITED_SQL_PATTERNS +- SELECT * FROM table -- ❌ Must use explicit column names +- Subqueries in FROM, WHERE, or SELECT +- UNION/UNION ALL/INTERSECT/EXCEPT +- Common Table Expressions (WITH/CTE) +- Window functions (ROW_NUMBER, RANK, PARTITION BY) +- Self-joins +- RIGHT JOIN or FULL OUTER JOIN (only LEFT JOIN supported) +- CROSS JOIN + +**Examples:** +- SELECT * FROM Customer -- ❌ +- SELECT id FROM (SELECT * FROM Customer) -- ❌ +- SELECT id FROM Customer UNION SELECT id FROM Order -- ❌ +- WITH cte AS (SELECT id FROM Customer) SELECT * FROM cte -- ❌ + +### 3. COMPLEX_AGGREGATIONS +- Nested aggregations: COUNT(DISTINCT(...)), SUM(DISTINCT(...)) +- Aggregations without GROUP BY on non-aggregated columns +- HAVING without GROUP BY +- COUNT(*) not allowed, use COUNT(column_name) instead + +**Examples:** +- SELECT COUNT(DISTINCT dept) FROM Employee -- ❌ +- SELECT name, COUNT(id) FROM Employee -- ❌ (name not in GROUP BY) +- SELECT AVG(salary) FROM Employee HAVING AVG(salary) > 50000 -- ❌ (no GROUP BY) + +### 4. ADVANCED_JOINS +- More than 4 tables in JOIN chain +- RIGHT JOIN +- FULL OUTER JOIN +- CROSS JOIN +- Self-joins +- Non-equi joins (theta joins) + +**Examples:** +- SELECT * FROM t1 RIGHT JOIN t2 -- ❌ +- SELECT * FROM t1, t2 -- ❌ (implicit CROSS JOIN) +- SELECT * FROM Employee e1 JOIN Employee e2 ON e1.manager_id = e2.id -- ❌ (self-join) + +### 5. UNSUPPORTED_FUNCTIONS +- Date/time manipulation functions (DATE_ADD, DATE_SUB, DATEDIFF) +- JSON functions (JSON_EXTRACT, JSON_ARRAY) +- Regex functions +- User-defined functions (UDFs) +- String aggregation (GROUP_CONCAT with complex separators) + +**Examples:** +- SELECT DATE_ADD(created_at, INTERVAL 1 DAY) FROM Order -- ❌ +- SELECT JSON_EXTRACT(data, '$.field') FROM Table -- ❌ + +### 6. COMPLEX_PREDICATES +- Correlated subqueries in WHERE +- EXISTS/NOT EXISTS +- ANY/ALL operators +- IN with subquery + +**Examples:** +- SELECT id FROM Customer WHERE EXISTS (SELECT 1 FROM Order WHERE customer_id = Customer.id) -- ❌ +- SELECT id FROM Customer WHERE id IN (SELECT customer_id FROM Order) -- ❌ + +### 7. MODIFICATIONS +- INSERT, UPDATE, DELETE, MERGE +- CREATE, ALTER, DROP (DDL) +- TRUNCATE +- Transactions (BEGIN, COMMIT, ROLLBACK) + +**Examples:** +- INSERT INTO Customer VALUES (...) -- ❌ +- UPDATE Customer SET name = 'John' -- ❌ +- DELETE FROM Customer -- ❌ + +### 8. UNSUPPORTED_CLAUSES +- HAVING without GROUP BY +- LIMIT without explicit value (e.g., LIMIT ALL) +- OFFSET without LIMIT +- FOR UPDATE / FOR SHARE +- INTO clause (SELECT INTO) + +**Examples:** +- SELECT AVG(salary) FROM Employee HAVING AVG(salary) > 50000 -- ❌ +- SELECT id FROM Customer OFFSET 10 -- ❌ (no LIMIT) + +--- + +## CRITICAL RULES + +1. **ALWAYS use explicit column names** - Never use SELECT * +2. **Use COUNT(column_name)** - Never use COUNT(*) +3. **Only LEFT JOIN** - No RIGHT JOIN, FULL OUTER JOIN, or CROSS JOIN +4. **Maximum 4 tables** - No more than 4 tables in a JOIN chain +5. **No subqueries** - No subqueries in any clause +6. **No CTEs** - No WITH clauses +7. **No window functions** - No ROW_NUMBER, RANK, PARTITION BY, etc. +8. **Explicit GROUP BY** - All non-aggregated columns in SELECT must be in GROUP BY +9. **Simple aggregations only** - No DISTINCT in aggregates +10. **ORDER BY only selected columns** - Cannot ORDER BY columns not in SELECT list diff --git a/src/uipath_langchain/agent/tools/datafabric_tool/system_prompt.txt b/src/uipath_langchain/agent/tools/datafabric_tool/system_prompt.txt new file mode 100644 index 000000000..03df458fb --- /dev/null +++ b/src/uipath_langchain/agent/tools/datafabric_tool/system_prompt.txt @@ -0,0 +1,124 @@ +You are a SQL expert specialized in converting natural language questions into SQL queries. + +Given a database schema, translate the user's natural language question into a valid SQL query. + +IMPORTANT: You MUST follow the SQL constraints defined in sql_constraints.txt. These constraints define supported and unsupported SQL patterns for SQLite. + +Rules: +1. Return ONLY the SQL query without any explanations, markdown formatting, or additional text +2. Use SQLite syntax +3. Ensure the query is syntactically correct +4. Use appropriate JOINs, WHERE clauses, and aggregations as needed +5. Include LIMIT clauses when appropriate to prevent returning too many rows +6. Use the exact table and column names from the provided schema +7. For financial values (salary, price, etc.), use ROUND() function +8. Handle NULL values appropriately with COALESCE() or IFNULL() + +SUPPORTED SCENARIOS (Use these patterns): + +1. Single-Entity Baselines: + - Simple projections: SELECT id, name FROM Customer + - Single-field predicates: =, <>, >, <, BETWEEN, IN, LIKE + - WHERE with AND/OR & parentheses: WHERE age >= 21 AND (region='APAC' OR vip=1) + - IS NULL/IS NOT NULL: WHERE deleted_at IS NULL + +2. Multi-Entity Joins (≤4 tables): + - LEFT JOIN chains (up to 4 tables): SELECT o.id, c.name FROM Order o LEFT JOIN Customer c ON o.customer_id = c.id + - Null-preserving semantics + +3. Predicate Distribution: + - Table-scoped predicates: WHERE c.country='IN' AND o.total>1000 + - Empty IN () evaluates to FALSE + +4. Aggregations & Grouping: + - GROUP BY entity fields: SELECT country, COUNT(id) FROM Customer GROUP BY country + - Supported functions: SUM, AVG, MIN, MAX, COUNT(column) + - Simple expressions in aggregates: SELECT SUM(price*qty) FROM LineItem + - HAVING on aggregates or plain fields: HAVING COUNT(id)>10 + +5. Expressions (Minimal): + - CASE (simple/searched) in SELECT/WHERE/ORDER BY + - Arithmetic: + - * / (SQLite-compatible) + - String functions: COALESCE, NULLIF, ||, LOWER, UPPER, TRIM, LTRIM, RTRIM + - Math functions: ROUND, ABS + - Example: SELECT CASE WHEN age>=18 THEN 'Adult' ELSE 'Minor' END AS segment FROM Customer + +6. Casting & Coercion: + - CAST among common scalar types: SELECT CAST(amount AS DECIMAL(12,2)) FROM Payments + - Prefer explicit casts + +7. Ordering & Pagination: + - ORDER BY fields/aliases/expressions: SELECT price*qty AS amt FROM LineItem ORDER BY amt DESC + - Multi-column ordering: ORDER BY country, name + - LIMIT/OFFSET: LIMIT 50 OFFSET 100 + +8. DISTINCT: + - SELECT DISTINCT country FROM Customer ORDER BY country + +9. Aliasing: + - Column aliases: SELECT name AS customer_name + - Reuse aliases in ORDER BY: SELECT price*qty AS amt FROM LineItem ORDER BY amt + +UNSUPPORTED SCENARIOS (Avoid these patterns): + +1. METADATA_RESOLUTION: + - Unknown/non-existent tables or entities + - More than 4 JOINs in a query + - Relationship cycles + +2. SQL_PARSING: + - Malformed SQL or SQL injection attempts + - UNION/INTERSECT/EXCEPT + - WITH (CTEs - Common Table Expressions) + - VALUES clause + - PRAGMA statements + +3. VALIDATION_GUARDRAIL: + - SELECT * (always specify columns explicitly) + - COUNT(*) or COUNT(1) (use COUNT(column_name) instead) + - Aggregates on literals + - HAVING without GROUP BY + - DISTINCT on constants + - Invalid LIMIT/OFFSET values + +4. UNSUPPORTED_CONSTRUCTS - Subqueries/Windows/DML/DDL: + - ANY subqueries or derived tables: WHERE x IN (SELECT ...) + - Window functions: ROW_NUMBER() OVER(...) + - DML: UPDATE, INSERT, DELETE + - DDL: CREATE, ALTER, DROP + - Temporary objects or transactions + +5. UNSUPPORTED_CONSTRUCTS - Joins: + - RIGHT JOIN, FULL OUTER JOIN, CROSS JOIN + - Non-equi join conditions: ON a.created_at > b.created_at + - Self-joins + - LATERAL/APPLY + +6. PARTITIONING: + - Disconnected tables with no join path + - Contradictory references causing cycles + +7. UNSUPPORTED_CONSTRUCTS - Advanced Aggregation: + - ROLLUP/CUBE/GROUPING SETS + - Approximate or ordered-set aggregates + - PERCENTILE functions + - Multi-table DISTINCT aggregates + +8. VALIDATION_GUARDRAIL - Ordering: + - ORDER BY ordinals: ORDER BY 1 + - NULLS FIRST/LAST + - TOP n syntax + - WITH TIES + - COLLATE clause + +9. UNSUPPORTED_CONSTRUCTS - Advanced Functions: + - REGEXP/SIMILAR TO/ILIKE + - Advanced math functions beyond basic arithmetic + - Date truncation/extraction beyond SQLite basics + - User-defined functions (UDFs) + +10. VALIDATION_GUARDRAIL - Types: + - JSON/ARRAY/MAP/GEOMETRY/BLOB operations + - Complex timezone-aware timestamp operations + +Return only the SQL query as plain text. \ No newline at end of file diff --git a/src/uipath_langchain/agent/tools/tool_factory.py b/src/uipath_langchain/agent/tools/tool_factory.py index 8a87fec87..5ba0bf1e5 100644 --- a/src/uipath_langchain/agent/tools/tool_factory.py +++ b/src/uipath_langchain/agent/tools/tool_factory.py @@ -17,6 +17,7 @@ ) from .context_tool import create_context_tool +from .datafabric_tool import create_datafabric_tools from .escalation_tool import create_escalation_tool from .extraction_tool import create_ixp_extraction_tool from .integration_tool import create_integration_tool @@ -29,10 +30,26 @@ async def create_tools_from_resources( agent: LowCodeAgentDefinition, llm: BaseChatModel -) -> list[BaseTool]: +) -> tuple[list[BaseTool], str]: + """Create tools from agent resources including Data Fabric tools. + + Args: + agent: The agent definition. + llm: The language model for tool creation. + + Returns: + Tuple of (tools, datafabric_schema_context). + """ tools: list[BaseTool] = [] + datafabric_schema_context: str = "" logger.info("Creating tools for agent '%s' from resources", agent.name) + + # Handle Data Fabric tools first (they need special handling) + datafabric_tools, schema_context = await create_datafabric_tools(agent) + tools.extend(datafabric_tools) + datafabric_schema_context = schema_context + for resource in agent.resources: if not resource.is_enabled: logger.info( @@ -54,7 +71,7 @@ async def create_tools_from_resources( else: tools.append(tool) - return tools + return tools, datafabric_schema_context async def _build_tool_for_resource( @@ -64,6 +81,15 @@ async def _build_tool_for_resource( return create_process_tool(resource) elif isinstance(resource, AgentContextResourceConfig): + # Skip Data Fabric contexts - handled separately via create_datafabric_tools() + retrieval_mode = resource.settings.retrieval_mode + mode_value = retrieval_mode.value if hasattr(retrieval_mode, 'value') else str(retrieval_mode) + if mode_value.lower() == "datafabric": + logger.info( + "Skipping Data Fabric context '%s' - handled separately", + resource.name, + ) + return None return create_context_tool(resource) elif isinstance(resource, AgentEscalationResourceConfig): diff --git a/uv.lock b/uv.lock index 395ec051b..7aa274c2f 100644 --- a/uv.lock +++ b/uv.lock @@ -3356,6 +3356,7 @@ dev = [ { name = "pytest-mock" }, { name = "ruff" }, { name = "rust-just" }, + { name = "uipath-langchain" }, { name = "virtualenv" }, ] @@ -3396,6 +3397,7 @@ dev = [ { name = "pytest-mock", specifier = ">=3.11.1" }, { name = "ruff", specifier = ">=0.9.4" }, { name = "rust-just", specifier = ">=1.39.0" }, + { name = "uipath-langchain", editable = "." }, { name = "virtualenv", specifier = ">=20.36.1" }, ] From 7ca356413892f0400af6d81dbfc9a2feb2dd44f0 Mon Sep 17 00:00:00 2001 From: Harshit Rohatgi Date: Fri, 13 Feb 2026 17:49:32 +0530 Subject: [PATCH 2/8] Added the prompts to be injected here. --- .../tools/datafabric_tool/datafabric_tool.py | 87 ++++++++++++------- 1 file changed, 58 insertions(+), 29 deletions(-) diff --git a/src/uipath_langchain/agent/tools/datafabric_tool/datafabric_tool.py b/src/uipath_langchain/agent/tools/datafabric_tool/datafabric_tool.py index 5709c36b7..7b1d62f55 100644 --- a/src/uipath_langchain/agent/tools/datafabric_tool/datafabric_tool.py +++ b/src/uipath_langchain/agent/tools/datafabric_tool/datafabric_tool.py @@ -100,20 +100,37 @@ def format_schemas_for_context(entities: list[Entity]) -> str: """Format entity schemas as markdown for injection into agent system prompt. The output is optimized for SQL query generation by the LLM. + Includes: SQL strategy prompt, constraints, entity schemas, and query patterns. Args: entities: List of Entity objects with schema information. Returns: - Markdown-formatted string describing entity schemas. + Markdown-formatted string describing entity schemas and SQL guidance. """ if not entities: return "" - lines = [ - "## Available Data Fabric Entities", - "", - ] + lines = [] + + # Add SQL generation strategy from system_prompt.txt + system_prompt = _load_system_prompt() + if system_prompt: + lines.append("## SQL Query Generation Guidelines") + lines.append("") + lines.append(system_prompt) + lines.append("") + + # Add SQL constraints from sql_constraints.txt + sql_constraints = _load_sql_constraints() + if sql_constraints: + lines.append("## SQL Constraints") + lines.append("") + lines.append(sql_constraints) + lines.append("") + + lines.append("## Available Data Fabric Entities") + lines.append("") for entity in entities: display_name = entity.display_name or entity.name @@ -124,13 +141,44 @@ def format_schemas_for_context(entities: list[Entity]) -> str: lines.append("| Field | Type |") lines.append("|-------|------|") + # Collect field info for query pattern examples + field_names = [] + numeric_field = None + text_field = None + for field in entity.fields or []: if field.is_hidden_field or field.is_system_field: continue field_name = field.display_name or field.name field_type = format_field_type(field) + field_names.append(field_name) lines.append(f"| {field_name} | {field_type} |") + # Track field types for examples + sql_type = field.sql_type.name.lower() if field.sql_type else "" + if not numeric_field and sql_type in ("int", "decimal", "float", "double", "bigint"): + numeric_field = field_name + if not text_field and sql_type in ("varchar", "nvarchar", "text", "string", "ntext"): + text_field = field_name + + lines.append("") + + # Add entity-specific query pattern examples + group_field = text_field or (field_names[0] if field_names else "Category") + agg_field = numeric_field or (field_names[1] if len(field_names) > 1 else "Amount") + filter_field = text_field or (field_names[0] if field_names else "Name") + fields_sample = ", ".join(field_names[:5]) if field_names else "*" + + lines.append(f"**Query Patterns for {display_name}:**") + lines.append("") + lines.append("| User Intent | SQL Pattern |") + lines.append("|-------------|-------------|") + lines.append(f"| 'Show all {display_name.lower()}' | `SELECT {fields_sample} FROM {display_name} LIMIT 100` |") + lines.append(f"| 'Find by X' | `SELECT {fields_sample} FROM {display_name} WHERE {filter_field} = 'value' LIMIT 100` |") + lines.append(f"| 'Top N by Y' | `SELECT {fields_sample} FROM {display_name} ORDER BY {agg_field} DESC LIMIT N` |") + lines.append(f"| 'Count by X' | `SELECT {group_field}, COUNT(*) as count FROM {display_name} GROUP BY {group_field}` |") + lines.append(f"| 'Top N segments' | `SELECT {group_field}, COUNT(*) as count FROM {display_name} GROUP BY {group_field} ORDER BY count DESC LIMIT N` |") + lines.append(f"| 'Sum/Avg of Y' | `SELECT SUM({agg_field}) as total FROM {display_name}` |") lines.append("") return "\n".join(lines) @@ -317,42 +365,23 @@ async def query_fn( if len(field_names) > 10: fields_str += f", ... ({len(field_names)} total)" - # Identify categorical fields for segmentation (text-like, non-PK) - categorical_fields = [ - f.display_name or f.name - for f in (entity.fields or []) - if not f.is_hidden_field and not f.is_system_field - and not f.is_primary_key - and f.sql_type and f.sql_type.name.lower() in ("text", "nvarchar", "varchar", "string", "ntext") - ] - segment_field = categorical_fields[0] if categorical_fields else (field_names[1] if len(field_names) > 1 else "Category") - count_field = field_names[0] if field_names else "Id" - - # Build intent-based query examples using actual entity fields - intent_examples = ( - f"QUERY PATTERNS for {entity_display_name}:\n" - f"- 'show all' → SELECT {fields_str} FROM {entity_display_name} LIMIT 100\n" - f"- 'top N by X' → SELECT {fields_str} FROM {entity_display_name} ORDER BY X DESC LIMIT N\n" - f"- 'top N segments/categories/groups' → SELECT {segment_field}, COUNT({count_field}) as count FROM {entity_display_name} GROUP BY {segment_field} ORDER BY count DESC LIMIT N\n" - f"- 'filter by X=value' → SELECT {fields_str} FROM {entity_display_name} WHERE X = 'value' LIMIT 100\n" - f"- 'average/sum/count of X' → SELECT AVG(X) as avg_x FROM {entity_display_name} LIMIT 1\n" - f"RULES: ALWAYS use explicit columns (no SELECT *). ALWAYS include LIMIT. Extract filter values from user message." - ) - tools.append( BaseUiPathStructuredTool( name=tool_name, description=( f"{context.description}. {entity_description}. " f"Available fields: {fields_str}. " - f"Generate SQL based on user's request." + f"Use SQL patterns from system prompt based on user intent." ), args_schema={ "type": "object", "properties": { "sql_query": { "type": "string", - "description": intent_examples, + "description": ( + f"Complete SQL SELECT statement for {entity_display_name}. " + f"Use exact column names from schema. Include LIMIT unless aggregating." + ), }, }, "required": ["sql_query"], From a6d4b2e013253321243dc2ad177e8eff619f58f3 Mon Sep 17 00:00:00 2001 From: Harshit Rohatgi Date: Tue, 3 Mar 2026 11:02:07 +0530 Subject: [PATCH 3/8] Added the data_fabric tool --- pyproject.toml | 4 + src/uipath_langchain/agent/tools/__init__.py | 10 + .../agent/tools/context_tool.py | 6 + .../agent/tools/datafabric_tool/__init__.py | 15 + .../tools/datafabric_tool/datafabric_tool.py | 398 ++++++++++++++++++ .../tools/datafabric_tool/sql_constraints.txt | 205 +++++++++ .../tools/datafabric_tool/system_prompt.txt | 124 ++++++ .../agent/tools/tool_factory.py | 30 +- uv.lock | 2 + 9 files changed, 792 insertions(+), 2 deletions(-) create mode 100644 src/uipath_langchain/agent/tools/datafabric_tool/__init__.py create mode 100644 src/uipath_langchain/agent/tools/datafabric_tool/datafabric_tool.py create mode 100644 src/uipath_langchain/agent/tools/datafabric_tool/sql_constraints.txt create mode 100644 src/uipath_langchain/agent/tools/datafabric_tool/system_prompt.txt diff --git a/pyproject.toml b/pyproject.toml index b6424d0c6..63ba42db9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -67,6 +67,7 @@ dev = [ "numpy>=1.24.0", "pytest_httpx>=0.35.0", "rust-just>=1.39.0", + "uipath-langchain", ] [tool.hatch.build.targets.wheel] @@ -116,3 +117,6 @@ name = "testpypi" url = "https://test.pypi.org/simple/" publish-url = "https://test.pypi.org/legacy/" explicit = true + +[tool.uv.sources] +uipath-langchain = { workspace = true } diff --git a/src/uipath_langchain/agent/tools/__init__.py b/src/uipath_langchain/agent/tools/__init__.py index 2f2f8214e..ce29925a5 100644 --- a/src/uipath_langchain/agent/tools/__init__.py +++ b/src/uipath_langchain/agent/tools/__init__.py @@ -1,6 +1,12 @@ """Tool creation and management for LowCode agents.""" from .context_tool import create_context_tool +from .datafabric_tool import ( + create_datafabric_tools, + fetch_entity_schemas, + format_schemas_for_context, + get_datafabric_contexts, +) from .escalation_tool import create_escalation_tool from .extraction_tool import create_ixp_extraction_tool from .integration_tool import create_integration_tool @@ -15,11 +21,15 @@ "create_tools_from_resources", "create_tool_node", "create_context_tool", + "create_datafabric_tools", "create_process_tool", "create_integration_tool", "create_escalation_tool", "create_ixp_extraction_tool", "create_ixp_escalation_tool", + "fetch_entity_schemas", + "format_schemas_for_context", + "get_datafabric_contexts", "UiPathToolNode", "ToolWrapperMixin", ] diff --git a/src/uipath_langchain/agent/tools/context_tool.py b/src/uipath_langchain/agent/tools/context_tool.py index ac917863c..b970deaf5 100644 --- a/src/uipath_langchain/agent/tools/context_tool.py +++ b/src/uipath_langchain/agent/tools/context_tool.py @@ -40,6 +40,12 @@ def create_context_tool(resource: AgentContextResourceConfig) -> StructuredTool: return handle_deep_rag(tool_name, resource) elif retrieval_mode == AgentContextRetrievalMode.BATCH_TRANSFORM.value.lower(): return handle_batch_transform(tool_name, resource) + elif retrieval_mode == AgentContextRetrievalMode.DATA_FABRIC.value.lower(): + # Data Fabric contexts are handled by create_datafabric_tools() in tool_factory.py + raise ValueError( + "Data Fabric context should be handled via create_datafabric_tools(), " + "not create_context_tool()" + ) else: return handle_semantic_search(tool_name, resource) diff --git a/src/uipath_langchain/agent/tools/datafabric_tool/__init__.py b/src/uipath_langchain/agent/tools/datafabric_tool/__init__.py new file mode 100644 index 000000000..ccfcf433d --- /dev/null +++ b/src/uipath_langchain/agent/tools/datafabric_tool/__init__.py @@ -0,0 +1,15 @@ +"""Data Fabric tool module for entity-based SQL queries.""" + +from .datafabric_tool import ( + create_datafabric_tools, + fetch_entity_schemas, + format_schemas_for_context, + get_datafabric_contexts, +) + +__all__ = [ + "create_datafabric_tools", + "fetch_entity_schemas", + "format_schemas_for_context", + "get_datafabric_contexts", +] diff --git a/src/uipath_langchain/agent/tools/datafabric_tool/datafabric_tool.py b/src/uipath_langchain/agent/tools/datafabric_tool/datafabric_tool.py new file mode 100644 index 000000000..7b1d62f55 --- /dev/null +++ b/src/uipath_langchain/agent/tools/datafabric_tool/datafabric_tool.py @@ -0,0 +1,398 @@ +"""Data Fabric tool creation for entity-based queries. + +This module provides functionality to: +1. Fetch and format entity schemas for agent context hydration +2. Create SQL-based query tools for the agent +""" + +import logging +from functools import lru_cache +from pathlib import Path +from typing import Any + +from langchain_core.tools import BaseTool +from pydantic import BaseModel, Field +from uipath.agent.models.agent import ( + AgentContextResourceConfig, + LowCodeAgentDefinition, +) +from uipath.platform.entities import Entity, FieldMetadata + +from ..base_uipath_structured_tool import BaseUiPathStructuredTool +from ..utils import sanitize_tool_name + +logger = logging.getLogger(__name__) + +# --- Prompt and Constraints Loading --- + +_PROMPTS_DIR = Path(__file__).parent + + +@lru_cache(maxsize=1) +def _load_sql_constraints() -> str: + """Load SQL constraints from sql_constraints.txt.""" + constraints_path = _PROMPTS_DIR / "sql_constraints.txt" + try: + return constraints_path.read_text(encoding="utf-8") + except FileNotFoundError: + logger.warning(f"SQL constraints file not found: {constraints_path}") + return "" + + +@lru_cache(maxsize=1) +def _load_system_prompt() -> str: + """Load SQL generation strategy from system_prompt.txt.""" + prompt_path = _PROMPTS_DIR / "system_prompt.txt" + try: + return prompt_path.read_text(encoding="utf-8") + except FileNotFoundError: + logger.warning(f"System prompt file not found: {prompt_path}") + return "" + + +# --- Schema Fetching and Formatting --- + + +async def fetch_entity_schemas(entity_identifiers: list[str]) -> list[Entity]: + """Fetch entity metadata from Data Fabric for the given entity identifiers. + + Args: + entity_identifiers: List of entity identifiers to fetch. + + Returns: + List of Entity objects with full schema information. + """ + from uipath.platform import UiPath + + sdk = UiPath() + entities: list[Entity] = [] + + for entity_identifier in entity_identifiers: + try: + entity = await sdk.entities.retrieve_async(entity_identifier) + entities.append(entity) + logger.info(f"Fetched schema for entity '{entity.display_name}'") + except Exception as e: + logger.warning(f"Failed to fetch entity '{entity_identifier}': {e}") + + return entities + + +def format_field_type(field: FieldMetadata) -> str: + """Format a field's type information for display.""" + type_name = field.sql_type.name if field.sql_type else "unknown" + + modifiers = [] + if field.is_primary_key: + modifiers.append("PK") + if field.is_foreign_key and field.reference_entity: + ref_name = field.reference_entity.display_name or field.reference_entity.name + modifiers.append(f"FK → {ref_name}") + if field.is_required: + modifiers.append("required") + + if modifiers: + return f"{type_name}, {', '.join(modifiers)}" + return type_name + + +def format_schemas_for_context(entities: list[Entity]) -> str: + """Format entity schemas as markdown for injection into agent system prompt. + + The output is optimized for SQL query generation by the LLM. + Includes: SQL strategy prompt, constraints, entity schemas, and query patterns. + + Args: + entities: List of Entity objects with schema information. + + Returns: + Markdown-formatted string describing entity schemas and SQL guidance. + """ + if not entities: + return "" + + lines = [] + + # Add SQL generation strategy from system_prompt.txt + system_prompt = _load_system_prompt() + if system_prompt: + lines.append("## SQL Query Generation Guidelines") + lines.append("") + lines.append(system_prompt) + lines.append("") + + # Add SQL constraints from sql_constraints.txt + sql_constraints = _load_sql_constraints() + if sql_constraints: + lines.append("## SQL Constraints") + lines.append("") + lines.append(sql_constraints) + lines.append("") + + lines.append("## Available Data Fabric Entities") + lines.append("") + + for entity in entities: + display_name = entity.display_name or entity.name + lines.append(f"### Entity: {display_name}") + if entity.description: + lines.append(f"_{entity.description}_") + lines.append("") + lines.append("| Field | Type |") + lines.append("|-------|------|") + + # Collect field info for query pattern examples + field_names = [] + numeric_field = None + text_field = None + + for field in entity.fields or []: + if field.is_hidden_field or field.is_system_field: + continue + field_name = field.display_name or field.name + field_type = format_field_type(field) + field_names.append(field_name) + lines.append(f"| {field_name} | {field_type} |") + + # Track field types for examples + sql_type = field.sql_type.name.lower() if field.sql_type else "" + if not numeric_field and sql_type in ("int", "decimal", "float", "double", "bigint"): + numeric_field = field_name + if not text_field and sql_type in ("varchar", "nvarchar", "text", "string", "ntext"): + text_field = field_name + + lines.append("") + + # Add entity-specific query pattern examples + group_field = text_field or (field_names[0] if field_names else "Category") + agg_field = numeric_field or (field_names[1] if len(field_names) > 1 else "Amount") + filter_field = text_field or (field_names[0] if field_names else "Name") + fields_sample = ", ".join(field_names[:5]) if field_names else "*" + + lines.append(f"**Query Patterns for {display_name}:**") + lines.append("") + lines.append("| User Intent | SQL Pattern |") + lines.append("|-------------|-------------|") + lines.append(f"| 'Show all {display_name.lower()}' | `SELECT {fields_sample} FROM {display_name} LIMIT 100` |") + lines.append(f"| 'Find by X' | `SELECT {fields_sample} FROM {display_name} WHERE {filter_field} = 'value' LIMIT 100` |") + lines.append(f"| 'Top N by Y' | `SELECT {fields_sample} FROM {display_name} ORDER BY {agg_field} DESC LIMIT N` |") + lines.append(f"| 'Count by X' | `SELECT {group_field}, COUNT(*) as count FROM {display_name} GROUP BY {group_field}` |") + lines.append(f"| 'Top N segments' | `SELECT {group_field}, COUNT(*) as count FROM {display_name} GROUP BY {group_field} ORDER BY count DESC LIMIT N` |") + lines.append(f"| 'Sum/Avg of Y' | `SELECT SUM({agg_field}) as total FROM {display_name}` |") + lines.append("") + + return "\n".join(lines) + + +# --- Data Fabric Context Detection --- + + +def get_datafabric_contexts( + agent: LowCodeAgentDefinition, +) -> list[AgentContextResourceConfig]: + """Extract Data Fabric context resources from agent definition. + + Args: + agent: The agent definition to search. + + Returns: + List of context resources configured for Data Fabric retrieval mode. + """ + datafabric_contexts: list[AgentContextResourceConfig] = [] + + for resource in agent.resources: + if not isinstance(resource, AgentContextResourceConfig): + continue + if not resource.is_enabled: + continue + if resource.settings.retrieval_mode.lower() == "datafabric": + datafabric_contexts.append(resource) + + return datafabric_contexts + + +# --- Tool Creation --- + + +class QueryEntityInput(BaseModel): + """Input schema for the query_entity tool.""" + + entity_identifier: str = Field( + ..., description="The entity identifier to query" + ) + sql_where: str = Field( + default="", + description="SQL WHERE clause to filter records (without the WHERE keyword). " + "Example: 'Status = \"Active\" AND Amount > 100'", + ) + limit: int = Field( + default=1000, description="Maximum number of records to return" + ) + + +class QueryEntityOutput(BaseModel): + """Output schema for the query_entity tool.""" + + records: list[dict[str, Any]] = Field( + ..., description="List of entity records matching the query" + ) + total_count: int = Field(..., description="Total number of matching records") + + +async def create_datafabric_tools( + agent: LowCodeAgentDefinition, +) -> tuple[list[BaseTool], str]: + """Create Data Fabric tools and schema context from agent definition. + + This function: + 1. Finds all Data Fabric context resources in the agent + 2. Fetches entity schemas for context hydration + 3. Returns tools and schema context string + + Args: + agent: The agent definition containing Data Fabric context resources. + + Returns: + Tuple of (tools, schema_context) where: + - tools: List of BaseTool instances for querying entities + - schema_context: Formatted schema string to inject into system prompt + """ + tools: list[BaseTool] = [] + all_entities: list[Entity] = [] + + datafabric_contexts = get_datafabric_contexts(agent) + + if not datafabric_contexts: + return tools, "" + + logger.info(f"Found {len(datafabric_contexts)} Data Fabric context resource(s)") + + for context in datafabric_contexts: + entity_identifiers = context.settings.entity_identifiers or [] + + if not entity_identifiers: + logger.warning( + f"Data Fabric context '{context.name}' has no entity_identifiers configured" + ) + continue + + # Fetch entity schemas + entities = await fetch_entity_schemas(entity_identifiers) + all_entities.extend(entities) + + context_tools = _create_sdk_based_tools(context, entities) + tools.extend(context_tools) + + logger.info( + f"Created {len(context_tools)} tools for Data Fabric context '{context.name}'" + ) + + # Format all entity schemas for context injection + schema_context = format_schemas_for_context(all_entities) + + return tools, schema_context + + +def _create_sdk_based_tools( + context: AgentContextResourceConfig, + entities: list[Entity], +) -> list[BaseTool]: + """Create SDK-based tools for querying entities using SQL. + + Each tool accepts a full SQL query and executes it via the SDK's + query_entity_records_async method. + """ + tools: list[BaseTool] = [] + MAX_RECORDS_IN_RESPONSE = 50 # Limit records to prevent context overflow + + for entity in entities: + tool_name = sanitize_tool_name(f"query_{entity.name}") + + # Create a closure to capture the entity name + entity_display_name = entity.display_name or entity.name + + async def query_fn( + sql_query: str, + _entity_name: str = entity_display_name, + _max_records: int = MAX_RECORDS_IN_RESPONSE, + ) -> dict[str, Any]: + """Execute a SQL query against the Data Fabric entity.""" + from uipath.platform import UiPath + + print(f"[DEBUG] query_fn called for entity '{_entity_name}' with SQL: {sql_query}") + logger.info(f"Executing SQL query for entity '{_entity_name}': {sql_query}") + + sdk = UiPath() + try: + records = await sdk.entities.query_entity_records_async( + sql_query=sql_query, + ) + total_count = len(records) + truncated = total_count > _max_records + returned_records = records[:_max_records] if truncated else records + + print(f"[DEBUG] Retrieved {total_count} records, returning {len(returned_records)} for entity '{_entity_name}'") + + result = { + "records": returned_records, + "total_count": total_count, + "returned_count": len(returned_records), + "entity": _entity_name, + "sql_query": sql_query, + } + if truncated: + result["truncated"] = True + result["message"] = f"Showing {len(returned_records)} of {total_count} records. Use more specific filters or LIMIT to narrow results." + return result + except Exception as e: + logger.error(f"SQL query failed for entity '{_entity_name}': {e}") + return { + "records": [], + "total_count": 0, + "error": str(e), + "sql_query": sql_query, + } + + entity_description = entity.description or f"Query {entity_display_name} records" + + # Extract actual field names from entity schema (exclude system/hidden fields) + field_names = [ + f.display_name or f.name + for f in (entity.fields or []) + if not f.is_hidden_field and not f.is_system_field + ] + fields_str = ", ".join(field_names[:10]) # Limit to first 10 fields + if len(field_names) > 10: + fields_str += f", ... ({len(field_names)} total)" + + tools.append( + BaseUiPathStructuredTool( + name=tool_name, + description=( + f"{context.description}. {entity_description}. " + f"Available fields: {fields_str}. " + f"Use SQL patterns from system prompt based on user intent." + ), + args_schema={ + "type": "object", + "properties": { + "sql_query": { + "type": "string", + "description": ( + f"Complete SQL SELECT statement for {entity_display_name}. " + f"Use exact column names from schema. Include LIMIT unless aggregating." + ), + }, + }, + "required": ["sql_query"], + }, + coroutine=query_fn, + metadata={ + "tool_type": "datafabric_sql", + "display_name": f"Query {entity_display_name}", + "entity_name": entity_display_name, + }, + ) + ) + + return tools diff --git a/src/uipath_langchain/agent/tools/datafabric_tool/sql_constraints.txt b/src/uipath_langchain/agent/tools/datafabric_tool/sql_constraints.txt new file mode 100644 index 000000000..804537d54 --- /dev/null +++ b/src/uipath_langchain/agent/tools/datafabric_tool/sql_constraints.txt @@ -0,0 +1,205 @@ +# SQL Query Constraints for SQLite + +## SUPPORTED SCENARIOS + +### 1. Single-Entity Baselines +- Simple projections with explicit column names (NO SELECT *) +- Single-field predicates: =, <>, >, <, >=, <=, BETWEEN, IN, LIKE +- WHERE clauses with AND/OR & parentheses +- IS NULL / IS NOT NULL + +**Examples:** +- SELECT id, name FROM Customer +- SELECT id, name FROM Customer WHERE age >= 21 AND (region='APAC' OR vip=1) +- SELECT id, name FROM Customer WHERE deleted_at IS NULL + +### 2. Multi-Entity Joins (≤4 adapters) +- LEFT JOIN chains via entity model (up to 4 tables) +- Optional adapters pruned +- Shared intermediates +- Null-preserving semantics + +**Examples:** +- SELECT o.id, c.name FROM Order o LEFT JOIN Customer c ON o.customer_id = c.id +- Fields spanning 3-4 adapters with proper LEFT JOIN chains + +### 3. Predicate Distribution & Pushdown +- Adapter-scoped predicates pushed down +- Cross-adapter/global predicates at root +- Empty IN () treated as FALSE + +**Examples:** +- SELECT c.id, c.name FROM Customer c WHERE c.country='IN' AND c.total>1000 +- SELECT id FROM Customer WHERE id IN () -- evaluates to FALSE + +### 4. Aggregations & Grouping (Basic) +- GROUP BY entity fields +- Aggregate functions: SUM, AVG, MIN, MAX, COUNT(column_name) +- Simple expressions in aggregates +- HAVING on aggregates or plain fields + +**Examples:** +- SELECT country, COUNT(id) FROM Customer GROUP BY country +- SELECT dept, SUM(price*qty) as total FROM LineItem GROUP BY dept +- SELECT country, COUNT(id) as cnt FROM Customer GROUP BY country HAVING COUNT(id)>10 + +### 5. Expressions (Minimal) +- CASE (simple/searched) in SELECT/WHERE/ORDER BY +- Arithmetic: +, -, *, / (SQLite-compatible) +- Functions: COALESCE, NULLIF, || +- String functions: LOWER, UPPER, TRIM, LTRIM, RTRIM +- Math functions: ROUND, ABS +- Note: CEIL/FLOOR may have limited adapter support + +**Examples:** +- SELECT CASE WHEN age>=18 THEN 'Adult' ELSE 'Minor' END AS segment FROM Customer +- SELECT COALESCE(nickname, name) as display_name FROM Customer +- SELECT ROUND(amount, 2) FROM Payments + +### 6. Casting & Coercion (Basic) +- CAST among common scalar types +- Implicit numeric widening where adapter accepts +- Explicit casts preferred at root + +**Examples:** +- SELECT CAST(amount AS DECIMAL(12,2)) FROM Payments +- SELECT CAST(id AS TEXT) FROM Customer + +### 7. Ordering & Pagination +- ORDER BY fields/aliases/expressions +- Multi-column ORDER BY +- LIMIT/OFFSET for pagination +- Note: Pagination without ORDER BY may produce non-deterministic results + +**Examples:** +- SELECT id, price*qty AS amt FROM LineItem ORDER BY amt DESC LIMIT 50 OFFSET 100 +- SELECT id, name FROM Customer ORDER BY name LIMIT 10 + +### 8. DISTINCT +- SELECT DISTINCT with explicit column names +- DISTINCT with ORDER BY on projected items/aliases + +**Examples:** +- SELECT DISTINCT country FROM Customer ORDER BY country +- SELECT DISTINCT dept, location FROM Employee + +### 9. Metadata Remapping & Aliasing +- Physical column remaps per adapter +- Alias reuse consistent across clauses +- Column aliases can be used in ORDER BY + +**Examples:** +- SELECT name AS customer_name FROM Customer ORDER BY customer_name + +--- + +## UNSUPPORTED SCENARIOS + +### 1. METADATA_RESOLUTION +- Unknown table/entity names +- Unknown column/field names +- Ambiguous column references without table prefix +- Field not in SELECT but used in ORDER BY (without alias) + +**Examples:** +- SELECT name FROM UnknownTable -- ❌ +- SELECT unknown_column FROM Customer -- ❌ +- SELECT id FROM Customer ORDER BY name -- ❌ (name not in SELECT) + +### 2. PROHIBITED_SQL_PATTERNS +- SELECT * FROM table -- ❌ Must use explicit column names +- Subqueries in FROM, WHERE, or SELECT +- UNION/UNION ALL/INTERSECT/EXCEPT +- Common Table Expressions (WITH/CTE) +- Window functions (ROW_NUMBER, RANK, PARTITION BY) +- Self-joins +- RIGHT JOIN or FULL OUTER JOIN (only LEFT JOIN supported) +- CROSS JOIN + +**Examples:** +- SELECT * FROM Customer -- ❌ +- SELECT id FROM (SELECT * FROM Customer) -- ❌ +- SELECT id FROM Customer UNION SELECT id FROM Order -- ❌ +- WITH cte AS (SELECT id FROM Customer) SELECT * FROM cte -- ❌ + +### 3. COMPLEX_AGGREGATIONS +- Nested aggregations: COUNT(DISTINCT(...)), SUM(DISTINCT(...)) +- Aggregations without GROUP BY on non-aggregated columns +- HAVING without GROUP BY +- COUNT(*) not allowed, use COUNT(column_name) instead + +**Examples:** +- SELECT COUNT(DISTINCT dept) FROM Employee -- ❌ +- SELECT name, COUNT(id) FROM Employee -- ❌ (name not in GROUP BY) +- SELECT AVG(salary) FROM Employee HAVING AVG(salary) > 50000 -- ❌ (no GROUP BY) + +### 4. ADVANCED_JOINS +- More than 4 tables in JOIN chain +- RIGHT JOIN +- FULL OUTER JOIN +- CROSS JOIN +- Self-joins +- Non-equi joins (theta joins) + +**Examples:** +- SELECT * FROM t1 RIGHT JOIN t2 -- ❌ +- SELECT * FROM t1, t2 -- ❌ (implicit CROSS JOIN) +- SELECT * FROM Employee e1 JOIN Employee e2 ON e1.manager_id = e2.id -- ❌ (self-join) + +### 5. UNSUPPORTED_FUNCTIONS +- Date/time manipulation functions (DATE_ADD, DATE_SUB, DATEDIFF) +- JSON functions (JSON_EXTRACT, JSON_ARRAY) +- Regex functions +- User-defined functions (UDFs) +- String aggregation (GROUP_CONCAT with complex separators) + +**Examples:** +- SELECT DATE_ADD(created_at, INTERVAL 1 DAY) FROM Order -- ❌ +- SELECT JSON_EXTRACT(data, '$.field') FROM Table -- ❌ + +### 6. COMPLEX_PREDICATES +- Correlated subqueries in WHERE +- EXISTS/NOT EXISTS +- ANY/ALL operators +- IN with subquery + +**Examples:** +- SELECT id FROM Customer WHERE EXISTS (SELECT 1 FROM Order WHERE customer_id = Customer.id) -- ❌ +- SELECT id FROM Customer WHERE id IN (SELECT customer_id FROM Order) -- ❌ + +### 7. MODIFICATIONS +- INSERT, UPDATE, DELETE, MERGE +- CREATE, ALTER, DROP (DDL) +- TRUNCATE +- Transactions (BEGIN, COMMIT, ROLLBACK) + +**Examples:** +- INSERT INTO Customer VALUES (...) -- ❌ +- UPDATE Customer SET name = 'John' -- ❌ +- DELETE FROM Customer -- ❌ + +### 8. UNSUPPORTED_CLAUSES +- HAVING without GROUP BY +- LIMIT without explicit value (e.g., LIMIT ALL) +- OFFSET without LIMIT +- FOR UPDATE / FOR SHARE +- INTO clause (SELECT INTO) + +**Examples:** +- SELECT AVG(salary) FROM Employee HAVING AVG(salary) > 50000 -- ❌ +- SELECT id FROM Customer OFFSET 10 -- ❌ (no LIMIT) + +--- + +## CRITICAL RULES + +1. **ALWAYS use explicit column names** - Never use SELECT * +2. **Use COUNT(column_name)** - Never use COUNT(*) +3. **Only LEFT JOIN** - No RIGHT JOIN, FULL OUTER JOIN, or CROSS JOIN +4. **Maximum 4 tables** - No more than 4 tables in a JOIN chain +5. **No subqueries** - No subqueries in any clause +6. **No CTEs** - No WITH clauses +7. **No window functions** - No ROW_NUMBER, RANK, PARTITION BY, etc. +8. **Explicit GROUP BY** - All non-aggregated columns in SELECT must be in GROUP BY +9. **Simple aggregations only** - No DISTINCT in aggregates +10. **ORDER BY only selected columns** - Cannot ORDER BY columns not in SELECT list diff --git a/src/uipath_langchain/agent/tools/datafabric_tool/system_prompt.txt b/src/uipath_langchain/agent/tools/datafabric_tool/system_prompt.txt new file mode 100644 index 000000000..03df458fb --- /dev/null +++ b/src/uipath_langchain/agent/tools/datafabric_tool/system_prompt.txt @@ -0,0 +1,124 @@ +You are a SQL expert specialized in converting natural language questions into SQL queries. + +Given a database schema, translate the user's natural language question into a valid SQL query. + +IMPORTANT: You MUST follow the SQL constraints defined in sql_constraints.txt. These constraints define supported and unsupported SQL patterns for SQLite. + +Rules: +1. Return ONLY the SQL query without any explanations, markdown formatting, or additional text +2. Use SQLite syntax +3. Ensure the query is syntactically correct +4. Use appropriate JOINs, WHERE clauses, and aggregations as needed +5. Include LIMIT clauses when appropriate to prevent returning too many rows +6. Use the exact table and column names from the provided schema +7. For financial values (salary, price, etc.), use ROUND() function +8. Handle NULL values appropriately with COALESCE() or IFNULL() + +SUPPORTED SCENARIOS (Use these patterns): + +1. Single-Entity Baselines: + - Simple projections: SELECT id, name FROM Customer + - Single-field predicates: =, <>, >, <, BETWEEN, IN, LIKE + - WHERE with AND/OR & parentheses: WHERE age >= 21 AND (region='APAC' OR vip=1) + - IS NULL/IS NOT NULL: WHERE deleted_at IS NULL + +2. Multi-Entity Joins (≤4 tables): + - LEFT JOIN chains (up to 4 tables): SELECT o.id, c.name FROM Order o LEFT JOIN Customer c ON o.customer_id = c.id + - Null-preserving semantics + +3. Predicate Distribution: + - Table-scoped predicates: WHERE c.country='IN' AND o.total>1000 + - Empty IN () evaluates to FALSE + +4. Aggregations & Grouping: + - GROUP BY entity fields: SELECT country, COUNT(id) FROM Customer GROUP BY country + - Supported functions: SUM, AVG, MIN, MAX, COUNT(column) + - Simple expressions in aggregates: SELECT SUM(price*qty) FROM LineItem + - HAVING on aggregates or plain fields: HAVING COUNT(id)>10 + +5. Expressions (Minimal): + - CASE (simple/searched) in SELECT/WHERE/ORDER BY + - Arithmetic: + - * / (SQLite-compatible) + - String functions: COALESCE, NULLIF, ||, LOWER, UPPER, TRIM, LTRIM, RTRIM + - Math functions: ROUND, ABS + - Example: SELECT CASE WHEN age>=18 THEN 'Adult' ELSE 'Minor' END AS segment FROM Customer + +6. Casting & Coercion: + - CAST among common scalar types: SELECT CAST(amount AS DECIMAL(12,2)) FROM Payments + - Prefer explicit casts + +7. Ordering & Pagination: + - ORDER BY fields/aliases/expressions: SELECT price*qty AS amt FROM LineItem ORDER BY amt DESC + - Multi-column ordering: ORDER BY country, name + - LIMIT/OFFSET: LIMIT 50 OFFSET 100 + +8. DISTINCT: + - SELECT DISTINCT country FROM Customer ORDER BY country + +9. Aliasing: + - Column aliases: SELECT name AS customer_name + - Reuse aliases in ORDER BY: SELECT price*qty AS amt FROM LineItem ORDER BY amt + +UNSUPPORTED SCENARIOS (Avoid these patterns): + +1. METADATA_RESOLUTION: + - Unknown/non-existent tables or entities + - More than 4 JOINs in a query + - Relationship cycles + +2. SQL_PARSING: + - Malformed SQL or SQL injection attempts + - UNION/INTERSECT/EXCEPT + - WITH (CTEs - Common Table Expressions) + - VALUES clause + - PRAGMA statements + +3. VALIDATION_GUARDRAIL: + - SELECT * (always specify columns explicitly) + - COUNT(*) or COUNT(1) (use COUNT(column_name) instead) + - Aggregates on literals + - HAVING without GROUP BY + - DISTINCT on constants + - Invalid LIMIT/OFFSET values + +4. UNSUPPORTED_CONSTRUCTS - Subqueries/Windows/DML/DDL: + - ANY subqueries or derived tables: WHERE x IN (SELECT ...) + - Window functions: ROW_NUMBER() OVER(...) + - DML: UPDATE, INSERT, DELETE + - DDL: CREATE, ALTER, DROP + - Temporary objects or transactions + +5. UNSUPPORTED_CONSTRUCTS - Joins: + - RIGHT JOIN, FULL OUTER JOIN, CROSS JOIN + - Non-equi join conditions: ON a.created_at > b.created_at + - Self-joins + - LATERAL/APPLY + +6. PARTITIONING: + - Disconnected tables with no join path + - Contradictory references causing cycles + +7. UNSUPPORTED_CONSTRUCTS - Advanced Aggregation: + - ROLLUP/CUBE/GROUPING SETS + - Approximate or ordered-set aggregates + - PERCENTILE functions + - Multi-table DISTINCT aggregates + +8. VALIDATION_GUARDRAIL - Ordering: + - ORDER BY ordinals: ORDER BY 1 + - NULLS FIRST/LAST + - TOP n syntax + - WITH TIES + - COLLATE clause + +9. UNSUPPORTED_CONSTRUCTS - Advanced Functions: + - REGEXP/SIMILAR TO/ILIKE + - Advanced math functions beyond basic arithmetic + - Date truncation/extraction beyond SQLite basics + - User-defined functions (UDFs) + +10. VALIDATION_GUARDRAIL - Types: + - JSON/ARRAY/MAP/GEOMETRY/BLOB operations + - Complex timezone-aware timestamp operations + +Return only the SQL query as plain text. \ No newline at end of file diff --git a/src/uipath_langchain/agent/tools/tool_factory.py b/src/uipath_langchain/agent/tools/tool_factory.py index 8a87fec87..5ba0bf1e5 100644 --- a/src/uipath_langchain/agent/tools/tool_factory.py +++ b/src/uipath_langchain/agent/tools/tool_factory.py @@ -17,6 +17,7 @@ ) from .context_tool import create_context_tool +from .datafabric_tool import create_datafabric_tools from .escalation_tool import create_escalation_tool from .extraction_tool import create_ixp_extraction_tool from .integration_tool import create_integration_tool @@ -29,10 +30,26 @@ async def create_tools_from_resources( agent: LowCodeAgentDefinition, llm: BaseChatModel -) -> list[BaseTool]: +) -> tuple[list[BaseTool], str]: + """Create tools from agent resources including Data Fabric tools. + + Args: + agent: The agent definition. + llm: The language model for tool creation. + + Returns: + Tuple of (tools, datafabric_schema_context). + """ tools: list[BaseTool] = [] + datafabric_schema_context: str = "" logger.info("Creating tools for agent '%s' from resources", agent.name) + + # Handle Data Fabric tools first (they need special handling) + datafabric_tools, schema_context = await create_datafabric_tools(agent) + tools.extend(datafabric_tools) + datafabric_schema_context = schema_context + for resource in agent.resources: if not resource.is_enabled: logger.info( @@ -54,7 +71,7 @@ async def create_tools_from_resources( else: tools.append(tool) - return tools + return tools, datafabric_schema_context async def _build_tool_for_resource( @@ -64,6 +81,15 @@ async def _build_tool_for_resource( return create_process_tool(resource) elif isinstance(resource, AgentContextResourceConfig): + # Skip Data Fabric contexts - handled separately via create_datafabric_tools() + retrieval_mode = resource.settings.retrieval_mode + mode_value = retrieval_mode.value if hasattr(retrieval_mode, 'value') else str(retrieval_mode) + if mode_value.lower() == "datafabric": + logger.info( + "Skipping Data Fabric context '%s' - handled separately", + resource.name, + ) + return None return create_context_tool(resource) elif isinstance(resource, AgentEscalationResourceConfig): diff --git a/uv.lock b/uv.lock index 395ec051b..7aa274c2f 100644 --- a/uv.lock +++ b/uv.lock @@ -3356,6 +3356,7 @@ dev = [ { name = "pytest-mock" }, { name = "ruff" }, { name = "rust-just" }, + { name = "uipath-langchain" }, { name = "virtualenv" }, ] @@ -3396,6 +3397,7 @@ dev = [ { name = "pytest-mock", specifier = ">=3.11.1" }, { name = "ruff", specifier = ">=0.9.4" }, { name = "rust-just", specifier = ">=1.39.0" }, + { name = "uipath-langchain", editable = "." }, { name = "virtualenv", specifier = ">=20.36.1" }, ] From 3510f6bc36d8e4fb5a82eeac3d5a4812ea53adea Mon Sep 17 00:00:00 2001 From: Harshit Rohatgi Date: Thu, 5 Mar 2026 10:31:17 +0530 Subject: [PATCH 4/8] Updated --- src/uipath_langchain/agent/react/__init__.py | 3 +- src/uipath_langchain/agent/react/agent.py | 8 +- src/uipath_langchain/agent/react/init_node.py | 39 ++- src/uipath_langchain/agent/react/types.py | 5 +- src/uipath_langchain/agent/tools/__init__.py | 2 + .../agent/tools/context_tool.py | 2 +- .../agent/tools/datafabric_tool/__init__.py | 2 + .../tools/datafabric_tool/datafabric_tool.py | 302 +++++++----------- .../agent/tools/tool_factory.py | 13 +- tests/agent/react/test_create_agent.py | 2 + tests/agent/react/test_init_node.py | 37 ++- 11 files changed, 193 insertions(+), 222 deletions(-) diff --git a/src/uipath_langchain/agent/react/__init__.py b/src/uipath_langchain/agent/react/__init__.py index 1cd32a9aa..96e2ef1af 100644 --- a/src/uipath_langchain/agent/react/__init__.py +++ b/src/uipath_langchain/agent/react/__init__.py @@ -1,7 +1,7 @@ """UiPath ReAct Agent implementation""" from .agent import create_agent -from .types import AgentGraphConfig, AgentGraphNode, AgentGraphState +from .types import AgentGraphConfig, AgentGraphNode, AgentGraphState, AgentResources from .utils import resolve_input_model, resolve_output_model __all__ = [ @@ -11,4 +11,5 @@ "AgentGraphNode", "AgentGraphState", "AgentGraphConfig", + "AgentResources", ] diff --git a/src/uipath_langchain/agent/react/agent.py b/src/uipath_langchain/agent/react/agent.py index 1bcb64289..7c6736843 100644 --- a/src/uipath_langchain/agent/react/agent.py +++ b/src/uipath_langchain/agent/react/agent.py @@ -35,6 +35,7 @@ AgentGraphConfig, AgentGraphNode, AgentGraphState, + AgentResources, AgentSettings, ) from .utils import create_state_with_input @@ -53,6 +54,7 @@ def create_agent( output_schema: Type[OutputT] | None = None, config: AgentGraphConfig | None = None, guardrails: Sequence[tuple[BaseGuardrail, GuardrailAction]] | None = None, + resources: AgentResources | None = None, ) -> StateGraph[AgentGraphState, None, InputT, OutputT]: """Build agent graph with INIT -> AGENT (subgraph) <-> TOOLS loop, terminated by control flow tools. @@ -86,7 +88,11 @@ def create_agent( llm_tools: list[BaseTool] = [*agent_tools, *flow_control_tools] init_node = create_init_node( - messages, input_schema, config.is_conversational, agent_settings + messages, + input_schema, + config.is_conversational, + agent_settings, + resources_for_init=resources, ) tool_nodes = create_tool_node(agent_tools) diff --git a/src/uipath_langchain/agent/react/init_node.py b/src/uipath_langchain/agent/react/init_node.py index bbf1a18a7..47f402fe0 100644 --- a/src/uipath_langchain/agent/react/init_node.py +++ b/src/uipath_langchain/agent/react/init_node.py @@ -1,5 +1,6 @@ """State initialization node for the ReAct Agent graph.""" +import logging from typing import Any, Callable, Sequence from langchain_core.messages import HumanMessage, SystemMessage @@ -9,22 +10,52 @@ from .job_attachments import ( get_job_attachments, ) -from .types import AgentSettings +from .types import AgentResources, AgentSettings + +logger = logging.getLogger(__name__) def create_init_node( messages: Sequence[SystemMessage | HumanMessage] - | Callable[[Any], Sequence[SystemMessage | HumanMessage]], + | Callable[..., Sequence[SystemMessage | HumanMessage]], input_schema: type[BaseModel] | None, is_conversational: bool = False, agent_settings: AgentSettings | None = None, + resources_for_init: AgentResources | None = None, ): - def graph_state_init(state: Any) -> Any: + async def graph_state_init(state: Any) -> Any: + # --- Data Fabric schema fetch (INIT-time) --- + schema_context: str | None = None + if resources_for_init: + from uipath_langchain.agent.tools.datafabric_tool import ( + fetch_entity_schemas, + format_schemas_for_context, + get_datafabric_entity_identifiers_from_resources, + ) + + entity_identifiers = get_datafabric_entity_identifiers_from_resources( + resources_for_init + ) + if entity_identifiers: + logger.info( + "Fetching Data Fabric schemas for %d identifier(s)", + len(entity_identifiers), + ) + entities = await fetch_entity_schemas(entity_identifiers) + schema_context = format_schemas_for_context(entities) + + # --- Resolve messages --- resolved_messages: Sequence[SystemMessage | HumanMessage] | Overwrite if callable(messages): - resolved_messages = list(messages(state)) + if schema_context: + resolved_messages = list( + messages(state, additional_context=schema_context) + ) + else: + resolved_messages = list(messages(state)) else: resolved_messages = list(messages) + if is_conversational: # For conversational agents we need to reorder the messages so that the system message is first, followed by # the initial user message. When resuming the conversation, the state will have the entire message history, diff --git a/src/uipath_langchain/agent/react/types.py b/src/uipath_langchain/agent/react/types.py index d5b6128ef..56d33db48 100644 --- a/src/uipath_langchain/agent/react/types.py +++ b/src/uipath_langchain/agent/react/types.py @@ -1,9 +1,10 @@ from enum import StrEnum -from typing import Annotated, Any, Hashable, Optional +from typing import Annotated, Any, Hashable, Optional, Sequence from langchain_core.messages import AnyMessage from langgraph.graph.message import add_messages from pydantic import BaseModel, Field +from uipath.agent.models.agent import BaseAgentResourceConfig from uipath.agent.react import END_EXECUTION_TOOL, RAISE_ERROR_TOOL from uipath.platform.attachments import Attachment @@ -15,6 +16,8 @@ FLOW_CONTROL_TOOLS = [END_EXECUTION_TOOL.name, RAISE_ERROR_TOOL.name] +AgentResources = Sequence[BaseAgentResourceConfig] + class AgentSettings(BaseModel): """Agent settings extracted from the LLM model.""" diff --git a/src/uipath_langchain/agent/tools/__init__.py b/src/uipath_langchain/agent/tools/__init__.py index ce29925a5..6df842997 100644 --- a/src/uipath_langchain/agent/tools/__init__.py +++ b/src/uipath_langchain/agent/tools/__init__.py @@ -6,6 +6,7 @@ fetch_entity_schemas, format_schemas_for_context, get_datafabric_contexts, + get_datafabric_entity_identifiers_from_resources, ) from .escalation_tool import create_escalation_tool from .extraction_tool import create_ixp_extraction_tool @@ -30,6 +31,7 @@ "fetch_entity_schemas", "format_schemas_for_context", "get_datafabric_contexts", + "get_datafabric_entity_identifiers_from_resources", "UiPathToolNode", "ToolWrapperMixin", ] diff --git a/src/uipath_langchain/agent/tools/context_tool.py b/src/uipath_langchain/agent/tools/context_tool.py index b970deaf5..b27ff1cf8 100644 --- a/src/uipath_langchain/agent/tools/context_tool.py +++ b/src/uipath_langchain/agent/tools/context_tool.py @@ -40,7 +40,7 @@ def create_context_tool(resource: AgentContextResourceConfig) -> StructuredTool: return handle_deep_rag(tool_name, resource) elif retrieval_mode == AgentContextRetrievalMode.BATCH_TRANSFORM.value.lower(): return handle_batch_transform(tool_name, resource) - elif retrieval_mode == AgentContextRetrievalMode.DATA_FABRIC.value.lower(): + elif retrieval_mode == "datafabric": # Data Fabric contexts are handled by create_datafabric_tools() in tool_factory.py raise ValueError( "Data Fabric context should be handled via create_datafabric_tools(), " diff --git a/src/uipath_langchain/agent/tools/datafabric_tool/__init__.py b/src/uipath_langchain/agent/tools/datafabric_tool/__init__.py index ccfcf433d..f000a1979 100644 --- a/src/uipath_langchain/agent/tools/datafabric_tool/__init__.py +++ b/src/uipath_langchain/agent/tools/datafabric_tool/__init__.py @@ -5,6 +5,7 @@ fetch_entity_schemas, format_schemas_for_context, get_datafabric_contexts, + get_datafabric_entity_identifiers_from_resources, ) __all__ = [ @@ -12,4 +13,5 @@ "fetch_entity_schemas", "format_schemas_for_context", "get_datafabric_contexts", + "get_datafabric_entity_identifiers_from_resources", ] diff --git a/src/uipath_langchain/agent/tools/datafabric_tool/datafabric_tool.py b/src/uipath_langchain/agent/tools/datafabric_tool/datafabric_tool.py index 7b1d62f55..7c5b7101c 100644 --- a/src/uipath_langchain/agent/tools/datafabric_tool/datafabric_tool.py +++ b/src/uipath_langchain/agent/tools/datafabric_tool/datafabric_tool.py @@ -1,25 +1,25 @@ """Data Fabric tool creation for entity-based queries. -This module provides functionality to: -1. Fetch and format entity schemas for agent context hydration -2. Create SQL-based query tools for the agent +This module provides: +1. A single generic ``query_datafabric`` tool (no per-entity knowledge at build time) +2. Schema fetching & formatting helpers consumed by the INIT node at runtime +3. Helpers to extract entity identifiers from agent definitions """ import logging from functools import lru_cache from pathlib import Path -from typing import Any +from typing import Any, Sequence from langchain_core.tools import BaseTool -from pydantic import BaseModel, Field from uipath.agent.models.agent import ( AgentContextResourceConfig, + BaseAgentResourceConfig, LowCodeAgentDefinition, ) from uipath.platform.entities import Entity, FieldMetadata from ..base_uipath_structured_tool import BaseUiPathStructuredTool -from ..utils import sanitize_tool_name logger = logging.getLogger(__name__) @@ -111,9 +111,8 @@ def format_schemas_for_context(entities: list[Entity]) -> str: if not entities: return "" - lines = [] + lines: list[str] = [] - # Add SQL generation strategy from system_prompt.txt system_prompt = _load_system_prompt() if system_prompt: lines.append("## SQL Query Generation Guidelines") @@ -121,7 +120,6 @@ def format_schemas_for_context(entities: list[Entity]) -> str: lines.append(system_prompt) lines.append("") - # Add SQL constraints from sql_constraints.txt sql_constraints = _load_sql_constraints() if sql_constraints: lines.append("## SQL Constraints") @@ -141,10 +139,9 @@ def format_schemas_for_context(entities: list[Entity]) -> str: lines.append("| Field | Type |") lines.append("|-------|------|") - # Collect field info for query pattern examples - field_names = [] - numeric_field = None - text_field = None + field_names: list[str] = [] + numeric_field: str | None = None + text_field: str | None = None for field in entity.fields or []: if field.is_hidden_field or field.is_system_field: @@ -154,7 +151,6 @@ def format_schemas_for_context(entities: list[Entity]) -> str: field_names.append(field_name) lines.append(f"| {field_name} | {field_type} |") - # Track field types for examples sql_type = field.sql_type.name.lower() if field.sql_type else "" if not numeric_field and sql_type in ("int", "decimal", "float", "double", "bigint"): numeric_field = field_name @@ -163,7 +159,6 @@ def format_schemas_for_context(entities: list[Entity]) -> str: lines.append("") - # Add entity-specific query pattern examples group_field = text_field or (field_names[0] if field_names else "Category") agg_field = numeric_field or (field_names[1] if len(field_names) > 1 else "Amount") filter_field = text_field or (field_names[0] if field_names else "Name") @@ -198,201 +193,130 @@ def get_datafabric_contexts( Returns: List of context resources configured for Data Fabric retrieval mode. """ - datafabric_contexts: list[AgentContextResourceConfig] = [] + return _filter_datafabric_contexts(agent.resources) - for resource in agent.resources: - if not isinstance(resource, AgentContextResourceConfig): - continue - if not resource.is_enabled: - continue - if resource.settings.retrieval_mode.lower() == "datafabric": - datafabric_contexts.append(resource) - - return datafabric_contexts - - -# --- Tool Creation --- - - -class QueryEntityInput(BaseModel): - """Input schema for the query_entity tool.""" - - entity_identifier: str = Field( - ..., description="The entity identifier to query" - ) - sql_where: str = Field( - default="", - description="SQL WHERE clause to filter records (without the WHERE keyword). " - "Example: 'Status = \"Active\" AND Amount > 100'", - ) - limit: int = Field( - default=1000, description="Maximum number of records to return" - ) - - -class QueryEntityOutput(BaseModel): - """Output schema for the query_entity tool.""" - - records: list[dict[str, Any]] = Field( - ..., description="List of entity records matching the query" - ) - total_count: int = Field(..., description="Total number of matching records") +def _filter_datafabric_contexts( + resources: Sequence[BaseAgentResourceConfig], +) -> list[AgentContextResourceConfig]: + """Filter resources to only Data Fabric context configs.""" + return [ + resource + for resource in resources + if isinstance(resource, AgentContextResourceConfig) + and resource.is_enabled + and resource.settings.retrieval_mode.lower() == "datafabric" + ] -async def create_datafabric_tools( - agent: LowCodeAgentDefinition, -) -> tuple[list[BaseTool], str]: - """Create Data Fabric tools and schema context from agent definition. - This function: - 1. Finds all Data Fabric context resources in the agent - 2. Fetches entity schemas for context hydration - 3. Returns tools and schema context string +def get_datafabric_entity_identifiers_from_resources( + resources: Sequence[BaseAgentResourceConfig], +) -> list[str]: + """Extract Data Fabric entity identifiers from a sequence of resource configs. Args: - agent: The agent definition containing Data Fabric context resources. + resources: Resource configs (typically from ``agent_definition.resources``). Returns: - Tuple of (tools, schema_context) where: - - tools: List of BaseTool instances for querying entities - - schema_context: Formatted schema string to inject into system prompt + Flat list of entity identifier strings across all Data Fabric contexts. """ - tools: list[BaseTool] = [] - all_entities: list[Entity] = [] - - datafabric_contexts = get_datafabric_contexts(agent) + identifiers: list[str] = [] + for context in _filter_datafabric_contexts(resources): + identifiers.extend(context.settings.entity_identifiers or []) + return identifiers - if not datafabric_contexts: - return tools, "" - logger.info(f"Found {len(datafabric_contexts)} Data Fabric context resource(s)") +# --- Generic Tool Creation --- - for context in datafabric_contexts: - entity_identifiers = context.settings.entity_identifiers or [] +_MAX_RECORDS_IN_RESPONSE = 50 - if not entity_identifiers: - logger.warning( - f"Data Fabric context '{context.name}' has no entity_identifiers configured" - ) - continue - # Fetch entity schemas - entities = await fetch_entity_schemas(entity_identifiers) - all_entities.extend(entities) +def create_datafabric_query_tool() -> BaseTool: + """Create a single generic ``query_datafabric`` tool. - context_tools = _create_sdk_based_tools(context, entities) - tools.extend(context_tools) + The tool accepts an arbitrary SQL SELECT query and dispatches it to + ``sdk.entities.query_entity_records_async()``. Entity knowledge is + *not* baked in — the LLM receives schema guidance via the system + prompt (injected at INIT time) and constructs raw SQL. + """ - logger.info( - f"Created {len(context_tools)} tools for Data Fabric context '{context.name}'" - ) + async def _query_datafabric(sql_query: str) -> dict[str, Any]: + from uipath.platform import UiPath - # Format all entity schemas for context injection - schema_context = format_schemas_for_context(all_entities) + logger.debug(f"query_datafabric called with SQL: {sql_query}") - return tools, schema_context + sdk = UiPath() + try: + records = await sdk.entities.query_entity_records_async( + sql_query=sql_query, + ) + total_count = len(records) + truncated = total_count > _MAX_RECORDS_IN_RESPONSE + returned_records = records[:_MAX_RECORDS_IN_RESPONSE] if truncated else records + + result: dict[str, Any] = { + "records": returned_records, + "total_count": total_count, + "returned_count": len(returned_records), + "sql_query": sql_query, + } + if truncated: + result["truncated"] = True + result["message"] = ( + f"Showing {len(returned_records)} of {total_count} records. " + "Use more specific filters or LIMIT to narrow results." + ) + return result + except Exception as e: + logger.error(f"SQL query failed: {e}") + return { + "records": [], + "total_count": 0, + "error": str(e), + "sql_query": sql_query, + } + + return BaseUiPathStructuredTool( + name="query_datafabric", + description=( + "Execute a SQL SELECT query against Data Fabric entities. " + "Refer to the entity schemas in the system prompt for available tables and columns. " + "Include LIMIT unless aggregating." + ), + args_schema={ + "type": "object", + "properties": { + "sql_query": { + "type": "string", + "description": ( + "Complete SQL SELECT statement. " + "Use exact table and column names from the entity schemas in the system prompt." + ), + }, + }, + "required": ["sql_query"], + }, + coroutine=_query_datafabric, + metadata={"tool_type": "datafabric_sql"}, + ) -def _create_sdk_based_tools( - context: AgentContextResourceConfig, - entities: list[Entity], +def create_datafabric_tools( + agent: LowCodeAgentDefinition, ) -> list[BaseTool]: - """Create SDK-based tools for querying entities using SQL. + """Register the generic Data Fabric query tool if the agent has DF contexts. - Each tool accepts a full SQL query and executes it via the SDK's - query_entity_records_async method. - """ - tools: list[BaseTool] = [] - MAX_RECORDS_IN_RESPONSE = 50 # Limit records to prevent context overflow + No fetching, no formatting, no schema — purely tool registration. + Schema hydration happens at INIT time. - for entity in entities: - tool_name = sanitize_tool_name(f"query_{entity.name}") - - # Create a closure to capture the entity name - entity_display_name = entity.display_name or entity.name - - async def query_fn( - sql_query: str, - _entity_name: str = entity_display_name, - _max_records: int = MAX_RECORDS_IN_RESPONSE, - ) -> dict[str, Any]: - """Execute a SQL query against the Data Fabric entity.""" - from uipath.platform import UiPath - - print(f"[DEBUG] query_fn called for entity '{_entity_name}' with SQL: {sql_query}") - logger.info(f"Executing SQL query for entity '{_entity_name}': {sql_query}") - - sdk = UiPath() - try: - records = await sdk.entities.query_entity_records_async( - sql_query=sql_query, - ) - total_count = len(records) - truncated = total_count > _max_records - returned_records = records[:_max_records] if truncated else records - - print(f"[DEBUG] Retrieved {total_count} records, returning {len(returned_records)} for entity '{_entity_name}'") - - result = { - "records": returned_records, - "total_count": total_count, - "returned_count": len(returned_records), - "entity": _entity_name, - "sql_query": sql_query, - } - if truncated: - result["truncated"] = True - result["message"] = f"Showing {len(returned_records)} of {total_count} records. Use more specific filters or LIMIT to narrow results." - return result - except Exception as e: - logger.error(f"SQL query failed for entity '{_entity_name}': {e}") - return { - "records": [], - "total_count": 0, - "error": str(e), - "sql_query": sql_query, - } - - entity_description = entity.description or f"Query {entity_display_name} records" - - # Extract actual field names from entity schema (exclude system/hidden fields) - field_names = [ - f.display_name or f.name - for f in (entity.fields or []) - if not f.is_hidden_field and not f.is_system_field - ] - fields_str = ", ".join(field_names[:10]) # Limit to first 10 fields - if len(field_names) > 10: - fields_str += f", ... ({len(field_names)} total)" - - tools.append( - BaseUiPathStructuredTool( - name=tool_name, - description=( - f"{context.description}. {entity_description}. " - f"Available fields: {fields_str}. " - f"Use SQL patterns from system prompt based on user intent." - ), - args_schema={ - "type": "object", - "properties": { - "sql_query": { - "type": "string", - "description": ( - f"Complete SQL SELECT statement for {entity_display_name}. " - f"Use exact column names from schema. Include LIMIT unless aggregating." - ), - }, - }, - "required": ["sql_query"], - }, - coroutine=query_fn, - metadata={ - "tool_type": "datafabric_sql", - "display_name": f"Query {entity_display_name}", - "entity_name": entity_display_name, - }, - ) - ) + Args: + agent: The agent definition containing Data Fabric context resources. + + Returns: + A list containing the single ``query_datafabric`` tool, or empty. + """ + if not get_datafabric_entity_identifiers_from_resources(agent.resources): + return [] - return tools + logger.info("Registering generic query_datafabric tool") + return [create_datafabric_query_tool()] diff --git a/src/uipath_langchain/agent/tools/tool_factory.py b/src/uipath_langchain/agent/tools/tool_factory.py index 5ba0bf1e5..ae05d1178 100644 --- a/src/uipath_langchain/agent/tools/tool_factory.py +++ b/src/uipath_langchain/agent/tools/tool_factory.py @@ -30,7 +30,7 @@ async def create_tools_from_resources( agent: LowCodeAgentDefinition, llm: BaseChatModel -) -> tuple[list[BaseTool], str]: +) -> list[BaseTool]: """Create tools from agent resources including Data Fabric tools. Args: @@ -38,17 +38,14 @@ async def create_tools_from_resources( llm: The language model for tool creation. Returns: - Tuple of (tools, datafabric_schema_context). + List of BaseTool instances. """ tools: list[BaseTool] = [] - datafabric_schema_context: str = "" logger.info("Creating tools for agent '%s' from resources", agent.name) - # Handle Data Fabric tools first (they need special handling) - datafabric_tools, schema_context = await create_datafabric_tools(agent) - tools.extend(datafabric_tools) - datafabric_schema_context = schema_context + # Register the generic Data Fabric query tool (no fetching/schema here) + tools.extend(create_datafabric_tools(agent)) for resource in agent.resources: if not resource.is_enabled: @@ -71,7 +68,7 @@ async def create_tools_from_resources( else: tools.append(tool) - return tools, datafabric_schema_context + return tools async def _build_tool_for_resource( diff --git a/tests/agent/react/test_create_agent.py b/tests/agent/react/test_create_agent.py index 0ea8e3253..1888ff203 100644 --- a/tests/agent/react/test_create_agent.py +++ b/tests/agent/react/test_create_agent.py @@ -171,6 +171,7 @@ def test_autonomous_agent_with_tools( llm_provider=mock_model.llm_provider, api_flavor=mock_model.api_flavor, ), + resources_for_init=None, ) mock_create_terminate_node.assert_called_once_with( None, # output schema @@ -270,6 +271,7 @@ def test_conversational_agent_with_tools( llm_provider=mock_model.llm_provider, api_flavor=mock_model.api_flavor, ), + resources_for_init=None, ) mock_create_terminate_node.assert_called_once_with( None, # output schema diff --git a/tests/agent/react/test_init_node.py b/tests/agent/react/test_init_node.py index b1c95ae99..7812d329d 100644 --- a/tests/agent/react/test_init_node.py +++ b/tests/agent/react/test_init_node.py @@ -16,6 +16,7 @@ class MockState(BaseModel): messages: list[Any] = [] +@pytest.mark.asyncio class TestCreateInitNodeConversational: """Test cases for create_init_node with is_conversational=True.""" @@ -43,7 +44,7 @@ def empty_state(self): """Fixture for state with no messages.""" return MockState(messages=[]) - def test_conversational_empty_state_returns_overwrite( + async def test_conversational_empty_state_returns_overwrite( self, system_message, user_message ): """Conversational mode with empty state should use Overwrite with new messages.""" @@ -53,7 +54,7 @@ def test_conversational_empty_state_returns_overwrite( ) state = MockState(messages=[]) - result = init_node(state) + result = await init_node(state) assert "messages" in result assert isinstance(result["messages"], Overwrite) @@ -63,7 +64,7 @@ def test_conversational_empty_state_returns_overwrite( assert overwrite_value[0] == system_message assert overwrite_value[1] == user_message - def test_conversational_resume_replaces_system_message( + async def test_conversational_resume_replaces_system_message( self, new_system_message, user_message ): """Conversational mode should replace old SystemMessage when resuming.""" @@ -79,7 +80,7 @@ def test_conversational_resume_replaces_system_message( new_messages, input_schema=None, is_conversational=True ) - result = init_node(state) + result = await init_node(state) assert isinstance(result["messages"], Overwrite) overwrite_value = result["messages"].value @@ -89,7 +90,7 @@ def test_conversational_resume_replaces_system_message( assert overwrite_value[1] == user_message assert overwrite_value[2] == old_human_message - def test_conversational_resume_preserves_non_system_first_message( + async def test_conversational_resume_preserves_non_system_first_message( self, system_message, user_message ): """Conversational mode should preserve all messages if first is not SystemMessage.""" @@ -102,7 +103,7 @@ def test_conversational_resume_preserves_non_system_first_message( new_messages, input_schema=None, is_conversational=True ) - result = init_node(state) + result = await init_node(state) assert isinstance(result["messages"], Overwrite) overwrite_value = result["messages"].value @@ -112,10 +113,10 @@ def test_conversational_resume_preserves_non_system_first_message( assert overwrite_value[1] == user_message assert overwrite_value[2] == existing_human_message - def test_conversational_with_callable_messages(self): + async def test_conversational_with_callable_messages(self): """Conversational mode should work with callable message generators.""" - def message_generator(state): + def message_generator(state, **kwargs): return [ SystemMessage( content=f"System for state with {len(state.messages)} messages" @@ -128,7 +129,7 @@ def message_generator(state): message_generator, input_schema=None, is_conversational=True ) - result = init_node(state) + result = await init_node(state) assert isinstance(result["messages"], Overwrite) overwrite_value = result["messages"].value @@ -136,6 +137,7 @@ def message_generator(state): assert "System for state with 0 messages" in overwrite_value[0].content +@pytest.mark.asyncio class TestCreateInitNodeNonConversational: """Test cases for create_init_node with is_conversational=False (default).""" @@ -149,7 +151,7 @@ def user_message(self): """Fixture for a user message.""" return HumanMessage(content="Hello") - def test_non_conversational_returns_list_not_overwrite( + async def test_non_conversational_returns_list_not_overwrite( self, system_message, user_message ): """Non-conversational mode should return list, not Overwrite.""" @@ -159,7 +161,7 @@ def test_non_conversational_returns_list_not_overwrite( ) state = MockState(messages=[]) - result = init_node(state) + result = await init_node(state) assert "messages" in result # Non-conversational mode returns a list (for add_messages reducer to append) @@ -167,22 +169,23 @@ def test_non_conversational_returns_list_not_overwrite( assert not isinstance(result["messages"], Overwrite) assert len(result["messages"]) == 2 - def test_non_conversational_default_behavior(self, system_message, user_message): + async def test_non_conversational_default_behavior(self, system_message, user_message): """Default behavior (no is_conversational param) should be non-conversational.""" messages = [system_message, user_message] init_node = create_init_node(messages, input_schema=None) state = MockState(messages=[]) - result = init_node(state) + result = await init_node(state) assert isinstance(result["messages"], list) assert not isinstance(result["messages"], Overwrite) +@pytest.mark.asyncio class TestCreateInitNodeInnerState: """Test cases for init node inner_state initialization.""" - def test_returns_inner_state_with_job_attachments(self): + async def test_returns_inner_state_with_job_attachments(self): """Init node should return inner_state with job_attachments dict.""" messages: list[SystemMessage | HumanMessage] = [ SystemMessage(content="System"), @@ -193,13 +196,13 @@ def test_returns_inner_state_with_job_attachments(self): ) state = MockState(messages=[]) - result = init_node(state) + result = await init_node(state) assert "inner_state" in result assert "job_attachments" in result["inner_state"] assert isinstance(result["inner_state"]["job_attachments"], dict) - def test_inner_state_present_in_conversational_mode(self): + async def test_inner_state_present_in_conversational_mode(self): """Inner state should also be present in conversational mode.""" messages: list[SystemMessage | HumanMessage] = [ SystemMessage(content="System"), @@ -210,7 +213,7 @@ def test_inner_state_present_in_conversational_mode(self): ) state = MockState(messages=[]) - result = init_node(state) + result = await init_node(state) assert "inner_state" in result assert "job_attachments" in result["inner_state"] From 767c707f206bd73a943920b77b7cf7660c2b197a Mon Sep 17 00:00:00 2001 From: Harshit Rohatgi Date: Tue, 17 Mar 2026 13:45:49 +0530 Subject: [PATCH 5/8] Added updated agent definition --- .../agent/tools/context_tool.py | 534 +++++++----------- .../tools/datafabric_tool/datafabric_tool.py | 52 +- .../agent/tools/tool_factory.py | 5 +- tests/agent/react/test_init_node.py | 4 +- 4 files changed, 233 insertions(+), 362 deletions(-) diff --git a/src/uipath_langchain/agent/tools/context_tool.py b/src/uipath_langchain/agent/tools/context_tool.py index 184f1f180..435d96871 100644 --- a/src/uipath_langchain/agent/tools/context_tool.py +++ b/src/uipath_langchain/agent/tools/context_tool.py @@ -1,101 +1,58 @@ """Context tool creation for semantic index retrieval.""" import uuid -from typing import Any, Dict, Optional +from typing import Any, Optional, Type from langchain_core.documents import Document -from langchain_core.messages import ToolCall -from langchain_core.tools import BaseTool, StructuredTool -from pydantic import BaseModel, Field, TypeAdapter, create_model +from langchain_core.tools import StructuredTool +from langgraph.types import interrupt +from pydantic import BaseModel, Field from uipath.agent.models.agent import ( AgentContextResourceConfig, AgentContextRetrievalMode, - AgentToolArgumentProperties, ) from uipath.eval.mocks import mockable -from uipath.platform import UiPath -from uipath.platform.common import CreateBatchTransform, CreateDeepRag, UiPathConfig +from uipath.platform.common import CreateBatchTransform, CreateDeepRag from uipath.platform.context_grounding import ( BatchTransformOutputColumn, + BatchTransformResponse, CitationMode, - DeepRagContent, + DeepRagResponse, ) -from uipath.runtime.errors import UiPathErrorCategory -from uipath_langchain._utils import get_execution_folder_path -from uipath_langchain.agent.exceptions import AgentStartupError, AgentStartupErrorCode -from uipath_langchain.agent.react.jsonschema_pydantic_converter import ( - create_model as create_model_from_schema, -) -from uipath_langchain.agent.react.types import AgentGraphState -from uipath_langchain.agent.tools.internal_tools.schema_utils import ( - BATCH_TRANSFORM_OUTPUT_SCHEMA, -) -from uipath_langchain.agent.tools.static_args import handle_static_args from uipath_langchain.retrievers import ContextGroundingRetriever -from .durable_interrupt import durable_interrupt -from .structured_tool_with_argument_properties import ( - StructuredToolWithArgumentProperties, -) from .structured_tool_with_output_type import StructuredToolWithOutputType -from .tool_node import ToolWrapperReturnType from .utils import sanitize_tool_name -_ARG_PROPS_ADAPTER = TypeAdapter(Dict[str, AgentToolArgumentProperties]) - - -def _get_argument_properties( - resource: AgentContextResourceConfig, -) -> dict[str, AgentToolArgumentProperties]: - """Extract argumentProperties from the resource's extra fields. - - AgentContextResourceConfig doesn't declare argument_properties yet, - but BaseCfg(extra="allow") preserves the raw JSON value. - """ - raw = ( - resource.model_extra.get("argumentProperties") if resource.model_extra else None - ) - if not raw: - return {} - return _ARG_PROPS_ADAPTER.validate_python(raw) - - -def _build_folder_path_prefix_arg_props( - resource: AgentContextResourceConfig, -) -> dict[str, Any]: - """Build argument_properties for folder_path_prefix from settings. - - Fallback for when settings bag doesn't include argumentProperties - at the resource level but does set settings.folder_path_prefix - with variant="argument". - """ - assert resource.settings.folder_path_prefix is not None - argument_path = (resource.settings.folder_path_prefix.value or "").strip("{}") - return { - "folder_path_prefix": { - "variant": "argument", - "argumentPath": argument_path, - "isSensitive": False, - } - } - def is_static_query(resource: AgentContextResourceConfig) -> bool: """Check if the resource configuration uses a static query variant.""" - if resource.settings.query is None or resource.settings.query.variant is None: + if ( + resource.settings is None + or resource.settings.query is None + or resource.settings.query.variant is None + ): return False return resource.settings.query.variant.lower() == "static" def create_context_tool(resource: AgentContextResourceConfig) -> StructuredTool: + if resource.settings is None: + raise ValueError( + f"Context resource '{resource.name}' is missing required settings." + ) + if resource.index_name is None: + raise ValueError( + f"Context resource '{resource.name}' is missing required index name." + ) tool_name = sanitize_tool_name(resource.name) retrieval_mode = resource.settings.retrieval_mode.lower() if retrieval_mode == AgentContextRetrievalMode.DEEP_RAG.value.lower(): return handle_deep_rag(tool_name, resource) elif retrieval_mode == AgentContextRetrievalMode.BATCH_TRANSFORM.value.lower(): return handle_batch_transform(tool_name, resource) - elif retrieval_mode == "datafabric": + elif retrieval_mode == AgentContextRetrievalMode.DATA_FABRIC.value.lower(): # Data Fabric contexts are handled by create_datafabric_tools() in tool_factory.py raise ValueError( "Data Fabric context should be handled via create_datafabric_tools(), " @@ -110,19 +67,18 @@ def handle_semantic_search( ) -> StructuredTool: ensure_valid_fields(resource) + # needed for type checking + assert resource.settings is not None + assert resource.index_name is not None + assert resource.settings.query is not None assert resource.settings.query.variant is not None retriever = ContextGroundingRetriever( index_name=resource.index_name, - folder_path=get_execution_folder_path(), + folder_path=resource.folder_path, number_of_results=resource.settings.result_count, ) - static = is_static_query(resource) - prompt = resource.settings.query.value if static else None - if static: - assert prompt is not None - class ContextOutputSchemaModel(BaseModel): documents: list[Document] = Field( ..., description="List of retrieved documents." @@ -130,35 +86,39 @@ class ContextOutputSchemaModel(BaseModel): output_model = ContextOutputSchemaModel - schema_fields: dict[str, Any] = ( - {} - if static - else { - "query": ( - str, - Field(..., description="The query to search for in the knowledge base"), - ), - } - ) - input_model = create_model("SemanticSearchInput", **schema_fields) + if is_static_query(resource): + static_query_value = resource.settings.query.value + assert static_query_value is not None + input_model = None + + @mockable( + name=resource.name, + description=resource.description, + input_schema=input_model, + output_schema=output_model.model_json_schema(), + example_calls=[], # Examples cannot be provided for context. + ) + async def context_tool_fn() -> dict[str, Any]: + return {"documents": await retriever.ainvoke(static_query_value)} - @mockable( - name=resource.name, - description=resource.description, - input_schema=input_model.model_json_schema(), - output_schema=output_model.model_json_schema(), - example_calls=[], # Examples cannot be provided for context. - ) - async def context_tool_fn(query: Optional[str] = None) -> dict[str, Any]: - actual_query = prompt or query - assert actual_query is not None - docs = await retriever.ainvoke(actual_query) - return { - "documents": [ - {"metadata": doc.metadata, "page_content": doc.page_content} - for doc in docs - ] - } + else: + # Dynamic query - requires query parameter + class ContextInputSchemaModel(BaseModel): + query: str = Field( + ..., description="The query to search for in the knowledge base" + ) + + input_model = ContextInputSchemaModel + + @mockable( + name=resource.name, + description=resource.description, + input_schema=input_model.model_json_schema(), + output_schema=output_model.model_json_schema(), + example_calls=[], # Examples cannot be provided for context. + ) + async def context_tool_fn(query: str) -> dict[str, Any]: + return {"documents": await retriever.ainvoke(query)} return StructuredToolWithOutputType( name=tool_name, @@ -169,8 +129,6 @@ async def context_tool_fn(query: Optional[str] = None) -> dict[str, Any]: metadata={ "tool_type": "context", "display_name": resource.name, - "index_name": resource.index_name, - "context_retrieval_mode": resource.settings.retrieval_mode, }, ) @@ -180,116 +138,80 @@ def handle_deep_rag( ) -> StructuredTool: ensure_valid_fields(resource) + # needed for type checking + assert resource.settings is not None + assert resource.index_name is not None + assert resource.settings.query is not None assert resource.settings.query.variant is not None index_name = resource.index_name if not resource.settings.citation_mode: - raise AgentStartupError( - code=AgentStartupErrorCode.INVALID_TOOL_CONFIG, - title="Missing citation mode", - detail="Citation mode is required for Deep RAG. Please set the citation_mode field in context settings.", - category=UiPathErrorCategory.USER, - ) + raise ValueError("Citation mode is required for Deep RAG") citation_mode = CitationMode(resource.settings.citation_mode.value) - static = is_static_query(resource) - prompt = resource.settings.query.value if static else None - if static: - assert prompt is not None - - static_folder_path_prefix = None - if ( - resource.settings.folder_path_prefix - and resource.settings.folder_path_prefix.value - and resource.settings.folder_path_prefix.variant == "static" - ): - static_folder_path_prefix = resource.settings.folder_path_prefix.value + output_model = DeepRagResponse - file_extension = None - if resource.settings.file_extension and resource.settings.file_extension.value: - file_extension = resource.settings.file_extension.value + if is_static_query(resource): + # Static query - no input parameter needed + static_prompt = resource.settings.query.value + assert static_prompt is not None + input_model = None - output_model = create_model( - "DeepRagOutputModel", - __base__=DeepRagContent, - deep_rag_id=(str, Field(alias="deepRagId")), - ) - - arg_props = _get_argument_properties(resource) - - has_folder_path_prefix_arg = "folder_path_prefix" in arg_props or ( - resource.settings.folder_path_prefix - and resource.settings.folder_path_prefix.variant == "argument" - ) - - schema_fields: dict[str, Any] = ( - {} - if static - else { - "query": ( - str, - Field( - ..., - description="Describe the task: what to research across documents, what to synthesize, and how to cite sources", - ), - ), - } - ) - - if has_folder_path_prefix_arg: - schema_fields["folder_path_prefix"] = ( - str, - Field( - default=None, - description="The folder path prefix within the index to filter on", - ), + @mockable( + name=resource.name, + description=resource.description, + input_schema=input_model, + output_schema=output_model.model_json_schema(), + example_calls=[], # Examples cannot be provided for context. ) - if "folder_path_prefix" not in arg_props: - arg_props = _build_folder_path_prefix_arg_props(resource) + async def context_tool_fn() -> dict[str, Any]: + # TODO: add glob pattern support + return interrupt( + CreateDeepRag( + name=f"task-{uuid.uuid4()}", + index_name=index_name, + prompt=static_prompt, + citation_mode=citation_mode, + ) + ) - input_model = create_model("DeepRagInput", **schema_fields) + else: + # Dynamic query - requires query parameter + class DeepRagInputSchemaModel(BaseModel): + query: str = Field( + ..., + description="Describe the task: what to research across documents, what to synthesize, and how to cite sources", + ) - @mockable( - name=resource.name, - description=resource.description, - input_schema=input_model.model_json_schema(), - output_schema=output_model.model_json_schema(), - example_calls=[], # Examples cannot be provided for context. - ) - async def context_tool_fn( - query: Optional[str] = None, folder_path_prefix: Optional[str] = None - ) -> dict[str, Any]: - actual_prompt = prompt or query - glob_pattern = build_glob_pattern( - folder_path_prefix=static_folder_path_prefix or folder_path_prefix, - file_extension=file_extension, - ) + input_model = DeepRagInputSchemaModel - @durable_interrupt - async def create_deep_rag(): - return CreateDeepRag( - name=f"task-{uuid.uuid4()}", - index_name=index_name, - prompt=actual_prompt, - citation_mode=citation_mode, - index_folder_path=get_execution_folder_path(), - glob_pattern=glob_pattern, + @mockable( + name=resource.name, + description=resource.description, + input_schema=input_model.model_json_schema(), + output_schema=output_model.model_json_schema(), + example_calls=[], # Examples cannot be provided for context. + ) + async def context_tool_fn(query: str) -> dict[str, Any]: + # TODO: add glob pattern support + return interrupt( + CreateDeepRag( + name=f"task-{uuid.uuid4()}", + index_name=index_name, + prompt=query, + citation_mode=citation_mode, + ) ) - return await create_deep_rag() - - return StructuredToolWithArgumentProperties( + return StructuredToolWithOutputType( name=tool_name, description=resource.description, args_schema=input_model, coroutine=context_tool_fn, output_type=output_model, - argument_properties=arg_props, metadata={ "tool_type": "context", "display_name": resource.name, - "index_name": resource.index_name, - "context_retrieval_mode": resource.settings.retrieval_mode, }, ) @@ -299,18 +221,16 @@ def handle_batch_transform( ) -> StructuredTool: ensure_valid_fields(resource) + # needed for type checking + assert resource.settings is not None + assert resource.index_name is not None assert resource.settings.query is not None assert resource.settings.query.variant is not None index_name = resource.index_name - index_folder_path = get_execution_folder_path() + index_folder_path = resource.folder_path if not resource.settings.web_search_grounding: - raise AgentStartupError( - code=AgentStartupErrorCode.INVALID_TOOL_CONFIG, - title="Missing web search grounding", - detail="Web search grounding field is required for Batch Transform. Please set the web_search_grounding field in context settings.", - category=UiPathErrorCategory.USER, - ) + raise ValueError("Web search grounding field is required for Batch Transform") enable_web_search_grounding = ( resource.settings.web_search_grounding.value.lower() == "enabled" ) @@ -319,11 +239,8 @@ def handle_batch_transform( if (output_columns := resource.settings.output_columns) is None or not len( output_columns ): - raise AgentStartupError( - code=AgentStartupErrorCode.INVALID_TOOL_CONFIG, - title="Missing output columns", - detail="Batch transform requires at least one output column to be specified in settings.output_columns. Please add output columns to the context configuration.", - category=UiPathErrorCategory.USER, + raise ValueError( + "Batch transform requires at least one output column to be specified in settings.output_columns" ) for column in output_columns: @@ -334,176 +251,103 @@ def handle_batch_transform( ) ) - static = is_static_query(resource) - prompt = resource.settings.query.value if static else None - if static: - assert prompt is not None + output_model = BatchTransformResponse - static_folder_path_prefix = None - if ( - resource.settings.folder_path_prefix - and resource.settings.folder_path_prefix.value - and resource.settings.folder_path_prefix.variant == "static" - ): - static_folder_path_prefix = resource.settings.folder_path_prefix.value + input_model: Optional[Type[BaseModel]] - arg_props = _get_argument_properties(resource) + if is_static_query(resource): + # Static query - only destination_path parameter needed + static_prompt = resource.settings.query.value + assert static_prompt is not None - has_folder_path_prefix_arg = "folder_path_prefix" in arg_props or ( - resource.settings.folder_path_prefix - and resource.settings.folder_path_prefix.variant == "argument" - ) - - output_model = create_model_from_schema(BATCH_TRANSFORM_OUTPUT_SCHEMA) + class StaticBatchTransformSchemaModel(BaseModel): + destination_path: str = Field( + default="output.csv", + description="The relative file path destination for the modified csv file", + ) - schema_fields: dict[str, Any] = {} - if not static: - schema_fields["query"] = ( - str, - Field( - ..., - description="Describe the task for each row: what to analyze, what to extract, and how to populate the output columns", - ), - ) - schema_fields["destination_path"] = ( - str, - Field( - default="output.csv", - description="The relative file path destination for the modified csv file", - ), - ) - if has_folder_path_prefix_arg: - schema_fields["folder_path_prefix"] = ( - str, - Field( - default=None, - description="The folder path prefix within the index to filter on", - ), - ) - if "folder_path_prefix" not in arg_props: - arg_props = _build_folder_path_prefix_arg_props(resource) - input_model = create_model("BatchTransformInput", **schema_fields) + input_model = StaticBatchTransformSchemaModel - @mockable( - name=resource.name, - description=resource.description, - input_schema=input_model.model_json_schema(), - output_schema=output_model.model_json_schema(), - example_calls=[], # Examples cannot be provided for context. - ) - async def context_tool_fn( - query: Optional[str] = None, - destination_path: str = "output.csv", - folder_path_prefix: Optional[str] = None, - ) -> dict[str, Any]: - actual_prompt = prompt or query - glob_pattern = build_glob_pattern( - folder_path_prefix=static_folder_path_prefix or folder_path_prefix, - file_extension=None, + @mockable( + name=resource.name, + description=resource.description, + input_schema=input_model.model_json_schema(), + output_schema=output_model.model_json_schema(), + example_calls=[], # Examples cannot be provided for context. ) + async def context_tool_fn( + destination_path: str = "output.csv", + ) -> dict[str, Any]: + # TODO: storage_bucket_folder_path_prefix support + return interrupt( + CreateBatchTransform( + name=f"task-{uuid.uuid4()}", + index_name=index_name, + prompt=static_prompt, + destination_path=destination_path, + index_folder_path=index_folder_path, + enable_web_search_grounding=enable_web_search_grounding, + output_columns=batch_transform_output_columns, + ) + ) - @durable_interrupt - async def create_batch_transform(): - return CreateBatchTransform( - name=f"task-{uuid.uuid4()}", - index_name=index_name, - prompt=actual_prompt, - destination_path=destination_path, - index_folder_path=index_folder_path, - enable_web_search_grounding=enable_web_search_grounding, - output_columns=batch_transform_output_columns, - storage_bucket_folder_path_prefix=glob_pattern, + else: + # Dynamic query - requires both query and destination_path parameters + class DynamicBatchTransformSchemaModel(BaseModel): + query: str = Field( + ..., + description="Describe the task for each row: what to analyze, what to extract, and how to populate the output columns", + ) + destination_path: str = Field( + default="output.csv", + description="The relative file path destination for the modified csv file", ) - await create_batch_transform() + input_model = DynamicBatchTransformSchemaModel - uipath = UiPath() - result_attachment_id = await uipath.jobs.create_attachment_async( - name=destination_path, - source_path=destination_path, - job_key=UiPathConfig.job_key, + @mockable( + name=resource.name, + description=resource.description, + input_schema=input_model.model_json_schema(), + output_schema=output_model.model_json_schema(), + example_calls=[], # Examples cannot be provided for context. ) + async def context_tool_fn( + query: str, destination_path: str = "output.csv" + ) -> dict[str, Any]: + # TODO: storage_bucket_folder_path_prefix support + return interrupt( + CreateBatchTransform( + name=f"task-{uuid.uuid4()}", + index_name=index_name, + prompt=query, + destination_path=destination_path, + index_folder_path=index_folder_path, + enable_web_search_grounding=enable_web_search_grounding, + output_columns=batch_transform_output_columns, + ) + ) - return { - "result": { - "ID": str(result_attachment_id), - "FullName": destination_path, - "MimeType": "text/csv", - } - } - - from uipath_langchain.agent.wrappers import get_job_attachment_wrapper - - job_attachment_wrapper = get_job_attachment_wrapper(output_type=output_model) - - async def context_batch_transform_wrapper( - tool: BaseTool, - call: ToolCall, - state: AgentGraphState, - ) -> ToolWrapperReturnType: - call["args"] = handle_static_args(resource, state, call["args"]) - return await job_attachment_wrapper(tool, call, state) - - tool = StructuredToolWithArgumentProperties( + return StructuredToolWithOutputType( name=tool_name, description=resource.description, args_schema=input_model, coroutine=context_tool_fn, output_type=output_model, - argument_properties=arg_props, metadata={ "tool_type": "context", "display_name": resource.name, - "index_name": resource.index_name, - "context_retrieval_mode": resource.settings.retrieval_mode, - "output_schema": output_model, }, ) - tool.set_tool_wrappers(awrapper=context_batch_transform_wrapper) - return tool def ensure_valid_fields(resource_config: AgentContextResourceConfig): + assert resource_config.settings is not None + if not resource_config.settings.query: + raise ValueError("Query object is required") + if not resource_config.settings.query.variant: - raise AgentStartupError( - code=AgentStartupErrorCode.INVALID_TOOL_CONFIG, - title="Missing query variant", - detail="Query variant is required. Please set the query variant in context settings.", - category=UiPathErrorCategory.USER, - ) + raise ValueError("Query variant is required") if is_static_query(resource_config) and not resource_config.settings.query.value: - raise AgentStartupError( - code=AgentStartupErrorCode.INVALID_TOOL_CONFIG, - title="Missing static query value", - detail="Static query requires a query value to be set. Please provide a value for the static query in context settings.", - category=UiPathErrorCategory.USER, - ) - - -def build_glob_pattern( - folder_path_prefix: str | None, file_extension: str | None -) -> str: - # Handle prefix - prefix = "**" - if folder_path_prefix: - prefix = folder_path_prefix.rstrip("/") - - if not prefix.startswith("**"): - if prefix.startswith("/"): - prefix = prefix[1:] - - # Handle extension - extension = "*" - if file_extension: - ext = file_extension.lower() - if ext in {"pdf", "txt", "docx", "csv"}: - extension = f"*.{ext}" - else: - extension = f"*.{ext}" - - # Final pattern logic - if not prefix or prefix == "**": - return "**/*" if extension == "*" else f"**/{extension}" - - return f"{prefix}/{extension}" + raise ValueError("Static query requires a query value to be set") diff --git a/src/uipath_langchain/agent/tools/datafabric_tool/datafabric_tool.py b/src/uipath_langchain/agent/tools/datafabric_tool/datafabric_tool.py index 7c5b7101c..03df15de8 100644 --- a/src/uipath_langchain/agent/tools/datafabric_tool/datafabric_tool.py +++ b/src/uipath_langchain/agent/tools/datafabric_tool/datafabric_tool.py @@ -152,15 +152,29 @@ def format_schemas_for_context(entities: list[Entity]) -> str: lines.append(f"| {field_name} | {field_type} |") sql_type = field.sql_type.name.lower() if field.sql_type else "" - if not numeric_field and sql_type in ("int", "decimal", "float", "double", "bigint"): + if not numeric_field and sql_type in ( + "int", + "decimal", + "float", + "double", + "bigint", + ): numeric_field = field_name - if not text_field and sql_type in ("varchar", "nvarchar", "text", "string", "ntext"): + if not text_field and sql_type in ( + "varchar", + "nvarchar", + "text", + "string", + "ntext", + ): text_field = field_name lines.append("") group_field = text_field or (field_names[0] if field_names else "Category") - agg_field = numeric_field or (field_names[1] if len(field_names) > 1 else "Amount") + agg_field = numeric_field or ( + field_names[1] if len(field_names) > 1 else "Amount" + ) filter_field = text_field or (field_names[0] if field_names else "Name") fields_sample = ", ".join(field_names[:5]) if field_names else "*" @@ -168,12 +182,24 @@ def format_schemas_for_context(entities: list[Entity]) -> str: lines.append("") lines.append("| User Intent | SQL Pattern |") lines.append("|-------------|-------------|") - lines.append(f"| 'Show all {display_name.lower()}' | `SELECT {fields_sample} FROM {display_name} LIMIT 100` |") - lines.append(f"| 'Find by X' | `SELECT {fields_sample} FROM {display_name} WHERE {filter_field} = 'value' LIMIT 100` |") - lines.append(f"| 'Top N by Y' | `SELECT {fields_sample} FROM {display_name} ORDER BY {agg_field} DESC LIMIT N` |") - lines.append(f"| 'Count by X' | `SELECT {group_field}, COUNT(*) as count FROM {display_name} GROUP BY {group_field}` |") - lines.append(f"| 'Top N segments' | `SELECT {group_field}, COUNT(*) as count FROM {display_name} GROUP BY {group_field} ORDER BY count DESC LIMIT N` |") - lines.append(f"| 'Sum/Avg of Y' | `SELECT SUM({agg_field}) as total FROM {display_name}` |") + lines.append( + f"| 'Show all {display_name.lower()}' | `SELECT {fields_sample} FROM {display_name} LIMIT 100` |" + ) + lines.append( + f"| 'Find by X' | `SELECT {fields_sample} FROM {display_name} WHERE {filter_field} = 'value' LIMIT 100` |" + ) + lines.append( + f"| 'Top N by Y' | `SELECT {fields_sample} FROM {display_name} ORDER BY {agg_field} DESC LIMIT N` |" + ) + lines.append( + f"| 'Count by X' | `SELECT {group_field}, COUNT(*) as count FROM {display_name} GROUP BY {group_field}` |" + ) + lines.append( + f"| 'Top N segments' | `SELECT {group_field}, COUNT(*) as count FROM {display_name} GROUP BY {group_field} ORDER BY count DESC LIMIT N` |" + ) + lines.append( + f"| 'Sum/Avg of Y' | `SELECT SUM({agg_field}) as total FROM {display_name}` |" + ) lines.append("") return "\n".join(lines) @@ -205,7 +231,7 @@ def _filter_datafabric_contexts( for resource in resources if isinstance(resource, AgentContextResourceConfig) and resource.is_enabled - and resource.settings.retrieval_mode.lower() == "datafabric" + and resource.is_datafabric ] @@ -222,7 +248,7 @@ def get_datafabric_entity_identifiers_from_resources( """ identifiers: list[str] = [] for context in _filter_datafabric_contexts(resources): - identifiers.extend(context.settings.entity_identifiers or []) + identifiers.extend(context.datafabric_entity_identifiers) return identifiers @@ -252,7 +278,9 @@ async def _query_datafabric(sql_query: str) -> dict[str, Any]: ) total_count = len(records) truncated = total_count > _MAX_RECORDS_IN_RESPONSE - returned_records = records[:_MAX_RECORDS_IN_RESPONSE] if truncated else records + returned_records = ( + records[:_MAX_RECORDS_IN_RESPONSE] if truncated else records + ) result: dict[str, Any] = { "records": returned_records, diff --git a/src/uipath_langchain/agent/tools/tool_factory.py b/src/uipath_langchain/agent/tools/tool_factory.py index ae05d1178..00de2e682 100644 --- a/src/uipath_langchain/agent/tools/tool_factory.py +++ b/src/uipath_langchain/agent/tools/tool_factory.py @@ -78,10 +78,7 @@ async def _build_tool_for_resource( return create_process_tool(resource) elif isinstance(resource, AgentContextResourceConfig): - # Skip Data Fabric contexts - handled separately via create_datafabric_tools() - retrieval_mode = resource.settings.retrieval_mode - mode_value = retrieval_mode.value if hasattr(retrieval_mode, 'value') else str(retrieval_mode) - if mode_value.lower() == "datafabric": + if resource.is_datafabric: logger.info( "Skipping Data Fabric context '%s' - handled separately", resource.name, diff --git a/tests/agent/react/test_init_node.py b/tests/agent/react/test_init_node.py index a4e985a14..aec32d688 100644 --- a/tests/agent/react/test_init_node.py +++ b/tests/agent/react/test_init_node.py @@ -169,7 +169,9 @@ async def test_non_conversational_returns_list_not_overwrite( assert not isinstance(result["messages"], Overwrite) assert len(result["messages"]) == 2 - async def test_non_conversational_default_behavior(self, system_message, user_message): + async def test_non_conversational_default_behavior( + self, system_message, user_message + ): """Default behavior (no is_conversational param) should be non-conversational.""" messages = [system_message, user_message] init_node = create_init_node(messages, input_schema=None) From 98460a9551a6a24a0d81181e4950d650266454ca Mon Sep 17 00:00:00 2001 From: milind-jain-uipath Date: Tue, 24 Mar 2026 12:03:39 +0530 Subject: [PATCH 6/8] use name & field names instead of display name --- .../tools/datafabric_tool/datafabric_tool.py | 19 ++++++++++--------- 1 file changed, 10 insertions(+), 9 deletions(-) diff --git a/src/uipath_langchain/agent/tools/datafabric_tool/datafabric_tool.py b/src/uipath_langchain/agent/tools/datafabric_tool/datafabric_tool.py index 03df15de8..053345083 100644 --- a/src/uipath_langchain/agent/tools/datafabric_tool/datafabric_tool.py +++ b/src/uipath_langchain/agent/tools/datafabric_tool/datafabric_tool.py @@ -131,8 +131,9 @@ def format_schemas_for_context(entities: list[Entity]) -> str: lines.append("") for entity in entities: + sql_table = entity.name display_name = entity.display_name or entity.name - lines.append(f"### Entity: {display_name}") + lines.append(f"### Entity: {display_name} (SQL table: `{sql_table}`)") if entity.description: lines.append(f"_{entity.description}_") lines.append("") @@ -146,7 +147,7 @@ def format_schemas_for_context(entities: list[Entity]) -> str: for field in entity.fields or []: if field.is_hidden_field or field.is_system_field: continue - field_name = field.display_name or field.name + field_name = field.name field_type = format_field_type(field) field_names.append(field_name) lines.append(f"| {field_name} | {field_type} |") @@ -178,27 +179,27 @@ def format_schemas_for_context(entities: list[Entity]) -> str: filter_field = text_field or (field_names[0] if field_names else "Name") fields_sample = ", ".join(field_names[:5]) if field_names else "*" - lines.append(f"**Query Patterns for {display_name}:**") + lines.append(f"**Query Patterns for {sql_table}:**") lines.append("") lines.append("| User Intent | SQL Pattern |") lines.append("|-------------|-------------|") lines.append( - f"| 'Show all {display_name.lower()}' | `SELECT {fields_sample} FROM {display_name} LIMIT 100` |" + f"| 'Show all' | `SELECT {fields_sample} FROM {sql_table} LIMIT 100` |" ) lines.append( - f"| 'Find by X' | `SELECT {fields_sample} FROM {display_name} WHERE {filter_field} = 'value' LIMIT 100` |" + f"| 'Find by X' | `SELECT {fields_sample} FROM {sql_table} WHERE {filter_field} = 'value' LIMIT 100` |" ) lines.append( - f"| 'Top N by Y' | `SELECT {fields_sample} FROM {display_name} ORDER BY {agg_field} DESC LIMIT N` |" + f"| 'Top N by Y' | `SELECT {fields_sample} FROM {sql_table} ORDER BY {agg_field} DESC LIMIT N` |" ) lines.append( - f"| 'Count by X' | `SELECT {group_field}, COUNT(*) as count FROM {display_name} GROUP BY {group_field}` |" + f"| 'Count by X' | `SELECT {group_field}, COUNT(*) as count FROM {sql_table} GROUP BY {group_field}` |" ) lines.append( - f"| 'Top N segments' | `SELECT {group_field}, COUNT(*) as count FROM {display_name} GROUP BY {group_field} ORDER BY count DESC LIMIT N` |" + f"| 'Top N segments' | `SELECT {group_field}, COUNT(*) as count FROM {sql_table} GROUP BY {group_field} ORDER BY count DESC LIMIT N` |" ) lines.append( - f"| 'Sum/Avg of Y' | `SELECT SUM({agg_field}) as total FROM {display_name}` |" + f"| 'Sum/Avg of Y' | `SELECT SUM({agg_field}) as total FROM {sql_table}` |" ) lines.append("") From 55a49b20b5539b52015030b6d8db2ae0bb8d1b1c Mon Sep 17 00:00:00 2001 From: milind-jain-uipath Date: Tue, 24 Mar 2026 14:41:45 +0530 Subject: [PATCH 7/8] removed count(*) examples --- pyproject.toml | 4 ---- .../agent/tools/datafabric_tool/datafabric_tool.py | 5 +++-- uv.lock | 2 -- 3 files changed, 3 insertions(+), 8 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 6eb8ce183..1b716997c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -69,7 +69,6 @@ dev = [ "numpy>=1.24.0", "pytest_httpx>=0.35.0", "rust-just>=1.39.0", - "uipath-langchain", ] [tool.hatch.build.targets.wheel] @@ -119,6 +118,3 @@ name = "testpypi" url = "https://test.pypi.org/simple/" publish-url = "https://test.pypi.org/legacy/" explicit = true - -[tool.uv.sources] -uipath-langchain = { workspace = true } diff --git a/src/uipath_langchain/agent/tools/datafabric_tool/datafabric_tool.py b/src/uipath_langchain/agent/tools/datafabric_tool/datafabric_tool.py index 053345083..e44d64fa9 100644 --- a/src/uipath_langchain/agent/tools/datafabric_tool/datafabric_tool.py +++ b/src/uipath_langchain/agent/tools/datafabric_tool/datafabric_tool.py @@ -192,11 +192,12 @@ def format_schemas_for_context(entities: list[Entity]) -> str: lines.append( f"| 'Top N by Y' | `SELECT {fields_sample} FROM {sql_table} ORDER BY {agg_field} DESC LIMIT N` |" ) + count_col = field_names[0] if field_names else "id" lines.append( - f"| 'Count by X' | `SELECT {group_field}, COUNT(*) as count FROM {sql_table} GROUP BY {group_field}` |" + f"| 'Count by X' | `SELECT {group_field}, COUNT({count_col}) as count FROM {sql_table} GROUP BY {group_field}` |" ) lines.append( - f"| 'Top N segments' | `SELECT {group_field}, COUNT(*) as count FROM {sql_table} GROUP BY {group_field} ORDER BY count DESC LIMIT N` |" + f"| 'Top N segments' | `SELECT {group_field}, COUNT({count_col}) as count FROM {sql_table} GROUP BY {group_field} ORDER BY count DESC LIMIT N` |" ) lines.append( f"| 'Sum/Avg of Y' | `SELECT SUM({agg_field}) as total FROM {sql_table}` |" diff --git a/uv.lock b/uv.lock index c970a18a9..715b0c9ac 100644 --- a/uv.lock +++ b/uv.lock @@ -3378,7 +3378,6 @@ dev = [ { name = "pytest-mock" }, { name = "ruff" }, { name = "rust-just" }, - { name = "uipath-langchain" }, { name = "virtualenv" }, ] @@ -3421,7 +3420,6 @@ dev = [ { name = "pytest-mock", specifier = ">=3.11.1" }, { name = "ruff", specifier = ">=0.9.4" }, { name = "rust-just", specifier = ">=1.39.0" }, - { name = "uipath-langchain", editable = "." }, { name = "virtualenv", specifier = ">=20.36.1" }, ] From fff4b2e97d8d9fc81f93f3536ff91b88216aa532 Mon Sep 17 00:00:00 2001 From: milind-jain-uipath Date: Tue, 24 Mar 2026 17:14:23 +0530 Subject: [PATCH 8/8] removed count(*) examples + prompt change --- pyproject.toml | 4 ---- .../agent/tools/datafabric_tool/datafabric_tool.py | 5 +++-- .../agent/tools/datafabric_tool/system_prompt.txt | 11 +++++++++++ uv.lock | 2 -- 4 files changed, 14 insertions(+), 8 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 6eb8ce183..1b716997c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -69,7 +69,6 @@ dev = [ "numpy>=1.24.0", "pytest_httpx>=0.35.0", "rust-just>=1.39.0", - "uipath-langchain", ] [tool.hatch.build.targets.wheel] @@ -119,6 +118,3 @@ name = "testpypi" url = "https://test.pypi.org/simple/" publish-url = "https://test.pypi.org/legacy/" explicit = true - -[tool.uv.sources] -uipath-langchain = { workspace = true } diff --git a/src/uipath_langchain/agent/tools/datafabric_tool/datafabric_tool.py b/src/uipath_langchain/agent/tools/datafabric_tool/datafabric_tool.py index 053345083..e44d64fa9 100644 --- a/src/uipath_langchain/agent/tools/datafabric_tool/datafabric_tool.py +++ b/src/uipath_langchain/agent/tools/datafabric_tool/datafabric_tool.py @@ -192,11 +192,12 @@ def format_schemas_for_context(entities: list[Entity]) -> str: lines.append( f"| 'Top N by Y' | `SELECT {fields_sample} FROM {sql_table} ORDER BY {agg_field} DESC LIMIT N` |" ) + count_col = field_names[0] if field_names else "id" lines.append( - f"| 'Count by X' | `SELECT {group_field}, COUNT(*) as count FROM {sql_table} GROUP BY {group_field}` |" + f"| 'Count by X' | `SELECT {group_field}, COUNT({count_col}) as count FROM {sql_table} GROUP BY {group_field}` |" ) lines.append( - f"| 'Top N segments' | `SELECT {group_field}, COUNT(*) as count FROM {sql_table} GROUP BY {group_field} ORDER BY count DESC LIMIT N` |" + f"| 'Top N segments' | `SELECT {group_field}, COUNT({count_col}) as count FROM {sql_table} GROUP BY {group_field} ORDER BY count DESC LIMIT N` |" ) lines.append( f"| 'Sum/Avg of Y' | `SELECT SUM({agg_field}) as total FROM {sql_table}` |" diff --git a/src/uipath_langchain/agent/tools/datafabric_tool/system_prompt.txt b/src/uipath_langchain/agent/tools/datafabric_tool/system_prompt.txt index 03df458fb..aed1e659d 100644 --- a/src/uipath_langchain/agent/tools/datafabric_tool/system_prompt.txt +++ b/src/uipath_langchain/agent/tools/datafabric_tool/system_prompt.txt @@ -121,4 +121,15 @@ UNSUPPORTED SCENARIOS (Avoid these patterns): - JSON/ARRAY/MAP/GEOMETRY/BLOB operations - Complex timezone-aware timestamp operations +RETRY BEHAVIOR: +If a query fails with a validation error (e.g. missing LIMIT, SELECT *, COUNT(*)), DO NOT give up. Instead: +1. Read the error message carefully +2. Fix the query to comply with the constraint +3. Call query_datafabric again with the corrected query + +Example: if "Queries without WHERE must include a LIMIT clause" is returned, add LIMIT 100 and retry: + Before: SELECT name, department FROM Employee + After: SELECT name, department FROM Employee LIMIT 100 + + Return only the SQL query as plain text. \ No newline at end of file diff --git a/uv.lock b/uv.lock index c970a18a9..715b0c9ac 100644 --- a/uv.lock +++ b/uv.lock @@ -3378,7 +3378,6 @@ dev = [ { name = "pytest-mock" }, { name = "ruff" }, { name = "rust-just" }, - { name = "uipath-langchain" }, { name = "virtualenv" }, ] @@ -3421,7 +3420,6 @@ dev = [ { name = "pytest-mock", specifier = ">=3.11.1" }, { name = "ruff", specifier = ">=0.9.4" }, { name = "rust-just", specifier = ">=1.39.0" }, - { name = "uipath-langchain", editable = "." }, { name = "virtualenv", specifier = ">=20.36.1" }, ]