diff --git a/sdks/typescript/src/generated/funcs/agents-get-evaluator.ts b/sdks/typescript/src/generated/funcs/agents-get-evaluator.ts index acb364eb..ceca1ec0 100644 --- a/sdks/typescript/src/generated/funcs/agents-get-evaluator.ts +++ b/sdks/typescript/src/generated/funcs/agents-get-evaluator.ts @@ -37,8 +37,7 @@ import { Result } from "../types/fp.js"; * agent_name: Agent identifier * evaluator_name: Name of the evaluator * db: Database session (injected) - * namespace_key: Resolved namespace; agents in another namespace - * return 404 (non-disclosing). + * principal: Authorized request principal * * Returns: * EvaluatorSchemaItem with schema details diff --git a/sdks/typescript/src/generated/funcs/agents-get.ts b/sdks/typescript/src/generated/funcs/agents-get.ts index 9724edbf..142f3062 100644 --- a/sdks/typescript/src/generated/funcs/agents-get.ts +++ b/sdks/typescript/src/generated/funcs/agents-get.ts @@ -38,8 +38,7 @@ import { Result } from "../types/fp.js"; * Args: * agent_name: Agent identifier * db: Database session (injected) - * namespace_key: Resolved namespace; agents in another namespace - * return 404 (non-disclosing). + * principal: Authorized request principal * * Returns: * GetAgentResponse with agent metadata and step list diff --git a/sdks/typescript/src/generated/funcs/agents-init.ts b/sdks/typescript/src/generated/funcs/agents-init.ts index 9d63358d..7150b2a4 100644 --- a/sdks/typescript/src/generated/funcs/agents-init.ts +++ b/sdks/typescript/src/generated/funcs/agents-init.ts @@ -51,6 +51,7 @@ import { Result } from "../types/fp.js"; * Args: * request: Agent metadata and step schemas * db: Database session (injected) + * principal: Authorized request principal * * Returns: * InitAgentResponse with created flag and the effective controls diff --git a/sdks/typescript/src/generated/funcs/agents-list-controls.ts b/sdks/typescript/src/generated/funcs/agents-list-controls.ts index 661c5509..d1e5b27d 100644 --- a/sdks/typescript/src/generated/funcs/agents-list-controls.ts +++ b/sdks/typescript/src/generated/funcs/agents-list-controls.ts @@ -53,7 +53,7 @@ import { Result } from "../types/fp.js"; * target_type: Optional opaque target kind (paired with target_id) * target_id: Optional opaque target identifier (paired with target_type) * db: Database session (injected) - * namespace_key: Namespace scoping for the resolution (injected) + * principal: Authorized request principal * * Returns: * AgentControlsResponse with controls matching the requested state filters diff --git a/sdks/typescript/src/generated/funcs/agents-list-evaluators.ts b/sdks/typescript/src/generated/funcs/agents-list-evaluators.ts index c4d8a4b2..4217e752 100644 --- a/sdks/typescript/src/generated/funcs/agents-list-evaluators.ts +++ b/sdks/typescript/src/generated/funcs/agents-list-evaluators.ts @@ -42,8 +42,7 @@ import { Result } from "../types/fp.js"; * cursor: Optional cursor for pagination (name of last evaluator from previous page) * limit: Pagination limit (default 20, max 100) * db: Database session (injected) - * namespace_key: Resolved namespace; agents in another namespace - * return 404 (non-disclosing). + * principal: Authorized request principal * * Returns: * ListEvaluatorsResponse with evaluator schemas and pagination diff --git a/sdks/typescript/src/generated/funcs/agents-list.ts b/sdks/typescript/src/generated/funcs/agents-list.ts index fda7574d..f887d0b5 100644 --- a/sdks/typescript/src/generated/funcs/agents-list.ts +++ b/sdks/typescript/src/generated/funcs/agents-list.ts @@ -42,7 +42,7 @@ import { Result } from "../types/fp.js"; * limit: Pagination limit (default 20, max 100) * name: Optional name filter (case-insensitive partial match) * db: Database session (injected) - * namespace_key: Resolved namespace for the request + * principal: Authorized request principal * * Returns: * ListAgentsResponse with agent summaries and pagination info diff --git a/sdks/typescript/src/generated/funcs/agents-update.ts b/sdks/typescript/src/generated/funcs/agents-update.ts index e82644cf..aff9d827 100644 --- a/sdks/typescript/src/generated/funcs/agents-update.ts +++ b/sdks/typescript/src/generated/funcs/agents-update.ts @@ -40,6 +40,7 @@ import { Result } from "../types/fp.js"; * agent_name: Agent identifier * request: Lists of step/evaluator identifiers to remove * db: Database session (injected) + * principal: Authorized request principal * * Returns: * PatchAgentResponse with lists of actually removed items diff --git a/sdks/typescript/src/generated/funcs/auth-runtime-token-exchange.ts b/sdks/typescript/src/generated/funcs/auth-runtime-token-exchange.ts index 176693e3..7e8679c8 100644 --- a/sdks/typescript/src/generated/funcs/auth-runtime-token-exchange.ts +++ b/sdks/typescript/src/generated/funcs/auth-runtime-token-exchange.ts @@ -32,11 +32,10 @@ import { Result } from "../types/fp.js"; * @remarks * Mint a short-lived runtime token for the requested target. * - * The caller's credential is authenticated and authorized by the - * installed default authorizer; the resulting :class:`Principal` - * supplies the actor identity and (when the upstream surfaces it) - * the grant scopes and expiry. This endpoint then mints a local HS256 - * token whose lifetime cannot outlive the upstream grant. + * The caller's credential is authenticated and authorized before the + * resolved principal supplies the actor identity, grant scopes, and + * expiry. This endpoint then mints a local HS256 token whose lifetime + * cannot outlive the grant. * * Runtime auth must be enabled via * ``AGENT_CONTROL_RUNTIME_TOKEN_SECRET``; otherwise the endpoint diff --git a/sdks/typescript/src/generated/funcs/control-bindings-create.ts b/sdks/typescript/src/generated/funcs/control-bindings-create.ts index 8412487e..faf99923 100644 --- a/sdks/typescript/src/generated/funcs/control-bindings-create.ts +++ b/sdks/typescript/src/generated/funcs/control-bindings-create.ts @@ -32,12 +32,8 @@ import { Result } from "../types/fp.js"; * @remarks * Attach a control to an opaque external target. * - * Each binding row is scoped to the request namespace as resolved by - * ``get_namespace_key``. The auth chain still runs via - * ``require_operation`` for authentication and authorization, but the - * storage namespace is taken from the same resolver the rest of the - * server uses so binding writes and runtime reads stay in lockstep - * until auth-derived namespace resolution lands across every endpoint. + * Each binding row is scoped to the namespace associated with the + * authenticated request. */ export function controlBindingsCreate( client: AgentControlSDKCore, diff --git a/sdks/typescript/src/generated/funcs/control-bindings-delete.ts b/sdks/typescript/src/generated/funcs/control-bindings-delete.ts index 9e4d1293..9872a9b4 100644 --- a/sdks/typescript/src/generated/funcs/control-bindings-delete.ts +++ b/sdks/typescript/src/generated/funcs/control-bindings-delete.ts @@ -36,7 +36,7 @@ import { Result } from "../types/fp.js"; * See the GET-by-id docstring for the authorization scope: this route * is namespace-wide because the target identifiers are not available * before the binding is loaded. Use ``POST /by-key:delete`` for - * target-scoped detach that forwards the target to the authorizer. + * target-scoped detach that includes the target in the request context. */ export function controlBindingsDelete( client: AgentControlSDKCore, diff --git a/sdks/typescript/src/generated/funcs/control-bindings-get.ts b/sdks/typescript/src/generated/funcs/control-bindings-get.ts index dafb7c7c..88b4e419 100644 --- a/sdks/typescript/src/generated/funcs/control-bindings-get.ts +++ b/sdks/typescript/src/generated/funcs/control-bindings-get.ts @@ -34,12 +34,11 @@ import { Result } from "../types/fp.js"; * Read a single control binding by surrogate ID. * * Authorization is namespace-wide: the binding's target identifiers - * are not forwarded to the upstream because they are only discoverable - * after the row is loaded, and ``require_operation`` is single-pass. + * are not available until after the row is loaded. * Callers whose authorization model requires per-target permissions * should use the natural-key endpoints (``PUT /by-key``, * ``POST /by-key:delete``) and the target-filtered list endpoint, all - * of which forward ``(target_type, target_id)`` to the authorizer. + * of which include ``(target_type, target_id)`` in the request context. */ export function controlBindingsGet( client: AgentControlSDKCore, diff --git a/sdks/typescript/src/generated/funcs/control-bindings-list.ts b/sdks/typescript/src/generated/funcs/control-bindings-list.ts index 5e7e87c3..a87ca89f 100644 --- a/sdks/typescript/src/generated/funcs/control-bindings-list.ts +++ b/sdks/typescript/src/generated/funcs/control-bindings-list.ts @@ -35,8 +35,7 @@ import { Result } from "../types/fp.js"; * cursor-based pagination. Bindings are ordered by ID descending * (newest first). The cursor is opaque to clients: pass back the * ``next_cursor`` value verbatim to fetch the following page. The - * storage namespace is resolved by ``get_namespace_key`` so this - * listing stays in lockstep with the rest of the server's reads. + * storage namespace is resolved from the authenticated request. */ export function controlBindingsList( client: AgentControlSDKCore, diff --git a/sdks/typescript/src/generated/funcs/control-bindings-update.ts b/sdks/typescript/src/generated/funcs/control-bindings-update.ts index b3faf800..b94520a2 100644 --- a/sdks/typescript/src/generated/funcs/control-bindings-update.ts +++ b/sdks/typescript/src/generated/funcs/control-bindings-update.ts @@ -36,7 +36,7 @@ import { Result } from "../types/fp.js"; * See the GET-by-id docstring for the authorization scope: this route * is namespace-wide because the target identifiers are not available * before the binding is loaded. Use ``PUT /by-key`` for target-scoped - * upserts that forward the target to the authorizer. + * upserts that include the target in the request context. */ export function controlBindingsUpdate( client: AgentControlSDKCore, diff --git a/sdks/typescript/src/generated/sdk/agents.ts b/sdks/typescript/src/generated/sdk/agents.ts index a22f4209..0a70e128 100644 --- a/sdks/typescript/src/generated/sdk/agents.ts +++ b/sdks/typescript/src/generated/sdk/agents.ts @@ -39,7 +39,7 @@ export class Agents extends ClientSDK { * limit: Pagination limit (default 20, max 100) * name: Optional name filter (case-insensitive partial match) * db: Database session (injected) - * namespace_key: Resolved namespace for the request + * principal: Authorized request principal * * Returns: * ListAgentsResponse with agent summaries and pagination info @@ -80,6 +80,7 @@ export class Agents extends ClientSDK { * Args: * request: Agent metadata and step schemas * db: Database session (injected) + * principal: Authorized request principal * * Returns: * InitAgentResponse with created flag and the effective controls @@ -106,8 +107,7 @@ export class Agents extends ClientSDK { * Args: * agent_name: Agent identifier * db: Database session (injected) - * namespace_key: Resolved namespace; agents in another namespace - * return 404 (non-disclosing). + * principal: Authorized request principal * * Returns: * GetAgentResponse with agent metadata and step list @@ -140,6 +140,7 @@ export class Agents extends ClientSDK { * agent_name: Agent identifier * request: Lists of step/evaluator identifiers to remove * db: Database session (injected) + * principal: Authorized request principal * * Returns: * PatchAgentResponse with lists of actually removed items @@ -185,7 +186,7 @@ export class Agents extends ClientSDK { * target_type: Optional opaque target kind (paired with target_id) * target_id: Optional opaque target identifier (paired with target_type) * db: Database session (injected) - * namespace_key: Namespace scoping for the resolution (injected) + * principal: Authorized request principal * * Returns: * AgentControlsResponse with controls matching the requested state filters @@ -256,8 +257,7 @@ export class Agents extends ClientSDK { * cursor: Optional cursor for pagination (name of last evaluator from previous page) * limit: Pagination limit (default 20, max 100) * db: Database session (injected) - * namespace_key: Resolved namespace; agents in another namespace - * return 404 (non-disclosing). + * principal: Authorized request principal * * Returns: * ListEvaluatorsResponse with evaluator schemas and pagination @@ -287,8 +287,7 @@ export class Agents extends ClientSDK { * agent_name: Agent identifier * evaluator_name: Name of the evaluator * db: Database session (injected) - * namespace_key: Resolved namespace; agents in another namespace - * return 404 (non-disclosing). + * principal: Authorized request principal * * Returns: * EvaluatorSchemaItem with schema details diff --git a/sdks/typescript/src/generated/sdk/auth.ts b/sdks/typescript/src/generated/sdk/auth.ts index cf6de9ba..2d0cf74e 100644 --- a/sdks/typescript/src/generated/sdk/auth.ts +++ b/sdks/typescript/src/generated/sdk/auth.ts @@ -14,11 +14,10 @@ export class Auth extends ClientSDK { * @remarks * Mint a short-lived runtime token for the requested target. * - * The caller's credential is authenticated and authorized by the - * installed default authorizer; the resulting :class:`Principal` - * supplies the actor identity and (when the upstream surfaces it) - * the grant scopes and expiry. This endpoint then mints a local HS256 - * token whose lifetime cannot outlive the upstream grant. + * The caller's credential is authenticated and authorized before the + * resolved principal supplies the actor identity, grant scopes, and + * expiry. This endpoint then mints a local HS256 token whose lifetime + * cannot outlive the grant. * * Runtime auth must be enabled via * ``AGENT_CONTROL_RUNTIME_TOKEN_SECRET``; otherwise the endpoint diff --git a/sdks/typescript/src/generated/sdk/control-bindings.ts b/sdks/typescript/src/generated/sdk/control-bindings.ts index 5101ce74..5a5bcf2b 100644 --- a/sdks/typescript/src/generated/sdk/control-bindings.ts +++ b/sdks/typescript/src/generated/sdk/control-bindings.ts @@ -23,8 +23,7 @@ export class ControlBindings extends ClientSDK { * cursor-based pagination. Bindings are ordered by ID descending * (newest first). The cursor is opaque to clients: pass back the * ``next_cursor`` value verbatim to fetch the following page. The - * storage namespace is resolved by ``get_namespace_key`` so this - * listing stays in lockstep with the rest of the server's reads. + * storage namespace is resolved from the authenticated request. */ async list( request?: @@ -45,12 +44,8 @@ export class ControlBindings extends ClientSDK { * @remarks * Attach a control to an opaque external target. * - * Each binding row is scoped to the request namespace as resolved by - * ``get_namespace_key``. The auth chain still runs via - * ``require_operation`` for authentication and authorization, but the - * storage namespace is taken from the same resolver the rest of the - * server uses so binding writes and runtime reads stay in lockstep - * until auth-derived namespace resolution lands across every endpoint. + * Each binding row is scoped to the namespace associated with the + * authenticated request. */ async create( request: models.CreateControlBindingRequest, @@ -109,7 +104,7 @@ export class ControlBindings extends ClientSDK { * See the GET-by-id docstring for the authorization scope: this route * is namespace-wide because the target identifiers are not available * before the binding is loaded. Use ``POST /by-key:delete`` for - * target-scoped detach that forwards the target to the authorizer. + * target-scoped detach that includes the target in the request context. */ async delete( request: @@ -130,12 +125,11 @@ export class ControlBindings extends ClientSDK { * Read a single control binding by surrogate ID. * * Authorization is namespace-wide: the binding's target identifiers - * are not forwarded to the upstream because they are only discoverable - * after the row is loaded, and ``require_operation`` is single-pass. + * are not available until after the row is loaded. * Callers whose authorization model requires per-target permissions * should use the natural-key endpoints (``PUT /by-key``, * ``POST /by-key:delete``) and the target-filtered list endpoint, all - * of which forward ``(target_type, target_id)`` to the authorizer. + * of which include ``(target_type, target_id)`` in the request context. */ async get( request: @@ -158,7 +152,7 @@ export class ControlBindings extends ClientSDK { * See the GET-by-id docstring for the authorization scope: this route * is namespace-wide because the target identifiers are not available * before the binding is loaded. Use ``PUT /by-key`` for target-scoped - * upserts that forward the target to the authorizer. + * upserts that include the target in the request context. */ async update( request: diff --git a/server/src/agent_control_server/auth_framework/__init__.py b/server/src/agent_control_server/auth_framework/__init__.py index 57368d57..0333f2cc 100644 --- a/server/src/agent_control_server/auth_framework/__init__.py +++ b/server/src/agent_control_server/auth_framework/__init__.py @@ -2,10 +2,9 @@ Endpoints declare an :class:`Operation` they need; an installed :class:`RequestAuthorizer` decides whether the request is allowed and -returns the resulting :class:`Principal`. Two providers ship in-tree: -:class:`HeaderAuthProvider` (uses local credential checks) and -:class:`HttpUpstreamAuthProvider` (delegates to a configurable -upstream HTTP service). +returns the resulting :class:`Principal`. Providers ship in-tree for +disabled auth, local credential checks, upstream HTTP authorization, +and local runtime-JWT verification. """ from .core import ( diff --git a/server/src/agent_control_server/auth_framework/config.py b/server/src/agent_control_server/auth_framework/config.py index 92107b0e..595c3117 100644 --- a/server/src/agent_control_server/auth_framework/config.py +++ b/server/src/agent_control_server/auth_framework/config.py @@ -8,15 +8,19 @@ - **Default flow** (everything except runtime). One authorizer handles every operation that does not have a specific override: - :class:`HeaderAuthProvider` (local credentials) or + :class:`NoAuthProvider` (no credentials), + :class:`HeaderAuthProvider` (local API keys), or :class:`HttpUpstreamAuthProvider` (forwards to a configurable URL). -- **Runtime flow.** When ``AGENT_CONTROL_RUNTIME_TOKEN_SECRET`` is - configured, :class:`LocalJwtVerifyProvider` is registered as the - override for :data:`Operation.RUNTIME_USE`; the - ``runtime.token_exchange`` operation continues to flow through the - default authorizer because the exchange itself is shaped like a - management call (forward credential, get grant). Without the secret, - no runtime override is installed. +- **Runtime flow.** ``AGENT_CONTROL_RUNTIME_AUTH_MODE`` selects the + override for :data:`Operation.RUNTIME_USE`: ``none`` uses + :class:`NoAuthProvider`, ``api_key`` uses + :class:`HeaderAuthProvider`, and ``jwt`` uses + :class:`LocalJwtVerifyProvider`. When the mode is unset, startup + selects ``jwt`` if ``AGENT_CONTROL_RUNTIME_TOKEN_SECRET`` is set; + otherwise runtime falls through to the default authorizer. + The ``runtime.token_exchange`` operation continues to flow through + the default authorizer because the exchange itself is shaped like a + management call (forward credential, get grant). """ from __future__ import annotations @@ -30,6 +34,7 @@ HeaderAuthProvider, HttpUpstreamAuthProvider, LocalJwtVerifyProvider, + NoAuthProvider, ) from .providers.http_upstream import HttpUpstreamConfig @@ -41,8 +46,10 @@ _UPSTREAM_TIMEOUT_ENV = "AGENT_CONTROL_AUTH_UPSTREAM_TIMEOUT_SECONDS" _UPSTREAM_TOKEN_ENV = "AGENT_CONTROL_AUTH_UPSTREAM_SERVICE_TOKEN" _UPSTREAM_TOKEN_HEADER_ENV = "AGENT_CONTROL_AUTH_UPSTREAM_SERVICE_TOKEN_HEADER" +_UPSTREAM_EXTRA_FORWARD_HEADERS_ENV = "AGENT_CONTROL_AUTH_UPSTREAM_EXTRA_FORWARD_HEADERS" # Runtime flow. +_RUNTIME_MODE_ENV = "AGENT_CONTROL_RUNTIME_AUTH_MODE" _RUNTIME_TOKEN_SECRET_ENV = "AGENT_CONTROL_RUNTIME_TOKEN_SECRET" _RUNTIME_TOKEN_TTL_ENV = "AGENT_CONTROL_RUNTIME_TOKEN_TTL_SECONDS" _DEFAULT_RUNTIME_TOKEN_TTL_SECONDS = 300 @@ -80,15 +87,20 @@ def configure_auth_from_env() -> None: Default flow: - - ``AGENT_CONTROL_AUTH_MODE=header`` (default): :class:`HeaderAuthProvider`. + - ``AGENT_CONTROL_AUTH_MODE=none``: :class:`NoAuthProvider`. + - ``AGENT_CONTROL_AUTH_MODE=api_key`` (default): :class:`HeaderAuthProvider`. + ``header`` remains accepted as a backwards-compatible alias. - ``AGENT_CONTROL_AUTH_MODE=http_upstream``: :class:`HttpUpstreamAuthProvider` pointed at ``AGENT_CONTROL_AUTH_UPSTREAM_URL``. Runtime flow: - - When ``AGENT_CONTROL_RUNTIME_TOKEN_SECRET`` is set, register - :class:`LocalJwtVerifyProvider` as an override for - :data:`Operation.RUNTIME_USE`. + - ``AGENT_CONTROL_RUNTIME_AUTH_MODE=none``: :class:`NoAuthProvider`. + - ``AGENT_CONTROL_RUNTIME_AUTH_MODE=api_key``: :class:`HeaderAuthProvider`. + - ``AGENT_CONTROL_RUNTIME_AUTH_MODE=jwt`` (default when a runtime token + secret is configured): :class:`LocalJwtVerifyProvider`. + - unset mode without a runtime token secret: fall through to the default + authorizer. Clears any previously-installed default and operation overrides before installing fresh ones, so reconfiguration cannot leave @@ -101,29 +113,35 @@ def configure_auth_from_env() -> None: global _runtime_auth_config clear_authorizers() _active_providers.clear() - _runtime_auth_config = _load_runtime_auth_config() + runtime_mode = _resolve_runtime_mode() + _runtime_auth_config = ( + _load_runtime_auth_config(require_secret=True) if runtime_mode == "jwt" else None + ) default = _build_default_provider() set_authorizer(default) _active_providers.append(default) - if _runtime_auth_config is not None: - runtime_provider = LocalJwtVerifyProvider(secret=_runtime_auth_config.secret) - set_authorizer(runtime_provider, operation=Operation.RUNTIME_USE) - _active_providers.append(runtime_provider) + if runtime_mode == "default": _logger.info( - "Runtime auth enabled: LocalJwtVerifyProvider override installed for %s", + "Runtime auth provider: default authorizer handles %s", Operation.RUNTIME_USE.value, ) else: - _logger.warning( - "Runtime auth disabled (%s not set); %s falls through to the " - "default authorizer, which may grant any authenticated credential. " - "Set the runtime token secret to bind runtime calls to a " - "short-lived target-scoped JWT.", - _RUNTIME_TOKEN_SECRET_ENV, - Operation.RUNTIME_USE.value, - ) + runtime_provider = _build_runtime_provider(runtime_mode, _runtime_auth_config) + set_authorizer(runtime_provider, operation=Operation.RUNTIME_USE) + _active_providers.append(runtime_provider) + if runtime_mode == "jwt": + _logger.info( + "Runtime auth provider: jwt override installed for %s", + Operation.RUNTIME_USE.value, + ) + else: + _logger.info( + "Runtime auth provider: %s override installed for %s", + runtime_mode, + Operation.RUNTIME_USE.value, + ) async def teardown_auth() -> None: @@ -172,9 +190,12 @@ def set_runtime_auth_config(config: RuntimeAuthConfig | None) -> None: def _build_default_provider() -> RequestAuthorizer: - mode = os.environ.get(_MODE_ENV, "header").strip().lower() - if mode == "header": - _logger.info("Default auth provider: header (local credentials)") + mode = os.environ.get(_MODE_ENV, "api_key").strip().lower() + if mode in {"none", "no_auth"}: + _logger.info("Default auth provider: none") + return NoAuthProvider() + if mode in {"api_key", "header"}: + _logger.info("Default auth provider: api_key (local credentials)") return HeaderAuthProvider() if mode == "http_upstream": url = os.environ.get(_UPSTREAM_URL_ENV) @@ -183,6 +204,9 @@ def _build_default_provider() -> RequestAuthorizer: timeout = float(os.environ.get(_UPSTREAM_TIMEOUT_ENV, "5.0")) token = os.environ.get(_UPSTREAM_TOKEN_ENV) token_header = os.environ.get(_UPSTREAM_TOKEN_HEADER_ENV, "X-Agent-Control-Service-Token") + extra_forward_headers = _parse_extra_forward_headers( + os.environ.get(_UPSTREAM_EXTRA_FORWARD_HEADERS_ENV) + ) _logger.info("Default auth provider: http_upstream url=%s", url) return HttpUpstreamAuthProvider( HttpUpstreamConfig( @@ -190,21 +214,87 @@ def _build_default_provider() -> RequestAuthorizer: timeout_seconds=timeout, service_token=token, service_token_header=token_header, + extra_forward_headers=extra_forward_headers, ) ) - raise RuntimeError(f"Unknown {_MODE_ENV}={mode!r}; expected 'header' or 'http_upstream'.") + raise RuntimeError( + f"Unknown {_MODE_ENV}={mode!r}; expected 'none', 'api_key', or 'http_upstream'." + ) + +def _parse_extra_forward_headers(raw: str | None) -> tuple[str, ...]: + """Parse a comma-separated header list into a deduplicated tuple. -def _load_runtime_auth_config() -> RuntimeAuthConfig | None: + Empty / unset env var returns an empty tuple. Whitespace around each + name is stripped. Empty entries (e.g. ``"X-A,,X-B"``) are dropped. + Order is preserved; duplicates (case-insensitive) are dropped after + the first occurrence. + """ + if not raw or not raw.strip(): + return () + seen: set[str] = set() + result: list[str] = [] + for raw_name in raw.split(","): + name = raw_name.strip() + if not name: + continue + lower = name.lower() + if lower in seen: + continue + seen.add(lower) + result.append(name) + return tuple(result) + + +def _resolve_runtime_mode() -> str: + raw = os.environ.get(_RUNTIME_MODE_ENV) + if raw is None or not raw.strip(): + return "jwt" if os.environ.get(_RUNTIME_TOKEN_SECRET_ENV) else "default" + + mode = raw.strip().lower() + if mode in {"none", "no_auth"}: + return "none" + if mode in {"api_key", "header"}: + return "api_key" + if mode == "jwt": + return mode + raise RuntimeError( + f"Unknown {_RUNTIME_MODE_ENV}={mode!r}; expected 'none', 'api_key', or 'jwt'." + ) + + +def _build_runtime_provider( + mode: str, + config: RuntimeAuthConfig | None, +) -> RequestAuthorizer: + if mode == "none": + return NoAuthProvider() + if mode == "api_key": + return HeaderAuthProvider() + if mode == "jwt": + if config is None: + raise RuntimeError(f"{_RUNTIME_MODE_ENV}=jwt but runtime auth config is missing.") + return LocalJwtVerifyProvider(secret=config.secret) + raise RuntimeError( + f"Unknown runtime auth mode {mode!r}; expected 'none', 'api_key', or 'jwt'." + ) + + +def _load_runtime_auth_config(*, require_secret: bool = False) -> RuntimeAuthConfig | None: """Parse, validate, and return the runtime-auth config from env. - Returns ``None`` when no runtime secret is configured. Raises - ``RuntimeError`` when the secret is too short or the TTL is invalid - so misconfiguration surfaces at startup, not on the first - request-time mint. + Returns ``None`` when no runtime secret is configured and + ``require_secret`` is false. Raises ``RuntimeError`` when the + secret is required, too short, or the TTL is invalid so + misconfiguration surfaces at startup, not on the first request-time + mint. """ secret = os.environ.get(_RUNTIME_TOKEN_SECRET_ENV) if not secret: + if require_secret: + raise RuntimeError( + f"{_RUNTIME_MODE_ENV}=jwt requires {_RUNTIME_TOKEN_SECRET_ENV} to be set." + ) return None if len(secret.encode("utf-8")) < _RUNTIME_TOKEN_SECRET_MIN_BYTES: raise RuntimeError( diff --git a/server/src/agent_control_server/auth_framework/core.py b/server/src/agent_control_server/auth_framework/core.py index 9299b441..058169de 100644 --- a/server/src/agent_control_server/auth_framework/core.py +++ b/server/src/agent_control_server/auth_framework/core.py @@ -42,14 +42,19 @@ class Operation(StrEnum): CONTROL_BINDINGS_READ = "control_bindings.read" CONTROL_BINDINGS_WRITE = "control_bindings.write" - # Runtime token exchange — wired on the exchange endpoint. + # Runtime token exchange - wired on the exchange endpoint. RUNTIME_TOKEN_EXCHANGE = "runtime.token_exchange" - # Reserved for follow-up migrations; not yet wired on endpoints. CONTROLS_READ = "controls.read" CONTROLS_CREATE = "controls.create" CONTROLS_UPDATE = "controls.update" CONTROLS_DELETE = "controls.delete" + POLICIES_READ = "policies.read" + POLICIES_CREATE = "policies.create" + POLICIES_UPDATE = "policies.update" + AGENTS_READ = "agents.read" + AGENTS_CREATE = "agents.create" + AGENTS_UPDATE = "agents.update" RUNTIME_USE = "runtime.use" @@ -61,8 +66,7 @@ class Principal: namespace_key: The namespace the request runs in. Endpoints use this to scope every read and write. is_admin: Whether the caller has admin privileges in the - current namespace. Mostly informational for endpoints that - still gate on the legacy admin-key contract. + current namespace. caller_id: Opaque, provider-supplied identifier for the caller (e.g., a key fingerprint or user id). Useful for audit logging; never echo back to clients. @@ -122,7 +126,7 @@ def set_authorizer( Without ``operation``, this becomes the default authorizer used by every operation that does not have a specific override. With - ``operation``, it overrides the default for that operation only — + ``operation``, it overrides the default for that operation only - used to route a different family (e.g., runtime) through a different provider. diff --git a/server/src/agent_control_server/auth_framework/providers/__init__.py b/server/src/agent_control_server/auth_framework/providers/__init__.py index e8a68486..ad5d6b38 100644 --- a/server/src/agent_control_server/auth_framework/providers/__init__.py +++ b/server/src/agent_control_server/auth_framework/providers/__init__.py @@ -3,10 +3,12 @@ from .header import AccessLevel, HeaderAuthProvider from .http_upstream import HttpUpstreamAuthProvider from .local_jwt import LocalJwtVerifyProvider +from .no_auth import NoAuthProvider __all__ = [ "AccessLevel", "HeaderAuthProvider", "HttpUpstreamAuthProvider", "LocalJwtVerifyProvider", + "NoAuthProvider", ] diff --git a/server/src/agent_control_server/auth_framework/providers/header.py b/server/src/agent_control_server/auth_framework/providers/header.py index f76936a1..16760768 100644 --- a/server/src/agent_control_server/auth_framework/providers/header.py +++ b/server/src/agent_control_server/auth_framework/providers/header.py @@ -1,23 +1,14 @@ """Default :class:`RequestAuthorizer` that uses local credentials only. -Resolves the namespace from a header (or falls back to -``DEFAULT_NAMESPACE_KEY``) and enforces a per-operation access level -using the legacy API-key + session-cookie credential check from -:mod:`agent_control_server.auth`. Behavior matches the pre-framework -local auth path verbatim: +Returns ``DEFAULT_NAMESPACE_KEY`` and enforces a per-operation access +level using the local API-key + session-cookie credential check from +:mod:`agent_control_server.auth`: - ``ADMIN`` operations require an admin key (or admin session). - ``AUTHENTICATED`` operations require any valid credential. - ``PUBLIC`` operations are open. -- When ``api_key_enabled`` is ``False`` (no-auth mode), every - operation succeeds with a non-admin :class:`Principal` — preserved - by the underlying credential check. - -The header lookup is wired but currently inert: the provider always -returns the default namespace because non-binding write endpoints -still hardcode it. The header is kept here so a follow-up that -threads namespace resolution through the rest of the API can flip it -on without changing the provider contract. +- When the underlying local credential layer is disabled, every + operation succeeds with a non-admin :class:`Principal`. """ from __future__ import annotations @@ -51,6 +42,12 @@ class AccessLevel(Enum): Operation.CONTROLS_CREATE: AccessLevel.ADMIN, Operation.CONTROLS_UPDATE: AccessLevel.ADMIN, Operation.CONTROLS_DELETE: AccessLevel.ADMIN, + Operation.POLICIES_READ: AccessLevel.AUTHENTICATED, + Operation.POLICIES_CREATE: AccessLevel.ADMIN, + Operation.POLICIES_UPDATE: AccessLevel.ADMIN, + Operation.AGENTS_READ: AccessLevel.AUTHENTICATED, + Operation.AGENTS_CREATE: AccessLevel.AUTHENTICATED, + Operation.AGENTS_UPDATE: AccessLevel.ADMIN, Operation.RUNTIME_TOKEN_EXCHANGE: AccessLevel.AUTHENTICATED, Operation.RUNTIME_USE: AccessLevel.AUTHENTICATED, } @@ -60,7 +57,7 @@ class HeaderAuthProvider(RequestAuthorizer): """Default authorizer. For each operation's configured access level, validates the - request's credentials via the legacy local check; on success, + request's credentials via the local credential check; on success, returns a :class:`Principal` scoped to the resolved namespace. """ @@ -100,8 +97,7 @@ async def authorize( ) # Runtime token exchange returns a normalized scope grant so the # exchange endpoint can require ``runtime.use`` uniformly across - # providers; an upstream that explicitly grants no scopes ends - # up with an empty tuple and is rejected. + # providers. scopes: tuple[str, ...] = ( (Operation.RUNTIME_USE.value,) if operation is Operation.RUNTIME_TOKEN_EXCHANGE else () ) @@ -113,10 +109,7 @@ async def authorize( ) def _resolve_namespace_key(self, request: Request) -> str: - # The provider always returns the default namespace because - # non-binding write endpoints still hardcode it; serving - # anything else here would create rows the rest of the API - # cannot find. The branch is preserved so a future change can - # lift the lock without touching the provider contract. + # Local credentials do not carry namespace metadata. Providers + # that resolve a namespace can return a different principal. del request return self._default_namespace_key diff --git a/server/src/agent_control_server/auth_framework/providers/http_upstream.py b/server/src/agent_control_server/auth_framework/providers/http_upstream.py index a97a3de8..78ed9ae2 100644 --- a/server/src/agent_control_server/auth_framework/providers/http_upstream.py +++ b/server/src/agent_control_server/auth_framework/providers/http_upstream.py @@ -60,15 +60,15 @@ _logger = get_logger(__name__) -_FORWARDED_HEADERS = ("X-API-Key", "Authorization", "Cookie") +_DEFAULT_FORWARDED_HEADERS = ("X-API-Key", "Authorization", "Cookie") class _UpstreamGrant(BaseModel): """Strict schema for the upstream authorization-service response. Unknown fields are tolerated (so the upstream can evolve), but every - *known* field is type-checked. A wrong type on any field — or a - half-supplied target binding — causes the provider to fail closed + *known* field is type-checked. A wrong type on any field - or a + half-supplied target binding - causes the provider to fail closed with a 502. """ @@ -108,7 +108,7 @@ def _target_must_be_paired(self) -> _UpstreamGrant: A target is meaningful only as a ``(target_type, target_id)`` pair; allowing one side without the other would let a malformed grant pass and the exchange endpoint mint a token for the - request's value of the missing half — outside the upstream's + request's value of the missing half - outside the upstream's intended authorization. """ if (self.target_type is None) != (self.target_id is None): @@ -136,6 +136,17 @@ class HttpUpstreamConfig: service_token_header: str = "X-Agent-Control-Service-Token" + extra_forward_headers: tuple[str, ...] = () + """Additional inbound request headers to forward to the upstream + on top of the default ``(X-API-Key, Authorization, Cookie)`` set. + + Use this when the upstream authenticates via a header the provider + does not forward by default (e.g., a deployer-specific API-key + header). Header lookups against the inbound request are + case-insensitive; an empty or absent inbound header is silently + dropped. Names duplicating the default set or each other (after + case-folding) are deduplicated.""" + class HttpUpstreamAuthProvider(RequestAuthorizer): """Delegates authorization to an upstream HTTP service.""" @@ -190,7 +201,12 @@ async def authorize( def _forward_headers(self, request: Request) -> dict[str, str]: headers: dict[str, str] = {} - for name in _FORWARDED_HEADERS: + seen: set[str] = set() + for name in (*_DEFAULT_FORWARDED_HEADERS, *self._config.extra_forward_headers): + lower = name.lower() + if lower in seen: + continue + seen.add(lower) value = request.headers.get(name) if value is not None: headers[name] = value diff --git a/server/src/agent_control_server/auth_framework/providers/local_jwt.py b/server/src/agent_control_server/auth_framework/providers/local_jwt.py index bb448503..8620d3b6 100644 --- a/server/src/agent_control_server/auth_framework/providers/local_jwt.py +++ b/server/src/agent_control_server/auth_framework/providers/local_jwt.py @@ -6,7 +6,7 @@ returns a :class:`Principal` carrying the bound target. When a ``context_builder`` on the dependency surfaces ``target_type`` / ``target_id``, the provider also enforces that they match the token's -binding — runtime endpoints get the request-target check for free. +binding - runtime endpoints get the request-target check for free. """ from __future__ import annotations diff --git a/server/src/agent_control_server/auth_framework/providers/no_auth.py b/server/src/agent_control_server/auth_framework/providers/no_auth.py new file mode 100644 index 00000000..509ca4f3 --- /dev/null +++ b/server/src/agent_control_server/auth_framework/providers/no_auth.py @@ -0,0 +1,29 @@ +"""Authorizer for deployments that intentionally disable authentication.""" + +from __future__ import annotations + +from typing import Any + +from fastapi import Request + +from ...models import DEFAULT_NAMESPACE_KEY +from ..core import Operation, Principal, RequestAuthorizer + + +class NoAuthProvider(RequestAuthorizer): + """Allows every operation and returns the default namespace.""" + + def __init__(self, *, default_namespace_key: str = DEFAULT_NAMESPACE_KEY) -> None: + self._default_namespace_key = default_namespace_key + + async def authorize( + self, + request: Request, + operation: Operation, + context: dict[str, Any] | None = None, + ) -> Principal: + del request, context + scopes: tuple[str, ...] = ( + (Operation.RUNTIME_USE.value,) if operation is Operation.RUNTIME_TOKEN_EXCHANGE else () + ) + return Principal(namespace_key=self._default_namespace_key, scopes=scopes) diff --git a/server/src/agent_control_server/endpoints/agents.py b/server/src/agent_control_server/endpoints/agents.py index 034ae35f..ac099911 100644 --- a/server/src/agent_control_server/endpoints/agents.py +++ b/server/src/agent_control_server/endpoints/agents.py @@ -36,7 +36,7 @@ from sqlalchemy.dialects.postgresql import insert as pg_insert from sqlalchemy.ext.asyncio import AsyncSession -from ..auth import RequireAPIKey, require_admin_key +from ..auth_framework import Operation, Principal, require_operation from ..db import get_async_db from ..errors import ( APIValidationError, @@ -53,7 +53,6 @@ Policy, agent_policies, ) -from ..namespace import get_namespace_key from ..services.agent_names import normalize_agent_name_or_422 from ..services.controls import ( AgentControlEnabledState, @@ -112,7 +111,7 @@ def _validate_controls_for_agent(agent: Agent, controls: list[Control]) -> list[ agent_evaluators = {e.name: e for e in (agent_data.evaluators or [])} for control in controls: - # Skip unrendered template controls — they have no evaluators to validate. + # Skip unrendered template controls - they have no evaluators to validate. if ( isinstance(control.data, dict) and control.data.get("template") is not None @@ -286,7 +285,7 @@ async def list_agents( limit: int = _DEFAULT_PAGINATION_LIMIT, name: str | None = None, db: AsyncSession = Depends(get_async_db), - namespace_key: str = Depends(get_namespace_key), + principal: Principal = Depends(require_operation(Operation.AGENTS_READ)), ) -> ListAgentsResponse: """ List all registered agents with cursor-based pagination. @@ -300,11 +299,13 @@ async def list_agents( limit: Pagination limit (default 20, max 100) name: Optional name filter (case-insensitive partial match) db: Database session (injected) - namespace_key: Resolved namespace for the request + principal: Authorized request principal Returns: ListAgentsResponse with agent summaries and pagination info """ + namespace_key = principal.namespace_key + # Clamp limit limit = min(max(1, limit), _MAX_PAGINATION_LIMIT) @@ -377,14 +378,20 @@ async def list_agents( agent_policies.c.agent_name, agent_policies.c.policy_id, ) - .where(agent_policies.c.agent_name.in_(agent_names)) + .where( + agent_policies.c.namespace_key == namespace_key, + agent_policies.c.agent_name.in_(agent_names), + ) .order_by(agent_policies.c.agent_name, agent_policies.c.policy_id) ) policy_ids_result = await db.execute(policy_ids_query) for assoc_agent_name, policy_id in policy_ids_result.all(): policy_ids_map.setdefault(assoc_agent_name, []).append(policy_id) - control_counts_map = await control_service.list_active_control_counts_by_agent(agent_names) + control_counts_map = await control_service.list_active_control_counts_by_agent( + agent_names, + namespace_key=namespace_key, + ) # Build summaries summaries: list[AgentSummary] = [] @@ -436,9 +443,8 @@ async def list_agents( ) async def init_agent( request: InitAgentRequest, - client: RequireAPIKey, db: AsyncSession = Depends(get_async_db), - namespace_key: str = Depends(get_namespace_key), + principal: Principal = Depends(require_operation(Operation.AGENTS_CREATE)), ) -> InitAgentResponse: """ Register a new agent or update an existing agent's steps and metadata. @@ -462,10 +468,13 @@ async def init_agent( Args: request: Agent metadata and step schemas db: Database session (injected) + principal: Authorized request principal Returns: InitAgentResponse with created flag and the effective controls """ + namespace_key = principal.namespace_key + # Check for evaluator name collisions with built-in evaluators builtin_names = _get_builtin_evaluator_names() for ev in request.evaluators: @@ -835,7 +844,7 @@ async def init_agent( async def get_agent( agent_name: str, db: AsyncSession = Depends(get_async_db), - namespace_key: str = Depends(get_namespace_key), + principal: Principal = Depends(require_operation(Operation.AGENTS_READ)), ) -> GetAgentResponse: """ Retrieve agent metadata and all registered steps. @@ -845,8 +854,7 @@ async def get_agent( Args: agent_name: Agent identifier db: Database session (injected) - namespace_key: Resolved namespace; agents in another namespace - return 404 (non-disclosing). + principal: Authorized request principal Returns: GetAgentResponse with agent metadata and step list @@ -855,6 +863,7 @@ async def get_agent( HTTPException 404: Agent not found HTTPException 422: Agent data is corrupted """ + namespace_key = principal.namespace_key agent_name = normalize_agent_name_or_422(agent_name) result = await db.execute( select(Agent).where(Agent.name == agent_name, Agent.namespace_key == namespace_key) @@ -917,7 +926,7 @@ async def _get_agent_or_404( The lookup is always namespace-scoped: an agent that exists only in another namespace surfaces as 404 (non-disclosing) so duplicate - names across namespaces — which the schema explicitly permits — + names across namespaces - which the schema explicitly permits - cannot be addressed across the namespace boundary. """ normalized_agent_name = normalize_agent_name_or_422(agent_name) @@ -940,7 +949,6 @@ async def _get_agent_or_404( @router.post( "/{agent_name}/policies/{policy_id}", - dependencies=[Depends(require_admin_key)], response_model=AssocResponse, summary="Associate policy with agent", response_description="Success confirmation", @@ -949,9 +957,10 @@ async def add_agent_policy( agent_name: str, policy_id: int, db: AsyncSession = Depends(get_async_db), - namespace_key: str = Depends(get_namespace_key), + principal: Principal = Depends(require_operation(Operation.AGENTS_UPDATE)), ) -> AssocResponse: """Associate a policy with an agent (idempotent).""" + namespace_key = principal.namespace_key agent = await _get_agent_or_404(agent_name, db, namespace_key=namespace_key) policy_result = await db.execute( @@ -1017,7 +1026,6 @@ async def add_agent_policy( @router.post( "/{agent_name}/policy/{policy_id}", - dependencies=[Depends(require_admin_key)], response_model=SetPolicyResponse, summary="Assign policy to agent (compatibility)", response_description="Success status with previous policy ID", @@ -1026,9 +1034,10 @@ async def set_agent_policy( agent_name: str, policy_id: int, db: AsyncSession = Depends(get_async_db), - namespace_key: str = Depends(get_namespace_key), + principal: Principal = Depends(require_operation(Operation.AGENTS_UPDATE)), ) -> SetPolicyResponse: """Compatibility endpoint that replaces all policy associations with one policy.""" + namespace_key = principal.namespace_key agent = await _get_agent_or_404(agent_name, db, namespace_key=namespace_key) policy_result = await db.execute( @@ -1117,9 +1126,10 @@ async def set_agent_policy( async def get_agent_policies( agent_name: str, db: AsyncSession = Depends(get_async_db), - namespace_key: str = Depends(get_namespace_key), + principal: Principal = Depends(require_operation(Operation.AGENTS_READ)), ) -> GetAgentPoliciesResponse: """List policy IDs associated with an agent.""" + namespace_key = principal.namespace_key agent = await _get_agent_or_404(agent_name, db, namespace_key=namespace_key) result = await db.execute( select(agent_policies.c.policy_id) @@ -1141,9 +1151,10 @@ async def get_agent_policies( async def get_agent_policy( agent_name: str, db: AsyncSession = Depends(get_async_db), - namespace_key: str = Depends(get_namespace_key), + principal: Principal = Depends(require_operation(Operation.AGENTS_READ)), ) -> GetPolicyResponse: """Compatibility endpoint that returns the first associated policy.""" + namespace_key = principal.namespace_key agent = await _get_agent_or_404(agent_name, db, namespace_key=namespace_key) policy_result = await db.execute( select(Policy.id) @@ -1172,7 +1183,6 @@ async def get_agent_policy( @router.delete( "/{agent_name}/policies/{policy_id}", - dependencies=[Depends(require_admin_key)], response_model=AssocResponse, summary="Remove policy association from agent", response_description="Success confirmation", @@ -1181,13 +1191,14 @@ async def remove_agent_policy( agent_name: str, policy_id: int, db: AsyncSession = Depends(get_async_db), - namespace_key: str = Depends(get_namespace_key), + principal: Principal = Depends(require_operation(Operation.AGENTS_UPDATE)), ) -> AssocResponse: """Remove a policy association from an agent. Idempotent for existing resources: removing a non-associated link is a no-op. Missing agent/policy resources still return 404. """ + namespace_key = principal.namespace_key agent = await _get_agent_or_404(agent_name, db, namespace_key=namespace_key) policy_result = await db.execute( @@ -1230,7 +1241,6 @@ async def remove_agent_policy( @router.delete( "/{agent_name}/policies", - dependencies=[Depends(require_admin_key)], response_model=AssocResponse, summary="Remove all policy associations from agent", response_description="Success confirmation", @@ -1238,9 +1248,10 @@ async def remove_agent_policy( async def remove_all_agent_policies( agent_name: str, db: AsyncSession = Depends(get_async_db), - namespace_key: str = Depends(get_namespace_key), + principal: Principal = Depends(require_operation(Operation.AGENTS_UPDATE)), ) -> AssocResponse: """Remove all policy associations from an agent.""" + namespace_key = principal.namespace_key agent = await _get_agent_or_404(agent_name, db, namespace_key=namespace_key) try: @@ -1271,7 +1282,6 @@ async def remove_all_agent_policies( @router.delete( "/{agent_name}/policy", - dependencies=[Depends(require_admin_key)], response_model=DeletePolicyResponse, summary="Remove agent's policy assignment (compatibility)", response_description="Success confirmation", @@ -1279,9 +1289,10 @@ async def remove_all_agent_policies( async def delete_agent_policy( agent_name: str, db: AsyncSession = Depends(get_async_db), - namespace_key: str = Depends(get_namespace_key), + principal: Principal = Depends(require_operation(Operation.AGENTS_UPDATE)), ) -> DeletePolicyResponse: """Compatibility endpoint that removes all policy associations.""" + namespace_key = principal.namespace_key agent = await _get_agent_or_404(agent_name, db, namespace_key=namespace_key) existing_policy_result = await db.execute( @@ -1328,7 +1339,6 @@ async def delete_agent_policy( @router.post( "/{agent_name}/controls/{control_id}", - dependencies=[Depends(require_admin_key)], response_model=AssocResponse, summary="Associate control directly with agent", response_description="Success confirmation", @@ -1337,9 +1347,10 @@ async def add_agent_control( agent_name: str, control_id: int, db: AsyncSession = Depends(get_async_db), - namespace_key: str = Depends(get_namespace_key), + principal: Principal = Depends(require_operation(Operation.AGENTS_UPDATE)), ) -> AssocResponse: """Associate a control directly with an agent (idempotent).""" + namespace_key = principal.namespace_key agent = await _get_agent_or_404(agent_name, db, namespace_key=namespace_key) control_service = ControlService(db) control = await control_service.get_active_control_or_404( @@ -1389,7 +1400,6 @@ async def add_agent_control( @router.delete( "/{agent_name}/controls/{control_id}", - dependencies=[Depends(require_admin_key)], response_model=RemoveAgentControlResponse, summary="Remove direct control association from agent", response_description="Success confirmation", @@ -1398,9 +1408,10 @@ async def remove_agent_control( agent_name: str, control_id: int, db: AsyncSession = Depends(get_async_db), - namespace_key: str = Depends(get_namespace_key), + principal: Principal = Depends(require_operation(Operation.AGENTS_UPDATE)), ) -> RemoveAgentControlResponse: """Remove a direct control association from an agent (idempotent).""" + namespace_key = principal.namespace_key agent = await _get_agent_or_404(agent_name, db, namespace_key=namespace_key) control_service = ControlService(db) await control_service.get_active_control_or_404(control_id, namespace_key=namespace_key) @@ -1481,7 +1492,7 @@ async def list_agent_controls( description="Optional opaque target identifier. Required when target_type is supplied.", ), db: AsyncSession = Depends(get_async_db), - namespace_key: str = Depends(get_namespace_key), + principal: Principal = Depends(require_operation(Operation.AGENTS_READ)), ) -> AgentControlsResponse: """ List protection controls effective for an agent. @@ -1506,7 +1517,7 @@ async def list_agent_controls( target_type: Optional opaque target kind (paired with target_id) target_id: Optional opaque target identifier (paired with target_type) db: Database session (injected) - namespace_key: Namespace scoping for the resolution (injected) + principal: Authorized request principal Returns: AgentControlsResponse with controls matching the requested state filters @@ -1515,6 +1526,8 @@ async def list_agent_controls( HTTPException 400: target_type and target_id were not supplied together HTTPException 404: Agent not found """ + namespace_key = principal.namespace_key + if (target_type is None) != (target_id is None): raise BadRequestError( error_code=ErrorCode.VALIDATION_ERROR, @@ -1572,7 +1585,7 @@ async def list_agent_evaluators( cursor: str | None = None, limit: int = _DEFAULT_PAGINATION_LIMIT, db: AsyncSession = Depends(get_async_db), - namespace_key: str = Depends(get_namespace_key), + principal: Principal = Depends(require_operation(Operation.AGENTS_READ)), ) -> ListEvaluatorsResponse: """ List all evaluator schemas registered with an agent. @@ -1586,8 +1599,7 @@ async def list_agent_evaluators( cursor: Optional cursor for pagination (name of last evaluator from previous page) limit: Pagination limit (default 20, max 100) db: Database session (injected) - namespace_key: Resolved namespace; agents in another namespace - return 404 (non-disclosing). + principal: Authorized request principal Returns: ListEvaluatorsResponse with evaluator schemas and pagination @@ -1595,6 +1607,7 @@ async def list_agent_evaluators( Raises: HTTPException 404: Agent not found """ + namespace_key = principal.namespace_key agent_name = normalize_agent_name_or_422(agent_name) # Clamp limit limit = min(max(1, limit), _MAX_PAGINATION_LIMIT) @@ -1672,7 +1685,7 @@ async def get_agent_evaluator( agent_name: str, evaluator_name: str, db: AsyncSession = Depends(get_async_db), - namespace_key: str = Depends(get_namespace_key), + principal: Principal = Depends(require_operation(Operation.AGENTS_READ)), ) -> EvaluatorSchemaItem: """ Get a specific evaluator schema registered with an agent. @@ -1681,8 +1694,7 @@ async def get_agent_evaluator( agent_name: Agent identifier evaluator_name: Name of the evaluator db: Database session (injected) - namespace_key: Resolved namespace; agents in another namespace - return 404 (non-disclosing). + principal: Authorized request principal Returns: EvaluatorSchemaItem with schema details @@ -1690,6 +1702,7 @@ async def get_agent_evaluator( Raises: HTTPException 404: Agent or evaluator not found """ + namespace_key = principal.namespace_key agent_name = normalize_agent_name_or_422(agent_name) result = await db.execute( select(Agent).where(Agent.name == agent_name, Agent.namespace_key == namespace_key) @@ -1734,7 +1747,6 @@ async def get_agent_evaluator( @router.patch( "/{agent_name}", - dependencies=[Depends(require_admin_key)], response_model=PatchAgentResponse, summary="Modify agent (remove steps/evaluators)", response_description="Lists of removed items", @@ -1743,7 +1755,7 @@ async def patch_agent( agent_name: str, request: PatchAgentRequest, db: AsyncSession = Depends(get_async_db), - namespace_key: str = Depends(get_namespace_key), + principal: Principal = Depends(require_operation(Operation.AGENTS_UPDATE)), ) -> PatchAgentResponse: """ Remove steps and/or evaluators from an agent. @@ -1755,6 +1767,7 @@ async def patch_agent( agent_name: Agent identifier request: Lists of step/evaluator identifiers to remove db: Database session (injected) + principal: Authorized request principal Returns: PatchAgentResponse with lists of actually removed items @@ -1763,6 +1776,7 @@ async def patch_agent( HTTPException 404: Agent not found HTTPException 500: Database error during update """ + namespace_key = principal.namespace_key agent_name = normalize_agent_name_or_422(agent_name) result = await db.execute( select(Agent).where( diff --git a/server/src/agent_control_server/endpoints/auth.py b/server/src/agent_control_server/endpoints/auth.py index 1a23baa8..b1ade969 100644 --- a/server/src/agent_control_server/endpoints/auth.py +++ b/server/src/agent_control_server/endpoints/auth.py @@ -2,14 +2,13 @@ The runtime auth flow is two-phase: this endpoint is phase one. The caller presents a long-lived credential plus ``(target_type, -target_id)``; the default authorizer (typically -:class:`HttpUpstreamAuthProvider` in production) authenticates the +target_id)``; the configured authorization provider authenticates the credential and authorizes the implied -:data:`Operation.RUNTIME_TOKEN_EXCHANGE`. On success, this endpoint +``runtime.token_exchange`` operation. On success, this endpoint mints a short-lived local runtime token bound to the supplied target and returns it. Subsequent target-bearing runtime calls present the returned token, which is verified locally by -:class:`LocalJwtVerifyProvider`. +the runtime JWT provider. """ from __future__ import annotations @@ -57,7 +56,7 @@ class RuntimeTokenExchangeResponse(BaseModel): async def _exchange_context(request: Request) -> dict[str, Any]: - """Surface target identifiers to the authorizer's context. + """Surface target identifiers to the authorization context. Reads the request body once. FastAPI caches the parsed body, so the endpoint's own Pydantic body model still binds normally. @@ -90,11 +89,10 @@ async def runtime_token_exchange( ) -> RuntimeTokenExchangeResponse: """Mint a short-lived runtime token for the requested target. - The caller's credential is authenticated and authorized by the - installed default authorizer; the resulting :class:`Principal` - supplies the actor identity and (when the upstream surfaces it) - the grant scopes and expiry. This endpoint then mints a local HS256 - token whose lifetime cannot outlive the upstream grant. + The caller's credential is authenticated and authorized before the + resolved principal supplies the actor identity, grant scopes, and + expiry. This endpoint then mints a local HS256 token whose lifetime + cannot outlive the grant. Runtime auth must be enabled via ``AGENT_CONTROL_RUNTIME_TOKEN_SECRET``; otherwise the endpoint @@ -130,8 +128,8 @@ async def runtime_token_exchange( actor_id = principal.caller_id or "anonymous" # The exchange endpoint requires the authorizer to explicitly grant - # runtime.use. Providers that do not surface scopes (legacy local - # provider) supply a normalized grant for ``RUNTIME_TOKEN_EXCHANGE``; + # runtime.use. Local providers supply a normalized grant for + # ``RUNTIME_TOKEN_EXCHANGE``; # upstream providers that return an explicit empty scopes array fail # closed here rather than escalating to runtime.use. if Operation.RUNTIME_USE.value not in principal.scopes: @@ -155,7 +153,7 @@ async def runtime_token_exchange( ) except UpstreamGrantExpiredError as exc: # Upstream returned a grant whose ``expires_at`` is already in - # the past — minting would hand the caller a token that's dead + # the past - minting would hand the caller a token that's dead # on arrival. Distinguished from the misconfigured case so the # error code and status reflect "upstream returned bad data." raise APIError( diff --git a/server/src/agent_control_server/endpoints/control_bindings.py b/server/src/agent_control_server/endpoints/control_bindings.py index 92798ae1..87386723 100644 --- a/server/src/agent_control_server/endpoints/control_bindings.py +++ b/server/src/agent_control_server/endpoints/control_bindings.py @@ -26,7 +26,6 @@ from ..db import get_async_db from ..errors import BadRequestError from ..models import ControlBinding -from ..namespace import get_namespace_key from ..services.control_bindings import ControlBindingsService router = APIRouter(prefix="/control-bindings", tags=["control-bindings"]) @@ -94,26 +93,21 @@ def _to_response(binding: ControlBinding) -> GetControlBindingResponse: async def create_control_binding( request: CreateControlBindingRequest, db: AsyncSession = Depends(get_async_db), - _principal: Principal = Depends( + principal: Principal = Depends( require_operation( Operation.CONTROL_BINDINGS_WRITE, context_builder=_binding_body_context, ) ), - namespace_key: str = Depends(get_namespace_key), ) -> CreateControlBindingResponse: """Attach a control to an opaque external target. - Each binding row is scoped to the request namespace as resolved by - ``get_namespace_key``. The auth chain still runs via - ``require_operation`` for authentication and authorization, but the - storage namespace is taken from the same resolver the rest of the - server uses so binding writes and runtime reads stay in lockstep - until auth-derived namespace resolution lands across every endpoint. + Each binding row is scoped to the namespace associated with the + authenticated request. """ service = ControlBindingsService(db) binding = await service.create_binding( - namespace_key=namespace_key, + namespace_key=principal.namespace_key, target_type=request.target_type, target_id=request.target_id, control_id=request.control_id, @@ -148,20 +142,18 @@ async def list_control_bindings( target_id: str | None = None, control_id: int | None = None, db: AsyncSession = Depends(get_async_db), - _principal: Principal = Depends( + principal: Principal = Depends( require_operation( Operation.CONTROL_BINDINGS_READ, context_builder=_binding_list_context, ) ), - namespace_key: str = Depends(get_namespace_key), ) -> ListControlBindingsResponse: """Return bindings in the request namespace with optional filters and cursor-based pagination. Bindings are ordered by ID descending (newest first). The cursor is opaque to clients: pass back the ``next_cursor`` value verbatim to fetch the following page. The - storage namespace is resolved by ``get_namespace_key`` so this - listing stays in lockstep with the rest of the server's reads. + storage namespace is resolved from the authenticated request. """ parsed_cursor: int | None if cursor is None: @@ -177,7 +169,7 @@ async def list_control_bindings( ) from exc service = ControlBindingsService(db) page = await service.list_bindings( - namespace_key=namespace_key, + namespace_key=principal.namespace_key, cursor=parsed_cursor, limit=limit, target_type=target_type, @@ -204,21 +196,21 @@ async def list_control_bindings( async def get_control_binding( binding_id: int, db: AsyncSession = Depends(get_async_db), - _principal: Principal = Depends(require_operation(Operation.CONTROL_BINDINGS_READ)), - namespace_key: str = Depends(get_namespace_key), + principal: Principal = Depends(require_operation(Operation.CONTROL_BINDINGS_READ)), ) -> GetControlBindingResponse: """Read a single control binding by surrogate ID. Authorization is namespace-wide: the binding's target identifiers - are not forwarded to the upstream because they are only discoverable - after the row is loaded, and ``require_operation`` is single-pass. + are not available until after the row is loaded. Callers whose authorization model requires per-target permissions should use the natural-key endpoints (``PUT /by-key``, ``POST /by-key:delete``) and the target-filtered list endpoint, all - of which forward ``(target_type, target_id)`` to the authorizer. + of which include ``(target_type, target_id)`` in the request context. """ service = ControlBindingsService(db) - binding = await service.get_binding_or_404(namespace_key=namespace_key, binding_id=binding_id) + binding = await service.get_binding_or_404( + namespace_key=principal.namespace_key, binding_id=binding_id + ) return _to_response(binding) @@ -232,19 +224,18 @@ async def patch_control_binding( binding_id: int, request: PatchControlBindingRequest, db: AsyncSession = Depends(get_async_db), - _principal: Principal = Depends(require_operation(Operation.CONTROL_BINDINGS_WRITE)), - namespace_key: str = Depends(get_namespace_key), + principal: Principal = Depends(require_operation(Operation.CONTROL_BINDINGS_WRITE)), ) -> PatchControlBindingResponse: """Update the ``enabled`` flag on a control binding. See the GET-by-id docstring for the authorization scope: this route is namespace-wide because the target identifiers are not available before the binding is loaded. Use ``PUT /by-key`` for target-scoped - upserts that forward the target to the authorizer. + upserts that include the target in the request context. """ service = ControlBindingsService(db) binding = await service.set_enabled( - namespace_key=namespace_key, + namespace_key=principal.namespace_key, binding_id=binding_id, enabled=request.enabled, ) @@ -261,18 +252,17 @@ async def patch_control_binding( async def delete_control_binding( binding_id: int, db: AsyncSession = Depends(get_async_db), - _principal: Principal = Depends(require_operation(Operation.CONTROL_BINDINGS_WRITE)), - namespace_key: str = Depends(get_namespace_key), + principal: Principal = Depends(require_operation(Operation.CONTROL_BINDINGS_WRITE)), ) -> DeleteControlBindingResponse: """Delete a control binding by surrogate ID. See the GET-by-id docstring for the authorization scope: this route is namespace-wide because the target identifiers are not available before the binding is loaded. Use ``POST /by-key:delete`` for - target-scoped detach that forwards the target to the authorizer. + target-scoped detach that includes the target in the request context. """ service = ControlBindingsService(db) - await service.delete_binding(namespace_key=namespace_key, binding_id=binding_id) + await service.delete_binding(namespace_key=principal.namespace_key, binding_id=binding_id) await db.commit() return DeleteControlBindingResponse(success=True) @@ -286,13 +276,12 @@ async def delete_control_binding( async def upsert_control_binding_by_key( request: UpsertControlBindingRequest, db: AsyncSession = Depends(get_async_db), - _principal: Principal = Depends( + principal: Principal = Depends( require_operation( Operation.CONTROL_BINDINGS_WRITE, context_builder=_binding_body_context, ) ), - namespace_key: str = Depends(get_namespace_key), ) -> UpsertControlBindingResponse: """Idempotent attach using ``(target_type, target_id, control_id)`` as the natural key. Updates ``enabled`` on an existing match; creates a new row @@ -300,7 +289,7 @@ async def upsert_control_binding_by_key( """ service = ControlBindingsService(db) binding, created = await service.upsert_by_natural_key( - namespace_key=namespace_key, + namespace_key=principal.namespace_key, target_type=request.target_type, target_id=request.target_id, control_id=request.control_id, @@ -324,20 +313,19 @@ async def upsert_control_binding_by_key( async def delete_control_binding_by_key( request: DeleteControlBindingByKeyRequest, db: AsyncSession = Depends(get_async_db), - _principal: Principal = Depends( + principal: Principal = Depends( require_operation( Operation.CONTROL_BINDINGS_WRITE, context_builder=_binding_body_context, ) ), - namespace_key: str = Depends(get_namespace_key), ) -> DeleteControlBindingByKeyResponse: """Idempotent detach by natural key. Returns ``deleted=False`` when no matching binding exists. """ service = ControlBindingsService(db) deleted = await service.delete_by_natural_key( - namespace_key=namespace_key, + namespace_key=principal.namespace_key, target_type=request.target_type, target_id=request.target_id, control_id=request.control_id, diff --git a/server/src/agent_control_server/endpoints/controls.py b/server/src/agent_control_server/endpoints/controls.py index fcb7cb18..b4fa8d0b 100644 --- a/server/src/agent_control_server/endpoints/controls.py +++ b/server/src/agent_control_server/endpoints/controls.py @@ -229,7 +229,7 @@ async def _materialize_control_input( enabled=enabled, ) - # Incomplete values — only allowed for new controls or already-unrendered + # Incomplete values - only allowed for new controls or already-unrendered # templates. Updating a rendered control with incomplete values is # rejected to prevent silently stripping rendered fields. current_is_rendered = ( @@ -470,7 +470,7 @@ async def render_control_template( async def create_control( request: CreateControlRequest, db: AsyncSession = Depends(get_async_db), - _principal: Principal = Depends(require_operation(Operation.CONTROLS_CREATE)), + principal: Principal = Depends(require_operation(Operation.CONTROLS_CREATE)), ) -> CreateControlResponse: """ Create a new control with a unique name. @@ -492,7 +492,10 @@ async def create_control( control_service = ControlService(db) # Uniqueness check - if await control_service.active_control_name_exists(request.name): + namespace_key = principal.namespace_key + if await control_service.active_control_name_exists( + request.name, namespace_key=namespace_key + ): raise ConflictError( error_code=ErrorCode.CONTROL_NAME_CONFLICT, detail=f"Control with name '{request.name}' already exists", @@ -504,7 +507,11 @@ async def create_control( control_def = await _materialize_control_input(request.data, db=db) control_data = _serialize_control_data(control_def) - control = control_service.create_control(name=request.name, data=control_data) + control = control_service.create_control( + namespace_key=namespace_key, + name=request.name, + data=control_data, + ) try: await control_service.create_version( control, @@ -569,7 +576,7 @@ async def get_control_schema() -> GetControlSchemaResponse: async def get_control( control_id: int, db: AsyncSession = Depends(get_async_db), - _principal: Principal = Depends(require_operation(Operation.CONTROLS_READ)), + principal: Principal = Depends(require_operation(Operation.CONTROLS_READ)), ) -> GetControlResponse: """ Retrieve a control by ID including its name and configuration data. @@ -584,7 +591,9 @@ async def get_control( Raises: HTTPException 404: Control not found """ - control = await ControlService(db).get_active_control_or_404(control_id) + control = await ControlService(db).get_active_control_or_404( + control_id, namespace_key=principal.namespace_key + ) control_data = _parse_stored_control_data( control.data, control_name=control.name, @@ -608,7 +617,7 @@ async def get_control( async def get_control_data( control_id: int, db: AsyncSession = Depends(get_async_db), - _principal: Principal = Depends(require_operation(Operation.CONTROLS_READ)), + principal: Principal = Depends(require_operation(Operation.CONTROLS_READ)), ) -> GetControlDataResponse: """ Retrieve the configuration data for a control. @@ -626,7 +635,9 @@ async def get_control_data( HTTPException 404: Control not found HTTPException 422: Control data is corrupted """ - control = await ControlService(db).get_active_control_or_404(control_id) + control = await ControlService(db).get_active_control_or_404( + control_id, namespace_key=principal.namespace_key + ) control_data = _parse_stored_control_data( control.data, control_name=control.name, @@ -648,10 +659,15 @@ async def list_control_versions( ), limit: int = Query(_DEFAULT_PAGINATION_LIMIT, ge=1, le=_MAX_PAGINATION_LIMIT), db: AsyncSession = Depends(get_async_db), - _principal: Principal = Depends(require_operation(Operation.CONTROLS_READ)), + principal: Principal = Depends(require_operation(Operation.CONTROLS_READ)), ) -> ListControlVersionsResponse: """List control versions ordered newest-first using cursor-based pagination.""" - page = await ControlService(db).list_versions(control_id, cursor=cursor, limit=limit) + page = await ControlService(db).list_versions( + control_id, + namespace_key=principal.namespace_key, + cursor=cursor, + limit=limit, + ) return ListControlVersionsResponse( versions=[ @@ -682,10 +698,12 @@ async def get_control_version( control_id: int, version_num: int, db: AsyncSession = Depends(get_async_db), - _principal: Principal = Depends(require_operation(Operation.CONTROLS_READ)), + principal: Principal = Depends(require_operation(Operation.CONTROLS_READ)), ) -> GetControlVersionResponse: """Return a specific control version, including its raw persisted snapshot.""" - version = await ControlService(db).get_version_or_404(control_id, version_num) + version = await ControlService(db).get_version_or_404( + control_id, version_num, namespace_key=principal.namespace_key + ) return GetControlVersionResponse( version_num=version.version_num, event_type=version.event_type, @@ -705,7 +723,7 @@ async def set_control_data( control_id: int, request: SetControlDataRequest, db: AsyncSession = Depends(get_async_db), - _principal: Principal = Depends(require_operation(Operation.CONTROLS_UPDATE)), + principal: Principal = Depends(require_operation(Operation.CONTROLS_UPDATE)), ) -> SetControlDataResponse: """ Update the configuration data for a control. @@ -726,7 +744,9 @@ async def set_control_data( HTTPException 500: Database error during update """ control_service = ControlService(db) - control = await control_service.get_active_control_or_404(control_id, for_update=True) + control = await control_service.get_active_control_or_404( + control_id, namespace_key=principal.namespace_key, for_update=True + ) control_def = await _materialize_control_input( request.data, @@ -811,7 +831,7 @@ async def list_controls( execution: str | None = Query(None, description="Filter by execution ('server' or 'sdk')"), tag: str | None = Query(None, description="Filter by tag"), db: AsyncSession = Depends(get_async_db), - _principal: Principal = Depends(require_operation(Operation.CONTROLS_READ)), + principal: Principal = Depends(require_operation(Operation.CONTROLS_READ)), ) -> ListControlsResponse: """ List all controls with optional filtering and cursor-based pagination. @@ -837,7 +857,9 @@ async def list_controls( GET /controls?limit=10&enabled=true&step_type=tool """ control_service = ControlService(db) + namespace_key = principal.namespace_key page = await control_service.list_controls_page( + namespace_key=namespace_key, cursor=cursor, limit=limit, name=name, @@ -849,7 +871,8 @@ async def list_controls( tag=tag, ) usage_by_control_id = await control_service.list_control_usage( - [control.id for control in page.controls] + [control.id for control in page.controls], + namespace_key=namespace_key, ) # Build summaries (filtering already done at DB level) @@ -910,7 +933,7 @@ async def delete_control( "If false, fail if control is associated with any policy or agent.", ), db: AsyncSession = Depends(get_async_db), - _principal: Principal = Depends(require_operation(Operation.CONTROLS_DELETE)), + principal: Principal = Depends(require_operation(Operation.CONTROLS_DELETE)), ) -> DeleteControlResponse: """ Delete a control by ID. @@ -933,13 +956,18 @@ async def delete_control( """ control_service = ControlService(db) bindings_service = ControlBindingsService(db) - control = await control_service.get_active_control_or_404(control_id, for_update=True) + namespace_key = principal.namespace_key + control = await control_service.get_active_control_or_404( + control_id, namespace_key=namespace_key, for_update=True + ) - associations = await control_service.list_control_associations(control_id) + associations = await control_service.list_control_associations( + control_id, namespace_key=namespace_key + ) associated_policy_ids = associations.policy_ids associated_agent_names = associations.agent_names target_binding_ids = await bindings_service.list_binding_ids_for_control( - namespace_key=control.namespace_key, control_id=control_id + namespace_key=namespace_key, control_id=control_id ) if ( @@ -996,13 +1024,15 @@ async def delete_control( dissociated_from_policies: list[int] = [] dissociated_from_agents: list[str] = [] if associated_policy_ids or associated_agent_names: - dissociated = await control_service.remove_all_control_associations(control_id) + dissociated = await control_service.remove_all_control_associations( + control_id, namespace_key=namespace_key + ) dissociated_from_policies = dissociated.policy_ids dissociated_from_agents = dissociated.agent_names detached_target_bindings: list[int] = [] if target_binding_ids: detached_target_bindings = await bindings_service.delete_bindings_for_control( - namespace_key=control.namespace_key, control_id=control_id + namespace_key=namespace_key, control_id=control_id ) if dissociated_from_policies or dissociated_from_agents or detached_target_bindings: _logger.info( @@ -1057,7 +1087,7 @@ async def patch_control( control_id: int, request: PatchControlRequest, db: AsyncSession = Depends(get_async_db), - _principal: Principal = Depends(require_operation(Operation.CONTROLS_UPDATE)), + principal: Principal = Depends(require_operation(Operation.CONTROLS_UPDATE)), ) -> PatchControlResponse: """ Update control metadata (name and/or enabled status). @@ -1081,7 +1111,10 @@ async def patch_control( HTTPException 500: Database error during update """ control_service = ControlService(db) - control = await control_service.get_active_control_or_404(control_id, for_update=True) + namespace_key = principal.namespace_key + control = await control_service.get_active_control_or_404( + control_id, namespace_key=namespace_key, for_update=True + ) parsed_control = _parse_stored_control_data( control.data, control_name=control.name, @@ -1096,6 +1129,7 @@ async def patch_control( # Check for name collision if await control_service.active_control_name_exists( request.name, + namespace_key=namespace_key, exclude_control_id=control_id, ): raise ConflictError( diff --git a/server/src/agent_control_server/endpoints/evaluation.py b/server/src/agent_control_server/endpoints/evaluation.py index e018796e..437af8b5 100644 --- a/server/src/agent_control_server/endpoints/evaluation.py +++ b/server/src/agent_control_server/endpoints/evaluation.py @@ -10,16 +10,15 @@ EvaluationResponse, ) from agent_control_models.errors import ErrorCode, ValidationErrorItem -from fastapi import APIRouter, Depends +from fastapi import APIRouter, Depends, Request from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession -from ..auth import RequireAPIKey +from ..auth_framework import Operation, Principal, require_operation from ..db import get_async_db from ..errors import APIValidationError, NotFoundError from ..logging_utils import get_logger from ..models import Agent -from ..namespace import get_namespace_key from ..services.controls import ControlService router = APIRouter(prefix="/evaluation", tags=["evaluation"]) @@ -118,6 +117,20 @@ def _sanitize_evaluation_response(response: EvaluationResponse) -> EvaluationRes ) +async def _evaluation_context(request: Request) -> dict[str, object]: + """Surface target identifiers to the runtime authorizer.""" + try: + body = await request.json() + except Exception: # noqa: BLE001 malformed JSON, defer to endpoint validation + return {} + if not isinstance(body, dict): + return {} + return { + "target_type": body.get("target_type"), + "target_id": body.get("target_id"), + } + + @router.post( "", response_model=EvaluationResponse, @@ -126,9 +139,10 @@ def _sanitize_evaluation_response(response: EvaluationResponse) -> EvaluationRes ) async def evaluate( request: EvaluationRequest, - client: RequireAPIKey, db: AsyncSession = Depends(get_async_db), - namespace_key: str = Depends(get_namespace_key), + principal: Principal = Depends( + require_operation(Operation.RUNTIME_USE, context_builder=_evaluation_context) + ), ) -> EvaluationResponse: """Analyze content for safety and control violations. @@ -144,7 +158,7 @@ async def evaluate( on the server; SDKs reconstruct and emit those events separately through the observability ingestion endpoint. """ - del client # Authentication is still required by dependency injection. + namespace_key = principal.namespace_key agent_result = await db.execute( select(Agent).where( diff --git a/server/src/agent_control_server/endpoints/policies.py b/server/src/agent_control_server/endpoints/policies.py index 7b8b2ef9..ddda7127 100644 --- a/server/src/agent_control_server/endpoints/policies.py +++ b/server/src/agent_control_server/endpoints/policies.py @@ -9,7 +9,7 @@ from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession -from ..auth import require_admin_key +from ..auth_framework import Operation, Principal, require_operation from ..db import get_async_db from ..errors import ConflictError, DatabaseError, NotFoundError from ..logging_utils import get_logger @@ -23,13 +23,14 @@ @router.put( "", - dependencies=[Depends(require_admin_key)], response_model=CreatePolicyResponse, summary="Create a new policy", response_description="Created policy ID", ) async def create_policy( - request: CreatePolicyRequest, db: AsyncSession = Depends(get_async_db) + request: CreatePolicyRequest, + db: AsyncSession = Depends(get_async_db), + principal: Principal = Depends(require_operation(Operation.POLICIES_CREATE)), ) -> CreatePolicyResponse: """ Create a new empty policy with a unique name. @@ -48,8 +49,14 @@ async def create_policy( HTTPException 409: Policy with this name already exists HTTPException 500: Database error during creation """ + namespace_key = principal.namespace_key # Uniqueness check - existing = await db.execute(select(Policy.id).where(Policy.name == request.name)) + existing = await db.execute( + select(Policy.id).where( + Policy.namespace_key == namespace_key, + Policy.name == request.name, + ) + ) if existing.first() is not None: raise ConflictError( error_code=ErrorCode.POLICY_NAME_CONFLICT, @@ -59,7 +66,7 @@ async def create_policy( hint="Choose a different name or update the existing policy.", ) - policy = Policy(name=request.name) + policy = Policy(namespace_key=namespace_key, name=request.name) db.add(policy) try: await db.commit() @@ -80,13 +87,15 @@ async def create_policy( @router.post( "/{policy_id}/controls/{control_id}", - dependencies=[Depends(require_admin_key)], response_model=AssocResponse, summary="Add control to policy", response_description="Success confirmation", ) async def add_control_to_policy( - policy_id: int, control_id: int, db: AsyncSession = Depends(get_async_db) + policy_id: int, + control_id: int, + db: AsyncSession = Depends(get_async_db), + principal: Principal = Depends(require_operation(Operation.POLICIES_UPDATE)), ) -> AssocResponse: """ Associate a control with a policy. @@ -106,8 +115,14 @@ async def add_control_to_policy( HTTPException 404: Policy or control not found HTTPException 500: Database error """ + namespace_key = principal.namespace_key # Find policy and control - pol_res = await db.execute(select(Policy).where(Policy.id == policy_id)) + pol_res = await db.execute( + select(Policy).where( + Policy.namespace_key == namespace_key, + Policy.id == policy_id, + ) + ) policy = pol_res.scalars().first() if policy is None: raise NotFoundError( @@ -119,11 +134,17 @@ async def add_control_to_policy( ) control_service = ControlService(db) - control = await control_service.get_active_control_or_404(control_id) + control = await control_service.get_active_control_or_404( + control_id, namespace_key=namespace_key + ) # Add association using INSERT ... ON CONFLICT DO NOTHING for idempotency try: - await control_service.add_control_to_policy(policy_id=policy_id, control_id=control_id) + await control_service.add_control_to_policy( + policy_id=policy_id, + control_id=control_id, + namespace_key=namespace_key, + ) await db.commit() except Exception: await db.rollback() @@ -149,13 +170,15 @@ async def add_control_to_policy( @router.delete( "/{policy_id}/controls/{control_id}", - dependencies=[Depends(require_admin_key)], response_model=AssocResponse, summary="Remove control from policy", response_description="Success confirmation", ) async def remove_control_from_policy( - policy_id: int, control_id: int, db: AsyncSession = Depends(get_async_db) + policy_id: int, + control_id: int, + db: AsyncSession = Depends(get_async_db), + principal: Principal = Depends(require_operation(Operation.POLICIES_UPDATE)), ) -> AssocResponse: """ Remove a control from a policy. @@ -175,7 +198,13 @@ async def remove_control_from_policy( HTTPException 404: Policy or control not found HTTPException 500: Database error """ - pol_res = await db.execute(select(Policy).where(Policy.id == policy_id)) + namespace_key = principal.namespace_key + pol_res = await db.execute( + select(Policy).where( + Policy.namespace_key == namespace_key, + Policy.id == policy_id, + ) + ) policy = pol_res.scalars().first() if policy is None: raise NotFoundError( @@ -187,13 +216,16 @@ async def remove_control_from_policy( ) control_service = ControlService(db) - control = await control_service.get_active_control_or_404(control_id) + control = await control_service.get_active_control_or_404( + control_id, namespace_key=namespace_key + ) # Remove association (idempotent - deleting non-existent is no-op) try: await control_service.remove_control_from_policy( policy_id=policy_id, control_id=control_id, + namespace_key=namespace_key, ) await db.commit() except Exception: @@ -222,7 +254,9 @@ async def remove_control_from_policy( response_description="List of control IDs", ) async def list_policy_controls( - policy_id: int, db: AsyncSession = Depends(get_async_db) + policy_id: int, + db: AsyncSession = Depends(get_async_db), + principal: Principal = Depends(require_operation(Operation.POLICIES_READ)), ) -> GetPolicyControlsResponse: """ List all controls associated with a policy. @@ -237,7 +271,13 @@ async def list_policy_controls( Raises: HTTPException 404: Policy not found """ - pol_res = await db.execute(select(Policy.id).where(Policy.id == policy_id)) + namespace_key = principal.namespace_key + pol_res = await db.execute( + select(Policy.id).where( + Policy.namespace_key == namespace_key, + Policy.id == policy_id, + ) + ) if pol_res.first() is None: raise NotFoundError( error_code=ErrorCode.POLICY_NOT_FOUND, @@ -247,5 +287,8 @@ async def list_policy_controls( hint="Verify the policy ID is correct and the policy has been created.", ) - control_ids = await ControlService(db).list_policy_control_ids(policy_id) + control_ids = await ControlService(db).list_policy_control_ids( + policy_id, + namespace_key=namespace_key, + ) return GetPolicyControlsResponse(control_ids=control_ids) diff --git a/server/src/agent_control_server/main.py b/server/src/agent_control_server/main.py index bc1bf04b..a1561e63 100644 --- a/server/src/agent_control_server/main.py +++ b/server/src/agent_control_server/main.py @@ -252,7 +252,7 @@ async def attach_version_header(request, call_next): # type: ignore[no-untyped- # Register handler for FastAPI's RequestValidationError (Pydantic validation) app.add_exception_handler(RequestValidationError, validation_exception_handler) # type: ignore[arg-type] -# Register handler for standard HTTPException (legacy code, FastAPI internals) +# Register handler for standard HTTPException (older routes, FastAPI internals) app.add_exception_handler(HTTPException, http_exception_handler) # type: ignore[arg-type] # Register catch-all handler for unexpected exceptions @@ -261,16 +261,18 @@ async def attach_version_header(request, call_next): # type: ignore[no-untyped- # API v1 prefix for all routes api_v1_prefix = f"{settings.api_prefix}/{settings.api_version}" -# Protected routes (require valid API key) +# API routers. Routers migrated to the auth framework mount the +# non-validating header extractor only so OpenAPI advertises X-API-Key; +# each endpoint's ``require_operation`` dependency owns authn + authz. app.include_router( agent_router, prefix=api_v1_prefix, - dependencies=[Depends(require_api_key)], + dependencies=[Depends(get_api_key_from_header)], ) app.include_router( policy_router, prefix=api_v1_prefix, - dependencies=[Depends(require_api_key)], + dependencies=[Depends(get_api_key_from_header)], ) app.include_router( # Endpoint dependencies handle auth; this advertises X-API-Key. @@ -281,11 +283,11 @@ async def attach_version_header(request, call_next): # type: ignore[no-untyped- app.include_router( # The auth framework on each endpoint owns authentication and # authorization for control bindings, so this router is mounted - # without the legacy router-level gate. See ``auth_framework`` for + # without the router-level auth gate. See ``auth_framework`` for # the provider contract. ``get_api_key_from_header`` is a non- # validating extractor (``auto_error=False``); it is attached purely # so the generated OpenAPI spec advertises the X-API-Key requirement - # on these routes — without it, downstream SDK generators would treat + # on these routes - without it, downstream SDK generators would treat # the routes as unauthenticated. control_binding_router, prefix=api_v1_prefix, @@ -309,9 +311,10 @@ async def attach_version_header(request, call_next): # type: ignore[no-untyped- app.include_router( evaluation_router, prefix=api_v1_prefix, - dependencies=[Depends(require_api_key)], + dependencies=[Depends(get_api_key_from_header)], ) +# Evaluator discovery still uses the local credential dependency. app.include_router( evaluator_router, prefix=api_v1_prefix, @@ -324,7 +327,7 @@ async def attach_version_header(request, call_next): # type: ignore[no-untyped- prefix=api_v1_prefix, ) -# System routes (config, login, logout) — no auth required +# System routes (config, login, logout) - no auth required app.include_router( system_router, prefix=settings.api_prefix, diff --git a/server/src/agent_control_server/namespace.py b/server/src/agent_control_server/namespace.py deleted file mode 100644 index 30e30be5..00000000 --- a/server/src/agent_control_server/namespace.py +++ /dev/null @@ -1,23 +0,0 @@ -"""Namespace resolution for request-scoped scoping. - -V1 always resolves to the default namespace. The function exists as a -single seam so a future change can switch every namespace-scoped -endpoint to a real per-request resolver without touching each call -site. Overriding the dependency in V1 is not supported: only this -binding/evaluation layer reads it; controls, agents, and policies still -write under the default namespace, so an override here would create -inconsistent rows. Future work will thread a single resolver through -every write path together. -""" - -from __future__ import annotations - -from .models import DEFAULT_NAMESPACE_KEY - - -def get_namespace_key() -> str: - """Return the namespace_key for the current request. - - V1 returns ``DEFAULT_NAMESPACE_KEY`` unconditionally. - """ - return DEFAULT_NAMESPACE_KEY diff --git a/server/src/agent_control_server/services/controls.py b/server/src/agent_control_server/services/controls.py index 263120b7..e3a5fd26 100644 --- a/server/src/agent_control_server/services/controls.py +++ b/server/src/agent_control_server/services/controls.py @@ -96,9 +96,15 @@ class ControlService: def __init__(self, db: AsyncSession) -> None: self._db = db - def create_control(self, *, name: str, data: dict[str, Any]) -> Control: + def create_control( + self, + *, + namespace_key: str, + name: str, + data: dict[str, Any], + ) -> Control: """Create a new pending control row.""" - control = Control(name=name, data=data) + control = Control(namespace_key=namespace_key, name=name, data=data) self._db.add(control) return control @@ -128,10 +134,13 @@ async def get_control_or_404( self, control_id: int, *, + namespace_key: str | None = None, for_update: bool = False, ) -> Control: """Load any control row, including soft-deleted controls.""" stmt = select(Control).where(Control.id == control_id) + if namespace_key is not None: + stmt = stmt.where(Control.namespace_key == namespace_key) if for_update: stmt = stmt.with_for_update() result = await self._db.execute(stmt) @@ -151,17 +160,19 @@ async def get_active_control_or_404( control_id: int, *, for_update: bool = False, - namespace_key: str | None = None, + namespace_key: str, ) -> Control: """Load an active control row or raise CONTROL_NOT_FOUND. - When ``namespace_key`` is supplied, the lookup is scoped to that - namespace; a control that exists only in another namespace - surfaces as 404 (non-disclosing). + The lookup is scoped to the supplied namespace; a control that + exists only in another namespace surfaces as 404 + (non-disclosing). """ - stmt = select(Control).where(Control.id == control_id, Control.deleted_at.is_(None)) - if namespace_key is not None: - stmt = stmt.where(Control.namespace_key == namespace_key) + stmt = select(Control).where( + Control.id == control_id, + Control.namespace_key == namespace_key, + Control.deleted_at.is_(None), + ) if for_update: stmt = stmt.with_for_update() result = await self._db.execute(stmt) @@ -180,10 +191,15 @@ async def active_control_name_exists( self, name: str, *, + namespace_key: str, exclude_control_id: int | None = None, ) -> bool: """Return whether an active control already uses the provided name.""" - stmt = select(Control.id).where(Control.name == name, Control.deleted_at.is_(None)) + stmt = select(Control.id).where( + Control.namespace_key == namespace_key, + Control.name == name, + Control.deleted_at.is_(None), + ) if exclude_control_id is not None: stmt = stmt.where(Control.id != exclude_control_id) result = await self._db.execute(stmt) @@ -216,11 +232,12 @@ async def list_versions( self, control_id: int, *, + namespace_key: str, cursor: int | None, limit: int, ) -> ControlVersionPage: """Return control versions newest-first with cursor pagination.""" - await self.get_control_or_404(control_id) + await self.get_control_or_404(control_id, namespace_key=namespace_key) total_result = await self._db.execute( select(func.count()) @@ -255,9 +272,11 @@ async def list_versions( next_cursor=next_cursor, ) - async def get_version_or_404(self, control_id: int, version_num: int) -> ControlVersion: + async def get_version_or_404( + self, control_id: int, version_num: int, *, namespace_key: str + ) -> ControlVersion: """Load a specific version row for a control.""" - await self.get_control_or_404(control_id) + await self.get_control_or_404(control_id, namespace_key=namespace_key) result = await self._db.execute( select(ControlVersion).where( @@ -303,12 +322,17 @@ async def list_controls_for_policy( result = await self._db.execute(stmt) return list(result.scalars().unique().all()) - async def list_policy_control_ids(self, policy_id: int) -> list[int]: + async def list_policy_control_ids(self, policy_id: int, *, namespace_key: str) -> list[int]: """Return active control IDs directly associated with a policy.""" result = await self._db.execute( select(policy_controls.c.control_id) .join(Control, Control.id == policy_controls.c.control_id) - .where(policy_controls.c.policy_id == policy_id, Control.deleted_at.is_(None)) + .where( + policy_controls.c.namespace_key == namespace_key, + policy_controls.c.policy_id == policy_id, + Control.namespace_key == namespace_key, + Control.deleted_at.is_(None), + ) .order_by(policy_controls.c.control_id) ) return [cast(int, row[0]) for row in result.all()] @@ -396,6 +420,7 @@ async def list_runtime_controls_for_agent( async def list_controls_page( self, *, + namespace_key: str, cursor: int | None, limit: int, name: str | None, @@ -407,7 +432,11 @@ async def list_controls_page( tag: str | None, ) -> ControlListPage: """Return paginated active controls for the browse endpoint.""" - query = select(Control).where(Control.deleted_at.is_(None)).order_by(Control.id.desc()) + query = ( + select(Control) + .where(Control.namespace_key == namespace_key, Control.deleted_at.is_(None)) + .order_by(Control.id.desc()) + ) query = self._apply_control_list_filters( query, name=name, @@ -424,7 +453,11 @@ async def list_controls_page( result = await self._db.execute(query.limit(limit + 1)) controls = list(result.scalars().all()) - total_query = select(func.count()).select_from(Control).where(Control.deleted_at.is_(None)) + total_query = ( + select(func.count()) + .select_from(Control) + .where(Control.namespace_key == namespace_key, Control.deleted_at.is_(None)) + ) total_query = self._apply_control_list_filters( total_query, name=name, @@ -453,7 +486,9 @@ async def list_controls_page( next_cursor=next_cursor, ) - async def list_control_usage(self, control_ids: Sequence[int]) -> dict[int, ControlUsage]: + async def list_control_usage( + self, control_ids: Sequence[int], *, namespace_key: str + ) -> dict[int, ControlUsage]: """Return representative agent usage and usage counts for the provided controls.""" if not control_ids: return {} @@ -465,8 +500,16 @@ async def list_control_usage(self, control_ids: Sequence[int]) -> dict[int, Cont agent_policies.c.agent_name, ) .select_from(policy_controls) - .join(agent_policies, policy_controls.c.policy_id == agent_policies.c.policy_id) - .where(policy_controls.c.control_id.in_(control_ids)) + .join( + agent_policies, + (policy_controls.c.policy_id == agent_policies.c.policy_id) + & (policy_controls.c.namespace_key == agent_policies.c.namespace_key), + ) + .where( + policy_controls.c.namespace_key == namespace_key, + agent_policies.c.namespace_key == namespace_key, + policy_controls.c.control_id.in_(control_ids), + ) ) direct_agents_query = ( select( @@ -474,7 +517,10 @@ async def list_control_usage(self, control_ids: Sequence[int]) -> dict[int, Cont agent_controls.c.agent_name, ) .select_from(agent_controls) - .where(agent_controls.c.control_id.in_(control_ids)) + .where( + agent_controls.c.namespace_key == namespace_key, + agent_controls.c.control_id.in_(control_ids), + ) ) agents_result = await self._db.execute(union_all(policy_agents_query, direct_agents_query)) for control_id, agent_name in agents_result.all(): @@ -491,6 +537,8 @@ async def list_control_usage(self, control_ids: Sequence[int]) -> dict[int, Cont async def list_active_control_counts_by_agent( self, agent_names: Sequence[str], + *, + namespace_key: str, ) -> dict[str, int]: """Return active control counts keyed by agent name.""" if not agent_names: @@ -503,15 +551,24 @@ async def list_active_control_counts_by_agent( ) .select_from( agent_policies.join( - policy_controls, agent_policies.c.policy_id == policy_controls.c.policy_id + policy_controls, + (agent_policies.c.policy_id == policy_controls.c.policy_id) + & (agent_policies.c.namespace_key == policy_controls.c.namespace_key), ) ) - .where(agent_policies.c.agent_name.in_(agent_names)) + .where( + agent_policies.c.namespace_key == namespace_key, + policy_controls.c.namespace_key == namespace_key, + agent_policies.c.agent_name.in_(agent_names), + ) ) direct_associations = select( agent_controls.c.agent_name.label("agent_name"), agent_controls.c.control_id.label("control_id"), - ).where(agent_controls.c.agent_name.in_(agent_names)) + ).where( + agent_controls.c.namespace_key == namespace_key, + agent_controls.c.agent_name.in_(agent_names), + ) all_associations = union_all(policy_associations, direct_associations).subquery() result = await self._db.execute( @@ -521,6 +578,7 @@ async def list_active_control_counts_by_agent( ) .join(Control, all_associations.c.control_id == Control.id) .where( + Control.namespace_key == namespace_key, Control.deleted_at.is_(None), or_( Control.data["enabled"].astext == "true", @@ -531,19 +589,28 @@ async def list_active_control_counts_by_agent( ) return {cast(str, row[0]): cast(int, row[1]) for row in result.all()} - async def add_control_to_policy(self, *, policy_id: int, control_id: int) -> None: + async def add_control_to_policy( + self, *, policy_id: int, control_id: int, namespace_key: str + ) -> None: """Create a policy-control association if it does not already exist.""" await self._db.execute( pg_insert(policy_controls) - .values(policy_id=policy_id, control_id=control_id) + .values( + namespace_key=namespace_key, + policy_id=policy_id, + control_id=control_id, + ) .on_conflict_do_nothing() ) - async def remove_control_from_policy(self, *, policy_id: int, control_id: int) -> None: + async def remove_control_from_policy( + self, *, policy_id: int, control_id: int, namespace_key: str + ) -> None: """Remove a policy-control association if it exists.""" await self._db.execute( delete(policy_controls).where( - (policy_controls.c.policy_id == policy_id) + (policy_controls.c.namespace_key == namespace_key) + & (policy_controls.c.policy_id == policy_id) & (policy_controls.c.control_id == control_id) ) ) @@ -613,16 +680,24 @@ async def remove_control_from_agent( control_still_active=policy_inheritance_result.first() is not None, ) - async def list_control_associations(self, control_id: int) -> ControlAssociations: + async def list_control_associations( + self, control_id: int, *, namespace_key: str + ) -> ControlAssociations: """Return all policy and direct agent associations for a control.""" policy_assoc_query = select( policy_controls.c.policy_id.label("policy_id"), literal(None, type_=String).label("agent_name"), - ).where(policy_controls.c.control_id == control_id) + ).where( + policy_controls.c.namespace_key == namespace_key, + policy_controls.c.control_id == control_id, + ) agent_assoc_query = select( literal(None, type_=Integer).label("policy_id"), agent_controls.c.agent_name.label("agent_name"), - ).where(agent_controls.c.control_id == control_id) + ).where( + agent_controls.c.namespace_key == namespace_key, + agent_controls.c.control_id == control_id, + ) assoc_result = await self._db.execute(union_all(policy_assoc_query, agent_assoc_query)) policy_ids: set[int] = set() @@ -638,16 +713,26 @@ async def list_control_associations(self, control_id: int) -> ControlAssociation agent_names=sorted(agent_names), ) - async def remove_all_control_associations(self, control_id: int) -> ControlAssociations: + async def remove_all_control_associations( + self, control_id: int, *, namespace_key: str + ) -> ControlAssociations: """Remove all policy and direct agent associations for a control.""" - associations = await self.list_control_associations(control_id) + associations = await self.list_control_associations( + control_id, namespace_key=namespace_key + ) if associations.policy_ids: await self._db.execute( - delete(policy_controls).where(policy_controls.c.control_id == control_id) + delete(policy_controls).where( + policy_controls.c.namespace_key == namespace_key, + policy_controls.c.control_id == control_id, + ) ) if associations.agent_names: await self._db.execute( - delete(agent_controls).where(agent_controls.c.control_id == control_id) + delete(agent_controls).where( + agent_controls.c.namespace_key == namespace_key, + agent_controls.c.control_id == control_id, + ) ) return associations diff --git a/server/tests/test_auth_framework.py b/server/tests/test_auth_framework.py index 96c4aad8..20c58aed 100644 --- a/server/tests/test_auth_framework.py +++ b/server/tests/test_auth_framework.py @@ -7,7 +7,6 @@ import httpx import pytest - from agent_control_server.auth_framework.core import ( Operation, Principal, @@ -20,6 +19,8 @@ AccessLevel, HeaderAuthProvider, HttpUpstreamAuthProvider, + LocalJwtVerifyProvider, + NoAuthProvider, ) from agent_control_server.auth_framework.providers.header import ( DEFAULT_OPERATION_ACCESS, @@ -64,6 +65,35 @@ def test_default_operation_access_covers_every_operation(): assert not missing, f"Operations missing default access mapping: {missing}" +# --------------------------------------------------------------------------- +# NoAuthProvider +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_no_auth_provider_allows_any_operation(): + provider = NoAuthProvider(default_namespace_key="ns-local") + + principal = await provider.authorize( + _build_request(), + Operation.CONTROLS_DELETE, + ) + + assert principal == Principal(namespace_key="ns-local") + + +@pytest.mark.asyncio +async def test_no_auth_provider_grants_runtime_exchange_scope(): + provider = NoAuthProvider() + + principal = await provider.authorize( + _build_request(), + Operation.RUNTIME_TOKEN_EXCHANGE, + ) + + assert principal.scopes == (Operation.RUNTIME_USE.value,) + + # --------------------------------------------------------------------------- # HeaderAuthProvider # --------------------------------------------------------------------------- @@ -101,7 +131,7 @@ async def test_header_provider_public_returns_default_namespace(): @pytest.mark.asyncio -async def test_header_provider_authenticated_calls_legacy_validator(): +async def test_header_provider_authenticated_calls_local_validator(): provider = HeaderAuthProvider() expected_client = MagicMock(is_admin=False, key_id="abc12345") @@ -230,6 +260,75 @@ def factory(request: httpx.Request) -> httpx.Response: assert captured["headers"]["x-custom-token"] == "shh" +@pytest.mark.asyncio +async def test_http_upstream_forwards_extra_headers(): + # Given: a provider configured with an extra header in its forward list + captured: dict[str, Any] = {} + + def factory(request: httpx.Request) -> httpx.Response: + captured["headers"] = dict(request.headers) + return httpx.Response(200, json={"namespace_key": "ns"}) + + provider = _build_upstream( + factory, + config_overrides={"extra_forward_headers": ("X-Deployer-Auth",)}, + ) + + # When: the inbound request carries the extra header + inbound = _build_request(headers={"X-Deployer-Auth": "k_abc", "X-API-Key": "k1"}) + await provider.authorize(inbound, Operation.CONTROL_BINDINGS_READ) + + # Then: both the default and the extra header reach the upstream + assert captured["headers"]["x-deployer-auth"] == "k_abc" + assert captured["headers"]["x-api-key"] == "k1" + + +@pytest.mark.asyncio +async def test_http_upstream_default_forward_set_unchanged(): + # Given: a provider with no extra_forward_headers + captured: dict[str, Any] = {} + + def factory(request: httpx.Request) -> httpx.Response: + captured["headers"] = dict(request.headers) + return httpx.Response(200, json={"namespace_key": "ns"}) + + provider = _build_upstream(factory) + + # When: the inbound carries an unlisted header alongside a default one + inbound = _build_request( + headers={"X-API-Key": "k1", "X-Deployer-Auth": "should-not-forward"} + ) + await provider.authorize(inbound, Operation.CONTROL_BINDINGS_READ) + + # Then: only the default-set header reaches the upstream + assert captured["headers"].get("x-api-key") == "k1" + assert "x-deployer-auth" not in captured["headers"] + + +@pytest.mark.asyncio +async def test_http_upstream_extra_forward_dedupes_against_defaults(): + # Given: extra list duplicates a default header (different case) + captured: dict[str, Any] = {} + + def factory(request: httpx.Request) -> httpx.Response: + captured["headers"] = dict(request.headers) + return httpx.Response(200, json={"namespace_key": "ns"}) + + provider = _build_upstream( + factory, + config_overrides={"extra_forward_headers": ("x-api-key", "Authorization")}, + ) + + # When: inbound has both + inbound = _build_request(headers={"X-API-Key": "k1", "Authorization": "Bearer t"}) + await provider.authorize(inbound, Operation.CONTROL_BINDINGS_READ) + + # Then: each header appears exactly once on the upstream request + forwarded = captured["headers"] + assert sum(1 for k in forwarded if k.lower() == "x-api-key") == 1 + assert sum(1 for k in forwarded if k.lower() == "authorization") == 1 + + @pytest.mark.asyncio @pytest.mark.parametrize( "status, expected", @@ -600,7 +699,6 @@ def test_runtime_token_rejects_naive_upstream_expires_at(): def test_runtime_token_rejects_management_token_passed_to_runtime_verify(): """A token without ``domain=runtime`` must be rejected by runtime verify.""" import jwt - from agent_control_server.auth_framework.runtime_token import ( RuntimeTokenError, verify_runtime_token, @@ -945,6 +1043,177 @@ def test_runtime_ttl_loader_accepts_max(monkeypatch): ) +def test_build_default_provider_accepts_none_mode(monkeypatch): + from agent_control_server.auth_framework import config as auth_config + + monkeypatch.setenv("AGENT_CONTROL_AUTH_MODE", "none") + + assert isinstance(auth_config._build_default_provider(), NoAuthProvider) + + +def test_resolve_runtime_mode_defaults_to_default_without_secret(monkeypatch): + from agent_control_server.auth_framework import config as auth_config + + monkeypatch.delenv("AGENT_CONTROL_RUNTIME_AUTH_MODE", raising=False) + monkeypatch.delenv("AGENT_CONTROL_RUNTIME_TOKEN_SECRET", raising=False) + + assert auth_config._resolve_runtime_mode() == "default" + + +def test_resolve_runtime_mode_defaults_to_jwt_with_secret(monkeypatch): + from agent_control_server.auth_framework import config as auth_config + + monkeypatch.delenv("AGENT_CONTROL_RUNTIME_AUTH_MODE", raising=False) + monkeypatch.setenv("AGENT_CONTROL_RUNTIME_TOKEN_SECRET", _TEST_SECRET) + + assert auth_config._resolve_runtime_mode() == "jwt" + + +def test_configure_runtime_none_installs_no_auth_provider(monkeypatch): + from agent_control_server.auth_framework import config as auth_config + + clear_authorizers() + + monkeypatch.setenv("AGENT_CONTROL_RUNTIME_AUTH_MODE", "none") + monkeypatch.delenv("AGENT_CONTROL_RUNTIME_TOKEN_SECRET", raising=False) + + auth_config.configure_auth_from_env() + + assert isinstance(get_authorizer(Operation.RUNTIME_USE), NoAuthProvider) + assert auth_config.runtime_auth_config() is None + + +def test_configure_runtime_api_key_ignores_jwt_secret(monkeypatch): + from agent_control_server.auth_framework import config as auth_config + + clear_authorizers() + + monkeypatch.setenv("AGENT_CONTROL_RUNTIME_AUTH_MODE", "api_key") + monkeypatch.setenv("AGENT_CONTROL_RUNTIME_TOKEN_SECRET", _TEST_SECRET) + + auth_config.configure_auth_from_env() + + assert isinstance(get_authorizer(Operation.RUNTIME_USE), HeaderAuthProvider) + assert auth_config.runtime_auth_config() is None + + +def test_configure_runtime_unset_preserves_no_auth_default(monkeypatch): + from agent_control_server.auth_framework import config as auth_config + + clear_authorizers() + + monkeypatch.setenv("AGENT_CONTROL_AUTH_MODE", "none") + monkeypatch.delenv("AGENT_CONTROL_RUNTIME_AUTH_MODE", raising=False) + monkeypatch.delenv("AGENT_CONTROL_RUNTIME_TOKEN_SECRET", raising=False) + + auth_config.configure_auth_from_env() + + assert isinstance(get_authorizer(Operation.RUNTIME_USE), NoAuthProvider) + assert auth_config.runtime_auth_config() is None + + +@pytest.mark.asyncio +async def test_configure_runtime_unset_preserves_http_upstream_default(monkeypatch): + from agent_control_server.auth_framework import config as auth_config + + clear_authorizers() + + monkeypatch.setenv("AGENT_CONTROL_AUTH_MODE", "http_upstream") + monkeypatch.setenv("AGENT_CONTROL_AUTH_UPSTREAM_URL", "https://auth.example.test/check") + monkeypatch.delenv("AGENT_CONTROL_RUNTIME_AUTH_MODE", raising=False) + monkeypatch.delenv("AGENT_CONTROL_RUNTIME_TOKEN_SECRET", raising=False) + + try: + auth_config.configure_auth_from_env() + + default_provider = get_authorizer(Operation.CONTROLS_READ) + runtime_provider = get_authorizer(Operation.RUNTIME_USE) + assert isinstance(default_provider, HttpUpstreamAuthProvider) + assert runtime_provider is default_provider + assert auth_config.runtime_auth_config() is None + finally: + await auth_config.teardown_auth() + + +@pytest.mark.asyncio +async def test_configure_http_upstream_management_with_jwt_runtime(monkeypatch): + from agent_control_server.auth_framework import config as auth_config + + clear_authorizers() + + monkeypatch.setenv("AGENT_CONTROL_AUTH_MODE", "http_upstream") + monkeypatch.setenv("AGENT_CONTROL_AUTH_UPSTREAM_URL", "https://auth.example.test/check") + monkeypatch.setenv("AGENT_CONTROL_RUNTIME_AUTH_MODE", "jwt") + monkeypatch.setenv("AGENT_CONTROL_RUNTIME_TOKEN_SECRET", _TEST_SECRET) + + try: + auth_config.configure_auth_from_env() + + assert isinstance(get_authorizer(Operation.CONTROLS_READ), HttpUpstreamAuthProvider) + assert isinstance(get_authorizer(Operation.RUNTIME_USE), LocalJwtVerifyProvider) + runtime_config = auth_config.runtime_auth_config() + assert runtime_config is not None + assert runtime_config.secret == _TEST_SECRET + finally: + await auth_config.teardown_auth() + + +@pytest.mark.parametrize( + "raw, expected", + [ + (None, ()), + ("", ()), + (" ", ()), + ("X-One", ("X-One",)), + ("X-One,X-Two", ("X-One", "X-Two")), + (" X-One , X-Two ", ("X-One", "X-Two")), + ("X-One,,X-Two", ("X-One", "X-Two")), + ("X-One,x-one,X-One", ("X-One",)), + ("X-A,X-B,x-a,X-C,X-b", ("X-A", "X-B", "X-C")), + ], +) +def test_parse_extra_forward_headers(raw, expected): + from agent_control_server.auth_framework.config import _parse_extra_forward_headers + + assert _parse_extra_forward_headers(raw) == expected + + +@pytest.mark.asyncio +async def test_configure_http_upstream_extra_forward_headers_env(monkeypatch): + """Setting the env var threads extra_forward_headers into the provider.""" + from agent_control_server.auth_framework import config as auth_config + + clear_authorizers() + + monkeypatch.setenv("AGENT_CONTROL_AUTH_MODE", "http_upstream") + monkeypatch.setenv("AGENT_CONTROL_AUTH_UPSTREAM_URL", "https://auth.example.test/check") + monkeypatch.setenv( + "AGENT_CONTROL_AUTH_UPSTREAM_EXTRA_FORWARD_HEADERS", + "X-Deployer-Auth, X-Deployer-Trace", + ) + + try: + auth_config.configure_auth_from_env() + provider = get_authorizer(Operation.CONTROLS_READ) + assert isinstance(provider, HttpUpstreamAuthProvider) + assert provider._config.extra_forward_headers == ( + "X-Deployer-Auth", + "X-Deployer-Trace", + ) + finally: + await auth_config.teardown_auth() + + +def test_configure_runtime_jwt_requires_secret(monkeypatch): + from agent_control_server.auth_framework import config as auth_config + + monkeypatch.setenv("AGENT_CONTROL_RUNTIME_AUTH_MODE", "jwt") + monkeypatch.delenv("AGENT_CONTROL_RUNTIME_TOKEN_SECRET", raising=False) + + with pytest.raises(RuntimeError, match="requires AGENT_CONTROL_RUNTIME_TOKEN_SECRET"): + auth_config.configure_auth_from_env() + + def test_configure_then_reconfigure_clears_runtime_override(monkeypatch): """Reconfiguring without a runtime secret must drop the override.""" from agent_control_server.auth_framework import config as auth_config diff --git a/server/tests/test_controls_additional.py b/server/tests/test_controls_additional.py index b4922b9d..dfbb15f5 100644 --- a/server/tests/test_controls_additional.py +++ b/server/tests/test_controls_additional.py @@ -8,19 +8,19 @@ from unittest.mock import AsyncMock, MagicMock import pytest +from agent_control_evaluators import RegexEvaluatorConfig +from agent_control_models import ConditionNode from fastapi.testclient import TestClient from sqlalchemy import text from sqlalchemy.exc import IntegrityError from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.orm import Session -from agent_control_models import ConditionNode +from agent_control_server.auth_framework import Principal from agent_control_server.db import get_async_db -from agent_control_server.models import Control - -from agent_control_evaluators import RegexEvaluatorConfig from agent_control_server.endpoints import controls as controls_module from agent_control_server.main import app +from agent_control_server.models import DEFAULT_NAMESPACE_KEY, Control from .conftest import engine from .utils import VALID_CONTROL_PAYLOAD @@ -1106,7 +1106,12 @@ def model_dump(self, *args: object, **kwargs: object) -> dict[str, object]: request = SimpleNamespace(data=DummyData(payload)) # When: updating the control data with a non-Pydantic selector - response = await controls_module.set_control_data(control.id, request, async_db) + response = await controls_module.set_control_data( + control.id, + request, + async_db, + principal=Principal(namespace_key=DEFAULT_NAMESPACE_KEY), + ) # Then: the update succeeds and uses the original selector serialization assert response.success is True diff --git a/server/tests/test_controls_auth.py b/server/tests/test_controls_auth.py index 1a2af21f..04f44ca4 100644 --- a/server/tests/test_controls_auth.py +++ b/server/tests/test_controls_auth.py @@ -4,14 +4,12 @@ import uuid -import pytest +from agent_control_server.auth_framework import set_authorizer +from agent_control_server.auth_framework.providers import NoAuthProvider from fastapi.testclient import TestClient -from agent_control_server.config import auth_settings - from .utils import VALID_CONTROL_PAYLOAD - _CONTROLS_URL = "/api/v1/controls" _TEMPLATES_URL = "/api/v1/control-templates" @@ -283,21 +281,16 @@ def test_unauthenticated_cannot_render_template( # --------------------------------------------------------------------------- -# No-auth deployment mode: api_key_enabled=False bypasses every gate. +# No-auth deployment mode: explicit provider bypasses every gate. # --------------------------------------------------------------------------- def test_no_auth_mode_allows_writes_without_credentials( unauthenticated_client: TestClient, - monkeypatch: pytest.MonkeyPatch, ) -> None: - """When ``api_key_enabled`` is False, the ``HeaderAuthProvider`` - short-circuits to a non-admin ``Principal`` for every operation, - including admin-level writes. This pins the "no auth" deployment - path so a future refactor can't silently start enforcing. - """ - # Given: api_key_enabled is False (single-tenant OSS dev mode) - monkeypatch.setattr(auth_settings, "api_key_enabled", False) + """Explicit no-auth provider allows every operation without credentials.""" + # Given: the request-auth framework is in no-auth mode + set_authorizer(NoAuthProvider()) # When: an unauthenticated client creates a control resp = unauthenticated_client.put( @@ -311,4 +304,3 @@ def test_no_auth_mode_allows_writes_without_credentials( # Then: the create succeeds because auth is disabled at the provider assert resp.status_code == 200, resp.text assert "control_id" in resp.json() - diff --git a/server/tests/test_principal_namespace_flow.py b/server/tests/test_principal_namespace_flow.py new file mode 100644 index 00000000..40ecd216 --- /dev/null +++ b/server/tests/test_principal_namespace_flow.py @@ -0,0 +1,141 @@ +"""HTTP-level coverage for principal-derived namespace scoping.""" + +from __future__ import annotations + +import uuid +from typing import Any + +from fastapi import FastAPI, Request +from fastapi.testclient import TestClient + +from agent_control_server.auth_framework import ( + Operation, + Principal, + set_authorizer, +) + +from .utils import VALID_CONTROL_PAYLOAD + + +class HeaderNamespaceAuthorizer: + """Test authorizer that maps a request header to ``Principal.namespace_key``.""" + + async def authorize( + self, + request: Request, + operation: Operation, + context: dict[str, Any] | None = None, + ) -> Principal: + del context + scopes = ( + (Operation.RUNTIME_USE.value,) + if operation is Operation.RUNTIME_TOKEN_EXCHANGE + else () + ) + return Principal( + namespace_key=request.headers.get("X-Test-Namespace", "default"), + is_admin=True, + scopes=scopes, + ) + + +def _client(app: FastAPI, namespace_key: str) -> TestClient: + return TestClient( + app, + raise_server_exceptions=True, + headers={"X-Test-Namespace": namespace_key}, + ) + + +def _agent_payload(agent_name: str) -> dict[str, Any]: + return { + "agent": { + "agent_name": agent_name, + "agent_description": "test agent", + "agent_version": "1.0", + }, + "steps": [], + } + + +def _evaluation_payload(agent_name: str) -> dict[str, Any]: + return { + "agent_name": agent_name, + "step": { + "type": "llm", + "name": "test-step", + "input": "x marks the spot", + "context": {}, + }, + "stage": "pre", + "target_type": "env", + "target_id": "prod", + } + + +def test_principal_namespace_scopes_management_and_runtime(app: FastAPI) -> None: + set_authorizer(HeaderNamespaceAuthorizer()) + + ns_a = _client(app, "ns-a") + ns_b = _client(app, "ns-b") + agent_name = f"agent-{uuid.uuid4().hex[:12]}" + + register_a = ns_a.post("/api/v1/agents/initAgent", json=_agent_payload(agent_name)) + register_b = ns_b.post("/api/v1/agents/initAgent", json=_agent_payload(agent_name)) + assert register_a.status_code == 200, register_a.text + assert register_b.status_code == 200, register_b.text + + create_control = ns_a.put( + "/api/v1/controls", + json={ + "name": f"control-{uuid.uuid4().hex[:12]}", + "data": VALID_CONTROL_PAYLOAD, + }, + ) + assert create_control.status_code == 200, create_control.text + control_id = int(create_control.json()["control_id"]) + + policy = ns_a.put( + "/api/v1/policies", + json={"name": f"policy-{uuid.uuid4().hex[:12]}"}, + ) + assert policy.status_code == 200, policy.text + policy_id = int(policy.json()["policy_id"]) + attach_to_policy = ns_a.post(f"/api/v1/policies/{policy_id}/controls/{control_id}") + assert attach_to_policy.status_code == 200, attach_to_policy.text + + binding = ns_a.put( + "/api/v1/control-bindings", + json={ + "target_type": "env", + "target_id": "prod", + "control_id": control_id, + "enabled": True, + }, + ) + assert binding.status_code == 200, binding.text + + assert ns_b.get(f"/api/v1/controls/{control_id}").status_code == 404 + assert ns_b.get(f"/api/v1/policies/{policy_id}/controls").status_code == 404 + assert ns_b.get("/api/v1/control-bindings").json()["bindings"] == [] + + eval_a = ns_a.post("/api/v1/evaluation", json=_evaluation_payload(agent_name)) + assert eval_a.status_code == 200, eval_a.text + assert eval_a.json()["is_safe"] is False + assert eval_a.json()["matches"][0]["control_id"] == control_id + + eval_b = ns_b.post("/api/v1/evaluation", json=_evaluation_payload(agent_name)) + assert eval_b.status_code == 200, eval_b.text + assert eval_b.json()["is_safe"] is True + + +def test_duplicate_control_names_allowed_across_principal_namespaces(app: FastAPI) -> None: + set_authorizer(HeaderNamespaceAuthorizer()) + + ns_a = _client(app, "ns-a") + ns_b = _client(app, "ns-b") + control_name = f"control-{uuid.uuid4().hex[:12]}" + payload = {"name": control_name, "data": VALID_CONTROL_PAYLOAD} + + assert ns_a.put("/api/v1/controls", json=payload).status_code == 200 + assert ns_b.put("/api/v1/controls", json=payload).status_code == 200 diff --git a/server/tests/test_runtime_token_exchange_endpoint.py b/server/tests/test_runtime_token_exchange_endpoint.py index 8d333a5c..1b1edae2 100644 --- a/server/tests/test_runtime_token_exchange_endpoint.py +++ b/server/tests/test_runtime_token_exchange_endpoint.py @@ -11,8 +11,6 @@ from datetime import UTC, datetime, timedelta import pytest -from fastapi.testclient import TestClient - from agent_control_server.auth_framework import Operation, Principal from agent_control_server.auth_framework.config import ( RuntimeAuthConfig, @@ -25,6 +23,7 @@ from agent_control_server.auth_framework.providers import ( LocalJwtVerifyProvider, ) +from fastapi.testclient import TestClient _TEST_SECRET = "test-runtime-secret-12345678901234567890" @@ -180,6 +179,39 @@ async def test_exchange_then_verify_full_round_trip(client: TestClient, runtime_ assert principal.caller_id == "actor-rt" +def test_evaluation_rejects_runtime_jwt_for_wrong_target( + client: TestClient, + runtime_config_enabled, +): + """A runtime JWT minted for one target cannot be used for another target.""" + stub = _StubExchangeAuthorizer(actor_id="actor-rt", scopes=("runtime.use",)) + clear_authorizers() + set_authorizer(stub) + set_authorizer(LocalJwtVerifyProvider(secret=_TEST_SECRET), operation=Operation.RUNTIME_USE) + + exchange = client.post( + "/api/v1/auth/runtime-token-exchange", + json={"target_type": "log_stream", "target_id": "ls-allowed"}, + ) + assert exchange.status_code == 200, exchange.text + token = exchange.json()["token"] + + response = client.post( + "/api/v1/evaluation", + headers={"Authorization": f"Bearer {token}"}, + json={ + "agent_name": "agent", + "step": {"type": "llm", "name": "step", "input": "hello"}, + "stage": "pre", + "target_type": "log_stream", + "target_id": "ls-other", + }, + ) + + assert response.status_code == 403, response.text + assert response.json()["detail"] == "Runtime token target_id does not match the request." + + def test_exchange_endpoint_502_when_upstream_grant_already_expired( client: TestClient, runtime_config_enabled, diff --git a/server/tests/test_services_controls.py b/server/tests/test_services_controls.py index b858c527..3815f26b 100644 --- a/server/tests/test_services_controls.py +++ b/server/tests/test_services_controls.py @@ -8,10 +8,6 @@ import pytest from agent_control_models.errors import ErrorCode -from sqlalchemy import insert, select -from sqlalchemy.ext.asyncio import AsyncSession -from sqlalchemy.orm import Session - from agent_control_server.errors import APIValidationError from agent_control_server.models import ( DEFAULT_NAMESPACE_KEY, @@ -27,6 +23,9 @@ from agent_control_server.services.controls import ( ControlService, ) +from sqlalchemy import insert, select +from sqlalchemy.ext.asyncio import AsyncSession +from sqlalchemy.orm import Session from .conftest import AsyncSessionTest, engine from .utils import VALID_CONTROL_PAYLOAD @@ -70,7 +69,11 @@ async def _create_versioned_control( async with AsyncSessionTest() as session: service = ControlService(session) - control = service.create_control(name=control_name, data=control_data) + control = service.create_control( + namespace_key=DEFAULT_NAMESPACE_KEY, + name=control_name, + data=control_data, + ) await service.create_version( control, event_type="created", @@ -143,6 +146,7 @@ async def test_create_control_transaction_rollback_does_not_persist_control_or_v async with AsyncSessionTest() as session: service = ControlService(session) control = service.create_control( + namespace_key=DEFAULT_NAMESPACE_KEY, name=control_name, data=deepcopy(VALID_CONTROL_PAYLOAD), ) @@ -167,7 +171,10 @@ async def test_replace_control_data_transaction_rollback_preserves_prior_state() async with AsyncSessionTest() as session: service = ControlService(session) - control = await service.get_active_control_or_404(control_id) + control = await service.get_active_control_or_404( + control_id, + namespace_key=DEFAULT_NAMESPACE_KEY, + ) updated_data = deepcopy(control.data) updated_data["description"] = "Should not persist" service.replace_control_data(control, data=updated_data) @@ -194,7 +201,10 @@ async def test_patch_mutation_transaction_rollback_preserves_prior_state() -> No async with AsyncSessionTest() as session: service = ControlService(session) - control = await service.get_active_control_or_404(control_id) + control = await service.get_active_control_or_404( + control_id, + namespace_key=DEFAULT_NAMESPACE_KEY, + ) service.rename_control(control, name=f"{control_name}-renamed") service.set_control_enabled(control, enabled=False) await service.create_version( @@ -221,7 +231,10 @@ async def test_delete_control_transaction_rollback_preserves_active_state() -> N async with AsyncSessionTest() as session: service = ControlService(session) - control = await service.get_active_control_or_404(control_id) + control = await service.get_active_control_or_404( + control_id, + namespace_key=DEFAULT_NAMESPACE_KEY, + ) service.mark_control_deleted(control, deleted_at=dt.datetime.now(dt.UTC)) await service.create_version( control, @@ -511,7 +524,10 @@ async def test_list_active_control_counts_by_agent_deduplicates_and_filters_inac await async_db.commit() # When: counting active controls for the agent - counts = await ControlService(async_db).list_active_control_counts_by_agent([agent.name]) + counts = await ControlService(async_db).list_active_control_counts_by_agent( + [agent.name], + namespace_key=DEFAULT_NAMESPACE_KEY, + ) # Then: active controls are deduplicated and inactive controls are excluded assert counts == {agent.name: 2} @@ -572,6 +588,7 @@ async def test_create_version_allocates_sequential_numbers_under_concurrent_muta async with AsyncSessionTest() as setup_session: setup_service = ControlService(setup_session) control = setup_service.create_control( + namespace_key=DEFAULT_NAMESPACE_KEY, name=f"control-{uuid.uuid4()}", data=deepcopy(VALID_CONTROL_PAYLOAD), ) @@ -592,7 +609,10 @@ async def mutate_and_version(description: str) -> None: async with AsyncSessionTest() as session: service = ControlService(session) - control = await service.get_active_control_or_404(control_id) + control = await service.get_active_control_or_404( + control_id, + namespace_key=DEFAULT_NAMESPACE_KEY, + ) updated_data = deepcopy(control.data) updated_data["description"] = description service.replace_control_data(control, data=updated_data) diff --git a/server/tests/test_target_merged_contract.py b/server/tests/test_target_merged_contract.py index 295a85e2..62891ba5 100644 --- a/server/tests/test_target_merged_contract.py +++ b/server/tests/test_target_merged_contract.py @@ -232,9 +232,9 @@ def test_target_binding_de_duplicated_against_direct_attachment( async def _insert_agent_in_namespace(async_db, *, name: str, namespace_key: str) -> None: """Insert an Agent row directly so the test can simulate a foreign namespace. - The endpoint's ``get_namespace_key`` returns the default namespace; this - helper sidesteps the resolver to seed an agent that the request-time - code path should not be able to reach. + The default test authorizer returns the default namespace; this helper + sidesteps the authorizer to seed an agent that the request-time code + path should not be able to reach. """ from agent_control_server.models import Agent