diff --git a/.github/ISSUE_TEMPLATE/feature_request.yml b/.github/ISSUE_TEMPLATE/feature_request.yml index e5f36b2c..40442a01 100644 --- a/.github/ISSUE_TEMPLATE/feature_request.yml +++ b/.github/ISSUE_TEMPLATE/feature_request.yml @@ -8,6 +8,10 @@ body: attributes: value: | Thanks for suggesting a new feature for Strands Agents Tools! + + > **Note:** We are not accepting new tools into this repository. If you'd like to build a new tool, we recommend using our [extension template](https://github.com/strands-agents/extension-template-python) to publish it as a standalone package. You can then get it featured in our [community catalog](https://strandsagents.com/docs/community/get-featured/). + > + > We still welcome feature requests for **improvements to existing tools**. - type: textarea id: problem-statement attributes: diff --git a/.github/PULL_REQUEST_TEMPLATE.md b/.github/PULL_REQUEST_TEMPLATE.md index aec71af0..f0888b3e 100644 --- a/.github/PULL_REQUEST_TEMPLATE.md +++ b/.github/PULL_REQUEST_TEMPLATE.md @@ -14,8 +14,9 @@ +> Please note that we are not accepting new tools into this repository. Instead, we recommend using our [extension template](https://github.com/strands-agents/extension-template-python) to publish your own tool package and get it featured in our [community catalog](https://strandsagents.com/docs/community/get-featured/). + Bug fix -New Tool Breaking change Documentation update Other (please describe): diff --git a/.github/workflows/strands-command.yml b/.github/workflows/strands-command.yml new file mode 100644 index 00000000..33874e24 --- /dev/null +++ b/.github/workflows/strands-command.yml @@ -0,0 +1,92 @@ +name: Strands Command Handler + +on: + issue_comment: + types: [created] + workflow_dispatch: + inputs: + issue_id: + description: 'Issue ID to process (can be issue or PR number)' + required: true + type: string + command: + description: 'Strands command to execute' + required: false + type: string + default: '' + session_id: + description: 'Optional session ID to use' + required: false + type: string + default: '' + +jobs: + authorization-check: + if: startsWith(github.event.comment.body, '/strands') || github.event_name == 'workflow_dispatch' + name: Check access + permissions: read-all + runs-on: ubuntu-latest + outputs: + approval-env: ${{ steps.auth.outputs.approval-env }} + steps: + - name: Check Authorization + id: auth + uses: strands-agents/devtools/authorization-check@main + with: + skip-check: ${{ github.event_name == 'workflow_dispatch' }} + username: ${{ github.event.comment.user.login || 'invalid' }} + allowed-roles: 'maintain,triage,write,admin' + + setup-and-process: + needs: [authorization-check] + environment: ${{ needs.authorization-check.outputs.approval-env }} + permissions: + # Needed to create a branch for the Implementer Agent + contents: write + # These both are needed to add the `strands-running` label to issues and prs + issues: write + pull-requests: write + runs-on: ubuntu-latest + steps: + - name: Parse input + id: parse + uses: strands-agents/devtools/strands-command/actions/strands-input-parser@main + with: + issue_id: ${{ inputs.issue_id }} + command: ${{ inputs.command }} + session_id: ${{ inputs.session_id }} + + execute-readonly-agent: + needs: [setup-and-process] + permissions: + contents: read + issues: read + pull-requests: read + id-token: write # Required for OIDC + runs-on: ubuntu-latest + timeout-minutes: 60 + steps: + + # Add any steps here to set up the environment for the Agent in your repo + # setup node, setup python, or any other dependencies + + - name: Run Strands Agent + id: agent-runner + uses: strands-agents/devtools/strands-command/actions/strands-agent-runner@main + with: + aws_role_arn: ${{ secrets.AWS_ROLE_ARN }} + sessions_bucket: ${{ secrets.AGENT_SESSIONS_BUCKET }} + write_permission: 'false' + + finalize: + if: always() && (startsWith(github.event.comment.body, '/strands') || github.event_name == 'workflow_dispatch') + needs: [setup-and-process, execute-readonly-agent] + permissions: + contents: write + issues: write + pull-requests: write + runs-on: ubuntu-latest + timeout-minutes: 30 + steps: + - name: Execute write operations + uses: strands-agents/devtools/strands-command/actions/strands-finalize@main diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 4e9ea627..e8f71776 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -7,6 +7,27 @@ Please read through this document before submitting any issues or pull requests information to effectively respond to your bug report or contribution. +## New Tools Policy + +**We are not accepting new tools into this repository.** Instead, we recommend publishing new tools as standalone community packages — this way you own your release cycle and can iterate independently. + +**What we accept:** +- Bug fixes for existing tools +- Documentation improvements +- Performance enhancements to existing tools +- Test coverage improvements + +**What we don't accept:** +- New tool submissions (PRs adding new tools will be closed) +- New tool feature requests (issues requesting new tools will be closed) + +**Want to build a tool?** Use our [extension template](https://github.com/strands-agents/extension-template-python) to scaffold your own tool package and publish it to PyPI. Once published, you can get it featured in our docs and community catalog: + +- Extension template: https://github.com/strands-agents/extension-template-python +- Get featured in docs: https://strandsagents.com/docs/community/get-featured/ +- Contribution guide: https://strandsagents.com/docs/contribute/ + + ## Reporting Bugs/Feature Requests We welcome you to use the [Bug Reports](../../issues/new?template=bug_report.yml) file to report bugs or [Feature Requests](../../issues/new?template=feature_request.yml) to suggest features. diff --git a/README.md b/README.md index e945edf4..bca32b1d 100644 --- a/README.md +++ b/README.md @@ -53,6 +53,8 @@ Strands Agents Tools is a community-driven project that provides a powerful set - ⏱️ **Task Scheduling** - Schedule and manage cron jobs - 🧠 **Advanced Reasoning** - Tools for complex thinking and reasoning capabilities - 🐝 **Swarm Intelligence** - Coordinate multiple AI agents for parallel problem solving with shared memory +- 🤖 **Agent as Tool** - Create nested agent instances with model switching support for multi-model workflows and specialized sub-tasks +- 🔗 **Multi-Agent Graph** - Create and manage deterministic DAG-based multi-agent pipelines with output propagation and per-node model configuration - 🔌 **Dynamic MCP Client** - ⚠️ Dynamically connect to external MCP servers and load remote tools (use with caution - see security warnings) - 🔄 **Multiple tools in Parallel** - Call multiple other tools at the same time in parallel with Batch Tool - 🔍 **Browser Tool** - Tool giving an agent access to perform automated actions on a browser (chromium) @@ -99,6 +101,9 @@ Below is a comprehensive table of all available tools, how to use them with an a | Tool | Agent Usage | Use Case | |------|-------------|----------| | a2a_client | `provider = A2AClientToolProvider(known_agent_urls=["http://localhost:9000"]); agent = Agent(tools=provider.tools)` | Discover and communicate with A2A-compliant agents, send messages between agents | +| apify_run_actor | `agent.tool.apify_run_actor(actor_id="apify/website-content-crawler", run_input={"startUrls": [{"url": "https://example.com"}]})` | Run any Apify Actor with arbitrary input | +| apify_scrape_url | `agent.tool.apify_scrape_url(url="https://example.com")` | Scrape a URL and return its content as markdown | +| apify_google_search_scraper | `agent.tool.apify_google_search_scraper(search_query="best AI frameworks")` | Search Google and return structured results | | file_read | `agent.tool.file_read(path="path/to/file.txt")` | Reading configuration files, parsing code files, loading datasets | | file_write | `agent.tool.file_write(path="path/to/file.txt", content="file content")` | Writing results to files, creating new files, saving output data | | editor | `agent.tool.editor(command="view", path="path/to/file.py")` | Advanced file operations like syntax highlighting, pattern replacement, and multi-file edits | @@ -108,7 +113,7 @@ Below is a comprehensive table of all available tools, how to use them with an a | tavily_extract | `agent.tool.tavily_extract(urls=["www.tavily.com"], extract_depth="advanced")` | Extract clean, structured content from web pages with advanced processing and noise removal | | tavily_crawl | `agent.tool.tavily_crawl(url="www.tavily.com", max_depth=2, instructions="Find API docs")` | Crawl websites intelligently starting from a base URL with filtering and extraction | | tavily_map | `agent.tool.tavily_map(url="www.tavily.com", max_depth=2, instructions="Find all pages")` | Map website structure and discover URLs starting from a base URL without content extraction | -| exa_search | `agent.tool.exa_search(query="Best project management tools", text=True)` | Intelligent web search with auto mode (default) that combines neural and keyword search for optimal results | +| exa_search | `agent.tool.exa_search(query="Best project management tools", text=True)` | Intelligent web search with auto mode (default) for optimal results, plus fast and deep search modes | | exa_get_contents | `agent.tool.exa_get_contents(urls=["https://example.com/article"], text=True, summary={"query": "key points"})` | Extract full content and summaries from specific URLs with live crawling fallback | | python_repl* | `agent.tool.python_repl(code="import pandas as pd\ndf = pd.read_csv('data.csv')\nprint(df.head())")` | Running Python code snippets, data analysis, executing complex logic with user confirmation for security | | calculator | `agent.tool.calculator(expression="2 * sin(pi/4) + log(e**2)")` | Performing mathematical operations, symbolic math, equation solving | @@ -131,12 +136,14 @@ Below is a comprehensive table of all available tools, how to use them with an a | current_time | `agent.tool.current_time(timezone="US/Pacific")` | Get the current time in ISO 8601 format for a specified timezone | | sleep | `agent.tool.sleep(seconds=5)` | Pause execution for the specified number of seconds, interruptible with SIGINT (Ctrl+C) | | agent_graph | `agent.tool.agent_graph(agents=["agent1", "agent2"], connections=[{"from": "agent1", "to": "agent2"}])` | Create and visualize agent relationship graphs for complex multi-agent systems | +| graph | `agent.tool.graph(action="create", graph_id="pipeline", topology={"nodes": [...], "edges": [...]})` | Create and manage deterministic DAG-based multi-agent graphs using Strands SDK Graph implementation with per-node model configuration | | cron* | `agent.tool.cron(action="schedule", name="task", schedule="0 * * * *", command="backup.sh")` | Schedule and manage recurring tasks with cron job syntax
**Does not work on Windows | | slack | `agent.tool.slack(action="post_message", channel="general", text="Hello team!")` | Interact with Slack workspace for messaging and monitoring | | speak | `agent.tool.speak(text="Operation completed successfully", style="green", mode="polly")` | Output status messages with rich formatting and optional text-to-speech | | stop | `agent.tool.stop(message="Process terminated by user request")` | Gracefully terminate agent execution with custom message | | handoff_to_user | `agent.tool.handoff_to_user(message="Please confirm action", breakout_of_loop=False)` | Hand off control to user for confirmation, input, or complete task handoff | | use_llm | `agent.tool.use_llm(prompt="Analyze this data", system_prompt="You are a data analyst")` | Create nested AI loops with customized system prompts for specialized tasks | +| use_agent | `agent.tool.use_agent(prompt="Analyze this code", system_prompt="You are a code analyst.", model_provider="bedrock")` | Create nested agent instances with model switching, multi-model workflows, cost optimization, and specialized sub-tasks | | workflow | `agent.tool.workflow(action="create", name="data_pipeline", steps=[{"tool": "file_read"}, {"tool": "python_repl"}])` | Define, execute, and manage multi-step automated workflows | | mcp_client | `agent.tool.mcp_client(action="connect", connection_id="my_server", transport="stdio", command="python", args=["server.py"])` | ⚠️ **SECURITY WARNING**: Dynamically connect to external MCP servers via stdio, sse, or streamable_http, list tools, and call remote tools. This can pose security risks as agents may connect to malicious servers. Use with caution in production. | | batch| `agent.tool.batch(invocations=[{"name": "current_time", "arguments": {"timezone": "Europe/London"}}, {"name": "stop", "arguments": {}}])` | Call multiple other tools in parallel. | @@ -147,6 +154,7 @@ Below is a comprehensive table of all available tools, how to use them with an a | search_video | `agent.tool.search_video(query="people discussing AI")` | Semantic video search using TwelveLabs' Marengo model | | chat_video | `agent.tool.chat_video(prompt="What are the main topics?", video_id="video_123")` | Interactive video analysis using TwelveLabs' Pegasus model | | mongodb_memory | `agent.tool.mongodb_memory(action="record", content="User prefers vegetarian pizza", connection_string="mongodb+srv://...", database_name="memories")` | Store and retrieve memories using MongoDB Atlas with semantic search via AWS Bedrock Titan embeddings | +| elasticsearch_memory | `agent.tool.elasticsearch_memory(action="record", content="User prefers dark mode", cloud_id="...", api_key="...")` | Store and retrieve memories using Elasticsearch with semantic search via AWS Bedrock Titan embeddings | \* *These tools do not work on windows* @@ -679,6 +687,53 @@ agent.tool.handoff_to_user( ) ``` +### Use Agent (Agent as Tool) + +```python +from strands import Agent +from strands_tools import use_agent + +agent = Agent(tools=[use_agent]) + +# Basic usage - inherits parent agent's model +result = agent.tool.use_agent( + prompt="Tell me about the advantages of tool-building in AI agents", + system_prompt="You are a helpful AI assistant specializing in AI development concepts." +) + +# Use a different model provider for specialized tasks +result = agent.tool.use_agent( + prompt="Calculate 2 + 2 and explain the result", + system_prompt="You are a helpful math assistant.", + model_provider="bedrock", + model_settings={ + "model_id": "us.anthropic.claude-sonnet-4-20250514-v1:0" + }, + tools=["calculator"] +) + +# Use environment variables to determine model +import os +os.environ["STRANDS_PROVIDER"] = "ollama" +os.environ["STRANDS_MODEL_ID"] = "qwen3:4b" +result = agent.tool.use_agent( + prompt="Analyze this code", + system_prompt="You are a code review assistant.", + model_provider="env" +) + +# Custom model configuration with specific parameters +result = agent.tool.use_agent( + prompt="Write a creative story", + system_prompt="You are a creative writing assistant.", + model_provider="github", + model_settings={ + "model_id": "openai/o4-mini", + "params": {"temperature": 1, "max_tokens": 4000} + } +) +``` + ### A2A Client ```python @@ -814,6 +869,68 @@ result = agent.tool.use_computer( ) ``` +### Graph (Multi-Agent DAG) + +Create deterministic DAG-based multi-agent pipelines where agents are nodes with dependency relationships. Unlike `agent_graph` (which uses persistent message-passing), `graph` uses task-based execution with output propagation. + +```python +from strands import Agent +from strands_tools.graph import graph + +agent = Agent(tools=[graph]) + +# Create a multi-agent research pipeline +result = agent.tool.graph( + action="create", + graph_id="research_pipeline", + topology={ + "nodes": [ + { + "id": "researcher", + "role": "researcher", + "system_prompt": "You research topics thoroughly.", + "model_provider": "bedrock", + "model_settings": {"model_id": "us.anthropic.claude-sonnet-4-20250514-v1:0"} + }, + { + "id": "analyst", + "role": "analyst", + "system_prompt": "You analyze research data.", + "model_provider": "bedrock", + "model_settings": {"model_id": "us.anthropic.claude-3-5-haiku-20241022-v1:0"} + }, + { + "id": "reporter", + "role": "reporter", + "system_prompt": "You create comprehensive reports.", + "tools": ["file_write", "editor"] + } + ], + "edges": [ + {"from": "researcher", "to": "analyst"}, + {"from": "analyst", "to": "reporter"} + ], + "entry_points": ["researcher"] + } +) + +# Execute a task through the graph +result = agent.tool.graph( + action="execute", + graph_id="research_pipeline", + task="Research and analyze the impact of AI on healthcare" +) + +# Get graph status +result = agent.tool.graph(action="status", graph_id="research_pipeline") + +# List all graphs +result = agent.tool.graph(action="list") + +# Delete a graph +result = agent.tool.graph(action="delete", graph_id="research_pipeline") +``` + ### Elasticsearch Memory **Note**: This tool requires AWS account credentials to generate embeddings using Amazon Bedrock Titan models. @@ -960,6 +1077,54 @@ result = agent.tool.mongodb_memory( ) ``` +### Apify + +```python +from strands import Agent +from strands_tools.apify import APIFY_ALL_TOOLS + +agent = Agent(tools=APIFY_ALL_TOOLS) + +# Scrape a single URL and get Markdown content +content = agent.tool.apify_scrape_url(url="https://example.com") + +# Run an Actor and get results in one step +result = agent.tool.apify_run_actor_and_get_dataset( + actor_id="apify/website-content-crawler", + run_input={"startUrls": [{"url": "https://example.com"}]}, + dataset_items_limit=50, +) + +# Run a saved task (pre-configured Actor with default inputs) +run_info = agent.tool.apify_run_task(task_id="user/my-task") + +# Run a task and get results in one step +result = agent.tool.apify_run_task_and_get_dataset( + task_id="user/my-task", + task_input={"query": "override default input"}, + dataset_items_limit=50, +) + +# Run an Actor (get metadata only) +run_info = agent.tool.apify_run_actor( + actor_id="apify/google-search-scraper", + run_input={"queries": "AI agent frameworks"}, +) + +# Fetch dataset items separately +items = agent.tool.apify_get_dataset_items( + dataset_id="abc123", + limit=100, +) + +# Search Google +results = agent.tool.apify_google_search_scraper( + search_query="best AI frameworks 2025", + results_limit=10, +) + +``` + ## 🌍 Environment Variables Configuration Agents Tools provides extensive customization through environment variables. This allows you to configure tool behavior without modifying code, making it ideal for different environments (development, testing, production). @@ -1068,6 +1233,12 @@ The Mem0 Memory Tool supports three different backend configurations: - If `NEPTUNE_ANALYTICS_GRAPH_IDENTIFIER` is set, the tool will configure Neptune Analytics as graph store to enhance memory search - LLM configuration applies to all backend modes and allows customization of the language model used for memory processing +#### Apify Tool + +| Environment Variable | Description | Default | +|----------------------|-------------|---------| +| APIFY_API_TOKEN | Apify API token for authentication (required) | None | + #### Bright Data Tool | Environment Variable | Description | Default | @@ -1139,6 +1310,12 @@ The Mem0 Memory Tool supports three different backend configurations: |----------------------|-------------|---------| | ENV_VARS_MASKED_DEFAULT | Default setting for masking sensitive values | true | +#### HTTP Request Tool + +| Environment Variable | Description | Default | +|----------------------|-------------|---------| +| STRANDS_HTTP_ALLOW_INSECURE_SSL | Allow disabling SSL certificate verification via verify_ssl parameter | false | + #### Dynamic MCP Client Tool | Environment Variable | Description | Default | @@ -1184,6 +1361,34 @@ The Mem0 Memory Tool supports three different backend configurations: |----------------------|-------------|---------| | RETRIEVE_ENABLE_METADATA_DEFAULT | Default setting for enabling metadata in retrieve tool responses | false | +#### Use Agent Tool + +| Environment Variable | Description | Default | +|----------------------|-------------|---------| +| STRANDS_PROVIDER | Default model provider when using model_provider="env" | ollama | +| STRANDS_MODEL_ID | Default model identifier for environment-based model selection | None | +| STRANDS_MAX_TOKENS | Maximum tokens for the nested agent model | None | +| STRANDS_TEMPERATURE | Sampling temperature for the nested agent model | None | + + +#### Elasticsearch Memory Tool + +| Environment Variable | Description | Default | +|----------------------|-------------|---------| +| ELASTICSEARCH_CLOUD_ID | Elasticsearch Cloud ID for connection | None | +| ELASTICSEARCH_URL | Elasticsearch URL for serverless connection | None | +| ELASTICSEARCH_API_KEY | Elasticsearch API key for authentication | None | +| ELASTICSEARCH_INDEX_NAME | Elasticsearch index name for memory storage | strands_memory | +| ELASTICSEARCH_NAMESPACE | Namespace for memory isolation | default | +| ELASTICSEARCH_EMBEDDING_MODEL | Amazon Bedrock model for embeddings | amazon.titan-embed-text-v2:0 | +| AWS_REGION | AWS region for Bedrock embedding service | us-west-2 | + +**Note**: This tool requires AWS account credentials to generate embeddings using Amazon Bedrock Titan models. + +#### Graph Tool + +The `graph` tool uses the same model provider environment variables as `use_agent` for per-node model configuration. No additional environment variables are required. + #### Video Tools | Environment Variable | Description | Default | diff --git a/docs/apify_tool.md b/docs/apify_tool.md new file mode 100644 index 00000000..cd5ab715 --- /dev/null +++ b/docs/apify_tool.md @@ -0,0 +1,364 @@ +# Apify + +The Apify tools (`apify.py`) enable [Strands Agents](https://strandsagents.com/) to interact with the [Apify](https://apify.com) platform — running any [Actor](https://apify.com/store) or [task](https://docs.apify.com/platform/actors/running/tasks) by ID, fetching dataset results, and scraping individual URLs. + +## Installation + +```bash +pip install strands-agents-tools[apify] +``` + +## Configuration + +Set your Apify API token as an environment variable: + +```bash +export APIFY_API_TOKEN=apify_api_your_token_here +``` + +Get your token from [Apify Console](https://console.apify.com/account/integrations) → Settings → API & Integrations → Personal API tokens. + +## Usage + +Register all core tools at once: + +```python +from strands import Agent +from strands_tools.apify import APIFY_CORE_TOOLS + +agent = Agent(tools=APIFY_CORE_TOOLS) +``` + +Or pick individual tools: + +```python +from strands import Agent +from strands_tools import apify + +agent = Agent(tools=[ + apify.apify_run_actor, + apify.apify_scrape_url, +]) +``` + +### Scrape a URL + +The simplest way to extract content from any web page. Uses the [Website Content Crawler](https://apify.com/apify/website-content-crawler) Actor under the hood and returns the page content as Markdown: + +```python +content = agent.tool.apify_scrape_url(url="https://example.com") +``` + +### Run an Actor + +Execute any Actor from [Apify Store](https://apify.com/store) by its ID. The call blocks until the Actor run finishes or the timeout is reached: + +```python +result = agent.tool.apify_run_actor( + actor_id="apify/website-content-crawler", + run_input={"startUrls": [{"url": "https://example.com"}]}, + timeout_secs=300, +) +``` + +The result is a JSON string containing run metadata: `run_id`, `status`, `dataset_id`, `started_at`, and `finished_at`. + +### Run an Actor and Get Results + +Combine running an Actor and fetching its dataset results in a single call: + +```python +result = agent.tool.apify_run_actor_and_get_dataset( + actor_id="apify/website-content-crawler", + run_input={"startUrls": [{"url": "https://example.com"}]}, + dataset_items_limit=50, +) +``` + +### Run a task + +Execute a saved [Actor task](https://docs.apify.com/platform/actors/running/tasks) — a pre-configured Actor with preset inputs. Use this when a task has already been set up in Apify Console: + +```python +result = agent.tool.apify_run_task( + task_id="user~my-task", + task_input={"query": "override input"}, + timeout_secs=300, +) +``` + +The result is a JSON string containing run metadata: `run_id`, `status`, `dataset_id`, `started_at`, and `finished_at`. + +### Run a task and get results + +Combine running a task and fetching its dataset results in a single call: + +```python +result = agent.tool.apify_run_task_and_get_dataset( + task_id="user~my-task", + dataset_items_limit=50, +) +``` + +### Fetch dataset items + +Retrieve results from a dataset by its ID. Useful after running an Actor to get the structured results separately, or to access any existing dataset: + +```python +items = agent.tool.apify_get_dataset_items( + dataset_id="abc123", + limit=100, + offset=0, +) +``` + +## Tool Parameters + +### apify_scrape_url + +| Parameter | Type | Required | Default | Description | +|-----------|------|----------|---------|-------------| +| `url` | string | Yes | — | The URL to scrape | +| `timeout_secs` | int | No | 120 | Maximum time in seconds to wait for scraping to finish | +| `crawler_type` | string | No | `"cheerio"` | Crawler engine to use. One of `"cheerio"` (fastest, no JS rendering), `"playwright:adaptive"` (fast, renders JS if present), or `"playwright:firefox"` (reliable, renders JS, best at avoiding blocking but slower) | + +**Returns:** Markdown content of the scraped page as a plain string. + +### apify_run_actor + +| Parameter | Type | Required | Default | Description | +|-----------|------|----------|---------|-------------| +| `actor_id` | string | Yes | — | Actor identifier (e.g., `apify/website-content-crawler`) | +| `run_input` | dict | No | None | JSON-serializable input for the Actor | +| `timeout_secs` | int | No | 300 | Maximum time in seconds to wait for the Actor run to finish | +| `memory_mbytes` | int | No | None | Memory allocation in MB for the Actor run (uses Actor default if not set) | +| `build` | string | No | None | Actor build tag or number to run a specific version (uses latest build if not set) | + +**Returns:** JSON string with run metadata: `run_id`, `status`, `dataset_id`, `started_at`, `finished_at`. + +### apify_run_task + +| Parameter | Type | Required | Default | Description | +|-----------|------|----------|---------|-------------| +| `task_id` | string | Yes | — | Task identifier (e.g., `user~my-task` or a task ID) | +| `task_input` | dict | No | None | JSON-serializable input to override the task's default input | +| `timeout_secs` | int | No | 300 | Maximum time in seconds to wait for the task run to finish | +| `memory_mbytes` | int | No | None | Memory allocation in MB for the task run (uses task default if not set) | + +**Returns:** JSON string with run metadata: `run_id`, `status`, `dataset_id`, `started_at`, `finished_at`. + +### apify_run_task_and_get_dataset + +| Parameter | Type | Required | Default | Description | +|-----------|------|----------|---------|-------------| +| `task_id` | string | Yes | — | Task identifier (e.g., `user~my-task` or a task ID) | +| `task_input` | dict | No | None | JSON-serializable input to override the task's default input | +| `timeout_secs` | int | No | 300 | Maximum time in seconds to wait for the task run to finish | +| `memory_mbytes` | int | No | None | Memory allocation in MB for the task run (uses task default if not set) | +| `dataset_items_limit` | int | No | 100 | Maximum number of dataset items to return | +| `dataset_items_offset` | int | No | 0 | Number of dataset items to skip for pagination | + +**Returns:** JSON string with run metadata plus an `items` array containing the dataset results. + +### apify_get_dataset_items + +| Parameter | Type | Required | Default | Description | +|-----------|------|----------|---------|-------------| +| `dataset_id` | string | Yes | — | The Apify dataset ID to fetch items from | +| `limit` | int | No | 100 | Maximum number of items to return | +| `offset` | int | No | 0 | Number of items to skip for pagination | + +**Returns:** JSON string containing an array of dataset items. + +### apify_run_actor_and_get_dataset + +| Parameter | Type | Required | Default | Description | +|-----------|------|----------|---------|-------------| +| `actor_id` | string | Yes | — | Actor identifier (e.g., `apify/website-content-crawler`) | +| `run_input` | dict | No | None | JSON-serializable input for the Actor | +| `timeout_secs` | int | No | 300 | Maximum time in seconds to wait for the Actor run to finish | +| `memory_mbytes` | int | No | None | Memory allocation in MB for the Actor run (uses Actor default if not set) | +| `build` | string | No | None | Actor build tag or number to run a specific version (uses latest build if not set) | +| `dataset_items_limit` | int | No | 100 | Maximum number of dataset items to return | +| `dataset_items_offset` | int | No | 0 | Number of dataset items to skip for pagination | + +**Returns:** JSON string with run metadata plus an `items` array containing the dataset results. + +## Search & Crawling + +Specialized tools for common search and crawling use cases. Register all search tools at once: + +```python +from strands import Agent +from strands_tools.apify import APIFY_SEARCH_TOOLS + +agent = Agent(tools=APIFY_SEARCH_TOOLS) +``` + +Or register all Apify tools (core + search): + +```python +from strands_tools.apify import APIFY_ALL_TOOLS + +agent = Agent(tools=APIFY_ALL_TOOLS) +``` + +### Search Google + +Search Google and return structured results using the [Google Search Scraper](https://apify.com/apify/google-search-scraper) Actor: + +```python +result = agent.tool.apify_google_search_scraper( + search_query="best AI frameworks 2025", + results_limit=10, + country_code="us", +) +``` + +### Search Google Maps + +Search Google Maps for businesses and places using the [Google Maps Scraper](https://apify.com/compass/crawler-google-places) Actor: + +```python +result = agent.tool.apify_google_places_scraper( + search_query="restaurants in Prague", + results_limit=20, + include_reviews=True, + max_reviews=5, +) +``` + +### Scrape YouTube + +Scrape YouTube videos, channels, or search results using the [YouTube Scraper](https://apify.com/streamers/youtube-scraper) Actor: + +```python +# Search YouTube +result = agent.tool.apify_youtube_scraper( + search_query="python tutorial", + results_limit=10, +) + +# Scrape specific videos +result = agent.tool.apify_youtube_scraper( + urls=["https://www.youtube.com/watch?v=dQw4w9WgXcQ"], +) +``` + +### Crawl a website + +Crawl a website and extract content from multiple pages using the [Website Content Crawler](https://apify.com/apify/website-content-crawler) Actor. This is the multi-page version — distinct from `apify_scrape_url` which is limited to a single page: + +```python +result = agent.tool.apify_website_content_crawler( + start_url="https://docs.example.com", + max_pages=20, + max_depth=3, +) +``` + +### Scrape e-commerce products + +Scrape product data from e-commerce websites using the [E-commerce Scraping Tool](https://apify.com/apify/e-commerce-scraping-tool) Actor. Supports Amazon, eBay, Walmart, and other platforms: + +```python +# Scrape a single product page +result = agent.tool.apify_ecommerce_scraper( + url="https://www.amazon.com/dp/B0TEST", +) + +# Scrape a category or search results page +result = agent.tool.apify_ecommerce_scraper( + url="https://www.amazon.com/s?k=headphones", + url_type="listing", + results_limit=20, +) +``` + +## Search & Crawling Tool Parameters + +### apify_google_search_scraper + +| Parameter | Type | Required | Default | Description | +|-----------|------|----------|---------|-------------| +| `search_query` | string | Yes | — | The search query string. Supports advanced Google operators like `"site:example.com"` | +| `results_limit` | int | No | 10 | Maximum number of results to return. Google returns ~10 per page, so requesting more triggers additional page scraping | +| `country_code` | string | No | None | Two-letter country code for localized results (e.g., `"us"`, `"de"`) | +| `language_code` | string | No | None | Two-letter language code (e.g., `"en"`, `"de"`) | +| `timeout_secs` | int | No | 300 | Maximum time in seconds to wait | + +**Returns:** JSON string with run metadata and an `items` array containing structured search results (organic results, ads, People Also Ask). + +### apify_google_places_scraper + +| Parameter | Type | Required | Default | Description | +|-----------|------|----------|---------|-------------| +| `search_query` | string | Yes | — | Search query for Google Maps (e.g., `"restaurants in Prague"`) | +| `results_limit` | int | No | 20 | Maximum number of places to return | +| `language` | string | No | None | Language for results (e.g., `"en"`, `"de"`) | +| `include_reviews` | bool | No | False | Whether to include user reviews | +| `max_reviews` | int | No | 5 | Maximum reviews per place when `include_reviews` is True | +| `timeout_secs` | int | No | 300 | Maximum time in seconds to wait | + +**Returns:** JSON string with run metadata and an `items` array containing place data (name, address, rating, phone, website). + +### apify_youtube_scraper + +| Parameter | Type | Required | Default | Description | +|-----------|------|----------|---------|-------------| +| `search_query` | string | No | None | YouTube search query | +| `urls` | list[str] | No | None | Specific YouTube video or channel URLs | +| `results_limit` | int | No | 20 | Maximum number of results to return | +| `timeout_secs` | int | No | 300 | Maximum time in seconds to wait | + +At least one of `search_query` or `urls` must be provided. + +**Returns:** JSON string with run metadata and an `items` array containing video/channel data. + +### apify_website_content_crawler + +| Parameter | Type | Required | Default | Description | +|-----------|------|----------|---------|-------------| +| `start_url` | string | Yes | — | The starting URL to crawl | +| `max_pages` | int | No | 10 | Maximum number of pages to crawl | +| `max_depth` | int | No | 2 | Maximum crawl depth from the start URL | +| `timeout_secs` | int | No | 300 | Maximum time in seconds to wait | + +**Returns:** JSON string with run metadata and an `items` array containing crawled page data with markdown content. + +### apify_ecommerce_scraper + +| Parameter | Type | Required | Default | Description | +|-----------|------|----------|---------|-------------| +| `url` | string | Yes | — | The URL to scrape | +| `url_type` | string | No | `"product"` | Type of URL: `"product"` for a product detail page, `"listing"` for a category or search results page | +| `results_limit` | int | No | 20 | Maximum number of products to return | +| `timeout_secs` | int | No | 300 | Maximum time in seconds to wait | + +**Returns:** JSON string with run metadata and an `items` array containing structured product data. + +## Troubleshooting + +| Error | Cause | Fix | +|-------|-------|-----| +| `APIFY_API_TOKEN environment variable is not set` | Token not configured | Set the `APIFY_API_TOKEN` environment variable | +| `apify-client package is required` | Optional dependency not installed | Run `pip install strands-agents-tools[apify]` | +| `Actor ... finished with status FAILED` | Actor execution error | Check Actor input parameters and run logs in [Apify Console](https://console.apify.com) | +| `Task ... finished with status FAILED` | Task execution error | Check task configuration and run logs in [Apify Console](https://console.apify.com) | +| `Actor/task ... finished with status TIMED-OUT` | Timeout too short for the workload | Increase the `timeout_secs` parameter | +| `Task ... returned no run data` | Task `call()` returned `None` (wait timeout) | Increase the `timeout_secs` parameter | +| `No content returned for URL` | Website Content Crawler returned empty results | Verify the URL is accessible and returns content | +| `At least one of 'search_query' or 'urls' must be provided` | YouTube Scraper called without input | Provide a `search_query`, `urls`, or both | + +## References + +- [Strands Agents Tools](https://strandsagents.com/latest/user-guide/concepts/tools/tools_overview/) +- [Apify Platform](https://apify.com) +- [Apify API Documentation](https://docs.apify.com/api/v2) +- [Apify Store](https://apify.com/store) +- [Apify Python Client](https://docs.apify.com/api/client/python/docs) +- [Google Search Scraper Actor](https://apify.com/apify/google-search-scraper) +- [Google Maps Scraper Actor](https://apify.com/compass/crawler-google-places) +- [YouTube Scraper Actor](https://apify.com/streamers/youtube-scraper) +- [Website Content Crawler Actor](https://apify.com/apify/website-content-crawler) +- [E-commerce Scraping Tool Actor](https://apify.com/apify/e-commerce-scraping-tool) diff --git a/pyproject.toml b/pyproject.toml index bf00325f..93e05c6f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -61,6 +61,9 @@ Homepage = "https://github.com/strands-agents/tools" Documentation = "https://strandsagents.com/" [project.optional-dependencies] +apify = [ + "apify-client>=2.5.0,<3.0.0", +] build = [ "hatch>=1.16.5", ] @@ -122,7 +125,7 @@ mongodb-memory = [ ] [tool.hatch.envs.hatch-static-analysis] -features = ["mem0-memory", "local-chromium-browser", "agent-core-browser", "agent-core-code-interpreter", "a2a-client", "diagram", "rss", "use-computer", "twelvelabs", "elasticsearch-memory", "mongodb-memory"] +features = ["mem0-memory", "local-chromium-browser", "agent-core-browser", "agent-core-code-interpreter", "a2a-client", "diagram", "rss", "use-computer", "twelvelabs", "elasticsearch-memory", "mongodb-memory", "apify"] dependencies = [ "strands-agents>=1.0.0", "mypy>=0.981,<1.0.0", @@ -141,7 +144,7 @@ lint-check = [ lint-fix = ["ruff check --fix"] [tool.hatch.envs.hatch-test] -features = ["mem0-memory", "local-chromium-browser", "agent-core-browser", "agent-core-code-interpreter", "a2a-client", "diagram", "rss", "use-computer", "twelvelabs", "elasticsearch-memory", "mongodb-memory"] +features = ["mem0-memory", "local-chromium-browser", "agent-core-browser", "agent-core-code-interpreter", "a2a-client", "diagram", "rss", "use-computer", "twelvelabs", "elasticsearch-memory", "mongodb-memory", "apify"] extra-dependencies = [ "moto>=5.1.0,<6.0.0", "pytest>=8.0.0,<10.0.0", diff --git a/src/strands_tools/apify.py b/src/strands_tools/apify.py new file mode 100644 index 00000000..d5b67276 --- /dev/null +++ b/src/strands_tools/apify.py @@ -0,0 +1,990 @@ +"""Apify platform tools for Strands Agents. + + +Apify is the world's largest marketplace of tools for web scraping, crawling, data extraction, and web automation. +These tools are called Actors, serverless cloud programs that take JSON input and store results +in a dataset (structured, tabular output) or key-value store (files and unstructured data). +Get structured data from social media, e-commerce, search engines, maps, travel sites, or any other website. + +Available Tools: +--------------- +Core: +- apify_run_actor: Run any Apify Actor with custom input +- apify_get_dataset_items: Fetch items from an Apify dataset with pagination +- apify_run_actor_and_get_dataset: Run an Actor and fetch results in one step +- apify_run_task: Run a saved Actor task with optional input overrides +- apify_run_task_and_get_dataset: Run a task and fetch results in one step +- apify_scrape_url: Scrape a single URL and return content as Markdown + +Search & Crawling: +- apify_google_search_scraper: Search Google and return structured results +- apify_google_places_scraper: Search Google Maps for businesses and places +- apify_youtube_scraper: Scrape YouTube videos, channels, or search results +- apify_website_content_crawler: Crawl a website and extract content from multiple pages +- apify_ecommerce_scraper: Scrape product data from e-commerce websites + +Setup Requirements: +------------------ +1. Create an Apify account at https://apify.com +2. Get your API token: Apify Console > Settings > API & Integrations > Personal API tokens +3. Install the optional dependency: pip install strands-agents-tools[apify] +4. Set the environment variable: + APIFY_API_TOKEN=your_api_token_here + +Usage Examples: +-------------- +Register all core tools at once via the preset list: + +```python +from strands import Agent +from strands_tools.apify import APIFY_CORE_TOOLS + +agent = Agent(tools=APIFY_CORE_TOOLS) +``` + +Register all search & crawling tools: + +```python +from strands import Agent +from strands_tools.apify import APIFY_SEARCH_TOOLS + +agent = Agent(tools=APIFY_SEARCH_TOOLS) +``` + +Register all Apify tools (core + search): + +```python +from strands import Agent +from strands_tools.apify import APIFY_ALL_TOOLS + +agent = Agent(tools=APIFY_ALL_TOOLS) +``` + +Or pick individual tools for a smaller LLM tool surface: + +```python +from strands import Agent +from strands_tools import apify + +agent = Agent(tools=[ + apify.apify_scrape_url, + apify.apify_run_actor, + apify.apify_google_search_scraper, +]) + +# Scrape a single URL +content = agent.tool.apify_scrape_url(url="https://example.com") + +# Run an Actor +result = agent.tool.apify_run_actor( + actor_id="apify/website-content-crawler", + run_input={"startUrls": [{"url": "https://example.com"}]}, +) + +# Search Google +results = agent.tool.apify_google_search_scraper( + search_query="best AI frameworks 2025", + results_limit=10, +) +``` +""" + +import json +import logging +import os +from typing import Any, Dict, List, Literal, Optional, get_args +from urllib.parse import urlparse + +from rich.panel import Panel +from rich.text import Text +from strands import tool + +from strands_tools.utils import console_util + +logger = logging.getLogger(__name__) +console = console_util.create() + +try: + from apify_client import ApifyClient + from apify_client.errors import ApifyApiError + + HAS_APIFY_CLIENT = True +except ImportError: + HAS_APIFY_CLIENT = False + +# Attribution header - lets Apify track usage originating from strands-agents (analytics only) +TRACKING_HEADER = {"x-apify-integration-platform": "strands-agents"} +ERROR_PANEL_TITLE = "[bold red]Apify Error[/bold red]" +DEFAULT_TIMEOUT_SECS = 300 +DEFAULT_SCRAPE_TIMEOUT_SECS = 120 +DEFAULT_DATASET_ITEMS_LIMIT = 100 + +WEBSITE_CONTENT_CRAWLER = "apify/website-content-crawler" +CrawlerType = Literal["playwright:adaptive", "playwright:firefox", "cheerio"] +WEBSITE_CONTENT_CRAWLER_TYPES = get_args(CrawlerType) + + +# --- Helper functions --- + + +def _check_dependency() -> None: + """Raise ImportError if apify-client is not installed.""" + if not HAS_APIFY_CLIENT: + raise ImportError("apify-client package is required. Install with: pip install strands-agents-tools[apify]") + + +def _format_error(e: Exception) -> str: + """Map exceptions to user-friendly error messages, with special handling for ApifyApiError.""" + if HAS_APIFY_CLIENT and isinstance(e, ApifyApiError): + status_code = getattr(e, "status_code", None) + msg = getattr(e, "message", str(e)) + match status_code: + case 400: + return f"Invalid request: {msg}" + case 401: + return "Authentication failed. Verify your APIFY_API_TOKEN is valid." + case 402: + return "Insufficient Apify plan credits or subscription limits exceeded." + case 404: + return f"Resource not found: {msg}" + case 408: + return f"Actor run timed out: {msg}" + case 429: + return ( + "Rate limit exceeded. The Apify client retries automatically; " + "if this persists, reduce request frequency." + ) + case None: + return f"Apify API error: {msg}" + case _: + return f"Apify API error ({status_code}): {msg}" + return str(e) + + +def _error_result(e: Exception, tool_name: str) -> Dict[str, Any]: + """Build a structured error response and display an error panel.""" + message = _format_error(e) + logger.error("%s failed: %s", tool_name, message) + console.print(Panel(Text(message, style="red"), title=ERROR_PANEL_TITLE, border_style="red")) + return {"status": "error", "content": [{"text": message}]} + + +def _success_result(text: str, panel_body: str, panel_title: str) -> Dict[str, Any]: + """Build a structured success response and display a success panel.""" + console.print(Panel(panel_body, title=f"[bold cyan]{panel_title}[/bold cyan]", border_style="green")) + return {"status": "success", "content": [{"text": text}]} + + +class ApifyToolClient: + """Helper class encapsulating Apify API interactions via apify-client.""" + + def __init__(self) -> None: + token = os.getenv("APIFY_API_TOKEN", "") + if not token: + raise ValueError( + "APIFY_API_TOKEN environment variable is not set. " + "Get your token at https://console.apify.com/account/integrations" + ) + self.client: "ApifyClient" = ApifyClient(token, headers=TRACKING_HEADER) + + @staticmethod + def _check_run_status(actor_run: Dict[str, Any], label: str) -> None: + """Raise RuntimeError if the Actor run did not succeed.""" + status = actor_run.get("status", "UNKNOWN") + if status != "SUCCEEDED": + run_id = actor_run.get("id", "N/A") + raise RuntimeError(f"{label} finished with status {status}. Run ID: {run_id}") + + @staticmethod + def _validate_url(url: str) -> None: + """Raise ValueError if the URL does not have a valid HTTP(S) scheme and domain.""" + parsed = urlparse(url) + if parsed.scheme not in ("http", "https"): + raise ValueError(f"Invalid URL scheme '{parsed.scheme}'. Only http and https URLs are supported.") + if not parsed.netloc: + raise ValueError(f"Invalid URL '{url}'. A domain is required.") + + @staticmethod + def _validate_identifier(value: str, name: str) -> None: + """Raise ValueError if a required string identifier is empty or whitespace-only.""" + if not value.strip(): + raise ValueError(f"'{name}' must be a non-empty string.") + + @staticmethod + def _validate_positive(value: int, name: str) -> None: + """Raise ValueError if the value is not a positive integer (> 0).""" + if value <= 0: + raise ValueError(f"'{name}' must be a positive integer, got {value}.") + + @staticmethod + def _validate_non_negative(value: int, name: str) -> None: + """Raise ValueError if the value is negative.""" + if value < 0: + raise ValueError(f"'{name}' must be a non-negative integer, got {value}.") + + def run_actor( + self, + actor_id: str, + run_input: Optional[Dict[str, Any]] = None, + timeout_secs: int = DEFAULT_TIMEOUT_SECS, + memory_mbytes: Optional[int] = None, + build: Optional[str] = None, + ) -> Dict[str, Any]: + """Run an Apify Actor synchronously and return run metadata.""" + self._validate_identifier(actor_id, "actor_id") + self._validate_positive(timeout_secs, "timeout_secs") + if memory_mbytes is not None: + self._validate_positive(memory_mbytes, "memory_mbytes") + + call_kwargs: Dict[str, Any] = { + "run_input": run_input if run_input is not None else {}, + "timeout_secs": timeout_secs, + "logger": None, # Suppress verbose apify-client logging not useful to end users + } + if memory_mbytes is not None: + call_kwargs["memory_mbytes"] = memory_mbytes + if build is not None: + call_kwargs["build"] = build + + actor_run = self.client.actor(actor_id).call(**call_kwargs) + if actor_run is None: + raise RuntimeError(f"Actor {actor_id} returned no run data (possible wait timeout).") + self._check_run_status(actor_run, f"Actor {actor_id}") + + return { + "run_id": actor_run.get("id"), + "status": actor_run.get("status"), + "dataset_id": actor_run.get("defaultDatasetId"), + "started_at": actor_run.get("startedAt"), + "finished_at": actor_run.get("finishedAt"), + } + + def get_dataset_items( + self, + dataset_id: str, + limit: int = DEFAULT_DATASET_ITEMS_LIMIT, + offset: int = 0, + ) -> List[Dict[str, Any]]: + """Fetch items from an Apify dataset.""" + self._validate_identifier(dataset_id, "dataset_id") + self._validate_positive(limit, "limit") + self._validate_non_negative(offset, "offset") + + result = self.client.dataset(dataset_id).list_items(limit=limit, offset=offset) + return list(result.items) + + def run_actor_and_get_dataset( + self, + actor_id: str, + run_input: Optional[Dict[str, Any]] = None, + timeout_secs: int = DEFAULT_TIMEOUT_SECS, + memory_mbytes: Optional[int] = None, + build: Optional[str] = None, + dataset_items_limit: int = DEFAULT_DATASET_ITEMS_LIMIT, + dataset_items_offset: int = 0, + ) -> Dict[str, Any]: + """Run an Actor synchronously, then fetch its default dataset items.""" + self._validate_positive(dataset_items_limit, "dataset_items_limit") + self._validate_non_negative(dataset_items_offset, "dataset_items_offset") + + run_metadata = self.run_actor( + actor_id=actor_id, + run_input=run_input, + timeout_secs=timeout_secs, + memory_mbytes=memory_mbytes, + build=build, + ) + dataset_id = run_metadata["dataset_id"] + if not dataset_id: + raise RuntimeError(f"Actor {actor_id} run has no default dataset.") + items = self.get_dataset_items(dataset_id=dataset_id, limit=dataset_items_limit, offset=dataset_items_offset) + return {**run_metadata, "items": items} + + def run_task( + self, + task_id: str, + task_input: Optional[Dict[str, Any]] = None, + timeout_secs: int = DEFAULT_TIMEOUT_SECS, + memory_mbytes: Optional[int] = None, + ) -> Dict[str, Any]: + """Run an Apify task synchronously and return run metadata.""" + self._validate_identifier(task_id, "task_id") + self._validate_positive(timeout_secs, "timeout_secs") + if memory_mbytes is not None: + self._validate_positive(memory_mbytes, "memory_mbytes") + + call_kwargs: Dict[str, Any] = {"timeout_secs": timeout_secs} + if task_input is not None: + call_kwargs["task_input"] = task_input + if memory_mbytes is not None: + call_kwargs["memory_mbytes"] = memory_mbytes + + task_run = self.client.task(task_id).call(**call_kwargs) + if task_run is None: + raise RuntimeError(f"Task {task_id} returned no run data (possible wait timeout).") + self._check_run_status(task_run, f"Task {task_id}") + + return { + "run_id": task_run.get("id"), + "status": task_run.get("status"), + "dataset_id": task_run.get("defaultDatasetId"), + "started_at": task_run.get("startedAt"), + "finished_at": task_run.get("finishedAt"), + } + + def run_task_and_get_dataset( + self, + task_id: str, + task_input: Optional[Dict[str, Any]] = None, + timeout_secs: int = DEFAULT_TIMEOUT_SECS, + memory_mbytes: Optional[int] = None, + dataset_items_limit: int = DEFAULT_DATASET_ITEMS_LIMIT, + dataset_items_offset: int = 0, + ) -> Dict[str, Any]: + """Run a task synchronously, then fetch its default dataset items.""" + self._validate_positive(dataset_items_limit, "dataset_items_limit") + self._validate_non_negative(dataset_items_offset, "dataset_items_offset") + + run_metadata = self.run_task( + task_id=task_id, + task_input=task_input, + timeout_secs=timeout_secs, + memory_mbytes=memory_mbytes, + ) + dataset_id = run_metadata["dataset_id"] + if not dataset_id: + raise RuntimeError(f"Task {task_id} run has no default dataset.") + items = self.get_dataset_items(dataset_id=dataset_id, limit=dataset_items_limit, offset=dataset_items_offset) + return {**run_metadata, "items": items} + + def scrape_url( + self, + url: str, + timeout_secs: int = DEFAULT_SCRAPE_TIMEOUT_SECS, + crawler_type: CrawlerType = "cheerio", + ) -> str: + """Scrape a single URL using Website Content Crawler and return Markdown.""" + self._validate_url(url) + self._validate_positive(timeout_secs, "timeout_secs") + if crawler_type not in WEBSITE_CONTENT_CRAWLER_TYPES: + raise ValueError( + f"Invalid crawler_type '{crawler_type}'. Must be one of: {', '.join(WEBSITE_CONTENT_CRAWLER_TYPES)}." + ) + + run_input: Dict[str, Any] = { + "startUrls": [{"url": url}], + "maxCrawlPages": 1, + "crawlerType": crawler_type, + } + actor_run = self.client.actor(WEBSITE_CONTENT_CRAWLER).call( + run_input=run_input, + timeout_secs=timeout_secs, + logger=None, # Suppress verbose apify-client logging not useful to end users + ) + if actor_run is None: + raise RuntimeError("Website Content Crawler returned no run data (possible wait timeout).") + self._check_run_status(actor_run, "Website Content Crawler") + + dataset_id = actor_run.get("defaultDatasetId") + if not dataset_id: + raise RuntimeError("Website Content Crawler run has no default dataset.") + result = self.client.dataset(dataset_id).list_items(limit=1) + items = list(result.items) + + if not items: + raise RuntimeError(f"No content returned for URL: {url}") + + return str(items[0].get("markdown") or items[0].get("text", "")) + + +# --- Tool functions --- + + +@tool +def apify_run_actor( + actor_id: str, + run_input: Optional[Dict[str, Any]] = None, + timeout_secs: int = DEFAULT_TIMEOUT_SECS, + memory_mbytes: Optional[int] = None, + build: Optional[str] = None, +) -> Dict[str, Any]: + """Run any Apify Actor and return the run metadata as JSON. + + An Actor is a serverless cloud app on the Apify platform — it takes JSON input, + runs the scraping or automation job, and writes results to a dataset. This tool + executes the Actor synchronously and returns run metadata only (run_id, status, + dataset_id, timestamps). Use apify_run_actor_and_get_dataset to also fetch the + output data in one call, or apify_scrape_url for quick single-URL extraction. + + Common Actors: + - "apify/website-content-crawler" - scrape websites and extract content as Markdown + - "apify/web-scraper" - general-purpose web scraper with JS rendering + - "apify/google-search-scraper" — scrape Google search results + + Args: + actor_id: Actor identifier in "username/actor-name" format, + e.g. "apify/website-content-crawler". Find Actors at https://apify.com/store. + run_input: JSON-serializable input for the Actor. Each Actor defines its own + input schema - check the Actor README on Apify Store for required fields. + timeout_secs: Maximum time in seconds to wait for the Actor run to finish. Defaults to 300. + memory_mbytes: Memory allocation in MB for the Actor run. Uses Actor default `memory` value if not set. + build: Actor build tag or number to run a specific version. Uses latest build if not set. + + Returns: + Dict with status and content containing run metadata: run_id, status, dataset_id, + started_at, finished_at. + """ + try: + _check_dependency() + client = ApifyToolClient() + result = client.run_actor( + actor_id=actor_id, + run_input=run_input, + timeout_secs=timeout_secs, + memory_mbytes=memory_mbytes, + build=build, + ) + return _success_result( + text=json.dumps(result, indent=2, default=str), + panel_body=( + f"[green]Actor run completed[/green]\n" + f"Actor: {actor_id}\n" + f"Run ID: {result['run_id']}\n" + f"Status: {result['status']}\n" + f"Dataset ID: {result['dataset_id']}" + ), + panel_title="Apify: Run Actor", + ) + except Exception as e: + return _error_result(e, "apify_run_actor") + + +@tool +def apify_get_dataset_items( + dataset_id: str, + limit: int = DEFAULT_DATASET_ITEMS_LIMIT, + offset: int = 0, +) -> Dict[str, Any]: + """Fetch items from an existing Apify dataset and return them as JSON. + + Every Actor run writes its output to a dataset — a structured, append-only store + for tabular data. Use the dataset_id from the run metadata returned by apify_run_actor + or apify_run_task. Use offset for pagination through large datasets. + + Args: + dataset_id: The Apify dataset ID to fetch items from. + limit: Maximum number of items to return. Defaults to 100. + offset: Number of items to skip for pagination. Defaults to 0. + + Returns: + Dict with status and content containing an array of dataset items. + """ + try: + _check_dependency() + client = ApifyToolClient() + items = client.get_dataset_items(dataset_id=dataset_id, limit=limit, offset=offset) + return _success_result( + text=json.dumps(items, indent=2, default=str), + panel_body=( + f"[green]Dataset items retrieved[/green]\nDataset ID: {dataset_id}\nItems returned: {len(items)}" + ), + panel_title="Apify: Dataset Items", + ) + except Exception as e: + return _error_result(e, "apify_get_dataset_items") + + +@tool +def apify_run_actor_and_get_dataset( + actor_id: str, + run_input: Optional[Dict[str, Any]] = None, + timeout_secs: int = DEFAULT_TIMEOUT_SECS, + memory_mbytes: Optional[int] = None, + build: Optional[str] = None, + dataset_items_limit: int = DEFAULT_DATASET_ITEMS_LIMIT, + dataset_items_offset: int = 0, +) -> Dict[str, Any]: + """Run an Apify Actor and fetch its dataset results in one step. + + Convenience tool that combines running an Actor and fetching its default dataset + items into a single call. Use this when you want both the run metadata and the + result data without making two separate tool calls. + + Args: + actor_id: Actor identifier in "username/actor-name" format, + e.g. "apify/website-content-crawler". Find Actors at https://apify.com/store. + run_input: JSON-serializable input for the Actor. Each Actor defines its own + input schema - check the Actor README on Apify Store for required fields. + timeout_secs: Maximum time in seconds to wait for the Actor run to finish. Defaults to 300. + memory_mbytes: Memory allocation in MB for the Actor run. Uses Actor default `memory` value if not set. + build: Actor build tag or number to run a specific version. Uses latest build if not set. + dataset_items_limit: Maximum number of dataset items to return. Defaults to 100. + dataset_items_offset: Number of dataset items to skip for pagination. Defaults to 0. + + Returns: + Dict with status and content containing run metadata (run_id, status, dataset_id, + started_at, finished_at) plus an "items" array containing the dataset results. + """ + try: + _check_dependency() + client = ApifyToolClient() + result = client.run_actor_and_get_dataset( + actor_id=actor_id, + run_input=run_input, + timeout_secs=timeout_secs, + memory_mbytes=memory_mbytes, + build=build, + dataset_items_limit=dataset_items_limit, + dataset_items_offset=dataset_items_offset, + ) + return _success_result( + text=json.dumps(result, indent=2, default=str), + panel_body=( + f"[green]Actor run completed with dataset[/green]\n" + f"Actor: {actor_id}\n" + f"Run ID: {result['run_id']}\n" + f"Status: {result['status']}\n" + f"Dataset ID: {result['dataset_id']}\n" + f"Items returned: {len(result['items'])}" + ), + panel_title="Apify: Run Actor + Dataset", + ) + except Exception as e: + return _error_result(e, "apify_run_actor_and_get_dataset") + + +@tool +def apify_run_task( + task_id: str, + task_input: Optional[Dict[str, Any]] = None, + timeout_secs: int = DEFAULT_TIMEOUT_SECS, + memory_mbytes: Optional[int] = None, +) -> Dict[str, Any]: + """Run a saved Apify task and return the run metadata as JSON. + + Tasks are saved Actor configurations with preset inputs, managed in Apify Console. + Use this when a task has already been configured, so you don't need to specify + the full Actor input every time. Use apify_run_task_and_get_dataset to also fetch + the output data in one call. + + Args: + task_id: Task identifier in "username/task-name" format or a task ID string. + task_input: Optional JSON-serializable input to override the task's default input fields. + timeout_secs: Maximum time in seconds to wait for the task run to finish. Defaults to 300. + memory_mbytes: Memory allocation in MB for the task run. Uses task default `memory` value if not set. + + Returns: + Dict with status and content containing run metadata: run_id, status, dataset_id, + started_at, finished_at. + """ + try: + _check_dependency() + client = ApifyToolClient() + result = client.run_task( + task_id=task_id, + task_input=task_input, + timeout_secs=timeout_secs, + memory_mbytes=memory_mbytes, + ) + return _success_result( + text=json.dumps(result, indent=2, default=str), + panel_body=( + f"[green]Task run completed[/green]\n" + f"Task: {task_id}\n" + f"Run ID: {result['run_id']}\n" + f"Status: {result['status']}\n" + f"Dataset ID: {result['dataset_id']}" + ), + panel_title="Apify: Run Task", + ) + except Exception as e: + return _error_result(e, "apify_run_task") + + +@tool +def apify_run_task_and_get_dataset( + task_id: str, + task_input: Optional[Dict[str, Any]] = None, + timeout_secs: int = DEFAULT_TIMEOUT_SECS, + memory_mbytes: Optional[int] = None, + dataset_items_limit: int = DEFAULT_DATASET_ITEMS_LIMIT, + dataset_items_offset: int = 0, +) -> Dict[str, Any]: + """Run a saved Apify task and fetch its dataset results in one step. + + Convenience tool that combines running a task and fetching its default dataset + items into a single call. Use this when you want both the run metadata and the + result data without making two separate tool calls. + + Args: + task_id: Task identifier in "username/task-name" format or a task ID string. + task_input: Optional JSON-serializable input to override the task's default input fields. + timeout_secs: Maximum time in seconds to wait for the task run to finish. Defaults to 300. + memory_mbytes: Memory allocation in MB for the task run. Uses task default `memory` value if not set. + dataset_items_limit: Maximum number of dataset items to return. Defaults to 100. + dataset_items_offset: Number of dataset items to skip for pagination. Defaults to 0. + + Returns: + Dict with status and content containing run metadata (run_id, status, dataset_id, + started_at, finished_at) plus an "items" array containing the dataset results. + """ + try: + _check_dependency() + client = ApifyToolClient() + result = client.run_task_and_get_dataset( + task_id=task_id, + task_input=task_input, + timeout_secs=timeout_secs, + memory_mbytes=memory_mbytes, + dataset_items_limit=dataset_items_limit, + dataset_items_offset=dataset_items_offset, + ) + return _success_result( + text=json.dumps(result, indent=2, default=str), + panel_body=( + f"[green]Task run completed with dataset[/green]\n" + f"Task: {task_id}\n" + f"Run ID: {result['run_id']}\n" + f"Status: {result['status']}\n" + f"Dataset ID: {result['dataset_id']}\n" + f"Items returned: {len(result['items'])}" + ), + panel_title="Apify: Run Task + Dataset", + ) + except Exception as e: + return _error_result(e, "apify_run_task_and_get_dataset") + + +@tool +def apify_scrape_url( + url: str, + timeout_secs: int = DEFAULT_SCRAPE_TIMEOUT_SECS, + crawler_type: CrawlerType = "cheerio", +) -> Dict[str, Any]: + """Scrape a single URL and return its content as Markdown. + + Uses the Website Content Crawler Actor under the hood, pre-configured for + fast single-page scraping. This is the simplest way to extract readable content + from any web page — no Actor input schema needed. For multi-page crawls, use + apify_run_actor_and_get_dataset with "apify/website-content-crawler" directly. + + Args: + url: The URL to scrape, e.g. "https://example.com". + timeout_secs: Maximum time in seconds to wait for scraping to finish. Defaults to 120. + crawler_type: Crawler engine to use. One of: + - "cheerio" (default): Fastest, no JavaScript rendering. Best for static HTML. + - "playwright:adaptive": Renders JS only when needed. Good general-purpose choice. + - "playwright:firefox": Full JS rendering, best at bypassing anti-bot protection but slowest. + + Returns: + Dict with status and content containing the Markdown content of the scraped page. + """ + try: + _check_dependency() + client = ApifyToolClient() + content = client.scrape_url(url=url, timeout_secs=timeout_secs, crawler_type=crawler_type) + return _success_result( + text=content, + panel_body=( + f"[green]URL scraped successfully[/green]\nURL: {url}\nContent length: {len(content)} characters" + ), + panel_title="Apify: Scrape URL", + ) + except Exception as e: + return _error_result(e, "apify_scrape_url") + + +APIFY_CORE_TOOLS = [ + apify_run_actor, + apify_get_dataset_items, + apify_run_actor_and_get_dataset, + apify_run_task, + apify_run_task_and_get_dataset, + apify_scrape_url, +] + + +# --- Search & crawling tool constants --- + +GOOGLE_SEARCH_SCRAPER_ID = "apify/google-search-scraper" +GOOGLE_PLACES_SCRAPER_ID = "compass/crawler-google-places" +YOUTUBE_SCRAPER_ID = "streamers/youtube-scraper" +ECOMMERCE_SCRAPER_ID = "apify/e-commerce-scraping-tool" +DEFAULT_SEARCH_RESULTS_LIMIT = 20 + + +# --- Search & crawling helpers --- + + +def _search_crawl_result( + actor_name: str, + client: ApifyToolClient, + run_input: Dict[str, Any], + actor_id: str, + timeout_secs: int, + results_limit: int, +) -> Dict[str, Any]: + """Run a search/crawling Actor and return formatted results.""" + result = client.run_actor_and_get_dataset( + actor_id=actor_id, + run_input=run_input, + timeout_secs=timeout_secs, + dataset_items_limit=results_limit, + ) + return _success_result( + text=json.dumps(result, indent=2, default=str), + panel_body=( + f"[green]{actor_name} completed[/green]\nRun ID: {result['run_id']}\nItems returned: {len(result['items'])}" + ), + panel_title=f"Apify: {actor_name}", + ) + + +# --- Search & crawling tool functions --- + + +@tool +def apify_google_search_scraper( + search_query: str, + results_limit: int = 10, + country_code: Optional[str] = None, + language_code: Optional[str] = None, + timeout_secs: int = DEFAULT_TIMEOUT_SECS, +) -> Dict[str, Any]: + """Search Google and return structured search results. + + Uses the Google Search Scraper Actor to perform a Google search and return + organic results, ads, People Also Ask, and related queries in a structured format. + + Args: + search_query: The search query string, e.g. "best AI frameworks 2025". + Supports advanced Google operators like "site:example.com" or "AI OR ML". + results_limit: Maximum number of results to return. Google returns ~10 results + per page, so requesting more triggers additional page scraping. Defaults to 10. + country_code: Two-letter country code for localized results, e.g. "us", "de". + language_code: Two-letter language code for the interface, e.g. "en", "de". + timeout_secs: Maximum time in seconds to wait for the run to finish. Defaults to 300. + + Returns: + Dict with status and content containing structured Google search results including + organic results, ads, and People Also Ask data. + """ + try: + _check_dependency() + client = ApifyToolClient() + max_pages = max(1, (results_limit + 9) // 10) + run_input: Dict[str, Any] = { + "queries": search_query, + "maxPagesPerQuery": max_pages, + } + if country_code is not None: + run_input["countryCode"] = country_code + if language_code is not None: + run_input["languageCode"] = language_code + return _search_crawl_result( + actor_name="Google Search Scraper", + client=client, + run_input=run_input, + actor_id=GOOGLE_SEARCH_SCRAPER_ID, + timeout_secs=timeout_secs, + results_limit=results_limit, + ) + except Exception as e: + return _error_result(e, "apify_google_search_scraper") + + +@tool +def apify_google_places_scraper( + search_query: str, + results_limit: int = DEFAULT_SEARCH_RESULTS_LIMIT, + language: Optional[str] = None, + include_reviews: bool = False, + max_reviews: int = 5, + timeout_secs: int = DEFAULT_TIMEOUT_SECS, +) -> Dict[str, Any]: + """Search Google Maps for businesses and places, optionally including reviews. + + Uses the Google Maps Scraper Actor to find places matching a search query + and return structured data including name, address, rating, phone, and website. + + Args: + search_query: Search query for Google Maps, e.g. "restaurants in Prague". + results_limit: Maximum number of places to return. Defaults to 20. + language: Language for results, e.g. "en", "de". Defaults to English. + include_reviews: Whether to include user reviews for each place. Defaults to False. + max_reviews: Maximum reviews per place when include_reviews is True. Defaults to 5. + timeout_secs: Maximum time in seconds to wait for the run to finish. Defaults to 300. + + Returns: + Dict with status and content containing structured Google Maps place data. + """ + try: + _check_dependency() + client = ApifyToolClient() + run_input: Dict[str, Any] = { + "searchStringsArray": [search_query], + "maxCrawledPlacesPerSearch": results_limit, + "maxReviews": max_reviews if include_reviews else 0, + } + if language is not None: + run_input["language"] = language + return _search_crawl_result( + actor_name="Google Places Scraper", + client=client, + run_input=run_input, + actor_id=GOOGLE_PLACES_SCRAPER_ID, + timeout_secs=timeout_secs, + results_limit=results_limit, + ) + except Exception as e: + return _error_result(e, "apify_google_places_scraper") + + +@tool +def apify_youtube_scraper( + search_query: Optional[str] = None, + urls: Optional[List[str]] = None, + results_limit: int = DEFAULT_SEARCH_RESULTS_LIMIT, + timeout_secs: int = DEFAULT_TIMEOUT_SECS, +) -> Dict[str, Any]: + """Scrape YouTube videos, channels, or search results. + + Uses the YouTube Scraper Actor to search YouTube or scrape specific video/channel + URLs. Provide either a search query, specific URLs, or both. + + Args: + search_query: YouTube search query, e.g. "python tutorial". + urls: Specific YouTube video or channel URLs to scrape. + results_limit: Maximum number of results to return. Defaults to 20. + timeout_secs: Maximum time in seconds to wait for the run to finish. Defaults to 300. + + Returns: + Dict with status and content containing structured YouTube video/channel data. + """ + try: + _check_dependency() + if not search_query and not urls: + raise ValueError("At least one of 'search_query' or 'urls' must be provided.") + client = ApifyToolClient() + run_input: Dict[str, Any] = { + "maxResults": results_limit, + } + if search_query is not None: + run_input["searchQueries"] = [search_query] + if urls is not None: + run_input["startUrls"] = [{"url": u} for u in urls] + return _search_crawl_result( + actor_name="YouTube Scraper", + client=client, + run_input=run_input, + actor_id=YOUTUBE_SCRAPER_ID, + timeout_secs=timeout_secs, + results_limit=results_limit, + ) + except Exception as e: + return _error_result(e, "apify_youtube_scraper") + + +@tool +def apify_website_content_crawler( + start_url: str, + max_pages: int = 10, + max_depth: int = 2, + timeout_secs: int = DEFAULT_TIMEOUT_SECS, +) -> Dict[str, Any]: + """Crawl a website and extract content from multiple pages. + + Uses the Website Content Crawler Actor to perform a multi-page crawl starting + from the given URL. Returns page content as markdown. This is the extended + multi-page version — distinct from apify_scrape_url which scrapes a single page. + + Args: + start_url: The starting URL to crawl, e.g. "https://docs.example.com". + max_pages: Maximum number of pages to crawl. Defaults to 10. + max_depth: Maximum crawl depth from the start URL. Defaults to 2. + timeout_secs: Maximum time in seconds to wait for the run to finish. Defaults to 300. + + Returns: + Dict with status and content containing crawled page data with markdown content. + """ + try: + _check_dependency() + client = ApifyToolClient() + client._validate_url(start_url) + run_input: Dict[str, Any] = { + "startUrls": [{"url": start_url}], + "maxCrawlPages": max_pages, + "maxCrawlDepth": max_depth, + "proxyConfiguration": {"useApifyProxy": True}, + } + return _search_crawl_result( + actor_name="Website Content Crawler", + client=client, + run_input=run_input, + actor_id=WEBSITE_CONTENT_CRAWLER, + timeout_secs=timeout_secs, + results_limit=max_pages, + ) + except Exception as e: + return _error_result(e, "apify_website_content_crawler") + + +VALID_ECOMMERCE_URL_TYPES = ("product", "listing") + + +@tool +def apify_ecommerce_scraper( + url: str, + url_type: str = "product", + results_limit: int = DEFAULT_SEARCH_RESULTS_LIMIT, + timeout_secs: int = DEFAULT_TIMEOUT_SECS, +) -> Dict[str, Any]: + """Scrape product data from e-commerce websites. + + Uses the E-commerce Scraping Tool Actor to extract structured product data + (title, price, description, images, etc.) from supported e-commerce platforms + including Amazon, eBay, Walmart, and others. The Actor auto-detects the site. + + Args: + url: The URL to scrape. + url_type: Type of URL being scraped. Use "product" (default) for a direct product + detail page, or "listing" for a category page or search results page containing + multiple products. + results_limit: Maximum number of products to return. Defaults to 20. + timeout_secs: Maximum time in seconds to wait for the run to finish. Defaults to 300. + + Returns: + Dict with status and content containing structured product data. + """ + try: + _check_dependency() + client = ApifyToolClient() + client._validate_url(url) + if url_type not in VALID_ECOMMERCE_URL_TYPES: + raise ValueError(f"Invalid url_type '{url_type}'. Must be one of: {', '.join(VALID_ECOMMERCE_URL_TYPES)}.") + url_field = "listingUrls" if url_type == "listing" else "detailsUrls" + run_input: Dict[str, Any] = { + url_field: [{"url": url}], + "maxProductResults": results_limit, + } + return _search_crawl_result( + actor_name="E-commerce Scraper", + client=client, + run_input=run_input, + actor_id=ECOMMERCE_SCRAPER_ID, + timeout_secs=timeout_secs, + results_limit=results_limit, + ) + except Exception as e: + return _error_result(e, "apify_ecommerce_scraper") + + +APIFY_SEARCH_TOOLS = [ + apify_google_search_scraper, + apify_google_places_scraper, + apify_youtube_scraper, + apify_website_content_crawler, + apify_ecommerce_scraper, +] + +APIFY_ALL_TOOLS = APIFY_CORE_TOOLS + APIFY_SEARCH_TOOLS diff --git a/src/strands_tools/code_interpreter/agent_core_code_interpreter.py b/src/strands_tools/code_interpreter/agent_core_code_interpreter.py index 7fe637b6..4e5ebf42 100644 --- a/src/strands_tools/code_interpreter/agent_core_code_interpreter.py +++ b/src/strands_tools/code_interpreter/agent_core_code_interpreter.py @@ -53,6 +53,7 @@ def __init__( session_name: Optional[str] = None, auto_create: bool = True, persist_sessions: bool = True, + session_timeout_seconds: int = 900, ) -> None: """ Initialize the Bedrock AgentCore code interpreter with session persistence support. @@ -100,6 +101,10 @@ def __init__( sessions to survive across invocations and be reconnected by subsequent instances via module-level cache. + session_timeout_seconds (int): Timeout in seconds for sessions created + by this instance. Sessions automatically terminate after the timeout period. + Default: 900 (15 minutes). + Session Lifecycle: Invocation 1 (Instance #1): 1. Create new instance with session_name="user-abc-123" @@ -180,6 +185,7 @@ def invoke(payload, context): self.identifier = identifier or "aws.codeinterpreter.v1" self.auto_create = auto_create self.persist_sessions = persist_sessions + self.session_timeout_seconds = session_timeout_seconds if session_name is None: self.default_session = f"session-{uuid.uuid4().hex[:12]}" @@ -262,8 +268,11 @@ def init_session(self, action: InitSessionAction) -> Dict[str, Any]: # Create new sandbox client client = BedrockAgentCoreCodeInterpreterClient(region=self.region) - # Start session with identifier and name - client.start(identifier=self.identifier, name=session_name) + client.start( + identifier=self.identifier, + name=session_name, + session_timeout_seconds=self.session_timeout_seconds, + ) aws_session_id = client.session_id diff --git a/src/strands_tools/elasticsearch_memory.py b/src/strands_tools/elasticsearch_memory.py index 6bf78930..5ad68f95 100644 --- a/src/strands_tools/elasticsearch_memory.py +++ b/src/strands_tools/elasticsearch_memory.py @@ -113,14 +113,15 @@ import json import logging import os +import re import time import uuid from datetime import datetime, timezone from enum import Enum -from typing import Dict, List, Optional +from typing import Any, Dict, List, Optional import boto3 -from elasticsearch import Elasticsearch, NotFoundError +from elasticsearch import Elasticsearch from strands import tool # Set up logging @@ -183,6 +184,47 @@ class MemoryAction(str, Enum): DEFAULT_EMBEDDING_MODEL = "amazon.titan-embed-text-v2:0" DEFAULT_EMBEDDING_DIMS = 1024 # Titan v2 returns 1024 dimensions DEFAULT_MAX_RESULTS = 10 +DEFAULT_NAMESPACE = "default" + + +def _validate_namespace(namespace: Any) -> str: + """Validate and sanitize namespace parameter to prevent injection attacks. + + This function treats namespace as a trusted identifier by requiring it to be + a simple string matching the pattern ^[A-Za-z0-9_-]{1,64}$ before including + it in Elasticsearch queries. This prevents potential injection attacks and + ensures consistent namespace handling across all memory operations. + + Args: + namespace: The namespace value to validate (can be any type) + + Returns: + str: A validated string namespace (1-64 chars, alphanumeric + underscore + hyphen only) + + Raises: + ElasticsearchValidationError: If namespace cannot be converted to a safe string + """ + if namespace is None: + return DEFAULT_NAMESPACE + + if not isinstance(namespace, str): + raise ElasticsearchValidationError(f"Namespace must be a string, got {type(namespace).__name__}. ") + + clean_namespace = str(namespace).strip() + + if not clean_namespace: + raise ElasticsearchValidationError("Invalid namespace: Namespace cannot be empty.") + + if len(clean_namespace) > 64: + raise ElasticsearchValidationError("Invalid namespace: Namespace too long. Maximum 64 characters allowed.") + + if not re.match(r"^[A-Za-z0-9_-]{1,64}$", clean_namespace): + raise ElasticsearchValidationError( + f"Invalid namespace: Namespace '{clean_namespace}' contains invalid characters. " + "Must match pattern ^[A-Za-z0-9_-]{1,64}$" + ) + + return clean_namespace def _ensure_index_exists(es_client: Elasticsearch, index_name: str, es_url: Optional[str] = None): @@ -465,13 +507,27 @@ def _get_memory(es_client: Elasticsearch, index_name: str, namespace: str, memor Exception: If memory not found or not in correct namespace """ try: - response = es_client.get(index=index_name, id=memory_id) - source = response["_source"] + # Query with both memory_id and namespace to enforce tenant isolation server-side + search_body = { + "query": { + "bool": { + "must": [ + {"term": {"memory_id": memory_id}}, + {"term": {"namespace": namespace}}, + ] + } + }, + "size": 1, + "_source": ["memory_id", "content", "timestamp", "metadata", "namespace"], + } + + response = es_client.search(index=index_name, body=search_body) - # Verify namespace - if source.get("namespace") != namespace: + if not response["hits"]["hits"]: raise ElasticsearchMemoryNotFoundError(f"Memory {memory_id} not found in namespace {namespace}") + source = response["hits"]["hits"][0]["_source"] + return { "memory_id": source["memory_id"], "content": source["content"], @@ -480,8 +536,6 @@ def _get_memory(es_client: Elasticsearch, index_name: str, namespace: str, memor "namespace": source["namespace"], } - except NotFoundError: - raise ElasticsearchMemoryNotFoundError(f"Memory {memory_id} not found") from None except ElasticsearchMemoryNotFoundError: raise except Exception as e: @@ -492,6 +546,10 @@ def _delete_memory(es_client: Elasticsearch, index_name: str, namespace: str, me """ Delete a specific memory by ID. + Uses delete_by_query with both memory_id and namespace constraints to + atomically verify ownership and delete in a single operation, preventing + TOCTOU (Time-of-Check to Time-of-Use) race conditions. + Args: es_client: Elasticsearch client index_name: Elasticsearch index name @@ -505,18 +563,29 @@ def _delete_memory(es_client: Elasticsearch, index_name: str, namespace: str, me Exception: If memory not found or deletion fails """ try: - # First verify the memory exists and is in correct namespace - _get_memory(es_client, index_name, namespace, memory_id) + # Atomically delete only if both memory_id and namespace match, + # preventing TOCTOU race conditions between check and delete + response = es_client.delete_by_query( + index=index_name, + body={ + "query": { + "bool": { + "must": [ + {"term": {"memory_id": memory_id}}, + {"term": {"namespace": namespace}}, + ] + } + } + }, + ) - # Delete the memory - response = es_client.delete(index=index_name, id=memory_id) + if response.get("deleted", 0) == 0: + raise ElasticsearchMemoryNotFoundError(f"Memory {memory_id} not found in namespace {namespace}") - return {"memory_id": memory_id, "result": response["result"]} + return {"memory_id": memory_id, "result": "deleted"} except ElasticsearchMemoryNotFoundError: raise - except NotFoundError: - raise ElasticsearchMemoryNotFoundError(f"Memory {memory_id} not found") from None except Exception as e: raise ElasticsearchMemoryError(f"Failed to delete memory {memory_id}: {str(e)}") from e @@ -603,11 +672,21 @@ def elasticsearch_memory( # Set defaults index_name = index_name or os.getenv("ELASTICSEARCH_INDEX_NAME", DEFAULT_INDEX_NAME) - namespace = namespace or os.getenv("ELASTICSEARCH_NAMESPACE", "default") + if namespace is None: + namespace = os.getenv("ELASTICSEARCH_NAMESPACE", DEFAULT_NAMESPACE) embedding_model = embedding_model or os.getenv("ELASTICSEARCH_EMBEDDING_MODEL", DEFAULT_EMBEDDING_MODEL) region = region or os.getenv("AWS_REGION", "us-west-2") max_results = max_results or DEFAULT_MAX_RESULTS + # Validate namespace to prevent injection attacks + try: + safe_namespace = _validate_namespace(namespace) + except ElasticsearchValidationError as e: + return { + "status": "error", + "content": [{"text": f"Invalid namespace: {str(e)}"}], + } + # Initialize Elasticsearch client try: if es_url: @@ -685,7 +764,7 @@ def elasticsearch_memory( try: if action_enum == MemoryAction.RECORD: response = _record_memory( - es_client, bedrock_runtime, index_name, namespace, embedding_model, content, metadata + es_client, bedrock_runtime, index_name, safe_namespace, embedding_model, content, metadata ) return { "status": "success", @@ -694,7 +773,14 @@ def elasticsearch_memory( elif action_enum == MemoryAction.RETRIEVE: response = _retrieve_memories( - es_client, bedrock_runtime, index_name, namespace, embedding_model, query, max_results, next_token + es_client, + bedrock_runtime, + index_name, + safe_namespace, + embedding_model, + query, + max_results, + next_token, ) return { "status": "success", @@ -702,21 +788,21 @@ def elasticsearch_memory( } elif action_enum == MemoryAction.LIST: - response = _list_memories(es_client, index_name, namespace, max_results, next_token) + response = _list_memories(es_client, index_name, safe_namespace, max_results, next_token) return { "status": "success", "content": [{"text": f"Memories listed successfully: {json.dumps(response, default=str)}"}], } elif action_enum == MemoryAction.GET: - response = _get_memory(es_client, index_name, namespace, memory_id) + response = _get_memory(es_client, index_name, safe_namespace, memory_id) return { "status": "success", "content": [{"text": f"Memory retrieved successfully: {json.dumps(response, default=str)}"}], } elif action_enum == MemoryAction.DELETE: - response = _delete_memory(es_client, index_name, namespace, memory_id) + response = _delete_memory(es_client, index_name, safe_namespace, memory_id) return { "status": "success", "content": [{"text": f"Memory deleted successfully: {memory_id}"}], diff --git a/src/strands_tools/exa.py b/src/strands_tools/exa.py index db0da875..39a0b81f 100644 --- a/src/strands_tools/exa.py +++ b/src/strands_tools/exa.py @@ -1,12 +1,12 @@ """ Exa Search and Contents tools for intelligent web search and content processing. -This module provides access to Exa's API, which offers neural search capabilities optimized for LLMs and AI agents. -The "auto" mode intelligently combines neural embeddings-based search with traditional keyword search for best results. +This module provides access to Exa's API, which offers advanced search capabilities optimized for LLMs and AI agents. +The "auto" mode intelligently selects the best search approach for optimal results. Key Features: - Auto mode that intelligently selects the best search approach (default) -- Neural and keyword search capabilities +- Deep search for thorough, comprehensive results - Advanced content filtering and domain management - Full page content extraction with summaries - Support for general web search, company info, news, PDFs, GitHub repos, and more @@ -47,10 +47,11 @@ from typing import Any, Dict, List, Literal, Optional, Union import aiohttp -from rich.console import Console from rich.panel import Panel from strands import tool +from strands_tools.utils import console_util + logger = logging.getLogger(__name__) # Exa API configuration @@ -59,7 +60,7 @@ EXA_CONTENTS_ENDPOINT = "/contents" # Initialize Rich console -console = Console() +console = console_util.create() def _get_api_key() -> str: @@ -191,7 +192,7 @@ def format_contents_response(data: Dict[str, Any]) -> Panel: @tool async def exa_search( query: str, - type: Optional[Literal["keyword", "neural", "fast", "auto"]] = "auto", + type: Optional[Literal["auto", "fast", "deep"]] = "auto", category: Optional[ Literal["company", "news", "pdf", "github", "personal site", "linkedin profile", "financial report"] ] = None, @@ -217,25 +218,22 @@ async def exa_search( extras: Optional[Dict[str, Any]] = None, ) -> Dict[str, Any]: """ - Search the web intelligently using Exa's neural and keyword search capabilities. + Search the web intelligently using Exa's advanced search capabilities. Exa provides advanced web search optimized for LLMs and AI agents. The "auto" mode (default) - intelligently combines neural embeddings-based search with traditional keyword search to find - the most relevant results for your query. + intelligently selects the best search approach to find the most relevant results for your query. Key Features: - Auto mode that intelligently selects the best search approach (default) - - Neural search using embeddings for semantic understanding - - Traditional keyword search for exact matches + - Deep search for thorough, comprehensive results - Advanced filtering by domain, date, and content - Live crawling with fallback options - Rich content extraction with summaries Search Types: - - auto: Intelligently combines neural and keyword approaches (recommended default) - - neural: Uses embeddings-based model for semantic search - - keyword: Google-like SERP search for exact matches - - fast: Streamlined versions of neural and keyword models + - auto: Intelligently selects the best search approach (recommended default) + - fast: Optimized for speed + - deep: Thorough search for comprehensive results Categories (optional - general web search works best): - company: Focus on company websites and information when specifically needed @@ -249,7 +247,7 @@ async def exa_search( Args: query: The search query string. Examples: "Latest developments in artificial intelligence", "Best project management tools" - type: Search type - "auto" (default, recommended), "neural", "keyword", or "fast" + type: Search type - "auto" (default, recommended), "fast", or "deep" category: Optional data category - use sparingly as general search works best. Use "company" when specifically looking for company information user_location: Two-letter ISO country code (e.g., "US", "UK") for geo-localized results diff --git a/src/strands_tools/http_request.py b/src/strands_tools/http_request.py index ba107d71..95176184 100644 --- a/src/strands_tools/http_request.py +++ b/src/strands_tools/http_request.py @@ -2,18 +2,26 @@ Make HTTP requests with comprehensive authentication, session management, and metrics. Supports all major authentication types and enterprise patterns. -Environment Variable Support: -1. Authentication tokens: - - Uses auth_env_var parameter to read tokens from environment (e.g., GITHUB_TOKEN, GITLAB_TOKEN) - - Example: http_request(method="GET", url="...", auth_type="token", auth_env_var="GITHUB_TOKEN") - - Supported variables: GITHUB_TOKEN, GITLAB_TOKEN, SLACK_BOT_TOKEN, AWS_ACCESS_KEY_ID, etc. -2. AWS credentials: - - Reads AWS_ACCESS_KEY_ID, AWS_SECRET_ACCESS_KEY, AWS_SESSION_TOKEN, AWS_REGION automatically - - Example: http_request(method="GET", url="...", auth_type="aws_sig_v4", aws_auth={"service": "s3"}) -Use the environment tool (agent.tool.environment) to view available environment variables: -- List all: environment(action="list") -- Get specific: environment(action="get", name="GITHUB_TOKEN") -- Set new: environment(action="set", name="CUSTOM_TOKEN", value="your-token") +Authentication Support: +1. Direct tokens: Pass auth_token directly for Bearer, token, custom, or api_key auth types +2. Basic auth: Provide username/password via basic_auth parameter +3. Digest auth: Provide credentials via digest_auth parameter +4. JWT: Provide secret/algorithm/expiry via jwt_config parameter +5. AWS SigV4: Uses boto3 credential chain automatically via aws_auth parameter + +Environment Variable Token Config: + Import and populate HTTP_REQUEST_TOKEN_CONFIG to allow specific environment variables + to be used as auth tokens for requests to matching domains. + + Format: {"ENV_VAR_NAME": ["allowed.domain.com", "*.other.com"]} + + Example: + from strands_tools.http_request import HTTP_REQUEST_TOKEN_CONFIG + HTTP_REQUEST_TOKEN_CONFIG["GITHUB_TOKEN"] = ["api.github.com"] + HTTP_REQUEST_TOKEN_CONFIG["GITLAB_TOKEN"] = ["gitlab.com"] + + When auth_env_var is passed to the tool, the token is only injected if the request + domain matches one of the allowed domains for that variable. """ import base64 @@ -21,6 +29,7 @@ import datetime import http.cookiejar import json +import logging import os import time from typing import Any, Dict, Optional, Union @@ -44,13 +53,14 @@ from strands_tools.utils import console_util from strands_tools.utils.user_input import get_user_input +logger = logging.getLogger(__name__) + TOOL_SPEC = { "name": "http_request", "description": ( "Make HTTP requests to any API with comprehensive authentication including Bearer tokens, Basic auth, " - "JWT, AWS SigV4, Digest auth, and enterprise authentication patterns. Automatically reads tokens from " - "environment variables (GITHUB_TOKEN, GITLAB_TOKEN, AWS credentials, etc.) when auth_env_var is specified. " - "Use environment(action='list') to view available variables. Includes session management, metrics, " + "JWT, AWS SigV4, Digest auth, and enterprise authentication patterns. " + "Includes session management, metrics, " "streaming support, cookie handling, redirect control, proxy support, and optional HTML to markdown conversion." ), "inputSchema": { @@ -82,11 +92,15 @@ }, "auth_token": { "type": "string", - "description": "Authentication token (if not provided, will check environment variables)", + "description": "Authentication token (if not provided, will check auth_env_var if configured)", }, "auth_env_var": { "type": "string", - "description": "Name of environment variable containing the auth token", + "description": ( + "Name of an environment variable containing the auth token. " + "The variable must be listed in HTTP_REQUEST_TOKEN_CONFIG " + "with an allowed domain that matches the request URL." + ), }, "headers": { "type": "object", @@ -98,7 +112,7 @@ }, "verify_ssl": { "type": "boolean", - "description": "Whether to verify SSL certificates", + "description": "Whether to verify SSL certificates. Disabling may be restricted.", }, "cookie": { "type": "string", @@ -198,6 +212,12 @@ # Metrics storage REQUEST_METRICS = collections.defaultdict(list) +# Token config: maps env var names to lists of allowed domains. +# Import and populate this dict to enable auth_env_var support: +# from strands_tools.http_request import HTTP_REQUEST_TOKEN_CONFIG +# HTTP_REQUEST_TOKEN_CONFIG["GITHUB_TOKEN"] = ["api.github.com"] +HTTP_REQUEST_TOKEN_CONFIG: Dict[str, list] = {} + def extract_content_from_html(html: str) -> str: """Convert HTML content to Markdown format. @@ -372,12 +392,14 @@ def format_headers_table(headers: Dict) -> Table: def process_auth_headers(headers: Dict[str, Any], tool_input: Dict[str, Any]) -> Dict[str, Any]: - """ - Process authentication headers based on input parameters. + """Process authentication headers based on input parameters. Supports multiple authentication methods: - 1. Environment variables: Uses auth_env_var to read tokens - 2. Direct token: Uses auth_token parameter + 1. Direct token: Uses auth_token parameter + 2. Env var token: Uses auth_env_var parameter, validated against HTTP_REQUEST_TOKEN_CONFIG + 3. Basic auth: Handled separately via handle_basic_auth + 4. JWT: Handled separately via handle_jwt + 5. AWS SigV4: Handled separately via handle_aws_sigv4 Special handling for different APIs: - GitHub: Uses "token" prefix (auth_type="token") @@ -385,25 +407,40 @@ def process_auth_headers(headers: Dict[str, Any], tool_input: Dict[str, Any]) -> - AWS: Uses SigV4 signing (auth_type="aws_sig_v4") Examples: - # GitHub API with environment variable - process_auth_headers({}, {"auth_type": "token", "auth_env_var": "GITHUB_TOKEN"}) + # GitHub API with env var (requires HTTP_REQUEST_TOKEN_CONFIG["GITHUB_TOKEN"] = ["api.github.com"]) + process_auth_headers({}, {"auth_type": "token", "auth_env_var": "GITHUB_TOKEN", "url": "https://api.github.com/user"}) - # GitLab API with environment variable - process_auth_headers({}, {"auth_type": "Bearer", "auth_env_var": "GITLAB_TOKEN"}) + # Direct token + process_auth_headers({}, {"auth_type": "Bearer", "auth_token": "my-token"}) """ headers = headers or {} - # Get auth token from input or environment auth_token = tool_input.get("auth_token") + + # Resolve token from environment variable if auth_env_var is provided if not auth_token and "auth_env_var" in tool_input: env_var_name = tool_input["auth_env_var"] - auth_token = os.getenv(env_var_name) - if not auth_token: + allowed_domains = HTTP_REQUEST_TOKEN_CONFIG.get(env_var_name) + + if allowed_domains is None: + raise ValueError( + f"Environment variable '{env_var_name}' is not listed in STRANDS_HTTP_REQUEST_TOKEN_CONFIG. " + f"Add it with an explicit list of allowed domains before using it as an auth token." + ) + + # Validate the request URL against the allowed domains using URL parsing + request_url = tool_input.get("url", "") + request_host = urlparse(request_url).hostname or "" + if request_host not in allowed_domains: raise ValueError( - f"Environment variable '{env_var_name}' not found or empty. " - f"Use environment(action='list') to see available variables." + f"Request to '{request_host}' is not in the allowed domains for '{env_var_name}': {allowed_domains}" ) + auth_token = os.environ.get(env_var_name) + if not auth_token: + raise ValueError(f"Environment variable '{env_var_name}' is not set or is empty.") + logger.info(f"Resolved auth token from environment variable '{env_var_name}' for domain '{request_host}'") + auth_type = tool_input.get("auth_type") if auth_token: @@ -554,6 +591,16 @@ def http_request(tool: ToolUse, **kwargs: Any) -> ToolResult: Common API Examples: 1. GitHub API (uses "token" auth_type): + ```python + http_request( + method="GET", + url="https://api.github.com/user", + auth_type="token", + auth_token="", + ) + ``` + + Or with env var (requires HTTP_REQUEST_TOKEN_CONFIG["GITHUB_TOKEN"] = ["api.github.com"]): ```python http_request( method="GET", @@ -564,6 +611,16 @@ def http_request(tool: ToolUse, **kwargs: Any) -> ToolResult: ``` 2. GitLab API (uses "Bearer" auth_type): + ```python + http_request( + method="GET", + url="https://gitlab.com/api/v4/user", + auth_type="Bearer", + auth_token="", + ) + ``` + + Or with env var (requires HTTP_REQUEST_TOKEN_CONFIG["GITLAB_TOKEN"] = ["gitlab.com"]): ```python http_request( method="GET", @@ -622,9 +679,10 @@ def http_request(tool: ToolUse, **kwargs: Any) -> ToolResult: ``` Environment Variables: - - Authentication tokens are read from environment when auth_env_var is specified - AWS credentials are automatically loaded from environment variables or credentials file - - Use environment(action='list') to view all available environment variables + + Token Config: + - Use HTTP_REQUEST_TOKEN_CONFIG to allow specific env vars as auth tokens for permitted domains """ console = console_util.create() @@ -643,7 +701,17 @@ def http_request(tool: ToolUse, **kwargs: Any) -> ToolResult: url = tool_input["url"] headers = process_auth_headers(tool_input.get("headers", {}), tool_input) body = tool_input.get("body") - verify = tool_input.get("verify_ssl", True) + + # verify_ssl=False is opt-in via STRANDS_HTTP_ALLOW_INSECURE_SSL env var + verify_ssl_input = tool_input.get("verify_ssl", True) + if verify_ssl_input is False: + if os.environ.get("STRANDS_HTTP_ALLOW_INSECURE_SSL", "").lower() != "true": + raise ValueError( + "SSL verification cannot be disabled unless the STRANDS_HTTP_ALLOW_INSECURE_SSL " + "environment variable is set to 'true'." + ) + verify = verify_ssl_input + cookie = tool_input.get("cookie") cookie_jar = tool_input.get("cookie_jar") @@ -905,8 +973,10 @@ def http_request(tool: ToolUse, **kwargs: Any) -> ToolResult: result_text.append(f"Redirects: {redirect_count} redirects followed ({redirect_chain})") # Add minimal headers to text response - important_headers = ["Content-Type", "Content-Length", "Date", "Server"] - headers_text = {k: v for k, v in response.headers.items() if k in important_headers} + important_headers_lower = { + h.lower() for h in ["Content-Type", "Content-Length", "Date", "Server", "Payment-Required"] + } + headers_text = {k: v for k, v in response.headers.items() if k.lower() in important_headers_lower} result_text.append(f"Headers: {headers_text}") # Add body to text response @@ -937,9 +1007,9 @@ def http_request(tool: ToolUse, **kwargs: Any) -> ToolResult: if "auth" in error_str or "token" in error_str or "credential" in error_str or "unauthorized" in error_str: suggestion = ( "\n\nSuggestion: Check your authentication setup. Common solutions:\n" - "- For GitHub API: Use auth_type='token' with auth_env_var='GITHUB_TOKEN'\n" - "- For GitLab API: Use auth_type='Bearer' with auth_env_var='GITLAB_TOKEN'\n" - "- Use environment(action='list') to view available environment variables" + "- For GitHub API: Use auth_type='token' with auth_token=''\n" + "- For GitLab API: Use auth_type='Bearer' with auth_token=''\n" + "- For AWS APIs: Use auth_type='aws_sig_v4' with aws_auth configuration" ) # Special handling for ImportError to help with test assertions diff --git a/src/strands_tools/mem0_memory.py b/src/strands_tools/mem0_memory.py index e815e9ef..3b55f79d 100644 --- a/src/strands_tools/mem0_memory.py +++ b/src/strands_tools/mem0_memory.py @@ -74,17 +74,18 @@ from mem0 import Memory as Mem0Memory from mem0 import MemoryClient from opensearchpy import AWSV4SignerAuth, RequestsHttpConnection -from rich.console import Console from rich.panel import Panel from rich.table import Table from rich.text import Text from strands.types.tools import ToolResult, ToolResultContent, ToolUse +from strands_tools.utils import console_util + # Set up logging logger = logging.getLogger(__name__) # Initialize Rich console -console = Console() +console = console_util.create() TOOL_SPEC = { "name": "mem0_memory", diff --git a/src/strands_tools/rss.py b/src/strands_tools/rss.py index 9a91de0a..952501c4 100644 --- a/src/strands_tools/rss.py +++ b/src/strands_tools/rss.py @@ -34,7 +34,11 @@ def __init__(self): os.makedirs(self.storage_path, exist_ok=True) def get_feed_file_path(self, feed_id: str) -> str: - return os.path.join(self.storage_path, f"{feed_id}.json") + file_path = os.path.realpath(os.path.join(self.storage_path, f"{feed_id}.json")) + storage_real = os.path.realpath(self.storage_path) + if not file_path.startswith(storage_real + os.sep): + raise ValueError(f"Invalid feed_id: path traversal detected in '{feed_id}'") + return file_path def get_subscription_file_path(self) -> str: return os.path.join(self.storage_path, "subscriptions.json") diff --git a/src/strands_tools/tavily.py b/src/strands_tools/tavily.py index 437c3506..3915c800 100644 --- a/src/strands_tools/tavily.py +++ b/src/strands_tools/tavily.py @@ -52,10 +52,11 @@ from typing import Any, Dict, List, Literal, Optional, Union import aiohttp -from rich.console import Console from rich.panel import Panel from strands import tool +from strands_tools.utils import console_util + logger = logging.getLogger(__name__) # Tavily API configuration @@ -66,7 +67,7 @@ TAVILY_MAP_ENDPOINT = "/map" # Initialize Rich console -console = Console() +console = console_util.create() def _get_api_key() -> str: diff --git a/tests/code_interpreter/test_agent_core_code_interpreter.py b/tests/code_interpreter/test_agent_core_code_interpreter.py index 1fff0789..89a7bb4e 100644 --- a/tests/code_interpreter/test_agent_core_code_interpreter.py +++ b/tests/code_interpreter/test_agent_core_code_interpreter.py @@ -67,6 +67,7 @@ def test_initialization(interpreter): assert interpreter.default_session.startswith("session-") assert interpreter.auto_create is True assert interpreter.persist_sessions is True + assert interpreter.session_timeout_seconds == 900 def test_initialization_with_new_parameters(): @@ -77,6 +78,22 @@ def test_initialization_with_new_parameters(): assert interpreter.persist_sessions is False +def test_initialization_with_session_timeout(): + """Test initialization with custom session timeout.""" + with patch("strands_tools.code_interpreter.agent_core_code_interpreter.resolve_region") as mock_resolve: + mock_resolve.return_value = "us-west-2" + interpreter = AgentCoreCodeInterpreter(region="us-west-2", session_timeout_seconds=1800) + assert interpreter.session_timeout_seconds == 1800 + + +def test_initialization_without_session_timeout(): + """Test initialization without session timeout defaults to 900.""" + with patch("strands_tools.code_interpreter.agent_core_code_interpreter.resolve_region") as mock_resolve: + mock_resolve.return_value = "us-west-2" + interpreter = AgentCoreCodeInterpreter(region="us-west-2") + assert interpreter.session_timeout_seconds == 900 + + def test_session_name_no_cleaning(): """Test that session names are used as-is without cleaning.""" with patch("strands_tools.code_interpreter.agent_core_code_interpreter.resolve_region") as mock_resolve: @@ -396,7 +413,9 @@ def test_init_session_success(mock_client_class, interpreter, mock_client): assert result["content"][0]["json"]["sessionId"] == "test-session-id-123" mock_client_class.assert_called_once_with(region="us-west-2") - mock_client.start.assert_called_once_with(identifier="aws.codeinterpreter.v1", name="my-session") + mock_client.start.assert_called_once_with( + identifier="aws.codeinterpreter.v1", name="my-session", session_timeout_seconds=900 + ) assert "my-session" in interpreter._sessions session_info = interpreter._sessions["my-session"] @@ -429,7 +448,9 @@ def test_init_session_with_custom_identifier(mock_client_class, mock_client): assert result["content"][0]["json"]["sessionId"] == "test-session-id-123" mock_client_class.assert_called_once_with(region="us-west-2") - mock_client.start.assert_called_once_with(identifier=custom_id, name="custom-session") + mock_client.start.assert_called_once_with( + identifier=custom_id, name="custom-session", session_timeout_seconds=900 + ) assert "custom-session" in interpreter._sessions session_info = interpreter._sessions["custom-session"] @@ -458,7 +479,9 @@ def test_init_session_with_default_identifier(mock_client_class, mock_client): assert result["content"][0]["json"]["sessionId"] == "test-session-id-123" mock_client_class.assert_called_once_with(region="us-west-2") - mock_client.start.assert_called_once_with(identifier="aws.codeinterpreter.v1", name="default-session") + mock_client.start.assert_called_once_with( + identifier="aws.codeinterpreter.v1", name="default-session", session_timeout_seconds=900 + ) assert "default-session" in interpreter._sessions session_info = interpreter._sessions["default-session"] @@ -468,6 +491,44 @@ def test_init_session_with_default_identifier(mock_client_class, mock_client): assert session_info.client == mock_client +@patch("strands_tools.code_interpreter.agent_core_code_interpreter.BedrockAgentCoreCodeInterpreterClient") +def test_init_session_with_session_timeout(mock_client_class, mock_client): + """Test session initialization passes session_timeout_seconds to client.start() when set.""" + with patch("strands_tools.code_interpreter.agent_core_code_interpreter.resolve_region") as mock_resolve: + mock_resolve.return_value = "us-west-2" + mock_client_class.return_value = mock_client + + interpreter = AgentCoreCodeInterpreter(region="us-west-2", session_timeout_seconds=1800) + + action = InitSessionAction(type="initSession", description="Test session", session_name="timeout-session") + + result = interpreter.init_session(action) + + assert result["status"] == "success" + mock_client.start.assert_called_once_with( + identifier="aws.codeinterpreter.v1", name="timeout-session", session_timeout_seconds=1800 + ) + + +@patch("strands_tools.code_interpreter.agent_core_code_interpreter.BedrockAgentCoreCodeInterpreterClient") +def test_init_session_without_session_timeout(mock_client_class, mock_client): + """Test session initialization passes default session_timeout_seconds to client.start().""" + with patch("strands_tools.code_interpreter.agent_core_code_interpreter.resolve_region") as mock_resolve: + mock_resolve.return_value = "us-west-2" + mock_client_class.return_value = mock_client + + interpreter = AgentCoreCodeInterpreter(region="us-west-2") + + action = InitSessionAction(type="initSession", description="Test session", session_name="no-timeout-session") + + result = interpreter.init_session(action) + + assert result["status"] == "success" + mock_client.start.assert_called_once_with( + identifier="aws.codeinterpreter.v1", name="no-timeout-session", session_timeout_seconds=900 + ) + + @patch("strands_tools.code_interpreter.agent_core_code_interpreter.BedrockAgentCoreCodeInterpreterClient") def test_init_session_multiple_identifiers_verification(mock_client_class, mock_client): """Test that different interpreter instances with different identifiers work correctly.""" @@ -498,9 +559,12 @@ def test_init_session_multiple_identifiers_verification(mock_client_class, mock_ assert mock_client.start.call_count == 3 call_args_list = mock_client.start.call_args_list - assert call_args_list[0] == ((), {"identifier": custom_id1, "name": "session1"}) - assert call_args_list[1] == ((), {"identifier": custom_id2, "name": "session2"}) - assert call_args_list[2] == ((), {"identifier": "aws.codeinterpreter.v1", "name": "session3"}) + assert call_args_list[0] == ((), {"identifier": custom_id1, "name": "session1", "session_timeout_seconds": 900}) + assert call_args_list[1] == ((), {"identifier": custom_id2, "name": "session2", "session_timeout_seconds": 900}) + assert call_args_list[2] == ( + (), + {"identifier": "aws.codeinterpreter.v1", "name": "session3", "session_timeout_seconds": 900}, + ) @patch("strands_tools.code_interpreter.agent_core_code_interpreter.BedrockAgentCoreCodeInterpreterClient") diff --git a/tests/test_apify.py b/tests/test_apify.py new file mode 100644 index 00000000..34fc6537 --- /dev/null +++ b/tests/test_apify.py @@ -0,0 +1,1050 @@ +"""Tests for the Apify tools.""" + +import json +from unittest.mock import MagicMock, patch + +import pytest + +from strands_tools import apify +from strands_tools.apify import ( + ApifyToolClient, + apify_ecommerce_scraper, + apify_get_dataset_items, + apify_google_places_scraper, + apify_google_search_scraper, + apify_run_actor, + apify_run_actor_and_get_dataset, + apify_run_task, + apify_run_task_and_get_dataset, + apify_scrape_url, + apify_website_content_crawler, + apify_youtube_scraper, +) + +MOCK_ACTOR_RUN = { + "id": "run-HG7ml5fB1hCp8YEBA", + "actId": "actor~my-scraper", + "userId": "user-abc123", + "startedAt": "2026-03-15T14:30:00.000Z", + "finishedAt": "2026-03-15T14:35:22.000Z", + "status": "SUCCEEDED", + "statusMessage": "Actor finished successfully", + "defaultDatasetId": "dataset-WkC9gct8rq1uR5vDZ", + "defaultKeyValueStoreId": "kvs-Xb3A8gct8rq1uR5vD", + "buildNumber": "1.2.3", +} + +MOCK_FAILED_RUN = { + **MOCK_ACTOR_RUN, + "status": "FAILED", + "statusMessage": "Actor failed with an error", +} + +MOCK_TIMED_OUT_RUN = { + **MOCK_ACTOR_RUN, + "status": "TIMED-OUT", + "statusMessage": "Actor run timed out", +} + +MOCK_DATASET_ITEMS = [ + {"url": "https://example.com/product/1", "title": "Widget A", "price": 19.99, "currency": "USD"}, + {"url": "https://example.com/product/2", "title": "Widget B", "price": 29.99, "currency": "USD"}, + {"url": "https://example.com/product/3", "title": "Widget C", "price": 39.99, "currency": "EUR"}, +] + +MOCK_SCRAPED_ITEM = { + "url": "https://example.com", + "markdown": "# Example Domain\n\nThis domain is for use in illustrative examples.", + "text": "Example Domain. This domain is for use in illustrative examples.", +} + + +def _make_apify_api_error(status_code: int, message: str) -> Exception: + """Create an ApifyApiError instance for testing without calling its real __init__.""" + from apify_client.errors import ApifyApiError + + error = ApifyApiError.__new__(ApifyApiError) + Exception.__init__(error, message) + error.status_code = status_code + error.message = message + return error + + +@pytest.fixture +def mock_apify_client(): + """Create a mock ApifyClient with pre-configured responses.""" + client = MagicMock() + + mock_actor = MagicMock() + mock_actor.call.return_value = MOCK_ACTOR_RUN + client.actor.return_value = mock_actor + + mock_task = MagicMock() + mock_task.call.return_value = MOCK_ACTOR_RUN + client.task.return_value = mock_task + + mock_dataset = MagicMock() + mock_list_result = MagicMock() + mock_list_result.items = MOCK_DATASET_ITEMS + mock_dataset.list_items.return_value = mock_list_result + client.dataset.return_value = mock_dataset + + return client + + +@pytest.fixture +def mock_apify_env(monkeypatch): + """Set required Apify environment variables.""" + monkeypatch.setenv("APIFY_API_TOKEN", "test-token-12345") + + +# --- Module import --- + + +def test_apify_module_is_importable(): + """Verify that the apify module can be imported from strands_tools.""" + assert apify is not None + assert apify.__name__ == "strands_tools.apify" + + +# --- ApifyToolClient --- + + +def test_client_missing_token(monkeypatch): + """ApifyToolClient raises ValueError when APIFY_API_TOKEN is not set.""" + monkeypatch.delenv("APIFY_API_TOKEN", raising=False) + with pytest.raises(ValueError, match="APIFY_API_TOKEN"): + ApifyToolClient() + + +def test_client_uses_env_token(mock_apify_env): + """ApifyToolClient passes the env token to ApifyClient.""" + with patch("strands_tools.apify.ApifyClient") as MockClient: + ApifyToolClient() + MockClient.assert_called_once_with( + "test-token-12345", + headers={"x-apify-integration-platform": "strands-agents"}, + ) + + +# --- apify_run_actor --- + + +def test_run_actor_success(mock_apify_env, mock_apify_client): + """Successful Actor run returns structured result with run metadata.""" + with patch("strands_tools.apify.ApifyClient", return_value=mock_apify_client): + result = apify_run_actor(actor_id="actor/my-scraper", run_input={"url": "https://example.com"}) + + assert result["status"] == "success" + data = json.loads(result["content"][0]["text"]) + assert data["run_id"] == "run-HG7ml5fB1hCp8YEBA" + assert data["status"] == "SUCCEEDED" + assert data["dataset_id"] == "dataset-WkC9gct8rq1uR5vDZ" + assert "started_at" in data + assert "finished_at" in data + mock_apify_client.actor.assert_called_once_with("actor/my-scraper") + + +def test_run_actor_default_input(mock_apify_env, mock_apify_client): + """Actor run defaults run_input to empty dict when not provided.""" + with patch("strands_tools.apify.ApifyClient", return_value=mock_apify_client): + result = apify_run_actor(actor_id="actor/my-scraper") + + assert result["status"] == "success" + call_kwargs = mock_apify_client.actor.return_value.call.call_args.kwargs + assert call_kwargs["run_input"] == {} + + +def test_run_actor_explicit_empty_input(mock_apify_env, mock_apify_client): + """Actor run passes through an explicitly empty dict instead of treating it as falsy.""" + empty_input: dict = {} + with patch("strands_tools.apify.ApifyClient", return_value=mock_apify_client): + result = apify_run_actor(actor_id="actor/my-scraper", run_input=empty_input) + + assert result["status"] == "success" + call_kwargs = mock_apify_client.actor.return_value.call.call_args.kwargs + assert call_kwargs["run_input"] is empty_input + + +def test_run_actor_with_memory(mock_apify_env, mock_apify_client): + """Actor run passes memory_mbytes when provided.""" + with patch("strands_tools.apify.ApifyClient", return_value=mock_apify_client): + apify_run_actor(actor_id="actor/my-scraper", memory_mbytes=512) + + call_kwargs = mock_apify_client.actor.return_value.call.call_args.kwargs + assert call_kwargs["memory_mbytes"] == 512 + + +def test_run_actor_failure(mock_apify_env, mock_apify_client): + """Actor run returns error dict when Actor fails.""" + mock_apify_client.actor.return_value.call.return_value = MOCK_FAILED_RUN + + with patch("strands_tools.apify.ApifyClient", return_value=mock_apify_client): + result = apify_run_actor(actor_id="actor/my-scraper") + + assert result["status"] == "error" + assert "FAILED" in result["content"][0]["text"] + + +def test_run_actor_timeout(mock_apify_env, mock_apify_client): + """Actor run returns error dict when Actor times out.""" + mock_apify_client.actor.return_value.call.return_value = MOCK_TIMED_OUT_RUN + + with patch("strands_tools.apify.ApifyClient", return_value=mock_apify_client): + result = apify_run_actor(actor_id="actor/my-scraper") + + assert result["status"] == "error" + assert "TIMED-OUT" in result["content"][0]["text"] + + +def test_run_actor_api_exception(mock_apify_env, mock_apify_client): + """Actor run returns error dict on API exceptions.""" + mock_apify_client.actor.return_value.call.side_effect = Exception("Connection failed") + + with patch("strands_tools.apify.ApifyClient", return_value=mock_apify_client): + result = apify_run_actor(actor_id="actor/my-scraper") + + assert result["status"] == "error" + assert "Connection failed" in result["content"][0]["text"] + + +def test_run_actor_none_response(mock_apify_env, mock_apify_client): + """Actor run returns error dict when ActorClient.call() returns None.""" + mock_apify_client.actor.return_value.call.return_value = None + + with patch("strands_tools.apify.ApifyClient", return_value=mock_apify_client): + result = apify_run_actor(actor_id="actor/my-scraper") + + assert result["status"] == "error" + assert "no run data" in result["content"][0]["text"] + + +def test_run_actor_apify_api_error_401(mock_apify_env, mock_apify_client): + """Actor run returns friendly message for 401 authentication errors.""" + error = _make_apify_api_error(401, "Unauthorized") + mock_apify_client.actor.return_value.call.side_effect = error + + with patch("strands_tools.apify.ApifyClient", return_value=mock_apify_client): + result = apify_run_actor(actor_id="actor/my-scraper") + + assert result["status"] == "error" + assert "Authentication failed" in result["content"][0]["text"] + + +def test_run_actor_apify_api_error_404(mock_apify_env, mock_apify_client): + """Actor run returns friendly message for 404 not-found errors.""" + error = _make_apify_api_error(404, "Actor not found") + mock_apify_client.actor.return_value.call.side_effect = error + + with patch("strands_tools.apify.ApifyClient", return_value=mock_apify_client): + result = apify_run_actor(actor_id="actor/nonexistent") + + assert result["status"] == "error" + assert "Resource not found" in result["content"][0]["text"] + + +# --- apify_get_dataset_items --- + + +def test_get_dataset_items_success(mock_apify_env, mock_apify_client): + """Successful dataset retrieval returns structured result with items.""" + with patch("strands_tools.apify.ApifyClient", return_value=mock_apify_client): + result = apify_get_dataset_items(dataset_id="dataset-WkC9gct8rq1uR5vDZ") + + assert result["status"] == "success" + items = json.loads(result["content"][0]["text"]) + assert len(items) == 3 + assert items[0]["title"] == "Widget A" + assert items[2]["currency"] == "EUR" + mock_apify_client.dataset.assert_called_once_with("dataset-WkC9gct8rq1uR5vDZ") + + +def test_get_dataset_items_with_pagination(mock_apify_env, mock_apify_client): + """dataset retrieval passes limit and offset.""" + with patch("strands_tools.apify.ApifyClient", return_value=mock_apify_client): + apify_get_dataset_items(dataset_id="dataset-xyz", limit=50, offset=10) + + mock_apify_client.dataset.return_value.list_items.assert_called_once_with(limit=50, offset=10) + + +def test_get_dataset_items_empty(mock_apify_env, mock_apify_client): + """Empty dataset returns a structured result with empty JSON array.""" + mock_list_result = MagicMock() + mock_list_result.items = [] + mock_apify_client.dataset.return_value.list_items.return_value = mock_list_result + + with patch("strands_tools.apify.ApifyClient", return_value=mock_apify_client): + result = apify_get_dataset_items(dataset_id="dataset-empty") + + assert result["status"] == "success" + items = json.loads(result["content"][0]["text"]) + assert items == [] + + +# --- apify_run_actor_and_get_dataset --- + + +def test_run_actor_and_get_dataset_success(mock_apify_env, mock_apify_client): + """Combined run + dataset fetch returns structured result with metadata and items.""" + with patch("strands_tools.apify.ApifyClient", return_value=mock_apify_client): + result = apify_run_actor_and_get_dataset( + actor_id="actor/my-scraper", + run_input={"url": "https://example.com"}, + dataset_items_limit=50, + ) + + assert result["status"] == "success" + data = json.loads(result["content"][0]["text"]) + assert data["run_id"] == "run-HG7ml5fB1hCp8YEBA" + assert data["status"] == "SUCCEEDED" + assert data["dataset_id"] == "dataset-WkC9gct8rq1uR5vDZ" + assert len(data["items"]) == 3 + assert data["items"][0]["title"] == "Widget A" + + +def test_run_actor_and_get_dataset_no_dataset_id(mock_apify_env, mock_apify_client): + """Combined tool returns error when the Actor run has no default dataset.""" + run_no_dataset = {**MOCK_ACTOR_RUN, "defaultDatasetId": None} + mock_apify_client.actor.return_value.call.return_value = run_no_dataset + + with patch("strands_tools.apify.ApifyClient", return_value=mock_apify_client): + result = apify_run_actor_and_get_dataset(actor_id="actor/my-scraper") + + assert result["status"] == "error" + assert "no default dataset" in result["content"][0]["text"] + + +def test_run_actor_and_get_dataset_actor_failure(mock_apify_env, mock_apify_client): + """Combined tool returns error dict when the Actor fails.""" + mock_apify_client.actor.return_value.call.return_value = MOCK_FAILED_RUN + + with patch("strands_tools.apify.ApifyClient", return_value=mock_apify_client): + result = apify_run_actor_and_get_dataset(actor_id="actor/my-scraper") + + assert result["status"] == "error" + assert "FAILED" in result["content"][0]["text"] + + +# --- apify_run_task --- + + +def test_run_task_success(mock_apify_env, mock_apify_client): + """Successful task run returns structured result with run metadata.""" + with patch("strands_tools.apify.ApifyClient", return_value=mock_apify_client): + result = apify_run_task(task_id="user~my-task", task_input={"query": "test"}) + + assert result["status"] == "success" + data = json.loads(result["content"][0]["text"]) + assert data["run_id"] == "run-HG7ml5fB1hCp8YEBA" + assert data["status"] == "SUCCEEDED" + assert data["dataset_id"] == "dataset-WkC9gct8rq1uR5vDZ" + mock_apify_client.task.assert_called_once_with("user~my-task") + + +def test_run_task_no_input(mock_apify_env, mock_apify_client): + """task run omits task_input kwarg when not provided.""" + with patch("strands_tools.apify.ApifyClient", return_value=mock_apify_client): + result = apify_run_task(task_id="user~my-task") + + assert result["status"] == "success" + call_kwargs = mock_apify_client.task.return_value.call.call_args.kwargs + assert "task_input" not in call_kwargs + + +def test_run_task_with_memory(mock_apify_env, mock_apify_client): + """task run passes memory_mbytes when provided.""" + with patch("strands_tools.apify.ApifyClient", return_value=mock_apify_client): + apify_run_task(task_id="user~my-task", memory_mbytes=1024) + + call_kwargs = mock_apify_client.task.return_value.call.call_args.kwargs + assert call_kwargs["memory_mbytes"] == 1024 + + +def test_run_task_failure(mock_apify_env, mock_apify_client): + """task run returns error dict when task fails.""" + mock_apify_client.task.return_value.call.return_value = MOCK_FAILED_RUN + + with patch("strands_tools.apify.ApifyClient", return_value=mock_apify_client): + result = apify_run_task(task_id="user~my-task") + + assert result["status"] == "error" + assert "FAILED" in result["content"][0]["text"] + + +def test_run_task_none_response(mock_apify_env, mock_apify_client): + """task run returns error dict when TaskClient.call() returns None.""" + mock_apify_client.task.return_value.call.return_value = None + + with patch("strands_tools.apify.ApifyClient", return_value=mock_apify_client): + result = apify_run_task(task_id="user~my-task") + + assert result["status"] == "error" + assert "no run data" in result["content"][0]["text"] + + +def test_run_task_apify_api_error_401(mock_apify_env, mock_apify_client): + """task run returns friendly message for 401 authentication errors.""" + error = _make_apify_api_error(401, "Unauthorized") + mock_apify_client.task.return_value.call.side_effect = error + + with patch("strands_tools.apify.ApifyClient", return_value=mock_apify_client): + result = apify_run_task(task_id="user~my-task") + + assert result["status"] == "error" + assert "Authentication failed" in result["content"][0]["text"] + + +# --- apify_run_task_and_get_dataset --- + + +def test_run_task_and_get_dataset_success(mock_apify_env, mock_apify_client): + """Combined task run + dataset fetch returns structured result with metadata and items.""" + with patch("strands_tools.apify.ApifyClient", return_value=mock_apify_client): + result = apify_run_task_and_get_dataset( + task_id="user~my-task", + task_input={"query": "test"}, + dataset_items_limit=50, + ) + + assert result["status"] == "success" + data = json.loads(result["content"][0]["text"]) + assert data["run_id"] == "run-HG7ml5fB1hCp8YEBA" + assert len(data["items"]) == 3 + assert data["items"][0]["title"] == "Widget A" + + +def test_run_task_and_get_dataset_no_dataset_id(mock_apify_env, mock_apify_client): + """Combined task tool returns error when the task run has no default dataset.""" + run_no_dataset = {**MOCK_ACTOR_RUN, "defaultDatasetId": None} + mock_apify_client.task.return_value.call.return_value = run_no_dataset + + with patch("strands_tools.apify.ApifyClient", return_value=mock_apify_client): + result = apify_run_task_and_get_dataset(task_id="user~my-task") + + assert result["status"] == "error" + assert "no default dataset" in result["content"][0]["text"] + + +def test_run_task_and_get_dataset_task_failure(mock_apify_env, mock_apify_client): + """Combined task tool returns error dict when the task fails.""" + mock_apify_client.task.return_value.call.return_value = MOCK_FAILED_RUN + + with patch("strands_tools.apify.ApifyClient", return_value=mock_apify_client): + result = apify_run_task_and_get_dataset(task_id="user~my-task") + + assert result["status"] == "error" + assert "FAILED" in result["content"][0]["text"] + + +# --- apify_scrape_url --- + + +def test_scrape_url_success(mock_apify_env, mock_apify_client): + """Scrape URL returns structured result with markdown content.""" + mock_list_result = MagicMock() + mock_list_result.items = [MOCK_SCRAPED_ITEM] + mock_apify_client.dataset.return_value.list_items.return_value = mock_list_result + + with patch("strands_tools.apify.ApifyClient", return_value=mock_apify_client): + result = apify_scrape_url(url="https://example.com") + + assert result["status"] == "success" + assert "Example Domain" in result["content"][0]["text"] + mock_apify_client.actor.assert_called_once_with("apify/website-content-crawler") + + +def test_scrape_url_none_response(mock_apify_env, mock_apify_client): + """Scrape URL returns error dict when ActorClient.call() returns None.""" + mock_apify_client.actor.return_value.call.return_value = None + + with patch("strands_tools.apify.ApifyClient", return_value=mock_apify_client): + result = apify_scrape_url(url="https://example.com") + + assert result["status"] == "error" + assert "no run data" in result["content"][0]["text"] + + +def test_scrape_url_no_dataset_id(mock_apify_env, mock_apify_client): + """Scrape URL returns error when the crawler run has no default dataset.""" + run_no_dataset = {**MOCK_ACTOR_RUN, "defaultDatasetId": None} + mock_apify_client.actor.return_value.call.return_value = run_no_dataset + + with patch("strands_tools.apify.ApifyClient", return_value=mock_apify_client): + result = apify_scrape_url(url="https://example.com") + + assert result["status"] == "error" + assert "no default dataset" in result["content"][0]["text"] + + +def test_scrape_url_no_content(mock_apify_env, mock_apify_client): + """Scrape URL returns error dict when no content is returned.""" + mock_list_result = MagicMock() + mock_list_result.items = [] + mock_apify_client.dataset.return_value.list_items.return_value = mock_list_result + + with patch("strands_tools.apify.ApifyClient", return_value=mock_apify_client): + result = apify_scrape_url(url="https://example.com") + + assert result["status"] == "error" + assert "No content returned" in result["content"][0]["text"] + + +def test_scrape_url_crawler_failure(mock_apify_env, mock_apify_client): + """Scrape URL returns error dict when the crawler Actor fails.""" + mock_apify_client.actor.return_value.call.return_value = MOCK_FAILED_RUN + + with patch("strands_tools.apify.ApifyClient", return_value=mock_apify_client): + result = apify_scrape_url(url="https://example.com") + + assert result["status"] == "error" + assert "FAILED" in result["content"][0]["text"] + + +def test_scrape_url_falls_back_to_text(mock_apify_env, mock_apify_client): + """Scrape URL falls back to text field when markdown is absent.""" + item_without_markdown = {"url": "https://example.com", "text": "Plain text content"} + mock_list_result = MagicMock() + mock_list_result.items = [item_without_markdown] + mock_apify_client.dataset.return_value.list_items.return_value = mock_list_result + + with patch("strands_tools.apify.ApifyClient", return_value=mock_apify_client): + result = apify_scrape_url(url="https://example.com") + + assert result["status"] == "success" + assert result["content"][0]["text"] == "Plain text content" + + +def test_scrape_url_invalid_url_scheme(mock_apify_env): + """apify_scrape_url returns error for invalid URL scheme.""" + result = apify_scrape_url(url="ftp://example.com") + + assert result["status"] == "error" + assert "Invalid URL scheme" in result["content"][0]["text"] + + +def test_scrape_url_missing_scheme(mock_apify_env): + """apify_scrape_url returns error for URL without http/https scheme.""" + result = apify_scrape_url(url="example.com") + + assert result["status"] == "error" + assert "Invalid URL scheme" in result["content"][0]["text"] + + +# --- Parameter validation --- + + +def test_run_actor_empty_actor_id(mock_apify_env): + """apify_run_actor returns error for whitespace-only actor_id.""" + result = apify_run_actor(actor_id=" ") + + assert result["status"] == "error" + assert "actor_id" in result["content"][0]["text"] + + +def test_run_actor_zero_timeout(mock_apify_env): + """apify_run_actor returns error for non-positive timeout_secs.""" + result = apify_run_actor(actor_id="actor/valid", timeout_secs=0) + + assert result["status"] == "error" + assert "timeout_secs" in result["content"][0]["text"] + + +def test_run_actor_negative_timeout(mock_apify_env): + """apify_run_actor returns error for negative timeout_secs.""" + result = apify_run_actor(actor_id="actor/valid", timeout_secs=-5) + + assert result["status"] == "error" + assert "timeout_secs" in result["content"][0]["text"] + + +def test_run_actor_zero_memory(mock_apify_env): + """apify_run_actor returns error for non-positive memory_mbytes.""" + result = apify_run_actor(actor_id="actor/valid", memory_mbytes=0) + + assert result["status"] == "error" + assert "memory_mbytes" in result["content"][0]["text"] + + +def test_run_task_empty_task_id(mock_apify_env): + """apify_run_task returns error for whitespace-only task_id.""" + result = apify_run_task(task_id=" ") + + assert result["status"] == "error" + assert "task_id" in result["content"][0]["text"] + + +def test_run_task_zero_timeout(mock_apify_env): + """apify_run_task returns error for non-positive timeout_secs.""" + result = apify_run_task(task_id="user~my-task", timeout_secs=0) + + assert result["status"] == "error" + assert "timeout_secs" in result["content"][0]["text"] + + +def test_run_task_zero_memory(mock_apify_env): + """apify_run_task returns error for non-positive memory_mbytes.""" + result = apify_run_task(task_id="user~my-task", memory_mbytes=0) + + assert result["status"] == "error" + assert "memory_mbytes" in result["content"][0]["text"] + + +def test_get_dataset_items_empty_dataset_id(mock_apify_env): + """apify_get_dataset_items returns error for whitespace-only dataset_id.""" + result = apify_get_dataset_items(dataset_id=" ") + + assert result["status"] == "error" + assert "dataset_id" in result["content"][0]["text"] + + +def test_get_dataset_items_zero_limit(mock_apify_env): + """apify_get_dataset_items returns error for non-positive limit.""" + result = apify_get_dataset_items(dataset_id="dataset-abc", limit=0) + + assert result["status"] == "error" + assert "limit" in result["content"][0]["text"] + + +def test_get_dataset_items_negative_offset(mock_apify_env): + """apify_get_dataset_items returns error for negative offset.""" + result = apify_get_dataset_items(dataset_id="dataset-abc", offset=-1) + + assert result["status"] == "error" + assert "offset" in result["content"][0]["text"] + + +def test_run_actor_and_get_dataset_zero_dataset_limit(mock_apify_env): + """apify_run_actor_and_get_dataset returns error for non-positive dataset_items_limit.""" + result = apify_run_actor_and_get_dataset(actor_id="actor/valid", dataset_items_limit=0) + + assert result["status"] == "error" + assert "dataset_items_limit" in result["content"][0]["text"] + + +def test_run_actor_and_get_dataset_negative_dataset_offset(mock_apify_env): + """apify_run_actor_and_get_dataset returns error for negative dataset_items_offset.""" + result = apify_run_actor_and_get_dataset(actor_id="actor/valid", dataset_items_offset=-1) + + assert result["status"] == "error" + assert "dataset_items_offset" in result["content"][0]["text"] + + +def test_run_task_and_get_dataset_zero_dataset_limit(mock_apify_env): + """apify_run_task_and_get_dataset returns error for non-positive dataset_items_limit.""" + result = apify_run_task_and_get_dataset(task_id="user~my-task", dataset_items_limit=0) + + assert result["status"] == "error" + assert "dataset_items_limit" in result["content"][0]["text"] + + +def test_run_task_and_get_dataset_negative_dataset_offset(mock_apify_env): + """apify_run_task_and_get_dataset returns error for negative dataset_items_offset.""" + result = apify_run_task_and_get_dataset(task_id="user~my-task", dataset_items_offset=-1) + + assert result["status"] == "error" + assert "dataset_items_offset" in result["content"][0]["text"] + + +def test_scrape_url_zero_timeout(mock_apify_env): + """apify_scrape_url returns error for non-positive timeout_secs.""" + result = apify_scrape_url(url="https://example.com", timeout_secs=0) + + assert result["status"] == "error" + assert "timeout_secs" in result["content"][0]["text"] + + +def test_scrape_url_invalid_crawler_type(mock_apify_env): + """apify_scrape_url returns error for unsupported crawler_type.""" + result = apify_scrape_url(url="https://example.com", crawler_type="invalid") + + assert result["status"] == "error" + assert "crawler_type" in result["content"][0]["text"] + + +def test_scrape_url_missing_domain(mock_apify_env): + """apify_scrape_url returns error for URL with no domain.""" + result = apify_scrape_url(url="https://") + + assert result["status"] == "error" + assert "domain" in result["content"][0]["text"].lower() + + +# --- Dependency guard --- + + +def test_missing_apify_client_run_actor(mock_apify_env): + """apify_run_actor returns error dict when apify-client is not installed.""" + with patch("strands_tools.apify.HAS_APIFY_CLIENT", False): + result = apify_run_actor(actor_id="test/actor") + + assert result["status"] == "error" + assert "apify-client" in result["content"][0]["text"] + + +def test_missing_apify_client_get_dataset(mock_apify_env): + """apify_get_dataset_items returns error dict when apify-client is not installed.""" + with patch("strands_tools.apify.HAS_APIFY_CLIENT", False): + result = apify_get_dataset_items(dataset_id="dataset-123") + + assert result["status"] == "error" + assert "apify-client" in result["content"][0]["text"] + + +def test_missing_apify_client_run_and_get(mock_apify_env): + """apify_run_actor_and_get_dataset returns error dict when apify-client is not installed.""" + with patch("strands_tools.apify.HAS_APIFY_CLIENT", False): + result = apify_run_actor_and_get_dataset(actor_id="test/actor") + + assert result["status"] == "error" + assert "apify-client" in result["content"][0]["text"] + + +def test_missing_apify_client_run_task(mock_apify_env): + """apify_run_task returns error dict when apify-client is not installed.""" + with patch("strands_tools.apify.HAS_APIFY_CLIENT", False): + result = apify_run_task(task_id="user~my-task") + + assert result["status"] == "error" + assert "apify-client" in result["content"][0]["text"] + + +def test_missing_apify_client_run_task_and_get(mock_apify_env): + """apify_run_task_and_get_dataset returns error dict when apify-client is not installed.""" + with patch("strands_tools.apify.HAS_APIFY_CLIENT", False): + result = apify_run_task_and_get_dataset(task_id="user~my-task") + + assert result["status"] == "error" + assert "apify-client" in result["content"][0]["text"] + + +def test_missing_apify_client_scrape_url(mock_apify_env): + """apify_scrape_url returns error dict when apify-client is not installed.""" + with patch("strands_tools.apify.HAS_APIFY_CLIENT", False): + result = apify_scrape_url(url="https://example.com") + + assert result["status"] == "error" + assert "apify-client" in result["content"][0]["text"] + + +# --- Missing token from tool entry points --- + + +def test_run_actor_missing_token(monkeypatch): + """apify_run_actor returns error dict when APIFY_API_TOKEN is missing.""" + monkeypatch.delenv("APIFY_API_TOKEN", raising=False) + result = apify_run_actor(actor_id="test/actor") + + assert result["status"] == "error" + assert "APIFY_API_TOKEN" in result["content"][0]["text"] + + +def test_get_dataset_items_missing_token(monkeypatch): + """apify_get_dataset_items returns error dict when APIFY_API_TOKEN is missing.""" + monkeypatch.delenv("APIFY_API_TOKEN", raising=False) + result = apify_get_dataset_items(dataset_id="dataset-123") + + assert result["status"] == "error" + assert "APIFY_API_TOKEN" in result["content"][0]["text"] + + +def test_run_actor_and_get_dataset_missing_token(monkeypatch): + """apify_run_actor_and_get_dataset returns error dict when APIFY_API_TOKEN is missing.""" + monkeypatch.delenv("APIFY_API_TOKEN", raising=False) + result = apify_run_actor_and_get_dataset(actor_id="test/actor") + + assert result["status"] == "error" + assert "APIFY_API_TOKEN" in result["content"][0]["text"] + + +def test_run_task_missing_token(monkeypatch): + """apify_run_task returns error dict when APIFY_API_TOKEN is missing.""" + monkeypatch.delenv("APIFY_API_TOKEN", raising=False) + result = apify_run_task(task_id="user~my-task") + + assert result["status"] == "error" + assert "APIFY_API_TOKEN" in result["content"][0]["text"] + + +def test_run_task_and_get_dataset_missing_token(monkeypatch): + """apify_run_task_and_get_dataset returns error dict when APIFY_API_TOKEN is missing.""" + monkeypatch.delenv("APIFY_API_TOKEN", raising=False) + result = apify_run_task_and_get_dataset(task_id="user~my-task") + + assert result["status"] == "error" + assert "APIFY_API_TOKEN" in result["content"][0]["text"] + + +def test_scrape_url_missing_token(monkeypatch): + """apify_scrape_url returns error dict when APIFY_API_TOKEN is missing.""" + monkeypatch.delenv("APIFY_API_TOKEN", raising=False) + result = apify_scrape_url(url="https://example.com") + + assert result["status"] == "error" + assert "APIFY_API_TOKEN" in result["content"][0]["text"] + + +# --- apify_google_search_scraper --- + + +def test_google_search_scraper_success(mock_apify_env, mock_apify_client): + """Google Search Scraper returns structured results with correct input mapping.""" + with patch("strands_tools.apify.ApifyClient", return_value=mock_apify_client): + result = apify_google_search_scraper(search_query="best AI frameworks", results_limit=5) + + assert result["status"] == "success" + data = json.loads(result["content"][0]["text"]) + assert data["run_id"] == "run-HG7ml5fB1hCp8YEBA" + assert len(data["items"]) == 3 + + mock_apify_client.actor.assert_called_once_with("apify/google-search-scraper") + run_input = mock_apify_client.actor.return_value.call.call_args.kwargs["run_input"] + assert run_input["queries"] == "best AI frameworks" + assert run_input["maxPagesPerQuery"] == 1 + assert "resultsPerPage" not in run_input + + +def test_google_search_scraper_multi_page(mock_apify_env, mock_apify_client): + """Google Search Scraper calculates correct page count when results_limit exceeds 10.""" + with patch("strands_tools.apify.ApifyClient", return_value=mock_apify_client): + apify_google_search_scraper(search_query="AI", results_limit=25) + + run_input = mock_apify_client.actor.return_value.call.call_args.kwargs["run_input"] + assert run_input["maxPagesPerQuery"] == 3 + assert "resultsPerPage" not in run_input + + +def test_google_search_scraper_optional_params(mock_apify_env, mock_apify_client): + """Google Search Scraper includes optional country and language codes when provided.""" + with patch("strands_tools.apify.ApifyClient", return_value=mock_apify_client): + apify_google_search_scraper(search_query="AI", results_limit=10, country_code="de", language_code="de") + + run_input = mock_apify_client.actor.return_value.call.call_args.kwargs["run_input"] + assert run_input["countryCode"] == "de" + assert run_input["languageCode"] == "de" + + +def test_google_search_scraper_optional_params_omitted(mock_apify_env, mock_apify_client): + """Google Search Scraper omits optional fields when not provided.""" + with patch("strands_tools.apify.ApifyClient", return_value=mock_apify_client): + apify_google_search_scraper(search_query="AI") + + run_input = mock_apify_client.actor.return_value.call.call_args.kwargs["run_input"] + assert "countryCode" not in run_input + assert "languageCode" not in run_input + + +def test_google_search_scraper_missing_dependency(mock_apify_env): + """Google Search Scraper returns error when apify-client is not installed.""" + with patch("strands_tools.apify.HAS_APIFY_CLIENT", False): + result = apify_google_search_scraper(search_query="test") + + assert result["status"] == "error" + assert "apify-client" in result["content"][0]["text"] + + +def test_google_search_scraper_missing_token(monkeypatch): + """Google Search Scraper returns error when APIFY_API_TOKEN is missing.""" + monkeypatch.delenv("APIFY_API_TOKEN", raising=False) + result = apify_google_search_scraper(search_query="test") + + assert result["status"] == "error" + assert "APIFY_API_TOKEN" in result["content"][0]["text"] + + +def test_google_search_scraper_actor_failure(mock_apify_env, mock_apify_client): + """Google Search Scraper returns error when Actor fails.""" + mock_apify_client.actor.return_value.call.return_value = MOCK_FAILED_RUN + + with patch("strands_tools.apify.ApifyClient", return_value=mock_apify_client): + result = apify_google_search_scraper(search_query="test") + + assert result["status"] == "error" + assert "FAILED" in result["content"][0]["text"] + + +# --- apify_google_places_scraper --- + + +def test_google_places_scraper_success(mock_apify_env, mock_apify_client): + """Google Places Scraper returns structured results with correct input mapping.""" + with patch("strands_tools.apify.ApifyClient", return_value=mock_apify_client): + result = apify_google_places_scraper(search_query="restaurants in Prague", results_limit=10) + + assert result["status"] == "success" + data = json.loads(result["content"][0]["text"]) + assert data["run_id"] == "run-HG7ml5fB1hCp8YEBA" + + mock_apify_client.actor.assert_called_once_with("compass/crawler-google-places") + run_input = mock_apify_client.actor.return_value.call.call_args.kwargs["run_input"] + assert run_input["searchStringsArray"] == ["restaurants in Prague"] + assert run_input["maxCrawledPlacesPerSearch"] == 10 + assert run_input["maxReviews"] == 0 + + +def test_google_places_scraper_with_reviews(mock_apify_env, mock_apify_client): + """Google Places Scraper sets maxReviews when include_reviews is True.""" + with patch("strands_tools.apify.ApifyClient", return_value=mock_apify_client): + apify_google_places_scraper(search_query="hotels in Berlin", include_reviews=True, max_reviews=10) + + run_input = mock_apify_client.actor.return_value.call.call_args.kwargs["run_input"] + assert run_input["maxReviews"] == 10 + + +def test_google_places_scraper_reviews_disabled(mock_apify_env, mock_apify_client): + """Google Places Scraper sets maxReviews to 0 when include_reviews is False.""" + with patch("strands_tools.apify.ApifyClient", return_value=mock_apify_client): + apify_google_places_scraper(search_query="cafes", include_reviews=False, max_reviews=10) + + run_input = mock_apify_client.actor.return_value.call.call_args.kwargs["run_input"] + assert run_input["maxReviews"] == 0 + + +def test_google_places_scraper_optional_language(mock_apify_env, mock_apify_client): + """Google Places Scraper includes language when provided.""" + with patch("strands_tools.apify.ApifyClient", return_value=mock_apify_client): + apify_google_places_scraper(search_query="cafes", language="de") + + run_input = mock_apify_client.actor.return_value.call.call_args.kwargs["run_input"] + assert run_input["language"] == "de" + + +# --- apify_youtube_scraper --- + + +def test_youtube_scraper_search_query(mock_apify_env, mock_apify_client): + """YouTube Scraper returns results when given a search query.""" + with patch("strands_tools.apify.ApifyClient", return_value=mock_apify_client): + result = apify_youtube_scraper(search_query="python tutorial", results_limit=5) + + assert result["status"] == "success" + mock_apify_client.actor.assert_called_once_with("streamers/youtube-scraper") + run_input = mock_apify_client.actor.return_value.call.call_args.kwargs["run_input"] + assert run_input["searchQueries"] == ["python tutorial"] + assert run_input["maxResults"] == 5 + assert "startUrls" not in run_input + + +def test_youtube_scraper_urls(mock_apify_env, mock_apify_client): + """YouTube Scraper returns results when given specific URLs.""" + urls = ["https://www.youtube.com/watch?v=abc123"] + with patch("strands_tools.apify.ApifyClient", return_value=mock_apify_client): + result = apify_youtube_scraper(urls=urls) + + assert result["status"] == "success" + run_input = mock_apify_client.actor.return_value.call.call_args.kwargs["run_input"] + assert run_input["startUrls"] == [{"url": "https://www.youtube.com/watch?v=abc123"}] + assert "searchQueries" not in run_input + + +def test_youtube_scraper_both_query_and_urls(mock_apify_env, mock_apify_client): + """YouTube Scraper accepts both search_query and urls simultaneously.""" + urls = ["https://www.youtube.com/watch?v=abc123"] + with patch("strands_tools.apify.ApifyClient", return_value=mock_apify_client): + result = apify_youtube_scraper(search_query="python", urls=urls) + + assert result["status"] == "success" + run_input = mock_apify_client.actor.return_value.call.call_args.kwargs["run_input"] + assert run_input["searchQueries"] == ["python"] + assert run_input["startUrls"] == [{"url": "https://www.youtube.com/watch?v=abc123"}] + + +def test_youtube_scraper_no_input(mock_apify_env): + """YouTube Scraper returns error when neither search_query nor urls is provided.""" + result = apify_youtube_scraper() + + assert result["status"] == "error" + assert "search_query" in result["content"][0]["text"] + + +# --- apify_website_content_crawler --- + + +def test_website_content_crawler_success(mock_apify_env, mock_apify_client): + """Website Content Crawler returns results with correct input mapping.""" + with patch("strands_tools.apify.ApifyClient", return_value=mock_apify_client): + result = apify_website_content_crawler(start_url="https://docs.example.com", max_pages=5, max_depth=3) + + assert result["status"] == "success" + mock_apify_client.actor.assert_called_once_with("apify/website-content-crawler") + run_input = mock_apify_client.actor.return_value.call.call_args.kwargs["run_input"] + assert run_input["startUrls"] == [{"url": "https://docs.example.com"}] + assert run_input["maxCrawlPages"] == 5 + assert run_input["maxCrawlDepth"] == 3 + assert run_input["proxyConfiguration"] == {"useApifyProxy": True} + + +def test_website_content_crawler_defaults(mock_apify_env, mock_apify_client): + """Website Content Crawler uses correct defaults for max_pages and max_depth.""" + with patch("strands_tools.apify.ApifyClient", return_value=mock_apify_client): + apify_website_content_crawler(start_url="https://example.com") + + run_input = mock_apify_client.actor.return_value.call.call_args.kwargs["run_input"] + assert run_input["maxCrawlPages"] == 10 + assert run_input["maxCrawlDepth"] == 2 + + +def test_website_content_crawler_invalid_url(mock_apify_env): + """Website Content Crawler returns error for invalid URL.""" + result = apify_website_content_crawler(start_url="not-a-url") + + assert result["status"] == "error" + assert "Invalid URL" in result["content"][0]["text"] + + +# --- apify_ecommerce_scraper --- + + +def test_ecommerce_scraper_success(mock_apify_env, mock_apify_client): + """E-commerce Scraper returns results with correct input mapping for product URL.""" + with patch("strands_tools.apify.ApifyClient", return_value=mock_apify_client): + result = apify_ecommerce_scraper(url="https://www.amazon.com/dp/B0TEST", results_limit=10) + + assert result["status"] == "success" + data = json.loads(result["content"][0]["text"]) + assert data["run_id"] == "run-HG7ml5fB1hCp8YEBA" + + mock_apify_client.actor.assert_called_once_with("apify/e-commerce-scraping-tool") + run_input = mock_apify_client.actor.return_value.call.call_args.kwargs["run_input"] + assert run_input["detailsUrls"] == [{"url": "https://www.amazon.com/dp/B0TEST"}] + assert "listingUrls" not in run_input + assert run_input["maxProductResults"] == 10 + + +def test_ecommerce_scraper_listing_url(mock_apify_env, mock_apify_client): + """E-commerce Scraper uses listingUrls when url_type is 'listing'.""" + with patch("strands_tools.apify.ApifyClient", return_value=mock_apify_client): + result = apify_ecommerce_scraper( + url="https://www.amazon.com/s?k=headphones", url_type="listing", results_limit=10 + ) + + assert result["status"] == "success" + run_input = mock_apify_client.actor.return_value.call.call_args.kwargs["run_input"] + assert run_input["listingUrls"] == [{"url": "https://www.amazon.com/s?k=headphones"}] + assert "detailsUrls" not in run_input + + +def test_ecommerce_scraper_invalid_url_type(mock_apify_env): + """E-commerce Scraper returns error for invalid url_type.""" + result = apify_ecommerce_scraper(url="https://www.amazon.com/dp/B0TEST", url_type="invalid") + + assert result["status"] == "error" + assert "url_type" in result["content"][0]["text"] + + +def test_ecommerce_scraper_invalid_url(mock_apify_env): + """E-commerce Scraper returns error for invalid URL.""" + result = apify_ecommerce_scraper(url="not-a-url") + + assert result["status"] == "error" + assert "Invalid URL" in result["content"][0]["text"] + + +def test_ecommerce_scraper_actor_failure(mock_apify_env, mock_apify_client): + """E-commerce Scraper returns error when Actor fails.""" + mock_apify_client.actor.return_value.call.return_value = MOCK_FAILED_RUN + + with patch("strands_tools.apify.ApifyClient", return_value=mock_apify_client): + result = apify_ecommerce_scraper(url="https://www.amazon.com/dp/B0TEST") + + assert result["status"] == "error" + assert "FAILED" in result["content"][0]["text"] diff --git a/tests/test_elasticsearch_memory.py b/tests/test_elasticsearch_memory.py index 01f0ce51..b95b5bc8 100644 --- a/tests/test_elasticsearch_memory.py +++ b/tests/test_elasticsearch_memory.py @@ -260,14 +260,21 @@ def test_get_memory(mock_elasticsearch_client, mock_bedrock_client, config): """Test getting a specific memory by ID.""" agent = Agent(tools=[elasticsearch_memory]) - # Configure mock get response - mock_elasticsearch_client["client"].get.return_value = { - "_source": { - "memory_id": "mem_123", - "content": "Test content", - "timestamp": "2023-01-01T00:00:00Z", - "metadata": {"category": "test"}, - "namespace": "test_namespace", + # Configure mock search response (now uses search instead of get for namespace enforcement) + mock_elasticsearch_client["client"].search.return_value = { + "hits": { + "hits": [ + { + "_source": { + "memory_id": "mem_123", + "content": "Test content", + "timestamp": "2023-01-01T00:00:00Z", + "metadata": {"category": "test"}, + "namespace": "test_namespace", + } + } + ], + "total": {"value": 1}, } } @@ -278,25 +285,20 @@ def test_get_memory(mock_elasticsearch_client, mock_bedrock_client, config): assert result["status"] == "success" assert "Memory retrieved successfully" in result["content"][0]["text"] - # Verify get was called - mock_elasticsearch_client["client"].get.assert_called_once_with(index="test_index", id="mem_123") + # Verify search was called with both memory_id and namespace for security + mock_elasticsearch_client["client"].search.assert_called_once() + search_call = mock_elasticsearch_client["client"].search.call_args[1] + query = search_call["body"]["query"]["bool"]["must"] + assert {"term": {"memory_id": "mem_123"}} in query + assert {"term": {"namespace": "test_namespace"}} in query def test_delete_memory(mock_elasticsearch_client, mock_bedrock_client, config): """Test deleting a memory.""" agent = Agent(tools=[elasticsearch_memory]) - # Configure mock responses - mock_elasticsearch_client["client"].get.return_value = { - "_source": { - "memory_id": "mem_123", - "content": "Test content", - "timestamp": "2023-01-01T00:00:00Z", - "metadata": {}, - "namespace": "test_namespace", - } - } - mock_elasticsearch_client["client"].delete.return_value = {"result": "deleted"} + # Configure mock delete_by_query response (atomic delete with namespace constraint) + mock_elasticsearch_client["client"].delete_by_query.return_value = {"deleted": 1} # Call the tool result = agent.tool.elasticsearch_memory(action="delete", memory_id="mem_123", **config) @@ -305,8 +307,12 @@ def test_delete_memory(mock_elasticsearch_client, mock_bedrock_client, config): assert result["status"] == "success" assert "Memory deleted successfully: mem_123" in result["content"][0]["text"] - # Verify delete was called - mock_elasticsearch_client["client"].delete.assert_called_once_with(index="test_index", id="mem_123") + # Verify delete_by_query was called with both memory_id and namespace + mock_elasticsearch_client["client"].delete_by_query.assert_called_once() + call_args = mock_elasticsearch_client["client"].delete_by_query.call_args[1] + query = call_args["body"]["query"]["bool"]["must"] + assert {"term": {"memory_id": "mem_123"}} in query + assert {"term": {"namespace": "test_namespace"}} in query def test_unsupported_action(mock_elasticsearch_client, mock_bedrock_client, config): @@ -387,26 +393,32 @@ def test_memory_not_found(mock_elasticsearch_client, mock_bedrock_client, config """Test handling when memory is not found.""" agent = Agent(tools=[elasticsearch_memory]) - from elasticsearch import NotFoundError - - # Configure mock to raise NotFoundError - mock_elasticsearch_client["client"].get.side_effect = NotFoundError("404", "not_found_exception", {}) + # Configure mock search to return empty results (memory not found in namespace) + mock_elasticsearch_client["client"].search.return_value = { + "hits": { + "hits": [], + "total": {"value": 0}, + } + } # Call the tool result = agent.tool.elasticsearch_memory(action="get", memory_id="nonexistent", **config) # Verify error response assert result["status"] == "error" - assert "Memory nonexistent not found" in result["content"][0]["text"] + assert "Memory nonexistent not found in namespace test_namespace" in result["content"][0]["text"] def test_namespace_validation(mock_elasticsearch_client, mock_bedrock_client, config): """Test that memories are properly filtered by namespace.""" agent = Agent(tools=[elasticsearch_memory]) - # Configure mock get response with wrong namespace - mock_elasticsearch_client["client"].get.return_value = { - "_source": {"memory_id": "mem_123", "content": "Test content", "namespace": "wrong_namespace"} + # Configure mock search to return empty results (memory not in this namespace) + mock_elasticsearch_client["client"].search.return_value = { + "hits": { + "hits": [], + "total": {"value": 0}, + } } # Call the tool @@ -416,6 +428,13 @@ def test_namespace_validation(mock_elasticsearch_client, mock_bedrock_client, co assert result["status"] == "error" assert "not found in namespace test_namespace" in result["content"][0]["text"] + # Verify search was called with both memory_id and namespace + mock_elasticsearch_client["client"].search.assert_called_once() + search_call = mock_elasticsearch_client["client"].search.call_args[1] + query = search_call["body"]["query"]["bool"]["must"] + assert {"term": {"memory_id": "mem_123"}} in query + assert {"term": {"namespace": "test_namespace"}} in query + def test_pagination_support(mock_elasticsearch_client, mock_bedrock_client, config): """Test pagination support in list and retrieve operations.""" @@ -754,12 +773,15 @@ def test_security_scenarios(mock_elasticsearch_client, mock_bedrock_client): """Test security-related scenarios like namespace isolation.""" agent = Agent(tools=[elasticsearch_memory]) - # Configure mock get response with wrong namespace - mock_elasticsearch_client["client"].get.return_value = { - "_source": {"memory_id": "mem_123", "content": "Test content", "namespace": "wrong_namespace"} + # Configure mock search to return empty results (memory not in this namespace) + mock_elasticsearch_client["client"].search.return_value = { + "hits": { + "hits": [], + "total": {"value": 0}, + } } - # Test namespace validation + # Test namespace validation - memory exists but not in requested namespace result = agent.tool.elasticsearch_memory( action="get", memory_id="mem_123", @@ -791,3 +813,109 @@ def test_troubleshooting_scenarios(mock_elasticsearch_client, mock_bedrock_clien result = agent.tool.elasticsearch_memory(action="record", content="test", **config) assert result["status"] == "error" assert "Unable to connect to Elasticsearch cluster" in result["content"][0]["text"] + + +def test_injection_prevention(mock_elasticsearch_client, mock_bedrock_client, config): + """Test that injection attempts via namespace are blocked.""" + agent = Agent(tools=[elasticsearch_memory]) + + # Remove namespace from config to avoid conflict + test_config = {k: v for k, v in config.items() if k != "namespace"} + + # Test dict-based injection (analogous to MongoDB {"$ne": ""} attack) + malicious_namespace = {"$ne": ""} + result = agent.tool.elasticsearch_memory(action="list", namespace=malicious_namespace, **test_config) + assert result["status"] == "error" + error_text = result["content"][0]["text"] + assert "Invalid namespace" in error_text or "Input should be a valid string" in error_text + + # Test other injection payloads + injection_attempts = [ + {"$gt": ""}, + {"$regex": ".*"}, + {"$exists": True}, + {"$in": ["tenant1", "tenant2"]}, + ] + + for injection_payload in injection_attempts: + result = agent.tool.elasticsearch_memory(action="list", namespace=injection_payload, **test_config) + assert result["status"] == "error", f"Injection {injection_payload} should be blocked" + error_text = result["content"][0]["text"] + assert "Invalid namespace" in error_text or "Input should be a valid string" in error_text + + +def test_namespace_validation_strict_rules(mock_elasticsearch_client, mock_bedrock_client, config): + """Test strict namespace validation rules.""" + agent = Agent(tools=[elasticsearch_memory]) + + # Remove namespace from config to avoid conflict + test_config = {k: v for k, v in config.items() if k != "namespace"} + + # Test invalid characters (should be rejected) + invalid_namespaces = [ + "user.name", # Dots not allowed + "user@domain", # @ symbol + "user$name", # $ symbol + "user name", # Space + "user/path", # Forward slash + "user:name", # Colon + "a" * 65, # Too long (over 64 chars) + "", # Empty + " ", # Whitespace only + ] + + for invalid_namespace in invalid_namespaces: + result = agent.tool.elasticsearch_memory(action="list", namespace=invalid_namespace, **test_config) + assert result["status"] == "error", f"Invalid namespace '{invalid_namespace}' should be rejected" + error_text = result["content"][0]["text"] + assert "Invalid namespace" in error_text + + +def test_valid_namespaces_accepted(mock_elasticsearch_client, mock_bedrock_client, config): + """Test that valid namespaces are accepted.""" + agent = Agent(tools=[elasticsearch_memory]) + + # Configure mock responses + mock_elasticsearch_client["client"].search.return_value = { + "hits": { + "hits": [], + "total": {"value": 0}, + } + } + + # Remove namespace from config + test_config = {k: v for k, v in config.items() if k != "namespace"} + + valid_namespaces = [ + "default", + "user_123", + "tenant-abc", + "MyNamespace", + "a", + "A" * 64, # Max length + ] + + for valid_namespace in valid_namespaces: + result = agent.tool.elasticsearch_memory(action="list", namespace=valid_namespace, **test_config) + assert result["status"] == "success", f"Valid namespace '{valid_namespace}' should be accepted" + + +def test_delete_memory_namespace_enforcement(mock_elasticsearch_client, mock_bedrock_client, config): + """Test that delete enforces namespace atomically (no TOCTOU).""" + agent = Agent(tools=[elasticsearch_memory]) + + # Configure delete_by_query to return 0 deleted (memory not in namespace) + mock_elasticsearch_client["client"].delete_by_query.return_value = {"deleted": 0} + + result = agent.tool.elasticsearch_memory(action="delete", memory_id="mem_123", **config) + + # Should fail because memory not found in the requested namespace + assert result["status"] == "error" + assert "not found in namespace test_namespace" in result["content"][0]["text"] + + # Verify delete_by_query was called with namespace constraint + mock_elasticsearch_client["client"].delete_by_query.assert_called_once() + call_args = mock_elasticsearch_client["client"].delete_by_query.call_args[1] + query = call_args["body"]["query"]["bool"]["must"] + assert {"term": {"memory_id": "mem_123"}} in query + assert {"term": {"namespace": "test_namespace"}} in query diff --git a/tests/test_exa.py b/tests/test_exa.py index fdbb3e68..f555b84d 100644 --- a/tests/test_exa.py +++ b/tests/test_exa.py @@ -17,7 +17,7 @@ def mock_aiohttp_response(): mock_response = AsyncMock() mock_response.json.return_value = { "requestId": "b5947044c4b78efa9552a7c89b306d95", - "resolvedSearchType": "neural", + "resolvedSearchType": "auto", "searchType": "auto", "results": [ { @@ -40,8 +40,7 @@ def mock_aiohttp_response(): "search": 0.005, "contents": 0, "breakdown": { - "keywordSearch": 0, - "neuralSearch": 0.005, + "search": 0.005, "contentText": 0, "contentHighlight": 0, "contentSummary": 0, @@ -249,7 +248,7 @@ def test_format_search_response(): data = { "requestId": "test-request-123", "searchType": "auto", - "resolvedSearchType": "neural", + "resolvedSearchType": "auto", "results": [ { "title": "Test Result", diff --git a/tests/test_http_request.py b/tests/test_http_request.py index bd8e3fee..d3a084cc 100644 --- a/tests/test_http_request.py +++ b/tests/test_http_request.py @@ -232,7 +232,7 @@ def test_disable_redirects(): @responses.activate -def test_auth_token_direct(mock_env_vars): +def test_auth_token_direct(): """Test using auth_token parameter directly.""" responses.add( responses.GET, @@ -262,8 +262,8 @@ def test_auth_token_direct(mock_env_vars): @responses.activate -def test_auth_token_from_env(mock_env_vars): - """Test getting auth token from environment variable.""" +def test_auth_token_bearer(): + """Test Bearer auth with direct auth_token.""" responses.add( responses.GET, "https://api.example.com/protected", @@ -278,7 +278,7 @@ def test_auth_token_from_env(mock_env_vars): "method": "GET", "url": "https://api.example.com/protected", "auth_type": "Bearer", - "auth_env_var": "TEST_TOKEN", + "auth_token": "test-token-value", }, } @@ -292,7 +292,7 @@ def test_auth_token_from_env(mock_env_vars): @responses.activate -def test_github_api_auth(mock_env_vars): +def test_github_api_auth(): """Test GitHub API authentication with token prefix.""" responses.add( responses.GET, @@ -315,7 +315,7 @@ def test_github_api_auth(mock_env_vars): "method": "GET", "url": "https://api.github.com/user", "auth_type": "token", - "auth_env_var": "GITHUB_TOKEN", + "auth_token": "github-token-1234", }, } @@ -330,6 +330,85 @@ def test_github_api_auth(mock_env_vars): assert responses.calls[0].request.headers["Accept"] == "application/vnd.github.v3+json" +@responses.activate +def test_auth_env_var_allowed_domain(mock_env_vars): + """Test auth_env_var resolves token when domain is in the allowlist.""" + responses.add( + responses.GET, + "https://api.github.com/user", + json={"login": "testuser"}, + status=200, + ) + + tool_use = { + "toolUseId": "test-env-var-allowed-id", + "input": { + "method": "GET", + "url": "https://api.github.com/user", + "auth_type": "token", + "auth_env_var": "GITHUB_TOKEN", + }, + } + + token_config = {"GITHUB_TOKEN": ["api.github.com"]} + with ( + patch("strands_tools.http_request.HTTP_REQUEST_TOKEN_CONFIG", token_config), + patch("strands_tools.http_request.get_user_input") as mock_input, + ): + mock_input.return_value = "y" + result = http_request.http_request(tool=tool_use) + + assert result["status"] == "success" + assert responses.calls[0].request.headers["Authorization"] == "token github-token-1234" + + +def test_auth_env_var_domain_not_allowed(mock_env_vars): + """Test auth_env_var raises error when domain is not in the allowlist.""" + tool_use = { + "toolUseId": "test-env-var-denied-id", + "input": { + "method": "GET", + "url": "https://evil.example.com/steal", + "auth_type": "token", + "auth_env_var": "GITHUB_TOKEN", + }, + } + + token_config = {"GITHUB_TOKEN": ["api.github.com"]} + with ( + patch("strands_tools.http_request.HTTP_REQUEST_TOKEN_CONFIG", token_config), + patch("strands_tools.http_request.get_user_input") as mock_input, + ): + mock_input.return_value = "y" + result = http_request.http_request(tool=tool_use) + + assert result["status"] == "error" + assert "not in the allowed domains" in result["content"][0]["text"] + + +def test_auth_env_var_not_in_config(): + """Test auth_env_var raises error when env var is not in token config at all.""" + tool_use = { + "toolUseId": "test-env-var-no-config-id", + "input": { + "method": "GET", + "url": "https://api.github.com/user", + "auth_type": "token", + "auth_env_var": "SOME_UNKNOWN_TOKEN", + }, + } + + with ( + patch("strands_tools.http_request.HTTP_REQUEST_TOKEN_CONFIG", {}), + patch("strands_tools.http_request.get_user_input") as mock_input, + ): + mock_input.return_value = "y" + result = http_request.http_request(tool=tool_use) + + assert result["status"] == "error" + assert "STRANDS_HTTP_REQUEST_TOKEN_CONFIG" in result["content"][0]["text"] + + @responses.activate def test_basic_auth(): """Test basic authentication.""" @@ -441,27 +520,6 @@ def test_cancellation(monkeypatch): monkeypatch.delenv("BYPASS_TOOL_CONSENT", raising=False) -@responses.activate -def test_missing_env_var(): - """Test error when environment variable doesn't exist.""" - tool_use = { - "toolUseId": "test-missing-env-id", - "input": { - "method": "GET", - "url": "https://api.example.com/", - "auth_type": "Bearer", - "auth_env_var": "NON_EXISTENT_TOKEN", - }, - } - - with patch("strands_tools.http_request.get_user_input") as mock_input: - mock_input.return_value = "y" - result = http_request.http_request(tool=tool_use) - - assert result["status"] == "error" - assert "Environment variable 'NON_EXISTENT_TOKEN' not found" in result["content"][0]["text"] - - def test_aws_sigv4_auth(): """Test AWS SigV4 authentication.""" tool_use = { @@ -605,23 +663,55 @@ def test_verify_ssl_option(): }, } - # Call http_request with verify_ssl=False - with patch("strands_tools.http_request.get_user_input") as mock_input: - mock_input.return_value = "y" - # Use a real request but don't actually send it over the network - with responses.RequestsMock() as rsps: - rsps.add( - responses.GET, - "https://example.com/api/insecure", - json={"status": "insecure"}, - status=200, - ) - result = http_request.http_request(tool=tool_use) + # Call http_request with verify_ssl=False (requires STRANDS_HTTP_ALLOW_INSECURE_SSL) + original_env = os.environ.copy() + os.environ["STRANDS_HTTP_ALLOW_INSECURE_SSL"] = "true" + try: + with patch("strands_tools.http_request.get_user_input") as mock_input: + mock_input.return_value = "y" + # Use a real request but don't actually send it over the network + with responses.RequestsMock() as rsps: + rsps.add( + responses.GET, + "https://example.com/api/insecure", + json={"status": "insecure"}, + status=200, + ) + result = http_request.http_request(tool=tool_use) + finally: + os.environ.clear() + os.environ.update(original_env) # Verify the result assert result["status"] == "success" +def test_verify_ssl_blocked_without_env_var(): + """Test that verify_ssl=False is blocked without STRANDS_HTTP_ALLOW_INSECURE_SSL.""" + tool_use = { + "toolUseId": "test-ssl-blocked-id", + "input": { + "method": "GET", + "url": "https://example.com/api/insecure", + "verify_ssl": False, + }, + } + + # Ensure the env var is NOT set + original_env = os.environ.copy() + os.environ.pop("STRANDS_HTTP_ALLOW_INSECURE_SSL", None) + try: + with patch("strands_tools.http_request.get_user_input") as mock_input: + mock_input.return_value = "y" + result = http_request.http_request(tool=tool_use) + finally: + os.environ.clear() + os.environ.update(original_env) + + assert result["status"] == "error" + assert "STRANDS_HTTP_ALLOW_INSECURE_SSL" in result["content"][0]["text"] + + @responses.activate def test_dev_mode_no_confirmation(): """Test that in BYPASS_TOOL_CONSENT mode, no confirmation is requested for modifying requests.""" @@ -1092,3 +1182,147 @@ def test_proxy_support(): assert result["status"] == "success" result_text = extract_result_text(result) assert "Status Code: 200" in result_text + + +@responses.activate +def test_payment_required_header_in_response(): + """Test that Payment-Required header is captured in response.""" + # Set up mock response with Payment-Required header + responses.add( + responses.GET, + "https://api.example.com/premium-feature", + json={"error": "payment required"}, + status=402, + headers={"Payment-Required": "true"}, + content_type="application/json", + ) + + tool_use = { + "toolUseId": "test-payment-required-id", + "input": { + "method": "GET", + "url": "https://api.example.com/premium-feature", + }, + } + + with patch("strands_tools.http_request.get_user_input") as mock_input: + mock_input.return_value = "y" + result = http_request.http_request(tool=tool_use) + + assert result["status"] == "success" + result_text = extract_result_text(result) + + # Verify Payment-Required header is in the response + assert "Payment-Required" in result_text + assert "true" in result_text + assert "Status Code: 402" in result_text + + +@responses.activate +def test_payment_required_header_with_other_headers(): + """Test Payment-Required header is captured alongside other important headers.""" + # Set up mock response with multiple important headers + responses.add( + responses.GET, + "https://api.example.com/data", + json={"data": "test"}, + status=200, + headers={ + "Date": "Mon, 24 Mar 2026 12:00:00 GMT", + "Server": "nginx/1.20.0", + "Payment-Required": "false", + "X-Custom-Header": "should-not-appear", + }, + content_type="application/json", + ) + + tool_use = { + "toolUseId": "test-multiple-headers-id", + "input": { + "method": "GET", + "url": "https://api.example.com/data", + }, + } + + with patch("strands_tools.http_request.get_user_input") as mock_input: + mock_input.return_value = "y" + result = http_request.http_request(tool=tool_use) + + assert result["status"] == "success" + result_text = extract_result_text(result) + + # Verify important headers are present + assert "Content-Type" in result_text + assert "Server" in result_text + assert "Payment-Required" in result_text + + # Verify custom headers are not included + assert "X-Custom-Header" not in result_text + + +@responses.activate +def test_payment_required_header_case_insensitive(): + """Test that Payment-Required header is matched case-insensitively.""" + # Set up mock response with lowercase payment-required header + responses.add( + responses.GET, + "https://api.example.com/check", + json={"status": "ok"}, + status=200, + headers={"payment-required": "false"}, + content_type="application/json", + ) + + tool_use = { + "toolUseId": "test-case-insensitive-id", + "input": { + "method": "GET", + "url": "https://api.example.com/check", + }, + } + + with patch("strands_tools.http_request.get_user_input") as mock_input: + mock_input.return_value = "y" + result = http_request.http_request(tool=tool_use) + + assert result["status"] == "success" + result_text = extract_result_text(result) + + # Verify the header is captured regardless of case + assert "payment-required" in result_text.lower() + + +@responses.activate +def test_payment_required_header_missing(): + """Test response when Payment-Required header is not present.""" + # Set up mock response without Payment-Required header + responses.add( + responses.GET, + "https://api.example.com/free-feature", + json={"data": "free content"}, + status=200, + headers={ + "Server": "nginx", + }, + content_type="application/json", + ) + + tool_use = { + "toolUseId": "test-no-payment-header-id", + "input": { + "method": "GET", + "url": "https://api.example.com/free-feature", + }, + } + + with patch("strands_tools.http_request.get_user_input") as mock_input: + mock_input.return_value = "y" + result = http_request.http_request(tool=tool_use) + + assert result["status"] == "success" + result_text = extract_result_text(result) + + # Verify response is successful even without Payment-Required header + assert "Status Code: 200" in result_text + # The headers dict should still be present but without Payment-Required + assert "Headers:" in result_text diff --git a/tests/test_rss.py b/tests/test_rss.py index 5a31cbe0..50b1db41 100644 --- a/tests/test_rss.py +++ b/tests/test_rss.py @@ -1,6 +1,7 @@ """Comprehensive tests for RSS feed tool with improved organization.""" import json +import os from unittest.mock import MagicMock, call, mock_open, patch import pytest @@ -131,6 +132,28 @@ def test_content_processing(self): result = manager.format_entry(entry_no_content, include_content=True) assert result["content"] == "No content available" + @pytest.mark.parametrize( + "feed_id", + [ + "../outside", + "../../etc/config", + "subdir/../../../escape", + "/absolute/path", + ], + ) + def test_get_feed_file_path_rejects_traversal(self, feed_id): + """Test that path traversal sequences in feed_id are rejected.""" + manager = RSSManager() + with pytest.raises(ValueError, match="path traversal detected"): + manager.get_feed_file_path(feed_id) + + def test_get_feed_file_path_allows_valid_ids(self): + """Test that valid feed_ids are accepted.""" + manager = RSSManager() + path = manager.get_feed_file_path("my_valid_feed") + assert path.endswith("my_valid_feed.json") + assert os.path.realpath(manager.storage_path) in path + @pytest.mark.parametrize( "url,expected_id", [