Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion sdks/python/src/agent_control/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -561,6 +561,7 @@ async def handle(message: str):
state.current_agent = next_agent
state.server_url = server_url or os.getenv('AGENT_CONTROL_URL') or 'http://localhost:8000'
state.api_key = api_key
state.runtime_token_cache.clear()
state.target_type = target_type
state.target_id = target_id

Expand Down Expand Up @@ -596,7 +597,8 @@ async def register() -> list[dict[str, Any]] | None:
assert state.current_agent is not None

async with AgentControlClient(
base_url=state.server_url, api_key=state.api_key
base_url=state.server_url,
api_key=state.api_key,
) as client:
# Check server health first
try:
Expand Down Expand Up @@ -714,6 +716,7 @@ def _reset_state() -> None:
state.server_controls = None
state.server_url = None
state.api_key = None
state.runtime_token_cache.clear()
state.target_type = None
state.target_id = None

Expand Down
3 changes: 3 additions & 0 deletions sdks/python/src/agent_control/_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@

from typing import TYPE_CHECKING, Any

from .runtime_auth import RuntimeTokenCache

if TYPE_CHECKING:
from agent_control_models import Agent

Expand All @@ -24,6 +26,7 @@ def __init__(self) -> None:
self.server_controls: list[dict[str, Any]] | None = None
self.server_url: str | None = None
self.api_key: str | None = None
self.runtime_token_cache = RuntimeTokenCache()
# Optional target context fixed at init() time; both fields are set
# together or both remain None.
self.target_type: str | None = None
Expand Down
242 changes: 236 additions & 6 deletions sdks/python/src/agent_control/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,44 @@

import logging
import os
from collections.abc import Generator
from types import TracebackType
from typing import Any, cast

import httpx

from . import __version__ as sdk_version
from .runtime_auth import (
RuntimeAuthMode,
RuntimeTokenCache,
normalize_runtime_auth_mode,
parse_runtime_token_exchange_response,
)

_logger = logging.getLogger(__name__)

_RUNTIME_AUTH_MODE_ENV_VAR = "AGENT_CONTROL_RUNTIME_AUTH_MODE"
_DEFAULT_RUNTIME_TOKEN_REFRESH_MARGIN_SECONDS = 30
_AUTO_RUNTIME_TOKEN_FALLBACK_STATUSES = {404, 503}
_GLOBAL_RUNTIME_TOKEN_FALLBACK_STATUSES = {404, 503}


class _AgentControlAuth(httpx.Auth):
"""Attach local API-key credentials unless a request already has Bearer auth."""

def __init__(self, api_key: str | None, header_name: str = "X-API-Key") -> None:
self._api_key = api_key
self._header_name = header_name

def auth_flow(
self,
request: httpx.Request,
) -> Generator[httpx.Request, httpx.Response, None]:
if self._api_key and "Authorization" not in request.headers:
if self._header_name not in request.headers:
request.headers[self._header_name] = self._api_key
yield request


class AgentControlClient:
"""
Expand All @@ -20,7 +50,9 @@ class AgentControlClient:
agents, policies, controls, evaluation.

Authentication:
The client supports API key authentication via the X-API-Key header.
The client supports API key authentication. By default the key is
sent on the ``X-API-Key`` header; set ``api_key_header`` (or the
``AGENT_CONTROL_API_KEY_HEADER`` environment variable) to override.
API key can be provided:
1. Directly via the `api_key` parameter
2. Via the AGENT_CONTROL_API_KEY environment variable
Expand All @@ -34,17 +66,32 @@ class AgentControlClient:
os.environ["AGENT_CONTROL_API_KEY"] = "my-secret-key"
async with AgentControlClient() as client:
await client.health_check()

# Custom header name (e.g., when the upstream auth expects something
# other than X-API-Key). The header name applies to every request
# this client sends.
async with AgentControlClient(
api_key="my-secret-key", api_key_header="X-Custom-API-Key"
) as client:
await client.health_check()
"""

# Environment variable name for API key
API_KEY_ENV_VAR = "AGENT_CONTROL_API_KEY"
API_KEY_HEADER_ENV_VAR = "AGENT_CONTROL_API_KEY_HEADER"
DEFAULT_API_KEY_HEADER = "X-API-Key"
BASE_URL_ENV_VAR = "AGENT_CONTROL_URL"

