diff --git a/aws_lambda_powertools/event_handler/__init__.py b/aws_lambda_powertools/event_handler/__init__.py index 582abd017c0..98433b2e29b 100644 --- a/aws_lambda_powertools/event_handler/__init__.py +++ b/aws_lambda_powertools/event_handler/__init__.py @@ -16,6 +16,7 @@ BedrockAgentFunctionResolver, BedrockFunctionResponse, ) +from aws_lambda_powertools.event_handler.depends import DependencyResolutionError, Depends from aws_lambda_powertools.event_handler.events_appsync.appsync_events import AppSyncEventsResolver from aws_lambda_powertools.event_handler.http_resolver import HttpResolverLocal from aws_lambda_powertools.event_handler.lambda_function_url import ( @@ -36,6 +37,8 @@ "BedrockResponse", "BedrockFunctionResponse", "CORSConfig", + "Depends", + "DependencyResolutionError", "HttpResolverLocal", "LambdaFunctionUrlResolver", "Request", diff --git a/aws_lambda_powertools/event_handler/api_gateway.py b/aws_lambda_powertools/event_handler/api_gateway.py index 4a9fc142cf0..7ffaa884761 100644 --- a/aws_lambda_powertools/event_handler/api_gateway.py +++ b/aws_lambda_powertools/event_handler/api_gateway.py @@ -472,6 +472,9 @@ def __init__( self.custom_response_validation_http_code = custom_response_validation_http_code + # Cache whether this route's handler declares Depends() parameters + self._has_dependencies: bool | None = None + # Caches the name of any Request-typed parameter in the handler. # Avoids re-scanning the signature on every invocation. self.request_param_name: str | None = None @@ -613,6 +616,15 @@ def dependant(self) -> Dependant: return self._dependant + @property + def has_dependencies(self) -> bool: + """Check if handler declares Depends() parameters without triggering full dependant computation.""" + if self._has_dependencies is None: + from aws_lambda_powertools.event_handler.depends import _has_depends + + self._has_dependencies = _has_depends(self.func) + return self._has_dependencies + @property def body_field(self) -> ModelField | None: if self._body_field is None: @@ -1428,6 +1440,17 @@ def _registered_api_adapter( if route.request_param_name: route_args = {**route_args, route.request_param_name: app.request} + # Resolve Depends() parameters + if route.has_dependencies: + from aws_lambda_powertools.event_handler.depends import build_dependency_tree, solve_dependencies + + dep_values = solve_dependencies( + dependant=build_dependency_tree(route.func), + request=app.request, + dependency_overrides=app.dependency_overrides or None, + ) + route_args.update(dep_values) + return app._to_response(next_middleware(**route_args)) @@ -1497,6 +1520,7 @@ def __init__( function to deserialize `str`, `bytes`, `bytearray` containing a JSON document to a Python `dict`, by default json.loads when integrating with EventSource data class """ + self.dependency_overrides: dict[Callable, Callable] = {} self._proxy_type = proxy_type or self._proxy_event_type self._dynamic_routes: list[Route] = [] self._static_routes: list[Route] = [] diff --git a/aws_lambda_powertools/event_handler/depends.py b/aws_lambda_powertools/event_handler/depends.py new file mode 100644 index 00000000000..f05167c63d9 --- /dev/null +++ b/aws_lambda_powertools/event_handler/depends.py @@ -0,0 +1,222 @@ +"""Lightweight dependency injection primitives — no pydantic import.""" + +from __future__ import annotations + +from typing import TYPE_CHECKING, Annotated, Any, get_args, get_origin, get_type_hints + +if TYPE_CHECKING: + from collections.abc import Callable + + from aws_lambda_powertools.event_handler.openapi.params import Dependant + from aws_lambda_powertools.event_handler.request import Request + + +class DependencyResolutionError(Exception): + """Raised when a dependency cannot be resolved.""" + + +class Depends: + """ + Declares a dependency for a route handler parameter. + + Dependencies are resolved automatically before the handler is called. The return value + of the dependency callable is injected as the parameter value. + + Parameters + ---------- + dependency: Callable[..., Any] + A callable whose return value will be injected into the handler parameter. + The callable can itself declare ``Depends()`` parameters to form a dependency tree. + use_cache: bool + If ``True`` (default), the dependency result is cached per invocation so that + the same dependency used multiple times is only called once. + + Examples + -------- + + ```python + from typing import Annotated + + from aws_lambda_powertools.event_handler import APIGatewayHttpResolver, Depends + + app = APIGatewayHttpResolver() + + def get_tenant() -> str: + return "default-tenant" + + @app.get("/orders") + def list_orders(tenant_id: Annotated[str, Depends(get_tenant)]): + return {"tenant": tenant_id} + ``` + """ + + def __init__(self, dependency: Callable[..., Any], *, use_cache: bool = True) -> None: + if not callable(dependency): + raise DependencyResolutionError( + f"Depends() requires a callable, got {type(dependency).__name__}: {dependency!r}", + ) + self.dependency = dependency + self.use_cache = use_cache + + +class _DependencyNode: + """Lightweight node in a dependency tree — used by ``build_dependency_tree``.""" + + def __init__(self, *, param_name: str, depends: Depends, sub_tree: DependencyTree) -> None: + self.param_name = param_name + self.depends = depends + self.dependant = sub_tree + + +class DependencyTree: + """Lightweight dependency tree — no pydantic required. + + This mirrors the shape that ``solve_dependencies`` expects (a ``.dependencies`` + attribute containing nodes with ``.param_name``, ``.depends``, and ``.dependant``), + but can be built without importing pydantic. + """ + + def __init__(self, *, dependencies: list[_DependencyNode] | None = None) -> None: + self.dependencies: list[_DependencyNode] = dependencies or [] + + +class DependencyParam: + """Holds a dependency's parameter name and its resolved Dependant sub-tree (OpenAPI path).""" + + def __init__(self, *, param_name: str, depends: Depends, dependant: Dependant) -> None: + self.param_name = param_name + self.depends = depends + self.dependant = dependant + + +def _get_depends_from_annotation(annotation: Any) -> Depends | None: + """Extract a Depends instance from an Annotated[Type, Depends(...)] annotation.""" + if get_origin(annotation) is Annotated: + for arg in get_args(annotation)[1:]: + if isinstance(arg, Depends): + return arg + return None + + +def _has_depends(func: Callable[..., Any]) -> bool: + """Check if a callable has any Depends() parameters, without importing pydantic.""" + try: + hints = get_type_hints(func, include_extras=True) + except Exception: + return False + + for annotation in hints.values(): + if _get_depends_from_annotation(annotation) is not None: + return True + return False + + +def build_dependency_tree(func: Callable[..., Any]) -> DependencyTree: + """Build a lightweight dependency tree from a callable's signature. + + This inspects the function parameters for ``Annotated[Type, Depends(...)]`` + annotations and recursively builds the tree — all without importing pydantic. + """ + try: + hints = get_type_hints(func, include_extras=True) + except Exception: + return DependencyTree() + + dependencies: list[_DependencyNode] = [] + + for param_name, annotation in hints.items(): + if param_name == "return": + continue + + depends_instance = _get_depends_from_annotation(annotation) + if depends_instance is not None: + sub_tree = build_dependency_tree(depends_instance.dependency) + dependencies.append( + _DependencyNode( + param_name=param_name, + depends=depends_instance, + sub_tree=sub_tree, + ), + ) + + return DependencyTree(dependencies=dependencies) + + +def solve_dependencies( + *, + dependant: Dependant | DependencyTree, + request: Request | None = None, + dependency_overrides: dict[Callable[..., Any], Callable[..., Any]] | None = None, + dependency_cache: dict[Callable[..., Any], Any] | None = None, +) -> dict[str, Any]: + """ + Recursively resolve all ``Depends()`` parameters for a given dependant. + + Parameters + ---------- + dependant: Dependant + The dependant model containing dependency declarations + request: Request, optional + The current request object, injected into dependencies that declare a Request parameter + dependency_overrides: dict, optional + Mapping of original dependency callable to override callable (for testing) + dependency_cache: dict, optional + Per-invocation cache of resolved dependency values + + Returns + ------- + dict[str, Any] + Mapping of parameter name to resolved dependency value + """ + from aws_lambda_powertools.event_handler.request import Request as RequestClass + + if dependency_cache is None: + dependency_cache = {} + + values: dict[str, Any] = {} + + for dep in dependant.dependencies: + use_fn = dep.depends.dependency + + # Apply overrides (for testing) + if dependency_overrides and use_fn in dependency_overrides: + use_fn = dependency_overrides[use_fn] + + # Check cache + if dep.depends.use_cache and use_fn in dependency_cache: + values[dep.param_name] = dependency_cache[use_fn] + continue + + # Recursively resolve sub-dependencies + sub_values = solve_dependencies( + dependant=dep.dependant, + request=request, + dependency_overrides=dependency_overrides, + dependency_cache=dependency_cache, + ) + + # Inject Request if the dependency declares it + if request is not None: + try: + hints = get_type_hints(use_fn) + except Exception: # pragma: no cover - defensive for broken annotations + hints = {} + for param_name, annotation in hints.items(): + if annotation is RequestClass: + sub_values[param_name] = request + + try: + solved = use_fn(**sub_values) + except Exception as exc: + dep_name = getattr(use_fn, "__name__", repr(use_fn)) + raise DependencyResolutionError( + f"Failed to resolve dependency '{dep_name}' for parameter '{dep.param_name}': {exc}", + ) from exc + + # Cache result + if dep.depends.use_cache: + dependency_cache[use_fn] = solved + + values[dep.param_name] = solved + + return values diff --git a/aws_lambda_powertools/event_handler/openapi/dependant.py b/aws_lambda_powertools/event_handler/openapi/dependant.py index ec8414a7dd2..1e7f4327602 100644 --- a/aws_lambda_powertools/event_handler/openapi/dependant.py +++ b/aws_lambda_powertools/event_handler/openapi/dependant.py @@ -4,6 +4,7 @@ import re from typing import TYPE_CHECKING, Any, ForwardRef, cast +from aws_lambda_powertools.event_handler.depends import DependencyParam, _get_depends_from_annotation from aws_lambda_powertools.event_handler.openapi.compat import ( ModelField, create_body_model, @@ -193,6 +194,22 @@ def get_dependant( if param.annotation is Request: continue + # Depends() parameters (via Annotated[Type, Depends(fn)]) are resolved at call time. + depends_instance = _get_depends_from_annotation(param.annotation) + if depends_instance is not None: + sub_dependant = get_dependant( + path=path, + call=depends_instance.dependency, + ) + dependant.dependencies.append( + DependencyParam( + param_name=param_name, + depends=depends_instance, + dependant=sub_dependant, + ), + ) + continue + # If the parameter is a path parameter, we need to set the in_ field to "path". is_path_param = param_name in path_param_names diff --git a/aws_lambda_powertools/event_handler/openapi/params.py b/aws_lambda_powertools/event_handler/openapi/params.py index 0f13e1e1990..1ade081959f 100644 --- a/aws_lambda_powertools/event_handler/openapi/params.py +++ b/aws_lambda_powertools/event_handler/openapi/params.py @@ -23,6 +23,7 @@ if TYPE_CHECKING: from collections.abc import Callable + from aws_lambda_powertools.event_handler.depends import DependencyParam from aws_lambda_powertools.event_handler.openapi.models import Example from aws_lambda_powertools.event_handler.openapi.types import CacheKey @@ -64,6 +65,7 @@ def __init__( http_connection_param_name: str | None = None, response_param_name: str | None = None, background_tasks_param_name: str | None = None, + dependencies: list[DependencyParam] | None = None, path: str | None = None, ) -> None: self.path_params = path_params or [] @@ -78,6 +80,7 @@ def __init__( self.http_connection_param_name = http_connection_param_name self.response_param_name = response_param_name self.background_tasks_param_name = background_tasks_param_name + self.dependencies = dependencies or [] self.name = name self.call = call # Store the path to be able to re-generate a dependable from it in overrides @@ -816,7 +819,7 @@ def get_flat_dependant( visited = [] visited.append(dependant.cache_key) - return Dependant( + flat = Dependant( path_params=dependant.path_params.copy(), query_params=dependant.query_params.copy(), header_params=dependant.header_params.copy(), @@ -825,6 +828,18 @@ def get_flat_dependant( path=dependant.path, ) + # Flatten sub-dependencies that declare HTTP params (query, header, etc.) + for dep in dependant.dependencies: + if dep.dependant.cache_key not in visited: + sub_flat = get_flat_dependant(dep.dependant, visited=visited) + flat.path_params.extend(sub_flat.path_params) + flat.query_params.extend(sub_flat.query_params) + flat.header_params.extend(sub_flat.header_params) + flat.cookie_params.extend(sub_flat.cookie_params) + flat.body_params.extend(sub_flat.body_params) + + return flat + def analyze_param( *, diff --git a/docs/core/event_handler/api_gateway.md b/docs/core/event_handler/api_gateway.md index b9666457dc8..0ffd8cee15c 100644 --- a/docs/core/event_handler/api_gateway.md +++ b/docs/core/event_handler/api_gateway.md @@ -1365,6 +1365,48 @@ You can use `append_context` when you want to share data between your App and Ro --8<-- "examples/event_handler_rest/src/split_route_append_context_module.py" ``` +### Dependency injection + +You can use `Depends()` to declare dependencies that are automatically resolved and injected into your route handlers. This provides type-safe, composable, and testable dependency injection. + +#### Basic usage + +Use `Annotated[Type, Depends(fn)]` to declare a dependency. The return value of `fn` is injected into the parameter automatically. + +```python hl_lines="5 8 20 25" +--8<-- "examples/event_handler_rest/src/dependency_injection.py" +``` + +#### Nested dependencies + +Dependencies can depend on other dependencies, forming a composable tree. Shared sub-dependencies are resolved once per invocation and cached automatically. + +```python hl_lines="18 22 29-30" +--8<-- "examples/event_handler_rest/src/dependency_injection_nested.py" +``` + +#### Accessing the request + +Dependencies that need access to the current request can declare a parameter typed as `Request`. It will be injected automatically. + +```python hl_lines="5-6 12 20" +--8<-- "examples/event_handler_rest/src/dependency_injection_with_request.py" +``` + +#### Testing with dependency overrides + +Use `dependency_overrides` to replace any dependency with a mock or stub during testing - no monkeypatching needed. + +```python hl_lines="3 12 26" +--8<-- "examples/event_handler_rest/src/dependency_injection_testing.py" +``` + +???+ tip "Caching behavior" + By default, dependencies are cached within the same invocation (`use_cache=True`). If the same dependency is used by multiple handlers or sub-dependencies, it is resolved once and the result is reused. Use `Depends(fn, use_cache=False)` to resolve every time. + +???+ info "`append_context` vs `Depends()`" + `append_context` remains available for backward compatibility. `Depends()` is recommended for new code because it provides type safety, IDE autocomplete, composable dependency trees, and `dependency_overrides` for testing. + #### Sample layout This is a sample project layout for a monolithic function with routes split in different files (`/todos`, `/health`). diff --git a/examples/event_handler_rest/src/dependency_injection.py b/examples/event_handler_rest/src/dependency_injection.py new file mode 100644 index 00000000000..664bc56951d --- /dev/null +++ b/examples/event_handler_rest/src/dependency_injection.py @@ -0,0 +1,32 @@ +import os +from typing import Any + +import boto3 +from typing_extensions import Annotated + +from aws_lambda_powertools.event_handler import APIGatewayHttpResolver +from aws_lambda_powertools.event_handler.depends import Depends +from aws_lambda_powertools.utilities.typing import LambdaContext + +app = APIGatewayHttpResolver() + + +def get_dynamodb_table(): + dynamodb = boto3.resource("dynamodb") + return dynamodb.Table(os.environ["TABLE_NAME"]) + + +@app.get("/orders") +def list_orders(table: Annotated[Any, Depends(get_dynamodb_table)]): + return table.scan()["Items"] + + +@app.post("/orders") +def create_order(table: Annotated[Any, Depends(get_dynamodb_table)]): + order = app.current_event.json_body + table.put_item(Item=order) + return {"message": "Order created"} + + +def lambda_handler(event: dict, context: LambdaContext) -> dict: + return app.resolve(event, context) diff --git a/examples/event_handler_rest/src/dependency_injection_nested.py b/examples/event_handler_rest/src/dependency_injection_nested.py new file mode 100644 index 00000000000..f8245439538 --- /dev/null +++ b/examples/event_handler_rest/src/dependency_injection_nested.py @@ -0,0 +1,38 @@ +import os +from typing import Any + +import boto3 +from typing_extensions import Annotated + +from aws_lambda_powertools.event_handler import APIGatewayHttpResolver +from aws_lambda_powertools.event_handler.depends import Depends +from aws_lambda_powertools.utilities.typing import LambdaContext + +app = APIGatewayHttpResolver() + + +def get_dynamodb_resource(): + return boto3.resource("dynamodb") + + +def get_orders_table(dynamodb: Annotated[Any, Depends(get_dynamodb_resource)]): + return dynamodb.Table(os.environ["ORDERS_TABLE"]) + + +def get_users_table(dynamodb: Annotated[Any, Depends(get_dynamodb_resource)]): + return dynamodb.Table(os.environ["USERS_TABLE"]) + + +@app.get("/orders/") +def get_user_orders( + user_id: str, + orders_table: Annotated[Any, Depends(get_orders_table)], + users_table: Annotated[Any, Depends(get_users_table)], +): + user = users_table.get_item(Key={"pk": user_id})["Item"] + orders = orders_table.query(KeyConditionExpression="pk = :uid", ExpressionAttributeValues={":uid": user_id}) + return {"user": user["name"], "orders": orders["Items"]} + + +def lambda_handler(event: dict, context: LambdaContext) -> dict: + return app.resolve(event, context) diff --git a/examples/event_handler_rest/src/dependency_injection_testing.py b/examples/event_handler_rest/src/dependency_injection_testing.py new file mode 100644 index 00000000000..3b9f41c5330 --- /dev/null +++ b/examples/event_handler_rest/src/dependency_injection_testing.py @@ -0,0 +1,26 @@ +from unittest.mock import MagicMock + +from dependency_injection import app, get_dynamodb_table + + +def test_list_orders(): + # Create a mock table + mock_table = MagicMock() + mock_table.scan.return_value = {"Items": [{"id": "order-1"}]} + + # Override the dependency with a lambda that returns the mock + app.dependency_overrides[get_dynamodb_table] = lambda: mock_table + + result = app( + { + "requestContext": {"http": {"method": "GET", "path": "/orders"}, "stage": "$default"}, + "rawPath": "/orders", + "headers": {}, + }, + {}, + ) + + assert result["statusCode"] == 200 + + # Clean up overrides after testing + app.dependency_overrides.clear() diff --git a/examples/event_handler_rest/src/dependency_injection_with_request.py b/examples/event_handler_rest/src/dependency_injection_with_request.py new file mode 100644 index 00000000000..c918b646f46 --- /dev/null +++ b/examples/event_handler_rest/src/dependency_injection_with_request.py @@ -0,0 +1,25 @@ +from typing_extensions import Annotated + +from aws_lambda_powertools.event_handler import APIGatewayHttpResolver +from aws_lambda_powertools.event_handler.depends import Depends +from aws_lambda_powertools.event_handler.exceptions import UnauthorizedError +from aws_lambda_powertools.event_handler.request import Request +from aws_lambda_powertools.utilities.typing import LambdaContext + +app = APIGatewayHttpResolver() + + +def get_authenticated_user(request: Request) -> str: + user_id = request.headers.get("x-user-id") + if not user_id: + raise UnauthorizedError("Missing authentication") + return user_id + + +@app.get("/profile") +def get_profile(user_id: Annotated[str, Depends(get_authenticated_user)]): + return {"user_id": user_id} + + +def lambda_handler(event: dict, context: LambdaContext) -> dict: + return app.resolve(event, context) diff --git a/tests/functional/event_handler/_pydantic/test_depends.py b/tests/functional/event_handler/_pydantic/test_depends.py new file mode 100644 index 00000000000..0cbc05b83c7 --- /dev/null +++ b/tests/functional/event_handler/_pydantic/test_depends.py @@ -0,0 +1,216 @@ +"""Tests for Depends() with OpenAPI schema generation and validation.""" + +import json +from typing import Annotated + +from pydantic import BaseModel + +from aws_lambda_powertools.event_handler import APIGatewayHttpResolver, Depends +from aws_lambda_powertools.event_handler.request import Request +from tests.functional.utils import load_event + +API_GW_V2_EVENT = load_event("apiGatewayProxyV2Event.json") + + +# --- Fixtures --- + + +class AppConfig(BaseModel): + region: str = "us-east-1" + debug: bool = False + + +def get_config() -> AppConfig: + return AppConfig(region="eu-west-1", debug=True) + + +def get_tenant() -> str: + return "tenant-abc" + + +# --- OpenAPI schema tests --- + + +def test_depends_excluded_from_openapi_schema(): + """Depends() parameters must NOT appear in the OpenAPI schema.""" + app = APIGatewayHttpResolver(enable_validation=True) + + @app.get("/orders") + def handler(tenant: Annotated[str, Depends(get_tenant)], status: str = "active"): + return {"tenant": tenant, "status": status} + + schema = app.get_openapi_schema() + get_op = schema.paths["/orders"].get + param_names = [p.name for p in (get_op.parameters or [])] + + assert "tenant" not in param_names + assert "status" in param_names + + +def test_depends_with_pydantic_model_excluded_from_schema(): + """Depends() returning a Pydantic model must NOT appear as a body param in the schema.""" + app = APIGatewayHttpResolver(enable_validation=True) + + @app.get("/info") + def handler(config: Annotated[AppConfig, Depends(get_config)]): + return {"region": config.region} + + schema = app.get_openapi_schema() + get_op = schema.paths["/info"].get + param_names = [p.name for p in (get_op.parameters or [])] + + assert "config" not in param_names + # Should have no request body either + assert get_op.requestBody is None + + +def test_depends_nested_excluded_from_openapi_schema(): + """Nested Depends() parameters must NOT appear in the OpenAPI schema.""" + app = APIGatewayHttpResolver(enable_validation=True) + + def get_prefix() -> str: + return "Hello" + + def get_greeting(prefix: Annotated[str, Depends(get_prefix)]) -> str: + return f"{prefix}, world!" + + @app.get("/greet") + def handler(greeting: Annotated[str, Depends(get_greeting)]): + return {"greeting": greeting} + + schema = app.get_openapi_schema() + get_op = schema.paths["/greet"].get + param_names = [p.name for p in (get_op.parameters or [])] + + assert "greeting" not in param_names + assert "prefix" not in param_names + + +# --- Validation + Depends integration tests --- + + +def test_depends_with_validation_resolves_and_validates(): + """Depends() values are injected alongside validated query params.""" + app = APIGatewayHttpResolver(enable_validation=True) + + @app.get("/orders") + def handler(tenant: Annotated[str, Depends(get_tenant)], limit: int = 10): + return {"tenant": tenant, "limit": limit} + + event = {**API_GW_V2_EVENT} + event["rawPath"] = "/orders" + event["requestContext"] = { + **event["requestContext"], + "http": {"method": "GET", "path": "/orders"}, + } + event["queryStringParameters"] = {"limit": "5"} + + result = app(event, {}) + assert result["statusCode"] == 200 + body = json.loads(result["body"]) + assert body["tenant"] == "tenant-abc" + assert body["limit"] == 5 + + +def test_depends_pydantic_model_with_validation(): + """Depends() returning a Pydantic model works with enable_validation.""" + app = APIGatewayHttpResolver(enable_validation=True) + + @app.get("/config") + def handler(config: Annotated[AppConfig, Depends(get_config)]): + return {"region": config.region, "debug": config.debug} + + event = {**API_GW_V2_EVENT} + event["rawPath"] = "/config" + event["requestContext"] = { + **event["requestContext"], + "http": {"method": "GET", "path": "/config"}, + } + + result = app(event, {}) + assert result["statusCode"] == 200 + body = json.loads(result["body"]) + assert body["region"] == "eu-west-1" + assert body["debug"] is True + + +def test_depends_with_request_and_validation(): + """Depends() with Request injection works alongside validation.""" + app = APIGatewayHttpResolver(enable_validation=True) + + def get_method(request: Request) -> str: + return request.method + + @app.post("/my/path") + def handler(method: Annotated[str, Depends(get_method)], name: str = "world"): + return {"method": method, "name": name} + + event = {**API_GW_V2_EVENT, "queryStringParameters": {"name": "Lambda"}} + result = app(event, {}) + assert result["statusCode"] == 200 + body = json.loads(result["body"]) + assert body["method"] == "POST" + assert body["name"] == "Lambda" + + +def test_depends_override_with_validation(): + """dependency_overrides works with enable_validation.""" + app = APIGatewayHttpResolver(enable_validation=True) + + @app.get("/orders") + def handler(tenant: Annotated[str, Depends(get_tenant)]): + return {"tenant": tenant} + + app.dependency_overrides[get_tenant] = lambda: "test-tenant" + + event = {**API_GW_V2_EVENT} + event["rawPath"] = "/orders" + event["requestContext"] = { + **event["requestContext"], + "http": {"method": "GET", "path": "/orders"}, + } + + result = app(event, {}) + assert result["statusCode"] == 200 + assert json.loads(result["body"]) == {"tenant": "test-tenant"} + + app.dependency_overrides.clear() + + +def test_depends_with_path_params_and_validation(): + """Depends() works with path parameters and validation.""" + app = APIGatewayHttpResolver(enable_validation=True) + + @app.get("/orders/") + def handler(order_id: str, tenant: Annotated[str, Depends(get_tenant)]): + return {"order_id": order_id, "tenant": tenant} + + event = {**API_GW_V2_EVENT} + event["rawPath"] = "/orders/abc-123" + event["requestContext"] = { + **event["requestContext"], + "http": {"method": "GET", "path": "/orders/abc-123"}, + } + + result = app(event, {}) + assert result["statusCode"] == 200 + body = json.loads(result["body"]) + assert body["order_id"] == "abc-123" + assert body["tenant"] == "tenant-abc" + + +def test_depends_with_regular_params_and_validation(): + """Depends() works alongside regular handler parameters with validation.""" + app = APIGatewayHttpResolver(enable_validation=True) + + def get_greeting() -> str: + return "hello" + + @app.post("/my/path") + def handler(name: str = "world", greeting: Annotated[str, Depends(get_greeting)] = ""): + return {"message": f"{greeting}, {name}!"} + + event = {**API_GW_V2_EVENT, "queryStringParameters": {"name": "Lambda"}} + result = app(event, {}) + assert result["statusCode"] == 200 + assert json.loads(result["body"]) == {"message": "hello, Lambda!"} diff --git a/tests/functional/event_handler/required_dependencies/test_depends.py b/tests/functional/event_handler/required_dependencies/test_depends.py new file mode 100644 index 00000000000..3131be2430e --- /dev/null +++ b/tests/functional/event_handler/required_dependencies/test_depends.py @@ -0,0 +1,416 @@ +"""Tests for the Depends() dependency injection feature using Annotated.""" + +import json + +import pytest +from typing_extensions import Annotated + +from aws_lambda_powertools.event_handler import APIGatewayHttpResolver +from aws_lambda_powertools.event_handler.depends import DependencyResolutionError, Depends +from aws_lambda_powertools.event_handler.request import Request +from tests.functional.utils import load_event + +API_GW_V2_EVENT = load_event("apiGatewayProxyV2Event.json") + + +def test_depends_simple(): + """A simple dependency is resolved and injected into the handler.""" + app = APIGatewayHttpResolver() + + def get_greeting() -> str: + return "hello" + + @app.post("/my/path") + def handler(greeting: Annotated[str, Depends(get_greeting)]): + return {"greeting": greeting} + + result = app(API_GW_V2_EVENT, {}) + assert result["statusCode"] == 200 + assert json.loads(result["body"]) == {"greeting": "hello"} + + +def test_depends_nested(): + """Dependencies can depend on other dependencies.""" + app = APIGatewayHttpResolver() + + def get_prefix() -> str: + return "Hello" + + def get_greeting(prefix: Annotated[str, Depends(get_prefix)]) -> str: + return f"{prefix}, world!" + + @app.post("/my/path") + def handler(greeting: Annotated[str, Depends(get_greeting)]): + return {"greeting": greeting} + + result = app(API_GW_V2_EVENT, {}) + assert result["statusCode"] == 200 + assert json.loads(result["body"]) == {"greeting": "Hello, world!"} + + +def test_depends_cache_per_invocation(): + """Same dependency used twice in one invocation is only resolved once (use_cache=True).""" + app = APIGatewayHttpResolver() + call_count = 0 + + def get_config() -> dict: + nonlocal call_count + call_count += 1 + return {"key": "value"} + + def get_a(config: Annotated[dict, Depends(get_config)]) -> str: + return config["key"] + + def get_b(config: Annotated[dict, Depends(get_config)]) -> str: + return config["key"] + + @app.post("/my/path") + def handler(a: Annotated[str, Depends(get_a)], b: Annotated[str, Depends(get_b)]): + return {"a": a, "b": b} + + result = app(API_GW_V2_EVENT, {}) + assert result["statusCode"] == 200 + assert call_count == 1 # get_config called once despite being used by both get_a and get_b + + +def test_depends_no_cache(): + """use_cache=False resolves every time.""" + app = APIGatewayHttpResolver() + call_count = 0 + + def get_value() -> int: + nonlocal call_count + call_count += 1 + return call_count + + @app.post("/my/path") + def handler( + a: Annotated[int, Depends(get_value, use_cache=False)], + b: Annotated[int, Depends(get_value, use_cache=False)], + ): + return {"a": a, "b": b} + + result = app(API_GW_V2_EVENT, {}) + assert result["statusCode"] == 200 + assert call_count == 2 + + +def test_depends_with_request(): + """A dependency can receive the Request object.""" + app = APIGatewayHttpResolver() + + def get_method(request: Request) -> str: + return request.method + + @app.post("/my/path") + def handler(method: Annotated[str, Depends(get_method)]): + return {"method": method} + + result = app(API_GW_V2_EVENT, {}) + assert result["statusCode"] == 200 + assert json.loads(result["body"]) == {"method": "POST"} + + +def test_depends_override(): + """dependency_overrides replaces a dependency callable for testing.""" + app = APIGatewayHttpResolver() + + def get_tenant() -> str: + return "real-tenant" + + @app.post("/my/path") + def handler(tenant: Annotated[str, Depends(get_tenant)]): + return {"tenant": tenant} + + app.dependency_overrides[get_tenant] = lambda: "test-tenant" + + result = app(API_GW_V2_EVENT, {}) + assert result["statusCode"] == 200 + assert json.loads(result["body"]) == {"tenant": "test-tenant"} + + app.dependency_overrides.clear() + + +def test_depends_override_nested(): + """dependency_overrides works for nested dependencies too.""" + app = APIGatewayHttpResolver() + + def get_db_client(): + return "real-db" + + def get_table(db: Annotated[str, Depends(get_db_client)]) -> str: + return f"table-from-{db}" + + @app.post("/my/path") + def handler(table: Annotated[str, Depends(get_table)]): + return {"table": table} + + app.dependency_overrides[get_db_client] = lambda: "mock-db" + + result = app(API_GW_V2_EVENT, {}) + assert result["statusCode"] == 200 + assert json.loads(result["body"]) == {"table": "table-from-mock-db"} + + app.dependency_overrides.clear() + + +def test_depends_multiple_handlers(): + """Dependencies work across different route handlers.""" + app = APIGatewayHttpResolver() + + def get_user() -> str: + return "user-123" + + @app.get("/my/path") + def get_handler(user: Annotated[str, Depends(get_user)]): + return {"user": user, "action": "get"} + + @app.post("/my/path") + def post_handler(user: Annotated[str, Depends(get_user)]): + return {"user": user, "action": "post"} + + # Test POST (matches the event) + result = app(API_GW_V2_EVENT, {}) + assert result["statusCode"] == 200 + assert json.loads(result["body"]) == {"user": "user-123", "action": "post"} + + +def test_depends_reusable_type_alias(): + """Annotated type aliases can be reused across handlers.""" + app = APIGatewayHttpResolver() + + def get_tenant() -> str: + return "tenant-abc" + + TenantId = Annotated[str, Depends(get_tenant)] + + @app.post("/my/path") + def handler(tenant: TenantId): + return {"tenant": tenant} + + result = app(API_GW_V2_EVENT, {}) + assert result["statusCode"] == 200 + assert json.loads(result["body"]) == {"tenant": "tenant-abc"} + + +def test_handler_without_depends_works_normally(): + """A plain handler with no Depends() params is not affected by DI.""" + app = APIGatewayHttpResolver() + + @app.post("/my/path") + def handler(): + return {"ok": True} + + result = app(API_GW_V2_EVENT, {}) + assert result["statusCode"] == 200 + assert json.loads(result["body"]) == {"ok": True} + + +def test_depends_not_cached_across_invocations(): + """Each app() call resolves dependencies fresh — no cross-request leakage.""" + app = APIGatewayHttpResolver() + call_count = 0 + + def get_counter() -> int: + nonlocal call_count + call_count += 1 + return call_count + + @app.post("/my/path") + def handler(c: Annotated[int, Depends(get_counter)]): + return {"c": c} + + result1 = app(API_GW_V2_EVENT, {}) + result2 = app(API_GW_V2_EVENT, {}) + + assert json.loads(result1["body"]) == {"c": 1} + assert json.loads(result2["body"]) == {"c": 2} + assert call_count == 2 + + +def test_depends_deeply_nested(): + """Three-level dependency chain resolves correctly.""" + app = APIGatewayHttpResolver() + + def get_url() -> str: + return "postgres://localhost" + + def get_conn(url: Annotated[str, Depends(get_url)]) -> str: + return f"conn({url})" + + def get_session(conn: Annotated[str, Depends(get_conn)]) -> str: + return f"session({conn})" + + @app.post("/my/path") + def handler(session: Annotated[str, Depends(get_session)]): + return {"session": session} + + result = app(API_GW_V2_EVENT, {}) + assert result["statusCode"] == 200 + assert json.loads(result["body"]) == {"session": "session(conn(postgres://localhost))"} + + +def test_depends_with_request_reads_headers(): + """A dependency using Request can read actual request headers.""" + app = APIGatewayHttpResolver() + + def get_user_agent(request: Request) -> str: + return request.headers.get("user-agent", "unknown") + + @app.post("/my/path") + def handler(ua: Annotated[str, Depends(get_user_agent)]): + return {"ua": ua} + + result = app(API_GW_V2_EVENT, {}) + assert result["statusCode"] == 200 + assert isinstance(json.loads(result["body"])["ua"], str) + + +def test_depends_returning_none(): + """A dependency can return None without breaking.""" + app = APIGatewayHttpResolver() + + def get_nothing() -> None: + return None + + @app.post("/my/path") + def handler(val: Annotated[None, Depends(get_nothing)]): + return {"val": val} + + result = app(API_GW_V2_EVENT, {}) + assert result["statusCode"] == 200 + assert json.loads(result["body"]) == {"val": None} + + +def test_depends_exception_raises_dependency_resolution_error(): + """If a dependency raises, a DependencyResolutionError wraps the original exception.""" + app = APIGatewayHttpResolver() + + def broken() -> str: + raise ValueError("boom") + + @app.post("/my/path") + def handler(val: Annotated[str, Depends(broken)]): + return {"val": val} + + with pytest.raises(DependencyResolutionError, match="broken.*boom"): + app(API_GW_V2_EVENT, {}) + + +def test_depends_non_callable_raises_dependency_resolution_error(): + """Passing a non-callable to Depends() raises DependencyResolutionError immediately.""" + with pytest.raises(DependencyResolutionError, match="requires a callable"): + Depends("not_a_function") # type: ignore + + with pytest.raises(DependencyResolutionError, match="requires a callable"): + Depends(42) # type: ignore + + with pytest.raises(DependencyResolutionError, match="requires a callable"): + Depends(None) # type: ignore + + +def test_depends_accepts_lambda(): + """Depends() works with a lambda as the dependency.""" + app = APIGatewayHttpResolver() + + @app.post("/my/path") + def handler(val: Annotated[str, Depends(lambda: "from-lambda")]): + return {"val": val} + + result = app(API_GW_V2_EVENT, {}) + assert result["statusCode"] == 200 + assert json.loads(result["body"]) == {"val": "from-lambda"} + + +def test_depends_accepts_class_with_call(): + """Depends() works with a class that implements __call__.""" + app = APIGatewayHttpResolver() + + class TenantProvider: + def __call__(self) -> str: + return "tenant-from-class" + + @app.post("/my/path") + def handler(tenant: Annotated[str, Depends(TenantProvider())]): + return {"tenant": tenant} + + result = app(API_GW_V2_EVENT, {}) + assert result["statusCode"] == 200 + assert json.loads(result["body"]) == {"tenant": "tenant-from-class"} + + +def test_depends_accepts_class_as_factory(): + """Depends() works with a class itself (constructor as callable).""" + app = APIGatewayHttpResolver() + + class Config: + def __init__(self): + self.region = "us-east-1" + + @app.post("/my/path") + def handler(config: Annotated[Config, Depends(Config)]): + return {"region": config.region} + + result = app(API_GW_V2_EVENT, {}) + assert result["statusCode"] == 200 + assert json.loads(result["body"]) == {"region": "us-east-1"} + + +def test_depends_with_unresolvable_annotations_is_ignored(): + """A handler whose annotations cannot be resolved by get_type_hints is treated as having no deps.""" + app = APIGatewayHttpResolver() + + # Build a function with broken annotations that get_type_hints cannot resolve. + # The param has a default so the handler can still be called without it. + def make_handler(): + def handler(x: "CompletelyBogusType" = None): # noqa: F821 + return {"ok": True} + + return handler + + app.post("/my/path")(make_handler()) + + result = app(API_GW_V2_EVENT, {}) + assert result["statusCode"] == 200 + assert json.loads(result["body"]) == {"ok": True} + + +def test_depends_without_request_does_not_inject(): + """A dependency that does NOT declare Request still works when request is available.""" + app = APIGatewayHttpResolver() + + def get_static() -> str: + return "no-request-needed" + + @app.post("/my/path") + def handler(val: Annotated[str, Depends(get_static)]): + return {"val": val} + + result = app(API_GW_V2_EVENT, {}) + assert result["statusCode"] == 200 + assert json.loads(result["body"]) == {"val": "no-request-needed"} + + +def test_depends_with_broken_type_hints_on_dependency(): + """A dependency callable with broken annotations still resolves (get_type_hints fails gracefully).""" + app = APIGatewayHttpResolver() + + # Create a callable whose annotations reference a nonexistent type + # so get_type_hints() will raise inside solve_dependencies + broken_dep = type( + "BrokenDep", + (), + { + "__call__": lambda self: "it-works", + "__annotations__": {"x": "NonExistentType"}, + "__module__": __name__, + }, + )() + + @app.post("/my/path") + def handler(val: Annotated[str, Depends(broken_dep)]): + return {"val": val} + + result = app(API_GW_V2_EVENT, {}) + assert result["statusCode"] == 200 + assert json.loads(result["body"]) == {"val": "it-works"}