diff --git a/pyproject.toml b/pyproject.toml index 35df3ec58..351f0ca74 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "uipath-langchain" -version = "0.9.12" +version = "0.9.13" description = "Python SDK that enables developers to build and deploy LangGraph agents to the UiPath Cloud Platform" readme = { file = "README.md", content-type = "text/markdown" } requires-python = ">=3.11" diff --git a/samples/joke-agent-decorator/README.md b/samples/joke-agent-decorator/README.md new file mode 100644 index 000000000..603f50a30 --- /dev/null +++ b/samples/joke-agent-decorator/README.md @@ -0,0 +1,197 @@ +# Joke Agent (Decorator-based Guardrails) + +A simple LangGraph agent that generates family-friendly jokes based on a given topic using UiPath's LLM. This sample demonstrates all three guardrail decorator types — PII, Prompt Injection, and Deterministic — applied directly to the LLM, agent, and tool without a middleware stack. + +## Requirements + +- Python 3.11+ + +## Installation + +```bash +uv venv -p 3.11 .venv +source .venv/bin/activate # On Windows: .venv\Scripts\activate +uv sync +``` + +## Usage + +Run the joke agent: + +```bash +uv run uipath run agent '{"topic": "banana"}' +``` + +### Input Format + +```json +{ + "topic": "banana" +} +``` + +### Output Format + +```json +{ + "joke": "Why did the banana go to the doctor? Because it wasn't peeling well!" +} +``` + +## Guardrails Overview + +This sample achieves full parity with the middleware-based `joke-agent` sample using only decorators. The table below shows which scope each guardrail covers: + +| Decorator | Target | Scope | Action | +|---|---|---|---| +| `@prompt_injection_guardrail` | `create_llm` factory | LLM | `BlockAction` — blocks on detection | +| `@pii_detection_guardrail` | `create_llm` factory | LLM | `LogAction(WARNING)` — logs and continues | +| `@pii_detection_guardrail` | `analyze_joke_syntax` tool | TOOL | `LogAction(WARNING)` — logs email/phone | +| `@deterministic_guardrail` | `analyze_joke_syntax` tool | TOOL (PRE) | `CustomFilterAction` — replaces "donkey" with "[censored]" | +| `@deterministic_guardrail` | `analyze_joke_syntax` tool | TOOL (PRE) | `BlockAction` — blocks jokes > 1000 chars | +| `@deterministic_guardrail` | `analyze_joke_syntax` tool | TOOL (POST) | `CustomFilterAction` — always-on output transform | +| `@pii_detection_guardrail` | `create_joke_agent` factory | AGENT | `LogAction(WARNING)` — logs agent-level PII | + +## Guardrail Decorators + +### LLM-level guardrails + +Stacked decorators on a factory function. The outermost decorator runs first: + +```python +@prompt_injection_guardrail( + threshold=0.5, + action=BlockAction(), + name="LLM Prompt Injection Detection", + enabled_for_evals=False, # default is True +) +@pii_detection_guardrail( + entities=[PIIDetectionEntity(PIIDetectionEntityType.EMAIL, 0.5)], + action=LogAction(severity_level=LoggingSeverityLevel.WARNING), + name="LLM PII Detection", +) +def create_llm(): + return UiPathChat(model="gpt-4o-2024-08-06", temperature=0.7) + +llm = create_llm() +``` + +### Tool-level guardrails + +`@deterministic_guardrail` applies local rule functions — no UiPath API call. Rules receive the tool input dict and return `True` to signal a violation. `@pii_detection_guardrail` at TOOL scope evaluates via the UiPath guardrails API. + +```python +@deterministic_guardrail( + rules=[lambda args: "donkey" in args.get("joke", "").lower()], + action=CustomFilterAction(word_to_filter="donkey", replacement="[censored]"), + stage=GuardrailExecutionStage.PRE, + name="Joke Content Word Filter", + enabled_for_evals=False, # default is True +) +@deterministic_guardrail( + rules=[lambda args: len(args.get("joke", "")) > 1000], + action=BlockAction(), + stage=GuardrailExecutionStage.PRE, + name="Joke Content Length Limiter", +) +@deterministic_guardrail( + rules=[], # empty rules = always apply (unconditional transform) + action=CustomFilterAction(word_to_filter="words", replacement="words++"), + stage=GuardrailExecutionStage.POST, + name="Joke Content Always Filter", +) +@pii_detection_guardrail( + entities=[ + PIIDetectionEntity(PIIDetectionEntityType.EMAIL, 0.5), + PIIDetectionEntity(PIIDetectionEntityType.PHONE_NUMBER, 0.5), + ], + action=LogAction(severity_level=LoggingSeverityLevel.WARNING), + name="Tool PII Detection", +) +@tool +def analyze_joke_syntax(joke: str) -> str: + ... +``` + +### Agent-level guardrail + +```python +@pii_detection_guardrail( + entities=[PIIDetectionEntity(PIIDetectionEntityType.EMAIL, 0.5)], + action=LogAction( + severity_level=LoggingSeverityLevel.WARNING, + message="PII detected from agent guardrails decorator", + ), + name="Agent PII Detection", + enabled_for_evals=False, # default is True +) +def create_joke_agent(): + return create_agent(model=llm, tools=[analyze_joke_syntax], ...) + +agent = create_joke_agent() +``` + +### Custom action + +`CustomFilterAction` (defined locally in `graph.py`) demonstrates how to implement a custom `GuardrailAction`. When a violation is detected it replaces the offending word in the tool input dict or string, logs the change, then returns the modified data so execution continues with the sanitised input: + +```python +@dataclass +class CustomFilterAction(GuardrailAction): + word_to_filter: str + replacement: str = "***" + + def handle_validation_result(self, result, data, guardrail_name): + # filter word from dict/str and return modified data + ... +``` + +## Rule semantics (`@deterministic_guardrail`) + +- A rule with **1 parameter** receives the tool input dict (`PRE` stage). +- A rule with **2 parameters** receives `(input_dict, output_dict)` (`POST` stage). +- A rule returns `True` to signal a **violation**, `False` to **pass**. +- **All** rules must detect a violation for the guardrail to trigger. If any rule passes, the guardrail passes. +- **Empty `rules=[]`** always triggers the action (useful for unconditional transforms). + +## `enabled_for_evals` override + +All decorator guardrails accept `enabled_for_evals` (default `True`). Set it to `False` +when you want runtime guardrail behavior but do not want that guardrail enabled for eval scenarios. + +## Verification + +To manually verify each guardrail fires, run from this directory: + +```bash +uv run uipath run agent '{"topic": "donkey"}' +``` + +**Scenario 1 — word filter (PRE):** the LLM includes "donkey" in the joke passed to `analyze_joke_syntax`. `CustomFilterAction` replaces it with `[censored]` before the tool executes. Look for `[FILTER][Joke Content Word Filter]` in stdout. + +**Scenario 2 — length limiter (PRE):** if the generated joke exceeds 1000 characters, `BlockAction` raises `AgentRuntimeError(TERMINATION_GUARDRAIL_VIOLATION)` before the tool is called. + +**Scenario 3 — PII at tool and agent scope:** supply a topic containing an email address: + +```bash +uv run uipath run agent '{"topic": "donkey, test@example.com"}' +``` + +Both the agent-scope and LLM-scope `@pii_detection_guardrail` decorators log a `WARNING` when the email is detected. The tool-scope `@pii_detection_guardrail` logs when the email reaches the tool input. + +## Differences from the Middleware Approach (`joke-agent`) + +| Aspect | Middleware (`joke-agent`) | Decorator (`joke-agent-decorator`) | +|---|---|---| +| Configuration | Middleware class instances passed to `create_agent(middleware=[...])` | `@decorator` stacked on the target object | +| Scope | Explicit `scopes=[...]` list | Inferred automatically from the decorated object | +| Tool guardrails | `UiPathDeterministicGuardrailMiddleware(tools=[...])` | `@deterministic_guardrail` directly on the `@tool` | +| Custom loops | Not supported (requires `create_agent`) | Works in any custom LangChain loop | +| API calls | Via middleware stack | Direct `uipath.guardrails.evaluate_guardrail()` | + +## Example Topics + +- `"banana"` — normal run, all guardrails pass +- `"donkey"` — triggers the word filter on `analyze_joke_syntax` +- `"donkey, test@example.com"` — triggers word filter + PII guardrails at all scopes +- `"computer"`, `"coffee"`, `"pizza"`, `"weather"` diff --git a/samples/joke-agent-decorator/graph.py b/samples/joke-agent-decorator/graph.py new file mode 100644 index 000000000..221efaf6c --- /dev/null +++ b/samples/joke-agent-decorator/graph.py @@ -0,0 +1,249 @@ +"""Joke generating agent that creates family-friendly jokes based on a topic.""" + +import logging +import re +from dataclasses import dataclass +from typing import Any + +from langchain.agents import create_agent +from langchain_core.messages import HumanMessage +from langchain_core.tools import tool +from langgraph.constants import END, START +from langgraph.graph import StateGraph +from pydantic import BaseModel +from uipath.core.guardrails import ( + GuardrailValidationResult, + GuardrailValidationResultType, +) + +from uipath_langchain.chat import UiPathChat +from uipath_langchain.guardrails import ( + BlockAction, + DeterministicValidator, + GuardrailAction, + GuardrailExecutionStage, + LogAction, + LoggingSeverityLevel, + PIIDetectionEntity, + PIIValidator, + PromptInjectionValidator, + guardrail, +) +from uipath_langchain.guardrails.enums import PIIDetectionEntityType + +logger = logging.getLogger(__name__) + + +# --------------------------------------------------------------------------- +# Custom filter action (defined locally) +# --------------------------------------------------------------------------- + +@dataclass +class CustomFilterAction(GuardrailAction): + """Filters/replaces a word in tool input when a violation is detected.""" + + word_to_filter: str + replacement: str = "***" + + def _filter(self, text: str) -> str: + return re.sub(re.escape(self.word_to_filter), self.replacement, text, flags=re.IGNORECASE) + + def handle_validation_result( + self, + result: GuardrailValidationResult, + data: str | dict[str, Any], + guardrail_name: str, + ) -> str | dict[str, Any] | None: + if result.result != GuardrailValidationResultType.VALIDATION_FAILED: + return None + if isinstance(data, str): + filtered = self._filter(data) + print(f"[FILTER][{guardrail_name}] '{self.word_to_filter}' replaced → '{filtered[:80]}'") + return filtered + if isinstance(data, dict): + filtered_data = data.copy() + for key in ["joke", "text", "content", "message", "input", "output"]: + if key in filtered_data and isinstance(filtered_data[key], str): + filtered_data[key] = self._filter(filtered_data[key]) + print(f"[FILTER][{guardrail_name}] dict filtered") + return filtered_data + return data + + +# --------------------------------------------------------------------------- +# Input / Output schemas +# --------------------------------------------------------------------------- + +class Input(BaseModel): + """Input schema for the joke agent.""" + topic: str + + +class Output(BaseModel): + """Output schema for the joke agent.""" + joke: str + + +# --------------------------------------------------------------------------- +# Reusable validators (declared once, used in multiple @guardrail decorators) +# --------------------------------------------------------------------------- + +pii_email = PIIValidator( + entities=[PIIDetectionEntity(PIIDetectionEntityType.EMAIL, 0.5)], +) + +pii_email_phone = PIIValidator( + entities=[ + PIIDetectionEntity(PIIDetectionEntityType.EMAIL, 0.5), + PIIDetectionEntity(PIIDetectionEntityType.PHONE_NUMBER, 0.5), + ], +) + + +# --------------------------------------------------------------------------- +# LLM with guardrails (prompt injection + PII at LLM scope) +# --------------------------------------------------------------------------- + +@guardrail( + validator=PromptInjectionValidator(threshold=0.5), + action=BlockAction(), + name="LLM Prompt Injection Detection", + stage=GuardrailExecutionStage.PRE, +) +@guardrail( + validator=pii_email, + action=LogAction(severity_level=LoggingSeverityLevel.WARNING), + name="LLM PII Detection", + stage=GuardrailExecutionStage.PRE, +) +def create_llm(): + """Create LLM instance with guardrails.""" + return UiPathChat(model="gpt-4o-2024-08-06", temperature=0.7) + + +llm = create_llm() + + +# --------------------------------------------------------------------------- +# Tool with guardrails (deterministic + PII at TOOL scope) +# --------------------------------------------------------------------------- + +@guardrail( + validator=DeterministicValidator( + rules=[lambda args: "donkey" in args.get("joke", "").lower()], + ), + action=CustomFilterAction(word_to_filter="donkey", replacement="[censored]"), + stage=GuardrailExecutionStage.PRE, + name="Joke Content Word Filter", +) +@guardrail( + validator=DeterministicValidator( + rules=[lambda args: len(args.get("joke", "")) > 1000], + ), + action=BlockAction(title="Joke is too long", detail="The generated joke is too long"), + stage=GuardrailExecutionStage.PRE, + name="Joke Content Length Limiter", +) +@guardrail( + validator=DeterministicValidator(rules=[]), + action=CustomFilterAction(word_to_filter="words", replacement="words++"), + stage=GuardrailExecutionStage.POST, + name="Joke Content Always Filter", +) +@guardrail( + validator=pii_email_phone, + action=LogAction( + severity_level=LoggingSeverityLevel.WARNING, + message="Email or phone number detected", + ), + name="Tool PII Detection", + stage=GuardrailExecutionStage.PRE, +) +@tool +def analyze_joke_syntax(joke: str) -> str: + """Analyze the syntax of a joke by counting words and letters. + + Args: + joke: The joke text to analyze + + Returns: + A string with the analysis results showing word count and letter count + """ + words = joke.split() + word_count = len(words) + letter_count = sum(1 for char in joke if char.isalpha()) + return f"Words number: {word_count}\nLetters: {letter_count}" + + +# --------------------------------------------------------------------------- +# System prompt +# --------------------------------------------------------------------------- + +SYSTEM_PROMPT = """You are an AI assistant designed to generate family-friendly jokes. Your process is as follows: + +1. Generate a family-friendly joke based on the given topic. +2. Use the analyze_joke_syntax tool to analyze the joke's syntax (word count and letter count). +3. Ensure your output includes the joke. + +When creating jokes, ensure they are: + +1. Appropriate for children +2. Free from offensive language or themes +3. Clever and entertaining +4. Not based on stereotypes or sensitive topics + +If you're unable to generate a suitable joke for any reason, politely explain why and offer to try again with a different topic. + +Example joke: Topic: "banana" Joke: "Why did the banana go to the doctor? Because it wasn't peeling well!" + +Remember to always include the 'joke' property in your output to match the required schema.""" + + +# --------------------------------------------------------------------------- +# Agent with PII guardrail at AGENT scope +# --------------------------------------------------------------------------- + +@guardrail( + validator=PIIValidator( + entities=[PIIDetectionEntity(PIIDetectionEntityType.PERSON, 0.5)], + ), + action=BlockAction( + title="Person name detection", + detail="Person name detected and is not allowed", + ), + name="Agent PII Detection", + stage=GuardrailExecutionStage.PRE, +) +def create_joke_agent(): + """Create the joke agent with guardrails.""" + return create_agent( + model=llm, + tools=[analyze_joke_syntax], + system_prompt=SYSTEM_PROMPT, + ) + + +agent = create_joke_agent() + + +# --------------------------------------------------------------------------- +# Wrapper graph node +# --------------------------------------------------------------------------- + +async def joke_node(state: Input) -> Output: + """Convert topic to messages, call agent, and extract joke.""" + messages = [ + HumanMessage(content=f"Generate a family-friendly joke based on the topic: {state.topic}") + ] + result = await agent.ainvoke({"messages": messages}) + joke = result["messages"][-1].content + return Output(joke=joke) + + +# Build wrapper graph with custom input/output schemas +builder = StateGraph(Input, input=Input, output=Output) +builder.add_node("joke", joke_node) +builder.add_edge(START, "joke") +builder.add_edge("joke", END) + +graph = builder.compile() diff --git a/samples/joke-agent-decorator/langgraph.json b/samples/joke-agent-decorator/langgraph.json new file mode 100644 index 000000000..c465a881b --- /dev/null +++ b/samples/joke-agent-decorator/langgraph.json @@ -0,0 +1,7 @@ +{ + "dependencies": ["."], + "graphs": { + "agent": "./graph.py:graph" + }, + "env": ".env" +} diff --git a/samples/joke-agent-decorator/pyproject.toml b/samples/joke-agent-decorator/pyproject.toml new file mode 100644 index 000000000..e0a580628 --- /dev/null +++ b/samples/joke-agent-decorator/pyproject.toml @@ -0,0 +1,18 @@ +[project] +name = "joke-agent-decorator" +version = "0.0.1" +description = "Joke generating agent that creates family-friendly jokes based on a topic - using decorator-based guardrails" +authors = [{ name = "Andrei Petraru", email = "andrei.petraru@uipath.com" }] +requires-python = ">=3.11" +dependencies = [ + "uipath-langchain>=0.8.28", + "uipath>2.7.0", +] + +[dependency-groups] +dev = [ + "uipath-dev>=0.0.14", +] + +[tool.uv.sources] +uipath-langchain = { path = "../..", editable = true } diff --git a/samples/joke-agent/graph.py b/samples/joke-agent/graph.py index 85025cae3..31acb3251 100644 --- a/samples/joke-agent/graph.py +++ b/samples/joke-agent/graph.py @@ -15,12 +15,12 @@ PIIDetectionEntity, GuardrailExecutionStage, LogAction, - PIIDetectionEntityType, UiPathDeterministicGuardrailMiddleware, UiPathPIIDetectionMiddleware, UiPathPromptInjectionMiddleware, ) from uipath_langchain.guardrails.actions import LoggingSeverityLevel +from uipath_langchain.guardrails.enums import PIIDetectionEntityType # Define input schema for the agent @@ -102,12 +102,13 @@ def analyze_joke_syntax(joke: str) -> str: PIIDetectionEntity(PIIDetectionEntityType.PHONE_NUMBER, 0.5), ], tools=[analyze_joke_syntax], + enabled_for_evals=False, ), *UiPathPromptInjectionMiddleware( name="Prompt Injection Detection", - scopes=[GuardrailScope.LLM], action=BlockAction(), threshold=0.5, + enabled_for_evals=False, ), # Custom FilterAction example: demonstrates how developers can implement their own actions *UiPathDeterministicGuardrailMiddleware( @@ -121,6 +122,7 @@ def analyze_joke_syntax(joke: str) -> str: ), stage=GuardrailExecutionStage.PRE, name="Joke Content Validator", + enabled_for_evals=False, ), *UiPathDeterministicGuardrailMiddleware( tools=[analyze_joke_syntax], diff --git a/src/uipath_langchain/guardrails/__init__.py b/src/uipath_langchain/guardrails/__init__.py index efd50e194..fc4f81751 100644 --- a/src/uipath_langchain/guardrails/__init__.py +++ b/src/uipath_langchain/guardrails/__init__.py @@ -7,8 +7,16 @@ from uipath.agent.models.agent import AgentGuardrailSeverityLevel from uipath.core.guardrails import GuardrailScope -from .actions import BlockAction, LogAction -from .enums import GuardrailExecutionStage, PIIDetectionEntityType +from .actions import BlockAction, LogAction, LoggingSeverityLevel +from .decorators import ( + DeterministicValidator, + GuardrailValidatorBase, + PIIValidator, + PromptInjectionValidator, + RuleFunction, + guardrail, +) +from .enums import GuardrailExecutionStage from .middlewares import ( UiPathDeterministicGuardrailMiddleware, UiPathPIIDetectionMiddleware, @@ -17,15 +25,27 @@ from .models import GuardrailAction, PIIDetectionEntity __all__ = [ - "PIIDetectionEntityType", + # Decorator + "guardrail", + # Validators + "GuardrailValidatorBase", + "PIIValidator", + "PromptInjectionValidator", + "DeterministicValidator", + "RuleFunction", + # Models & enums + "PIIDetectionEntity", "GuardrailExecutionStage", "GuardrailScope", - "PIIDetectionEntity", "GuardrailAction", + # Actions "LogAction", "BlockAction", + "LoggingSeverityLevel", + # Middlewares (unchanged) "UiPathPIIDetectionMiddleware", "UiPathPromptInjectionMiddleware", "UiPathDeterministicGuardrailMiddleware", - "AgentGuardrailSeverityLevel", # Re-export for convenience + # Re-exports + "AgentGuardrailSeverityLevel", ] diff --git a/src/uipath_langchain/guardrails/actions.py b/src/uipath_langchain/guardrails/actions.py index d4fcd5ec7..624aef052 100644 --- a/src/uipath_langchain/guardrails/actions.py +++ b/src/uipath_langchain/guardrails/actions.py @@ -68,11 +68,9 @@ def handle_validation_result( """Handle validation result by logging it.""" if result.result == GuardrailValidationResultType.VALIDATION_FAILED: log_level = self.severity_level - log_level_name = logging.getLevelName(log_level) message = self.message or f"Failed: {result.reason}" logger = logging.getLogger(__name__) - logger.log(log_level, message) - print(f"[{log_level_name}][GUARDRAIL] [{guardrail_name}] {message}") + logger.log(log_level, "[GUARDRAIL] [%s] %s", guardrail_name, message) return None diff --git a/src/uipath_langchain/guardrails/decorators/__init__.py b/src/uipath_langchain/guardrails/decorators/__init__.py new file mode 100644 index 000000000..272d434ed --- /dev/null +++ b/src/uipath_langchain/guardrails/decorators/__init__.py @@ -0,0 +1,19 @@ +"""Guardrail decorators package.""" + +from .guardrail import guardrail +from .validators import ( + DeterministicValidator, + GuardrailValidatorBase, + PIIValidator, + PromptInjectionValidator, + RuleFunction, +) + +__all__ = [ + "guardrail", + "GuardrailValidatorBase", + "PIIValidator", + "PromptInjectionValidator", + "DeterministicValidator", + "RuleFunction", +] diff --git a/src/uipath_langchain/guardrails/decorators/_base.py b/src/uipath_langchain/guardrails/decorators/_base.py new file mode 100644 index 000000000..56283e068 --- /dev/null +++ b/src/uipath_langchain/guardrails/decorators/_base.py @@ -0,0 +1,600 @@ +"""Shared base utilities for guardrail decorators.""" + +import inspect +from dataclasses import dataclass +from functools import wraps +from typing import Any, Callable, Sequence + +from langchain_core.language_models import BaseChatModel +from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, ToolMessage +from langchain_core.tools import BaseTool +from langgraph.graph import StateGraph +from langgraph.graph.state import CompiledStateGraph +from langgraph.types import Command +from uipath.core.guardrails import ( + GuardrailScope, + GuardrailValidationResult, + GuardrailValidationResultType, +) +from uipath.platform import UiPath +from uipath.platform.guardrails import BuiltInValidatorGuardrail + +from ..enums import GuardrailExecutionStage +from ..models import GuardrailAction + + +@dataclass +class GuardrailMetadata: + """Metadata for a guardrail decorator. + + Args: + guardrail_type: Type of guardrail ("pii", "prompt_injection", "deterministic") + scope: Scope where guardrail applies (AGENT, LLM, TOOL) + config: Type-specific configuration dictionary + name: Name of the guardrail + description: Optional description + guardrail: The BuiltInValidatorGuardrail instance for API-based evaluation + wrap_tool: Optional callable that wraps a BaseTool with this guardrail's logic. + Set by each decorator so that _wrap_function_with_guardrail can delegate + tool wrapping without knowing the concrete guardrail type. + """ + + guardrail_type: str + scope: GuardrailScope + config: dict[str, Any] + name: str + description: str | None = None + guardrail: BuiltInValidatorGuardrail | None = None + wrap_tool: Callable[["BaseTool", "GuardrailMetadata"], "BaseTool"] | None = None + wrap_llm: ( + Callable[["BaseChatModel", "GuardrailMetadata"], "BaseChatModel"] | None + ) = None + + +def _get_or_create_metadata_list(obj: Any) -> list[GuardrailMetadata]: + """Get or create the guardrail metadata list on an object.""" + if not hasattr(obj, "_guardrail_metadata"): + obj._guardrail_metadata = [] + return obj._guardrail_metadata + + +def _store_guardrail_metadata(obj: Any, metadata: GuardrailMetadata) -> None: + """Store guardrail metadata on an object.""" + metadata_list = _get_or_create_metadata_list(obj) + metadata_list.append(metadata) + + +def _extract_guardrail_metadata(obj: Any) -> list[GuardrailMetadata]: + """Extract all guardrail metadata from an object.""" + if hasattr(obj, "_guardrail_metadata"): + return list(obj._guardrail_metadata) + return [] + + +def _get_last_human_message(messages: list[BaseMessage]) -> HumanMessage | None: + """Return the last HumanMessage in a list, or None if absent.""" + for msg in reversed(messages): + if isinstance(msg, HumanMessage): + return msg + return None + + +def _get_last_ai_message(messages: list[BaseMessage]) -> AIMessage | None: + """Return the last AIMessage in a list, or None if absent.""" + for msg in reversed(messages): + if isinstance(msg, AIMessage): + return msg + return None + + +def _extract_message_text(msg: BaseMessage) -> str: + """Extract text content from a single message.""" + if isinstance(msg.content, str): + return msg.content + if isinstance(msg.content, list): + parts = [ + part.get("text", "") + for part in msg.content + if isinstance(part, dict) and part.get("type") == "text" + ] + return "\n".join(filter(None, parts)) + return "" + + +def _apply_message_text_modification(msg: BaseMessage, modified: str) -> None: + """Apply a modified text string back to a message in-place. + + For str content, replaces it directly. For multimodal list content, + replaces the first text part. + """ + if isinstance(msg.content, str): + msg.content = modified + elif isinstance(msg.content, list): + for part in msg.content: + if isinstance(part, dict) and part.get("type") == "text": + part["text"] = modified + break + + +def _detect_scope(obj: Any) -> GuardrailScope: + """Detect the guardrail scope from an object. + + Returns: + GuardrailScope.TOOL for BaseTool instances. + GuardrailScope.LLM for BaseChatModel instances. + GuardrailScope.AGENT for StateGraph or CompiledStateGraph instances, + including subgraphs — guardrails apply at the boundary of any graph + execution, not only the top-level agent. + GuardrailScope.AGENT for plain functions/methods (agent factory functions), + optionally annotated with a StateGraph or CompiledStateGraph return type. + """ + if isinstance(obj, BaseTool): + return GuardrailScope.TOOL + + if isinstance(obj, BaseChatModel): + return GuardrailScope.LLM + + if isinstance(obj, StateGraph): + return GuardrailScope.AGENT + + if isinstance(obj, CompiledStateGraph): + return GuardrailScope.AGENT + + if inspect.isfunction(obj) or inspect.ismethod(obj): + sig = inspect.signature(obj) + if sig.return_annotation != inspect.Signature.empty: + if sig.return_annotation in (StateGraph, CompiledStateGraph) or ( + hasattr(sig.return_annotation, "__origin__") + and sig.return_annotation.__origin__ in (StateGraph, CompiledStateGraph) + ): + return GuardrailScope.AGENT + return GuardrailScope.AGENT + + raise ValueError( + f"Cannot determine scope for object of type {type(obj)}. " + "Object must be a BaseTool, BaseChatModel, StateGraph, CompiledStateGraph, " + "or a callable function/method (agent factory)." + ) + + +def _evaluate_guardrail( + data: str | dict[str, Any], + guardrail: BuiltInValidatorGuardrail, + uipath: UiPath, +) -> GuardrailValidationResult: + """Evaluate a guardrail against data via the UiPath API.""" + return uipath.guardrails.evaluate_guardrail(data, guardrail) + + +def _handle_guardrail_result( + result: GuardrailValidationResult, + data: str | dict[str, Any], + action: GuardrailAction, + guardrail_name: str, +) -> str | dict[str, Any] | None: + """Handle guardrail validation result using action.""" + if result.result == GuardrailValidationResultType.VALIDATION_FAILED: + return action.handle_validation_result(result, data, guardrail_name) + return None + + +def _evaluate_rules( + rules: Sequence[Callable[..., bool]], + stage: GuardrailExecutionStage, + input_data: dict[str, Any] | None, + output_data: dict[str, Any] | None, + guardrail_name: str = "Rule", +) -> GuardrailValidationResult: + """Evaluate deterministic rules and return a validation result. + + All rules must detect violations to trigger. If any rule passes (returns False), + the guardrail passes. Empty rules always trigger the action. + """ + import logging + + logger = logging.getLogger(__name__) + + if not rules: + return GuardrailValidationResult( + result=GuardrailValidationResultType.VALIDATION_FAILED, + reason="Empty rules — always apply action", + ) + + violations: list[str] = [] + passed_rules: list[str] = [] + evaluated_count = 0 + + for rule in rules: + try: + sig = inspect.signature(rule) + param_count = len(sig.parameters) + + if stage == GuardrailExecutionStage.PRE: + if input_data is None or param_count != 1: + continue + violation = rule(input_data) + evaluated_count += 1 + else: + if output_data is None: + continue + if param_count == 2 and input_data is not None: + violation = rule(input_data, output_data) + elif param_count == 1: + violation = rule(output_data) + else: + continue + evaluated_count += 1 + + if violation: + violations.append(f"Rule {guardrail_name} detected violation") + else: + passed_rules.append(f"Rule {guardrail_name}") + except Exception as e: + logger.error(f"Error in rule function {guardrail_name}: {e}", exc_info=True) + violations.append(f"Rule {guardrail_name} raised exception: {str(e)}") + evaluated_count += 1 + + if evaluated_count == 0: + return GuardrailValidationResult( + result=GuardrailValidationResultType.PASSED, + reason="No applicable rules to evaluate", + ) + + if passed_rules: + return GuardrailValidationResult( + result=GuardrailValidationResultType.PASSED, + reason=f"Rules passed: {', '.join(passed_rules)}", + ) + + return GuardrailValidationResult( + result=GuardrailValidationResultType.VALIDATION_FAILED, + reason="; ".join(violations), + ) + + +# --------------------------------------------------------------------------- +# Module-level tool I/O helpers shared by PII and deterministic tool wrappers +# --------------------------------------------------------------------------- + + +def _is_tool_call_envelope(tool_input: Any) -> bool: + """Return True if tool_input is a LangGraph tool-call envelope dict.""" + return ( + isinstance(tool_input, dict) + and "args" in tool_input + and tool_input.get("type") == "tool_call" + ) + + +def _extract_input(tool_input: Any) -> dict[str, Any]: + """Normalise tool input to a dict for rule/guardrail evaluation. + + LangGraph passes the raw tool-call dict ({"name": ..., "args": {...}, "id": ..., + "type": "tool_call"}) to tool.invoke/ainvoke. Unwrap "args" so rules can access + the actual tool arguments (e.g. args.get("joke", "")) directly. + """ + if _is_tool_call_envelope(tool_input): + args = tool_input["args"] + if isinstance(args, dict): + return args + if isinstance(tool_input, dict): + return tool_input + return {"input": tool_input} + + +def _rewrap_input(original_tool_input: Any, modified_args: dict[str, Any]) -> Any: + """Re-wrap modified args back into the original tool-call envelope (if applicable).""" + if _is_tool_call_envelope(original_tool_input): + import copy + + wrapped = copy.copy(original_tool_input) + wrapped["args"] = modified_args + return wrapped + return modified_args + + +def _extract_output(result: Any) -> dict[str, Any]: + """Normalise tool output to a dict for guardrail/rule evaluation. + + Handles ToolMessage and Command (returned when the tool is called through + LangGraph's tool node) by extracting their string content first, then + parsing as JSON/literal-eval. Falls back to {"output": content} for + plain strings and {"output": result} for anything else. + """ + import ast + import json + + content: Any = result + if isinstance(result, Command): + update = result.update if hasattr(result, "update") else {} + messages = update.get("messages", []) if isinstance(update, dict) else [] + if messages and isinstance(messages[0], ToolMessage): + content = messages[0].content + else: + return {} + elif isinstance(result, ToolMessage): + content = result.content + + if isinstance(content, dict): + return content + if isinstance(result, dict): + return result + if isinstance(content, str): + try: + parsed = json.loads(content) + return parsed if isinstance(parsed, dict) else {"output": parsed} + except ValueError: + try: + parsed = ast.literal_eval(content) + return parsed if isinstance(parsed, dict) else {"output": parsed} + except (ValueError, SyntaxError): + return {"output": content} + return {"output": content} + + +# --------------------------------------------------------------------------- +# Module-level LLM guardrail helpers shared by PII and prompt-injection wrappers +# --------------------------------------------------------------------------- + + +def _apply_llm_input_guardrail( + messages: list[BaseMessage], + guardrail: BuiltInValidatorGuardrail, + uipath: UiPath, + action: GuardrailAction, + guardrail_name: str, +) -> None: + """Evaluate a guardrail against the last HumanMessage (PRE stage). + + Only the most recent user input is evaluated — prior turns were already + evaluated in previous invocations. Modifies the message content in-place + if the action returns a replacement string. + """ + msg = _get_last_human_message(messages) + if msg is None: + return + text = _extract_message_text(msg) + if not text: + return + try: + eval_result = _evaluate_guardrail(text, guardrail, uipath) + except Exception: + return + modified = _handle_guardrail_result(eval_result, text, action, guardrail_name) + if isinstance(modified, str) and modified != text: + _apply_message_text_modification(msg, modified) + + +def _apply_llm_output_guardrail( + response: AIMessage, + guardrail: BuiltInValidatorGuardrail, + uipath: UiPath, + action: GuardrailAction, + guardrail_name: str, +) -> None: + """Evaluate a guardrail against the LLM response text (POST stage). + + Modifies ``response.content`` in-place if the action returns a replacement string. + """ + if not isinstance(response.content, str) or not response.content: + return + try: + eval_result = _evaluate_guardrail(response.content, guardrail, uipath) + except Exception: + return + modified = _handle_guardrail_result( + eval_result, response.content, action, guardrail_name + ) + if isinstance(modified, str) and modified != response.content: + response.content = modified + + +def _apply_guardrail_to_message_list( + messages: list[BaseMessage], + guardrail: BuiltInValidatorGuardrail, + uipath: UiPath, + action: GuardrailAction, + guardrail_name: str, + target_type: type[BaseMessage] = HumanMessage, +) -> None: + """Evaluate a guardrail against the last message of target_type and modify it in-place. + + Pass target_type=HumanMessage (default) for PRE/input evaluation, + or target_type=AIMessage for POST/output evaluation. + """ + msg: BaseMessage | None = None + for m in reversed(messages): + if isinstance(m, target_type): + msg = m + break + if msg is None: + return + text = _extract_message_text(msg) + if not text: + return + try: + result = _evaluate_guardrail(text, guardrail, uipath) + except Exception: + return + modified = _handle_guardrail_result(result, text, action, guardrail_name) + if isinstance(modified, str) and modified != text: + _apply_message_text_modification(msg, modified) + + +def _apply_guardrail_to_input_messages( + input_data: Any, + guardrail: BuiltInValidatorGuardrail, + uipath: UiPath, + action: GuardrailAction, + guardrail_name: str, +) -> None: + """If input is a dict with a 'messages' list, apply guardrail to it in-place.""" + if not isinstance(input_data, dict) or "messages" not in input_data: + return + messages = input_data["messages"] + if not isinstance(messages, list): + return + _apply_guardrail_to_message_list( + messages, guardrail, uipath, action, guardrail_name + ) + + +def _apply_guardrail_to_output_messages( + output: Any, + guardrail: BuiltInValidatorGuardrail, + uipath: UiPath, + action: GuardrailAction, + guardrail_name: str, +) -> None: + """If output is a dict with a 'messages' list, apply guardrail to the last AIMessage in-place.""" + if not isinstance(output, dict) or "messages" not in output: + return + messages = output["messages"] + if not isinstance(messages, list): + return + _apply_guardrail_to_message_list( + messages, guardrail, uipath, action, guardrail_name, target_type=AIMessage + ) + + +def _wrap_stategraph_with_guardrail( + graph: StateGraph[Any, Any], metadata: GuardrailMetadata +) -> StateGraph[Any, Any]: + """Wrap StateGraph invoke/ainvoke to apply guardrails.""" + built_in_guardrail = metadata.guardrail + if built_in_guardrail is None: + return graph + action = metadata.config["action"] + guardrail_name = metadata.name + uipath = UiPath() + + if hasattr(graph, "invoke"): + original_invoke = graph.invoke + + @wraps(original_invoke) + def wrapped_invoke(input, config=None, **kwargs): + _apply_guardrail_to_input_messages( + input, built_in_guardrail, uipath, action, guardrail_name + ) + output = original_invoke(input, config, **kwargs) + _apply_guardrail_to_output_messages( + output, built_in_guardrail, uipath, action, guardrail_name + ) + return output + + graph.invoke = wrapped_invoke + + if hasattr(graph, "ainvoke"): + original_ainvoke = graph.ainvoke + + @wraps(original_ainvoke) + async def wrapped_ainvoke(input, config=None, **kwargs): + _apply_guardrail_to_input_messages( + input, built_in_guardrail, uipath, action, guardrail_name + ) + output = await original_ainvoke(input, config, **kwargs) + _apply_guardrail_to_output_messages( + output, built_in_guardrail, uipath, action, guardrail_name + ) + return output + + graph.ainvoke = wrapped_ainvoke + + return graph + + +def _wrap_compiled_graph_with_guardrail( + graph: CompiledStateGraph[Any, Any, Any], metadata: GuardrailMetadata +) -> CompiledStateGraph[Any, Any, Any]: + """Wrap a CompiledStateGraph's invoke/ainvoke to apply guardrails.""" + built_in_guardrail = metadata.guardrail + if built_in_guardrail is None: + return graph + action = metadata.config["action"] + guardrail_name = metadata.name + uipath = UiPath() + + original_invoke = graph.invoke + original_ainvoke = graph.ainvoke + + @wraps(original_invoke) + def wrapped_invoke(input, config=None, **kwargs): + _apply_guardrail_to_input_messages( + input, built_in_guardrail, uipath, action, guardrail_name + ) + output = original_invoke(input, config, **kwargs) + _apply_guardrail_to_output_messages( + output, built_in_guardrail, uipath, action, guardrail_name + ) + return output + + @wraps(original_ainvoke) + async def wrapped_ainvoke(input, config=None, **kwargs): + _apply_guardrail_to_input_messages( + input, built_in_guardrail, uipath, action, guardrail_name + ) + output = await original_ainvoke(input, config, **kwargs) + _apply_guardrail_to_output_messages( + output, built_in_guardrail, uipath, action, guardrail_name + ) + return output + + graph.invoke = wrapped_invoke # type: ignore[method-assign] + graph.ainvoke = wrapped_ainvoke # type: ignore[method-assign] + return graph + + +def _wrap_function_with_guardrail( + func: Callable[..., Any], metadata: GuardrailMetadata +) -> Callable[..., Any]: + """Wrap a function to apply guardrails. + + After calling the function, inspects the return value: + - StateGraph / CompiledStateGraph: delegates to the appropriate graph wrapper + - BaseChatModel: delegates to the LLM wrapper + - BaseTool: delegates to the tool wrapper + """ + built_in_guardrail = metadata.guardrail + if built_in_guardrail is None: + return func + action = metadata.config["action"] + guardrail_name = metadata.name + uipath = UiPath() + + @wraps(func) + def wrapped_func(*args, **kwargs): + result = func(*args, **kwargs) + if isinstance(result, StateGraph): + return _wrap_stategraph_with_guardrail(result, metadata) + if isinstance(result, CompiledStateGraph): + return _wrap_compiled_graph_with_guardrail(result, metadata) + if isinstance(result, BaseChatModel): + if metadata.wrap_llm is not None: + return metadata.wrap_llm(result, metadata) + if isinstance(result, BaseTool) and metadata.wrap_tool is not None: + return metadata.wrap_tool(result, metadata) + _apply_guardrail_to_output_messages( + result, built_in_guardrail, uipath, action, guardrail_name + ) + return result + + @wraps(func) + async def wrapped_async_func(*args, **kwargs): + result = await func(*args, **kwargs) + if isinstance(result, StateGraph): + return _wrap_stategraph_with_guardrail(result, metadata) + if isinstance(result, CompiledStateGraph): + return _wrap_compiled_graph_with_guardrail(result, metadata) + if isinstance(result, BaseChatModel): + if metadata.wrap_llm is not None: + return metadata.wrap_llm(result, metadata) + if isinstance(result, BaseTool) and metadata.wrap_tool is not None: + return metadata.wrap_tool(result, metadata) + _apply_guardrail_to_output_messages( + result, built_in_guardrail, uipath, action, guardrail_name + ) + return result + + if inspect.iscoroutinefunction(func): + return wrapped_async_func + return wrapped_func diff --git a/src/uipath_langchain/guardrails/decorators/guardrail.py b/src/uipath_langchain/guardrails/decorators/guardrail.py new file mode 100644 index 000000000..e3966a781 --- /dev/null +++ b/src/uipath_langchain/guardrails/decorators/guardrail.py @@ -0,0 +1,664 @@ +"""Single @guardrail decorator for all guardrail types.""" + +import inspect +import logging +from functools import wraps +from typing import Any, Callable + +from langchain_core.language_models import BaseChatModel +from langchain_core.messages import AIMessage, BaseMessage, ToolMessage +from langchain_core.tools import BaseTool +from langgraph.graph import StateGraph +from langgraph.graph.state import CompiledStateGraph +from langgraph.types import Command +from uipath.core.guardrails import ( + GuardrailValidationResult, +) +from uipath.platform import UiPath +from uipath.platform.guardrails import BuiltInValidatorGuardrail + +from ...agent.exceptions import AgentRuntimeError +from ..enums import GuardrailExecutionStage +from ..middlewares._utils import create_modified_tool_result +from ..models import GuardrailAction +from ._base import ( + _apply_message_text_modification, + _detect_scope, + _evaluate_guardrail, + _extract_input, + _extract_message_text, + _extract_output, + _get_last_ai_message, + _get_last_human_message, + _handle_guardrail_result, + _rewrap_input, +) +from .validators._base import GuardrailValidatorBase + +logger = logging.getLogger(__name__) + +# --------------------------------------------------------------------------- +# Evaluator factory +# --------------------------------------------------------------------------- + +_EvaluatorFn = Callable[ + [ + "str | dict[str, Any]", # data + GuardrailExecutionStage, # stage + "dict[str, Any] | None", # input_data + "dict[str, Any] | None", # output_data + ], + GuardrailValidationResult, +] + + +def _make_evaluator( + validator: GuardrailValidatorBase, + built_in_guardrail: BuiltInValidatorGuardrail | None, +) -> _EvaluatorFn: + """Return a unified evaluation callable for use in all wrappers. + + If *built_in_guardrail* is provided the callable hits the UiPath API (lazily + initialising ``UiPath()``). Otherwise it delegates to ``validator.evaluate()``. + + Args: + validator: The validator instance (used for local evaluation path). + built_in_guardrail: Pre-built ``BuiltInValidatorGuardrail``, or ``None``. + + Returns: + Callable with signature ``(data, stage, input_data, output_data)``. + """ + if built_in_guardrail is not None: + _uipath_holder: list[UiPath] = [] + + def _api_eval( + data: str | dict[str, Any], + stage: GuardrailExecutionStage, + input_data: dict[str, Any] | None, + output_data: dict[str, Any] | None, + ) -> GuardrailValidationResult: + if not _uipath_holder: + _uipath_holder.append(UiPath()) + return _evaluate_guardrail(data, built_in_guardrail, _uipath_holder[0]) + + return _api_eval + + def _local_eval( + data: str | dict[str, Any], + stage: GuardrailExecutionStage, + input_data: dict[str, Any] | None, + output_data: dict[str, Any] | None, + ) -> GuardrailValidationResult: + return validator.evaluate(data, stage, input_data, output_data) + + return _local_eval + + +# --------------------------------------------------------------------------- +# Generic tool wrapper +# --------------------------------------------------------------------------- + + +def _wrap_tool_with_guardrail( + tool: BaseTool, + evaluator: _EvaluatorFn, + action: GuardrailAction, + name: str, + stage: GuardrailExecutionStage, +) -> BaseTool: + """Wrap a ``BaseTool`` to apply the guardrail at PRE, POST, or PRE_AND_POST. + + Uses Pydantic ``__class__`` swapping so all Pydantic fields and the + ``StructuredTool`` interface are fully inherited by the guarded subclass. + + Args: + tool: ``BaseTool`` instance to wrap. + evaluator: Unified evaluation callable from ``_make_evaluator()``. + action: Action to invoke on validation failure. + name: Guardrail name (used in action and log messages). + stage: When to run the guardrail (PRE, POST, or PRE_AND_POST). + + Returns: + The same tool object with its ``__class__`` swapped to a guarded subclass. + """ + _stage = stage + + def _apply_pre(tool_input: Any) -> Any: + input_data = _extract_input(tool_input) + try: + result = evaluator( + input_data, GuardrailExecutionStage.PRE, input_data, None + ) + except Exception as exc: + logger.error( + "Error evaluating guardrail (pre) for tool %r: %s", + tool.name, + exc, + exc_info=True, + ) + return tool_input + try: + modified = _handle_guardrail_result(result, input_data, action, name) + except AgentRuntimeError: + raise + if modified is not None and isinstance(modified, dict): + return _rewrap_input(tool_input, modified) + return tool_input + + def _apply_post(tool_input: Any, raw_result: Any) -> Any: + input_data = _extract_input(tool_input) + output_data = _extract_output(raw_result) + try: + result = evaluator( + output_data, GuardrailExecutionStage.POST, input_data, output_data + ) + except Exception as exc: + logger.error( + "Error evaluating guardrail (post) for tool %r: %s", + tool.name, + exc, + exc_info=True, + ) + return raw_result + try: + modified = _handle_guardrail_result(result, output_data, action, name) + except AgentRuntimeError: + raise + if modified is not None: + if isinstance(raw_result, (ToolMessage, Command)): + return create_modified_tool_result(raw_result, modified) + return modified + return raw_result + + ConcreteToolType = type(tool) + + class _GuardedTool(ConcreteToolType): # type: ignore[valid-type, misc] + def invoke(self, tool_input: Any, config: Any = None, **kwargs: Any) -> Any: + guarded_input = tool_input + if _stage in ( + GuardrailExecutionStage.PRE, + GuardrailExecutionStage.PRE_AND_POST, + ): + guarded_input = _apply_pre(tool_input) + result = super().invoke(guarded_input, config, **kwargs) + if _stage in ( + GuardrailExecutionStage.POST, + GuardrailExecutionStage.PRE_AND_POST, + ): + result = _apply_post(guarded_input, result) + return result + + # ainvoke is intentionally NOT overridden here. + # StructuredTool.ainvoke (for sync tools without a coroutine) delegates to + # self.invoke via run_in_executor. Overriding ainvoke would cause the POST + # guardrail to fire twice: once inside self.invoke and once after super().ainvoke() + # returns. Guardrails are applied correctly through invoke alone. + + tool.__class__ = _GuardedTool + return tool + + +# --------------------------------------------------------------------------- +# Generic LLM wrapper +# --------------------------------------------------------------------------- + + +def _wrap_llm_with_guardrail( + llm: BaseChatModel, + evaluator: _EvaluatorFn, + action: GuardrailAction, + name: str, + stage: GuardrailExecutionStage, +) -> BaseChatModel: + """Wrap a ``BaseChatModel`` to apply the guardrail at PRE, POST, or PRE_AND_POST. + + PRE: evaluates the last ``HumanMessage`` before the LLM is called. + POST: evaluates the ``AIMessage`` response after the LLM returns. + + Args: + llm: ``BaseChatModel`` instance to wrap. + evaluator: Unified evaluation callable from ``_make_evaluator()``. + action: Action to invoke on validation failure. + name: Guardrail name. + stage: When to run the guardrail. + + Returns: + The same LLM object with its ``__class__`` swapped to a guarded subclass. + """ + _stage = stage + + ConcreteType = type(llm) + + class _GuardedLLM(ConcreteType): # type: ignore[valid-type, misc] + def invoke(self, messages: Any, config: Any = None, **kwargs: Any) -> Any: + if isinstance(messages, list) and _stage in ( + GuardrailExecutionStage.PRE, + GuardrailExecutionStage.PRE_AND_POST, + ): + _apply_llm_pre(messages, evaluator, action, name) + response = super().invoke(messages, config, **kwargs) + if isinstance(response, AIMessage) and _stage in ( + GuardrailExecutionStage.POST, + GuardrailExecutionStage.PRE_AND_POST, + ): + _apply_llm_post(response, evaluator, action, name) + return response + + async def ainvoke( + self, messages: Any, config: Any = None, **kwargs: Any + ) -> Any: + if isinstance(messages, list) and _stage in ( + GuardrailExecutionStage.PRE, + GuardrailExecutionStage.PRE_AND_POST, + ): + _apply_llm_pre(messages, evaluator, action, name) + response = await super().ainvoke(messages, config, **kwargs) + if isinstance(response, AIMessage) and _stage in ( + GuardrailExecutionStage.POST, + GuardrailExecutionStage.PRE_AND_POST, + ): + _apply_llm_post(response, evaluator, action, name) + return response + + llm.__class__ = _GuardedLLM + return llm + + +def _apply_llm_pre( + messages: list[BaseMessage], + evaluator: _EvaluatorFn, + action: GuardrailAction, + name: str, +) -> None: + """Evaluate the last HumanMessage in-place (PRE stage, LLM scope).""" + msg = _get_last_human_message(messages) + if msg is None: + return + text = _extract_message_text(msg) + if not text: + return + try: + result = evaluator(text, GuardrailExecutionStage.PRE, None, None) + except Exception: + return + modified = _handle_guardrail_result(result, text, action, name) + if isinstance(modified, str) and modified != text: + _apply_message_text_modification(msg, modified) + + +def _apply_llm_post( + response: AIMessage, + evaluator: _EvaluatorFn, + action: GuardrailAction, + name: str, +) -> None: + """Evaluate the AIMessage content in-place (POST stage, LLM scope).""" + if not isinstance(response.content, str) or not response.content: + return + try: + result = evaluator(response.content, GuardrailExecutionStage.POST, None, None) + except Exception: + return + modified = _handle_guardrail_result(result, response.content, action, name) + if isinstance(modified, str) and modified != response.content: + response.content = modified + + +# --------------------------------------------------------------------------- +# Generic StateGraph / CompiledStateGraph wrappers (AGENT scope) +# --------------------------------------------------------------------------- + + +def _wrap_stategraph_with_guardrail( + graph: StateGraph[Any, Any], + evaluator: _EvaluatorFn, + action: GuardrailAction, + name: str, + stage: GuardrailExecutionStage, +) -> StateGraph[Any, Any]: + """Wrap a ``StateGraph``'s invoke/ainvoke to apply the guardrail (AGENT scope). + + Args: + graph: ``StateGraph`` instance to wrap. + evaluator: Unified evaluation callable. + action: Action to invoke on validation failure. + name: Guardrail name. + stage: When to run the guardrail. + + Returns: + The same graph with patched ``invoke`` / ``ainvoke`` methods. + """ + if hasattr(graph, "invoke"): + original_invoke = graph.invoke + + @wraps(original_invoke) + def wrapped_invoke(input: Any, config: Any = None, **kwargs: Any) -> Any: + if stage in ( + GuardrailExecutionStage.PRE, + GuardrailExecutionStage.PRE_AND_POST, + ): + _apply_agent_input_guardrail(input, evaluator, action, name) + output = original_invoke(input, config, **kwargs) + if stage in ( + GuardrailExecutionStage.POST, + GuardrailExecutionStage.PRE_AND_POST, + ): + _apply_agent_output_guardrail(output, evaluator, action, name) + return output + + graph.invoke = wrapped_invoke + + if hasattr(graph, "ainvoke"): + original_ainvoke = graph.ainvoke + + @wraps(original_ainvoke) + async def wrapped_ainvoke(input: Any, config: Any = None, **kwargs: Any) -> Any: + if stage in ( + GuardrailExecutionStage.PRE, + GuardrailExecutionStage.PRE_AND_POST, + ): + _apply_agent_input_guardrail(input, evaluator, action, name) + output = await original_ainvoke(input, config, **kwargs) + if stage in ( + GuardrailExecutionStage.POST, + GuardrailExecutionStage.PRE_AND_POST, + ): + _apply_agent_output_guardrail(output, evaluator, action, name) + return output + + graph.ainvoke = wrapped_ainvoke + + return graph + + +def _wrap_compiled_graph_with_guardrail( + graph: CompiledStateGraph[Any, Any, Any], + evaluator: _EvaluatorFn, + action: GuardrailAction, + name: str, + stage: GuardrailExecutionStage, +) -> CompiledStateGraph[Any, Any, Any]: + """Wrap a ``CompiledStateGraph``'s invoke/ainvoke (AGENT scope). + + Args: + graph: ``CompiledStateGraph`` instance to wrap. + evaluator: Unified evaluation callable. + action: Action to invoke on validation failure. + name: Guardrail name. + stage: When to run the guardrail. + + Returns: + The same compiled graph with patched ``invoke`` / ``ainvoke`` methods. + """ + original_invoke = graph.invoke + original_ainvoke = graph.ainvoke + + @wraps(original_invoke) + def wrapped_invoke(input: Any, config: Any = None, **kwargs: Any) -> Any: + if stage in ( + GuardrailExecutionStage.PRE, + GuardrailExecutionStage.PRE_AND_POST, + ): + _apply_agent_input_guardrail(input, evaluator, action, name) + output = original_invoke(input, config, **kwargs) + if stage in ( + GuardrailExecutionStage.POST, + GuardrailExecutionStage.PRE_AND_POST, + ): + _apply_agent_output_guardrail(output, evaluator, action, name) + return output + + @wraps(original_ainvoke) + async def wrapped_ainvoke(input: Any, config: Any = None, **kwargs: Any) -> Any: + if stage in ( + GuardrailExecutionStage.PRE, + GuardrailExecutionStage.PRE_AND_POST, + ): + _apply_agent_input_guardrail(input, evaluator, action, name) + output = await original_ainvoke(input, config, **kwargs) + if stage in ( + GuardrailExecutionStage.POST, + GuardrailExecutionStage.PRE_AND_POST, + ): + _apply_agent_output_guardrail(output, evaluator, action, name) + return output + + graph.invoke = wrapped_invoke # type: ignore[method-assign] + graph.ainvoke = wrapped_ainvoke # type: ignore[method-assign] + return graph + + +def _apply_agent_input_guardrail( + input: Any, + evaluator: _EvaluatorFn, + action: GuardrailAction, + name: str, +) -> None: + """Evaluate the last HumanMessage from agent input in-place.""" + if not isinstance(input, dict) or "messages" not in input: + return + messages = input["messages"] + if not isinstance(messages, list): + return + msg = _get_last_human_message(messages) + if msg is None: + return + text = _extract_message_text(msg) + if not text: + return + try: + result = evaluator(text, GuardrailExecutionStage.PRE, None, None) + except Exception: + return + modified = _handle_guardrail_result(result, text, action, name) + if isinstance(modified, str) and modified != text: + _apply_message_text_modification(msg, modified) + + +def _apply_agent_output_guardrail( + output: Any, + evaluator: _EvaluatorFn, + action: GuardrailAction, + name: str, +) -> None: + """Evaluate the last AIMessage from agent output in-place.""" + if not isinstance(output, dict) or "messages" not in output: + return + messages = output["messages"] + if not isinstance(messages, list): + return + msg = _get_last_ai_message(messages) + if msg is None: + return + text = _extract_message_text(msg) + if not text: + return + try: + result = evaluator(text, GuardrailExecutionStage.POST, None, None) + except Exception: + return + modified = _handle_guardrail_result(result, text, action, name) + if isinstance(modified, str) and modified != text: + _apply_message_text_modification(msg, modified) + + +# --------------------------------------------------------------------------- +# Factory function wrapper (AGENT scope: function returning LLM/tool/graph) +# --------------------------------------------------------------------------- + + +def _wrap_factory_function( + func: Callable[..., Any], + evaluator: _EvaluatorFn, + action: GuardrailAction, + name: str, + stage: GuardrailExecutionStage, +) -> Callable[..., Any]: + """Wrap a factory function, applying the guardrail to its return value. + + After calling the function the return type is inspected and delegated to + the appropriate typed wrapper. + + Args: + func: Factory function to wrap. + evaluator: Unified evaluation callable. + action: Action to invoke on validation failure. + name: Guardrail name. + stage: When to run the guardrail. + + Returns: + Wrapped function (sync or async, matching the original). + """ + + def _dispatch(result: Any) -> Any: + if isinstance(result, CompiledStateGraph): + return _wrap_compiled_graph_with_guardrail( + result, evaluator, action, name, stage + ) + if isinstance(result, StateGraph): + return _wrap_stategraph_with_guardrail( + result, evaluator, action, name, stage + ) + if isinstance(result, BaseChatModel): + return _wrap_llm_with_guardrail(result, evaluator, action, name, stage) + if isinstance(result, BaseTool): + return _wrap_tool_with_guardrail(result, evaluator, action, name, stage) + return result + + if inspect.iscoroutinefunction(func): + + @wraps(func) + async def wrapped_async(*args: Any, **kwargs: Any) -> Any: + return _dispatch(await func(*args, **kwargs)) + + return wrapped_async + + @wraps(func) + def wrapped(*args: Any, **kwargs: Any) -> Any: + return _dispatch(func(*args, **kwargs)) + + return wrapped + + +# --------------------------------------------------------------------------- +# Public @guardrail decorator +# --------------------------------------------------------------------------- + + +def guardrail( + func: Any = None, + *, + validator: GuardrailValidatorBase, + action: GuardrailAction, + name: str = "Guardrail", + description: str | None = None, + stage: GuardrailExecutionStage = GuardrailExecutionStage.PRE_AND_POST, + enabled_for_evals: bool = True, +) -> Any: + """Apply a guardrail to an LLM, tool, or agent factory function. + + The guardrail is described by a *validator* (what to check) and an *action* + (how to respond on violation). Scope is auto-detected from the decorated + object type, and validated against the validator's ``supported_scopes`` / + ``supported_stages`` at decoration time. + + Can be stacked: multiple ``@guardrail`` decorators on the same object chain + via Pydantic ``__class__`` swapping (tools/LLMs) or function wrapping (agents). + + Args: + func: Object to decorate when used without parentheses (rare). + validator: ``GuardrailValidatorBase`` instance defining what to check. + action: ``GuardrailAction`` instance defining how to respond on violation. + name: Human-readable name for this guardrail instance. + description: Optional description (used when building API guardrails). + stage: When to evaluate — ``PRE``, ``POST``, or ``PRE_AND_POST``. + Defaults to ``PRE_AND_POST``. + enabled_for_evals: Whether this guardrail is active in evaluation + scenarios. Defaults to ``True``. + + Returns: + The decorated object (same type as input). + + Raises: + ValueError: If the validator does not support the detected scope or + the requested stage, or if ``action`` is missing / invalid. + + Example:: + + pii = PIIValidator(entities=[PIIDetectionEntity(PIIDetectionEntityType.EMAIL, 0.5)]) + + @guardrail(validator=pii, action=LogAction(), name="LLM PII", stage=GuardrailExecutionStage.PRE) + def create_llm(): + return UiPathChat(model="gpt-4o") + + @guardrail(validator=PIIValidator(entities=[PIIDetectionEntity(PIIDetectionEntityType.EMAIL, 0.5)]), + action=BlockAction(), name="Tool PII") + @tool + def my_tool(text: str) -> str: ... + """ + if action is None: + raise ValueError("action must be provided") + if not isinstance(action, GuardrailAction): + raise ValueError("action must be an instance of GuardrailAction") + if not isinstance(enabled_for_evals, bool): + raise ValueError("enabled_for_evals must be a boolean") + + def _apply(obj: Any) -> Any: + # Factory functions (plain callables that are not tool/LLM/graph instances) + # have an unknown scope until they are called and their return type is known. + # Handle them separately to avoid false scope-validation failures. + _is_factory = callable(obj) and not isinstance( + obj, (BaseTool, BaseChatModel, StateGraph, CompiledStateGraph) + ) + + if _is_factory: + # TOOL-only validators must be applied directly to @tool instances, + # not to factory functions — the scope is unresolvable at decoration time. + if validator.supported_scopes and all( + s.value == "Tool" for s in validator.supported_scopes + ): + raise ValueError( + f"@guardrail with {type(validator).__name__} can only be applied " + "to BaseTool instances (decorated with @tool). " + "Apply it directly to the tool, not to a factory function." + ) + validator.validate_stage(stage) + # Use the validator's primary scope (first supported, or AGENT if + # unrestricted) to build the API guardrail instance. + api_scope = ( + validator.supported_scopes[0] + if validator.supported_scopes + else _detect_scope(obj) + ) + else: + scope = _detect_scope(obj) + validator.validate_scope(scope) + validator.validate_stage(stage) + api_scope = scope + + built_in_guardrail = validator.build_built_in_guardrail( + api_scope, name, description, enabled_for_evals + ) + evaluator = _make_evaluator(validator, built_in_guardrail) + + if isinstance(obj, BaseTool): + return _wrap_tool_with_guardrail(obj, evaluator, action, name, stage) + if isinstance(obj, BaseChatModel): + return _wrap_llm_with_guardrail(obj, evaluator, action, name, stage) + if isinstance(obj, CompiledStateGraph): + return _wrap_compiled_graph_with_guardrail( + obj, evaluator, action, name, stage + ) + if isinstance(obj, StateGraph): + return _wrap_stategraph_with_guardrail(obj, evaluator, action, name, stage) + if callable(obj): + return _wrap_factory_function(obj, evaluator, action, name, stage) + raise ValueError( + f"@guardrail cannot be applied to {type(obj)!r}. " + "Target must be a BaseTool, BaseChatModel, StateGraph, " + "CompiledStateGraph, or a callable factory function." + ) + + if func is None: + return _apply + return _apply(func) diff --git a/src/uipath_langchain/guardrails/decorators/validators/__init__.py b/src/uipath_langchain/guardrails/decorators/validators/__init__.py new file mode 100644 index 000000000..a611e8807 --- /dev/null +++ b/src/uipath_langchain/guardrails/decorators/validators/__init__.py @@ -0,0 +1,14 @@ +"""Guardrail validator classes for use with the @guardrail decorator.""" + +from ._base import GuardrailValidatorBase +from .deterministic import DeterministicValidator, RuleFunction +from .pii import PIIValidator +from .prompt_injection import PromptInjectionValidator + +__all__ = [ + "GuardrailValidatorBase", + "PIIValidator", + "PromptInjectionValidator", + "DeterministicValidator", + "RuleFunction", +] diff --git a/src/uipath_langchain/guardrails/decorators/validators/_base.py b/src/uipath_langchain/guardrails/decorators/validators/_base.py new file mode 100644 index 000000000..7c7f6a845 --- /dev/null +++ b/src/uipath_langchain/guardrails/decorators/validators/_base.py @@ -0,0 +1,109 @@ +"""Abstract base class for guardrail validators.""" + +from typing import Any, ClassVar + +from uipath.core.guardrails import GuardrailScope, GuardrailValidationResult +from uipath.platform.guardrails import BuiltInValidatorGuardrail + +from ...enums import GuardrailExecutionStage # guardrails/enums.py + + +class GuardrailValidatorBase: + """Abstract base class for guardrail validators. + + Defines WHAT to validate. The @guardrail decorator defines HOW to respond + (action, name, stage, enabled_for_evals). + + A validator instance can be declared once and reused across multiple + @guardrail decorators with different actions, names, or stages. + + Subclasses implement either: + - ``build_built_in_guardrail()`` for UiPath API-based validation (PII, PromptInjection) + - ``evaluate()`` for local Python-based validation (Deterministic) + """ + + supported_scopes: ClassVar[list[GuardrailScope]] = [] + """Scopes this validator supports. Empty list means all scopes are allowed.""" + + supported_stages: ClassVar[list[GuardrailExecutionStage]] = [] + """Stages this validator supports. Empty list means all stages are allowed.""" + + def build_built_in_guardrail( + self, + scope: GuardrailScope, + name: str, + description: str | None, + enabled_for_evals: bool, + ) -> BuiltInValidatorGuardrail | None: + """Build a UiPath API guardrail instance for this validator. + + API-based validators (PII, PromptInjection) override this to return a + ``BuiltInValidatorGuardrail``. Local validators return ``None`` (default). + + Args: + scope: The resolved scope of the decorated object. + name: Name for the guardrail. + description: Optional description. + enabled_for_evals: Whether enabled in evaluation scenarios. + + Returns: + ``BuiltInValidatorGuardrail`` for API-based evaluation, or ``None`` + to use local ``evaluate()`` instead. + """ + return None + + def evaluate( + self, + data: str | dict[str, Any], + stage: GuardrailExecutionStage, + input_data: dict[str, Any] | None, + output_data: dict[str, Any] | None, + ) -> GuardrailValidationResult: + """Perform local validation (no UiPath API call). + + Local validators (Deterministic) override this. Only called when + ``build_built_in_guardrail()`` returns ``None``. + + Args: + data: The primary data being evaluated (message text, tool I/O dict). + stage: Current execution stage (PRE or POST). + input_data: Normalised tool/agent input dict, or ``None`` if unavailable. + output_data: Normalised tool/agent output dict, or ``None`` at PRE stage. + + Returns: + ``GuardrailValidationResult`` with PASSED or VALIDATION_FAILED. + """ + raise NotImplementedError( + f"{type(self).__name__} must implement either build_built_in_guardrail() " + "for API-based validation or evaluate() for local validation." + ) + + def validate_scope(self, scope: GuardrailScope) -> None: + """Raise ``ValueError`` if ``scope`` is not supported by this validator. + + Args: + scope: The resolved scope of the decorated object. + + Raises: + ValueError: If ``supported_scopes`` is non-empty and ``scope`` is absent. + """ + if self.supported_scopes and scope not in self.supported_scopes: + raise ValueError( + f"{type(self).__name__} does not support scope {scope!r}. " + f"Supported scopes: {[s.value for s in self.supported_scopes]}" + ) + + def validate_stage(self, stage: GuardrailExecutionStage) -> None: + """Raise ``ValueError`` if ``stage`` is not supported by this validator. + + Args: + stage: The requested execution stage. + + Raises: + ValueError: If ``supported_stages`` is non-empty and ``stage`` is absent. + """ + if self.supported_stages and stage not in self.supported_stages: + raise ValueError( + f"{type(self).__name__} does not support stage {stage!r}. " + f"Supported stages: {[s.value for s in self.supported_stages]}" + ) diff --git a/src/uipath_langchain/guardrails/decorators/validators/deterministic.py b/src/uipath_langchain/guardrails/decorators/validators/deterministic.py new file mode 100644 index 000000000..d4c32958a --- /dev/null +++ b/src/uipath_langchain/guardrails/decorators/validators/deterministic.py @@ -0,0 +1,98 @@ +"""Deterministic (rule-based) guardrail validator.""" + +import inspect +from typing import Any, Callable, ClassVar, Sequence + +from uipath.core.guardrails import GuardrailScope, GuardrailValidationResult + +from ...enums import GuardrailExecutionStage +from ._base import GuardrailValidatorBase + +RuleFunction = ( + Callable[[dict[str, Any]], bool] | Callable[[dict[str, Any], dict[str, Any]], bool] +) +"""Type alias for deterministic rule functions. + +A rule is a callable that returns ``True`` to signal a violation, ``False`` +to pass. It accepts either: + +- One parameter: ``(input_dict) -> bool`` — evaluated at PRE stage. +- Two parameters: ``(input_dict, output_dict) -> bool`` — evaluated at POST stage. + +All rules must detect a violation for the guardrail action to trigger (AND +semantics). An empty rules list always triggers the action. +""" + + +class DeterministicValidator(GuardrailValidatorBase): + """Validates tool input/output using local Python rule functions. + + No UiPath API call is made. Rules run in-process, making this suitable + for fast, deterministic checks such as keyword filtering, length limits, + or regex matching. + + Restricted to TOOL scope only. + + Args: + rules: Sequence of rule callables. Each rule receives the tool input + dict (1-parameter rules) or both input and output dicts + (2-parameter rules). Returns ``True`` to flag a violation. + ALL rules must flag a violation for the action to trigger. + An empty list always triggers the action. + + Raises: + ValueError: If any rule is not callable or has an unsupported parameter count. + + Example:: + + donkey_filter = DeterministicValidator( + rules=[lambda args: "donkey" in args.get("joke", "").lower()] + ) + + @guardrail( + validator=donkey_filter, + action=CustomFilterAction(word_to_filter="donkey", replacement="[censored]"), + stage=GuardrailExecutionStage.PRE, + name="Joke Content Word Filter", + ) + @tool + def analyze_joke_syntax(joke: str) -> str: ... + """ + + supported_scopes = [GuardrailScope.TOOL] + supported_stages: ClassVar[list[GuardrailExecutionStage]] = [] # all stages allowed + + def __init__(self, rules: Sequence[RuleFunction] = ()) -> None: + for i, rule in enumerate(rules): + if not callable(rule): + raise ValueError(f"Rule {i + 1} must be callable, got {type(rule)}") + sig = inspect.signature(rule) + param_count = len(sig.parameters) + if param_count not in (1, 2): + raise ValueError( + f"Rule {i + 1} must have 1 or 2 parameters, got {param_count}" + ) + self.rules = list(rules) + + def evaluate( + self, + data: str | dict[str, Any], + stage: GuardrailExecutionStage, + input_data: dict[str, Any] | None, + output_data: dict[str, Any] | None, + ) -> GuardrailValidationResult: + """Evaluate rules locally against tool input/output dicts. + + Args: + data: Unused — rules operate on ``input_data`` / ``output_data`` directly. + stage: Current stage (PRE evaluates 1-param rules; POST evaluates + 2-param rules when ``input_data`` is available, else 1-param rules). + input_data: Normalised tool input dict. + output_data: Normalised tool output dict (``None`` at PRE stage). + + Returns: + ``GuardrailValidationResult`` with PASSED or VALIDATION_FAILED. + """ + from .._base import _evaluate_rules # decorators/_base.py + + return _evaluate_rules(self.rules, stage, input_data, output_data) diff --git a/src/uipath_langchain/guardrails/decorators/validators/pii.py b/src/uipath_langchain/guardrails/decorators/validators/pii.py new file mode 100644 index 000000000..1d960c83e --- /dev/null +++ b/src/uipath_langchain/guardrails/decorators/validators/pii.py @@ -0,0 +1,110 @@ +"""PII detection guardrail validator.""" + +from typing import Any, Sequence +from uuid import uuid4 + +from uipath.core.guardrails import GuardrailScope, GuardrailSelector +from uipath.platform.guardrails import ( + BuiltInValidatorGuardrail, + EnumListParameterValue, + MapEnumParameterValue, +) + +from ...models import PIIDetectionEntity +from ._base import GuardrailValidatorBase + + +class PIIValidator(GuardrailValidatorBase): + """Validates data for PII (Personally Identifiable Information) entities. + + Uses the UiPath PII detection API to identify entities such as email addresses, + phone numbers, credit card numbers, and other PII types. + + Supported at all scopes (AGENT, LLM, TOOL) and all stages. + + A single ``PIIValidator`` instance can be declared once and reused across + multiple ``@guardrail`` decorators with different actions or stages. + + Args: + entities: One or more ``PIIDetectionEntity`` objects specifying which PII + types to detect and their confidence thresholds (0.0–1.0). + + Raises: + ValueError: If ``entities`` is empty. + + Example:: + + pii_email = PIIValidator( + entities=[PIIDetectionEntity(PIIDetectionEntityType.EMAIL, 0.5)] + ) + + @guardrail( + validator=pii_email, + action=LogAction(severity_level=LoggingSeverityLevel.WARNING), + name="LLM PII Detection", + stage=GuardrailExecutionStage.PRE, + ) + def create_llm(): + return UiPathChat(model="gpt-4o") + + @guardrail( + validator=pii_email, + action=BlockAction(), + name="Tool PII Detection", + ) + @tool + def my_tool(text: str) -> str: ... + """ + + # All scopes and stages supported — inherits empty lists from base (unrestricted). + + def __init__(self, entities: Sequence[PIIDetectionEntity]) -> None: + if not entities: + raise ValueError("entities must be provided and non-empty") + self.entities = list(entities) + + def build_built_in_guardrail( + self, + scope: GuardrailScope, + name: str, + description: str | None, + enabled_for_evals: bool, + ) -> BuiltInValidatorGuardrail: + """Build a PII detection ``BuiltInValidatorGuardrail`` for the UiPath API. + + Args: + scope: The resolved scope of the decorated object. + name: Name for the guardrail. + description: Optional description. + enabled_for_evals: Whether enabled in evaluation scenarios. + + Returns: + Configured ``BuiltInValidatorGuardrail`` for PII detection. + """ + entity_names = [entity.name for entity in self.entities] + entity_thresholds: dict[str, Any] = { + entity.name: entity.threshold for entity in self.entities + } + + return BuiltInValidatorGuardrail( + id=str(uuid4()), + name=name, + description=description + or f"Detects PII entities: {', '.join(entity_names)}", + enabled_for_evals=enabled_for_evals, + selector=GuardrailSelector(scopes=[scope]), + guardrail_type="builtInValidator", + validator_type="pii_detection", + validator_parameters=[ + EnumListParameterValue( + parameter_type="enum-list", + id="entities", + value=entity_names, + ), + MapEnumParameterValue( + parameter_type="map-enum", + id="entityThresholds", + value=entity_thresholds, + ), + ], + ) diff --git a/src/uipath_langchain/guardrails/decorators/validators/prompt_injection.py b/src/uipath_langchain/guardrails/decorators/validators/prompt_injection.py new file mode 100644 index 000000000..7214ff70e --- /dev/null +++ b/src/uipath_langchain/guardrails/decorators/validators/prompt_injection.py @@ -0,0 +1,80 @@ +"""Prompt injection detection guardrail validator.""" + +from uuid import uuid4 + +from uipath.core.guardrails import GuardrailScope, GuardrailSelector +from uipath.platform.guardrails import BuiltInValidatorGuardrail +from uipath.platform.guardrails.guardrails import NumberParameterValue + +from ...enums import GuardrailExecutionStage +from ._base import GuardrailValidatorBase + + +class PromptInjectionValidator(GuardrailValidatorBase): + """Validates LLM input for prompt injection attacks. + + Uses the UiPath prompt injection detection API. Restricted to LLM scope + and PRE stage only — prompt injection is an input-only concern. + + Args: + threshold: Detection confidence threshold (0.0–1.0). Default: ``0.5``. + + Raises: + ValueError: If ``threshold`` is outside [0.0, 1.0]. + + Example:: + + prompt_inject = PromptInjectionValidator(threshold=0.7) + + @guardrail( + validator=prompt_inject, + action=BlockAction(), + name="LLM Prompt Injection Detection", + ) + def create_llm(): + return UiPathChat(model="gpt-4o") + """ + + supported_scopes = [GuardrailScope.LLM] + supported_stages = [GuardrailExecutionStage.PRE] + + def __init__(self, threshold: float = 0.5) -> None: + if not 0.0 <= threshold <= 1.0: + raise ValueError(f"threshold must be between 0.0 and 1.0, got {threshold}") + self.threshold = threshold + + def build_built_in_guardrail( + self, + scope: GuardrailScope, + name: str, + description: str | None, + enabled_for_evals: bool, + ) -> BuiltInValidatorGuardrail: + """Build a prompt injection ``BuiltInValidatorGuardrail`` for the UiPath API. + + Args: + scope: The resolved scope of the decorated object (must be LLM). + name: Name for the guardrail. + description: Optional description. + enabled_for_evals: Whether enabled in evaluation scenarios. + + Returns: + Configured ``BuiltInValidatorGuardrail`` for prompt injection detection. + """ + return BuiltInValidatorGuardrail( + id=str(uuid4()), + name=name, + description=description + or f"Detects prompt injection with threshold {self.threshold}", + enabled_for_evals=enabled_for_evals, + selector=GuardrailSelector(scopes=[GuardrailScope.LLM]), + guardrail_type="builtInValidator", + validator_type="prompt_injection", + validator_parameters=[ + NumberParameterValue( + parameter_type="number", + id="threshold", + value=self.threshold, + ), + ], + ) diff --git a/src/uipath_langchain/guardrails/enums.py b/src/uipath_langchain/guardrails/enums.py index c13ecf3a9..f73e4d23f 100644 --- a/src/uipath_langchain/guardrails/enums.py +++ b/src/uipath_langchain/guardrails/enums.py @@ -37,7 +37,7 @@ class PIIDetectionEntityType(str, Enum): class GuardrailExecutionStage(str, Enum): - """Execution stage for deterministic guardrails.""" + """Execution stage for guardrails.""" PRE = "pre" # Pre-execution only POST = "post" # Post-execution only diff --git a/src/uipath_langchain/guardrails/middlewares/deterministic.py b/src/uipath_langchain/guardrails/middlewares/deterministic.py index 606d4d5b7..fdaaceab7 100644 --- a/src/uipath_langchain/guardrails/middlewares/deterministic.py +++ b/src/uipath_langchain/guardrails/middlewares/deterministic.py @@ -65,6 +65,7 @@ class UiPathDeterministicGuardrailMiddleware: rules=[], action=CustomFilterAction(...), stage=GuardrailExecutionStage.POST, + enabled_for_evals=False, ) agent = create_agent( @@ -91,6 +92,8 @@ class UiPathDeterministicGuardrailMiddleware: - GuardrailExecutionStage.PRE_AND_POST: Validate both input and output name: Optional name for the guardrail (defaults to "Deterministic Guardrail") description: Optional description for the guardrail + enabled_for_evals: Whether this guardrail is enabled for evaluation scenarios. + Defaults to True. """ def __init__( @@ -102,6 +105,7 @@ def __init__( *, name: str = "Deterministic Guardrail", description: str | None = None, + enabled_for_evals: bool = True, ): """Initialize deterministic guardrail middleware.""" if not tools: @@ -112,6 +116,8 @@ def __init__( raise ValueError( f"stage must be an instance of GuardrailExecutionStage, got {type(stage)}" ) + if not isinstance(enabled_for_evals, bool): + raise ValueError("enabled_for_evals must be a boolean") for i, rule in enumerate(rules): if not callable(rule): @@ -139,6 +145,7 @@ def __init__( self.action = action self._stage = stage self._name = name + self.enabled_for_evals = enabled_for_evals self._description = description or "Deterministic guardrail with custom rules" self._middleware_instances = self._create_middleware_instances() diff --git a/src/uipath_langchain/guardrails/middlewares/pii_detection.py b/src/uipath_langchain/guardrails/middlewares/pii_detection.py index be5a95334..b40d8eece 100644 --- a/src/uipath_langchain/guardrails/middlewares/pii_detection.py +++ b/src/uipath_langchain/guardrails/middlewares/pii_detection.py @@ -31,6 +31,8 @@ MapEnumParameterValue, ) +from uipath_langchain.agent.exceptions import AgentRuntimeError + from ..models import GuardrailAction, PIIDetectionEntity from ._utils import ( create_modified_tool_request, @@ -70,6 +72,7 @@ def analyze_joke_syntax(joke: str) -> str: PIIDetectionEntity(PIIDetectionEntityType.EMAIL, 0.5), PIIDetectionEntity(PIIDetectionEntityType.ADDRESS, 0.7), ], + enabled_for_evals=True, ) # PII detection for specific tools (using tool reference directly) @@ -78,6 +81,7 @@ def analyze_joke_syntax(joke: str) -> str: action=LogAction(severity_level=LoggingSeverityLevel.WARNING), entities=[PIIDetectionEntity(PIIDetectionEntityType.EMAIL, 0.5)], tools=[analyze_joke_syntax], + enabled_for_evals=False, ) agent = create_agent( @@ -97,6 +101,8 @@ def analyze_joke_syntax(joke: str) -> str: If TOOL scope is not specified, this parameter is ignored. name: Optional name for the guardrail (defaults to "PII Detection") description: Optional description for the guardrail + enabled_for_evals: Whether this guardrail is enabled for evaluation scenarios. + Defaults to True. """ def __init__( @@ -108,6 +114,7 @@ def __init__( tools: Sequence[str | BaseTool] | None = None, name: str = "PII Detection", description: str | None = None, + enabled_for_evals: bool = True, ): """Initialize PII detection guardrail middleware.""" if not scopes: @@ -116,6 +123,8 @@ def __init__( raise ValueError("At least one entity must be specified") if not isinstance(action, GuardrailAction): raise ValueError("action must be an instance of GuardrailAction") + if not isinstance(enabled_for_evals, bool): + raise ValueError("enabled_for_evals must be a boolean") self._tool_names: list[str] | None = None if tools is not None: @@ -144,6 +153,7 @@ def __init__( self.action = action self.entities = list(entities) self._name = name + self.enabled_for_evals = enabled_for_evals self._description = ( description or f"Detects PII entities: {', '.join(e.name for e in entities)}" @@ -230,6 +240,8 @@ async def _wrap_tool_call_func( ) if modified_input is not None and isinstance(modified_input, dict): request = create_modified_tool_request(request, modified_input) + except AgentRuntimeError: + raise except Exception as e: logger.error( f"Error evaluating PII guardrail for tool '{tool_name}': {e}", @@ -274,7 +286,7 @@ def _create_guardrail(self) -> BuiltInValidatorGuardrail: id=str(uuid4()), name=self._name, description=self._description, - enabled_for_evals=True, + enabled_for_evals=self.enabled_for_evals, selector=GuardrailSelector(**selector_kwargs), guardrail_type="builtInValidator", validator_type="pii_detection", @@ -334,5 +346,7 @@ def _check_messages(self, messages: list[BaseMessage]) -> None: if isinstance(msg.content, str) and text in msg.content: msg.content = msg.content.replace(text, modified_text, 1) break + except AgentRuntimeError: + raise except Exception as e: logger.error(f"Error evaluating PII guardrail: {e}", exc_info=True) diff --git a/src/uipath_langchain/guardrails/middlewares/prompt_injection.py b/src/uipath_langchain/guardrails/middlewares/prompt_injection.py index 787fd10bb..7711723b0 100644 --- a/src/uipath_langchain/guardrails/middlewares/prompt_injection.py +++ b/src/uipath_langchain/guardrails/middlewares/prompt_injection.py @@ -16,6 +16,8 @@ from uipath.platform.guardrails import BuiltInValidatorGuardrail, GuardrailScope from uipath.platform.guardrails.guardrails import NumberParameterValue +from uipath_langchain.agent.exceptions import AgentRuntimeError + from ..models import GuardrailAction from ._utils import extract_text_from_messages @@ -37,6 +39,7 @@ class UiPathPromptInjectionMiddleware: middleware = UiPathPromptInjectionMiddleware( action=LogAction(severity_level=LoggingSeverityLevel.WARNING), threshold=0.5, + enabled_for_evals=True, ) ``` @@ -47,6 +50,8 @@ class UiPathPromptInjectionMiddleware: threshold: Detection threshold (0.0 to 1.0) name: Optional name for the guardrail (defaults to "Prompt Injection Detection") description: Optional description for the guardrail + enabled_for_evals: Whether this guardrail is enabled for evaluation scenarios. + Defaults to True. """ def __init__( @@ -57,12 +62,15 @@ def __init__( scopes: Sequence[GuardrailScope] | None = None, name: str = "Prompt Injection Detection", description: str | None = None, + enabled_for_evals: bool = True, ): """Initialize prompt injection detection guardrail middleware.""" if not isinstance(action, GuardrailAction): raise ValueError("action must be an instance of GuardrailAction") if not 0.0 <= threshold <= 1.0: raise ValueError(f"Threshold must be between 0.0 and 1.0, got {threshold}") + if not isinstance(enabled_for_evals, bool): + raise ValueError("enabled_for_evals must be a boolean") scopes_list = list(scopes) if scopes is not None else [GuardrailScope.LLM] if scopes_list != [GuardrailScope.LLM]: @@ -75,6 +83,7 @@ def __init__( self.action = action self.threshold = threshold self._name = name + self.enabled_for_evals = enabled_for_evals self._description = ( description or f"Detects prompt injection attempts with threshold {threshold}" @@ -118,7 +127,7 @@ def _create_guardrail(self) -> BuiltInValidatorGuardrail: id=str(uuid4()), name=self._name, description=self._description, - enabled_for_evals=True, + enabled_for_evals=self.enabled_for_evals, selector=GuardrailSelector(scopes=self.scopes), guardrail_type="builtInValidator", validator_type="prompt_injection", @@ -168,6 +177,8 @@ def _check_messages(self, messages: list[BaseMessage]) -> None: if isinstance(msg.content, str) and text in msg.content: msg.content = msg.content.replace(text, modified_text, 1) break + except AgentRuntimeError: + raise except Exception as e: logger.error( f"Error evaluating prompt injection guardrail: {e}", exc_info=True diff --git a/uv.lock b/uv.lock index 4a41c373e..6d6ff018d 100644 --- a/uv.lock +++ b/uv.lock @@ -3333,7 +3333,7 @@ wheels = [ [[package]] name = "uipath-langchain" -version = "0.9.12" +version = "0.9.13" source = { editable = "." } dependencies = [ { name = "httpx" },