def __init__(
self,
base_url: str | None = None,
timeout: float = 30.0,
api_key: str | None = None,
api_key_header: str | None = None,
runtime_auth_mode: RuntimeAuthMode | str | None = None,
runtime_token_cache: RuntimeTokenCache | None = None,
runtime_token_refresh_margin_seconds: int = (_DEFAULT_RUNTIME_TOKEN_REFRESH_MARGIN_SECONDS),
transport: httpx.AsyncBaseTransport | None = None,
):
"""
Initialize the client.
Expand All @@ -55,13 +102,38 @@ def __init__(
timeout: Request timeout in seconds
api_key: API key for authentication. If not provided, will attempt
to read from AGENT_CONTROL_API_KEY environment variable.
api_key_header: HTTP header name to send the API key on. Defaults
to ``X-API-Key``; the AGENT_CONTROL_API_KEY_HEADER
environment variable overrides the default. Useful when
the configured upstream auth expects a different header.
runtime_auth_mode: Runtime auth mode for evaluation requests. ``auto``
attempts target-bound JWT exchange and falls back to normal
request auth when the exchange endpoint is unavailable. ``jwt``
requires a successful exchange. ``api_key`` and ``none`` keep
evaluation requests on the normal request-auth path.
runtime_token_cache: Optional cache shared across client instances.
runtime_token_refresh_margin_seconds: Refresh cached runtime tokens
before this many seconds of validity remain.
transport: Optional httpx transport, primarily for tests.
"""
resolved_base_url = base_url or os.environ.get(
self.BASE_URL_ENV_VAR, "http://localhost:8000"
)
self.base_url = resolved_base_url.rstrip("/")
self.timeout = timeout
self._api_key = api_key or os.environ.get(self.API_KEY_ENV_VAR)
self._api_key_header = (
api_key_header
or os.environ.get(self.API_KEY_HEADER_ENV_VAR)
or self.DEFAULT_API_KEY_HEADER
)
configured_runtime_mode = runtime_auth_mode or os.environ.get(_RUNTIME_AUTH_MODE_ENV_VAR)
self._runtime_auth_mode = normalize_runtime_auth_mode(configured_runtime_mode)
if runtime_token_refresh_margin_seconds < 0:
raise ValueError("runtime_token_refresh_margin_seconds must be >= 0.")
self._runtime_token_refresh_margin_seconds = runtime_token_refresh_margin_seconds
self._runtime_token_cache = runtime_token_cache or RuntimeTokenCache()
self._transport = transport
self._client: httpx.AsyncClient | None = None
self._server_version_warning_emitted = False

Expand All @@ -70,15 +142,22 @@ def api_key(self) -> str | None:
"""Get the configured API key (read-only)."""
return self._api_key

@property
def api_key_header(self) -> str:
"""Get the header name the API key is sent on (read-only)."""
return self._api_key_header

@property
def runtime_auth_mode(self) -> RuntimeAuthMode:
"""Get the configured runtime auth mode (read-only)."""
return self._runtime_auth_mode

def _get_headers(self) -> dict[str, str]:
"""Build request headers including authentication."""
headers: dict[str, str] = {
"""Build base SDK metadata headers."""
return {
"X-Agent-Control-SDK": "python",
"X-Agent-Control-SDK-Version": sdk_version,
}
if self._api_key:
headers["X-API-Key"] = self._api_key
return headers

async def _check_server_version(self, response: httpx.Response) -> None:
"""Warn once when the server major version differs from the SDK major."""
Expand Down Expand Up @@ -108,6 +187,8 @@ async def __aenter__(self) -> "AgentControlClient":
base_url=self.base_url,
timeout=self.timeout,
headers=self._get_headers(),
auth=_AgentControlAuth(self._api_key, self._api_key_header),
transport=self._transport,
event_hooks={"response": [self._check_server_version]},
)
return self
Expand Down Expand Up @@ -137,6 +218,7 @@ async def health_check(self) -> dict[str, str]:
response = await self._client.get("/health")
response.raise_for_status()
from typing import cast

