diff --git a/sdks/python/src/agent_control/__init__.py b/sdks/python/src/agent_control/__init__.py index 95de745f..9de85b39 100644 --- a/sdks/python/src/agent_control/__init__.py +++ b/sdks/python/src/agent_control/__init__.py @@ -561,6 +561,7 @@ async def handle(message: str): state.current_agent = next_agent state.server_url = server_url or os.getenv('AGENT_CONTROL_URL') or 'http://localhost:8000' state.api_key = api_key + state.runtime_token_cache.clear() state.target_type = target_type state.target_id = target_id @@ -596,7 +597,8 @@ async def register() -> list[dict[str, Any]] | None: assert state.current_agent is not None async with AgentControlClient( - base_url=state.server_url, api_key=state.api_key + base_url=state.server_url, + api_key=state.api_key, ) as client: # Check server health first try: @@ -714,6 +716,7 @@ def _reset_state() -> None: state.server_controls = None state.server_url = None state.api_key = None + state.runtime_token_cache.clear() state.target_type = None state.target_id = None diff --git a/sdks/python/src/agent_control/_state.py b/sdks/python/src/agent_control/_state.py index 25974567..834c73b0 100644 --- a/sdks/python/src/agent_control/_state.py +++ b/sdks/python/src/agent_control/_state.py @@ -8,6 +8,8 @@ from typing import TYPE_CHECKING, Any +from .runtime_auth import RuntimeTokenCache + if TYPE_CHECKING: from agent_control_models import Agent @@ -24,6 +26,7 @@ def __init__(self) -> None: self.server_controls: list[dict[str, Any]] | None = None self.server_url: str | None = None self.api_key: str | None = None + self.runtime_token_cache = RuntimeTokenCache() # Optional target context fixed at init() time; both fields are set # together or both remain None. self.target_type: str | None = None diff --git a/sdks/python/src/agent_control/client.py b/sdks/python/src/agent_control/client.py index 41ce0425..10ce50d6 100644 --- a/sdks/python/src/agent_control/client.py +++ b/sdks/python/src/agent_control/client.py @@ -2,14 +2,44 @@ import logging import os +from collections.abc import Generator from types import TracebackType +from typing import Any, cast import httpx from . import __version__ as sdk_version +from .runtime_auth import ( + RuntimeAuthMode, + RuntimeTokenCache, + normalize_runtime_auth_mode, + parse_runtime_token_exchange_response, +) _logger = logging.getLogger(__name__) +_RUNTIME_AUTH_MODE_ENV_VAR = "AGENT_CONTROL_RUNTIME_AUTH_MODE" +_DEFAULT_RUNTIME_TOKEN_REFRESH_MARGIN_SECONDS = 30 +_AUTO_RUNTIME_TOKEN_FALLBACK_STATUSES = {404, 503} +_GLOBAL_RUNTIME_TOKEN_FALLBACK_STATUSES = {404, 503} + + +class _AgentControlAuth(httpx.Auth): + """Attach local API-key credentials unless a request already has Bearer auth.""" + + def __init__(self, api_key: str | None, header_name: str = "X-API-Key") -> None: + self._api_key = api_key + self._header_name = header_name + + def auth_flow( + self, + request: httpx.Request, + ) -> Generator[httpx.Request, httpx.Response, None]: + if self._api_key and "Authorization" not in request.headers: + if self._header_name not in request.headers: + request.headers[self._header_name] = self._api_key + yield request + class AgentControlClient: """ @@ -20,7 +50,9 @@ class AgentControlClient: agents, policies, controls, evaluation. Authentication: - The client supports API key authentication via the X-API-Key header. + The client supports API key authentication. By default the key is + sent on the ``X-API-Key`` header; set ``api_key_header`` (or the + ``AGENT_CONTROL_API_KEY_HEADER`` environment variable) to override. API key can be provided: 1. Directly via the `api_key` parameter 2. Via the AGENT_CONTROL_API_KEY environment variable @@ -34,10 +66,20 @@ class AgentControlClient: os.environ["AGENT_CONTROL_API_KEY"] = "my-secret-key" async with AgentControlClient() as client: await client.health_check() + + # Custom header name (e.g., when the upstream auth expects something + # other than X-API-Key). The header name applies to every request + # this client sends. + async with AgentControlClient( + api_key="my-secret-key", api_key_header="X-Custom-API-Key" + ) as client: + await client.health_check() """ # Environment variable name for API key API_KEY_ENV_VAR = "AGENT_CONTROL_API_KEY" + API_KEY_HEADER_ENV_VAR = "AGENT_CONTROL_API_KEY_HEADER" + DEFAULT_API_KEY_HEADER = "X-API-Key" BASE_URL_ENV_VAR = "AGENT_CONTROL_URL" def __init__( @@ -45,6 +87,11 @@ def __init__( base_url: str | None = None, timeout: float = 30.0, api_key: str | None = None, + api_key_header: str | None = None, + runtime_auth_mode: RuntimeAuthMode | str | None = None, + runtime_token_cache: RuntimeTokenCache | None = None, + runtime_token_refresh_margin_seconds: int = (_DEFAULT_RUNTIME_TOKEN_REFRESH_MARGIN_SECONDS), + transport: httpx.AsyncBaseTransport | None = None, ): """ Initialize the client. @@ -55,6 +102,19 @@ def __init__( timeout: Request timeout in seconds api_key: API key for authentication. If not provided, will attempt to read from AGENT_CONTROL_API_KEY environment variable. + api_key_header: HTTP header name to send the API key on. Defaults + to ``X-API-Key``; the AGENT_CONTROL_API_KEY_HEADER + environment variable overrides the default. Useful when + the configured upstream auth expects a different header. + runtime_auth_mode: Runtime auth mode for evaluation requests. ``auto`` + attempts target-bound JWT exchange and falls back to normal + request auth when the exchange endpoint is unavailable. ``jwt`` + requires a successful exchange. ``api_key`` and ``none`` keep + evaluation requests on the normal request-auth path. + runtime_token_cache: Optional cache shared across client instances. + runtime_token_refresh_margin_seconds: Refresh cached runtime tokens + before this many seconds of validity remain. + transport: Optional httpx transport, primarily for tests. """ resolved_base_url = base_url or os.environ.get( self.BASE_URL_ENV_VAR, "http://localhost:8000" @@ -62,6 +122,18 @@ def __init__( self.base_url = resolved_base_url.rstrip("/") self.timeout = timeout self._api_key = api_key or os.environ.get(self.API_KEY_ENV_VAR) + self._api_key_header = ( + api_key_header + or os.environ.get(self.API_KEY_HEADER_ENV_VAR) + or self.DEFAULT_API_KEY_HEADER + ) + configured_runtime_mode = runtime_auth_mode or os.environ.get(_RUNTIME_AUTH_MODE_ENV_VAR) + self._runtime_auth_mode = normalize_runtime_auth_mode(configured_runtime_mode) + if runtime_token_refresh_margin_seconds < 0: + raise ValueError("runtime_token_refresh_margin_seconds must be >= 0.") + self._runtime_token_refresh_margin_seconds = runtime_token_refresh_margin_seconds + self._runtime_token_cache = runtime_token_cache or RuntimeTokenCache() + self._transport = transport self._client: httpx.AsyncClient | None = None self._server_version_warning_emitted = False @@ -70,15 +142,22 @@ def api_key(self) -> str | None: """Get the configured API key (read-only).""" return self._api_key + @property + def api_key_header(self) -> str: + """Get the header name the API key is sent on (read-only).""" + return self._api_key_header + + @property + def runtime_auth_mode(self) -> RuntimeAuthMode: + """Get the configured runtime auth mode (read-only).""" + return self._runtime_auth_mode + def _get_headers(self) -> dict[str, str]: - """Build request headers including authentication.""" - headers: dict[str, str] = { + """Build base SDK metadata headers.""" + return { "X-Agent-Control-SDK": "python", "X-Agent-Control-SDK-Version": sdk_version, } - if self._api_key: - headers["X-API-Key"] = self._api_key - return headers async def _check_server_version(self, response: httpx.Response) -> None: """Warn once when the server major version differs from the SDK major.""" @@ -108,6 +187,8 @@ async def __aenter__(self) -> "AgentControlClient": base_url=self.base_url, timeout=self.timeout, headers=self._get_headers(), + auth=_AgentControlAuth(self._api_key, self._api_key_header), + transport=self._transport, event_hooks={"response": [self._check_server_version]}, ) return self @@ -137,6 +218,7 @@ async def health_check(self) -> dict[str, str]: response = await self._client.get("/health") response.raise_for_status() from typing import cast + return cast(dict[str, str], response.json()) @property @@ -145,3 +227,151 @@ def http_client(self) -> httpx.AsyncClient: if self._client is None: raise RuntimeError("Client not initialized. Use 'async with' context manager.") return self._client + + async def post_runtime_evaluation( + self, + *, + json: dict[str, Any], + headers: dict[str, str] | None = None, + target_type: str | None = None, + target_id: str | None = None, + ) -> httpx.Response: + """POST an evaluation request with runtime auth when configured.""" + runtime_authorization = await self._runtime_authorization( + target_type=target_type, + target_id=target_id, + ) + request_headers = self._merge_runtime_headers(headers, runtime_authorization) + response = await self.http_client.post( + "/api/v1/evaluation", + json=json, + headers=request_headers, + ) + + if response.status_code == 401 and runtime_authorization is not None: + await response.aread() + runtime_authorization = await self._runtime_authorization( + target_type=target_type, + target_id=target_id, + force_refresh=True, + allow_auto_fallback=False, + ) + request_headers = self._merge_runtime_headers(headers, runtime_authorization) + response = await self.http_client.post( + "/api/v1/evaluation", + json=json, + headers=request_headers, + ) + + return response + + def _merge_runtime_headers( + self, + headers: dict[str, str] | None, + runtime_authorization: str | None, + ) -> dict[str, str] | None: + """Merge caller headers with an optional Bearer token.""" + if headers is None and runtime_authorization is None: + return None + + merged = dict(headers or {}) + if runtime_authorization is not None: + merged["Authorization"] = runtime_authorization + return merged + + async def _runtime_authorization( + self, + *, + target_type: str | None, + target_id: str | None, + force_refresh: bool = False, + allow_auto_fallback: bool = True, + ) -> str | None: + """Return an Authorization header value for runtime evaluation.""" + if self._runtime_auth_mode in {"none", "api_key"}: + return None + + if target_type is None or target_id is None: + if self._runtime_auth_mode == "jwt": + raise RuntimeError( + "runtime_auth_mode='jwt' requires target_type and target_id " + "for evaluation requests." + ) + return None + + if self._runtime_auth_mode == "auto" and self._runtime_token_cache.is_jwt_unavailable( + self.base_url, target_type, target_id + ): + return None + + if not force_refresh: + cached = self._runtime_token_cache.get( + self.base_url, + target_type, + target_id, + refresh_margin_seconds=self._runtime_token_refresh_margin_seconds, + ) + if cached is not None: + return f"Bearer {cached.token}" + + exchange_lock = self._runtime_token_cache.exchange_lock( + self.base_url, + target_type, + target_id, + ) + async with exchange_lock: + if not force_refresh: + cached = self._runtime_token_cache.get( + self.base_url, + target_type, + target_id, + refresh_margin_seconds=self._runtime_token_refresh_margin_seconds, + ) + if cached is not None: + return f"Bearer {cached.token}" + + token = await self._exchange_runtime_token( + target_type=target_type, + target_id=target_id, + allow_auto_fallback=allow_auto_fallback, + ) + if token is None: + return None + return f"Bearer {token}" + + async def _exchange_runtime_token( + self, + *, + target_type: str, + target_id: str, + allow_auto_fallback: bool = True, + ) -> str | None: + """Exchange the configured credential for a target-bound runtime token.""" + response = await self.http_client.post( + "/api/v1/auth/runtime-token-exchange", + json={"target_type": target_type, "target_id": target_id}, + ) + + if ( + self._runtime_auth_mode == "auto" + and allow_auto_fallback + and response.status_code in _AUTO_RUNTIME_TOKEN_FALLBACK_STATUSES + ): + self._runtime_token_cache.mark_jwt_unavailable( + server_url=self.base_url, + target_type=target_type, + target_id=target_id, + globally=response.status_code in _GLOBAL_RUNTIME_TOKEN_FALLBACK_STATUSES, + ) + return None + + response.raise_for_status() + payload = response.json() + if not isinstance(payload, dict): + raise RuntimeError("Runtime token exchange response was not an object.") + token = parse_runtime_token_exchange_response( + cast(dict[str, object], payload), + server_url=self.base_url, + ) + self._runtime_token_cache.set(token) + return token.token diff --git a/sdks/python/src/agent_control/evaluation.py b/sdks/python/src/agent_control/evaluation.py index f1c7da97..2ecfd850 100644 --- a/sdks/python/src/agent_control/evaluation.py +++ b/sdks/python/src/agent_control/evaluation.py @@ -1,8 +1,11 @@ """Evaluation check operations for Agent Control SDK.""" +from collections.abc import Awaitable, Callable from dataclasses import dataclass +from inspect import iscoroutinefunction from typing import Any, Literal, cast +import httpx from agent_control_engine import list_evaluators from agent_control_engine.core import ControlEngine from agent_control_models import ( @@ -22,6 +25,8 @@ from .tracing import get_trace_and_span_ids from .validation import ensure_agent_name +_RuntimePostEvaluation = Callable[..., Awaitable[httpx.Response]] + @dataclass class _ControlAdapter: @@ -43,12 +48,12 @@ def _resolve_session_target( ) -> tuple[str | None, str | None]: """Default per-call target from state, and reject mismatches. - The SDK supports one target per session, fixed at ``init()`` time — + The SDK supports one target per session, fixed at ``init()`` time - including no-target sessions, where the session target is ``(None, None)``. The cached controls (``state.server_controls``) are fetched for that session target. A per-call override that disagrees - with the session target — including supplying an explicit target on a - no-target session — would evaluate against the wrong cache and could + with the session target - including supplying an explicit target on a + no-target session - would evaluate against the wrong cache and could return safe without contacting the server. Reject the mismatch so callers re-init when they need to change targets. @@ -118,7 +123,7 @@ def _has_applicable_prefiltered_server_controls( parsed_server_controls: list[_ControlAdapter] = [] for control in server_control_payloads: - # Skip unrendered template controls — they have no condition to evaluate + # Skip unrendered template controls - they have no condition to evaluate # and should not trigger the server-call fallback. ctrl_data = control.get("control", {}) if ( @@ -206,6 +211,41 @@ def _cached_server_control_lookup( return _build_server_control_lookup(state.server_controls) +def _runtime_post_evaluation(client: Any) -> _RuntimePostEvaluation | None: + """Return a runtime-evaluation callable when the client exposes one.""" + runtime_post = getattr(client, "post_runtime_evaluation", None) + if not callable(runtime_post) or not iscoroutinefunction(runtime_post): + return None + return cast(_RuntimePostEvaluation, runtime_post) + + +async def _post_evaluation_request( + client: AgentControlClient, + *, + request_payload: dict[str, Any], + headers: dict[str, str] | None, + target_type: str | None, + target_id: str | None, +) -> httpx.Response: + """Send an evaluation request, using runtime auth when the client supports it.""" + runtime_post = None + if (target_type is not None and target_id is not None) or client.runtime_auth_mode == "jwt": + runtime_post = _runtime_post_evaluation(client) + if runtime_post is not None: + return await runtime_post( + json=request_payload, + headers=headers, + target_type=target_type, + target_id=target_id, + ) + + return await client.http_client.post( + "/api/v1/evaluation", + json=request_payload, + headers=headers, + ) + + async def check_evaluation( client: AgentControlClient, agent_name: str, @@ -241,10 +281,12 @@ async def check_evaluation( ) request_payload = request.model_dump(mode="json") - response = await client.http_client.post( - "/api/v1/evaluation", - json=request_payload, + response = await _post_evaluation_request( + client, + request_payload=request_payload, headers=None, + target_type=target_type, + target_id=target_id, ) response.raise_for_status() @@ -311,7 +353,7 @@ async def check_evaluation_with_local( for control in controls: control_data = control.get("control", {}) - # Skip unrendered template controls — they cannot be evaluated. + # Skip unrendered template controls - they cannot be evaluated. if ( isinstance(control_data, dict) and control_data.get("template") is not None @@ -424,10 +466,12 @@ def _with_parse_errors(result: EvaluationResult) -> EvaluationResult: headers["X-Span-Id"] = resolved_span_id try: - response = await client.http_client.post( - "/api/v1/evaluation", - json=request_payload, + response = await _post_evaluation_request( + client, + request_payload=request_payload, headers=headers, + target_type=target_type, + target_id=target_id, ) response.raise_for_status() server_result = EvaluationResponse.model_validate(response.json()) @@ -510,7 +554,11 @@ async def evaluate_controls( step_obj = Step(**step_dict) # type: ignore[arg-type] resolved_controls = state.server_controls or [] - async with AgentControlClient(base_url=state.server_url, api_key=state.api_key) as client: + async with AgentControlClient( + base_url=state.server_url, + api_key=state.api_key, + runtime_token_cache=state.runtime_token_cache, + ) as client: return await check_evaluation_with_local( client=client, agent_name=agent_name, diff --git a/sdks/python/src/agent_control/observability.py b/sdks/python/src/agent_control/observability.py index cd249fd9..e8a2ad6a 100644 --- a/sdks/python/src/agent_control/observability.py +++ b/sdks/python/src/agent_control/observability.py @@ -26,6 +26,9 @@ await shutdown_observability() Configuration (Environment Variables): + # Server connection + AGENT_CONTROL_API_KEY_HEADER: API key header name (default: X-API-Key) + # Observability (event batching) AGENT_CONTROL_OBSERVABILITY_ENABLED: Enable observability (default: true) AGENT_CONTROL_BATCH_SIZE: Max events per batch (default: 100) @@ -275,6 +278,7 @@ class EventBatcher: Attributes: server_url: Base URL of the Agent Control server api_key: API key for authentication + api_key_header: HTTP header used to send the API key batch_size: Maximum events per batch flush_interval: Seconds between automatic flushes """ @@ -283,6 +287,7 @@ def __init__( self, server_url: str | None = None, api_key: str | None = None, + api_key_header: str | None = None, batch_size: int | None = None, flush_interval: float | None = None, ): @@ -292,11 +297,13 @@ def __init__( Args: server_url: Server URL (defaults to get_settings().url) api_key: API key (defaults to get_settings().api_key) + api_key_header: API key header (defaults to get_settings().api_key_header) batch_size: Max events per batch (defaults to get_settings().batch_size) flush_interval: Seconds between flushes (defaults to get_settings().flush_interval) """ self.server_url = server_url or get_settings().url self.api_key = api_key or get_settings().api_key + self.api_key_header = api_key_header or get_settings().api_key_header self.batch_size = batch_size if batch_size is not None else get_settings().batch_size if flush_interval is not None: self.flush_interval = flush_interval @@ -413,7 +420,7 @@ def _build_batch_request( url = f"{self.server_url}/api/v1/observability/events" headers = {"Content-Type": "application/json"} if self.api_key: - headers["X-API-Key"] = self.api_key + headers[self.api_key_header] = self.api_key payload = {"events": [event.model_dump(mode="json") for event in events]} return url, headers, payload @@ -1023,6 +1030,7 @@ async def _run_awaitable_during_shutdown(result: Awaitable[Any]) -> None: def init_observability( server_url: str | None = None, api_key: str | None = None, + api_key_header: str | None = None, enabled: bool | None = None, sink_name: str | None = None, sink_config: JSONObject | None = None, @@ -1035,6 +1043,7 @@ def init_observability( Args: server_url: Server URL for sending events api_key: API key for authentication + api_key_header: HTTP header used to send the API key enabled: Override AGENT_CONTROL_OBSERVABILITY_ENABLED sink_name: Override AGENT_CONTROL_OBSERVABILITY_SINK_NAME sink_config: Override AGENT_CONTROL_OBSERVABILITY_SINK_CONFIG @@ -1079,7 +1088,11 @@ def init_observability( return _batcher # Create batcher - _batcher = EventBatcher(server_url=server_url, api_key=api_key) + _batcher = EventBatcher( + server_url=server_url, + api_key=api_key, + api_key_header=api_key_header, + ) _batcher.start() _event_sink = _BatcherControlEventSink(_batcher) diff --git a/sdks/python/src/agent_control/runtime_auth.py b/sdks/python/src/agent_control/runtime_auth.py new file mode 100644 index 00000000..3b3643e5 --- /dev/null +++ b/sdks/python/src/agent_control/runtime_auth.py @@ -0,0 +1,194 @@ +"""Runtime-token cache helpers for the Agent Control SDK.""" + +from __future__ import annotations + +import asyncio +import threading +from collections.abc import Mapping, Sequence +from dataclasses import dataclass +from datetime import UTC, datetime, timedelta +from typing import Literal + +RuntimeAuthMode = Literal["auto", "none", "api_key", "jwt"] + +_TokenKey = tuple[str, str, str] +_DEFAULT_MAX_CACHE_ENTRIES = 256 + + +@dataclass(frozen=True) +class RuntimeToken: + """Short-lived runtime token bound to one target.""" + + token: str + expires_at: datetime + server_url: str + target_type: str + target_id: str + scopes: tuple[str, ...] + + def is_fresh(self, *, refresh_margin_seconds: int) -> bool: + """Return whether the token is usable beyond the refresh margin.""" + refresh_at = datetime.now(UTC) + timedelta(seconds=refresh_margin_seconds) + return self.expires_at > refresh_at + + +class RuntimeTokenCache: + """Thread-safe runtime token cache keyed by server and target.""" + + def __init__(self, *, max_entries: int = _DEFAULT_MAX_CACHE_ENTRIES) -> None: + if max_entries < 1: + raise ValueError("max_entries must be >= 1.") + self._max_entries = max_entries + self._tokens: dict[_TokenKey, RuntimeToken] = {} + self._jwt_unavailable = False + self._jwt_unavailable_targets: set[_TokenKey] = set() + self._exchange_locks: dict[_TokenKey, asyncio.Lock] = {} + self._lock = threading.Lock() + + def get( + self, + server_url: str, + target_type: str, + target_id: str, + *, + refresh_margin_seconds: int, + ) -> RuntimeToken | None: + """Return a fresh cached token for the target, if present.""" + key = (server_url, target_type, target_id) + with self._lock: + token = self._tokens.get(key) + if token is None: + return None + if token.is_fresh(refresh_margin_seconds=refresh_margin_seconds): + return token + self._tokens.pop(key, None) + return None + + def set(self, token: RuntimeToken) -> None: + """Store a token and clear any fallback marker for its target.""" + key = (token.server_url, token.target_type, token.target_id) + with self._lock: + if key not in self._tokens and len(self._tokens) >= self._max_entries: + oldest_key = next(iter(self._tokens)) + self._tokens.pop(oldest_key, None) + self._jwt_unavailable_targets.discard(oldest_key) + self._exchange_locks.pop(oldest_key, None) + self._tokens[key] = token + self._jwt_unavailable_targets.discard(key) + + def remove(self, server_url: str, target_type: str, target_id: str) -> None: + """Drop the cached token for one target.""" + with self._lock: + self._tokens.pop((server_url, target_type, target_id), None) + + def mark_jwt_unavailable( + self, + *, + server_url: str | None = None, + target_type: str | None = None, + target_id: str | None = None, + globally: bool = False, + ) -> None: + """Record that JWT runtime auth should not be attempted.""" + with self._lock: + if globally: + self._jwt_unavailable = True + self._tokens.clear() + return + if server_url is not None and target_type is not None and target_id is not None: + key = (server_url, target_type, target_id) + if ( + key not in self._jwt_unavailable_targets + and len(self._jwt_unavailable_targets) >= self._max_entries + ): + evicted_key = self._jwt_unavailable_targets.pop() + self._exchange_locks.pop(evicted_key, None) + self._jwt_unavailable_targets.add(key) + self._tokens.pop(key, None) + + def is_jwt_unavailable(self, server_url: str, target_type: str, target_id: str) -> bool: + """Return whether JWT exchange is known unavailable for the target.""" + key = (server_url, target_type, target_id) + with self._lock: + return self._jwt_unavailable or key in self._jwt_unavailable_targets + + def clear(self) -> None: + """Clear every cached token and fallback marker.""" + with self._lock: + self._tokens.clear() + self._jwt_unavailable = False + self._jwt_unavailable_targets.clear() + self._exchange_locks.clear() + + def exchange_lock(self, server_url: str, target_type: str, target_id: str) -> asyncio.Lock: + """Return the async exchange lock for one server and target.""" + key = (server_url, target_type, target_id) + with self._lock: + lock = self._exchange_locks.get(key) + if lock is None: + lock = asyncio.Lock() + self._exchange_locks[key] = lock + return lock + + +def normalize_runtime_auth_mode(raw: str | None) -> RuntimeAuthMode: + """Normalize configured SDK runtime auth mode.""" + if raw is None or not raw.strip(): + return "auto" + + mode = raw.strip().lower() + if mode in {"none", "no_auth"}: + return "none" + if mode in {"api_key", "header"}: + return "api_key" + if mode == "auto": + return "auto" + if mode == "jwt": + return "jwt" + raise ValueError("runtime_auth_mode must be one of 'auto', 'none', 'api_key', or 'jwt'.") + + +def parse_runtime_token_exchange_response( + payload: Mapping[str, object], + *, + server_url: str, +) -> RuntimeToken: + """Parse the runtime token exchange response payload.""" + token = payload.get("token") + expires_at = payload.get("expires_at") + target_type = payload.get("target_type") + target_id = payload.get("target_id") + scopes = payload.get("scopes") + + if not isinstance(token, str) or not token: + raise RuntimeError("Runtime token exchange response did not include a token.") + if not isinstance(expires_at, str) or not expires_at: + raise RuntimeError("Runtime token exchange response did not include expires_at.") + if not isinstance(target_type, str) or not target_type: + raise RuntimeError("Runtime token exchange response did not include target_type.") + if not isinstance(target_id, str) or not target_id: + raise RuntimeError("Runtime token exchange response did not include target_id.") + if not isinstance(scopes, Sequence) or isinstance(scopes, str): + raise RuntimeError("Runtime token exchange response did not include scopes.") + + parsed_scopes: list[str] = [] + for scope in scopes: + if not isinstance(scope, str): + raise RuntimeError("Runtime token exchange response included a non-string scope.") + parsed_scopes.append(scope) + + normalized_expires_at = expires_at + if normalized_expires_at.endswith("Z"): + normalized_expires_at = f"{normalized_expires_at[:-1]}+00:00" + parsed_expires_at = datetime.fromisoformat(normalized_expires_at) + if parsed_expires_at.tzinfo is None: + parsed_expires_at = parsed_expires_at.replace(tzinfo=UTC) + + return RuntimeToken( + token=token, + expires_at=parsed_expires_at.astimezone(UTC), + server_url=server_url, + target_type=target_type, + target_id=target_id, + scopes=tuple(parsed_scopes), + ) diff --git a/sdks/python/src/agent_control/settings.py b/sdks/python/src/agent_control/settings.py index 982811f9..ea73d666 100644 --- a/sdks/python/src/agent_control/settings.py +++ b/sdks/python/src/agent_control/settings.py @@ -56,6 +56,10 @@ class SDKSettings(BaseSettings): default="", description="API key for server authentication", ) + api_key_header: str = Field( + default="X-API-Key", + description="HTTP header used to send the API key", + ) # Observability (event batching) observability_enabled: bool = Field( diff --git a/sdks/python/tests/test_client.py b/sdks/python/tests/test_client.py index aff6796e..54754a87 100644 --- a/sdks/python/tests/test_client.py +++ b/sdks/python/tests/test_client.py @@ -2,12 +2,15 @@ from __future__ import annotations +import asyncio +from datetime import UTC, datetime, timedelta from unittest.mock import patch import httpx import pytest from agent_control.client import AgentControlClient, sdk_version +from agent_control.runtime_auth import RuntimeTokenCache def test_client_uses_agent_control_url_env_var( @@ -36,17 +39,543 @@ def test_explicit_base_url_overrides_env_var( assert client.base_url == "http://explicit.test:8000" -def test_get_headers_include_sdk_metadata_and_api_key() -> None: +def test_get_headers_include_sdk_metadata() -> None: # Given: a client configured with an API key client = AgentControlClient(api_key="test-key") # When: building request headers headers = client._get_headers() - # Then: SDK metadata and authentication headers are included + # Then: SDK metadata headers are included assert headers["X-Agent-Control-SDK"] == "python" assert headers["X-Agent-Control-SDK-Version"] == sdk_version - assert headers["X-API-Key"] == "test-key" + assert "X-API-Key" not in headers + + +def test_client_rejects_negative_runtime_token_refresh_margin() -> None: + with pytest.raises(ValueError, match="runtime_token_refresh_margin_seconds"): + AgentControlClient(runtime_token_refresh_margin_seconds=-1) + + +@pytest.mark.asyncio +async def test_client_adds_api_key_auth_to_regular_requests() -> None: + seen_requests: list[httpx.Request] = [] + + def handler(request: httpx.Request) -> httpx.Response: + seen_requests.append(request) + return httpx.Response(200, json={"ok": True}) + + transport = httpx.MockTransport(handler) + + async with AgentControlClient( + base_url="https://agent-control.test", + api_key="test-key", + transport=transport, + ) as client: + response = await client.http_client.get("/api/v1/agents") + + assert response.status_code == 200 + assert seen_requests[0].headers["X-API-Key"] == "test-key" + + +@pytest.mark.asyncio +async def test_client_uses_configured_api_key_header_name() -> None: + # Given: a client configured to send the API key on a custom header + seen_requests: list[httpx.Request] = [] + + def handler(request: httpx.Request) -> httpx.Response: + seen_requests.append(request) + return httpx.Response(200, json={"ok": True}) + + transport = httpx.MockTransport(handler) + + async with AgentControlClient( + base_url="https://agent-control.test", + api_key="test-key", + api_key_header="X-Custom-API-Key", + transport=transport, + ) as client: + # When: making a request + response = await client.http_client.get("/api/v1/agents") + + # Then: the key is on the configured header and the default is absent + assert response.status_code == 200 + assert seen_requests[0].headers["X-Custom-API-Key"] == "test-key" + assert "X-API-Key" not in seen_requests[0].headers + + +@pytest.mark.asyncio +async def test_client_reads_api_key_header_from_env( + monkeypatch: pytest.MonkeyPatch, +) -> None: + # Given: AGENT_CONTROL_API_KEY_HEADER set in the environment + monkeypatch.setenv("AGENT_CONTROL_API_KEY_HEADER", "X-Custom-API-Key") + seen_requests: list[httpx.Request] = [] + + def handler(request: httpx.Request) -> httpx.Response: + seen_requests.append(request) + return httpx.Response(200, json={"ok": True}) + + transport = httpx.MockTransport(handler) + + async with AgentControlClient( + base_url="https://agent-control.test", + api_key="test-key", + transport=transport, + ) as client: + # When: no api_key_header is passed to the constructor + response = await client.http_client.get("/api/v1/agents") + + # Then: the env-var value is used + assert response.status_code == 200 + assert seen_requests[0].headers["X-Custom-API-Key"] == "test-key" + + +def test_client_exposes_default_api_key_header() -> None: + # Given: a client with no explicit header override + client = AgentControlClient(api_key="test-key") + + # Then: the property reports the documented default + assert client.api_key_header == "X-API-Key" + + +@pytest.mark.asyncio +async def test_runtime_evaluation_exchanges_and_caches_bearer_token() -> None: + exchange_calls = 0 + evaluation_authorization_headers: list[str | None] = [] + evaluation_api_key_headers: list[str | None] = [] + expires_at = (datetime.now(UTC) + timedelta(minutes=5)).isoformat() + + def handler(request: httpx.Request) -> httpx.Response: + nonlocal exchange_calls + if request.url.path == "/api/v1/auth/runtime-token-exchange": + exchange_calls += 1 + assert request.headers["X-API-Key"] == "test-key" + return httpx.Response( + 200, + json={ + "token": "runtime-token", + "expires_at": expires_at, + "target_type": "log_stream", + "target_id": "ls-1", + "scopes": ["runtime.use"], + }, + ) + + evaluation_authorization_headers.append(request.headers.get("Authorization")) + evaluation_api_key_headers.append(request.headers.get("X-API-Key")) + return httpx.Response(200, json={"is_safe": True, "confidence": 1.0}) + + transport = httpx.MockTransport(handler) + + async with AgentControlClient( + base_url="https://agent-control.test", + api_key="test-key", + runtime_auth_mode="auto", + transport=transport, + ) as client: + for _ in range(2): + response = await client.post_runtime_evaluation( + json={"target_type": "log_stream", "target_id": "ls-1"}, + target_type="log_stream", + target_id="ls-1", + ) + assert response.status_code == 200 + + assert exchange_calls == 1 + assert evaluation_authorization_headers == ["Bearer runtime-token", "Bearer runtime-token"] + assert evaluation_api_key_headers == [None, None] + + +@pytest.mark.asyncio +async def test_runtime_evaluation_single_flights_cold_cache_exchange() -> None: + exchange_calls = 0 + evaluation_authorization_headers: list[str | None] = [] + expires_at = (datetime.now(UTC) + timedelta(minutes=5)).isoformat() + + async def handler(request: httpx.Request) -> httpx.Response: + nonlocal exchange_calls + if request.url.path == "/api/v1/auth/runtime-token-exchange": + exchange_calls += 1 + await asyncio.sleep(0.01) + return httpx.Response( + 200, + json={ + "token": "runtime-token", + "expires_at": expires_at, + "target_type": "log_stream", + "target_id": "ls-1", + "scopes": ["runtime.use"], + }, + ) + + evaluation_authorization_headers.append(request.headers.get("Authorization")) + return httpx.Response(200, json={"is_safe": True, "confidence": 1.0}) + + transport = httpx.MockTransport(handler) + + async with AgentControlClient( + base_url="https://agent-control.test", + api_key="test-key", + runtime_auth_mode="jwt", + transport=transport, + ) as client: + responses = await asyncio.gather( + *( + client.post_runtime_evaluation( + json={"target_type": "log_stream", "target_id": "ls-1"}, + target_type="log_stream", + target_id="ls-1", + ) + for _ in range(5) + ) + ) + + assert [response.status_code for response in responses] == [200, 200, 200, 200, 200] + assert exchange_calls == 1 + assert evaluation_authorization_headers == ["Bearer runtime-token"] * 5 + + +@pytest.mark.asyncio +async def test_runtime_evaluation_refreshes_token_before_expiry() -> None: + exchange_tokens = ["short-token", "fresh-token"] + exchange_expiries = [ + (datetime.now(UTC) + timedelta(seconds=5)).isoformat(), + (datetime.now(UTC) + timedelta(minutes=5)).isoformat(), + ] + evaluation_authorization_headers: list[str | None] = [] + + def handler(request: httpx.Request) -> httpx.Response: + if request.url.path == "/api/v1/auth/runtime-token-exchange": + return httpx.Response( + 200, + json={ + "token": exchange_tokens.pop(0), + "expires_at": exchange_expiries.pop(0), + "target_type": "log_stream", + "target_id": "ls-1", + "scopes": ["runtime.use"], + }, + ) + + evaluation_authorization_headers.append(request.headers.get("Authorization")) + return httpx.Response(200, json={"is_safe": True, "confidence": 1.0}) + + transport = httpx.MockTransport(handler) + + async with AgentControlClient( + base_url="https://agent-control.test", + api_key="test-key", + runtime_auth_mode="jwt", + runtime_token_refresh_margin_seconds=30, + transport=transport, + ) as client: + for _ in range(2): + response = await client.post_runtime_evaluation( + json={"target_type": "log_stream", "target_id": "ls-1"}, + target_type="log_stream", + target_id="ls-1", + ) + assert response.status_code == 200 + + assert exchange_tokens == [] + assert evaluation_authorization_headers == [ + "Bearer short-token", + "Bearer fresh-token", + ] + + +@pytest.mark.asyncio +async def test_runtime_token_cache_is_scoped_to_server_url() -> None: + exchange_paths: list[str] = [] + authorization_headers: list[str | None] = [] + expires_at = (datetime.now(UTC) + timedelta(minutes=5)).isoformat() + cache = RuntimeTokenCache() + + def handler(request: httpx.Request) -> httpx.Response: + server_url = f"{request.url.scheme}://{request.url.host}" + if request.url.path == "/api/v1/auth/runtime-token-exchange": + exchange_paths.append(server_url) + return httpx.Response( + 200, + json={ + "token": f"{request.url.host}-token", + "expires_at": expires_at, + "target_type": "log_stream", + "target_id": "ls-1", + "scopes": ["runtime.use"], + }, + ) + + authorization_headers.append(request.headers.get("Authorization")) + return httpx.Response(200, json={"is_safe": True, "confidence": 1.0}) + + transport = httpx.MockTransport(handler) + + for base_url in ("https://server-a.test", "https://server-b.test"): + async with AgentControlClient( + base_url=base_url, + api_key="test-key", + runtime_auth_mode="jwt", + runtime_token_cache=cache, + transport=transport, + ) as client: + response = await client.post_runtime_evaluation( + json={"target_type": "log_stream", "target_id": "ls-1"}, + target_type="log_stream", + target_id="ls-1", + ) + assert response.status_code == 200 + + assert exchange_paths == ["https://server-a.test", "https://server-b.test"] + assert authorization_headers == [ + "Bearer server-a.test-token", + "Bearer server-b.test-token", + ] + + +@pytest.mark.asyncio +async def test_runtime_evaluation_auto_falls_back_to_api_key_when_exchange_unavailable() -> None: + exchange_calls = 0 + evaluation_api_key_headers: list[str | None] = [] + + def handler(request: httpx.Request) -> httpx.Response: + nonlocal exchange_calls + if request.url.path == "/api/v1/auth/runtime-token-exchange": + exchange_calls += 1 + return httpx.Response(503, json={"detail": "runtime auth disabled"}) + + evaluation_api_key_headers.append(request.headers.get("X-API-Key")) + return httpx.Response(200, json={"is_safe": True, "confidence": 1.0}) + + transport = httpx.MockTransport(handler) + + async with AgentControlClient( + base_url="https://agent-control.test", + api_key="test-key", + runtime_auth_mode="auto", + transport=transport, + ) as client: + for _ in range(2): + response = await client.post_runtime_evaluation( + json={"target_type": "log_stream", "target_id": "ls-1"}, + target_type="log_stream", + target_id="ls-1", + ) + assert response.status_code == 200 + + assert exchange_calls == 1 + assert evaluation_api_key_headers == ["test-key", "test-key"] + + +@pytest.mark.asyncio +async def test_runtime_evaluation_auto_without_target_uses_api_key_path() -> None: + exchange_calls = 0 + evaluation_api_key_headers: list[str | None] = [] + evaluation_authorization_headers: list[str | None] = [] + + def handler(request: httpx.Request) -> httpx.Response: + nonlocal exchange_calls + if request.url.path == "/api/v1/auth/runtime-token-exchange": + exchange_calls += 1 + return httpx.Response(200, json={}) + + evaluation_api_key_headers.append(request.headers.get("X-API-Key")) + evaluation_authorization_headers.append(request.headers.get("Authorization")) + return httpx.Response(200, json={"is_safe": True, "confidence": 1.0}) + + transport = httpx.MockTransport(handler) + + async with AgentControlClient( + base_url="https://agent-control.test", + api_key="test-key", + runtime_auth_mode="auto", + transport=transport, + ) as client: + response = await client.post_runtime_evaluation(json={}) + + assert response.status_code == 200 + assert exchange_calls == 0 + assert evaluation_api_key_headers == ["test-key"] + assert evaluation_authorization_headers == [None] + + +@pytest.mark.asyncio +async def test_runtime_evaluation_retries_once_after_unauthorized_token() -> None: + exchange_tokens = ["expired-token", "fresh-token"] + evaluation_authorization_headers: list[str | None] = [] + expires_at = (datetime.now(UTC) + timedelta(minutes=5)).isoformat() + + def handler(request: httpx.Request) -> httpx.Response: + if request.url.path == "/api/v1/auth/runtime-token-exchange": + token = exchange_tokens.pop(0) + return httpx.Response( + 200, + json={ + "token": token, + "expires_at": expires_at, + "target_type": "log_stream", + "target_id": "ls-1", + "scopes": ["runtime.use"], + }, + ) + + authorization = request.headers.get("Authorization") + evaluation_authorization_headers.append(authorization) + if authorization == "Bearer expired-token": + return httpx.Response(401, json={"detail": "expired"}) + return httpx.Response(200, json={"is_safe": True, "confidence": 1.0}) + + transport = httpx.MockTransport(handler) + + async with AgentControlClient( + base_url="https://agent-control.test", + api_key="test-key", + runtime_auth_mode="jwt", + transport=transport, + ) as client: + response = await client.post_runtime_evaluation( + json={"target_type": "log_stream", "target_id": "ls-1"}, + target_type="log_stream", + target_id="ls-1", + ) + + assert response.status_code == 200 + assert evaluation_authorization_headers == [ + "Bearer expired-token", + "Bearer fresh-token", + ] + assert exchange_tokens == [] + + +@pytest.mark.asyncio +async def test_runtime_evaluation_does_not_auto_fallback_after_unauthorized_token() -> None: + exchange_attempt = 0 + evaluation_authorization_headers: list[str | None] = [] + evaluation_api_key_headers: list[str | None] = [] + expires_at = (datetime.now(UTC) + timedelta(minutes=5)).isoformat() + + def handler(request: httpx.Request) -> httpx.Response: + nonlocal exchange_attempt + if request.url.path == "/api/v1/auth/runtime-token-exchange": + exchange_attempt += 1 + if exchange_attempt == 1: + return httpx.Response( + 200, + json={ + "token": "expired-token", + "expires_at": expires_at, + "target_type": "log_stream", + "target_id": "ls-1", + "scopes": ["runtime.use"], + }, + ) + return httpx.Response(503, json={"detail": "runtime auth disabled"}) + + evaluation_authorization_headers.append(request.headers.get("Authorization")) + evaluation_api_key_headers.append(request.headers.get("X-API-Key")) + return httpx.Response(401, json={"detail": "expired"}) + + transport = httpx.MockTransport(handler) + + async with AgentControlClient( + base_url="https://agent-control.test", + api_key="test-key", + runtime_auth_mode="auto", + transport=transport, + ) as client: + with pytest.raises(httpx.HTTPStatusError) as exc_info: + await client.post_runtime_evaluation( + json={"target_type": "log_stream", "target_id": "ls-1"}, + target_type="log_stream", + target_id="ls-1", + ) + + assert exc_info.value.response.status_code == 503 + assert exchange_attempt == 2 + assert evaluation_authorization_headers == ["Bearer expired-token"] + assert evaluation_api_key_headers == [None] + + +@pytest.mark.asyncio +async def test_runtime_evaluation_returns_second_unauthorized_response() -> None: + exchange_tokens = ["expired-token", "still-expired-token"] + evaluation_authorization_headers: list[str | None] = [] + expires_at = (datetime.now(UTC) + timedelta(minutes=5)).isoformat() + + def handler(request: httpx.Request) -> httpx.Response: + if request.url.path == "/api/v1/auth/runtime-token-exchange": + return httpx.Response( + 200, + json={ + "token": exchange_tokens.pop(0), + "expires_at": expires_at, + "target_type": "log_stream", + "target_id": "ls-1", + "scopes": ["runtime.use"], + }, + ) + + evaluation_authorization_headers.append(request.headers.get("Authorization")) + return httpx.Response(401, json={"detail": "expired"}) + + transport = httpx.MockTransport(handler) + + async with AgentControlClient( + base_url="https://agent-control.test", + api_key="test-key", + runtime_auth_mode="jwt", + transport=transport, + ) as client: + response = await client.post_runtime_evaluation( + json={"target_type": "log_stream", "target_id": "ls-1"}, + target_type="log_stream", + target_id="ls-1", + ) + + assert response.status_code == 401 + assert exchange_tokens == [] + assert evaluation_authorization_headers == [ + "Bearer expired-token", + "Bearer still-expired-token", + ] + + +@pytest.mark.asyncio +async def test_runtime_evaluation_jwt_mode_requires_target_context() -> None: + transport = httpx.MockTransport(lambda request: httpx.Response(200, json={"ok": True})) + + async with AgentControlClient( + base_url="https://agent-control.test", + api_key="test-key", + runtime_auth_mode="jwt", + transport=transport, + ) as client: + with pytest.raises(RuntimeError, match="requires target_type and target_id"): + await client.post_runtime_evaluation(json={}) + + +@pytest.mark.asyncio +async def test_runtime_exchange_rejects_non_object_response() -> None: + def handler(request: httpx.Request) -> httpx.Response: + if request.url.path == "/api/v1/auth/runtime-token-exchange": + return httpx.Response(200, json=["not", "an", "object"]) + return httpx.Response(200, json={"is_safe": True, "confidence": 1.0}) + + transport = httpx.MockTransport(handler) + + async with AgentControlClient( + base_url="https://agent-control.test", + api_key="test-key", + runtime_auth_mode="jwt", + transport=transport, + ) as client: + with pytest.raises(RuntimeError, match="response was not an object"): + await client.post_runtime_evaluation( + json={"target_type": "log_stream", "target_id": "ls-1"}, + target_type="log_stream", + target_id="ls-1", + ) @pytest.mark.asyncio diff --git a/sdks/python/tests/test_local_evaluation.py b/sdks/python/tests/test_local_evaluation.py index da0115b8..b5b725f6 100644 --- a/sdks/python/tests/test_local_evaluation.py +++ b/sdks/python/tests/test_local_evaluation.py @@ -11,6 +11,7 @@ from typing import Any from unittest.mock import AsyncMock, MagicMock, patch +import httpx import pytest from agent_control.client import AgentControlClient from agent_control.evaluation import ( @@ -29,6 +30,36 @@ # ============================================================================= +class _RuntimeAuthDuckClient: + """Minimal custom client that exposes the runtime-auth evaluation method.""" + + base_url = "https://agent-control.test" + + def __init__(self) -> None: + self.runtime_requests: list[dict[str, Any]] = [] + self.response = MagicMock() + self.response.json.return_value = {"is_safe": True, "confidence": 1.0} + self.response.raise_for_status = MagicMock() + + async def post_runtime_evaluation( + self, + *, + json: dict[str, Any], + headers: dict[str, str] | None = None, + target_type: str | None = None, + target_id: str | None = None, + ) -> MagicMock: + self.runtime_requests.append( + { + "json": json, + "headers": headers, + "target_type": target_type, + "target_id": target_id, + } + ) + return self.response + + @pytest.fixture def agent_name() -> str: """Test agent name.""" @@ -329,6 +360,100 @@ async def test_server_only_controls_calls_server(self, agent_name, llm_payload): assert result.is_safe is True + @pytest.mark.asyncio + async def test_custom_client_with_runtime_method_uses_runtime_auth_path( + self, agent_name, llm_payload + ) -> None: + """Custom clients can opt into runtime auth with post_runtime_evaluation.""" + controls = [ + make_control_dict(1, "server_ctrl", execution="server"), + ] + client = _RuntimeAuthDuckClient() + + result = await check_evaluation_with_local( + client=client, # type: ignore[arg-type] + agent_name=agent_name, + step=llm_payload, + stage="pre", + controls=controls, + target_type="log_stream", + target_id="ls-1", + ) + + assert result.is_safe is True + assert len(client.runtime_requests) == 1 + assert client.runtime_requests[0]["target_type"] == "log_stream" + assert client.runtime_requests[0]["target_id"] == "ls-1" + assert client.runtime_requests[0]["json"]["target_type"] == "log_stream" + assert client.runtime_requests[0]["json"]["target_id"] == "ls-1" + + @pytest.mark.asyncio + async def test_mock_client_with_runtime_method_uses_runtime_auth_path( + self, + agent_name, + llm_payload, + ) -> None: + """Configured mock clients can exercise the runtime-auth path.""" + controls = [ + make_control_dict(1, "server_ctrl", execution="server"), + ] + client = MagicMock(spec=AgentControlClient) + mock_response = MagicMock() + mock_response.json.return_value = {"is_safe": True, "confidence": 1.0} + mock_response.raise_for_status = MagicMock() + client.http_client = AsyncMock() + client.http_client.post = AsyncMock() + client.post_runtime_evaluation = AsyncMock(return_value=mock_response) + + result = await check_evaluation_with_local( + client=client, + agent_name=agent_name, + step=llm_payload, + stage="pre", + controls=controls, + target_type="log_stream", + target_id="ls-1", + ) + + assert result.is_safe is True + client.post_runtime_evaluation.assert_awaited_once() + client.http_client.post.assert_not_called() + + @pytest.mark.asyncio + async def test_jwt_runtime_client_without_target_raises( + self, + agent_name, + llm_payload, + ) -> None: + """JWT runtime mode requires target context through local evaluation.""" + controls = [ + make_control_dict(1, "server_ctrl", execution="server"), + ] + sent_requests: list[httpx.Request] = [] + + def handler(request: httpx.Request) -> httpx.Response: + sent_requests.append(request) + return httpx.Response(200, json={"is_safe": True, "confidence": 1.0}) + + transport = httpx.MockTransport(handler) + + async with AgentControlClient( + base_url="https://agent-control.test", + api_key="test-key", + runtime_auth_mode="jwt", + transport=transport, + ) as client: + with pytest.raises(RuntimeError, match="requires target_type and target_id"): + await check_evaluation_with_local( + client=client, + agent_name=agent_name, + step=llm_payload, + stage="pre", + controls=controls, + ) + + assert sent_requests == [] + @pytest.mark.asyncio async def test_server_only_template_backed_controls_still_call_server( self, diff --git a/sdks/python/tests/test_observability.py b/sdks/python/tests/test_observability.py index 4f655147..608495a7 100644 --- a/sdks/python/tests/test_observability.py +++ b/sdks/python/tests/test_observability.py @@ -113,6 +113,7 @@ def reset_observability_state() -> None: observability_enabled=True, observability_sink_name=DEFAULT_CONTROL_EVENT_SINK_NAME, observability_sink_config={}, + api_key_header="X-API-Key", ) with obs._external_event_sinks_lock: obs._external_event_sinks.clear() @@ -126,6 +127,7 @@ class TestEventBatcherInit: def test_init_default_values(self): """Test EventBatcher initializes with default values.""" batcher = EventBatcher() + assert batcher.api_key_header == get_settings().api_key_header assert batcher.batch_size == get_settings().batch_size assert batcher.flush_interval == get_settings().flush_interval assert batcher.shutdown_join_timeout == get_settings().shutdown_join_timeout @@ -139,11 +141,13 @@ def test_init_custom_values(self): batcher = EventBatcher( server_url="http://custom:9000", api_key="test-key", + api_key_header="X-Custom-API-Key", batch_size=50, flush_interval=5.0, ) assert batcher.server_url == "http://custom:9000" assert batcher.api_key == "test-key" + assert batcher.api_key_header == "X-Custom-API-Key" assert batcher.batch_size == 50 assert batcher.flush_interval == 5.0 @@ -151,20 +155,23 @@ def test_init_from_settings(self): """Test EventBatcher reads from settings.""" from agent_control.settings import configure_settings - # Save original values - original_url = get_settings().url - original_api_key = get_settings().api_key + original_settings = get_settings().model_dump() try: # Configure settings programmatically - configure_settings(url="http://configured-server:8080", api_key="configured-api-key") + configure_settings( + url="http://configured-server:8080", + api_key="configured-api-key", + api_key_header="X-Custom-API-Key", + ) batcher = EventBatcher() assert batcher.server_url == "http://configured-server:8080" assert batcher.api_key == "configured-api-key" + assert batcher.api_key_header == "X-Custom-API-Key" finally: # Restore original settings - configure_settings(url=original_url, api_key=original_api_key) + configure_settings(**original_settings) class TestEventBatcherStartStop: @@ -547,6 +554,46 @@ def test_send_batch_sync_returns_true_on_202(self): assert result is True client_ctor.assert_called_once_with(timeout=30.0) client.post.assert_called_once() + assert client.post.call_args.kwargs["headers"]["X-API-Key"] == "test-key" + + def test_send_batch_sync_uses_configured_api_key_header(self): + batcher = EventBatcher( + server_url="http://test:8000", + api_key="test-key", + api_key_header="X-Custom-API-Key", + ) + response = MagicMock(status_code=202, text="accepted") + client = MagicMock() + client.post.return_value = response + client_context = MagicMock() + client_context.__enter__.return_value = client + + with patch( + "agent_control.observability.httpx.Client", + return_value=client_context, + ): + result = batcher._send_batch_sync([create_mock_event()]) + + assert result is True + headers = client.post.call_args.kwargs["headers"] + assert headers["X-Custom-API-Key"] == "test-key" + assert "X-API-Key" not in headers + + def test_build_batch_request_uses_settings_api_key_header(self): + original_settings = get_settings().model_dump() + try: + configure_settings( + api_key="settings-key", + api_key_header="X-Custom-API-Key", + ) + batcher = EventBatcher() + + _, headers, _ = batcher._build_batch_request([create_mock_event()]) + + assert headers["X-Custom-API-Key"] == "settings-key" + assert "X-API-Key" not in headers + finally: + configure_settings(**original_settings) def test_send_batch_sync_returns_false_on_401_without_retry(self): batcher = EventBatcher() @@ -1056,10 +1103,12 @@ def test_init_enabled_creates_batcher(self): result = init_observability( server_url="http://test:8000", api_key="test-key", + api_key_header="X-Custom-API-Key", enabled=True, ) assert result is not None assert isinstance(result, EventBatcher) + assert result.api_key_header == "X-Custom-API-Key" assert result._running is True assert get_event_sink() is not None diff --git a/sdks/python/tests/test_runtime_auth.py b/sdks/python/tests/test_runtime_auth.py new file mode 100644 index 00000000..ed12d1b7 --- /dev/null +++ b/sdks/python/tests/test_runtime_auth.py @@ -0,0 +1,232 @@ +"""Tests for Agent Control SDK runtime auth helpers.""" + +from __future__ import annotations + +from datetime import UTC, datetime, timedelta + +import pytest + +from agent_control.runtime_auth import ( + RuntimeToken, + RuntimeTokenCache, + normalize_runtime_auth_mode, + parse_runtime_token_exchange_response, +) + + +def _runtime_token( + *, + token: str = "token", + server_url: str = "https://server-a.test", + target_type: str = "log_stream", + target_id: str = "ls-1", + expires_at: datetime | None = None, +) -> RuntimeToken: + return RuntimeToken( + token=token, + expires_at=expires_at or datetime.now(UTC) + timedelta(minutes=5), + server_url=server_url, + target_type=target_type, + target_id=target_id, + scopes=("runtime.use",), + ) + + +def test_runtime_token_cache_is_keyed_by_server_and_target() -> None: + cache = RuntimeTokenCache() + token = _runtime_token() + + cache.set(token) + + assert ( + cache.get("https://server-a.test", "log_stream", "ls-1", refresh_margin_seconds=0) == token + ) + assert ( + cache.get("https://server-b.test", "log_stream", "ls-1", refresh_margin_seconds=0) is None + ) + assert ( + cache.get("https://server-a.test", "log_stream", "ls-2", refresh_margin_seconds=0) is None + ) + + +def test_runtime_token_cache_drops_stale_tokens() -> None: + cache = RuntimeTokenCache() + cache.set(_runtime_token(expires_at=datetime.now(UTC) + timedelta(seconds=5))) + + assert ( + cache.get("https://server-a.test", "log_stream", "ls-1", refresh_margin_seconds=30) is None + ) + assert ( + cache.get("https://server-a.test", "log_stream", "ls-1", refresh_margin_seconds=0) is None + ) + + +def test_runtime_token_cache_tracks_jwt_unavailable_by_server_and_target() -> None: + cache = RuntimeTokenCache() + + cache.mark_jwt_unavailable( + server_url="https://server-a.test", + target_type="log_stream", + target_id="ls-1", + ) + + assert cache.is_jwt_unavailable("https://server-a.test", "log_stream", "ls-1") + assert not cache.is_jwt_unavailable("https://server-b.test", "log_stream", "ls-1") + + cache.set(_runtime_token()) + + assert not cache.is_jwt_unavailable("https://server-a.test", "log_stream", "ls-1") + + +def test_runtime_token_cache_global_unavailable_clears_cache() -> None: + cache = RuntimeTokenCache() + cache.set(_runtime_token()) + + cache.mark_jwt_unavailable(globally=True) + + assert cache.is_jwt_unavailable("https://server-a.test", "log_stream", "ls-1") + assert ( + cache.get("https://server-a.test", "log_stream", "ls-1", refresh_margin_seconds=0) is None + ) + + cache.clear() + + assert not cache.is_jwt_unavailable("https://server-a.test", "log_stream", "ls-1") + + +def test_runtime_token_cache_remove_drops_one_token() -> None: + cache = RuntimeTokenCache() + cache.set(_runtime_token(target_id="ls-1")) + token_2 = _runtime_token(token="token-2", target_id="ls-2") + cache.set(token_2) + + cache.remove("https://server-a.test", "log_stream", "ls-1") + + assert ( + cache.get("https://server-a.test", "log_stream", "ls-1", refresh_margin_seconds=0) is None + ) + assert ( + cache.get("https://server-a.test", "log_stream", "ls-2", refresh_margin_seconds=0) + == token_2 + ) + + +def test_runtime_token_cache_evicts_oldest_token_when_full() -> None: + cache = RuntimeTokenCache(max_entries=1) + token_1 = _runtime_token(target_id="ls-1") + token_2 = _runtime_token(token="token-2", target_id="ls-2") + + cache.set(token_1) + cache.set(token_2) + + assert ( + cache.get("https://server-a.test", "log_stream", "ls-1", refresh_margin_seconds=0) is None + ) + assert ( + cache.get("https://server-a.test", "log_stream", "ls-2", refresh_margin_seconds=0) + == token_2 + ) + + +def test_runtime_token_cache_rejects_empty_capacity() -> None: + with pytest.raises(ValueError, match="max_entries"): + RuntimeTokenCache(max_entries=0) + + +@pytest.mark.parametrize( + ("raw", "expected"), + [ + (None, "auto"), + ("", "auto"), + (" NO_AUTH ", "none"), + ("header", "api_key"), + ("api_key", "api_key"), + ("jwt", "jwt"), + ], +) +def test_normalize_runtime_auth_mode(raw: str | None, expected: str) -> None: + assert normalize_runtime_auth_mode(raw) == expected + + +def test_normalize_runtime_auth_mode_rejects_unknown_mode() -> None: + with pytest.raises(ValueError, match="runtime_auth_mode must be one of"): + normalize_runtime_auth_mode("cookie") + + +def test_parse_runtime_token_exchange_response_normalizes_zulu_expiry() -> None: + token = parse_runtime_token_exchange_response( + { + "token": "runtime-token", + "expires_at": "2026-05-07T15:00:00Z", + "target_type": "log_stream", + "target_id": "ls-1", + "scopes": ["runtime.use"], + }, + server_url="https://server-a.test", + ) + + assert token.token == "runtime-token" + assert token.expires_at == datetime(2026, 5, 7, 15, 0, tzinfo=UTC) + assert token.server_url == "https://server-a.test" + assert token.target_type == "log_stream" + assert token.target_id == "ls-1" + assert token.scopes == ("runtime.use",) + + +def test_parse_runtime_token_exchange_response_treats_naive_expiry_as_utc() -> None: + token = parse_runtime_token_exchange_response( + { + "token": "runtime-token", + "expires_at": "2026-05-07T15:00:00", + "target_type": "log_stream", + "target_id": "ls-1", + "scopes": ["runtime.use"], + }, + server_url="https://server-a.test", + ) + + assert token.expires_at == datetime(2026, 5, 7, 15, 0, tzinfo=UTC) + + +@pytest.mark.parametrize( + ("payload", "match"), + [ + ({}, "token"), + ({"token": "runtime-token"}, "expires_at"), + ({"token": "runtime-token", "expires_at": "2026-05-07T15:00:00Z"}, "target_type"), + ( + { + "token": "runtime-token", + "expires_at": "2026-05-07T15:00:00Z", + "target_type": "log_stream", + }, + "target_id", + ), + ( + { + "token": "runtime-token", + "expires_at": "2026-05-07T15:00:00Z", + "target_type": "log_stream", + "target_id": "ls-1", + "scopes": "runtime.use", + }, + "scopes", + ), + ( + { + "token": "runtime-token", + "expires_at": "2026-05-07T15:00:00Z", + "target_type": "log_stream", + "target_id": "ls-1", + "scopes": ["runtime.use", 1], + }, + "non-string scope", + ), + ], +) +def test_parse_runtime_token_exchange_response_rejects_invalid_payloads( + payload: dict[str, object], + match: str, +) -> None: + with pytest.raises(RuntimeError, match=match): + parse_runtime_token_exchange_response(payload, server_url="https://server-a.test")