From 8950be120b896bb6cbb0c90fd83c0ca063ff68b9 Mon Sep 17 00:00:00 2001 From: Raghav Sunil Date: Mon, 22 Jun 2026 22:45:22 +0000 Subject: [PATCH 1/6] feat(payments): Add LangGraph integration for payment handling --- .../payments/integrations/langgraph/README.md | 462 ++++++++++ .../integrations/langgraph/__init__.py | 12 + .../payments/integrations/langgraph/config.py | 151 ++++ .../payments/integrations/langgraph/errors.py | 45 + .../integrations/langgraph/middleware.py | 831 ++++++++++++++++++ .../payments/integrations/langgraph/tools.py | 194 ++++ .../integrations/langgraph/__init__.py | 0 .../integrations/langgraph/test_functional.py | 534 +++++++++++ .../integrations/langgraph/test_stage1.py | 332 +++++++ .../integrations/langgraph/test_stage2.py | 326 +++++++ .../integrations/langgraph/test_stage3.py | 347 ++++++++ .../integrations/langgraph/test_stage4.py | 238 +++++ .../integrations/langgraph/test_stage5.py | 246 ++++++ .../integrations/langgraph/test_stage6.py | 251 ++++++ .../integrations/langgraph/test_stage7.py | 376 ++++++++ 15 files changed, 4345 insertions(+) create mode 100644 src/bedrock_agentcore/payments/integrations/langgraph/README.md create mode 100644 src/bedrock_agentcore/payments/integrations/langgraph/__init__.py create mode 100644 src/bedrock_agentcore/payments/integrations/langgraph/config.py create mode 100644 src/bedrock_agentcore/payments/integrations/langgraph/errors.py create mode 100644 src/bedrock_agentcore/payments/integrations/langgraph/middleware.py create mode 100644 src/bedrock_agentcore/payments/integrations/langgraph/tools.py create mode 100644 tests/bedrock_agentcore/payments/integrations/langgraph/__init__.py create mode 100644 tests/bedrock_agentcore/payments/integrations/langgraph/test_functional.py create mode 100644 tests/bedrock_agentcore/payments/integrations/langgraph/test_stage1.py create mode 100644 tests/bedrock_agentcore/payments/integrations/langgraph/test_stage2.py create mode 100644 tests/bedrock_agentcore/payments/integrations/langgraph/test_stage3.py create mode 100644 tests/bedrock_agentcore/payments/integrations/langgraph/test_stage4.py create mode 100644 tests/bedrock_agentcore/payments/integrations/langgraph/test_stage5.py create mode 100644 tests/bedrock_agentcore/payments/integrations/langgraph/test_stage6.py create mode 100644 tests/bedrock_agentcore/payments/integrations/langgraph/test_stage7.py diff --git a/src/bedrock_agentcore/payments/integrations/langgraph/README.md b/src/bedrock_agentcore/payments/integrations/langgraph/README.md new file mode 100644 index 00000000..f295c76e --- /dev/null +++ b/src/bedrock_agentcore/payments/integrations/langgraph/README.md @@ -0,0 +1,462 @@ +# LangGraph AgentCore Payments Middleware + +The AgentCore Payments Middleware enables LangGraph agents to autonomously handle [x402 Payment Required](https://www.x402.org/) responses. When a tool hits a paid API that returns HTTP 402, the middleware automatically detects the payment requirement, signs the payment via PaymentManager, and retries the request with payment credentials — all transparent to the LLM. + +## Overview + +- **Automatic x402 Payment Handling** — intercepts 402 responses, processes payment, retries with proof +- **Zero Wrapper Code** — no manual tool wrapping needed; just pass `middleware=[payments]` to `create_agent` +- **Multi-Format Detection** — handles `PAYMENT_REQUIRED:` marker, raw JSON `statusCode: 402`, and x402 payloads +- **Custom Handler Registry** — register handlers for tools with non-standard response formats +- **Built-in Tools** — payment-aware `http_request` + payment query tools auto-registered +- **Deterministic Error Messages** — tailored, actionable error messages returned to the LLM on failure +- **Async Support** — non-blocking `asyncio.sleep` and `asyncio.to_thread` for the async path +- **Auto-Session** — optionally create payment sessions lazily on first payment + +## How It Works + +``` +┌─────────┐ ┌──────────────────────────────┐ ┌────────────┐ +│ Agent │────▶│ wrap_tool_call (middleware) │────▶│ Tool │──── HTTP ───▶ Paid API +│ │ │ │ └────────────┘ │ +│ │ │ 1. Execute tool │ │ +│ │ │ 2. Detect 402 │◀── 402 + x402 payload ─────────┘ +│ │ │ 3. Sign payment (PM) │ +│ │ │ 4. Inject header │ +│ │ │ 5. Wait (blockchain delay) │ +│ │ │ 6. Retry tool │──── HTTP + payment header ──▶ Paid API +│ │◀────│ 7. Return 200 to agent │◀── 200 + content ──────────────┘ +└─────────┘ └──────────────────────────────┘ +``` + +## Quick Start + +```python +from langchain.agents import create_agent +from bedrock_agentcore.payments.integrations.langgraph import ( + AgentCorePaymentsConfig, + AgentCorePaymentsMiddleware, +) + +# 1. Config +config = AgentCorePaymentsConfig( + payment_manager_arn="arn:aws:bedrock-agentcore:us-east-1:123456789012:payment-manager/pm-123", + user_id="user-123", + payment_instrument_id="instrument-456", + region="us-east-1", + auto_session=True, # session created automatically on first payment +) + +# 2. Middleware +payments = AgentCorePaymentsMiddleware(config) + +# 3. Agent — that's it +agent = create_agent( + model="claude-sonnet-4-20250514", + tools=[], # middleware auto-registers http_request + payment query tools + middleware=[payments], +) + +# 402 responses are handled automatically +result = agent.invoke({"messages": [{"role": "user", "content": "Fetch data from https://paid-api.example.com/data"}]}) +``` + +## Built-in Tools + +The middleware automatically registers these tools (available to the LLM): + +| Tool | Description | +|------|-------------| +| `http_request` | Call any HTTP endpoint. 402 responses are paid automatically. | +| `get_payment_instrument` | Query details about a payment instrument | +| `list_payment_instruments` | List all instruments for a user | +| `get_payment_instrument_balance` | Check wallet balance on a chain | +| `get_payment_session` | Query session budget, status, expiry | + +Set `provide_http_request=False` if you bring your own HTTP tool. + +## Custom Tool Integration Contract + +For your own tools to work with auto-payment, they need two things: + +### 1. Signal 402 (output) + +The tool must indicate a 402 response in its return value. Three formats are supported: + +**Option A: `PAYMENT_REQUIRED:` marker (recommended)** +```python +@tool +def my_api(query: str, headers: dict = None) -> str: + resp = httpx.get(URL, headers=headers or {}) + if resp.status_code == 402: + payload = {"statusCode": 402, "headers": dict(resp.headers), "body": resp.json()} + return f"PAYMENT_REQUIRED: {json.dumps(payload)}" + return resp.text +``` + +**Option B: Raw JSON with `statusCode: 402` (fallback detection)** +```python +@tool +def my_api(query: str, headers: dict = None) -> str: + resp = httpx.get(URL, headers=headers or {}) + return json.dumps({"statusCode": resp.status_code, "headers": dict(resp.headers), "body": resp.json()}) +``` + +**Option C: Custom handler (for non-standard formats)** +```python +config = AgentCorePaymentsConfig( + ..., + custom_handlers={"my_tool": MyCustomHandler()}, +) +``` + +### 2. Accept and forward `headers` (input) + +The tool **must** have a `headers` parameter and use it in its HTTP request. The middleware injects the payment header into `tool_args["headers"]` before retry: + +```python +@tool +def my_api(query: str, headers: dict = None) -> str: + resp = httpx.get(URL, headers=headers or {}) # ← forwards headers + ... +``` + +Without this, the payment header is injected but never sent to the server. + +### Minimal custom tool template + +```python +@tool +def my_paid_tool(query: str, headers: dict = None) -> str: + """Access a paid API. Payments handled automatically.""" + resp = httpx.get("https://paid-api.example.com/data", headers=headers or {}) + if resp.status_code == 402: + payload = {"statusCode": 402, "headers": dict(resp.headers), "body": resp.json()} + return f"PAYMENT_REQUIRED: {json.dumps(payload)}" + return json.dumps(resp.json()) +``` + +## Detection Priority + +When a tool returns, the middleware checks for 402 in this order: + +1. **Custom handler** (if registered for this tool name) — full control over detection +2. **`PAYMENT_REQUIRED:` marker** — explicit opt-in signal in content +3. **Lenient fallback** — parses raw JSON for `statusCode: 402` or `x402Version` + `accepts` fields + +This means MCP tools and other tools that return raw JSON are handled automatically without needing the marker or a custom handler. + +## Error Handling + +When payment processing fails, the middleware gives you two layers of control: + +1. **Error handler callback** (`on_payment_error`) — your code resolves the issue programmatically, and the middleware retries +2. **Deterministic error ToolMessages** — if no callback is set (or it returns `PROPAGATE`), the LLM receives a tailored error message + +### Error Handler Callback (Recommended) + +Register a callback to handle payment errors programmatically — auto-provision missing resources, refresh expired sessions, or create new instruments — **without the LLM ever seeing an error**. + +```python +from bedrock_agentcore.payments.integrations.langgraph import ( + AgentCorePaymentsConfig, + AgentCorePaymentsMiddleware, + PaymentErrorContext, + ErrorResolution, +) + +def handle_payment_error(ctx: PaymentErrorContext) -> ErrorResolution | str: + if ctx.exception_type in ("PaymentSessionNotFound", "PaymentSessionExpired"): + # Create a fresh session and retry + session = pm.create_payment_session( + user_id=ctx.config.user_id, + limits={"maxSpendAmount": {"value": "5.00", "currency": "USD"}}, + expiry_time_in_minutes=60, + ) + ctx.config.payment_session_id = session["paymentSessionId"] + return ErrorResolution.RETRY + + if ctx.exception_type == "InsufficientBudget": + # Create a new session with higher budget + session = pm.create_payment_session( + user_id=ctx.config.user_id, + limits={"maxSpendAmount": {"value": "10.00", "currency": "USD"}}, + expiry_time_in_minutes=60, + ) + ctx.config.payment_session_id = session["paymentSessionId"] + return ErrorResolution.RETRY + + if ctx.exception_type == "PaymentInstrumentConfigurationRequired": + # Custom message — direct the user to your setup page + return "Payment instrument not configured. Please visit https://myapp.com/wallet/setup to set up your wallet." + + # Can't handle — use the default deterministic error message + return ErrorResolution.PROPAGATE + +config = AgentCorePaymentsConfig( + payment_manager_arn="arn:...", + user_id="user-1", + payment_instrument_id="instr-1", + region="us-east-1", + on_payment_error=handle_payment_error, + max_error_retries=3, +) +``` + +The callback can return: + +| Return | Behavior | +|--------|----------| +| `ErrorResolution.RETRY` | Retry payment with updated config | +| `ErrorResolution.PROPAGATE` | Use default deterministic error message | +| `str` | Custom message sent to the LLM as `"PAYMENT ERROR: {your string}"` | + +#### How It Works + +``` +Payment exception occurs + │ + ├── on_payment_error is None? → deterministic error ToolMessage (default behavior) + │ + ▼ + Invoke callback(PaymentErrorContext) + │ + ├── Returns PROPAGATE → deterministic error ToolMessage to LLM + │ + ├── Returns RETRY → re-attempt payment with (potentially updated) config + │ │ + │ ├── Success → return paid content to LLM ✅ + │ └── Fails again → loop back to callback (up to max_error_retries) + │ + └── Callback raises exception → fall through to error ToolMessage (no crash) +``` + +#### PaymentErrorContext Fields + +| Field | Type | Description | +|-------|------|-------------| +| `exception` | `Exception` | The exception instance | +| `exception_type` | `str` | Class name (e.g., `"PaymentSessionExpired"`) | +| `exception_message` | `str` | `str(exception)` | +| `tool_name` | `str` | Tool that triggered the 402 | +| `tool_args` | `dict` | The tool call arguments | +| `payment_required_request` | `dict \| None` | The 402 payload (None if error before extraction) | +| `config` | `AgentCorePaymentsConfig` | Mutable reference — modify to fix the issue | +| `retry_count` | `int` | How many times we've retried (starts at 0) | + +#### Async Callbacks + +The callback can be `async def` — automatically awaited in the async path: + +```python +async def async_handler(ctx: PaymentErrorContext) -> ErrorResolution: + session = await create_session_async(ctx.config.user_id) + ctx.config.payment_session_id = session["id"] + return ErrorResolution.RETRY +``` + +#### Safety Guarantees + +- **Max retries**: `max_error_retries=3` (default) prevents infinite loops. Set to 0 to disable. +- **Exception safety**: If the callback raises, the middleware falls through to the error ToolMessage — never crashes. +- **Backward compatible**: `on_payment_error=None` (default) preserves existing behavior. + +#### Recommended Resolution Patterns + +| Exception | Typical Resolution | +|---|---| +| `PaymentSessionNotFound` / `PaymentSessionExpired` | Create a new session via `pm.create_payment_session(...)`, set `ctx.config.payment_session_id`, return RETRY | +| `InsufficientBudget` | Create a new session with higher limits, or PROPAGATE and let the user decide | +| `PaymentInstrumentConfigurationRequired` | Set `ctx.config.payment_instrument_id` from your app's user → instrument mapping, return RETRY | +| `PaymentInstrumentNotFound` | Likely a config error — PROPAGATE (instrument IDs shouldn't change at runtime) | +| `PaymentSessionConfigurationRequired` | Create a session and set it, or enable `auto_session=True` instead of using the callback for this | +| Generic `PaymentError` | Log it, PROPAGATE — usually transient or unrecoverable | + +### Deterministic Error Messages (Default / Fallback) + +When no callback is configured, or the callback returns `PROPAGATE`, the LLM receives a tailored error message with instructions not to retry: + +| Failure | ToolMessage Content | +|---------|-------------------| +| No instrument configured | `PAYMENT ERROR: No payment instrument configured. Do not retry this call. Inform the user they need to configure a payment instrument before making paid requests.` | +| No session configured | `PAYMENT ERROR: No payment session configured. Do not retry this call. Inform the user they need to create a payment session before making paid requests.` | +| Instrument not found | `PAYMENT ERROR: Payment instrument not found. Do not retry this call. Inform the user their payment instrument ID is invalid or has been deleted.` | +| Session not found | `PAYMENT ERROR: Payment session not found. Do not retry this call. Inform the user their payment session ID is invalid or has expired.` | +| Session expired | `PAYMENT ERROR: Payment session has expired. Do not retry this call. Inform the user they need to create a new payment session.` | +| Insufficient budget | `PAYMENT ERROR: Insufficient budget. The payment amount exceeds the remaining session limit. Do not retry this call. Inform the user they need to increase their session budget or create a new session with higher limits.` | +| Payment rejected by server | `PAYMENT ERROR: Payment was signed but rejected by the server ({detail}). Do not retry this call. Inform the user that the payment was not accepted by the merchant.` | +| Generic payment failure | `PAYMENT ERROR: Payment processing failed ({message}). Do not retry this call. Inform the user that payment could not be completed.` | +| Incompatible tool format | `PAYMENT ERROR: Could not apply payment credentials to this tool's request format. Do not retry this call. Inform the user this tool is not compatible with automatic payment processing.` | + + +## Custom Handlers + +Register custom `PaymentResponseHandler` implementations for tools with non-standard output formats. The custom handler is used for **all three phases**: detection, extraction, and injection. + +```python +from bedrock_agentcore.payments.integrations.handlers import PaymentResponseHandler + +class MyMCPHandler(PaymentResponseHandler): + def extract_status_code(self, result): + # Parse your tool's output format to detect 402 + ... + + def extract_headers(self, result): + # Extract HTTP headers from the 402 response + ... + + def extract_body(self, result): + # Extract the x402 payment body + ... + + def validate_tool_input(self, tool_input): + # Check that tool_input is suitable for header injection + return isinstance(tool_input, dict) + + def apply_payment_header(self, tool_input, payment_header): + # Put the payment header where your tool reads it from + tool_input["headers"] = tool_input.get("headers", {}) + tool_input["headers"].update(payment_header) + return True + +config = AgentCorePaymentsConfig( + ..., + custom_handlers={"my_mcp_tool": MyMCPHandler()}, +) +``` + +## MCP Server Tools + +MCP tools connected via `langchain-mcp-adapters` work with the middleware. Since the adapter serializes MCP responses into `ToolMessage.content` as strings, the fallback detection handles the common case (raw JSON with `statusCode: 402`). For non-standard MCP response formats, register a custom handler. + +```python +from langchain_mcp_adapters.client import MultiServerMCPClient + +client = MultiServerMCPClient({"paid_api": {"transport": "stdio", "command": "python", "args": ["server.py"]}}) +mcp_tools = await client.get_tools() + +agent = create_agent( + model=model, + tools=mcp_tools, + middleware=[payments], # auto-detects 402 from MCP tool responses +) +``` + +## Configuration Reference + +| Parameter | Type | Default | Description | +|-----------|------|---------|-------------| +| `payment_manager_arn` | `str` | *required* | ARN of the payment manager resource | +| `user_id` | `str \| None` | `None` | User ID. Required for SigV4 auth; optional with bearer token | +| `payment_instrument_id` | `str \| None` | `None` | Instrument ID for x402 signing | +| `payment_session_id` | `str \| None` | `None` | Session ID for budget enforcement | +| `payment_connector_id` | `str \| None` | `None` | Connector ID (optional) | +| `region` | `str \| None` | `None` | AWS region | +| `network_preferences_config` | `list[str] \| None` | `None` | Ordered CAIP-2 network identifiers | +| `auto_payment` | `bool` | `True` | Enable/disable automatic 402 processing | +| `auto_session` | `bool` | `False` | Auto-create session on first 402 | +| `auto_session_budget` | `str` | `"1.00"` | Budget (USD) for auto-created sessions | +| `auto_session_expiry_minutes` | `int` | `60` | Expiry for auto-created sessions | +| `agent_name` | `str \| None` | `None` | Agent name for data-plane headers | +| `bearer_token` | `str \| None` | `None` | Static JWT. Mutually exclusive with `token_provider` | +| `token_provider` | `Callable \| None` | `None` | Callable returning fresh JWT. Mutually exclusive with `bearer_token` | +| `payment_tool_allowlist` | `list[str] \| None` | `None` | Tools eligible for payment. `None` = all | +| `provide_http_request` | `bool` | `True` | Register built-in `http_request` tool | +| `post_payment_retry_delay_seconds` | `float` | `3.0` | Delay after signing before retry | +| `custom_handlers` | `dict[str, Handler] \| None` | `None` | Custom handlers keyed by tool name | +| `on_payment_error` | `Callable \| None` | `None` | Error callback for programmatic recovery. See Error Handler Callback. | +| `max_error_retries` | `int` | `3` | Max times callback can return RETRY per tool call. 0 disables. | + +## Payment Tool Allowlist + +Restrict which tools get payment processing: + +```python +config = AgentCorePaymentsConfig( + ..., + payment_tool_allowlist=["http_request", "my_paid_api"], +) +``` + +Modify the allowlist at runtime: + +```python +# Add tools +config.add_to_allowlist("new_paid_tool", "another_tool") + +# Remove tools (reverts to all-eligible if list becomes empty) +config.remove_from_allowlist("my_paid_api") +``` + +Tools not in the list pass through untouched. When `None` (default), all tools are eligible. + +## Auto-Session + +Skip manual session creation — the middleware creates one lazily on the first 402: + +```python +config = AgentCorePaymentsConfig( + payment_manager_arn="arn:...", + user_id="user-1", + payment_instrument_id="instr-1", + region="us-east-1", + auto_session=True, # enable lazy creation + auto_session_budget="5.00", # $5 budget + auto_session_expiry_minutes=120, # 2 hours +) +``` + +The session is created once and reused for all subsequent payments in that middleware instance. + +## Disabling Auto-Payment + +Use the middleware only for its built-in query tools without 402 interception: + +```python +config = AgentCorePaymentsConfig( + ..., + auto_payment=False, +) +``` + +## Bearer Token Authentication + +For payment managers using `CUSTOM_JWT` authorizer: + +```python +# Static token +config = AgentCorePaymentsConfig( + payment_manager_arn="arn:...", + bearer_token="eyJhbGciOiJSUzI1NiJ9...", + payment_instrument_id="instr-1", + auto_session=True, +) + +# Dynamic token provider (recommended for production) +config = AgentCorePaymentsConfig( + payment_manager_arn="arn:...", + token_provider=lambda: fetch_fresh_jwt(), + payment_instrument_id="instr-1", + auto_session=True, +) +``` + +With bearer auth, `user_id` is optional (derived from JWT `sub` claim). + +## Comparison: With vs Without Middleware + +**Without middleware** (manual wrapping): +- Write a wrapper function per tool type (~30-50 lines each) +- Handle 402 detection, x402 parsing, signing, retry manually +- No error handling — exceptions crash the tool call +- No blockchain timing delay — fast facilitators may reject +- No budget error messages to the LLM +- Adding a new tool = another wrapper + +**With middleware:** +```python +config = AgentCorePaymentsConfig(payment_manager_arn="...", user_id="...", payment_instrument_id="...", auto_session=True) +agent = create_agent(model=model, tools=[my_tools], middleware=[AgentCorePaymentsMiddleware(config)]) +``` + +Done. All tools handled automatically. diff --git a/src/bedrock_agentcore/payments/integrations/langgraph/__init__.py b/src/bedrock_agentcore/payments/integrations/langgraph/__init__.py new file mode 100644 index 00000000..e0704f45 --- /dev/null +++ b/src/bedrock_agentcore/payments/integrations/langgraph/__init__.py @@ -0,0 +1,12 @@ +"""LangGraph integration for AgentCore Payments.""" + +from .config import AgentCorePaymentsConfig +from .errors import ErrorResolution, PaymentErrorContext +from .middleware import AgentCorePaymentsMiddleware + +__all__ = [ + "AgentCorePaymentsConfig", + "AgentCorePaymentsMiddleware", + "ErrorResolution", + "PaymentErrorContext", +] diff --git a/src/bedrock_agentcore/payments/integrations/langgraph/config.py b/src/bedrock_agentcore/payments/integrations/langgraph/config.py new file mode 100644 index 00000000..1ebd76ba --- /dev/null +++ b/src/bedrock_agentcore/payments/integrations/langgraph/config.py @@ -0,0 +1,151 @@ +"""Configuration for AgentCorePaymentsMiddleware (LangGraph integration).""" + +from dataclasses import dataclass, field +from typing import Any, Callable, Dict, List, Optional + +from ..handlers import PaymentResponseHandler + + +@dataclass +class AgentCorePaymentsConfig: + """Configuration for AgentCorePaymentsMiddleware. + + Attributes: + payment_manager_arn: ARN of the payment manager resource. + user_id: User ID for payment processing. Required for SigV4 auth. + payment_instrument_id: Payment instrument ID for x402 signing. + payment_session_id: Payment session ID for budget enforcement. + payment_connector_id: Payment connector ID (optional). + region: AWS region for the payment manager. + network_preferences_config: Ordered list of network CAIP2 identifiers. + auto_payment: Whether to automatically process 402 responses. Default True. + agent_name: Agent name propagated via HTTP header on data-plane calls. + bearer_token: Static JWT for OAuth/CUSTOM_JWT auth. Mutually exclusive with token_provider. + token_provider: Callable returning a fresh JWT. Mutually exclusive with bearer_token. + payment_tool_allowlist: Tool names eligible for payment processing. None = all tools. + provide_http_request: Whether middleware registers its built-in http_request tool. + post_payment_retry_delay_seconds: Delay after signing before retry. Default 3.0s. + custom_handlers: Custom PaymentResponseHandler instances keyed by tool name. + Takes precedence over the built-in handler registry during resolution. + auto_session: Whether to auto-create a payment session on first 402 if + payment_session_id is not set. Default False. + auto_session_budget: Budget for auto-created sessions (USD). Default "1.00". + auto_session_expiry_minutes: Expiry time for auto-created sessions. Default 60. + on_payment_error: Optional callback invoked when a payment exception occurs. + Receives PaymentErrorContext, returns ErrorResolution.RETRY or .PROPAGATE. + When None (default), errors produce deterministic ToolMessages directly. + max_error_retries: Maximum times the error callback can return RETRY per tool call. + Default 3. Set to 0 to disable the callback entirely. + """ + + payment_manager_arn: str + user_id: Optional[str] = None + payment_instrument_id: Optional[str] = None + payment_session_id: Optional[str] = None + payment_connector_id: Optional[str] = None + region: Optional[str] = None + network_preferences_config: Optional[List[str]] = None + auto_payment: bool = True + agent_name: Optional[str] = None + bearer_token: Optional[str] = None + token_provider: Optional[Callable[[], str]] = None + payment_tool_allowlist: Optional[List[str]] = None + provide_http_request: bool = True + post_payment_retry_delay_seconds: float = 3.0 + custom_handlers: Optional[Dict[str, Any]] = field(default=None) + auto_session: bool = False + auto_session_budget: str = "1.00" + auto_session_expiry_minutes: int = 60 + on_payment_error: Optional[Callable] = None + max_error_retries: int = 3 + + def __post_init__(self) -> None: + """Validate configuration after initialization.""" + if not self.payment_manager_arn: + raise ValueError("payment_manager_arn is required") + if not self.payment_manager_arn.startswith("arn:"): + raise ValueError(f"Invalid ARN format: {self.payment_manager_arn}") + + if self.bearer_token is not None and self.token_provider is not None: + raise ValueError("bearer_token and token_provider are mutually exclusive") + if self.bearer_token is not None and not isinstance(self.bearer_token, str): + raise ValueError(f"bearer_token must be a string, got {type(self.bearer_token).__name__}") + if self.token_provider is not None and not callable(self.token_provider): + raise ValueError(f"token_provider must be callable, got {type(self.token_provider).__name__}") + + if not self.user_id and self.bearer_token is None and self.token_provider is None: + raise ValueError("user_id is required for SigV4 auth (when bearer_token/token_provider not set)") + if self.user_id is not None and self.user_id and not self.user_id.strip(): + raise ValueError("user_id cannot be whitespace-only") + + if not isinstance(self.auto_payment, bool): + raise ValueError(f"auto_payment must be a boolean, got {type(self.auto_payment).__name__}") + if not isinstance(self.provide_http_request, bool): + raise ValueError(f"provide_http_request must be a boolean, got {type(self.provide_http_request).__name__}") + + if self.payment_tool_allowlist is not None: + if not isinstance(self.payment_tool_allowlist, list): + raise ValueError("payment_tool_allowlist must be a list of tool name strings") + if not all(isinstance(t, str) for t in self.payment_tool_allowlist): + raise ValueError("All entries in payment_tool_allowlist must be strings") + + if not isinstance(self.post_payment_retry_delay_seconds, (int, float)) or isinstance( + self.post_payment_retry_delay_seconds, bool + ): + raise ValueError( + f"post_payment_retry_delay_seconds must be a number, got " + f"{type(self.post_payment_retry_delay_seconds).__name__}" + ) + if self.post_payment_retry_delay_seconds < 0: + raise ValueError( + f"post_payment_retry_delay_seconds must be >= 0, got {self.post_payment_retry_delay_seconds}" + ) + + if self.custom_handlers is not None: + if not isinstance(self.custom_handlers, dict): + raise ValueError("custom_handlers must be a dict mapping tool names to PaymentResponseHandler instances") + if not all(isinstance(k, str) for k in self.custom_handlers): + raise ValueError("All keys in custom_handlers must be strings") + if not all(isinstance(v, PaymentResponseHandler) for v in self.custom_handlers.values()): + raise ValueError("All values in custom_handlers must be PaymentResponseHandler instances") + + if self.on_payment_error is not None and not callable(self.on_payment_error): + raise ValueError(f"on_payment_error must be callable, got {type(self.on_payment_error).__name__}") + + if not isinstance(self.max_error_retries, int) or isinstance(self.max_error_retries, bool): + raise ValueError(f"max_error_retries must be an int, got {type(self.max_error_retries).__name__}") + if self.max_error_retries < 0: + raise ValueError(f"max_error_retries must be >= 0, got {self.max_error_retries}") + + def add_to_allowlist(self, *tool_names: str) -> None: + """Add tool names to the payment allowlist. + + Creates the allowlist if it doesn't exist yet (switching from "all tools" + to explicit allowlist mode). + + Args: + tool_names: One or more tool names to add. + """ + if self.payment_tool_allowlist is None: + self.payment_tool_allowlist = [] + for name in tool_names: + if not isinstance(name, str): + raise ValueError(f"Tool name must be a string, got {type(name).__name__}") + if name not in self.payment_tool_allowlist: + self.payment_tool_allowlist.append(name) + + def remove_from_allowlist(self, *tool_names: str) -> None: + """Remove tool names from the payment allowlist. + + If the allowlist becomes empty, sets it to None (all tools eligible). + + Args: + tool_names: One or more tool names to remove. + """ + if self.payment_tool_allowlist is None: + return + for name in tool_names: + if name in self.payment_tool_allowlist: + self.payment_tool_allowlist.remove(name) + if not self.payment_tool_allowlist: + self.payment_tool_allowlist = None diff --git a/src/bedrock_agentcore/payments/integrations/langgraph/errors.py b/src/bedrock_agentcore/payments/integrations/langgraph/errors.py new file mode 100644 index 00000000..b0f79ac2 --- /dev/null +++ b/src/bedrock_agentcore/payments/integrations/langgraph/errors.py @@ -0,0 +1,45 @@ +"""Error callback types for AgentCorePaymentsMiddleware.""" + +from __future__ import annotations + +from dataclasses import dataclass +from enum import Enum +from typing import TYPE_CHECKING, Any, Dict, Optional + +if TYPE_CHECKING: + from .config import AgentCorePaymentsConfig + + +class ErrorResolution(Enum): + """Return value from on_payment_error callback.""" + + RETRY = "retry" + PROPAGATE = "propagate" + + +@dataclass +class PaymentErrorContext: + """Context passed to the on_payment_error callback. + + The developer can inspect the exception, mutate `config` to fix the issue + (e.g., set payment_instrument_id), and return ErrorResolution.RETRY. + + Attributes: + exception: The exception instance that triggered the callback. + exception_type: String name of the exception class. + exception_message: str(exception). + tool_name: Name of the tool that triggered the 402. + tool_args: The tool call arguments dict. + payment_required_request: The 402 payload dict (may be None if error before extraction). + config: Mutable reference to AgentCorePaymentsConfig. + retry_count: How many times we've already retried via the callback. + """ + + exception: Exception + exception_type: str + exception_message: str + tool_name: str + tool_args: Dict[str, Any] + payment_required_request: Optional[Dict[str, Any]] + config: "AgentCorePaymentsConfig" + retry_count: int diff --git a/src/bedrock_agentcore/payments/integrations/langgraph/middleware.py b/src/bedrock_agentcore/payments/integrations/langgraph/middleware.py new file mode 100644 index 00000000..868e3f81 --- /dev/null +++ b/src/bedrock_agentcore/payments/integrations/langgraph/middleware.py @@ -0,0 +1,831 @@ +"""AgentCorePaymentsMiddleware for LangGraph agents.""" + +import logging +import time +import uuid +from typing import Any, Awaitable, Callable, Dict, List, Optional, Union + +from langchain.agents.middleware import AgentMiddleware +from langchain.messages import ToolMessage +from langgraph.prebuilt.tool_node import ToolCallRequest +from langgraph.types import Command + +from bedrock_agentcore.payments.integrations.handlers import ( + PaymentResponseHandler, + get_payment_handler, +) +from bedrock_agentcore.payments.manager import ( + InsufficientBudget, + PaymentError, + PaymentInstrumentConfigurationRequired, + PaymentInstrumentNotFound, + PaymentSessionConfigurationRequired, + PaymentSessionExpired, + PaymentSessionNotFound, + PaymentManager, +) + +from .config import AgentCorePaymentsConfig +from .tools import ( + make_get_payment_instrument_balance_tool, + make_get_payment_instrument_tool, + make_get_payment_session_tool, + make_http_request_tool, + make_list_payment_instruments_tool, +) + +logger = logging.getLogger(__name__) + +# Deterministic error messages per exception type. +# The LLM sees these messages and should follow the "Do not retry" instruction. +_ERROR_MESSAGES: Dict[type, str] = { + PaymentInstrumentConfigurationRequired: ( + "No payment instrument configured. Do not retry this call. " + "Inform the user they need to configure a payment instrument before making paid requests." + ), + PaymentSessionConfigurationRequired: ( + "No payment session configured. Do not retry this call. " + "Inform the user they need to create a payment session before making paid requests." + ), + PaymentInstrumentNotFound: ( + "Payment instrument not found. Do not retry this call. " + "Inform the user their payment instrument ID is invalid or has been deleted." + ), + PaymentSessionNotFound: ( + "Payment session not found. Do not retry this call. " + "Inform the user their payment session ID is invalid or has expired." + ), + PaymentSessionExpired: ( + "Payment session has expired. Do not retry this call. " + "Inform the user they need to create a new payment session." + ), + InsufficientBudget: ( + "Insufficient budget. The payment amount exceeds the remaining session limit. " + "Do not retry this call. Inform the user they need to increase their session budget " + "or create a new session with higher limits." + ), +} + + +class _FallbackHandler: + """Minimal handler wrapping pre-parsed 402 data from fallback detection.""" + + def __init__(self, parsed: Dict[str, Any]): + self._parsed = parsed + + def extract_status_code(self, result: Any) -> Optional[int]: + return self._parsed.get("statusCode") + + def extract_headers(self, result: Any) -> Optional[Dict[str, Any]]: + return self._parsed.get("headers", {}) + + def extract_body(self, result: Any) -> Optional[Dict[str, Any]]: + return self._parsed.get("body", {}) + + +class AgentCorePaymentsMiddleware(AgentMiddleware): + """Middleware that intercepts tool calls to handle x402 Payment Required responses. + + This middleware wraps tool execution to automatically detect HTTP 402 responses, + process x402 payment requirements via PaymentManager, and retry the tool call + with payment credentials. + + Usage: + config = AgentCorePaymentsConfig( + payment_manager_arn="arn:aws:...", + user_id="user-123", + payment_instrument_id="instrument-456", + payment_session_id="session-789", + ) + middleware = AgentCorePaymentsMiddleware(config) + agent = create_agent(model=..., tools=[...], middleware=[middleware]) + """ + + def __init__(self, config: AgentCorePaymentsConfig) -> None: + """Initialize middleware with config and create PaymentManager. + + Args: + config: Payment configuration. + + Raises: + RuntimeError: If PaymentManager initialization fails. + """ + super().__init__() + self.config = config + try: + self.payment_manager = PaymentManager( + payment_manager_arn=config.payment_manager_arn, + region_name=config.region, + agent_name=config.agent_name, + bearer_token=config.bearer_token, + token_provider=config.token_provider, + ) + except Exception as e: + raise RuntimeError(f"Failed to initialize PaymentManager: {e}") from e + self.tools = self._build_tools() + logger.info("AgentCorePaymentsMiddleware initialized") + + def _build_tools(self) -> list: + """Build the list of tools to register with the agent. + + Returns: + List of tool callables. Includes http_request if provide_http_request=True. + """ + tools = [] + if self.config.provide_http_request: + tools.append(make_http_request_tool(self)) + tools.append(make_get_payment_instrument_tool(self)) + tools.append(make_list_payment_instruments_tool(self)) + tools.append(make_get_payment_instrument_balance_tool(self)) + tools.append(make_get_payment_session_tool(self)) + return tools + + @staticmethod + def _prepare_for_handler(content: Any) -> Optional[Dict[str, List[Dict[str, str]]]]: + """Normalize ToolMessage.content into handler-compatible shape. + + Args: + content: ToolMessage.content — either a str, list, or None. + + Returns: + Dict with "content" key containing list of {"text": ...} blocks, or None. + """ + if content is None: + return None + if isinstance(content, str): + return {"content": [{"text": content}]} + if isinstance(content, list): + blocks = [] + for item in content: + if isinstance(item, dict) and "text" in item: + blocks.append(item) + elif isinstance(item, str): + blocks.append({"text": item}) + else: + blocks.append(item) + return {"content": blocks} + return None + + def _get_handler(self, tool_name: str, tool_args: Dict[str, Any]) -> PaymentResponseHandler: + """Resolve the payment response handler for a tool. + + Resolution priority: + 1. Custom handlers from config (highest priority) + 2. Built-in registry (name-based → MCP shape → generic fallback) + + Args: + tool_name: Name of the tool. + tool_args: Tool call arguments dict. + + Returns: + The resolved PaymentResponseHandler instance. + """ + if self.config.custom_handlers and tool_name in self.config.custom_handlers: + logger.debug("Using custom handler for tool: %s", tool_name) + return self.config.custom_handlers[tool_name] + return get_payment_handler(tool_name, tool_args) + + @staticmethod + def _fallback_detect_402(content: Any) -> Optional[Dict[str, Any]]: + """Lenient fallback: detect 402 from raw JSON without the PAYMENT_REQUIRED: marker. + + Handles MCP tools and other tools that return raw JSON responses like: + - {"statusCode": 402, "headers": {...}, "body": {...}} + - {"x402Version": 1, "accepts": [...]} + + Args: + content: ToolMessage.content (str or list). + + Returns: + Parsed payment-required dict if 402 detected, None otherwise. + """ + import json as _json + + texts = [] + if isinstance(content, str): + texts.append(content) + elif isinstance(content, list): + for item in content: + if isinstance(item, dict) and "text" in item: + texts.append(item["text"]) + elif isinstance(item, str): + texts.append(item) + + for text in texts: + try: + parsed = _json.loads(text) + except (ValueError, TypeError): + continue + if not isinstance(parsed, dict): + continue + + # Check for statusCode: 402 + if parsed.get("statusCode") == 402: + logger.debug("Fallback detection: found statusCode 402 in raw JSON") + return parsed + + # Check for httpStatus: 402 (MCP structuredContent format) + if parsed.get("httpStatus") == 402: + logger.debug("Fallback detection: found httpStatus 402 (MCP format)") + return { + "statusCode": 402, + "headers": parsed.get("responseHeaders", {}), + "body": parsed.get("structuredContent", {}), + } + + # Check for x402 payload (x402Version + accepts) at top level + if "x402Version" in parsed and "accepts" in parsed: + logger.debug("Fallback detection: found x402 payload in raw JSON") + return {"statusCode": 402, "headers": {}, "body": parsed} + + # Check for x402 payload nested in structuredContent + sc = parsed.get("structuredContent") + if isinstance(sc, dict) and "x402Version" in sc and "accepts" in sc: + logger.debug("Fallback detection: found x402 payload in structuredContent") + return { + "statusCode": 402, + "headers": parsed.get("responseHeaders", {}), + "body": sc, + } + + return None + + def _generate_payment_header(self, payment_required_request: Dict[str, Any]) -> Dict[str, str]: + """Generate payment header via PaymentManager. + + Args: + payment_required_request: Dict with statusCode, headers, body from the 402 response. + + Returns: + Dict with payment header name → value. + + Raises: + PaymentInstrumentConfigurationRequired: If payment_instrument_id not set. + PaymentSessionConfigurationRequired: If payment_session_id not set. + PaymentError: If payment processing fails. + """ + if self.config.payment_instrument_id is None: + raise PaymentInstrumentConfigurationRequired( + "payment_instrument_id is required for x402 payments." + ) + if self.config.payment_session_id is None: + if self.config.auto_session: + self._create_auto_session() + else: + raise PaymentSessionConfigurationRequired( + "payment_session_id is required for x402 payments." + ) + + return self.payment_manager.generate_payment_header( + user_id=self.config.user_id, + payment_instrument_id=self.config.payment_instrument_id, + payment_session_id=self.config.payment_session_id, + payment_required_request=payment_required_request, + network_preferences=self.config.network_preferences_config, + client_token=str(uuid.uuid4()), + payment_connector_id=self.config.payment_connector_id, + ) + + def _create_auto_session(self) -> None: + """Lazily create a payment session on first 402 when auto_session=True.""" + logger.info( + "auto_session: creating payment session (budget=$%s, expiry=%dmin)", + self.config.auto_session_budget, + self.config.auto_session_expiry_minutes, + ) + session = self.payment_manager.create_payment_session( + user_id=self.config.user_id, + limits={"maxSpendAmount": {"value": self.config.auto_session_budget, "currency": "USD"}}, + expiry_time_in_minutes=self.config.auto_session_expiry_minutes, + ) + self.config.payment_session_id = session["paymentSessionId"] + logger.info("auto_session: created session %s", self.config.payment_session_id) + + def wrap_tool_call( + self, + request: ToolCallRequest, + handler: Callable[[ToolCallRequest], Union[ToolMessage, Command]], + ) -> Union[ToolMessage, Command]: + """Wrap tool execution with 402 payment detection, signing, and retry. + + Args: + request: The tool call request. + handler: Callable that executes the tool. + + Returns: + The tool execution result or an error ToolMessage. + """ + result = handler(request) + + # Guard: Command results have no tool output to inspect + if isinstance(result, Command): + return result + + # Guard: auto_payment disabled + if not self.config.auto_payment: + return result + + # Guard: allowlist filtering + tool_name = request.tool_call["name"] + if self.config.payment_tool_allowlist is not None: + if tool_name not in self.config.payment_tool_allowlist: + return result + + # 402 detection + # Priority: custom handler → GenericPaymentHandler (marker) → lenient fallback (raw JSON) + tool_args = request.tool_call.get("args", {}) + prepared = self._prepare_for_handler(result.content) + if prepared is None: + return result + + has_custom_handler = ( + self.config.custom_handlers is not None + and tool_name in self.config.custom_handlers + ) + + if has_custom_handler: + detection_handler = self.config.custom_handlers[tool_name] + else: + from bedrock_agentcore.payments.integrations.handlers import GenericPaymentHandler + detection_handler = GenericPaymentHandler() + + status_code = detection_handler.extract_status_code(prepared) + + # Lenient fallback: if no custom handler and marker detection didn't find 402, + # try parsing raw JSON for statusCode:402 or x402Version fields. + # This handles MCP tools and other tools that return raw JSON without the marker. + if status_code != 402 and not has_custom_handler: + fallback = self._fallback_detect_402(result.content) + if fallback is not None: + status_code = 402 + # Switch detection_handler to a wrapper that returns the parsed data + detection_handler = _FallbackHandler(fallback) + + if status_code != 402: + return result + + logger.info("Detected 402 Payment Required from tool: %s", tool_name) + + # Payment processing with comprehensive error handling + payment_required_request = None + try: + # Extract payment requirement details + headers_402 = detection_handler.extract_headers(prepared) or {} + body_402 = detection_handler.extract_body(prepared) or {} + payment_required_request = { + "statusCode": 402, + "headers": headers_402, + "body": body_402, + } + + # Generate payment header + payment_header = self._generate_payment_header(payment_required_request) + + # Resolve handler for header injection + # Custom handler handles all phases; otherwise resolve by tool shape + if has_custom_handler: + injection_handler = detection_handler + else: + injection_handler = self._get_handler(tool_name, tool_args) + + # Inject header into tool args + if not injection_handler.validate_tool_input(tool_args): + return self._error_tool_message( + request, + PaymentError("Could not apply payment credentials to this tool's request format."), + ) + if not injection_handler.apply_payment_header(tool_args, payment_header): + return self._error_tool_message( + request, + PaymentError("Could not apply payment credentials to this tool's request format."), + ) + + # Blockchain timing delay + delay = self.config.post_payment_retry_delay_seconds + if delay > 0: + logger.info("Waiting %.1fs before retry for blockchain timing", delay) + time.sleep(delay) + + # Re-execute the tool with payment credentials + retry_result = handler(request) + + if isinstance(retry_result, Command): + return retry_result + + # Post-payment rejection detection + retry_prepared = self._prepare_for_handler(retry_result.content) + if retry_prepared is not None: + # Use fresh detection on the retry result (not the frozen fallback handler) + from bedrock_agentcore.payments.integrations.handlers import GenericPaymentHandler as _GH + _retry_handler = _GH() + retry_status = _retry_handler.extract_status_code(retry_prepared) + # Also check via fallback if marker not found + if retry_status != 402: + retry_fallback = self._fallback_detect_402(retry_result.content) + if retry_fallback is not None: + retry_status = 402 + _retry_handler = _FallbackHandler(retry_fallback) + if retry_status == 402: + retry_body = _retry_handler.extract_body(retry_prepared) or {} + error_detail = ( + retry_body.get("error", "unknown error") + if isinstance(retry_body, dict) + else "unknown error" + ) + return self._error_tool_message( + request, + PaymentError( + f"Payment was signed but rejected by the server ({error_detail})." + ), + ) + + return retry_result + + except Exception as e: + logger.error("Payment processing error for tool %s: %s: %s", tool_name, type(e).__name__, e) + if self.config.on_payment_error is not None and self.config.max_error_retries > 0: + resolution = self._invoke_error_handler( + exception=e, + tool_name=tool_name, + tool_args=tool_args, + payment_required_request=payment_required_request, + request=request, + handler=handler, + ) + if resolution is not None: + return resolution + return self._error_tool_message(request, e) + + def _invoke_error_handler( + self, + exception: Exception, + tool_name: str, + tool_args: Dict[str, Any], + payment_required_request: Optional[Dict[str, Any]], + request: "ToolCallRequest", + handler: Callable, + ) -> Optional[Union[ToolMessage, Command]]: + """Invoke on_payment_error callback and retry if requested. + + Returns: + ToolMessage/Command if retry succeeded or max retries exhausted. + None if callback returned PROPAGATE (caller uses default error path). + """ + from .errors import ErrorResolution, PaymentErrorContext + + retry_count = 0 + current_exception = exception + + while retry_count < self.config.max_error_retries: + ctx = PaymentErrorContext( + exception=current_exception, + exception_type=type(current_exception).__name__, + exception_message=str(current_exception), + tool_name=tool_name, + tool_args=tool_args, + payment_required_request=payment_required_request, + config=self.config, + retry_count=retry_count, + ) + + try: + resolution = self.config.on_payment_error(ctx) + except Exception as cb_err: + logger.error("on_payment_error callback raised: %s", cb_err) + return None + + # str return = custom message to the LLM + if isinstance(resolution, str): + return ToolMessage( + content=f"PAYMENT ERROR: {resolution}", + tool_call_id=request.tool_call["id"], + status="error", + ) + + if resolution != ErrorResolution.RETRY: + return None + + retry_count += 1 + logger.info("on_payment_error returned RETRY (attempt %d/%d)", retry_count, self.config.max_error_retries) + + try: + payment_header = self._generate_payment_header(payment_required_request or {}) + + has_custom = self.config.custom_handlers and tool_name in self.config.custom_handlers + injection_handler = ( + self.config.custom_handlers[tool_name] if has_custom + else self._get_handler(tool_name, tool_args) + ) + + if not injection_handler.validate_tool_input(tool_args): + return self._error_tool_message(request, PaymentError("Could not apply payment credentials after error recovery.")) + if not injection_handler.apply_payment_header(tool_args, payment_header): + return self._error_tool_message(request, PaymentError("Could not apply payment credentials after error recovery.")) + + delay = self.config.post_payment_retry_delay_seconds + if delay > 0: + time.sleep(delay) + + retry_result = handler(request) + if isinstance(retry_result, Command): + return retry_result + + # Post-payment rejection check + retry_prepared = self._prepare_for_handler(retry_result.content) + if retry_prepared is not None: + from bedrock_agentcore.payments.integrations.handlers import GenericPaymentHandler as _GH + _rh = _GH() + retry_status = _rh.extract_status_code(retry_prepared) + if retry_status != 402: + fallback = self._fallback_detect_402(retry_result.content) + if fallback: + retry_status = 402 + if retry_status == 402: + retry_body = _rh.extract_body(retry_prepared) or {} + detail = retry_body.get("error", "unknown") if isinstance(retry_body, dict) else "unknown" + return self._error_tool_message(request, PaymentError(f"Payment signed but rejected after recovery ({detail}).")) + + return retry_result + + except Exception as retry_err: + logger.error("Payment retry after error handler failed: %s", retry_err) + current_exception = retry_err + continue + + logger.warning("max_error_retries (%d) exhausted", self.config.max_error_retries) + return self._error_tool_message(request, current_exception) + + @staticmethod + def _error_tool_message(request: ToolCallRequest, exception: Exception) -> ToolMessage: + """Create a ToolMessage with a deterministic error message for the LLM. + + Looks up the exception type in the error message map. Falls back to a + generic message that includes the exception string for unrecognized types. + + Args: + request: The original tool call request (for tool_call_id). + exception: The exception to report. + + Returns: + ToolMessage with status="error" and deterministic content. + """ + msg = _ERROR_MESSAGES.get(type(exception)) + if msg is None: + if isinstance(exception, PaymentError): + msg = ( + f"Payment processing failed ({exception}). " + "Do not retry this call. Inform the user that payment could not be completed." + ) + else: + msg = ( + f"An unexpected error occurred during payment processing ({exception}). " + "Do not retry this call. Inform the user that payment could not be completed." + ) + + return ToolMessage( + content=f"PAYMENT ERROR: {msg}", + tool_call_id=request.tool_call["id"], + status="error", + ) + + async def awrap_tool_call( + self, + request: ToolCallRequest, + handler: Callable[[ToolCallRequest], Awaitable[Union[ToolMessage, Command]]], + ) -> Union[ToolMessage, Command]: + """Async version of wrap_tool_call. + + Uses asyncio.sleep for non-blocking delay and asyncio.to_thread for + the synchronous PaymentManager.generate_payment_header call. + + Args: + request: The tool call request. + handler: Async callable that executes the tool. + + Returns: + The tool execution result or an error ToolMessage. + """ + import asyncio + + result = await handler(request) + + if isinstance(result, Command): + return result + if not self.config.auto_payment: + return result + + tool_name = request.tool_call["name"] + if self.config.payment_tool_allowlist is not None: + if tool_name not in self.config.payment_tool_allowlist: + return result + + tool_args = request.tool_call.get("args", {}) + prepared = self._prepare_for_handler(result.content) + if prepared is None: + return result + + has_custom_handler = ( + self.config.custom_handlers is not None + and tool_name in self.config.custom_handlers + ) + + if has_custom_handler: + detection_handler = self.config.custom_handlers[tool_name] + else: + from bedrock_agentcore.payments.integrations.handlers import GenericPaymentHandler + detection_handler = GenericPaymentHandler() + + status_code = detection_handler.extract_status_code(prepared) + + # Lenient fallback for async path (same as sync) + if status_code != 402 and not has_custom_handler: + fallback = self._fallback_detect_402(result.content) + if fallback is not None: + status_code = 402 + detection_handler = _FallbackHandler(fallback) + + if status_code != 402: + return result + + logger.info("Detected 402 Payment Required from tool (async): %s", tool_name) + + payment_required_request = None + try: + headers_402 = detection_handler.extract_headers(prepared) or {} + body_402 = detection_handler.extract_body(prepared) or {} + payment_required_request = { + "statusCode": 402, + "headers": headers_402, + "body": body_402, + } + + payment_header = await asyncio.to_thread( + self._generate_payment_header, payment_required_request + ) + + if has_custom_handler: + injection_handler = detection_handler + else: + injection_handler = self._get_handler(tool_name, tool_args) + + if not injection_handler.validate_tool_input(tool_args): + return self._error_tool_message( + request, + PaymentError("Could not apply payment credentials to this tool's request format."), + ) + if not injection_handler.apply_payment_header(tool_args, payment_header): + return self._error_tool_message( + request, + PaymentError("Could not apply payment credentials to this tool's request format."), + ) + + delay = self.config.post_payment_retry_delay_seconds + if delay > 0: + logger.info("Waiting %.1fs before retry for blockchain timing (async)", delay) + await asyncio.sleep(delay) + + retry_result = await handler(request) + + if isinstance(retry_result, Command): + return retry_result + + retry_prepared = self._prepare_for_handler(retry_result.content) + if retry_prepared is not None: + from bedrock_agentcore.payments.integrations.handlers import GenericPaymentHandler as _GH + _retry_handler = _GH() + retry_status = _retry_handler.extract_status_code(retry_prepared) + if retry_status != 402: + retry_fallback = self._fallback_detect_402(retry_result.content) + if retry_fallback is not None: + retry_status = 402 + _retry_handler = _FallbackHandler(retry_fallback) + if retry_status == 402: + retry_body = _retry_handler.extract_body(retry_prepared) or {} + error_detail = ( + retry_body.get("error", "unknown error") + if isinstance(retry_body, dict) + else "unknown error" + ) + return self._error_tool_message( + request, + PaymentError( + f"Payment was signed but rejected by the server ({error_detail})." + ), + ) + + return retry_result + + except Exception as e: + logger.error("Payment processing error (async) for tool %s: %s: %s", tool_name, type(e).__name__, e) + if self.config.on_payment_error is not None and self.config.max_error_retries > 0: + resolution = await self._ainvoke_error_handler( + exception=e, + tool_name=tool_name, + tool_args=tool_args, + payment_required_request=payment_required_request, + request=request, + handler=handler, + ) + if resolution is not None: + return resolution + return self._error_tool_message(request, e) + + async def _ainvoke_error_handler( + self, + exception: Exception, + tool_name: str, + tool_args: Dict[str, Any], + payment_required_request: Optional[Dict[str, Any]], + request: "ToolCallRequest", + handler: Callable, + ) -> Optional[Union[ToolMessage, Command]]: + """Async version of _invoke_error_handler. Supports async callbacks.""" + import asyncio + import inspect + from .errors import ErrorResolution, PaymentErrorContext + + retry_count = 0 + current_exception = exception + + while retry_count < self.config.max_error_retries: + ctx = PaymentErrorContext( + exception=current_exception, + exception_type=type(current_exception).__name__, + exception_message=str(current_exception), + tool_name=tool_name, + tool_args=tool_args, + payment_required_request=payment_required_request, + config=self.config, + retry_count=retry_count, + ) + + try: + if inspect.iscoroutinefunction(self.config.on_payment_error): + resolution = await self.config.on_payment_error(ctx) + else: + resolution = self.config.on_payment_error(ctx) + except Exception as cb_err: + logger.error("on_payment_error callback raised (async): %s", cb_err) + return None + + # str return = custom message to the LLM + if isinstance(resolution, str): + return ToolMessage( + content=f"PAYMENT ERROR: {resolution}", + tool_call_id=request.tool_call["id"], + status="error", + ) + + if resolution != ErrorResolution.RETRY: + return None + + retry_count += 1 + logger.info("on_payment_error returned RETRY (async, attempt %d/%d)", retry_count, self.config.max_error_retries) + + try: + payment_header = await asyncio.to_thread( + self._generate_payment_header, payment_required_request or {} + ) + + has_custom = self.config.custom_handlers and tool_name in self.config.custom_handlers + injection_handler = ( + self.config.custom_handlers[tool_name] if has_custom + else self._get_handler(tool_name, tool_args) + ) + + if not injection_handler.validate_tool_input(tool_args): + return self._error_tool_message(request, PaymentError("Could not apply payment credentials after error recovery.")) + if not injection_handler.apply_payment_header(tool_args, payment_header): + return self._error_tool_message(request, PaymentError("Could not apply payment credentials after error recovery.")) + + delay = self.config.post_payment_retry_delay_seconds + if delay > 0: + await asyncio.sleep(delay) + + retry_result = await handler(request) + if isinstance(retry_result, Command): + return retry_result + + retry_prepared = self._prepare_for_handler(retry_result.content) + if retry_prepared is not None: + from bedrock_agentcore.payments.integrations.handlers import GenericPaymentHandler as _GH + _rh = _GH() + retry_status = _rh.extract_status_code(retry_prepared) + if retry_status != 402: + fallback = self._fallback_detect_402(retry_result.content) + if fallback: + retry_status = 402 + if retry_status == 402: + retry_body = _rh.extract_body(retry_prepared) or {} + detail = retry_body.get("error", "unknown") if isinstance(retry_body, dict) else "unknown" + return self._error_tool_message(request, PaymentError(f"Payment signed but rejected after recovery ({detail}).")) + + return retry_result + + except Exception as retry_err: + logger.error("Payment retry after error handler failed (async): %s", retry_err) + current_exception = retry_err + continue + + logger.warning("max_error_retries (%d) exhausted (async)", self.config.max_error_retries) + return self._error_tool_message(request, current_exception) diff --git a/src/bedrock_agentcore/payments/integrations/langgraph/tools.py b/src/bedrock_agentcore/payments/integrations/langgraph/tools.py new file mode 100644 index 00000000..142d2373 --- /dev/null +++ b/src/bedrock_agentcore/payments/integrations/langgraph/tools.py @@ -0,0 +1,194 @@ +"""Built-in tools for AgentCorePaymentsMiddleware.""" + +import json +import logging +from typing import Any, Dict, Optional, Union + +import httpx +from langchain.tools import tool + +logger = logging.getLogger(__name__) + + +def make_http_request_tool(middleware: Any): + """Create an http_request tool that closes over the middleware instance. + + Returns PAYMENT_REQUIRED: marker on 402 for automatic payment processing. + """ + + @tool + def http_request( + url: str, + method: str = "GET", + headers: Optional[Dict[str, str]] = None, + body: Optional[Union[Dict[str, Any], str]] = None, + ) -> str: + """Call an HTTP endpoint. 402 Payment Required responses are settled automatically. + + Args: + url: The full URL to request. + method: HTTP method. Defaults to GET. + headers: Optional request headers. + body: Optional request body. Dict is sent as JSON; str is sent as-is. + + Returns: + JSON string with statusCode, headers, and body. Prefixed with + PAYMENT_REQUIRED: on 402 for automatic payment processing. + """ + request_headers = dict(headers) if headers else {} + method_upper = method.upper() + + try: + with httpx.Client(timeout=30.0, follow_redirects=True) as client: + if body is None or method_upper in ("GET", "HEAD"): + resp = client.request(method_upper, url, headers=request_headers) + elif isinstance(body, str): + resp = client.request(method_upper, url, headers=request_headers, content=body) + else: + resp = client.request(method_upper, url, headers=request_headers, json=body) + except httpx.RequestError as exc: + return json.dumps({"statusCode": 0, "error": f"Request failed: {exc}", "url": url}) + + response_headers = dict(resp.headers) + try: + response_body: Any = resp.json() + except Exception: + response_body = {"text": resp.text} + + payload = { + "statusCode": resp.status_code, + "headers": response_headers, + "body": response_body, + } + + if resp.status_code == 402: + return f"PAYMENT_REQUIRED: {json.dumps(payload)}" + + return json.dumps(payload) + + return http_request + + +def make_get_payment_instrument_tool(middleware: Any): + """Create a get_payment_instrument tool.""" + + @tool + def get_payment_instrument( + payment_instrument_id: Optional[str] = None, + user_id: Optional[str] = None, + payment_connector_id: Optional[str] = None, + ) -> Dict[str, Any]: + """Retrieve details about a specific payment instrument. + + Args: + payment_instrument_id: Instrument ID (falls back to middleware config). + user_id: User ID (falls back to middleware config). + payment_connector_id: Connector ID (optional). + + Returns: + Payment instrument details dictionary. + """ + resolved_id = (payment_instrument_id.strip() if payment_instrument_id else None) or middleware.config.payment_instrument_id + resolved_user = (user_id.strip() if user_id else None) or middleware.config.user_id + return middleware.payment_manager.get_payment_instrument( + user_id=resolved_user, + payment_instrument_id=resolved_id, + payment_connector_id=payment_connector_id, + ) + + return get_payment_instrument + + +def make_list_payment_instruments_tool(middleware: Any): + """Create a list_payment_instruments tool.""" + + @tool + def list_payment_instruments( + user_id: Optional[str] = None, + payment_connector_id: Optional[str] = None, + max_results: int = 100, + next_token: Optional[str] = None, + ) -> Dict[str, Any]: + """List all payment instruments for a user. + + Args: + user_id: User ID (falls back to middleware config). + payment_connector_id: Filter by connector (optional). + max_results: Maximum results to return (default 100). + next_token: Pagination token (optional). + + Returns: + Dictionary with paymentInstruments list and optional nextToken. + """ + resolved_user = (user_id.strip() if user_id else None) or middleware.config.user_id + return middleware.payment_manager.list_payment_instruments( + user_id=resolved_user, + payment_connector_id=payment_connector_id, + max_results=max_results, + next_token=next_token, + ) + + return list_payment_instruments + + +def make_get_payment_instrument_balance_tool(middleware: Any): + """Create a get_payment_instrument_balance tool.""" + + @tool + def get_payment_instrument_balance( + payment_instrument_id: str, + chain: str = "BASE_SEPOLIA", + token: str = "USDC", + payment_connector_id: Optional[str] = None, + user_id: Optional[str] = None, + ) -> Dict[str, Any]: + """Get the token balance for a payment instrument on a blockchain. + + Args: + payment_instrument_id: Instrument ID (required). + chain: Blockchain chain (e.g., BASE_SEPOLIA, SOLANA_DEVNET). + token: Token to query (e.g., USDC). + payment_connector_id: Connector ID (falls back to config). + user_id: User ID (falls back to config). + + Returns: + Dictionary with balance information. + """ + resolved_user = (user_id.strip() if user_id else None) or middleware.config.user_id + resolved_connector = payment_connector_id or middleware.config.payment_connector_id + return middleware.payment_manager.get_payment_instrument_balance( + payment_connector_id=resolved_connector, + payment_instrument_id=payment_instrument_id, + chain=chain, + token=token, + user_id=resolved_user, + ) + + return get_payment_instrument_balance + + +def make_get_payment_session_tool(middleware: Any): + """Create a get_payment_session tool.""" + + @tool + def get_payment_session( + payment_session_id: Optional[str] = None, + user_id: Optional[str] = None, + ) -> Dict[str, Any]: + """Retrieve details about a specific payment session. + + Args: + payment_session_id: Session ID (falls back to middleware config). + user_id: User ID (falls back to middleware config). + + Returns: + Payment session details dictionary. + """ + resolved_id = (payment_session_id.strip() if payment_session_id else None) or middleware.config.payment_session_id + resolved_user = (user_id.strip() if user_id else None) or middleware.config.user_id + return middleware.payment_manager.get_payment_session( + user_id=resolved_user, + payment_session_id=resolved_id, + ) + + return get_payment_session diff --git a/tests/bedrock_agentcore/payments/integrations/langgraph/__init__.py b/tests/bedrock_agentcore/payments/integrations/langgraph/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/bedrock_agentcore/payments/integrations/langgraph/test_functional.py b/tests/bedrock_agentcore/payments/integrations/langgraph/test_functional.py new file mode 100644 index 00000000..dd02a1d7 --- /dev/null +++ b/tests/bedrock_agentcore/payments/integrations/langgraph/test_functional.py @@ -0,0 +1,534 @@ +"""Functional test: full 402 → sign → retry → 200 flow against a real testnet. + +Run: + python -m pytest tests/bedrock_agentcore/payments/integrations/langgraph/test_functional.py -v -s + +Required environment variables: + ACP_PAYMENT_MANAGER_ARN - PaymentManager ARN + ACP_USER_ID - User ID for payment processing + ACP_PAYMENT_INSTRUMENT_ID - Payment instrument ID (funded wallet) + ACP_REGION - AWS region (e.g., us-east-1) + ACP_TESTNET_URL - x402-enabled testnet endpoint that returns 402 + +Optional: + ACP_PAYMENT_SESSION_ID - Active payment session ID (auto-created if not set) + ACP_PAYMENT_CONNECTOR_ID - Payment connector ID + ACP_RETRY_DELAY - Override retry delay (default 3.0s) +""" + +import json +import os + +import pytest +from langchain.messages import ToolMessage +from unittest.mock import MagicMock + +from bedrock_agentcore.payments.integrations.langgraph import ( + AgentCorePaymentsConfig, + AgentCorePaymentsMiddleware, +) +from bedrock_agentcore.payments.manager import PaymentManager + +# Skip entire module if required env vars not configured +pytestmark = pytest.mark.skipif( + not all(os.environ.get(k) for k in [ + "ACP_PAYMENT_MANAGER_ARN", + "ACP_USER_ID", + "ACP_PAYMENT_INSTRUMENT_ID", + "ACP_REGION", + "ACP_TESTNET_URL", + ]), + reason="Testnet env vars not set (ACP_PAYMENT_MANAGER_ARN, ACP_USER_ID, ACP_PAYMENT_INSTRUMENT_ID, ACP_REGION, ACP_TESTNET_URL)", +) + + +@pytest.fixture(scope="module") +def payment_session_id(): + """Use existing session ID from env, or auto-create one ($1.00, 60 min).""" + existing = os.environ.get("ACP_PAYMENT_SESSION_ID") + if existing: + print(f"\n[Using existing session: {existing}]") + return existing + + print("\n[ACP_PAYMENT_SESSION_ID not set — auto-creating session ($1.00, 60 min)]") + pm = PaymentManager( + payment_manager_arn=os.environ["ACP_PAYMENT_MANAGER_ARN"], + region_name=os.environ["ACP_REGION"], + ) + session = pm.create_payment_session( + user_id=os.environ["ACP_USER_ID"], + limits={"maxSpendAmount": {"value": "1.00", "currency": "USD"}}, + expiry_time_in_minutes=60, + ) + session_id = session["paymentSessionId"] + print(f"[Created session: {session_id}]") + return session_id + + +@pytest.fixture(scope="module") +def config(payment_session_id): + return AgentCorePaymentsConfig( + payment_manager_arn=os.environ["ACP_PAYMENT_MANAGER_ARN"], + user_id=os.environ["ACP_USER_ID"], + payment_instrument_id=os.environ["ACP_PAYMENT_INSTRUMENT_ID"], + payment_session_id=payment_session_id, + payment_connector_id=os.environ.get("ACP_PAYMENT_CONNECTOR_ID"), + region=os.environ["ACP_REGION"], + post_payment_retry_delay_seconds=float(os.environ.get("ACP_RETRY_DELAY", "3.0")), + ) + + +@pytest.fixture(scope="module") +def middleware(config): + return AgentCorePaymentsMiddleware(config) + + +@pytest.fixture(scope="module") +def testnet_url(): + return os.environ["ACP_TESTNET_URL"] + + +class TestFullPaymentFlow: + """End-to-end: http_request tool → 402 → middleware signs → retry → 200.""" + + def test_http_request_tool_gets_402(self, middleware, testnet_url): + """The built-in http_request tool returns PAYMENT_REQUIRED on 402.""" + tool = next(t for t in middleware.tools if t.name == "http_request") + result = tool.invoke({"url": testnet_url}) + + print(f"\n[http_request raw result]: {result[:200]}...") + assert "PAYMENT_REQUIRED:" in result, f"Expected 402 from testnet, got: {result[:100]}" + + parsed = json.loads(result[len("PAYMENT_REQUIRED: "):]) + assert parsed["statusCode"] == 402 + print(f"[402 body keys]: {list(parsed.get('body', {}).keys())}") + + def test_wrap_tool_call_full_flow(self, middleware, testnet_url): + """wrap_tool_call intercepts 402, signs payment, retries, gets 200.""" + # Simulate what LangGraph does: create a ToolCallRequest-like object + # with tool_call dict, then a handler that calls http_request + tool = next(t for t in middleware.tools if t.name == "http_request") + + tool_args = {"url": testnet_url, "method": "GET", "headers": {}} + + request = MagicMock() + request.tool_call = { + "name": "http_request", + "args": tool_args, + "id": "functional-test-1", + } + + call_count = [0] + + def handler(req): + """Simulate LangGraph's tool execution — calls the actual http_request tool.""" + call_count[0] += 1 + content = tool.invoke(req.tool_call["args"]) + return ToolMessage(content=content, tool_call_id=req.tool_call["id"]) + + print(f"\n[Calling wrap_tool_call against {testnet_url}]") + result = middleware.wrap_tool_call(request, handler) + + print(f"[Handler called {call_count[0]} time(s)]") + print(f"[Result content]: {result.content[:200]}...") + + # Should have been called twice: initial 402 + retry with payment header + assert call_count[0] == 2, f"Expected 2 calls (402 + retry), got {call_count[0]}" + + # Result should NOT be a payment error + assert "PAYMENT ERROR" not in result.content, f"Payment failed: {result.content}" + + # Result should be successful (200) + parsed = json.loads(result.content) + assert parsed["statusCode"] == 200, f"Expected 200 on retry, got: {parsed.get('statusCode')}" + print(f"[Success] Paid content received: {json.dumps(parsed['body'], indent=2)[:200]}") + + def test_wrap_tool_call_mcp_gateway_shape(self, middleware, testnet_url): + """wrap_tool_call works with MCP Gateway shaped tool input (parameters.headers).""" + import httpx + + # MCP Gateway tools have args like: {"toolName": "...", "parameters": {"url": ..., "headers": {}}} + # The MCPRequestPaymentHandler injects headers into parameters.headers + tool_args = { + "toolName": "fetch_paid_content", + "parameters": {"url": testnet_url, "method": "GET", "headers": {}}, + } + + request = MagicMock() + request.tool_call = { + "name": "mcp_proxy_tool", + "args": tool_args, + "id": "functional-mcp-test", + } + + call_count = [0] + + def handler(req): + """Simulate MCP proxy: uses parameters.url and parameters.headers to make the real call.""" + call_count[0] += 1 + params = req.tool_call["args"]["parameters"] + url = params["url"] + headers = params.get("headers", {}) + + with httpx.Client(timeout=30.0, follow_redirects=True) as client: + resp = client.request("GET", url, headers=headers) + + resp_headers = dict(resp.headers) + try: + resp_body = resp.json() + except Exception: + resp_body = {"text": resp.text} + + payload = {"statusCode": resp.status_code, "headers": resp_headers, "body": resp_body} + + if resp.status_code == 402: + content = f"PAYMENT_REQUIRED: {json.dumps(payload)}" + else: + content = json.dumps(payload) + + return ToolMessage(content=content, tool_call_id=req.tool_call["id"]) + + print(f"\n[MCP Gateway shape test against {testnet_url}]") + result = middleware.wrap_tool_call(request, handler) + + print(f"[Handler called {call_count[0]} time(s)]") + print(f"[Result content]: {result.content[:200]}...") + + assert call_count[0] == 2, f"Expected 2 calls (402 + retry), got {call_count[0]}" + assert "PAYMENT ERROR" not in result.content, f"Payment failed: {result.content}" + + parsed = json.loads(result.content) + assert parsed["statusCode"] == 200, f"Expected 200, got: {parsed.get('statusCode')}" + + # Verify header was injected into parameters.headers (not top-level headers) + injected = tool_args["parameters"]["headers"] + print(f"[MCP parameters.headers]: {list(injected.keys())}") + assert len(injected) > 0, "No payment header injected into parameters.headers" + print(f"[MCP Gateway flow succeeded with 200]") + + def test_payment_header_was_injected(self, middleware, testnet_url): + """After wrap_tool_call, the tool_args dict has a payment header.""" + tool = next(t for t in middleware.tools if t.name == "http_request") + + tool_args = {"url": testnet_url, "method": "GET", "headers": {}} + + request = MagicMock() + request.tool_call = { + "name": "http_request", + "args": tool_args, + "id": "functional-test-2", + } + + def handler(req): + content = tool.invoke(req.tool_call["args"]) + return ToolMessage(content=content, tool_call_id=req.tool_call["id"]) + + middleware.wrap_tool_call(request, handler) + + # After the flow, headers should contain a payment header + injected_headers = tool_args.get("headers", {}) + print(f"\n[Injected headers]: {list(injected_headers.keys())}") + assert len(injected_headers) > 0, "No payment header was injected" + # Common header names: X-PAYMENT (v1) or PAYMENT-SIGNATURE (v2) + has_payment_header = any( + k.upper() in ("X-PAYMENT", "PAYMENT-SIGNATURE", "PAYMENT") + for k in injected_headers + ) + assert has_payment_header, f"Expected payment header, got: {list(injected_headers.keys())}" + + +class TestPaymentQueryTools: + """Functional tests for payment query tools against real PaymentManager.""" + + def test_get_payment_instrument(self, middleware): + """get_payment_instrument returns real instrument data.""" + tool = next(t for t in middleware.tools if t.name == "get_payment_instrument") + result = tool.invoke({}) + print(f"\n[get_payment_instrument]: {str(result)[:300]}") + assert "paymentInstrumentId" in result or "payment_instrument_id" in str(result).lower() + + def test_get_payment_session(self, middleware): + """get_payment_session returns real session data.""" + tool = next(t for t in middleware.tools if t.name == "get_payment_session") + result = tool.invoke({}) + print(f"\n[get_payment_session]: {str(result)[:300]}") + assert "paymentSessionId" in result or "payment_session_id" in str(result).lower() + + +class TestFallbackDetectionFunctional: + """Functional test: tool returns raw JSON (no PAYMENT_REQUIRED: marker) and fallback detects 402.""" + + def test_raw_json_tool_full_flow(self, middleware, testnet_url): + """A tool returning raw JSON without the marker still gets payment processing via fallback.""" + import httpx + + tool_args = {"url": testnet_url, "headers": {}} + + request = MagicMock() + request.tool_call = {"name": "raw_api_tool", "args": tool_args, "id": "fallback-test"} + + call_count = [0] + + def handler(req): + """Tool that returns raw JSON — NO PAYMENT_REQUIRED: prefix.""" + call_count[0] += 1 + url = req.tool_call["args"]["url"] + headers = req.tool_call["args"].get("headers", {}) + + with httpx.Client(timeout=30.0, follow_redirects=True) as client: + resp = client.request("GET", url, headers=headers) + + resp_headers = dict(resp.headers) + try: + resp_body = resp.json() + except Exception: + resp_body = {"text": resp.text} + + # Raw JSON — no marker + payload = json.dumps({"statusCode": resp.status_code, "headers": resp_headers, "body": resp_body}) + return ToolMessage(content=payload, tool_call_id=req.tool_call["id"]) + + print(f"\n[Fallback detection test (no marker, no custom handler) against {testnet_url}]") + result = middleware.wrap_tool_call(request, handler) + + print(f"[Handler called {call_count[0]} time(s)]") + print(f"[Result content]: {result.content[:200]}...") + + # Fallback detected 402 from raw JSON and processed payment + assert call_count[0] == 2, f"Expected 2 calls (402 + retry), got {call_count[0]}" + assert "PAYMENT ERROR" not in result.content, f"Payment failed: {result.content}" + + parsed = json.loads(result.content) + assert parsed["statusCode"] == 200, f"Expected 200, got {parsed.get('statusCode')}" + print(f"[Fallback detection flow succeeded — no marker, no custom handler, got 200]") + + +class TestCustomHandlerRegistry: + """Functional test: custom handler resolves and processes payment correctly.""" + + def test_custom_handler_full_flow(self, config, testnet_url): + """A custom handler registered for a tool name is used for detection, extraction, and injection.""" + from bedrock_agentcore.payments.integrations.handlers import GenericPaymentHandler + + # Custom handler that tracks all three phases + class TrackingHandler(GenericPaymentHandler): + def __init__(self): + self.detect_called = False + self.extract_called = False + self.inject_called = False + + def extract_status_code(self, result): + self.detect_called = True + return super().extract_status_code(result) + + def extract_headers(self, result): + self.extract_called = True + return super().extract_headers(result) + + def apply_payment_header(self, tool_input, payment_header): + self.inject_called = True + return super().apply_payment_header(tool_input, payment_header) + + custom_handler = TrackingHandler() + + from dataclasses import replace + custom_config = replace(config, custom_handlers={"my_http_tool": custom_handler}) + mw = AgentCorePaymentsMiddleware(custom_config) + + http_tool = next(t for t in mw.tools if t.name == "http_request") + tool_args = {"url": testnet_url, "method": "GET", "headers": {}} + + request = MagicMock() + request.tool_call = {"name": "my_http_tool", "args": tool_args, "id": "custom-handler-test"} + + call_count = [0] + + def handler(req): + call_count[0] += 1 + content = http_tool.invoke(req.tool_call["args"]) + return ToolMessage(content=content, tool_call_id=req.tool_call["id"]) + + print(f"\n[Testing custom handler registry against {testnet_url}]") + result = mw.wrap_tool_call(request, handler) + + assert custom_handler.detect_called, "Custom handler's extract_status_code was not invoked" + assert custom_handler.extract_called, "Custom handler's extract_headers was not invoked" + assert custom_handler.inject_called, "Custom handler's apply_payment_header was not invoked" + print(f"[Custom handler used for detection: ✓, extraction: ✓, injection: ✓]") + + assert call_count[0] == 2, f"Expected 2 calls, got {call_count[0]}" + assert "PAYMENT ERROR" not in result.content, f"Payment failed: {result.content}" + + parsed = json.loads(result.content) + assert parsed["statusCode"] == 200 + print(f"[Custom handler flow succeeded with 200]") + + def test_custom_handler_non_marker_tool(self, config, testnet_url): + """Custom handler detects 402 from a tool that does NOT use the PAYMENT_REQUIRED: marker.""" + import httpx + from bedrock_agentcore.payments.integrations.handlers import PaymentResponseHandler + + # Custom handler that detects 402 from raw JSON (no marker prefix) + class RawJsonHandler(PaymentResponseHandler): + """Handles tools that return raw JSON like {"statusCode": 402, "headers": {...}, "body": {...}}""" + + def __init__(self): + self.detect_called = False + + def extract_status_code(self, result): + self.detect_called = True + import json as _json + # result is {"content": [{"text": "..."}]} from _prepare_for_handler + content = result.get("content", []) + for block in content: + text = block.get("text", "") if isinstance(block, dict) else "" + try: + parsed = _json.loads(text) + if isinstance(parsed, dict): + return parsed.get("statusCode") + except (ValueError, TypeError): + pass + return None + + def extract_headers(self, result): + import json as _json + content = result.get("content", []) + for block in content: + text = block.get("text", "") if isinstance(block, dict) else "" + try: + parsed = _json.loads(text) + if isinstance(parsed, dict): + return parsed.get("headers", {}) + except (ValueError, TypeError): + pass + return None + + def extract_body(self, result): + import json as _json + content = result.get("content", []) + for block in content: + text = block.get("text", "") if isinstance(block, dict) else "" + try: + parsed = _json.loads(text) + if isinstance(parsed, dict): + return parsed.get("body", {}) + except (ValueError, TypeError): + pass + return None + + def validate_tool_input(self, tool_input): + return isinstance(tool_input, dict) + + def apply_payment_header(self, tool_input, payment_header): + if "headers" not in tool_input: + tool_input["headers"] = {} + tool_input["headers"].update(payment_header) + return True + + custom_handler = RawJsonHandler() + + from dataclasses import replace + custom_config = replace(config, custom_handlers={"raw_http_tool": custom_handler}) + mw = AgentCorePaymentsMiddleware(custom_config) + + tool_args = {"url": testnet_url, "headers": {}} + + request = MagicMock() + request.tool_call = {"name": "raw_http_tool", "args": tool_args, "id": "non-marker-test"} + + call_count = [0] + + def handler(req): + """Tool that returns raw JSON WITHOUT the PAYMENT_REQUIRED: marker.""" + call_count[0] += 1 + url = req.tool_call["args"]["url"] + headers = req.tool_call["args"].get("headers", {}) + + with httpx.Client(timeout=30.0, follow_redirects=True) as client: + resp = client.request("GET", url, headers=headers) + + resp_headers = dict(resp.headers) + try: + resp_body = resp.json() + except Exception: + resp_body = {"text": resp.text} + + # NO PAYMENT_REQUIRED: prefix — just raw JSON + payload = json.dumps({"statusCode": resp.status_code, "headers": resp_headers, "body": resp_body}) + return ToolMessage(content=payload, tool_call_id=req.tool_call["id"]) + + print(f"\n[Non-marker tool with custom handler against {testnet_url}]") + result = mw.wrap_tool_call(request, handler) + + print(f"[Handler called {call_count[0]} time(s)]") + print(f"[Custom handler detect_called: {custom_handler.detect_called}]") + print(f"[Result content]: {result.content[:200]}...") + + # Custom handler detected 402 from raw JSON + assert custom_handler.detect_called, "Custom handler was not invoked for detection" + + # Full flow worked + assert call_count[0] == 2, f"Expected 2 calls (402 + retry), got {call_count[0]}" + assert "PAYMENT ERROR" not in result.content, f"Payment failed: {result.content}" + + parsed = json.loads(result.content) + assert parsed["statusCode"] == 200, f"Expected 200, got {parsed.get('statusCode')}" + print(f"[Non-marker custom handler flow succeeded with 200]") + from bedrock_agentcore.payments.integrations.handlers import GenericPaymentHandler + + # Custom handler that tracks all three phases + class TrackingHandler(GenericPaymentHandler): + def __init__(self): + self.detect_called = False + self.extract_called = False + self.inject_called = False + + def extract_status_code(self, result): + self.detect_called = True + return super().extract_status_code(result) + + def extract_headers(self, result): + self.extract_called = True + return super().extract_headers(result) + + def apply_payment_header(self, tool_input, payment_header): + self.inject_called = True + return super().apply_payment_header(tool_input, payment_header) + + custom_handler = TrackingHandler() + + # Create middleware with custom handler for "my_http_tool" + from dataclasses import replace + custom_config = replace(config, custom_handlers={"my_http_tool": custom_handler}) + mw = AgentCorePaymentsMiddleware(custom_config) + + # Use the real http_request tool under the hood, but the tool_call name is "my_http_tool" + http_tool = next(t for t in mw.tools if t.name == "http_request") + tool_args = {"url": testnet_url, "method": "GET", "headers": {}} + + request = MagicMock() + request.tool_call = {"name": "my_http_tool", "args": tool_args, "id": "custom-handler-test"} + + call_count = [0] + + def handler(req): + call_count[0] += 1 + content = http_tool.invoke(req.tool_call["args"]) + return ToolMessage(content=content, tool_call_id=req.tool_call["id"]) + + print(f"\n[Testing custom handler registry against {testnet_url}]") + result = mw.wrap_tool_call(request, handler) + + # Custom handler was used for all three phases + assert custom_handler.detect_called, "Custom handler's extract_status_code was not invoked" + assert custom_handler.extract_called, "Custom handler's extract_headers was not invoked" + assert custom_handler.inject_called, "Custom handler's apply_payment_header was not invoked" + print(f"[Custom handler used for detection: ✓, extraction: ✓, injection: ✓]") + + # Full flow still works (402 → sign → retry → 200) + assert call_count[0] == 2, f"Expected 2 calls, got {call_count[0]}" + assert "PAYMENT ERROR" not in result.content, f"Payment failed: {result.content}" + + parsed = json.loads(result.content) + assert parsed["statusCode"] == 200 + print(f"[Custom handler flow succeeded with 200]") diff --git a/tests/bedrock_agentcore/payments/integrations/langgraph/test_stage1.py b/tests/bedrock_agentcore/payments/integrations/langgraph/test_stage1.py new file mode 100644 index 00000000..d0dcb2bb --- /dev/null +++ b/tests/bedrock_agentcore/payments/integrations/langgraph/test_stage1.py @@ -0,0 +1,332 @@ +"""Tests for LangGraph AgentCorePaymentsConfig and AgentCorePaymentsMiddleware.""" + +import asyncio +from unittest.mock import MagicMock, AsyncMock, patch + +import pytest + +from bedrock_agentcore.payments.integrations.langgraph.config import AgentCorePaymentsConfig +from bedrock_agentcore.payments.integrations.langgraph.middleware import AgentCorePaymentsMiddleware +from bedrock_agentcore.payments.integrations.handlers import GenericPaymentHandler, PaymentResponseHandler + + +# --------------------------------------------------------------------------- +# Config validation tests +# --------------------------------------------------------------------------- + + +class TestAgentCorePaymentsConfigValidation: + """Test AgentCorePaymentsConfig field validation.""" + + def test_valid_minimal_config(self): + config = AgentCorePaymentsConfig( + payment_manager_arn="arn:aws:bedrock-agentcore:us-east-1:123456789012:payment-manager/pm-1", + user_id="user-123", + ) + assert config.payment_manager_arn.startswith("arn:") + assert config.user_id == "user-123" + assert config.auto_payment is True + assert config.provide_http_request is True + assert config.post_payment_retry_delay_seconds == 3.0 + assert config.custom_handlers is None + + def test_empty_arn_raises(self): + with pytest.raises(ValueError, match="payment_manager_arn is required"): + AgentCorePaymentsConfig(payment_manager_arn="", user_id="u") + + def test_invalid_arn_format_raises(self): + with pytest.raises(ValueError, match="Invalid ARN format"): + AgentCorePaymentsConfig(payment_manager_arn="not-an-arn", user_id="u") + + def test_user_id_required_for_sigv4(self): + with pytest.raises(ValueError, match="user_id is required for SigV4"): + AgentCorePaymentsConfig( + payment_manager_arn="arn:aws:bedrock-agentcore:us-east-1:123456789012:payment-manager/pm-1" + ) + + def test_user_id_optional_with_bearer_token(self): + config = AgentCorePaymentsConfig( + payment_manager_arn="arn:aws:bedrock-agentcore:us-east-1:123456789012:payment-manager/pm-1", + bearer_token="jwt-token", + ) + assert config.user_id is None + + def test_user_id_optional_with_token_provider(self): + config = AgentCorePaymentsConfig( + payment_manager_arn="arn:aws:bedrock-agentcore:us-east-1:123456789012:payment-manager/pm-1", + token_provider=lambda: "fresh", + ) + assert config.user_id is None + + def test_whitespace_user_id_raises(self): + with pytest.raises(ValueError, match="user_id cannot be whitespace-only"): + AgentCorePaymentsConfig( + payment_manager_arn="arn:aws:bedrock-agentcore:us-east-1:123456789012:payment-manager/pm-1", + user_id=" ", + ) + + def test_bearer_token_and_token_provider_mutually_exclusive(self): + with pytest.raises(ValueError, match="mutually exclusive"): + AgentCorePaymentsConfig( + payment_manager_arn="arn:aws:bedrock-agentcore:us-east-1:123456789012:payment-manager/pm-1", + user_id="u", + bearer_token="tok", + token_provider=lambda: "tok", + ) + + def test_bearer_token_must_be_string(self): + with pytest.raises(ValueError, match="bearer_token must be a string"): + AgentCorePaymentsConfig( + payment_manager_arn="arn:aws:bedrock-agentcore:us-east-1:123456789012:payment-manager/pm-1", + user_id="u", + bearer_token=123, # type: ignore + ) + + def test_token_provider_must_be_callable(self): + with pytest.raises(ValueError, match="token_provider must be callable"): + AgentCorePaymentsConfig( + payment_manager_arn="arn:aws:bedrock-agentcore:us-east-1:123456789012:payment-manager/pm-1", + user_id="u", + token_provider="not-callable", # type: ignore + ) + + def test_auto_payment_must_be_bool(self): + with pytest.raises(ValueError, match="auto_payment must be a boolean"): + AgentCorePaymentsConfig( + payment_manager_arn="arn:aws:bedrock-agentcore:us-east-1:123456789012:payment-manager/pm-1", + user_id="u", + auto_payment="yes", # type: ignore + ) + + def test_provide_http_request_must_be_bool(self): + with pytest.raises(ValueError, match="provide_http_request must be a boolean"): + AgentCorePaymentsConfig( + payment_manager_arn="arn:aws:bedrock-agentcore:us-east-1:123456789012:payment-manager/pm-1", + user_id="u", + provide_http_request=1, # type: ignore + ) + + def test_allowlist_must_be_list(self): + with pytest.raises(ValueError, match="payment_tool_allowlist must be a list"): + AgentCorePaymentsConfig( + payment_manager_arn="arn:aws:bedrock-agentcore:us-east-1:123456789012:payment-manager/pm-1", + user_id="u", + payment_tool_allowlist="http_request", # type: ignore + ) + + def test_allowlist_entries_must_be_strings(self): + with pytest.raises(ValueError, match="All entries in payment_tool_allowlist must be strings"): + AgentCorePaymentsConfig( + payment_manager_arn="arn:aws:bedrock-agentcore:us-east-1:123456789012:payment-manager/pm-1", + user_id="u", + payment_tool_allowlist=["ok", 123], # type: ignore + ) + + def test_delay_must_be_numeric(self): + with pytest.raises(ValueError, match="post_payment_retry_delay_seconds must be a number"): + AgentCorePaymentsConfig( + payment_manager_arn="arn:aws:bedrock-agentcore:us-east-1:123456789012:payment-manager/pm-1", + user_id="u", + post_payment_retry_delay_seconds="3", # type: ignore + ) + + def test_delay_must_be_non_negative(self): + with pytest.raises(ValueError, match="must be >= 0"): + AgentCorePaymentsConfig( + payment_manager_arn="arn:aws:bedrock-agentcore:us-east-1:123456789012:payment-manager/pm-1", + user_id="u", + post_payment_retry_delay_seconds=-1, + ) + + def test_delay_bool_rejected(self): + """Booleans are technically int subclass but should be rejected.""" + with pytest.raises(ValueError, match="post_payment_retry_delay_seconds must be a number"): + AgentCorePaymentsConfig( + payment_manager_arn="arn:aws:bedrock-agentcore:us-east-1:123456789012:payment-manager/pm-1", + user_id="u", + post_payment_retry_delay_seconds=True, # type: ignore + ) + + def test_custom_handlers_must_be_dict(self): + with pytest.raises(ValueError, match="custom_handlers must be a dict"): + AgentCorePaymentsConfig( + payment_manager_arn="arn:aws:bedrock-agentcore:us-east-1:123456789012:payment-manager/pm-1", + user_id="u", + custom_handlers=["not", "a", "dict"], # type: ignore + ) + + def test_custom_handlers_keys_must_be_strings(self): + with pytest.raises(ValueError, match="All keys in custom_handlers must be strings"): + AgentCorePaymentsConfig( + payment_manager_arn="arn:aws:bedrock-agentcore:us-east-1:123456789012:payment-manager/pm-1", + user_id="u", + custom_handlers={123: GenericPaymentHandler()}, # type: ignore + ) + + def test_custom_handlers_values_must_be_handler_instances(self): + with pytest.raises(ValueError, match="All values in custom_handlers must be PaymentResponseHandler"): + AgentCorePaymentsConfig( + payment_manager_arn="arn:aws:bedrock-agentcore:us-east-1:123456789012:payment-manager/pm-1", + user_id="u", + custom_handlers={"my_tool": "not-a-handler"}, # type: ignore + ) + + def test_custom_handlers_valid(self): + handler = GenericPaymentHandler() + config = AgentCorePaymentsConfig( + payment_manager_arn="arn:aws:bedrock-agentcore:us-east-1:123456789012:payment-manager/pm-1", + user_id="u", + custom_handlers={"my_tool": handler}, + ) + assert config.custom_handlers == {"my_tool": handler} + + def test_valid_full_config(self): + config = AgentCorePaymentsConfig( + payment_manager_arn="arn:aws:bedrock-agentcore:us-east-1:123456789012:payment-manager/pm-1", + user_id="user-1", + payment_instrument_id="instr-1", + payment_session_id="sess-1", + payment_connector_id="conn-1", + region="us-west-2", + network_preferences_config=["eip155:84532"], + auto_payment=False, + agent_name="my-agent", + payment_tool_allowlist=["http_request"], + provide_http_request=False, + post_payment_retry_delay_seconds=5.0, + ) + assert config.auto_payment is False + assert config.provide_http_request is False + assert config.post_payment_retry_delay_seconds == 5.0 + assert config.payment_tool_allowlist == ["http_request"] + + +# --------------------------------------------------------------------------- +# Middleware instantiation tests +# --------------------------------------------------------------------------- + + +class TestAgentCorePaymentsMiddlewareInstantiation: + """Test middleware creation and PaymentManager initialization.""" + + @patch("bedrock_agentcore.payments.integrations.langgraph.middleware.PaymentManager") + def test_middleware_creates_payment_manager(self, mock_pm_cls): + """PaymentManager is created with config values during __init__.""" + config = AgentCorePaymentsConfig( + payment_manager_arn="arn:aws:bedrock-agentcore:us-east-1:123456789012:payment-manager/pm-1", + user_id="user-1", + region="us-west-2", + agent_name="test-agent", + ) + mw = AgentCorePaymentsMiddleware(config) + + mock_pm_cls.assert_called_once_with( + payment_manager_arn="arn:aws:bedrock-agentcore:us-east-1:123456789012:payment-manager/pm-1", + region_name="us-west-2", + agent_name="test-agent", + bearer_token=None, + token_provider=None, + ) + assert mw.config is config + assert mw.payment_manager is mock_pm_cls.return_value + + @patch("bedrock_agentcore.payments.integrations.langgraph.middleware.PaymentManager") + def test_middleware_passes_bearer_token(self, mock_pm_cls): + config = AgentCorePaymentsConfig( + payment_manager_arn="arn:aws:bedrock-agentcore:us-east-1:123456789012:payment-manager/pm-1", + bearer_token="my-jwt", + ) + AgentCorePaymentsMiddleware(config) + mock_pm_cls.assert_called_once_with( + payment_manager_arn="arn:aws:bedrock-agentcore:us-east-1:123456789012:payment-manager/pm-1", + region_name=None, + agent_name=None, + bearer_token="my-jwt", + token_provider=None, + ) + + @patch("bedrock_agentcore.payments.integrations.langgraph.middleware.PaymentManager") + def test_middleware_raises_runtime_error_on_pm_failure(self, mock_pm_cls): + """RuntimeError raised if PaymentManager constructor throws.""" + mock_pm_cls.side_effect = Exception("boto3 broke") + config = AgentCorePaymentsConfig( + payment_manager_arn="arn:aws:bedrock-agentcore:us-east-1:123456789012:payment-manager/pm-1", + user_id="u", + ) + with pytest.raises(RuntimeError, match="Failed to initialize PaymentManager"): + AgentCorePaymentsMiddleware(config) + + +# --------------------------------------------------------------------------- +# Pass-through behavior tests +# --------------------------------------------------------------------------- + + +class TestAgentCorePaymentsMiddlewarePassThrough: + """Test that wrap_tool_call and awrap_tool_call pass through correctly.""" + + @patch("bedrock_agentcore.payments.integrations.langgraph.middleware.PaymentManager") + def test_wrap_tool_call_passes_through(self, mock_pm_cls): + """Sync wrap_tool_call calls handler and returns its result.""" + config = AgentCorePaymentsConfig( + payment_manager_arn="arn:aws:bedrock-agentcore:us-east-1:123456789012:payment-manager/pm-1", + user_id="u", + ) + mw = AgentCorePaymentsMiddleware(config) + + mock_request = MagicMock() + mock_result = MagicMock() + mock_handler = MagicMock(return_value=mock_result) + + result = mw.wrap_tool_call(mock_request, mock_handler) + + mock_handler.assert_called_once_with(mock_request) + assert result is mock_result + + @patch("bedrock_agentcore.payments.integrations.langgraph.middleware.PaymentManager") + def test_awrap_tool_call_passes_through(self, mock_pm_cls): + """Async awrap_tool_call awaits handler and returns its result.""" + config = AgentCorePaymentsConfig( + payment_manager_arn="arn:aws:bedrock-agentcore:us-east-1:123456789012:payment-manager/pm-1", + user_id="u", + ) + mw = AgentCorePaymentsMiddleware(config) + + mock_request = MagicMock() + mock_result = MagicMock() + mock_handler = AsyncMock(return_value=mock_result) + + result = asyncio.run(mw.awrap_tool_call(mock_request, mock_handler)) + + mock_handler.assert_called_once_with(mock_request) + assert result is mock_result + + @patch("bedrock_agentcore.payments.integrations.langgraph.middleware.PaymentManager") + def test_wrap_tool_call_propagates_exceptions(self, mock_pm_cls): + """Exceptions from handler propagate through wrap_tool_call.""" + config = AgentCorePaymentsConfig( + payment_manager_arn="arn:aws:bedrock-agentcore:us-east-1:123456789012:payment-manager/pm-1", + user_id="u", + ) + mw = AgentCorePaymentsMiddleware(config) + + mock_request = MagicMock() + mock_handler = MagicMock(side_effect=RuntimeError("tool exploded")) + + with pytest.raises(RuntimeError, match="tool exploded"): + mw.wrap_tool_call(mock_request, mock_handler) + + @patch("bedrock_agentcore.payments.integrations.langgraph.middleware.PaymentManager") + def test_awrap_tool_call_propagates_exceptions(self, mock_pm_cls): + """Exceptions from async handler propagate through awrap_tool_call.""" + config = AgentCorePaymentsConfig( + payment_manager_arn="arn:aws:bedrock-agentcore:us-east-1:123456789012:payment-manager/pm-1", + user_id="u", + ) + mw = AgentCorePaymentsMiddleware(config) + + mock_request = MagicMock() + mock_handler = AsyncMock(side_effect=ValueError("async boom")) + + with pytest.raises(ValueError, match="async boom"): + asyncio.run(mw.awrap_tool_call(mock_request, mock_handler)) diff --git a/tests/bedrock_agentcore/payments/integrations/langgraph/test_stage2.py b/tests/bedrock_agentcore/payments/integrations/langgraph/test_stage2.py new file mode 100644 index 00000000..6f7af5cd --- /dev/null +++ b/tests/bedrock_agentcore/payments/integrations/langgraph/test_stage2.py @@ -0,0 +1,326 @@ +"""Tests for Stage 2: 402 Detection + Adapter.""" + +import json +from unittest.mock import MagicMock, patch + +import pytest +from langchain.messages import ToolMessage +from langgraph.types import Command + +from bedrock_agentcore.payments.integrations.handlers import ( + GenericPaymentHandler, + HttpRequestPaymentHandler, + MCPRequestPaymentHandler, + PaymentResponseHandler, +) +from bedrock_agentcore.payments.integrations.langgraph.config import AgentCorePaymentsConfig +from bedrock_agentcore.payments.integrations.langgraph.middleware import AgentCorePaymentsMiddleware + + +def _make_config(**overrides): + defaults = { + "payment_manager_arn": "arn:aws:bedrock-agentcore:us-east-1:123456789012:payment-manager/pm-1", + "user_id": "user-1", + } + defaults.update(overrides) + return AgentCorePaymentsConfig(**defaults) + + +def _make_request(tool_name="http_request", tool_args=None, tool_id="tc-1"): + req = MagicMock() + req.tool_call = {"name": tool_name, "args": tool_args or {}, "id": tool_id} + return req + + +PAYMENT_402_PAYLOAD = json.dumps({"statusCode": 402, "headers": {"x-pay": "v"}, "body": {"x402Version": 1}}) + + +# --------------------------------------------------------------------------- +# _prepare_for_handler tests +# --------------------------------------------------------------------------- + + +class TestPrepareForHandler: + """Test the adapter that normalizes ToolMessage.content for handlers.""" + + def test_string_content_wrapped(self): + content = f"PAYMENT_REQUIRED: {PAYMENT_402_PAYLOAD}" + result = AgentCorePaymentsMiddleware._prepare_for_handler(content) + assert result == {"content": [{"text": content}]} + + def test_list_of_dicts_passed_through(self): + content = [{"text": "PAYMENT_REQUIRED: " + PAYMENT_402_PAYLOAD}] + result = AgentCorePaymentsMiddleware._prepare_for_handler(content) + assert result == {"content": content} + + def test_list_of_strings_wrapped(self): + content = ["PAYMENT_REQUIRED: data", "other"] + result = AgentCorePaymentsMiddleware._prepare_for_handler(content) + assert result == {"content": [{"text": "PAYMENT_REQUIRED: data"}, {"text": "other"}]} + + def test_mixed_list(self): + content = [{"text": "foo"}, "bar"] + result = AgentCorePaymentsMiddleware._prepare_for_handler(content) + assert result == {"content": [{"text": "foo"}, {"text": "bar"}]} + + def test_empty_string(self): + result = AgentCorePaymentsMiddleware._prepare_for_handler("") + assert result == {"content": [{"text": ""}]} + + def test_none_returns_none(self): + assert AgentCorePaymentsMiddleware._prepare_for_handler(None) is None + + def test_non_str_non_list_returns_none(self): + assert AgentCorePaymentsMiddleware._prepare_for_handler(12345) is None + + +# --------------------------------------------------------------------------- +# _get_handler tests +# --------------------------------------------------------------------------- + + +class TestGetHandler: + """Test handler resolution priority.""" + + @patch("bedrock_agentcore.payments.integrations.langgraph.middleware.PaymentManager") + def test_custom_handler_takes_priority(self, mock_pm): + custom = MagicMock(spec=PaymentResponseHandler) + config = _make_config(custom_handlers={"my_tool": custom}) + mw = AgentCorePaymentsMiddleware(config) + assert mw._get_handler("my_tool", {}) is custom + + @patch("bedrock_agentcore.payments.integrations.langgraph.middleware.PaymentManager") + def test_custom_handler_only_for_registered_name(self, mock_pm): + custom = MagicMock(spec=PaymentResponseHandler) + config = _make_config(custom_handlers={"my_tool": custom}) + mw = AgentCorePaymentsMiddleware(config) + handler = mw._get_handler("other_tool", {}) + assert handler is not custom + assert isinstance(handler, GenericPaymentHandler) + + @patch("bedrock_agentcore.payments.integrations.langgraph.middleware.PaymentManager") + def test_builtin_name_based_handler(self, mock_pm): + mw = AgentCorePaymentsMiddleware(_make_config()) + handler = mw._get_handler("http_request", {}) + assert isinstance(handler, HttpRequestPaymentHandler) + + @patch("bedrock_agentcore.payments.integrations.langgraph.middleware.PaymentManager") + def test_mcp_shape_detected(self, mock_pm): + mw = AgentCorePaymentsMiddleware(_make_config()) + handler = mw._get_handler("proxy_tool", {"toolName": "x", "parameters": {}}) + assert isinstance(handler, MCPRequestPaymentHandler) + + @patch("bedrock_agentcore.payments.integrations.langgraph.middleware.PaymentManager") + def test_generic_fallback(self, mock_pm): + mw = AgentCorePaymentsMiddleware(_make_config()) + handler = mw._get_handler("unknown_tool", {"url": "http://example.com"}) + assert isinstance(handler, GenericPaymentHandler) + + +# --------------------------------------------------------------------------- +# 402 detection tests +# --------------------------------------------------------------------------- + + +class TestDetection402: + """Test that 402 is correctly detected from various formats.""" + + @patch("bedrock_agentcore.payments.integrations.langgraph.middleware.PaymentManager") + def test_402_detected_from_marker_string(self, mock_pm): + mw = AgentCorePaymentsMiddleware(_make_config()) + content = f"PAYMENT_REQUIRED: {PAYMENT_402_PAYLOAD}" + tool_msg = ToolMessage(content=content, tool_call_id="tc-1") + + request = _make_request() + mock_handler = MagicMock(return_value=tool_msg) + + result = mw.wrap_tool_call(request, mock_handler) + + # 402 detected → signing attempted → error (no instrument configured) + assert isinstance(result, ToolMessage) + assert "PAYMENT ERROR" in result.content + + @patch("bedrock_agentcore.payments.integrations.langgraph.middleware.PaymentManager") + def test_402_detected_from_content_list(self, mock_pm): + mw = AgentCorePaymentsMiddleware(_make_config()) + content = [{"text": f"PAYMENT_REQUIRED: {PAYMENT_402_PAYLOAD}"}] + tool_msg = ToolMessage(content=content, tool_call_id="tc-1") + + request = _make_request() + mock_handler = MagicMock(return_value=tool_msg) + + result = mw.wrap_tool_call(request, mock_handler) + + # 402 detected → signing attempted → error (no instrument configured) + assert isinstance(result, ToolMessage) + assert "PAYMENT ERROR" in result.content + + @patch("bedrock_agentcore.payments.integrations.langgraph.middleware.PaymentManager") + def test_non_402_status_passes_through(self, mock_pm): + mw = AgentCorePaymentsMiddleware(_make_config()) + payload = json.dumps({"statusCode": 200, "headers": {}, "body": {}}) + content = f"PAYMENT_REQUIRED: {payload}" + tool_msg = ToolMessage(content=content, tool_call_id="tc-1") + + request = _make_request() + mock_handler = MagicMock(return_value=tool_msg) + result = mw.wrap_tool_call(request, mock_handler) + assert result is tool_msg + + @patch("bedrock_agentcore.payments.integrations.langgraph.middleware.PaymentManager") + def test_no_marker_passes_through(self, mock_pm): + mw = AgentCorePaymentsMiddleware(_make_config()) + tool_msg = ToolMessage(content="Just normal output", tool_call_id="tc-1") + + request = _make_request() + mock_handler = MagicMock(return_value=tool_msg) + result = mw.wrap_tool_call(request, mock_handler) + assert result is tool_msg + + +# --------------------------------------------------------------------------- +# Guard condition tests +# --------------------------------------------------------------------------- + + +class TestGuardConditions: + """Test that guards correctly bypass 402 detection.""" + + @patch("bedrock_agentcore.payments.integrations.langgraph.middleware.PaymentManager") + def test_command_result_passes_through(self, mock_pm): + mw = AgentCorePaymentsMiddleware(_make_config()) + cmd = Command(update={"key": "val"}) + + request = _make_request() + mock_handler = MagicMock(return_value=cmd) + + with patch.object(mw, "_get_handler") as spy: + result = mw.wrap_tool_call(request, mock_handler) + spy.assert_not_called() + + assert result is cmd + + @patch("bedrock_agentcore.payments.integrations.langgraph.middleware.PaymentManager") + def test_auto_payment_false_skips(self, mock_pm): + config = _make_config(auto_payment=False) + mw = AgentCorePaymentsMiddleware(config) + content = f"PAYMENT_REQUIRED: {PAYMENT_402_PAYLOAD}" + tool_msg = ToolMessage(content=content, tool_call_id="tc-1") + + request = _make_request() + mock_handler = MagicMock(return_value=tool_msg) + + with patch.object(mw, "_get_handler") as spy: + result = mw.wrap_tool_call(request, mock_handler) + spy.assert_not_called() + + assert result is tool_msg + + @patch("bedrock_agentcore.payments.integrations.langgraph.middleware.PaymentManager") + def test_tool_not_in_allowlist_skips(self, mock_pm): + config = _make_config(payment_tool_allowlist=["other_tool"]) + mw = AgentCorePaymentsMiddleware(config) + content = f"PAYMENT_REQUIRED: {PAYMENT_402_PAYLOAD}" + tool_msg = ToolMessage(content=content, tool_call_id="tc-1") + + request = _make_request(tool_name="http_request") + mock_handler = MagicMock(return_value=tool_msg) + + with patch.object(mw, "_get_handler") as spy: + result = mw.wrap_tool_call(request, mock_handler) + spy.assert_not_called() + + assert result is tool_msg + + @patch("bedrock_agentcore.payments.integrations.langgraph.middleware.PaymentManager") + def test_tool_in_allowlist_proceeds(self, mock_pm): + config = _make_config(payment_tool_allowlist=["http_request"]) + mw = AgentCorePaymentsMiddleware(config) + content = f"PAYMENT_REQUIRED: {PAYMENT_402_PAYLOAD}" + tool_msg = ToolMessage(content=content, tool_call_id="tc-1") + + request = _make_request(tool_name="http_request") + mock_handler = MagicMock(return_value=tool_msg) + + result = mw.wrap_tool_call(request, mock_handler) + + # 402 detected and processing attempted (error because no instrument) + assert isinstance(result, ToolMessage) + assert "PAYMENT ERROR" in result.content + + @patch("bedrock_agentcore.payments.integrations.langgraph.middleware.PaymentManager") + def test_none_allowlist_processes_all(self, mock_pm): + config = _make_config(payment_tool_allowlist=None) + mw = AgentCorePaymentsMiddleware(config) + content = f"PAYMENT_REQUIRED: {PAYMENT_402_PAYLOAD}" + tool_msg = ToolMessage(content=content, tool_call_id="tc-1") + + request = _make_request(tool_name="any_tool") + mock_handler = MagicMock(return_value=tool_msg) + + result = mw.wrap_tool_call(request, mock_handler) + + # 402 detected and processing attempted (error because no instrument) + assert isinstance(result, ToolMessage) + assert "PAYMENT ERROR" in result.content + + +# --------------------------------------------------------------------------- +# Fallback detection tests (raw JSON without PAYMENT_REQUIRED: marker) +# --------------------------------------------------------------------------- + + +class TestFallbackDetection: + """Test lenient fallback that detects 402 from raw JSON.""" + + @patch("bedrock_agentcore.payments.integrations.langgraph.middleware.PaymentManager") + def test_raw_status_code_402_detected(self, mock_pm): + """Raw JSON with statusCode:402 triggers payment processing.""" + mw = AgentCorePaymentsMiddleware(_make_config()) + content = json.dumps({"statusCode": 402, "headers": {"h": "v"}, "body": {"x402Version": 1}}) + tool_msg = ToolMessage(content=content, tool_call_id="tc-1") + + request = _make_request() + mock_handler = MagicMock(return_value=tool_msg) + result = mw.wrap_tool_call(request, mock_handler) + + # 402 detected via fallback → payment processing attempted + assert "PAYMENT ERROR" in result.content + + @patch("bedrock_agentcore.payments.integrations.langgraph.middleware.PaymentManager") + def test_x402_payload_detected(self, mock_pm): + """Raw JSON with x402Version + accepts triggers payment processing.""" + mw = AgentCorePaymentsMiddleware(_make_config()) + content = json.dumps({"x402Version": 1, "accepts": [{"network": "base-sepolia"}]}) + tool_msg = ToolMessage(content=content, tool_call_id="tc-1") + + request = _make_request() + mock_handler = MagicMock(return_value=tool_msg) + result = mw.wrap_tool_call(request, mock_handler) + + assert "PAYMENT ERROR" in result.content + + @patch("bedrock_agentcore.payments.integrations.langgraph.middleware.PaymentManager") + def test_non_402_raw_json_passes_through(self, mock_pm): + """Raw JSON with statusCode:200 is not detected as 402.""" + mw = AgentCorePaymentsMiddleware(_make_config()) + content = json.dumps({"statusCode": 200, "body": {"data": "ok"}}) + tool_msg = ToolMessage(content=content, tool_call_id="tc-1") + + request = _make_request() + mock_handler = MagicMock(return_value=tool_msg) + result = mw.wrap_tool_call(request, mock_handler) + + assert result is tool_msg + + @patch("bedrock_agentcore.payments.integrations.langgraph.middleware.PaymentManager") + def test_plain_text_not_detected(self, mock_pm): + """Non-JSON text is not detected as 402.""" + mw = AgentCorePaymentsMiddleware(_make_config()) + tool_msg = ToolMessage(content="Hello world", tool_call_id="tc-1") + + request = _make_request() + mock_handler = MagicMock(return_value=tool_msg) + result = mw.wrap_tool_call(request, mock_handler) + + assert result is tool_msg diff --git a/tests/bedrock_agentcore/payments/integrations/langgraph/test_stage3.py b/tests/bedrock_agentcore/payments/integrations/langgraph/test_stage3.py new file mode 100644 index 00000000..aced5a88 --- /dev/null +++ b/tests/bedrock_agentcore/payments/integrations/langgraph/test_stage3.py @@ -0,0 +1,347 @@ +"""Tests for Stage 3: Payment Signing + Retry.""" + +import json +from unittest.mock import MagicMock, patch, call + +import pytest +from langchain.messages import ToolMessage +from langgraph.types import Command + +from bedrock_agentcore.payments.integrations.langgraph.config import AgentCorePaymentsConfig +from bedrock_agentcore.payments.integrations.langgraph.middleware import AgentCorePaymentsMiddleware +from bedrock_agentcore.payments.manager import ( + PaymentError, + PaymentInstrumentConfigurationRequired, + PaymentSessionConfigurationRequired, +) + + +def _make_config(**overrides): + defaults = { + "payment_manager_arn": "arn:aws:bedrock-agentcore:us-east-1:123456789012:payment-manager/pm-1", + "user_id": "user-1", + "payment_instrument_id": "instr-1", + "payment_session_id": "sess-1", + "post_payment_retry_delay_seconds": 0, # no delay in tests + } + defaults.update(overrides) + return AgentCorePaymentsConfig(**defaults) + + +def _make_request(tool_name="http_request", tool_args=None, tool_id="tc-1"): + req = MagicMock() + req.tool_call = {"name": tool_name, "args": tool_args if tool_args is not None else {"url": "http://x.com", "headers": {}}, "id": tool_id} + return req + + +def _402_content(): + payload = json.dumps({"statusCode": 402, "headers": {"x-pay": "v"}, "body": {"x402Version": 1}}) + return f"PAYMENT_REQUIRED: {payload}" + + +def _200_content(): + return json.dumps({"statusCode": 200, "body": {"data": "paid content"}}) + + +# --------------------------------------------------------------------------- +# _generate_payment_header tests +# --------------------------------------------------------------------------- + + +class TestGeneratePaymentHeader: + """Test _generate_payment_header method.""" + + @patch("bedrock_agentcore.payments.integrations.langgraph.middleware.PaymentManager") + def test_calls_pm_with_correct_params(self, mock_pm_cls): + mock_pm = mock_pm_cls.return_value + mock_pm.generate_payment_header.return_value = {"X-PAYMENT": "signed"} + + config = _make_config() + mw = AgentCorePaymentsMiddleware(config) + + payment_req = {"statusCode": 402, "headers": {"h": "1"}, "body": {"b": "2"}} + result = mw._generate_payment_header(payment_req) + + assert result == {"X-PAYMENT": "signed"} + mock_pm.generate_payment_header.assert_called_once() + call_kwargs = mock_pm.generate_payment_header.call_args[1] + assert call_kwargs["user_id"] == "user-1" + assert call_kwargs["payment_instrument_id"] == "instr-1" + assert call_kwargs["payment_session_id"] == "sess-1" + assert call_kwargs["payment_required_request"] is payment_req + assert call_kwargs["client_token"] # uuid string, non-empty + + @patch("bedrock_agentcore.payments.integrations.langgraph.middleware.PaymentManager") + def test_raises_if_no_instrument_id(self, mock_pm_cls): + config = _make_config(payment_instrument_id=None) + mw = AgentCorePaymentsMiddleware(config) + + with pytest.raises(PaymentInstrumentConfigurationRequired): + mw._generate_payment_header({"statusCode": 402, "headers": {}, "body": {}}) + + @patch("bedrock_agentcore.payments.integrations.langgraph.middleware.PaymentManager") + def test_raises_if_no_session_id(self, mock_pm_cls): + config = _make_config(payment_session_id=None) + mw = AgentCorePaymentsMiddleware(config) + + with pytest.raises(PaymentSessionConfigurationRequired): + mw._generate_payment_header({"statusCode": 402, "headers": {}, "body": {}}) + + +# --------------------------------------------------------------------------- +# Header injection tests +# --------------------------------------------------------------------------- + + +class TestHeaderInjection: + """Test that payment headers are correctly injected into tool args.""" + + @patch("bedrock_agentcore.payments.integrations.langgraph.middleware.PaymentManager") + def test_header_injected_into_tool_args(self, mock_pm_cls): + """After signing, the payment header appears in tool_args['headers'].""" + mock_pm = mock_pm_cls.return_value + mock_pm.generate_payment_header.return_value = {"X-PAYMENT": "sig123"} + + config = _make_config() + mw = AgentCorePaymentsMiddleware(config) + + tool_args = {"url": "http://x.com", "headers": {}} + request = _make_request(tool_args=tool_args) + + call_count = [0] + + def mock_handler(req): + call_count[0] += 1 + if call_count[0] == 1: + return ToolMessage(content=_402_content(), tool_call_id="tc-1") + return ToolMessage(content=_200_content(), tool_call_id="tc-1") + + mw.wrap_tool_call(request, mock_handler) + + # Verify header was injected into tool_args + assert tool_args["headers"]["X-PAYMENT"] == "sig123" + + +# --------------------------------------------------------------------------- +# Successful retry tests +# --------------------------------------------------------------------------- + + +class TestSuccessfulRetry: + """Test the full 402 → sign → retry → 200 flow.""" + + @patch("bedrock_agentcore.payments.integrations.langgraph.middleware.PaymentManager") + def test_402_then_200_on_retry(self, mock_pm_cls): + """Tool returns 402, middleware signs, retries, gets 200.""" + mock_pm = mock_pm_cls.return_value + mock_pm.generate_payment_header.return_value = {"X-PAYMENT": "sig"} + + config = _make_config() + mw = AgentCorePaymentsMiddleware(config) + + request = _make_request(tool_args={"url": "http://x.com", "headers": {}}) + success_msg = ToolMessage(content=_200_content(), tool_call_id="tc-1") + + call_count = [0] + + def mock_handler(req): + call_count[0] += 1 + if call_count[0] == 1: + return ToolMessage(content=_402_content(), tool_call_id="tc-1") + return success_msg + + result = mw.wrap_tool_call(request, mock_handler) + + assert result is success_msg + assert call_count[0] == 2 + mock_pm.generate_payment_header.assert_called_once() + + @patch("bedrock_agentcore.payments.integrations.langgraph.middleware.PaymentManager") + def test_handler_called_twice(self, mock_pm_cls): + """The execute handler is called exactly twice: initial + retry.""" + mock_pm = mock_pm_cls.return_value + mock_pm.generate_payment_header.return_value = {"X-PAYMENT": "sig"} + + config = _make_config() + mw = AgentCorePaymentsMiddleware(config) + + request = _make_request(tool_args={"url": "http://x.com", "headers": {}}) + mock_handler = MagicMock(side_effect=[ + ToolMessage(content=_402_content(), tool_call_id="tc-1"), + ToolMessage(content=_200_content(), tool_call_id="tc-1"), + ]) + + mw.wrap_tool_call(request, mock_handler) + assert mock_handler.call_count == 2 + + +# --------------------------------------------------------------------------- +# Post-payment rejection tests +# --------------------------------------------------------------------------- + + +class TestPostPaymentRejection: + """Test detection of 402 after successful signing.""" + + @patch("bedrock_agentcore.payments.integrations.langgraph.middleware.PaymentManager") + def test_402_after_signing_returns_error(self, mock_pm_cls): + """If retry still returns 402, return error ToolMessage.""" + mock_pm = mock_pm_cls.return_value + mock_pm.generate_payment_header.return_value = {"X-PAYMENT": "sig"} + + config = _make_config() + mw = AgentCorePaymentsMiddleware(config) + + request = _make_request(tool_args={"url": "http://x.com", "headers": {}}) + + # Both calls return 402 + mock_handler = MagicMock(side_effect=[ + ToolMessage(content=_402_content(), tool_call_id="tc-1"), + ToolMessage(content=_402_content(), tool_call_id="tc-1"), + ]) + + result = mw.wrap_tool_call(request, mock_handler) + + assert isinstance(result, ToolMessage) + assert "PAYMENT ERROR" in result.content + assert "rejected" in result.content + assert result.tool_call_id == "tc-1" + + @patch("bedrock_agentcore.payments.integrations.langgraph.middleware.PaymentManager") + def test_rejection_error_includes_body_error(self, mock_pm_cls): + """Error message from 402 body is included in the rejection message.""" + mock_pm = mock_pm_cls.return_value + mock_pm.generate_payment_header.return_value = {"X-PAYMENT": "sig"} + + config = _make_config() + mw = AgentCorePaymentsMiddleware(config) + + request = _make_request(tool_args={"url": "http://x.com", "headers": {}}) + + payload_with_error = json.dumps({ + "statusCode": 402, + "headers": {}, + "body": {"error": "insufficient_balance"}, + }) + content_402_with_error = f"PAYMENT_REQUIRED: {payload_with_error}" + + mock_handler = MagicMock(side_effect=[ + ToolMessage(content=_402_content(), tool_call_id="tc-1"), + ToolMessage(content=content_402_with_error, tool_call_id="tc-1"), + ]) + + result = mw.wrap_tool_call(request, mock_handler) + assert "insufficient_balance" in result.content + + +# --------------------------------------------------------------------------- +# Delay tests +# --------------------------------------------------------------------------- + + +class TestRetryDelay: + """Test configurable delay before retry.""" + + @patch("bedrock_agentcore.payments.integrations.langgraph.middleware.time.sleep") + @patch("bedrock_agentcore.payments.integrations.langgraph.middleware.PaymentManager") + def test_delay_applied_before_retry(self, mock_pm_cls, mock_sleep): + mock_pm = mock_pm_cls.return_value + mock_pm.generate_payment_header.return_value = {"X-PAYMENT": "sig"} + + config = _make_config(post_payment_retry_delay_seconds=3.0) + mw = AgentCorePaymentsMiddleware(config) + + request = _make_request(tool_args={"url": "http://x.com", "headers": {}}) + mock_handler = MagicMock(side_effect=[ + ToolMessage(content=_402_content(), tool_call_id="tc-1"), + ToolMessage(content=_200_content(), tool_call_id="tc-1"), + ]) + + mw.wrap_tool_call(request, mock_handler) + mock_sleep.assert_called_once_with(3.0) + + @patch("bedrock_agentcore.payments.integrations.langgraph.middleware.time.sleep") + @patch("bedrock_agentcore.payments.integrations.langgraph.middleware.PaymentManager") + def test_zero_delay_skips_sleep(self, mock_pm_cls, mock_sleep): + mock_pm = mock_pm_cls.return_value + mock_pm.generate_payment_header.return_value = {"X-PAYMENT": "sig"} + + config = _make_config(post_payment_retry_delay_seconds=0) + mw = AgentCorePaymentsMiddleware(config) + + request = _make_request(tool_args={"url": "http://x.com", "headers": {}}) + mock_handler = MagicMock(side_effect=[ + ToolMessage(content=_402_content(), tool_call_id="tc-1"), + ToolMessage(content=_200_content(), tool_call_id="tc-1"), + ]) + + mw.wrap_tool_call(request, mock_handler) + mock_sleep.assert_not_called() + + +# --------------------------------------------------------------------------- +# Error ToolMessage tests +# --------------------------------------------------------------------------- + + +class TestErrorToolMessage: + """Test error messages returned for various failure cases.""" + + @patch("bedrock_agentcore.payments.integrations.langgraph.middleware.PaymentManager") + def test_missing_instrument_returns_error_msg(self, mock_pm_cls): + config = _make_config(payment_instrument_id=None) + mw = AgentCorePaymentsMiddleware(config) + + request = _make_request(tool_args={"url": "http://x.com", "headers": {}}) + mock_handler = MagicMock(return_value=ToolMessage(content=_402_content(), tool_call_id="tc-1")) + + result = mw.wrap_tool_call(request, mock_handler) + assert "PAYMENT ERROR" in result.content + assert "payment instrument" in result.content + assert result.tool_call_id == "tc-1" + + @patch("bedrock_agentcore.payments.integrations.langgraph.middleware.PaymentManager") + def test_missing_session_returns_error_msg(self, mock_pm_cls): + config = _make_config(payment_session_id=None) + mw = AgentCorePaymentsMiddleware(config) + + request = _make_request(tool_args={"url": "http://x.com", "headers": {}}) + mock_handler = MagicMock(return_value=ToolMessage(content=_402_content(), tool_call_id="tc-1")) + + result = mw.wrap_tool_call(request, mock_handler) + assert "PAYMENT ERROR" in result.content + assert "payment session" in result.content + + @patch("bedrock_agentcore.payments.integrations.langgraph.middleware.PaymentManager") + def test_pm_error_returns_error_msg(self, mock_pm_cls): + mock_pm = mock_pm_cls.return_value + mock_pm.generate_payment_header.side_effect = PaymentError("budget exceeded") + + config = _make_config() + mw = AgentCorePaymentsMiddleware(config) + + request = _make_request(tool_args={"url": "http://x.com", "headers": {}}) + mock_handler = MagicMock(return_value=ToolMessage(content=_402_content(), tool_call_id="tc-1")) + + result = mw.wrap_tool_call(request, mock_handler) + assert "PAYMENT ERROR" in result.content + assert "budget exceeded" in result.content + + @patch("bedrock_agentcore.payments.integrations.langgraph.middleware.PaymentManager") + def test_validate_tool_input_fails_returns_error(self, mock_pm_cls): + """If handler can't validate tool input shape, return error.""" + mock_pm = mock_pm_cls.return_value + mock_pm.generate_payment_header.return_value = {"X-PAYMENT": "sig"} + + config = _make_config() + mw = AgentCorePaymentsMiddleware(config) + + # Tool args not a dict — force validate_tool_input to fail + # We pass args as a non-dict via direct manipulation + request = _make_request(tool_args="not-a-dict") + + mock_handler = MagicMock(return_value=ToolMessage(content=_402_content(), tool_call_id="tc-1")) + + result = mw.wrap_tool_call(request, mock_handler) + assert "PAYMENT ERROR" in result.content + assert "request format" in result.content diff --git a/tests/bedrock_agentcore/payments/integrations/langgraph/test_stage4.py b/tests/bedrock_agentcore/payments/integrations/langgraph/test_stage4.py new file mode 100644 index 00000000..0ec1fa29 --- /dev/null +++ b/tests/bedrock_agentcore/payments/integrations/langgraph/test_stage4.py @@ -0,0 +1,238 @@ +"""Tests for Stage 4: Error Handling — deterministic messages + broad exception guard.""" + +import json +from unittest.mock import MagicMock, patch + +import pytest +from langchain.messages import ToolMessage + +from bedrock_agentcore.payments.integrations.langgraph.config import AgentCorePaymentsConfig +from bedrock_agentcore.payments.integrations.langgraph.middleware import AgentCorePaymentsMiddleware +from bedrock_agentcore.payments.manager import ( + InsufficientBudget, + PaymentError, + PaymentInstrumentConfigurationRequired, + PaymentInstrumentNotFound, + PaymentSessionConfigurationRequired, + PaymentSessionExpired, + PaymentSessionNotFound, +) + + +def _make_config(**overrides): + defaults = { + "payment_manager_arn": "arn:aws:bedrock-agentcore:us-east-1:123456789012:payment-manager/pm-1", + "user_id": "user-1", + "payment_instrument_id": "instr-1", + "payment_session_id": "sess-1", + "post_payment_retry_delay_seconds": 0, + } + defaults.update(overrides) + return AgentCorePaymentsConfig(**defaults) + + +def _make_request(tool_name="http_request", tool_args=None, tool_id="tc-1"): + req = MagicMock() + req.tool_call = {"name": tool_name, "args": tool_args if tool_args is not None else {"url": "http://x.com", "headers": {}}, "id": tool_id} + return req + + +def _402_content(): + payload = json.dumps({"statusCode": 402, "headers": {}, "body": {"x402Version": 1}}) + return f"PAYMENT_REQUIRED: {payload}" + + +# --------------------------------------------------------------------------- +# Deterministic error message tests +# --------------------------------------------------------------------------- + + +class TestDeterministicErrorMessages: + """Each exception type produces the correct deterministic message.""" + + @patch("bedrock_agentcore.payments.integrations.langgraph.middleware.PaymentManager") + def test_instrument_config_required(self, mock_pm_cls): + config = _make_config(payment_instrument_id=None) + mw = AgentCorePaymentsMiddleware(config) + + request = _make_request(tool_args={"url": "http://x.com", "headers": {}}) + handler = MagicMock(return_value=ToolMessage(content=_402_content(), tool_call_id="tc-1")) + + result = mw.wrap_tool_call(request, handler) + assert "No payment instrument configured" in result.content + assert "Do not retry this call" in result.content + assert result.status == "error" + + @patch("bedrock_agentcore.payments.integrations.langgraph.middleware.PaymentManager") + def test_session_config_required(self, mock_pm_cls): + config = _make_config(payment_session_id=None) + mw = AgentCorePaymentsMiddleware(config) + + request = _make_request(tool_args={"url": "http://x.com", "headers": {}}) + handler = MagicMock(return_value=ToolMessage(content=_402_content(), tool_call_id="tc-1")) + + result = mw.wrap_tool_call(request, handler) + assert "No payment session configured" in result.content + assert "Do not retry this call" in result.content + + @patch("bedrock_agentcore.payments.integrations.langgraph.middleware.PaymentManager") + def test_instrument_not_found(self, mock_pm_cls): + mock_pm_cls.return_value.generate_payment_header.side_effect = PaymentInstrumentNotFound("not found") + mw = AgentCorePaymentsMiddleware(_make_config()) + + request = _make_request(tool_args={"url": "http://x.com", "headers": {}}) + handler = MagicMock(return_value=ToolMessage(content=_402_content(), tool_call_id="tc-1")) + + result = mw.wrap_tool_call(request, handler) + assert "Payment instrument not found" in result.content + assert "Do not retry this call" in result.content + + @patch("bedrock_agentcore.payments.integrations.langgraph.middleware.PaymentManager") + def test_session_not_found(self, mock_pm_cls): + mock_pm_cls.return_value.generate_payment_header.side_effect = PaymentSessionNotFound("gone") + mw = AgentCorePaymentsMiddleware(_make_config()) + + request = _make_request(tool_args={"url": "http://x.com", "headers": {}}) + handler = MagicMock(return_value=ToolMessage(content=_402_content(), tool_call_id="tc-1")) + + result = mw.wrap_tool_call(request, handler) + assert "Payment session not found" in result.content + assert "Do not retry this call" in result.content + + @patch("bedrock_agentcore.payments.integrations.langgraph.middleware.PaymentManager") + def test_session_expired(self, mock_pm_cls): + mock_pm_cls.return_value.generate_payment_header.side_effect = PaymentSessionExpired("expired") + mw = AgentCorePaymentsMiddleware(_make_config()) + + request = _make_request(tool_args={"url": "http://x.com", "headers": {}}) + handler = MagicMock(return_value=ToolMessage(content=_402_content(), tool_call_id="tc-1")) + + result = mw.wrap_tool_call(request, handler) + assert "Payment session has expired" in result.content + assert "Do not retry this call" in result.content + + @patch("bedrock_agentcore.payments.integrations.langgraph.middleware.PaymentManager") + def test_insufficient_budget(self, mock_pm_cls): + mock_pm_cls.return_value.generate_payment_header.side_effect = InsufficientBudget("over limit") + mw = AgentCorePaymentsMiddleware(_make_config()) + + request = _make_request(tool_args={"url": "http://x.com", "headers": {}}) + handler = MagicMock(return_value=ToolMessage(content=_402_content(), tool_call_id="tc-1")) + + result = mw.wrap_tool_call(request, handler) + assert "Insufficient budget" in result.content + assert "Do not retry this call" in result.content + + @patch("bedrock_agentcore.payments.integrations.langgraph.middleware.PaymentManager") + def test_generic_payment_error(self, mock_pm_cls): + mock_pm_cls.return_value.generate_payment_header.side_effect = PaymentError("something broke") + mw = AgentCorePaymentsMiddleware(_make_config()) + + request = _make_request(tool_args={"url": "http://x.com", "headers": {}}) + handler = MagicMock(return_value=ToolMessage(content=_402_content(), tool_call_id="tc-1")) + + result = mw.wrap_tool_call(request, handler) + assert "Payment processing failed" in result.content + assert "something broke" in result.content + assert "Do not retry this call" in result.content + + @patch("bedrock_agentcore.payments.integrations.langgraph.middleware.PaymentManager") + def test_post_payment_rejection(self, mock_pm_cls): + mock_pm_cls.return_value.generate_payment_header.return_value = {"X-PAYMENT": "sig"} + mw = AgentCorePaymentsMiddleware(_make_config()) + + payload_with_error = json.dumps({"statusCode": 402, "headers": {}, "body": {"error": "bad_sig"}}) + request = _make_request(tool_args={"url": "http://x.com", "headers": {}}) + handler = MagicMock(side_effect=[ + ToolMessage(content=_402_content(), tool_call_id="tc-1"), + ToolMessage(content=f"PAYMENT_REQUIRED: {payload_with_error}", tool_call_id="tc-1"), + ]) + + result = mw.wrap_tool_call(request, handler) + assert "signed but rejected" in result.content + assert "bad_sig" in result.content + assert "Do not retry this call" in result.content + + @patch("bedrock_agentcore.payments.integrations.langgraph.middleware.PaymentManager") + def test_tool_input_validation_failed(self, mock_pm_cls): + mock_pm_cls.return_value.generate_payment_header.return_value = {"X-PAYMENT": "sig"} + mw = AgentCorePaymentsMiddleware(_make_config()) + + request = _make_request(tool_args="not-a-dict") + handler = MagicMock(return_value=ToolMessage(content=_402_content(), tool_call_id="tc-1")) + + result = mw.wrap_tool_call(request, handler) + assert "Could not apply payment credentials" in result.content + assert "Do not retry this call" in result.content + + +# --------------------------------------------------------------------------- +# Broad exception guard tests +# --------------------------------------------------------------------------- + + +class TestUnexpectedExceptionHandling: + """Unexpected exceptions are caught and returned as error ToolMessages.""" + + @patch("bedrock_agentcore.payments.integrations.langgraph.middleware.PaymentManager") + def test_unexpected_runtime_error(self, mock_pm_cls): + mock_pm_cls.return_value.generate_payment_header.side_effect = RuntimeError("boom") + mw = AgentCorePaymentsMiddleware(_make_config()) + + request = _make_request(tool_args={"url": "http://x.com", "headers": {}}) + handler = MagicMock(return_value=ToolMessage(content=_402_content(), tool_call_id="tc-1")) + + # Should NOT raise — returns error ToolMessage + result = mw.wrap_tool_call(request, handler) + assert isinstance(result, ToolMessage) + assert "unexpected error" in result.content + assert "boom" in result.content + assert "Do not retry this call" in result.content + + @patch("bedrock_agentcore.payments.integrations.langgraph.middleware.PaymentManager") + def test_unexpected_type_error(self, mock_pm_cls): + mock_pm_cls.return_value.generate_payment_header.side_effect = TypeError("bad type") + mw = AgentCorePaymentsMiddleware(_make_config()) + + request = _make_request(tool_args={"url": "http://x.com", "headers": {}}) + handler = MagicMock(return_value=ToolMessage(content=_402_content(), tool_call_id="tc-1")) + + result = mw.wrap_tool_call(request, handler) + assert isinstance(result, ToolMessage) + assert "unexpected error" in result.content + assert "bad type" in result.content + + +# --------------------------------------------------------------------------- +# Guard regression tests +# --------------------------------------------------------------------------- + + +class TestGuardRegression: + """Guards bypass payment processing entirely — no error messages leak.""" + + @patch("bedrock_agentcore.payments.integrations.langgraph.middleware.PaymentManager") + def test_auto_payment_false_no_error(self, mock_pm_cls): + config = _make_config(auto_payment=False) + mw = AgentCorePaymentsMiddleware(config) + + tool_msg = ToolMessage(content=_402_content(), tool_call_id="tc-1") + request = _make_request() + handler = MagicMock(return_value=tool_msg) + + result = mw.wrap_tool_call(request, handler) + assert result is tool_msg + assert "PAYMENT ERROR" not in result.content + + @patch("bedrock_agentcore.payments.integrations.langgraph.middleware.PaymentManager") + def test_not_in_allowlist_no_error(self, mock_pm_cls): + config = _make_config(payment_tool_allowlist=["other_tool"]) + mw = AgentCorePaymentsMiddleware(config) + + tool_msg = ToolMessage(content=_402_content(), tool_call_id="tc-1") + request = _make_request(tool_name="http_request") + handler = MagicMock(return_value=tool_msg) + + result = mw.wrap_tool_call(request, handler) + assert result is tool_msg + assert "PAYMENT ERROR" not in result.content diff --git a/tests/bedrock_agentcore/payments/integrations/langgraph/test_stage5.py b/tests/bedrock_agentcore/payments/integrations/langgraph/test_stage5.py new file mode 100644 index 00000000..69c1ba79 --- /dev/null +++ b/tests/bedrock_agentcore/payments/integrations/langgraph/test_stage5.py @@ -0,0 +1,246 @@ +"""Tests for Stage 5: Async awrap_tool_call.""" + +import asyncio +import json +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest +from langchain.messages import ToolMessage +from langgraph.types import Command + +from bedrock_agentcore.payments.integrations.langgraph.config import AgentCorePaymentsConfig +from bedrock_agentcore.payments.integrations.langgraph.middleware import AgentCorePaymentsMiddleware + + +def _make_config(**overrides): + defaults = { + "payment_manager_arn": "arn:aws:bedrock-agentcore:us-east-1:123456789012:payment-manager/pm-1", + "user_id": "user-1", + "payment_instrument_id": "instr-1", + "payment_session_id": "sess-1", + "post_payment_retry_delay_seconds": 0, + } + defaults.update(overrides) + return AgentCorePaymentsConfig(**defaults) + + +def _make_request(tool_name="http_request", tool_args=None, tool_id="tc-1"): + req = MagicMock() + req.tool_call = {"name": tool_name, "args": tool_args if tool_args is not None else {"url": "http://x.com", "headers": {}}, "id": tool_id} + return req + + +def _402_content(): + payload = json.dumps({"statusCode": 402, "headers": {}, "body": {"x402Version": 1}}) + return f"PAYMENT_REQUIRED: {payload}" + + +def _200_content(): + return json.dumps({"statusCode": 200, "body": {"data": "paid"}}) + + +# --------------------------------------------------------------------------- +# Basic async pass-through +# --------------------------------------------------------------------------- + + +class TestAsyncPassThrough: + """Test basic async behavior for non-payment cases.""" + + @patch("bedrock_agentcore.payments.integrations.langgraph.middleware.PaymentManager") + def test_non_402_passes_through(self, mock_pm_cls): + mw = AgentCorePaymentsMiddleware(_make_config()) + tool_msg = ToolMessage(content="normal output", tool_call_id="tc-1") + handler = AsyncMock(return_value=tool_msg) + + result = asyncio.run(mw.awrap_tool_call(_make_request(), handler)) + assert result is tool_msg + + @patch("bedrock_agentcore.payments.integrations.langgraph.middleware.PaymentManager") + def test_command_passes_through(self, mock_pm_cls): + mw = AgentCorePaymentsMiddleware(_make_config()) + cmd = Command(update={"k": "v"}) + handler = AsyncMock(return_value=cmd) + + result = asyncio.run(mw.awrap_tool_call(_make_request(), handler)) + assert result is cmd + + +# --------------------------------------------------------------------------- +# Full retry flow +# --------------------------------------------------------------------------- + + +class TestAsyncRetryFlow: + """Test 402 → sign → retry → 200 async flow.""" + + @patch("bedrock_agentcore.payments.integrations.langgraph.middleware.PaymentManager") + def test_402_then_200_on_retry(self, mock_pm_cls): + mock_pm_cls.return_value.generate_payment_header.return_value = {"X-PAYMENT": "sig"} + mw = AgentCorePaymentsMiddleware(_make_config()) + + success_msg = ToolMessage(content=_200_content(), tool_call_id="tc-1") + handler = AsyncMock(side_effect=[ + ToolMessage(content=_402_content(), tool_call_id="tc-1"), + success_msg, + ]) + + result = asyncio.run(mw.awrap_tool_call(_make_request(tool_args={"url": "http://x.com", "headers": {}}), handler)) + assert result is success_msg + + @patch("bedrock_agentcore.payments.integrations.langgraph.middleware.PaymentManager") + def test_handler_awaited_twice(self, mock_pm_cls): + mock_pm_cls.return_value.generate_payment_header.return_value = {"X-PAYMENT": "sig"} + mw = AgentCorePaymentsMiddleware(_make_config()) + + handler = AsyncMock(side_effect=[ + ToolMessage(content=_402_content(), tool_call_id="tc-1"), + ToolMessage(content=_200_content(), tool_call_id="tc-1"), + ]) + + asyncio.run(mw.awrap_tool_call(_make_request(tool_args={"url": "http://x.com", "headers": {}}), handler)) + assert handler.await_count == 2 + + +# --------------------------------------------------------------------------- +# asyncio.sleep verification +# --------------------------------------------------------------------------- + + +class TestAsyncSleepUsed: + """Verify asyncio.sleep is used (not time.sleep).""" + + @patch("bedrock_agentcore.payments.integrations.langgraph.middleware.PaymentManager") + def test_asyncio_sleep_called(self, mock_pm_cls): + mock_pm_cls.return_value.generate_payment_header.return_value = {"X-PAYMENT": "sig"} + config = _make_config(post_payment_retry_delay_seconds=3.0) + mw = AgentCorePaymentsMiddleware(config) + + handler = AsyncMock(side_effect=[ + ToolMessage(content=_402_content(), tool_call_id="tc-1"), + ToolMessage(content=_200_content(), tool_call_id="tc-1"), + ]) + + with patch("asyncio.sleep", new_callable=AsyncMock) as mock_async_sleep: + asyncio.run(mw.awrap_tool_call(_make_request(tool_args={"url": "http://x.com", "headers": {}}), handler)) + mock_async_sleep.assert_called_once_with(3.0) + + @patch("bedrock_agentcore.payments.integrations.langgraph.middleware.time.sleep") + @patch("bedrock_agentcore.payments.integrations.langgraph.middleware.PaymentManager") + def test_time_sleep_not_called(self, mock_pm_cls, mock_time_sleep): + mock_pm_cls.return_value.generate_payment_header.return_value = {"X-PAYMENT": "sig"} + config = _make_config(post_payment_retry_delay_seconds=3.0) + mw = AgentCorePaymentsMiddleware(config) + + handler = AsyncMock(side_effect=[ + ToolMessage(content=_402_content(), tool_call_id="tc-1"), + ToolMessage(content=_200_content(), tool_call_id="tc-1"), + ]) + + with patch("asyncio.sleep", new_callable=AsyncMock): + asyncio.run(mw.awrap_tool_call(_make_request(tool_args={"url": "http://x.com", "headers": {}}), handler)) + + mock_time_sleep.assert_not_called() + + +# --------------------------------------------------------------------------- +# asyncio.to_thread verification +# --------------------------------------------------------------------------- + + +class TestAsyncToThread: + """Verify _generate_payment_header runs via asyncio.to_thread.""" + + @patch("bedrock_agentcore.payments.integrations.langgraph.middleware.PaymentManager") + def test_generate_header_runs_in_thread(self, mock_pm_cls): + mock_pm_cls.return_value.generate_payment_header.return_value = {"X-PAYMENT": "sig"} + mw = AgentCorePaymentsMiddleware(_make_config()) + + handler = AsyncMock(side_effect=[ + ToolMessage(content=_402_content(), tool_call_id="tc-1"), + ToolMessage(content=_200_content(), tool_call_id="tc-1"), + ]) + + with patch("asyncio.to_thread", new_callable=AsyncMock, return_value={"X-PAYMENT": "sig"}) as mock_to_thread: + asyncio.run(mw.awrap_tool_call(_make_request(tool_args={"url": "http://x.com", "headers": {}}), handler)) + mock_to_thread.assert_called_once() + # First arg is the method + assert mock_to_thread.call_args[0][0] == mw._generate_payment_header + + +# --------------------------------------------------------------------------- +# Async error handling +# --------------------------------------------------------------------------- + + +class TestAsyncErrorHandling: + """Async path produces same error messages as sync path.""" + + @patch("bedrock_agentcore.payments.integrations.langgraph.middleware.PaymentManager") + def test_missing_instrument_error(self, mock_pm_cls): + config = _make_config(payment_instrument_id=None) + mw = AgentCorePaymentsMiddleware(config) + + handler = AsyncMock(return_value=ToolMessage(content=_402_content(), tool_call_id="tc-1")) + + result = asyncio.run(mw.awrap_tool_call(_make_request(tool_args={"url": "http://x.com", "headers": {}}), handler)) + assert "PAYMENT ERROR" in result.content + assert "No payment instrument configured" in result.content + assert result.status == "error" + + @patch("bedrock_agentcore.payments.integrations.langgraph.middleware.PaymentManager") + def test_post_payment_rejection(self, mock_pm_cls): + mock_pm_cls.return_value.generate_payment_header.return_value = {"X-PAYMENT": "sig"} + mw = AgentCorePaymentsMiddleware(_make_config()) + + handler = AsyncMock(side_effect=[ + ToolMessage(content=_402_content(), tool_call_id="tc-1"), + ToolMessage(content=_402_content(), tool_call_id="tc-1"), + ]) + + result = asyncio.run(mw.awrap_tool_call(_make_request(tool_args={"url": "http://x.com", "headers": {}}), handler)) + assert "signed but rejected" in result.content + + @patch("bedrock_agentcore.payments.integrations.langgraph.middleware.PaymentManager") + def test_unexpected_exception(self, mock_pm_cls): + mock_pm_cls.return_value.generate_payment_header.side_effect = RuntimeError("async boom") + mw = AgentCorePaymentsMiddleware(_make_config()) + + handler = AsyncMock(return_value=ToolMessage(content=_402_content(), tool_call_id="tc-1")) + + result = asyncio.run(mw.awrap_tool_call(_make_request(tool_args={"url": "http://x.com", "headers": {}}), handler)) + assert isinstance(result, ToolMessage) + assert "unexpected error" in result.content + assert "async boom" in result.content + + +# --------------------------------------------------------------------------- +# Async guards +# --------------------------------------------------------------------------- + + +class TestAsyncGuards: + """Guards work identically in async path.""" + + @patch("bedrock_agentcore.payments.integrations.langgraph.middleware.PaymentManager") + def test_auto_payment_false_skips(self, mock_pm_cls): + config = _make_config(auto_payment=False) + mw = AgentCorePaymentsMiddleware(config) + + tool_msg = ToolMessage(content=_402_content(), tool_call_id="tc-1") + handler = AsyncMock(return_value=tool_msg) + + result = asyncio.run(mw.awrap_tool_call(_make_request(), handler)) + assert result is tool_msg + assert "PAYMENT ERROR" not in result.content + + @patch("bedrock_agentcore.payments.integrations.langgraph.middleware.PaymentManager") + def test_allowlist_skips(self, mock_pm_cls): + config = _make_config(payment_tool_allowlist=["other_tool"]) + mw = AgentCorePaymentsMiddleware(config) + + tool_msg = ToolMessage(content=_402_content(), tool_call_id="tc-1") + handler = AsyncMock(return_value=tool_msg) + + result = asyncio.run(mw.awrap_tool_call(_make_request(tool_name="http_request"), handler)) + assert result is tool_msg diff --git a/tests/bedrock_agentcore/payments/integrations/langgraph/test_stage6.py b/tests/bedrock_agentcore/payments/integrations/langgraph/test_stage6.py new file mode 100644 index 00000000..25880ae9 --- /dev/null +++ b/tests/bedrock_agentcore/payments/integrations/langgraph/test_stage6.py @@ -0,0 +1,251 @@ +"""Tests for Stage 6: Built-in Tools.""" + +import json +from unittest.mock import MagicMock, patch, PropertyMock + +import pytest + +from bedrock_agentcore.payments.integrations.langgraph.config import AgentCorePaymentsConfig +from bedrock_agentcore.payments.integrations.langgraph.middleware import AgentCorePaymentsMiddleware + + +def _make_config(**overrides): + defaults = { + "payment_manager_arn": "arn:aws:bedrock-agentcore:us-east-1:123456789012:payment-manager/pm-1", + "user_id": "user-1", + "payment_instrument_id": "instr-1", + "payment_session_id": "sess-1", + "payment_connector_id": "conn-1", + "post_payment_retry_delay_seconds": 0, + } + defaults.update(overrides) + return AgentCorePaymentsConfig(**defaults) + + +def _get_tool_by_name(mw, name): + for t in mw.tools: + if t.name == name: + return t + return None + + +# --------------------------------------------------------------------------- +# Conditional registration tests +# --------------------------------------------------------------------------- + + +class TestConditionalRegistration: + """Test provide_http_request flag controls tool registration.""" + + @patch("bedrock_agentcore.payments.integrations.langgraph.middleware.PaymentManager") + def test_provide_http_request_true_includes_tool(self, mock_pm_cls): + mw = AgentCorePaymentsMiddleware(_make_config(provide_http_request=True)) + names = [t.name for t in mw.tools] + assert "http_request" in names + + @patch("bedrock_agentcore.payments.integrations.langgraph.middleware.PaymentManager") + def test_provide_http_request_false_excludes_tool(self, mock_pm_cls): + mw = AgentCorePaymentsMiddleware(_make_config(provide_http_request=False)) + names = [t.name for t in mw.tools] + assert "http_request" not in names + + @patch("bedrock_agentcore.payments.integrations.langgraph.middleware.PaymentManager") + def test_query_tools_always_registered(self, mock_pm_cls): + mw = AgentCorePaymentsMiddleware(_make_config(provide_http_request=False)) + names = [t.name for t in mw.tools] + assert "get_payment_instrument" in names + assert "list_payment_instruments" in names + assert "get_payment_instrument_balance" in names + assert "get_payment_session" in names + + +# --------------------------------------------------------------------------- +# http_request tool tests +# --------------------------------------------------------------------------- + + +class TestHttpRequestTool: + """Test http_request tool behavior.""" + + @patch("bedrock_agentcore.payments.integrations.langgraph.middleware.PaymentManager") + def test_402_returns_payment_required_marker(self, mock_pm_cls): + mw = AgentCorePaymentsMiddleware(_make_config()) + tool = _get_tool_by_name(mw, "http_request") + + mock_resp = MagicMock() + mock_resp.status_code = 402 + mock_resp.headers = {"content-type": "application/json"} + mock_resp.json.return_value = {"x402Version": 1, "accepts": []} + + with patch("httpx.Client") as mock_client_cls: + mock_client_cls.return_value.__enter__ = MagicMock(return_value=mock_client_cls.return_value) + mock_client_cls.return_value.__exit__ = MagicMock(return_value=False) + mock_client_cls.return_value.request.return_value = mock_resp + + result = tool.invoke({"url": "http://example.com"}) + + assert result.startswith("PAYMENT_REQUIRED: ") + parsed = json.loads(result[len("PAYMENT_REQUIRED: "):]) + assert parsed["statusCode"] == 402 + + @patch("bedrock_agentcore.payments.integrations.langgraph.middleware.PaymentManager") + def test_200_returns_json_payload(self, mock_pm_cls): + mw = AgentCorePaymentsMiddleware(_make_config()) + tool = _get_tool_by_name(mw, "http_request") + + mock_resp = MagicMock() + mock_resp.status_code = 200 + mock_resp.headers = {"content-type": "application/json"} + mock_resp.json.return_value = {"data": "content"} + + with patch("httpx.Client") as mock_client_cls: + mock_client_cls.return_value.__enter__ = MagicMock(return_value=mock_client_cls.return_value) + mock_client_cls.return_value.__exit__ = MagicMock(return_value=False) + mock_client_cls.return_value.request.return_value = mock_resp + + result = tool.invoke({"url": "http://example.com"}) + + parsed = json.loads(result) + assert parsed["statusCode"] == 200 + assert parsed["body"] == {"data": "content"} + + @patch("bedrock_agentcore.payments.integrations.langgraph.middleware.PaymentManager") + def test_network_error_returns_error(self, mock_pm_cls): + import httpx as _httpx + + mw = AgentCorePaymentsMiddleware(_make_config()) + tool = _get_tool_by_name(mw, "http_request") + + with patch("httpx.Client") as mock_client_cls: + mock_client_cls.return_value.__enter__ = MagicMock(return_value=mock_client_cls.return_value) + mock_client_cls.return_value.__exit__ = MagicMock(return_value=False) + mock_client_cls.return_value.request.side_effect = _httpx.ConnectError("connection refused") + + result = tool.invoke({"url": "http://example.com"}) + + parsed = json.loads(result) + assert parsed["statusCode"] == 0 + assert "connection refused" in parsed["error"] + + @patch("bedrock_agentcore.payments.integrations.langgraph.middleware.PaymentManager") + def test_post_with_json_body(self, mock_pm_cls): + mw = AgentCorePaymentsMiddleware(_make_config()) + tool = _get_tool_by_name(mw, "http_request") + + mock_resp = MagicMock() + mock_resp.status_code = 200 + mock_resp.headers = {} + mock_resp.json.return_value = {"ok": True} + + with patch("httpx.Client") as mock_client_cls: + client = mock_client_cls.return_value + client.__enter__ = MagicMock(return_value=client) + client.__exit__ = MagicMock(return_value=False) + client.request.return_value = mock_resp + + tool.invoke({"url": "http://x.com", "method": "POST", "body": {"k": "v"}}) + client.request.assert_called_once_with("POST", "http://x.com", headers={}, json={"k": "v"}) + + +# --------------------------------------------------------------------------- +# Payment query tool tests +# --------------------------------------------------------------------------- + + +class TestPaymentQueryTools: + """Test payment query tools call PaymentManager correctly.""" + + @patch("bedrock_agentcore.payments.integrations.langgraph.middleware.PaymentManager") + def test_get_payment_instrument(self, mock_pm_cls): + mock_pm = mock_pm_cls.return_value + mock_pm.get_payment_instrument.return_value = {"paymentInstrumentId": "instr-1"} + + mw = AgentCorePaymentsMiddleware(_make_config()) + tool = _get_tool_by_name(mw, "get_payment_instrument") + + result = tool.invoke({"payment_instrument_id": "instr-99", "user_id": "u2"}) + mock_pm.get_payment_instrument.assert_called_once_with( + user_id="u2", + payment_instrument_id="instr-99", + payment_connector_id=None, + ) + assert result == {"paymentInstrumentId": "instr-1"} + + @patch("bedrock_agentcore.payments.integrations.langgraph.middleware.PaymentManager") + def test_get_payment_instrument_falls_back_to_config(self, mock_pm_cls): + mock_pm = mock_pm_cls.return_value + mock_pm.get_payment_instrument.return_value = {"paymentInstrumentId": "instr-1"} + + mw = AgentCorePaymentsMiddleware(_make_config()) + tool = _get_tool_by_name(mw, "get_payment_instrument") + + tool.invoke({}) + mock_pm.get_payment_instrument.assert_called_once_with( + user_id="user-1", + payment_instrument_id="instr-1", + payment_connector_id=None, + ) + + @patch("bedrock_agentcore.payments.integrations.langgraph.middleware.PaymentManager") + def test_list_payment_instruments(self, mock_pm_cls): + mock_pm = mock_pm_cls.return_value + mock_pm.list_payment_instruments.return_value = {"paymentInstruments": []} + + mw = AgentCorePaymentsMiddleware(_make_config()) + tool = _get_tool_by_name(mw, "list_payment_instruments") + + result = tool.invoke({"user_id": "u2"}) + mock_pm.list_payment_instruments.assert_called_once_with( + user_id="u2", + payment_connector_id=None, + max_results=100, + next_token=None, + ) + assert result == {"paymentInstruments": []} + + @patch("bedrock_agentcore.payments.integrations.langgraph.middleware.PaymentManager") + def test_get_payment_instrument_balance(self, mock_pm_cls): + mock_pm = mock_pm_cls.return_value + mock_pm.get_payment_instrument_balance.return_value = {"tokenBalance": {"amount": "10.0"}} + + mw = AgentCorePaymentsMiddleware(_make_config()) + tool = _get_tool_by_name(mw, "get_payment_instrument_balance") + + result = tool.invoke({"payment_instrument_id": "instr-1", "chain": "BASE_SEPOLIA"}) + mock_pm.get_payment_instrument_balance.assert_called_once_with( + payment_connector_id="conn-1", + payment_instrument_id="instr-1", + chain="BASE_SEPOLIA", + token="USDC", + user_id="user-1", + ) + assert result["tokenBalance"]["amount"] == "10.0" + + @patch("bedrock_agentcore.payments.integrations.langgraph.middleware.PaymentManager") + def test_get_payment_session(self, mock_pm_cls): + mock_pm = mock_pm_cls.return_value + mock_pm.get_payment_session.return_value = {"paymentSessionId": "sess-1"} + + mw = AgentCorePaymentsMiddleware(_make_config()) + tool = _get_tool_by_name(mw, "get_payment_session") + + result = tool.invoke({"payment_session_id": "sess-99", "user_id": "u3"}) + mock_pm.get_payment_session.assert_called_once_with( + user_id="u3", + payment_session_id="sess-99", + ) + assert result == {"paymentSessionId": "sess-1"} + + @patch("bedrock_agentcore.payments.integrations.langgraph.middleware.PaymentManager") + def test_get_payment_session_falls_back_to_config(self, mock_pm_cls): + mock_pm = mock_pm_cls.return_value + mock_pm.get_payment_session.return_value = {"paymentSessionId": "sess-1"} + + mw = AgentCorePaymentsMiddleware(_make_config()) + tool = _get_tool_by_name(mw, "get_payment_session") + + tool.invoke({}) + mock_pm.get_payment_session.assert_called_once_with( + user_id="user-1", + payment_session_id="sess-1", + ) diff --git a/tests/bedrock_agentcore/payments/integrations/langgraph/test_stage7.py b/tests/bedrock_agentcore/payments/integrations/langgraph/test_stage7.py new file mode 100644 index 00000000..9c5f4840 --- /dev/null +++ b/tests/bedrock_agentcore/payments/integrations/langgraph/test_stage7.py @@ -0,0 +1,376 @@ +"""Tests for Stage 7: Error Handler Callback.""" + +import asyncio +import json +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest +from langchain.messages import ToolMessage +from langgraph.types import Command + +from bedrock_agentcore.payments.integrations.langgraph.config import AgentCorePaymentsConfig +from bedrock_agentcore.payments.integrations.langgraph.errors import ErrorResolution, PaymentErrorContext +from bedrock_agentcore.payments.integrations.langgraph.middleware import AgentCorePaymentsMiddleware +from bedrock_agentcore.payments.manager import ( + InsufficientBudget, + PaymentError, + PaymentInstrumentConfigurationRequired, + PaymentSessionExpired, +) + + +def _402_content(): + payload = json.dumps({"statusCode": 402, "headers": {}, "body": {"x402Version": 1}}) + return f"PAYMENT_REQUIRED: {payload}" + + +def _200_content(): + return json.dumps({"statusCode": 200, "body": {"data": "paid"}}) + + +def _make_config(**overrides): + defaults = { + "payment_manager_arn": "arn:aws:bedrock-agentcore:us-east-1:123456789012:payment-manager/pm-1", + "user_id": "user-1", + "payment_instrument_id": "instr-1", + "payment_session_id": "sess-1", + "post_payment_retry_delay_seconds": 0, + } + defaults.update(overrides) + return AgentCorePaymentsConfig(**defaults) + + +def _make_request(tool_name="http_request", tool_args=None, tool_id="tc-1"): + req = MagicMock() + req.tool_call = {"name": tool_name, "args": tool_args if tool_args is not None else {"url": "http://x.com", "headers": {}}, "id": tool_id} + return req + + +# --------------------------------------------------------------------------- +# Basic callback flow +# --------------------------------------------------------------------------- + + +class TestErrorHandlerBasicFlow: + """Callback invoked on payment errors and can resolve them.""" + + @patch("bedrock_agentcore.payments.integrations.langgraph.middleware.PaymentManager") + def test_callback_fixes_missing_instrument(self, mock_pm_cls): + """Callback sets instrument_id, returns RETRY → payment succeeds.""" + mock_pm = mock_pm_cls.return_value + # First call fails (no instrument), after callback fixes it, second call succeeds + mock_pm.generate_payment_header.side_effect = [ + PaymentInstrumentConfigurationRequired("missing"), + {"X-PAYMENT": "sig"}, + ] + + def handler_cb(ctx): + ctx.config.payment_instrument_id = "fixed-instr" + return ErrorResolution.RETRY + + config = _make_config(payment_instrument_id=None, on_payment_error=handler_cb) + mw = AgentCorePaymentsMiddleware(config) + + request = _make_request(tool_args={"url": "http://x.com", "headers": {}}) + mock_handler = MagicMock(side_effect=[ + ToolMessage(content=_402_content(), tool_call_id="tc-1"), + ToolMessage(content=_200_content(), tool_call_id="tc-1"), + ]) + + result = mw.wrap_tool_call(request, mock_handler) + assert "PAYMENT ERROR" not in result.content + assert json.loads(result.content)["statusCode"] == 200 + + @patch("bedrock_agentcore.payments.integrations.langgraph.middleware.PaymentManager") + def test_callback_not_invoked_when_none(self, mock_pm_cls): + """No callback → standard error ToolMessage.""" + mock_pm_cls.return_value.generate_payment_header.side_effect = PaymentError("fail") + config = _make_config(on_payment_error=None) + mw = AgentCorePaymentsMiddleware(config) + + request = _make_request(tool_args={"url": "http://x.com", "headers": {}}) + mock_handler = MagicMock(return_value=ToolMessage(content=_402_content(), tool_call_id="tc-1")) + + result = mw.wrap_tool_call(request, mock_handler) + assert "PAYMENT ERROR" in result.content + + +# --------------------------------------------------------------------------- +# PROPAGATE resolution +# --------------------------------------------------------------------------- + + +class TestPropagateResolution: + """Callback returns PROPAGATE or custom string → error message to LLM.""" + + @patch("bedrock_agentcore.payments.integrations.langgraph.middleware.PaymentManager") + def test_propagate_falls_through(self, mock_pm_cls): + mock_pm_cls.return_value.generate_payment_header.side_effect = InsufficientBudget("over") + config = _make_config(on_payment_error=lambda ctx: ErrorResolution.PROPAGATE) + mw = AgentCorePaymentsMiddleware(config) + + request = _make_request(tool_args={"url": "http://x.com", "headers": {}}) + mock_handler = MagicMock(return_value=ToolMessage(content=_402_content(), tool_call_id="tc-1")) + + result = mw.wrap_tool_call(request, mock_handler) + assert "Insufficient budget" in result.content + assert "Do not retry" in result.content + + @patch("bedrock_agentcore.payments.integrations.langgraph.middleware.PaymentManager") + def test_string_return_custom_message(self, mock_pm_cls): + """Returning a string sends that custom message to the LLM.""" + mock_pm_cls.return_value.generate_payment_header.side_effect = PaymentError("fail") + + def cb(ctx): + return "Please visit https://myapp.com/setup to configure your wallet." + + config = _make_config(on_payment_error=cb) + mw = AgentCorePaymentsMiddleware(config) + + request = _make_request(tool_args={"url": "http://x.com", "headers": {}}) + mock_handler = MagicMock(return_value=ToolMessage(content=_402_content(), tool_call_id="tc-1")) + + result = mw.wrap_tool_call(request, mock_handler) + assert "PAYMENT ERROR" in result.content + assert "https://myapp.com/setup" in result.content + assert result.status == "error" + assert result.tool_call_id == "tc-1" + + +# --------------------------------------------------------------------------- +# Retry loop +# --------------------------------------------------------------------------- + + +class TestRetryLoop: + """Callback can retry multiple times up to max_error_retries.""" + + @patch("bedrock_agentcore.payments.integrations.langgraph.middleware.PaymentManager") + def test_max_retries_exhausted(self, mock_pm_cls): + """Callback always retries but PM always fails → exhausts max retries.""" + mock_pm_cls.return_value.generate_payment_header.side_effect = PaymentError("always fails") + config = _make_config(on_payment_error=lambda ctx: ErrorResolution.RETRY, max_error_retries=3) + mw = AgentCorePaymentsMiddleware(config) + + request = _make_request(tool_args={"url": "http://x.com", "headers": {}}) + mock_handler = MagicMock(return_value=ToolMessage(content=_402_content(), tool_call_id="tc-1")) + + result = mw.wrap_tool_call(request, mock_handler) + assert "PAYMENT ERROR" in result.content + + @patch("bedrock_agentcore.payments.integrations.langgraph.middleware.PaymentManager") + def test_retry_count_increments(self, mock_pm_cls): + """retry_count passed to callback increments each time.""" + mock_pm_cls.return_value.generate_payment_header.side_effect = PaymentError("fail") + counts = [] + + def cb(ctx): + counts.append(ctx.retry_count) + return ErrorResolution.RETRY + + config = _make_config(on_payment_error=cb, max_error_retries=3) + mw = AgentCorePaymentsMiddleware(config) + + request = _make_request(tool_args={"url": "http://x.com", "headers": {}}) + mock_handler = MagicMock(return_value=ToolMessage(content=_402_content(), tool_call_id="tc-1")) + + mw.wrap_tool_call(request, mock_handler) + assert counts == [0, 1, 2] + + +# --------------------------------------------------------------------------- +# Exception safety +# --------------------------------------------------------------------------- + + +class TestCallbackExceptionSafety: + """Buggy callbacks don't crash the agent.""" + + @patch("bedrock_agentcore.payments.integrations.langgraph.middleware.PaymentManager") + def test_callback_raises(self, mock_pm_cls): + mock_pm_cls.return_value.generate_payment_header.side_effect = PaymentError("orig") + + def bad_cb(ctx): + raise RuntimeError("callback bug") + + config = _make_config(on_payment_error=bad_cb) + mw = AgentCorePaymentsMiddleware(config) + + request = _make_request(tool_args={"url": "http://x.com", "headers": {}}) + mock_handler = MagicMock(return_value=ToolMessage(content=_402_content(), tool_call_id="tc-1")) + + result = mw.wrap_tool_call(request, mock_handler) + # Falls through to original error message, no crash + assert "PAYMENT ERROR" in result.content + assert "orig" in result.content or "Payment processing failed" in result.content + + +# --------------------------------------------------------------------------- +# Context populated correctly +# --------------------------------------------------------------------------- + + +class TestContextPopulated: + """PaymentErrorContext has correct data.""" + + @patch("bedrock_agentcore.payments.integrations.langgraph.middleware.PaymentManager") + def test_context_fields(self, mock_pm_cls): + mock_pm_cls.return_value.generate_payment_header.side_effect = PaymentSessionExpired("expired!") + captured = [] + + def cb(ctx): + captured.append(ctx) + return ErrorResolution.PROPAGATE + + config = _make_config(on_payment_error=cb) + mw = AgentCorePaymentsMiddleware(config) + + request = _make_request(tool_name="my_tool", tool_args={"url": "http://x.com", "headers": {}}) + mock_handler = MagicMock(return_value=ToolMessage(content=_402_content(), tool_call_id="tc-1")) + + mw.wrap_tool_call(request, mock_handler) + + ctx = captured[0] + assert ctx.exception_type == "PaymentSessionExpired" + assert ctx.exception_message == "expired!" + assert ctx.tool_name == "my_tool" + assert ctx.tool_args == {"url": "http://x.com", "headers": {}} + assert ctx.config is config + assert ctx.retry_count == 0 + assert ctx.payment_required_request is not None + assert ctx.payment_required_request["statusCode"] == 402 + + +# --------------------------------------------------------------------------- +# Async callback support +# --------------------------------------------------------------------------- + + +class TestAsyncErrorHandler: + """Async callbacks are awaited correctly.""" + + @patch("bedrock_agentcore.payments.integrations.langgraph.middleware.PaymentManager") + def test_async_callback_awaited(self, mock_pm_cls): + mock_pm = mock_pm_cls.return_value + mock_pm.generate_payment_header.side_effect = [ + PaymentError("first fail"), + {"X-PAYMENT": "sig"}, + ] + + async def async_cb(ctx): + ctx.config.payment_session_id = "new-sess" + return ErrorResolution.RETRY + + config = _make_config(on_payment_error=async_cb) + mw = AgentCorePaymentsMiddleware(config) + + request = _make_request(tool_args={"url": "http://x.com", "headers": {}}) + handler = AsyncMock(side_effect=[ + ToolMessage(content=_402_content(), tool_call_id="tc-1"), + ToolMessage(content=_200_content(), tool_call_id="tc-1"), + ]) + + result = asyncio.run(mw.awrap_tool_call(request, handler)) + assert "PAYMENT ERROR" not in result.content + assert json.loads(result.content)["statusCode"] == 200 + + @patch("bedrock_agentcore.payments.integrations.langgraph.middleware.PaymentManager") + def test_sync_callback_in_async_path(self, mock_pm_cls): + mock_pm = mock_pm_cls.return_value + mock_pm.generate_payment_header.side_effect = [ + PaymentError("fail"), + {"X-PAYMENT": "sig"}, + ] + + def sync_cb(ctx): + return ErrorResolution.RETRY + + config = _make_config(on_payment_error=sync_cb) + mw = AgentCorePaymentsMiddleware(config) + + request = _make_request(tool_args={"url": "http://x.com", "headers": {}}) + handler = AsyncMock(side_effect=[ + ToolMessage(content=_402_content(), tool_call_id="tc-1"), + ToolMessage(content=_200_content(), tool_call_id="tc-1"), + ]) + + result = asyncio.run(mw.awrap_tool_call(request, handler)) + assert "PAYMENT ERROR" not in result.content + + +# --------------------------------------------------------------------------- +# Config validation +# --------------------------------------------------------------------------- + + +class TestConfigValidation: + """New config fields are validated.""" + + def test_on_payment_error_must_be_callable(self): + with pytest.raises(ValueError, match="on_payment_error must be callable"): + _make_config(on_payment_error="not callable") + + def test_on_payment_error_none_is_valid(self): + config = _make_config(on_payment_error=None) + assert config.on_payment_error is None + + def test_max_error_retries_must_be_int(self): + with pytest.raises(ValueError, match="max_error_retries must be an int"): + _make_config(max_error_retries="three") + + def test_max_error_retries_must_be_non_negative(self): + with pytest.raises(ValueError, match="max_error_retries must be >= 0"): + _make_config(max_error_retries=-1) + + @patch("bedrock_agentcore.payments.integrations.langgraph.middleware.PaymentManager") + def test_max_error_retries_zero_disables_callback(self, mock_pm_cls): + """max_error_retries=0 means callback is never invoked.""" + mock_pm_cls.return_value.generate_payment_header.side_effect = PaymentError("fail") + called = [] + + def cb(ctx): + called.append(True) + return ErrorResolution.RETRY + + config = _make_config(on_payment_error=cb, max_error_retries=0) + mw = AgentCorePaymentsMiddleware(config) + + request = _make_request(tool_args={"url": "http://x.com", "headers": {}}) + mock_handler = MagicMock(return_value=ToolMessage(content=_402_content(), tool_call_id="tc-1")) + + result = mw.wrap_tool_call(request, mock_handler) + assert called == [] + assert "PAYMENT ERROR" in result.content + + +# --------------------------------------------------------------------------- +# Backward compatibility +# --------------------------------------------------------------------------- + + +class TestBackwardCompatibility: + """No regression when on_payment_error is None.""" + + @patch("bedrock_agentcore.payments.integrations.langgraph.middleware.PaymentManager") + def test_no_callback_instrument_missing(self, mock_pm_cls): + config = _make_config(payment_instrument_id=None, on_payment_error=None) + mw = AgentCorePaymentsMiddleware(config) + + request = _make_request(tool_args={"url": "http://x.com", "headers": {}}) + mock_handler = MagicMock(return_value=ToolMessage(content=_402_content(), tool_call_id="tc-1")) + + result = mw.wrap_tool_call(request, mock_handler) + assert "No payment instrument configured" in result.content + assert "Do not retry" in result.content + + @patch("bedrock_agentcore.payments.integrations.langgraph.middleware.PaymentManager") + def test_no_callback_session_expired(self, mock_pm_cls): + mock_pm_cls.return_value.generate_payment_header.side_effect = PaymentSessionExpired("exp") + config = _make_config(on_payment_error=None) + mw = AgentCorePaymentsMiddleware(config) + + request = _make_request(tool_args={"url": "http://x.com", "headers": {}}) + mock_handler = MagicMock(return_value=ToolMessage(content=_402_content(), tool_call_id="tc-1")) + + result = mw.wrap_tool_call(request, mock_handler) + assert "Payment session has expired" in result.content From e84026fed159ff591cf26c2ef5ebbd1588212596 Mon Sep 17 00:00:00 2001 From: Raghav Sunil Date: Tue, 23 Jun 2026 19:58:29 +0000 Subject: [PATCH 2/6] refactor(payments): Unify LangGraph and Strands config into single class Merge AgentCorePaymentsConfig (LangGraph) and AgentCorePaymentsPluginConfig (Strands) into a single dataclass in integrations/config.py. Both names remain available as aliases for backward compatibility. --- .../payments/integrations/config.py | 157 +++++++++++------- .../integrations/langgraph/__init__.py | 2 +- .../payments/integrations/langgraph/config.py | 151 ----------------- .../integrations/langgraph/middleware.py | 2 +- .../integrations/langgraph/test_stage1.py | 2 +- .../integrations/langgraph/test_stage2.py | 2 +- .../integrations/langgraph/test_stage3.py | 2 +- .../integrations/langgraph/test_stage4.py | 2 +- .../integrations/langgraph/test_stage5.py | 2 +- .../integrations/langgraph/test_stage6.py | 2 +- .../integrations/langgraph/test_stage7.py | 2 +- 11 files changed, 108 insertions(+), 218 deletions(-) delete mode 100644 src/bedrock_agentcore/payments/integrations/langgraph/config.py diff --git a/src/bedrock_agentcore/payments/integrations/config.py b/src/bedrock_agentcore/payments/integrations/config.py index 5456a1e7..f0ebd820 100644 --- a/src/bedrock_agentcore/payments/integrations/config.py +++ b/src/bedrock_agentcore/payments/integrations/config.py @@ -1,16 +1,19 @@ -"""Configuration for AgentCorePaymentsPlugin.""" +"""Configuration for AgentCore Payments integrations (Strands and LangGraph).""" -from dataclasses import dataclass -from typing import Callable, List, Optional +from dataclasses import dataclass, field +from typing import Any, Callable, Dict, List, Optional + +from .handlers import PaymentResponseHandler @dataclass class AgentCorePaymentsPluginConfig: - """Configuration for AgentCorePaymentsPlugin. + """Configuration for AgentCore Payments integrations. + + This unified config is used by both the Strands plugin and LangGraph middleware. Attributes: - payment_manager_arn: ARN of the payment manager service - region: AWS region for the payment manager + payment_manager_arn: ARN of the payment manager service. user_id: User ID for payment processing. Required for SigV4 auth. Optional for bearer token auth (JWT identifies the user). When set with bearer auth, propagated via X-Amzn-Bedrock-AgentCore-Payments-User-Id header. @@ -18,45 +21,29 @@ class AgentCorePaymentsPluginConfig: Can be set later via update_payment_instrument_id(). payment_session_id: Optional payment session ID for the transaction. Can be set later via update_payment_session_id(). - network_preferences_config: Optional list of network CAIP2 identifiers - in order of preference. If not provided, defaults to the system default. - auto_payment: Whether to automatically process 402 payment requirements. - Defaults to True to maintain existing behavior. - max_interrupt_retries: Maximum number of interrupt retries per tool use. - Defaults to 5. Set to 0 to disable interrupt retries entirely (no interrupts will be raised). - agent_name: Optional agent name to propagate via the - X-Amzn-Bedrock-AgentCore-Payments-Agent-Name HTTP header on every - AgentCore payments data-plane API call. When set, the header is automatically injected - by PaymentManager and propagated for Payments. - bearer_token: Optional static JWT bearer token for OAuth/CUSTOM_JWT authentication. - When set, PaymentManager uses Bearer token auth instead of SigV4. - Mutually exclusive with token_provider. - token_provider: Optional callable that returns a fresh JWT bearer token string. - Called before each request to support token refresh. - Mutually exclusive with bearer_token. - payment_tool_allowlist: Optional list of tool names that are eligible for - automatic X402 payment processing. When None (default), all tools are - eligible (preserving existing behavior). When set, only tool calls whose - name appears in this list will trigger payment processing; all others are - skipped. - provide_http_request: Whether the plugin should register its built-in - ``http_request`` ``@tool`` on the agent. Defaults to True so adding the - plugin gives a turnkey paid-HTTP experience. Set to False if you want - to ship your own ``http_request`` tool — Strands raises a ValueError - on duplicate tool names, so you must opt out of the plugin's version - before passing your own. Auto-payment of 402 responses still works - against any tool whose output carries the ``PAYMENT_REQUIRED:`` - content marker, so disabling this flag does not disable interception. - post_payment_retry_delay_seconds: Seconds to wait after generating a - payment header before allowing the tool to be retried. The x402 - EIP-3009 ``transferWithAuthorization`` contract requires - ``block.timestamp > validAfter`` (strict greater-than). Some signing - services set ``validAfter`` close to the current time, which can - cause the merchant facilitator to submit before ``validAfter`` - elapses, producing a misleading "invalid_payload" response. A small - delay between signing and retry lets the chain advance one block so - the authorization is valid by the time the seller submits. Defaults - to 3.0 seconds (about one Base Sepolia block). Set to 0 to disable. + payment_connector_id: Payment connector ID (optional). + region: AWS region for the payment manager. + network_preferences_config: Ordered list of network CAIP2 identifiers. + auto_payment: Whether to automatically process 402 responses. Default True. + agent_name: Agent name propagated via HTTP header on data-plane calls. + bearer_token: Static JWT for OAuth/CUSTOM_JWT auth. Mutually exclusive with token_provider. + token_provider: Callable returning a fresh JWT. Mutually exclusive with bearer_token. + payment_tool_allowlist: Tool names eligible for payment processing. None = all tools. + provide_http_request: Whether the integration registers its built-in http_request tool. + post_payment_retry_delay_seconds: Delay after signing before retry. Default 3.0s. + max_interrupt_retries: Maximum number of interrupt retries per tool use (Strands only). + Defaults to 5. Set to 0 to disable interrupt retries entirely. + custom_handlers: Custom PaymentResponseHandler instances keyed by tool name. + Takes precedence over the built-in handler registry during resolution. + auto_session: Whether to auto-create a payment session on first 402 if + payment_session_id is not set. Default False. + auto_session_budget: Budget for auto-created sessions (USD). Default "1.00". + auto_session_expiry_minutes: Expiry time for auto-created sessions. Default 60. + on_payment_error: Optional callback invoked when a payment exception occurs. + Receives PaymentErrorContext, returns ErrorResolution.RETRY or .PROPAGATE. + When None (default), errors produce deterministic ToolMessages directly. + max_error_retries: Maximum times the error callback can return RETRY per tool call. + Default 3. Set to 0 to disable the callback entirely. """ payment_manager_arn: str @@ -65,41 +52,45 @@ class AgentCorePaymentsPluginConfig: payment_session_id: Optional[str] = None payment_connector_id: Optional[str] = None region: Optional[str] = None - network_preferences_config: Optional[list[str]] = None + network_preferences_config: Optional[List[str]] = None auto_payment: bool = True - max_interrupt_retries: int = 5 agent_name: Optional[str] = None bearer_token: Optional[str] = None token_provider: Optional[Callable[[], str]] = None payment_tool_allowlist: Optional[List[str]] = None provide_http_request: bool = True post_payment_retry_delay_seconds: float = 3.0 + max_interrupt_retries: int = 5 + custom_handlers: Optional[Dict[str, Any]] = field(default=None) + auto_session: bool = False + auto_session_budget: str = "1.00" + auto_session_expiry_minutes: int = 60 + on_payment_error: Optional[Callable] = None + max_error_retries: int = 3 def __post_init__(self) -> None: """Validate configuration after initialization.""" if not self.payment_manager_arn: raise ValueError("payment_manager_arn is required") - if not self.payment_manager_arn.startswith("arn:"): raise ValueError(f"Invalid ARN format: {self.payment_manager_arn}") + if self.bearer_token is not None and self.token_provider is not None: + raise ValueError("bearer_token and token_provider are mutually exclusive") if self.bearer_token is not None and not isinstance(self.bearer_token, str): raise ValueError(f"bearer_token must be a string, got {type(self.bearer_token).__name__}") - if self.token_provider is not None and not callable(self.token_provider): raise ValueError(f"token_provider must be callable, got {type(self.token_provider).__name__}") - if self.user_id is not None and self.user_id and not self.user_id.strip(): - raise ValueError("user_id cannot be whitespace-only") - if not self.user_id and self.bearer_token is None and self.token_provider is None: raise ValueError("user_id is required for SigV4 auth (when bearer_token/token_provider not set)") + if self.user_id is not None and self.user_id and not self.user_id.strip(): + raise ValueError("user_id cannot be whitespace-only") if not isinstance(self.auto_payment, bool): raise ValueError(f"auto_payment must be a boolean, got {type(self.auto_payment).__name__}") - - if self.bearer_token is not None and self.token_provider is not None: - raise ValueError("bearer_token and token_provider are mutually exclusive. Provide only one.") + if not isinstance(self.provide_http_request, bool): + raise ValueError(f"provide_http_request must be a boolean, got {type(self.provide_http_request).__name__}") if self.payment_tool_allowlist is not None: if not isinstance(self.payment_tool_allowlist, list): @@ -107,9 +98,6 @@ def __post_init__(self) -> None: if not all(isinstance(t, str) for t in self.payment_tool_allowlist): raise ValueError("All entries in payment_tool_allowlist must be strings") - if not isinstance(self.provide_http_request, bool): - raise ValueError(f"provide_http_request must be a boolean, got {type(self.provide_http_request).__name__}") - if not isinstance(self.post_payment_retry_delay_seconds, (int, float)) or isinstance( self.post_payment_retry_delay_seconds, bool ): @@ -122,6 +110,22 @@ def __post_init__(self) -> None: f"post_payment_retry_delay_seconds must be >= 0, got {self.post_payment_retry_delay_seconds}" ) + if self.custom_handlers is not None: + if not isinstance(self.custom_handlers, dict): + raise ValueError("custom_handlers must be a dict mapping tool names to PaymentResponseHandler instances") + if not all(isinstance(k, str) for k in self.custom_handlers): + raise ValueError("All keys in custom_handlers must be strings") + if not all(isinstance(v, PaymentResponseHandler) for v in self.custom_handlers.values()): + raise ValueError("All values in custom_handlers must be PaymentResponseHandler instances") + + if self.on_payment_error is not None and not callable(self.on_payment_error): + raise ValueError(f"on_payment_error must be callable, got {type(self.on_payment_error).__name__}") + + if not isinstance(self.max_error_retries, int) or isinstance(self.max_error_retries, bool): + raise ValueError(f"max_error_retries must be an int, got {type(self.max_error_retries).__name__}") + if self.max_error_retries < 0: + raise ValueError(f"max_error_retries must be >= 0, got {self.max_error_retries}") + def update_payment_session_id(self, payment_session_id: str) -> None: """Update the payment session ID. @@ -141,3 +145,40 @@ def update_payment_instrument_id(self, payment_instrument_id: str) -> None: if not payment_instrument_id: raise ValueError("payment_instrument_id cannot be empty") self.payment_instrument_id = payment_instrument_id + + def add_to_allowlist(self, *tool_names: str) -> None: + """Add tool names to the payment allowlist. + + Creates the allowlist if it doesn't exist yet (switching from "all tools" + to explicit allowlist mode). + + Args: + tool_names: One or more tool names to add. + """ + if self.payment_tool_allowlist is None: + self.payment_tool_allowlist = [] + for name in tool_names: + if not isinstance(name, str): + raise ValueError(f"Tool name must be a string, got {type(name).__name__}") + if name not in self.payment_tool_allowlist: + self.payment_tool_allowlist.append(name) + + def remove_from_allowlist(self, *tool_names: str) -> None: + """Remove tool names from the payment allowlist. + + If the allowlist becomes empty, sets it to None (all tools eligible). + + Args: + tool_names: One or more tool names to remove. + """ + if self.payment_tool_allowlist is None: + return + for name in tool_names: + if name in self.payment_tool_allowlist: + self.payment_tool_allowlist.remove(name) + if not self.payment_tool_allowlist: + self.payment_tool_allowlist = None + + +# Backward-compatible alias for LangGraph imports +AgentCorePaymentsConfig = AgentCorePaymentsPluginConfig diff --git a/src/bedrock_agentcore/payments/integrations/langgraph/__init__.py b/src/bedrock_agentcore/payments/integrations/langgraph/__init__.py index e0704f45..7b5239f7 100644 --- a/src/bedrock_agentcore/payments/integrations/langgraph/__init__.py +++ b/src/bedrock_agentcore/payments/integrations/langgraph/__init__.py @@ -1,6 +1,6 @@ """LangGraph integration for AgentCore Payments.""" -from .config import AgentCorePaymentsConfig +from ..config import AgentCorePaymentsConfig from .errors import ErrorResolution, PaymentErrorContext from .middleware import AgentCorePaymentsMiddleware diff --git a/src/bedrock_agentcore/payments/integrations/langgraph/config.py b/src/bedrock_agentcore/payments/integrations/langgraph/config.py deleted file mode 100644 index 1ebd76ba..00000000 --- a/src/bedrock_agentcore/payments/integrations/langgraph/config.py +++ /dev/null @@ -1,151 +0,0 @@ -"""Configuration for AgentCorePaymentsMiddleware (LangGraph integration).""" - -from dataclasses import dataclass, field -from typing import Any, Callable, Dict, List, Optional - -from ..handlers import PaymentResponseHandler - - -@dataclass -class AgentCorePaymentsConfig: - """Configuration for AgentCorePaymentsMiddleware. - - Attributes: - payment_manager_arn: ARN of the payment manager resource. - user_id: User ID for payment processing. Required for SigV4 auth. - payment_instrument_id: Payment instrument ID for x402 signing. - payment_session_id: Payment session ID for budget enforcement. - payment_connector_id: Payment connector ID (optional). - region: AWS region for the payment manager. - network_preferences_config: Ordered list of network CAIP2 identifiers. - auto_payment: Whether to automatically process 402 responses. Default True. - agent_name: Agent name propagated via HTTP header on data-plane calls. - bearer_token: Static JWT for OAuth/CUSTOM_JWT auth. Mutually exclusive with token_provider. - token_provider: Callable returning a fresh JWT. Mutually exclusive with bearer_token. - payment_tool_allowlist: Tool names eligible for payment processing. None = all tools. - provide_http_request: Whether middleware registers its built-in http_request tool. - post_payment_retry_delay_seconds: Delay after signing before retry. Default 3.0s. - custom_handlers: Custom PaymentResponseHandler instances keyed by tool name. - Takes precedence over the built-in handler registry during resolution. - auto_session: Whether to auto-create a payment session on first 402 if - payment_session_id is not set. Default False. - auto_session_budget: Budget for auto-created sessions (USD). Default "1.00". - auto_session_expiry_minutes: Expiry time for auto-created sessions. Default 60. - on_payment_error: Optional callback invoked when a payment exception occurs. - Receives PaymentErrorContext, returns ErrorResolution.RETRY or .PROPAGATE. - When None (default), errors produce deterministic ToolMessages directly. - max_error_retries: Maximum times the error callback can return RETRY per tool call. - Default 3. Set to 0 to disable the callback entirely. - """ - - payment_manager_arn: str - user_id: Optional[str] = None - payment_instrument_id: Optional[str] = None - payment_session_id: Optional[str] = None - payment_connector_id: Optional[str] = None - region: Optional[str] = None - network_preferences_config: Optional[List[str]] = None - auto_payment: bool = True - agent_name: Optional[str] = None - bearer_token: Optional[str] = None - token_provider: Optional[Callable[[], str]] = None - payment_tool_allowlist: Optional[List[str]] = None - provide_http_request: bool = True - post_payment_retry_delay_seconds: float = 3.0 - custom_handlers: Optional[Dict[str, Any]] = field(default=None) - auto_session: bool = False - auto_session_budget: str = "1.00" - auto_session_expiry_minutes: int = 60 - on_payment_error: Optional[Callable] = None - max_error_retries: int = 3 - - def __post_init__(self) -> None: - """Validate configuration after initialization.""" - if not self.payment_manager_arn: - raise ValueError("payment_manager_arn is required") - if not self.payment_manager_arn.startswith("arn:"): - raise ValueError(f"Invalid ARN format: {self.payment_manager_arn}") - - if self.bearer_token is not None and self.token_provider is not None: - raise ValueError("bearer_token and token_provider are mutually exclusive") - if self.bearer_token is not None and not isinstance(self.bearer_token, str): - raise ValueError(f"bearer_token must be a string, got {type(self.bearer_token).__name__}") - if self.token_provider is not None and not callable(self.token_provider): - raise ValueError(f"token_provider must be callable, got {type(self.token_provider).__name__}") - - if not self.user_id and self.bearer_token is None and self.token_provider is None: - raise ValueError("user_id is required for SigV4 auth (when bearer_token/token_provider not set)") - if self.user_id is not None and self.user_id and not self.user_id.strip(): - raise ValueError("user_id cannot be whitespace-only") - - if not isinstance(self.auto_payment, bool): - raise ValueError(f"auto_payment must be a boolean, got {type(self.auto_payment).__name__}") - if not isinstance(self.provide_http_request, bool): - raise ValueError(f"provide_http_request must be a boolean, got {type(self.provide_http_request).__name__}") - - if self.payment_tool_allowlist is not None: - if not isinstance(self.payment_tool_allowlist, list): - raise ValueError("payment_tool_allowlist must be a list of tool name strings") - if not all(isinstance(t, str) for t in self.payment_tool_allowlist): - raise ValueError("All entries in payment_tool_allowlist must be strings") - - if not isinstance(self.post_payment_retry_delay_seconds, (int, float)) or isinstance( - self.post_payment_retry_delay_seconds, bool - ): - raise ValueError( - f"post_payment_retry_delay_seconds must be a number, got " - f"{type(self.post_payment_retry_delay_seconds).__name__}" - ) - if self.post_payment_retry_delay_seconds < 0: - raise ValueError( - f"post_payment_retry_delay_seconds must be >= 0, got {self.post_payment_retry_delay_seconds}" - ) - - if self.custom_handlers is not None: - if not isinstance(self.custom_handlers, dict): - raise ValueError("custom_handlers must be a dict mapping tool names to PaymentResponseHandler instances") - if not all(isinstance(k, str) for k in self.custom_handlers): - raise ValueError("All keys in custom_handlers must be strings") - if not all(isinstance(v, PaymentResponseHandler) for v in self.custom_handlers.values()): - raise ValueError("All values in custom_handlers must be PaymentResponseHandler instances") - - if self.on_payment_error is not None and not callable(self.on_payment_error): - raise ValueError(f"on_payment_error must be callable, got {type(self.on_payment_error).__name__}") - - if not isinstance(self.max_error_retries, int) or isinstance(self.max_error_retries, bool): - raise ValueError(f"max_error_retries must be an int, got {type(self.max_error_retries).__name__}") - if self.max_error_retries < 0: - raise ValueError(f"max_error_retries must be >= 0, got {self.max_error_retries}") - - def add_to_allowlist(self, *tool_names: str) -> None: - """Add tool names to the payment allowlist. - - Creates the allowlist if it doesn't exist yet (switching from "all tools" - to explicit allowlist mode). - - Args: - tool_names: One or more tool names to add. - """ - if self.payment_tool_allowlist is None: - self.payment_tool_allowlist = [] - for name in tool_names: - if not isinstance(name, str): - raise ValueError(f"Tool name must be a string, got {type(name).__name__}") - if name not in self.payment_tool_allowlist: - self.payment_tool_allowlist.append(name) - - def remove_from_allowlist(self, *tool_names: str) -> None: - """Remove tool names from the payment allowlist. - - If the allowlist becomes empty, sets it to None (all tools eligible). - - Args: - tool_names: One or more tool names to remove. - """ - if self.payment_tool_allowlist is None: - return - for name in tool_names: - if name in self.payment_tool_allowlist: - self.payment_tool_allowlist.remove(name) - if not self.payment_tool_allowlist: - self.payment_tool_allowlist = None diff --git a/src/bedrock_agentcore/payments/integrations/langgraph/middleware.py b/src/bedrock_agentcore/payments/integrations/langgraph/middleware.py index 868e3f81..3ef6eeed 100644 --- a/src/bedrock_agentcore/payments/integrations/langgraph/middleware.py +++ b/src/bedrock_agentcore/payments/integrations/langgraph/middleware.py @@ -25,7 +25,7 @@ PaymentManager, ) -from .config import AgentCorePaymentsConfig +from ..config import AgentCorePaymentsConfig from .tools import ( make_get_payment_instrument_balance_tool, make_get_payment_instrument_tool, diff --git a/tests/bedrock_agentcore/payments/integrations/langgraph/test_stage1.py b/tests/bedrock_agentcore/payments/integrations/langgraph/test_stage1.py index d0dcb2bb..0563675e 100644 --- a/tests/bedrock_agentcore/payments/integrations/langgraph/test_stage1.py +++ b/tests/bedrock_agentcore/payments/integrations/langgraph/test_stage1.py @@ -5,7 +5,7 @@ import pytest -from bedrock_agentcore.payments.integrations.langgraph.config import AgentCorePaymentsConfig +from bedrock_agentcore.payments.integrations.langgraph import AgentCorePaymentsConfig from bedrock_agentcore.payments.integrations.langgraph.middleware import AgentCorePaymentsMiddleware from bedrock_agentcore.payments.integrations.handlers import GenericPaymentHandler, PaymentResponseHandler diff --git a/tests/bedrock_agentcore/payments/integrations/langgraph/test_stage2.py b/tests/bedrock_agentcore/payments/integrations/langgraph/test_stage2.py index 6f7af5cd..dc302348 100644 --- a/tests/bedrock_agentcore/payments/integrations/langgraph/test_stage2.py +++ b/tests/bedrock_agentcore/payments/integrations/langgraph/test_stage2.py @@ -13,7 +13,7 @@ MCPRequestPaymentHandler, PaymentResponseHandler, ) -from bedrock_agentcore.payments.integrations.langgraph.config import AgentCorePaymentsConfig +from bedrock_agentcore.payments.integrations.langgraph import AgentCorePaymentsConfig from bedrock_agentcore.payments.integrations.langgraph.middleware import AgentCorePaymentsMiddleware diff --git a/tests/bedrock_agentcore/payments/integrations/langgraph/test_stage3.py b/tests/bedrock_agentcore/payments/integrations/langgraph/test_stage3.py index aced5a88..97bdf72c 100644 --- a/tests/bedrock_agentcore/payments/integrations/langgraph/test_stage3.py +++ b/tests/bedrock_agentcore/payments/integrations/langgraph/test_stage3.py @@ -7,7 +7,7 @@ from langchain.messages import ToolMessage from langgraph.types import Command -from bedrock_agentcore.payments.integrations.langgraph.config import AgentCorePaymentsConfig +from bedrock_agentcore.payments.integrations.langgraph import AgentCorePaymentsConfig from bedrock_agentcore.payments.integrations.langgraph.middleware import AgentCorePaymentsMiddleware from bedrock_agentcore.payments.manager import ( PaymentError, diff --git a/tests/bedrock_agentcore/payments/integrations/langgraph/test_stage4.py b/tests/bedrock_agentcore/payments/integrations/langgraph/test_stage4.py index 0ec1fa29..bbe9f868 100644 --- a/tests/bedrock_agentcore/payments/integrations/langgraph/test_stage4.py +++ b/tests/bedrock_agentcore/payments/integrations/langgraph/test_stage4.py @@ -6,7 +6,7 @@ import pytest from langchain.messages import ToolMessage -from bedrock_agentcore.payments.integrations.langgraph.config import AgentCorePaymentsConfig +from bedrock_agentcore.payments.integrations.langgraph import AgentCorePaymentsConfig from bedrock_agentcore.payments.integrations.langgraph.middleware import AgentCorePaymentsMiddleware from bedrock_agentcore.payments.manager import ( InsufficientBudget, diff --git a/tests/bedrock_agentcore/payments/integrations/langgraph/test_stage5.py b/tests/bedrock_agentcore/payments/integrations/langgraph/test_stage5.py index 69c1ba79..9dbe4fb0 100644 --- a/tests/bedrock_agentcore/payments/integrations/langgraph/test_stage5.py +++ b/tests/bedrock_agentcore/payments/integrations/langgraph/test_stage5.py @@ -8,7 +8,7 @@ from langchain.messages import ToolMessage from langgraph.types import Command -from bedrock_agentcore.payments.integrations.langgraph.config import AgentCorePaymentsConfig +from bedrock_agentcore.payments.integrations.langgraph import AgentCorePaymentsConfig from bedrock_agentcore.payments.integrations.langgraph.middleware import AgentCorePaymentsMiddleware diff --git a/tests/bedrock_agentcore/payments/integrations/langgraph/test_stage6.py b/tests/bedrock_agentcore/payments/integrations/langgraph/test_stage6.py index 25880ae9..eb9d9a03 100644 --- a/tests/bedrock_agentcore/payments/integrations/langgraph/test_stage6.py +++ b/tests/bedrock_agentcore/payments/integrations/langgraph/test_stage6.py @@ -5,7 +5,7 @@ import pytest -from bedrock_agentcore.payments.integrations.langgraph.config import AgentCorePaymentsConfig +from bedrock_agentcore.payments.integrations.langgraph import AgentCorePaymentsConfig from bedrock_agentcore.payments.integrations.langgraph.middleware import AgentCorePaymentsMiddleware diff --git a/tests/bedrock_agentcore/payments/integrations/langgraph/test_stage7.py b/tests/bedrock_agentcore/payments/integrations/langgraph/test_stage7.py index 9c5f4840..6a0bbcb4 100644 --- a/tests/bedrock_agentcore/payments/integrations/langgraph/test_stage7.py +++ b/tests/bedrock_agentcore/payments/integrations/langgraph/test_stage7.py @@ -8,7 +8,7 @@ from langchain.messages import ToolMessage from langgraph.types import Command -from bedrock_agentcore.payments.integrations.langgraph.config import AgentCorePaymentsConfig +from bedrock_agentcore.payments.integrations.langgraph import AgentCorePaymentsConfig from bedrock_agentcore.payments.integrations.langgraph.errors import ErrorResolution, PaymentErrorContext from bedrock_agentcore.payments.integrations.langgraph.middleware import AgentCorePaymentsMiddleware from bedrock_agentcore.payments.manager import ( From d56abba02b6eef921c36a33099267dc75c7f145d Mon Sep 17 00:00:00 2001 From: Raghav Sunil Date: Tue, 23 Jun 2026 21:55:25 +0000 Subject: [PATCH 3/6] define _ERROR_MESSAGES dict at the root level of payments SDK instead of in langgraph-specific package --- .../payments/integrations/error_messages.py | 74 +++++++++++++++++++ .../integrations/langgraph/middleware.py | 42 +---------- 2 files changed, 76 insertions(+), 40 deletions(-) create mode 100644 src/bedrock_agentcore/payments/integrations/error_messages.py diff --git a/src/bedrock_agentcore/payments/integrations/error_messages.py b/src/bedrock_agentcore/payments/integrations/error_messages.py new file mode 100644 index 00000000..4c00af43 --- /dev/null +++ b/src/bedrock_agentcore/payments/integrations/error_messages.py @@ -0,0 +1,74 @@ +"""Shared deterministic error messages for payment exceptions. + +These messages are designed to be shown to LLMs via tool results. They instruct +the model not to retry and to inform the user of the specific issue. + +Used by: LangGraph middleware, and available for any future plugin integration. +""" + +from typing import Dict, Type + +from bedrock_agentcore.payments.manager import ( + InsufficientBudget, + PaymentError, + PaymentInstrumentConfigurationRequired, + PaymentInstrumentNotFound, + PaymentSessionConfigurationRequired, + PaymentSessionExpired, + PaymentSessionNotFound, +) + +# Maps exception types to deterministic, LLM-instructive messages. +PAYMENT_ERROR_MESSAGES: Dict[Type[Exception], str] = { + PaymentInstrumentConfigurationRequired: ( + "No payment instrument configured. Do not retry this call. " + "Inform the user they need to configure a payment instrument before making paid requests." + ), + PaymentSessionConfigurationRequired: ( + "No payment session configured. Do not retry this call. " + "Inform the user they need to create a payment session before making paid requests." + ), + PaymentInstrumentNotFound: ( + "Payment instrument not found. Do not retry this call. " + "Inform the user their payment instrument ID is invalid or has been deleted." + ), + PaymentSessionNotFound: ( + "Payment session not found. Do not retry this call. " + "Inform the user their payment session ID is invalid or has expired." + ), + PaymentSessionExpired: ( + "Payment session has expired. Do not retry this call. " + "Inform the user they need to create a new payment session." + ), + InsufficientBudget: ( + "Insufficient budget. The payment amount exceeds the remaining session limit. " + "Do not retry this call. Inform the user they need to increase their session budget " + "or create a new session with higher limits." + ), +} + + +def get_payment_error_message(exception: Exception) -> str: + """Get the deterministic error message for a payment exception. + + Looks up the exception type in the message map. Falls back to a generic + message that includes the exception string for unrecognized types. + + Args: + exception: The payment exception. + + Returns: + Human/LLM-readable error message string. + """ + msg = PAYMENT_ERROR_MESSAGES.get(type(exception)) + if msg is not None: + return msg + if isinstance(exception, PaymentError): + return ( + f"Payment processing failed ({exception}). " + "Do not retry this call. Inform the user that payment could not be completed." + ) + return ( + f"An unexpected error occurred during payment processing ({exception}). " + "Do not retry this call. Inform the user that payment could not be completed." + ) diff --git a/src/bedrock_agentcore/payments/integrations/langgraph/middleware.py b/src/bedrock_agentcore/payments/integrations/langgraph/middleware.py index 3ef6eeed..1251871b 100644 --- a/src/bedrock_agentcore/payments/integrations/langgraph/middleware.py +++ b/src/bedrock_agentcore/payments/integrations/langgraph/middleware.py @@ -10,6 +10,7 @@ from langgraph.prebuilt.tool_node import ToolCallRequest from langgraph.types import Command +from bedrock_agentcore.payments.integrations.error_messages import get_payment_error_message from bedrock_agentcore.payments.integrations.handlers import ( PaymentResponseHandler, get_payment_handler, @@ -37,34 +38,6 @@ logger = logging.getLogger(__name__) # Deterministic error messages per exception type. -# The LLM sees these messages and should follow the "Do not retry" instruction. -_ERROR_MESSAGES: Dict[type, str] = { - PaymentInstrumentConfigurationRequired: ( - "No payment instrument configured. Do not retry this call. " - "Inform the user they need to configure a payment instrument before making paid requests." - ), - PaymentSessionConfigurationRequired: ( - "No payment session configured. Do not retry this call. " - "Inform the user they need to create a payment session before making paid requests." - ), - PaymentInstrumentNotFound: ( - "Payment instrument not found. Do not retry this call. " - "Inform the user their payment instrument ID is invalid or has been deleted." - ), - PaymentSessionNotFound: ( - "Payment session not found. Do not retry this call. " - "Inform the user their payment session ID is invalid or has expired." - ), - PaymentSessionExpired: ( - "Payment session has expired. Do not retry this call. " - "Inform the user they need to create a new payment session." - ), - InsufficientBudget: ( - "Insufficient budget. The payment amount exceeds the remaining session limit. " - "Do not retry this call. Inform the user they need to increase their session budget " - "or create a new session with higher limits." - ), -} class _FallbackHandler: @@ -569,18 +542,7 @@ def _error_tool_message(request: ToolCallRequest, exception: Exception) -> ToolM Returns: ToolMessage with status="error" and deterministic content. """ - msg = _ERROR_MESSAGES.get(type(exception)) - if msg is None: - if isinstance(exception, PaymentError): - msg = ( - f"Payment processing failed ({exception}). " - "Do not retry this call. Inform the user that payment could not be completed." - ) - else: - msg = ( - f"An unexpected error occurred during payment processing ({exception}). " - "Do not retry this call. Inform the user that payment could not be completed." - ) + msg = get_payment_error_message(exception) return ToolMessage( content=f"PAYMENT ERROR: {msg}", From 6b4f356f464b7f84c5465eb17ff196ffedbc07d6 Mon Sep 17 00:00:00 2001 From: Raghav Sunil Date: Thu, 25 Jun 2026 18:51:59 +0000 Subject: [PATCH 4/6] added langgraph dependencies to pyproject.toml for CI --- pyproject.toml | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/pyproject.toml b/pyproject.toml index 449cc412..cb659cc6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -154,6 +154,9 @@ dev = [ "wheel>=0.45.1", "strands-agents>=1.20.0", "strands-agents-evals>=0.1.0", + "langchain>=0.2.0", + "langgraph>=0.2.0", + "langchain-mcp-adapters>=0.1.0", "a2a-sdk[http-server]>=0.3,<1.0", "ag-ui-protocol>=0.1.10", "mcp-proxy-for-aws>=0.1.0", @@ -166,6 +169,12 @@ strands-agents = [ "strands-agents>=1.20.0", "mcp>=1.23.0,<2.0.0", ] +langgraph = [ + "langchain>=0.2.0", + "langgraph>=0.2.0", + "langchain-mcp-adapters>=0.1.0", + "httpx>=0.27.0", +] strands-agents-evals = [ "strands-agents-evals>=0.1.0" ] From 28a49ac3075f3f921b503ca64b7ee7c9eb8bc5c2 Mon Sep 17 00:00:00 2001 From: Raghav Sunil Date: Thu, 25 Jun 2026 21:59:19 +0000 Subject: [PATCH 5/6] fix: linter check failing due to long lines, wrapped and fixed length of all --- .../payments/integrations/config.py | 5 ++- .../integrations/langgraph/middleware.py | 39 +++++++++++++++---- .../payments/integrations/langgraph/tools.py | 10 ++++- .../integrations/langgraph/test_functional.py | 5 ++- .../integrations/langgraph/test_stage3.py | 6 ++- .../integrations/langgraph/test_stage4.py | 6 ++- .../integrations/langgraph/test_stage5.py | 22 ++++++++--- .../integrations/langgraph/test_stage7.py | 6 ++- 8 files changed, 80 insertions(+), 19 deletions(-) diff --git a/src/bedrock_agentcore/payments/integrations/config.py b/src/bedrock_agentcore/payments/integrations/config.py index f0ebd820..7781174c 100644 --- a/src/bedrock_agentcore/payments/integrations/config.py +++ b/src/bedrock_agentcore/payments/integrations/config.py @@ -112,7 +112,10 @@ def __post_init__(self) -> None: if self.custom_handlers is not None: if not isinstance(self.custom_handlers, dict): - raise ValueError("custom_handlers must be a dict mapping tool names to PaymentResponseHandler instances") + raise ValueError( + "custom_handlers must be a dict mapping tool names" + " to PaymentResponseHandler instances" + ) if not all(isinstance(k, str) for k in self.custom_handlers): raise ValueError("All keys in custom_handlers must be strings") if not all(isinstance(v, PaymentResponseHandler) for v in self.custom_handlers.values()): diff --git a/src/bedrock_agentcore/payments/integrations/langgraph/middleware.py b/src/bedrock_agentcore/payments/integrations/langgraph/middleware.py index 1251871b..3c872a94 100644 --- a/src/bedrock_agentcore/payments/integrations/langgraph/middleware.py +++ b/src/bedrock_agentcore/payments/integrations/langgraph/middleware.py @@ -491,9 +491,15 @@ def _invoke_error_handler( ) if not injection_handler.validate_tool_input(tool_args): - return self._error_tool_message(request, PaymentError("Could not apply payment credentials after error recovery.")) + return self._error_tool_message( + request, + PaymentError("Could not apply payment credentials after error recovery."), + ) if not injection_handler.apply_payment_header(tool_args, payment_header): - return self._error_tool_message(request, PaymentError("Could not apply payment credentials after error recovery.")) + return self._error_tool_message( + request, + PaymentError("Could not apply payment credentials after error recovery."), + ) delay = self.config.post_payment_retry_delay_seconds if delay > 0: @@ -516,7 +522,12 @@ def _invoke_error_handler( if retry_status == 402: retry_body = _rh.extract_body(retry_prepared) or {} detail = retry_body.get("error", "unknown") if isinstance(retry_body, dict) else "unknown" - return self._error_tool_message(request, PaymentError(f"Payment signed but rejected after recovery ({detail}).")) + return self._error_tool_message( + request, + PaymentError( + f"Payment signed but rejected after recovery ({detail})." + ), + ) return retry_result @@ -742,7 +753,10 @@ async def _ainvoke_error_handler( return None retry_count += 1 - logger.info("on_payment_error returned RETRY (async, attempt %d/%d)", retry_count, self.config.max_error_retries) + logger.info( + "on_payment_error returned RETRY (async, attempt %d/%d)", + retry_count, self.config.max_error_retries, + ) try: payment_header = await asyncio.to_thread( @@ -756,9 +770,15 @@ async def _ainvoke_error_handler( ) if not injection_handler.validate_tool_input(tool_args): - return self._error_tool_message(request, PaymentError("Could not apply payment credentials after error recovery.")) + return self._error_tool_message( + request, + PaymentError("Could not apply payment credentials after error recovery."), + ) if not injection_handler.apply_payment_header(tool_args, payment_header): - return self._error_tool_message(request, PaymentError("Could not apply payment credentials after error recovery.")) + return self._error_tool_message( + request, + PaymentError("Could not apply payment credentials after error recovery."), + ) delay = self.config.post_payment_retry_delay_seconds if delay > 0: @@ -780,7 +800,12 @@ async def _ainvoke_error_handler( if retry_status == 402: retry_body = _rh.extract_body(retry_prepared) or {} detail = retry_body.get("error", "unknown") if isinstance(retry_body, dict) else "unknown" - return self._error_tool_message(request, PaymentError(f"Payment signed but rejected after recovery ({detail}).")) + return self._error_tool_message( + request, + PaymentError( + f"Payment signed but rejected after recovery ({detail})." + ), + ) return retry_result diff --git a/src/bedrock_agentcore/payments/integrations/langgraph/tools.py b/src/bedrock_agentcore/payments/integrations/langgraph/tools.py index 142d2373..a0d6926c 100644 --- a/src/bedrock_agentcore/payments/integrations/langgraph/tools.py +++ b/src/bedrock_agentcore/payments/integrations/langgraph/tools.py @@ -88,7 +88,10 @@ def get_payment_instrument( Returns: Payment instrument details dictionary. """ - resolved_id = (payment_instrument_id.strip() if payment_instrument_id else None) or middleware.config.payment_instrument_id + resolved_id = ( + (payment_instrument_id.strip() if payment_instrument_id else None) + or middleware.config.payment_instrument_id + ) resolved_user = (user_id.strip() if user_id else None) or middleware.config.user_id return middleware.payment_manager.get_payment_instrument( user_id=resolved_user, @@ -184,7 +187,10 @@ def get_payment_session( Returns: Payment session details dictionary. """ - resolved_id = (payment_session_id.strip() if payment_session_id else None) or middleware.config.payment_session_id + resolved_id = ( + (payment_session_id.strip() if payment_session_id else None) + or middleware.config.payment_session_id + ) resolved_user = (user_id.strip() if user_id else None) or middleware.config.user_id return middleware.payment_manager.get_payment_session( user_id=resolved_user, diff --git a/tests/bedrock_agentcore/payments/integrations/langgraph/test_functional.py b/tests/bedrock_agentcore/payments/integrations/langgraph/test_functional.py index dd02a1d7..5f4a08ed 100644 --- a/tests/bedrock_agentcore/payments/integrations/langgraph/test_functional.py +++ b/tests/bedrock_agentcore/payments/integrations/langgraph/test_functional.py @@ -38,7 +38,10 @@ "ACP_REGION", "ACP_TESTNET_URL", ]), - reason="Testnet env vars not set (ACP_PAYMENT_MANAGER_ARN, ACP_USER_ID, ACP_PAYMENT_INSTRUMENT_ID, ACP_REGION, ACP_TESTNET_URL)", + reason=( + "Testnet env vars not set (ACP_PAYMENT_MANAGER_ARN, ACP_USER_ID," + " ACP_PAYMENT_INSTRUMENT_ID, ACP_REGION, ACP_TESTNET_URL)" + ), ) diff --git a/tests/bedrock_agentcore/payments/integrations/langgraph/test_stage3.py b/tests/bedrock_agentcore/payments/integrations/langgraph/test_stage3.py index 97bdf72c..6b85d403 100644 --- a/tests/bedrock_agentcore/payments/integrations/langgraph/test_stage3.py +++ b/tests/bedrock_agentcore/payments/integrations/langgraph/test_stage3.py @@ -30,7 +30,11 @@ def _make_config(**overrides): def _make_request(tool_name="http_request", tool_args=None, tool_id="tc-1"): req = MagicMock() - req.tool_call = {"name": tool_name, "args": tool_args if tool_args is not None else {"url": "http://x.com", "headers": {}}, "id": tool_id} + req.tool_call = { + "name": tool_name, + "args": tool_args if tool_args is not None else {"url": "http://x.com", "headers": {}}, + "id": tool_id, + } return req diff --git a/tests/bedrock_agentcore/payments/integrations/langgraph/test_stage4.py b/tests/bedrock_agentcore/payments/integrations/langgraph/test_stage4.py index bbe9f868..d7b46a20 100644 --- a/tests/bedrock_agentcore/payments/integrations/langgraph/test_stage4.py +++ b/tests/bedrock_agentcore/payments/integrations/langgraph/test_stage4.py @@ -33,7 +33,11 @@ def _make_config(**overrides): def _make_request(tool_name="http_request", tool_args=None, tool_id="tc-1"): req = MagicMock() - req.tool_call = {"name": tool_name, "args": tool_args if tool_args is not None else {"url": "http://x.com", "headers": {}}, "id": tool_id} + req.tool_call = { + "name": tool_name, + "args": tool_args if tool_args is not None else {"url": "http://x.com", "headers": {}}, + "id": tool_id, + } return req diff --git a/tests/bedrock_agentcore/payments/integrations/langgraph/test_stage5.py b/tests/bedrock_agentcore/payments/integrations/langgraph/test_stage5.py index 9dbe4fb0..6c46a155 100644 --- a/tests/bedrock_agentcore/payments/integrations/langgraph/test_stage5.py +++ b/tests/bedrock_agentcore/payments/integrations/langgraph/test_stage5.py @@ -26,7 +26,11 @@ def _make_config(**overrides): def _make_request(tool_name="http_request", tool_args=None, tool_id="tc-1"): req = MagicMock() - req.tool_call = {"name": tool_name, "args": tool_args if tool_args is not None else {"url": "http://x.com", "headers": {}}, "id": tool_id} + req.tool_call = { + "name": tool_name, + "args": tool_args if tool_args is not None else {"url": "http://x.com", "headers": {}}, + "id": tool_id, + } return req @@ -85,7 +89,9 @@ def test_402_then_200_on_retry(self, mock_pm_cls): success_msg, ]) - result = asyncio.run(mw.awrap_tool_call(_make_request(tool_args={"url": "http://x.com", "headers": {}}), handler)) + result = asyncio.run(mw.awrap_tool_call( + _make_request(tool_args={"url": "http://x.com", "headers": {}}), handler, + )) assert result is success_msg @patch("bedrock_agentcore.payments.integrations.langgraph.middleware.PaymentManager") @@ -183,7 +189,9 @@ def test_missing_instrument_error(self, mock_pm_cls): handler = AsyncMock(return_value=ToolMessage(content=_402_content(), tool_call_id="tc-1")) - result = asyncio.run(mw.awrap_tool_call(_make_request(tool_args={"url": "http://x.com", "headers": {}}), handler)) + result = asyncio.run(mw.awrap_tool_call( + _make_request(tool_args={"url": "http://x.com", "headers": {}}), handler, + )) assert "PAYMENT ERROR" in result.content assert "No payment instrument configured" in result.content assert result.status == "error" @@ -198,7 +206,9 @@ def test_post_payment_rejection(self, mock_pm_cls): ToolMessage(content=_402_content(), tool_call_id="tc-1"), ]) - result = asyncio.run(mw.awrap_tool_call(_make_request(tool_args={"url": "http://x.com", "headers": {}}), handler)) + result = asyncio.run(mw.awrap_tool_call( + _make_request(tool_args={"url": "http://x.com", "headers": {}}), handler, + )) assert "signed but rejected" in result.content @patch("bedrock_agentcore.payments.integrations.langgraph.middleware.PaymentManager") @@ -208,7 +218,9 @@ def test_unexpected_exception(self, mock_pm_cls): handler = AsyncMock(return_value=ToolMessage(content=_402_content(), tool_call_id="tc-1")) - result = asyncio.run(mw.awrap_tool_call(_make_request(tool_args={"url": "http://x.com", "headers": {}}), handler)) + result = asyncio.run(mw.awrap_tool_call( + _make_request(tool_args={"url": "http://x.com", "headers": {}}), handler, + )) assert isinstance(result, ToolMessage) assert "unexpected error" in result.content assert "async boom" in result.content diff --git a/tests/bedrock_agentcore/payments/integrations/langgraph/test_stage7.py b/tests/bedrock_agentcore/payments/integrations/langgraph/test_stage7.py index 6a0bbcb4..39b19db6 100644 --- a/tests/bedrock_agentcore/payments/integrations/langgraph/test_stage7.py +++ b/tests/bedrock_agentcore/payments/integrations/langgraph/test_stage7.py @@ -42,7 +42,11 @@ def _make_config(**overrides): def _make_request(tool_name="http_request", tool_args=None, tool_id="tc-1"): req = MagicMock() - req.tool_call = {"name": tool_name, "args": tool_args if tool_args is not None else {"url": "http://x.com", "headers": {}}, "id": tool_id} + req.tool_call = { + "name": tool_name, + "args": tool_args if tool_args is not None else {"url": "http://x.com", "headers": {}}, + "id": tool_id, + } return req From bef47636f6c205fbbbba071c947acadf8f513b43 Mon Sep 17 00:00:00 2001 From: Raghav Sunil Date: Thu, 25 Jun 2026 22:15:45 +0000 Subject: [PATCH 6/6] style: apply ruff formatting locally for push and rerun lint check --- .../payments/integrations/config.py | 3 +- .../integrations/langgraph/middleware.py | 72 +++++-------- .../payments/integrations/langgraph/tools.py | 10 +- .../integrations/langgraph/test_functional.py | 47 ++++---- .../integrations/langgraph/test_stage1.py | 5 +- .../integrations/langgraph/test_stage2.py | 1 - .../integrations/langgraph/test_stage3.py | 65 ++++++----- .../integrations/langgraph/test_stage4.py | 13 ++- .../integrations/langgraph/test_stage5.py | 101 +++++++++++------- .../integrations/langgraph/test_stage6.py | 6 +- .../integrations/langgraph/test_stage7.py | 33 +++--- 11 files changed, 185 insertions(+), 171 deletions(-) diff --git a/src/bedrock_agentcore/payments/integrations/config.py b/src/bedrock_agentcore/payments/integrations/config.py index 7781174c..83bc5c9d 100644 --- a/src/bedrock_agentcore/payments/integrations/config.py +++ b/src/bedrock_agentcore/payments/integrations/config.py @@ -113,8 +113,7 @@ def __post_init__(self) -> None: if self.custom_handlers is not None: if not isinstance(self.custom_handlers, dict): raise ValueError( - "custom_handlers must be a dict mapping tool names" - " to PaymentResponseHandler instances" + "custom_handlers must be a dict mapping tool names to PaymentResponseHandler instances" ) if not all(isinstance(k, str) for k in self.custom_handlers): raise ValueError("All keys in custom_handlers must be strings") diff --git a/src/bedrock_agentcore/payments/integrations/langgraph/middleware.py b/src/bedrock_agentcore/payments/integrations/langgraph/middleware.py index 3c872a94..37b8cf90 100644 --- a/src/bedrock_agentcore/payments/integrations/langgraph/middleware.py +++ b/src/bedrock_agentcore/payments/integrations/langgraph/middleware.py @@ -16,14 +16,10 @@ get_payment_handler, ) from bedrock_agentcore.payments.manager import ( - InsufficientBudget, PaymentError, PaymentInstrumentConfigurationRequired, - PaymentInstrumentNotFound, - PaymentSessionConfigurationRequired, - PaymentSessionExpired, - PaymentSessionNotFound, PaymentManager, + PaymentSessionConfigurationRequired, ) from ..config import AgentCorePaymentsConfig @@ -238,16 +234,12 @@ def _generate_payment_header(self, payment_required_request: Dict[str, Any]) -> PaymentError: If payment processing fails. """ if self.config.payment_instrument_id is None: - raise PaymentInstrumentConfigurationRequired( - "payment_instrument_id is required for x402 payments." - ) + raise PaymentInstrumentConfigurationRequired("payment_instrument_id is required for x402 payments.") if self.config.payment_session_id is None: if self.config.auto_session: self._create_auto_session() else: - raise PaymentSessionConfigurationRequired( - "payment_session_id is required for x402 payments." - ) + raise PaymentSessionConfigurationRequired("payment_session_id is required for x402 payments.") return self.payment_manager.generate_payment_header( user_id=self.config.user_id, @@ -311,15 +303,13 @@ def wrap_tool_call( if prepared is None: return result - has_custom_handler = ( - self.config.custom_handlers is not None - and tool_name in self.config.custom_handlers - ) + has_custom_handler = self.config.custom_handlers is not None and tool_name in self.config.custom_handlers if has_custom_handler: detection_handler = self.config.custom_handlers[tool_name] else: from bedrock_agentcore.payments.integrations.handlers import GenericPaymentHandler + detection_handler = GenericPaymentHandler() status_code = detection_handler.extract_status_code(prepared) @@ -390,6 +380,7 @@ def wrap_tool_call( if retry_prepared is not None: # Use fresh detection on the retry result (not the frozen fallback handler) from bedrock_agentcore.payments.integrations.handlers import GenericPaymentHandler as _GH + _retry_handler = _GH() retry_status = _retry_handler.extract_status_code(retry_prepared) # Also check via fallback if marker not found @@ -401,15 +392,11 @@ def wrap_tool_call( if retry_status == 402: retry_body = _retry_handler.extract_body(retry_prepared) or {} error_detail = ( - retry_body.get("error", "unknown error") - if isinstance(retry_body, dict) - else "unknown error" + retry_body.get("error", "unknown error") if isinstance(retry_body, dict) else "unknown error" ) return self._error_tool_message( request, - PaymentError( - f"Payment was signed but rejected by the server ({error_detail})." - ), + PaymentError(f"Payment was signed but rejected by the server ({error_detail})."), ) return retry_result @@ -486,8 +473,7 @@ def _invoke_error_handler( has_custom = self.config.custom_handlers and tool_name in self.config.custom_handlers injection_handler = ( - self.config.custom_handlers[tool_name] if has_custom - else self._get_handler(tool_name, tool_args) + self.config.custom_handlers[tool_name] if has_custom else self._get_handler(tool_name, tool_args) ) if not injection_handler.validate_tool_input(tool_args): @@ -513,6 +499,7 @@ def _invoke_error_handler( retry_prepared = self._prepare_for_handler(retry_result.content) if retry_prepared is not None: from bedrock_agentcore.payments.integrations.handlers import GenericPaymentHandler as _GH + _rh = _GH() retry_status = _rh.extract_status_code(retry_prepared) if retry_status != 402: @@ -524,9 +511,7 @@ def _invoke_error_handler( detail = retry_body.get("error", "unknown") if isinstance(retry_body, dict) else "unknown" return self._error_tool_message( request, - PaymentError( - f"Payment signed but rejected after recovery ({detail})." - ), + PaymentError(f"Payment signed but rejected after recovery ({detail})."), ) return retry_result @@ -597,15 +582,13 @@ async def awrap_tool_call( if prepared is None: return result - has_custom_handler = ( - self.config.custom_handlers is not None - and tool_name in self.config.custom_handlers - ) + has_custom_handler = self.config.custom_handlers is not None and tool_name in self.config.custom_handlers if has_custom_handler: detection_handler = self.config.custom_handlers[tool_name] else: from bedrock_agentcore.payments.integrations.handlers import GenericPaymentHandler + detection_handler = GenericPaymentHandler() status_code = detection_handler.extract_status_code(prepared) @@ -632,9 +615,7 @@ async def awrap_tool_call( "body": body_402, } - payment_header = await asyncio.to_thread( - self._generate_payment_header, payment_required_request - ) + payment_header = await asyncio.to_thread(self._generate_payment_header, payment_required_request) if has_custom_handler: injection_handler = detection_handler @@ -665,6 +646,7 @@ async def awrap_tool_call( retry_prepared = self._prepare_for_handler(retry_result.content) if retry_prepared is not None: from bedrock_agentcore.payments.integrations.handlers import GenericPaymentHandler as _GH + _retry_handler = _GH() retry_status = _retry_handler.extract_status_code(retry_prepared) if retry_status != 402: @@ -675,15 +657,11 @@ async def awrap_tool_call( if retry_status == 402: retry_body = _retry_handler.extract_body(retry_prepared) or {} error_detail = ( - retry_body.get("error", "unknown error") - if isinstance(retry_body, dict) - else "unknown error" + retry_body.get("error", "unknown error") if isinstance(retry_body, dict) else "unknown error" ) return self._error_tool_message( request, - PaymentError( - f"Payment was signed but rejected by the server ({error_detail})." - ), + PaymentError(f"Payment was signed but rejected by the server ({error_detail})."), ) return retry_result @@ -715,6 +693,7 @@ async def _ainvoke_error_handler( """Async version of _invoke_error_handler. Supports async callbacks.""" import asyncio import inspect + from .errors import ErrorResolution, PaymentErrorContext retry_count = 0 @@ -755,18 +734,16 @@ async def _ainvoke_error_handler( retry_count += 1 logger.info( "on_payment_error returned RETRY (async, attempt %d/%d)", - retry_count, self.config.max_error_retries, + retry_count, + self.config.max_error_retries, ) try: - payment_header = await asyncio.to_thread( - self._generate_payment_header, payment_required_request or {} - ) + payment_header = await asyncio.to_thread(self._generate_payment_header, payment_required_request or {}) has_custom = self.config.custom_handlers and tool_name in self.config.custom_handlers injection_handler = ( - self.config.custom_handlers[tool_name] if has_custom - else self._get_handler(tool_name, tool_args) + self.config.custom_handlers[tool_name] if has_custom else self._get_handler(tool_name, tool_args) ) if not injection_handler.validate_tool_input(tool_args): @@ -791,6 +768,7 @@ async def _ainvoke_error_handler( retry_prepared = self._prepare_for_handler(retry_result.content) if retry_prepared is not None: from bedrock_agentcore.payments.integrations.handlers import GenericPaymentHandler as _GH + _rh = _GH() retry_status = _rh.extract_status_code(retry_prepared) if retry_status != 402: @@ -802,9 +780,7 @@ async def _ainvoke_error_handler( detail = retry_body.get("error", "unknown") if isinstance(retry_body, dict) else "unknown" return self._error_tool_message( request, - PaymentError( - f"Payment signed but rejected after recovery ({detail})." - ), + PaymentError(f"Payment signed but rejected after recovery ({detail})."), ) return retry_result diff --git a/src/bedrock_agentcore/payments/integrations/langgraph/tools.py b/src/bedrock_agentcore/payments/integrations/langgraph/tools.py index a0d6926c..e74f0562 100644 --- a/src/bedrock_agentcore/payments/integrations/langgraph/tools.py +++ b/src/bedrock_agentcore/payments/integrations/langgraph/tools.py @@ -89,9 +89,8 @@ def get_payment_instrument( Payment instrument details dictionary. """ resolved_id = ( - (payment_instrument_id.strip() if payment_instrument_id else None) - or middleware.config.payment_instrument_id - ) + payment_instrument_id.strip() if payment_instrument_id else None + ) or middleware.config.payment_instrument_id resolved_user = (user_id.strip() if user_id else None) or middleware.config.user_id return middleware.payment_manager.get_payment_instrument( user_id=resolved_user, @@ -188,9 +187,8 @@ def get_payment_session( Payment session details dictionary. """ resolved_id = ( - (payment_session_id.strip() if payment_session_id else None) - or middleware.config.payment_session_id - ) + payment_session_id.strip() if payment_session_id else None + ) or middleware.config.payment_session_id resolved_user = (user_id.strip() if user_id else None) or middleware.config.user_id return middleware.payment_manager.get_payment_session( user_id=resolved_user, diff --git a/tests/bedrock_agentcore/payments/integrations/langgraph/test_functional.py b/tests/bedrock_agentcore/payments/integrations/langgraph/test_functional.py index 5f4a08ed..a1ee0db1 100644 --- a/tests/bedrock_agentcore/payments/integrations/langgraph/test_functional.py +++ b/tests/bedrock_agentcore/payments/integrations/langgraph/test_functional.py @@ -18,10 +18,10 @@ import json import os +from unittest.mock import MagicMock import pytest from langchain.messages import ToolMessage -from unittest.mock import MagicMock from bedrock_agentcore.payments.integrations.langgraph import ( AgentCorePaymentsConfig, @@ -31,13 +31,16 @@ # Skip entire module if required env vars not configured pytestmark = pytest.mark.skipif( - not all(os.environ.get(k) for k in [ - "ACP_PAYMENT_MANAGER_ARN", - "ACP_USER_ID", - "ACP_PAYMENT_INSTRUMENT_ID", - "ACP_REGION", - "ACP_TESTNET_URL", - ]), + not all( + os.environ.get(k) + for k in [ + "ACP_PAYMENT_MANAGER_ARN", + "ACP_USER_ID", + "ACP_PAYMENT_INSTRUMENT_ID", + "ACP_REGION", + "ACP_TESTNET_URL", + ] + ), reason=( "Testnet env vars not set (ACP_PAYMENT_MANAGER_ARN, ACP_USER_ID," " ACP_PAYMENT_INSTRUMENT_ID, ACP_REGION, ACP_TESTNET_URL)" @@ -102,7 +105,7 @@ def test_http_request_tool_gets_402(self, middleware, testnet_url): print(f"\n[http_request raw result]: {result[:200]}...") assert "PAYMENT_REQUIRED:" in result, f"Expected 402 from testnet, got: {result[:100]}" - parsed = json.loads(result[len("PAYMENT_REQUIRED: "):]) + parsed = json.loads(result[len("PAYMENT_REQUIRED: ") :]) assert parsed["statusCode"] == 402 print(f"[402 body keys]: {list(parsed.get('body', {}).keys())}") @@ -207,7 +210,7 @@ def handler(req): injected = tool_args["parameters"]["headers"] print(f"[MCP parameters.headers]: {list(injected.keys())}") assert len(injected) > 0, "No payment header injected into parameters.headers" - print(f"[MCP Gateway flow succeeded with 200]") + print("[MCP Gateway flow succeeded with 200]") def test_payment_header_was_injected(self, middleware, testnet_url): """After wrap_tool_call, the tool_args dict has a payment header.""" @@ -233,10 +236,7 @@ def handler(req): print(f"\n[Injected headers]: {list(injected_headers.keys())}") assert len(injected_headers) > 0, "No payment header was injected" # Common header names: X-PAYMENT (v1) or PAYMENT-SIGNATURE (v2) - has_payment_header = any( - k.upper() in ("X-PAYMENT", "PAYMENT-SIGNATURE", "PAYMENT") - for k in injected_headers - ) + has_payment_header = any(k.upper() in ("X-PAYMENT", "PAYMENT-SIGNATURE", "PAYMENT") for k in injected_headers) assert has_payment_header, f"Expected payment header, got: {list(injected_headers.keys())}" @@ -303,7 +303,7 @@ def handler(req): parsed = json.loads(result.content) assert parsed["statusCode"] == 200, f"Expected 200, got {parsed.get('statusCode')}" - print(f"[Fallback detection flow succeeded — no marker, no custom handler, got 200]") + print("[Fallback detection flow succeeded — no marker, no custom handler, got 200]") class TestCustomHandlerRegistry: @@ -335,6 +335,7 @@ def apply_payment_header(self, tool_input, payment_header): custom_handler = TrackingHandler() from dataclasses import replace + custom_config = replace(config, custom_handlers={"my_http_tool": custom_handler}) mw = AgentCorePaymentsMiddleware(custom_config) @@ -357,18 +358,19 @@ def handler(req): assert custom_handler.detect_called, "Custom handler's extract_status_code was not invoked" assert custom_handler.extract_called, "Custom handler's extract_headers was not invoked" assert custom_handler.inject_called, "Custom handler's apply_payment_header was not invoked" - print(f"[Custom handler used for detection: ✓, extraction: ✓, injection: ✓]") + print("[Custom handler used for detection: ✓, extraction: ✓, injection: ✓]") assert call_count[0] == 2, f"Expected 2 calls, got {call_count[0]}" assert "PAYMENT ERROR" not in result.content, f"Payment failed: {result.content}" parsed = json.loads(result.content) assert parsed["statusCode"] == 200 - print(f"[Custom handler flow succeeded with 200]") + print("[Custom handler flow succeeded with 200]") def test_custom_handler_non_marker_tool(self, config, testnet_url): """Custom handler detects 402 from a tool that does NOT use the PAYMENT_REQUIRED: marker.""" import httpx + from bedrock_agentcore.payments.integrations.handlers import PaymentResponseHandler # Custom handler that detects 402 from raw JSON (no marker prefix) @@ -381,6 +383,7 @@ def __init__(self): def extract_status_code(self, result): self.detect_called = True import json as _json + # result is {"content": [{"text": "..."}]} from _prepare_for_handler content = result.get("content", []) for block in content: @@ -395,6 +398,7 @@ def extract_status_code(self, result): def extract_headers(self, result): import json as _json + content = result.get("content", []) for block in content: text = block.get("text", "") if isinstance(block, dict) else "" @@ -408,6 +412,7 @@ def extract_headers(self, result): def extract_body(self, result): import json as _json + content = result.get("content", []) for block in content: text = block.get("text", "") if isinstance(block, dict) else "" @@ -431,6 +436,7 @@ def apply_payment_header(self, tool_input, payment_header): custom_handler = RawJsonHandler() from dataclasses import replace + custom_config = replace(config, custom_handlers={"raw_http_tool": custom_handler}) mw = AgentCorePaymentsMiddleware(custom_config) @@ -476,7 +482,7 @@ def handler(req): parsed = json.loads(result.content) assert parsed["statusCode"] == 200, f"Expected 200, got {parsed.get('statusCode')}" - print(f"[Non-marker custom handler flow succeeded with 200]") + print("[Non-marker custom handler flow succeeded with 200]") from bedrock_agentcore.payments.integrations.handlers import GenericPaymentHandler # Custom handler that tracks all three phases @@ -502,6 +508,7 @@ def apply_payment_header(self, tool_input, payment_header): # Create middleware with custom handler for "my_http_tool" from dataclasses import replace + custom_config = replace(config, custom_handlers={"my_http_tool": custom_handler}) mw = AgentCorePaymentsMiddleware(custom_config) @@ -526,7 +533,7 @@ def handler(req): assert custom_handler.detect_called, "Custom handler's extract_status_code was not invoked" assert custom_handler.extract_called, "Custom handler's extract_headers was not invoked" assert custom_handler.inject_called, "Custom handler's apply_payment_header was not invoked" - print(f"[Custom handler used for detection: ✓, extraction: ✓, injection: ✓]") + print("[Custom handler used for detection: ✓, extraction: ✓, injection: ✓]") # Full flow still works (402 → sign → retry → 200) assert call_count[0] == 2, f"Expected 2 calls, got {call_count[0]}" @@ -534,4 +541,4 @@ def handler(req): parsed = json.loads(result.content) assert parsed["statusCode"] == 200 - print(f"[Custom handler flow succeeded with 200]") + print("[Custom handler flow succeeded with 200]") diff --git a/tests/bedrock_agentcore/payments/integrations/langgraph/test_stage1.py b/tests/bedrock_agentcore/payments/integrations/langgraph/test_stage1.py index 0563675e..4ae6c2ff 100644 --- a/tests/bedrock_agentcore/payments/integrations/langgraph/test_stage1.py +++ b/tests/bedrock_agentcore/payments/integrations/langgraph/test_stage1.py @@ -1,14 +1,13 @@ """Tests for LangGraph AgentCorePaymentsConfig and AgentCorePaymentsMiddleware.""" import asyncio -from unittest.mock import MagicMock, AsyncMock, patch +from unittest.mock import AsyncMock, MagicMock, patch import pytest +from bedrock_agentcore.payments.integrations.handlers import GenericPaymentHandler from bedrock_agentcore.payments.integrations.langgraph import AgentCorePaymentsConfig from bedrock_agentcore.payments.integrations.langgraph.middleware import AgentCorePaymentsMiddleware -from bedrock_agentcore.payments.integrations.handlers import GenericPaymentHandler, PaymentResponseHandler - # --------------------------------------------------------------------------- # Config validation tests diff --git a/tests/bedrock_agentcore/payments/integrations/langgraph/test_stage2.py b/tests/bedrock_agentcore/payments/integrations/langgraph/test_stage2.py index dc302348..d4bef343 100644 --- a/tests/bedrock_agentcore/payments/integrations/langgraph/test_stage2.py +++ b/tests/bedrock_agentcore/payments/integrations/langgraph/test_stage2.py @@ -3,7 +3,6 @@ import json from unittest.mock import MagicMock, patch -import pytest from langchain.messages import ToolMessage from langgraph.types import Command diff --git a/tests/bedrock_agentcore/payments/integrations/langgraph/test_stage3.py b/tests/bedrock_agentcore/payments/integrations/langgraph/test_stage3.py index 6b85d403..7bfc0918 100644 --- a/tests/bedrock_agentcore/payments/integrations/langgraph/test_stage3.py +++ b/tests/bedrock_agentcore/payments/integrations/langgraph/test_stage3.py @@ -1,11 +1,10 @@ """Tests for Stage 3: Payment Signing + Retry.""" import json -from unittest.mock import MagicMock, patch, call +from unittest.mock import MagicMock, patch import pytest from langchain.messages import ToolMessage -from langgraph.types import Command from bedrock_agentcore.payments.integrations.langgraph import AgentCorePaymentsConfig from bedrock_agentcore.payments.integrations.langgraph.middleware import AgentCorePaymentsMiddleware @@ -170,10 +169,12 @@ def test_handler_called_twice(self, mock_pm_cls): mw = AgentCorePaymentsMiddleware(config) request = _make_request(tool_args={"url": "http://x.com", "headers": {}}) - mock_handler = MagicMock(side_effect=[ - ToolMessage(content=_402_content(), tool_call_id="tc-1"), - ToolMessage(content=_200_content(), tool_call_id="tc-1"), - ]) + mock_handler = MagicMock( + side_effect=[ + ToolMessage(content=_402_content(), tool_call_id="tc-1"), + ToolMessage(content=_200_content(), tool_call_id="tc-1"), + ] + ) mw.wrap_tool_call(request, mock_handler) assert mock_handler.call_count == 2 @@ -199,10 +200,12 @@ def test_402_after_signing_returns_error(self, mock_pm_cls): request = _make_request(tool_args={"url": "http://x.com", "headers": {}}) # Both calls return 402 - mock_handler = MagicMock(side_effect=[ - ToolMessage(content=_402_content(), tool_call_id="tc-1"), - ToolMessage(content=_402_content(), tool_call_id="tc-1"), - ]) + mock_handler = MagicMock( + side_effect=[ + ToolMessage(content=_402_content(), tool_call_id="tc-1"), + ToolMessage(content=_402_content(), tool_call_id="tc-1"), + ] + ) result = mw.wrap_tool_call(request, mock_handler) @@ -222,17 +225,21 @@ def test_rejection_error_includes_body_error(self, mock_pm_cls): request = _make_request(tool_args={"url": "http://x.com", "headers": {}}) - payload_with_error = json.dumps({ - "statusCode": 402, - "headers": {}, - "body": {"error": "insufficient_balance"}, - }) + payload_with_error = json.dumps( + { + "statusCode": 402, + "headers": {}, + "body": {"error": "insufficient_balance"}, + } + ) content_402_with_error = f"PAYMENT_REQUIRED: {payload_with_error}" - mock_handler = MagicMock(side_effect=[ - ToolMessage(content=_402_content(), tool_call_id="tc-1"), - ToolMessage(content=content_402_with_error, tool_call_id="tc-1"), - ]) + mock_handler = MagicMock( + side_effect=[ + ToolMessage(content=_402_content(), tool_call_id="tc-1"), + ToolMessage(content=content_402_with_error, tool_call_id="tc-1"), + ] + ) result = mw.wrap_tool_call(request, mock_handler) assert "insufficient_balance" in result.content @@ -256,10 +263,12 @@ def test_delay_applied_before_retry(self, mock_pm_cls, mock_sleep): mw = AgentCorePaymentsMiddleware(config) request = _make_request(tool_args={"url": "http://x.com", "headers": {}}) - mock_handler = MagicMock(side_effect=[ - ToolMessage(content=_402_content(), tool_call_id="tc-1"), - ToolMessage(content=_200_content(), tool_call_id="tc-1"), - ]) + mock_handler = MagicMock( + side_effect=[ + ToolMessage(content=_402_content(), tool_call_id="tc-1"), + ToolMessage(content=_200_content(), tool_call_id="tc-1"), + ] + ) mw.wrap_tool_call(request, mock_handler) mock_sleep.assert_called_once_with(3.0) @@ -274,10 +283,12 @@ def test_zero_delay_skips_sleep(self, mock_pm_cls, mock_sleep): mw = AgentCorePaymentsMiddleware(config) request = _make_request(tool_args={"url": "http://x.com", "headers": {}}) - mock_handler = MagicMock(side_effect=[ - ToolMessage(content=_402_content(), tool_call_id="tc-1"), - ToolMessage(content=_200_content(), tool_call_id="tc-1"), - ]) + mock_handler = MagicMock( + side_effect=[ + ToolMessage(content=_402_content(), tool_call_id="tc-1"), + ToolMessage(content=_200_content(), tool_call_id="tc-1"), + ] + ) mw.wrap_tool_call(request, mock_handler) mock_sleep.assert_not_called() diff --git a/tests/bedrock_agentcore/payments/integrations/langgraph/test_stage4.py b/tests/bedrock_agentcore/payments/integrations/langgraph/test_stage4.py index d7b46a20..b800f954 100644 --- a/tests/bedrock_agentcore/payments/integrations/langgraph/test_stage4.py +++ b/tests/bedrock_agentcore/payments/integrations/langgraph/test_stage4.py @@ -3,7 +3,6 @@ import json from unittest.mock import MagicMock, patch -import pytest from langchain.messages import ToolMessage from bedrock_agentcore.payments.integrations.langgraph import AgentCorePaymentsConfig @@ -11,9 +10,7 @@ from bedrock_agentcore.payments.manager import ( InsufficientBudget, PaymentError, - PaymentInstrumentConfigurationRequired, PaymentInstrumentNotFound, - PaymentSessionConfigurationRequired, PaymentSessionExpired, PaymentSessionNotFound, ) @@ -147,10 +144,12 @@ def test_post_payment_rejection(self, mock_pm_cls): payload_with_error = json.dumps({"statusCode": 402, "headers": {}, "body": {"error": "bad_sig"}}) request = _make_request(tool_args={"url": "http://x.com", "headers": {}}) - handler = MagicMock(side_effect=[ - ToolMessage(content=_402_content(), tool_call_id="tc-1"), - ToolMessage(content=f"PAYMENT_REQUIRED: {payload_with_error}", tool_call_id="tc-1"), - ]) + handler = MagicMock( + side_effect=[ + ToolMessage(content=_402_content(), tool_call_id="tc-1"), + ToolMessage(content=f"PAYMENT_REQUIRED: {payload_with_error}", tool_call_id="tc-1"), + ] + ) result = mw.wrap_tool_call(request, handler) assert "signed but rejected" in result.content diff --git a/tests/bedrock_agentcore/payments/integrations/langgraph/test_stage5.py b/tests/bedrock_agentcore/payments/integrations/langgraph/test_stage5.py index 6c46a155..e3bf364d 100644 --- a/tests/bedrock_agentcore/payments/integrations/langgraph/test_stage5.py +++ b/tests/bedrock_agentcore/payments/integrations/langgraph/test_stage5.py @@ -4,7 +4,6 @@ import json from unittest.mock import AsyncMock, MagicMock, patch -import pytest from langchain.messages import ToolMessage from langgraph.types import Command @@ -84,14 +83,19 @@ def test_402_then_200_on_retry(self, mock_pm_cls): mw = AgentCorePaymentsMiddleware(_make_config()) success_msg = ToolMessage(content=_200_content(), tool_call_id="tc-1") - handler = AsyncMock(side_effect=[ - ToolMessage(content=_402_content(), tool_call_id="tc-1"), - success_msg, - ]) - - result = asyncio.run(mw.awrap_tool_call( - _make_request(tool_args={"url": "http://x.com", "headers": {}}), handler, - )) + handler = AsyncMock( + side_effect=[ + ToolMessage(content=_402_content(), tool_call_id="tc-1"), + success_msg, + ] + ) + + result = asyncio.run( + mw.awrap_tool_call( + _make_request(tool_args={"url": "http://x.com", "headers": {}}), + handler, + ) + ) assert result is success_msg @patch("bedrock_agentcore.payments.integrations.langgraph.middleware.PaymentManager") @@ -99,10 +103,12 @@ def test_handler_awaited_twice(self, mock_pm_cls): mock_pm_cls.return_value.generate_payment_header.return_value = {"X-PAYMENT": "sig"} mw = AgentCorePaymentsMiddleware(_make_config()) - handler = AsyncMock(side_effect=[ - ToolMessage(content=_402_content(), tool_call_id="tc-1"), - ToolMessage(content=_200_content(), tool_call_id="tc-1"), - ]) + handler = AsyncMock( + side_effect=[ + ToolMessage(content=_402_content(), tool_call_id="tc-1"), + ToolMessage(content=_200_content(), tool_call_id="tc-1"), + ] + ) asyncio.run(mw.awrap_tool_call(_make_request(tool_args={"url": "http://x.com", "headers": {}}), handler)) assert handler.await_count == 2 @@ -122,10 +128,12 @@ def test_asyncio_sleep_called(self, mock_pm_cls): config = _make_config(post_payment_retry_delay_seconds=3.0) mw = AgentCorePaymentsMiddleware(config) - handler = AsyncMock(side_effect=[ - ToolMessage(content=_402_content(), tool_call_id="tc-1"), - ToolMessage(content=_200_content(), tool_call_id="tc-1"), - ]) + handler = AsyncMock( + side_effect=[ + ToolMessage(content=_402_content(), tool_call_id="tc-1"), + ToolMessage(content=_200_content(), tool_call_id="tc-1"), + ] + ) with patch("asyncio.sleep", new_callable=AsyncMock) as mock_async_sleep: asyncio.run(mw.awrap_tool_call(_make_request(tool_args={"url": "http://x.com", "headers": {}}), handler)) @@ -138,10 +146,12 @@ def test_time_sleep_not_called(self, mock_pm_cls, mock_time_sleep): config = _make_config(post_payment_retry_delay_seconds=3.0) mw = AgentCorePaymentsMiddleware(config) - handler = AsyncMock(side_effect=[ - ToolMessage(content=_402_content(), tool_call_id="tc-1"), - ToolMessage(content=_200_content(), tool_call_id="tc-1"), - ]) + handler = AsyncMock( + side_effect=[ + ToolMessage(content=_402_content(), tool_call_id="tc-1"), + ToolMessage(content=_200_content(), tool_call_id="tc-1"), + ] + ) with patch("asyncio.sleep", new_callable=AsyncMock): asyncio.run(mw.awrap_tool_call(_make_request(tool_args={"url": "http://x.com", "headers": {}}), handler)) @@ -162,10 +172,12 @@ def test_generate_header_runs_in_thread(self, mock_pm_cls): mock_pm_cls.return_value.generate_payment_header.return_value = {"X-PAYMENT": "sig"} mw = AgentCorePaymentsMiddleware(_make_config()) - handler = AsyncMock(side_effect=[ - ToolMessage(content=_402_content(), tool_call_id="tc-1"), - ToolMessage(content=_200_content(), tool_call_id="tc-1"), - ]) + handler = AsyncMock( + side_effect=[ + ToolMessage(content=_402_content(), tool_call_id="tc-1"), + ToolMessage(content=_200_content(), tool_call_id="tc-1"), + ] + ) with patch("asyncio.to_thread", new_callable=AsyncMock, return_value={"X-PAYMENT": "sig"}) as mock_to_thread: asyncio.run(mw.awrap_tool_call(_make_request(tool_args={"url": "http://x.com", "headers": {}}), handler)) @@ -189,9 +201,12 @@ def test_missing_instrument_error(self, mock_pm_cls): handler = AsyncMock(return_value=ToolMessage(content=_402_content(), tool_call_id="tc-1")) - result = asyncio.run(mw.awrap_tool_call( - _make_request(tool_args={"url": "http://x.com", "headers": {}}), handler, - )) + result = asyncio.run( + mw.awrap_tool_call( + _make_request(tool_args={"url": "http://x.com", "headers": {}}), + handler, + ) + ) assert "PAYMENT ERROR" in result.content assert "No payment instrument configured" in result.content assert result.status == "error" @@ -201,14 +216,19 @@ def test_post_payment_rejection(self, mock_pm_cls): mock_pm_cls.return_value.generate_payment_header.return_value = {"X-PAYMENT": "sig"} mw = AgentCorePaymentsMiddleware(_make_config()) - handler = AsyncMock(side_effect=[ - ToolMessage(content=_402_content(), tool_call_id="tc-1"), - ToolMessage(content=_402_content(), tool_call_id="tc-1"), - ]) - - result = asyncio.run(mw.awrap_tool_call( - _make_request(tool_args={"url": "http://x.com", "headers": {}}), handler, - )) + handler = AsyncMock( + side_effect=[ + ToolMessage(content=_402_content(), tool_call_id="tc-1"), + ToolMessage(content=_402_content(), tool_call_id="tc-1"), + ] + ) + + result = asyncio.run( + mw.awrap_tool_call( + _make_request(tool_args={"url": "http://x.com", "headers": {}}), + handler, + ) + ) assert "signed but rejected" in result.content @patch("bedrock_agentcore.payments.integrations.langgraph.middleware.PaymentManager") @@ -218,9 +238,12 @@ def test_unexpected_exception(self, mock_pm_cls): handler = AsyncMock(return_value=ToolMessage(content=_402_content(), tool_call_id="tc-1")) - result = asyncio.run(mw.awrap_tool_call( - _make_request(tool_args={"url": "http://x.com", "headers": {}}), handler, - )) + result = asyncio.run( + mw.awrap_tool_call( + _make_request(tool_args={"url": "http://x.com", "headers": {}}), + handler, + ) + ) assert isinstance(result, ToolMessage) assert "unexpected error" in result.content assert "async boom" in result.content diff --git a/tests/bedrock_agentcore/payments/integrations/langgraph/test_stage6.py b/tests/bedrock_agentcore/payments/integrations/langgraph/test_stage6.py index eb9d9a03..919091c8 100644 --- a/tests/bedrock_agentcore/payments/integrations/langgraph/test_stage6.py +++ b/tests/bedrock_agentcore/payments/integrations/langgraph/test_stage6.py @@ -1,9 +1,7 @@ """Tests for Stage 6: Built-in Tools.""" import json -from unittest.mock import MagicMock, patch, PropertyMock - -import pytest +from unittest.mock import MagicMock, patch from bedrock_agentcore.payments.integrations.langgraph import AgentCorePaymentsConfig from bedrock_agentcore.payments.integrations.langgraph.middleware import AgentCorePaymentsMiddleware @@ -85,7 +83,7 @@ def test_402_returns_payment_required_marker(self, mock_pm_cls): result = tool.invoke({"url": "http://example.com"}) assert result.startswith("PAYMENT_REQUIRED: ") - parsed = json.loads(result[len("PAYMENT_REQUIRED: "):]) + parsed = json.loads(result[len("PAYMENT_REQUIRED: ") :]) assert parsed["statusCode"] == 402 @patch("bedrock_agentcore.payments.integrations.langgraph.middleware.PaymentManager") diff --git a/tests/bedrock_agentcore/payments/integrations/langgraph/test_stage7.py b/tests/bedrock_agentcore/payments/integrations/langgraph/test_stage7.py index 39b19db6..1d03cab1 100644 --- a/tests/bedrock_agentcore/payments/integrations/langgraph/test_stage7.py +++ b/tests/bedrock_agentcore/payments/integrations/langgraph/test_stage7.py @@ -6,10 +6,9 @@ import pytest from langchain.messages import ToolMessage -from langgraph.types import Command from bedrock_agentcore.payments.integrations.langgraph import AgentCorePaymentsConfig -from bedrock_agentcore.payments.integrations.langgraph.errors import ErrorResolution, PaymentErrorContext +from bedrock_agentcore.payments.integrations.langgraph.errors import ErrorResolution from bedrock_agentcore.payments.integrations.langgraph.middleware import AgentCorePaymentsMiddleware from bedrock_agentcore.payments.manager import ( InsufficientBudget, @@ -76,10 +75,12 @@ def handler_cb(ctx): mw = AgentCorePaymentsMiddleware(config) request = _make_request(tool_args={"url": "http://x.com", "headers": {}}) - mock_handler = MagicMock(side_effect=[ - ToolMessage(content=_402_content(), tool_call_id="tc-1"), - ToolMessage(content=_200_content(), tool_call_id="tc-1"), - ]) + mock_handler = MagicMock( + side_effect=[ + ToolMessage(content=_402_content(), tool_call_id="tc-1"), + ToolMessage(content=_200_content(), tool_call_id="tc-1"), + ] + ) result = mw.wrap_tool_call(request, mock_handler) assert "PAYMENT ERROR" not in result.content @@ -269,10 +270,12 @@ async def async_cb(ctx): mw = AgentCorePaymentsMiddleware(config) request = _make_request(tool_args={"url": "http://x.com", "headers": {}}) - handler = AsyncMock(side_effect=[ - ToolMessage(content=_402_content(), tool_call_id="tc-1"), - ToolMessage(content=_200_content(), tool_call_id="tc-1"), - ]) + handler = AsyncMock( + side_effect=[ + ToolMessage(content=_402_content(), tool_call_id="tc-1"), + ToolMessage(content=_200_content(), tool_call_id="tc-1"), + ] + ) result = asyncio.run(mw.awrap_tool_call(request, handler)) assert "PAYMENT ERROR" not in result.content @@ -293,10 +296,12 @@ def sync_cb(ctx): mw = AgentCorePaymentsMiddleware(config) request = _make_request(tool_args={"url": "http://x.com", "headers": {}}) - handler = AsyncMock(side_effect=[ - ToolMessage(content=_402_content(), tool_call_id="tc-1"), - ToolMessage(content=_200_content(), tool_call_id="tc-1"), - ]) + handler = AsyncMock( + side_effect=[ + ToolMessage(content=_402_content(), tool_call_id="tc-1"), + ToolMessage(content=_200_content(), tool_call_id="tc-1"), + ] + ) result = asyncio.run(mw.awrap_tool_call(request, handler)) assert "PAYMENT ERROR" not in result.content