return cast(dict[str, str], response.json())

@property
Expand All @@ -145,3 +227,151 @@ def http_client(self) -> httpx.AsyncClient:
if self._client is None:
raise RuntimeError("Client not initialized. Use 'async with' context manager.")
return self._client

async def post_runtime_evaluation(
self,
*,
json: dict[str, Any],
headers: dict[str, str] | None = None,
target_type: str | None = None,
target_id: str | None = None,
) -> httpx.Response:
"""POST an evaluation request with runtime auth when configured."""
runtime_authorization = await self._runtime_authorization(
target_type=target_type,
target_id=target_id,
)
request_headers = self._merge_runtime_headers(headers, runtime_authorization)
response = await self.http_client.post(
"/api/v1/evaluation",
json=json,
headers=request_headers,
)

if response.status_code == 401 and runtime_authorization is not None:
await response.aread()
runtime_authorization = await self._runtime_authorization(
target_type=target_type,
target_id=target_id,
force_refresh=True,
allow_auto_fallback=False,
)
request_headers = self._merge_runtime_headers(headers, runtime_authorization)
response = await self.http_client.post(
"/api/v1/evaluation",
json=json,
headers=request_headers,
)

return response

def _merge_runtime_headers(
self,
headers: dict[str, str] | None,
runtime_authorization: str | None,
) -> dict[str, str] | None:
"""Merge caller headers with an optional Bearer token."""
if headers is None and runtime_authorization is None:
return None

merged = dict(headers or {})
if runtime_authorization is not None:
merged["Authorization"] = runtime_authorization
return merged

async def _runtime_authorization(
self,
*,
target_type: str | None,
target_id: str | None,
force_refresh: bool = False,
allow_auto_fallback: bool = True,
) -> str | None:
"""Return an Authorization header value for runtime evaluation."""
if self._runtime_auth_mode in {"none", "api_key"}:
return None

if target_type is None or target_id is None:
if self._runtime_auth_mode == "jwt":
raise RuntimeError(
"runtime_auth_mode='jwt' requires target_type and target_id "
"for evaluation requests."
)
return None

if self._runtime_auth_mode == "auto" and self._runtime_token_cache.is_jwt_unavailable(
self.base_url, target_type, target_id
):
return None

if not force_refresh:
cached = self._runtime_token_cache.get(
self.base_url,
target_type,
target_id,
refresh_margin_seconds=self._runtime_token_refresh_margin_seconds,
)
if cached is not None:
return f"Bearer {cached.token}"

exchange_lock = self._runtime_token_cache.exchange_lock(
self.base_url,
target_type,
target_id,
)
async with exchange_lock:
if not force_refresh:
cached = self._runtime_token_cache.get(
self.base_url,
target_type,
target_id,
refresh_margin_seconds=self._runtime_token_refresh_margin_seconds,
)
if cached is not None:
return f"Bearer {cached.token}"

token = await self._exchange_runtime_token(
target_type=target_type,
target_id=target_id,
allow_auto_fallback=allow_auto_fallback,
)
if token is None:
return None
return f"Bearer {token}"

async def _exchange_runtime_token(
self,
*,
target_type: str,
target_id: str,
allow_auto_fallback: bool = True,
) -> str | None:
"""Exchange the configured credential for a target-bound runtime token."""
response = await self.http_client.post(
"/api/v1/auth/runtime-token-exchange",
json={"target_type": target_type, "target_id": target_id},
)

if (
self._runtime_auth_mode == "auto"
and allow_auto_fallback
and response.status_code in _AUTO_RUNTIME_TOKEN_FALLBACK_STATUSES
):
self._runtime_token_cache.mark_jwt_unavailable(
server_url=self.base_url,
target_type=target_type,
target_id=target_id,
globally=response.status_code in _GLOBAL_RUNTIME_TOKEN_FALLBACK_STATUSES,
)
return None

response.raise_for_status()
payload = response.json()
if not isinstance(payload, dict):
raise RuntimeError("Runtime token exchange response was not an object.")
token = parse_runtime_token_exchange_response(
cast(dict[str, object], payload),
server_url=self.base_url,
)
self._runtime_token_cache.set(token)
return token.token
Loading
Loading