diff --git a/python/packages/hosting-entra/LICENSE b/python/packages/hosting-entra/LICENSE new file mode 100644 index 0000000000..9e841e7a26 --- /dev/null +++ b/python/packages/hosting-entra/LICENSE @@ -0,0 +1,21 @@ + MIT License + + Copyright (c) Microsoft Corporation. + + Permission is hereby granted, free of charge, to any person obtaining a copy + of this software and associated documentation files (the "Software"), to deal + in the Software without restriction, including without limitation the rights + to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + copies of the Software, and to permit persons to whom the Software is + furnished to do so, subject to the following conditions: + + The above copyright notice and this permission notice shall be included in all + copies or substantial portions of the Software. + + THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + SOFTWARE diff --git a/python/packages/hosting-entra/README.md b/python/packages/hosting-entra/README.md new file mode 100644 index 0000000000..6e0073812b --- /dev/null +++ b/python/packages/hosting-entra/README.md @@ -0,0 +1,39 @@ +# agent-framework-hosting-entra + +Microsoft Entra (Azure AD) identity-linking sidecar channel for +[agent-framework-hosting](../hosting). Owns the OAuth 2.0 Authorization Code +flow that binds a per-channel id (e.g. a Telegram chat id) to the user's +Entra object id, so multiple non-Entra channels can share a single +`entra:` isolation key. + +## Usage + +```python +from pathlib import Path +from agent_framework_hosting import AgentFrameworkHost +from agent_framework_hosting_entra import ( + EntraIdentityLinkChannel, + EntraIdentityStore, +) + +store = EntraIdentityStore(Path("./identity_links.json")) + +host = AgentFrameworkHost( + target=my_agent, + channels=[ + EntraIdentityLinkChannel( + store=store, + tenant_id="", + client_id="", + client_secret="", + public_base_url="https://your.host", + ), + # ... other channels whose run hooks call store.lookup(...) + ], +) +host.serve() +``` + +For tenants that disallow client secrets, pass `certificate_path=` (and +optionally `certificate_password=`) instead of `client_secret`. The PEM +layout matches the one used by `agent-framework-hosting-teams`. diff --git a/python/packages/hosting-entra/agent_framework_hosting_entra/__init__.py b/python/packages/hosting-entra/agent_framework_hosting_entra/__init__.py new file mode 100644 index 0000000000..6e1bba53b8 --- /dev/null +++ b/python/packages/hosting-entra/agent_framework_hosting_entra/__init__.py @@ -0,0 +1,15 @@ +# Copyright (c) Microsoft. All rights reserved. + +"""Microsoft Entra (Azure AD) identity channel for :mod:`agent_framework_hosting`.""" + +from ._channel import ( + EntraIdentityLinkChannel, + EntraIdentityStore, + entra_isolation_key, +) + +__all__ = [ + "EntraIdentityLinkChannel", + "EntraIdentityStore", + "entra_isolation_key", +] diff --git a/python/packages/hosting-entra/agent_framework_hosting_entra/_channel.py b/python/packages/hosting-entra/agent_framework_hosting_entra/_channel.py new file mode 100644 index 0000000000..7e6075638c --- /dev/null +++ b/python/packages/hosting-entra/agent_framework_hosting_entra/_channel.py @@ -0,0 +1,505 @@ +# Copyright (c) Microsoft. All rights reserved. + +"""Microsoft Entra (Azure AD) identity-linking sidecar channel. + +Implements the OAuth 2.0 Authorization Code flow against Entra so users on +non-Entra channels (Telegram, Responses callers without a verified token, +etc.) can bind their per-channel id to a stable ``entra:`` isolation +key. Once the link is established, channel run-hooks can call +:meth:`EntraIdentityStore.lookup` and rewrite the request to use the Entra +key instead of the channel-native id. + +Two credential modes are supported: + +* ``client_secret`` — confidential-client secret. +* ``certificate_path`` — PEM bundle (private key + cert) for tenants that + disallow secrets. The Teams channel uses the same PEM layout; see + :mod:`agent_framework_hosting_teams` for the openssl recipe. +""" + +from __future__ import annotations + +import asyncio +import hashlib +import hmac +import html +import json +import secrets +import time +from dataclasses import dataclass +from pathlib import Path +from typing import Any +from urllib.parse import urlencode, urlparse + +import httpx +import msal +from agent_framework_hosting import ( + ChannelContext, + ChannelContribution, + logger, +) +from cryptography import x509 +from cryptography.hazmat.primitives import hashes, serialization +from starlette.requests import Request +from starlette.responses import HTMLResponse, RedirectResponse, Response +from starlette.routing import Route + + +def entra_isolation_key(oid: str) -> str: + """Canonical isolation key for a user identified by Entra object id.""" + return f"entra:{oid}" + + +class EntraIdentityStore: + """Tiny JSON-backed mapping ``: → entra:``. + + Production deployments should swap this for a real KV store. Single-file + JSON is fine for samples because writes are infrequent (only during the + OAuth callback) and we serialize them under an asyncio lock. + """ + + def __init__(self, path: Path) -> None: + """Open an identity store backed by ``path``. + + Loads any existing JSON document; an unreadable or corrupt file is + logged and replaced with an empty in-memory map so callers always + get a usable store. + """ + self._path = path + self._lock = asyncio.Lock() + self._data: dict[str, str] = {} + if path.exists(): + try: + self._data = json.loads(path.read_text()) + except Exception: + logger.exception("identity store load failed; starting empty") + + def lookup(self, channel_key: str) -> str | None: + """Return the linked ``entra:`` key for a per-channel id, or ``None``.""" + return self._data.get(channel_key) + + async def link(self, channel_key: str, oid: str) -> None: + """Bind ``channel_key`` (e.g. ``telegram:123``) to the Entra ``oid`` and persist. + + Overwrites any existing mapping for ``channel_key`` and rewrites the + backing JSON file under the lock so concurrent callers cannot race. + """ + async with self._lock: + self._data[channel_key] = entra_isolation_key(oid) + self._path.write_text(json.dumps(self._data, indent=2, sort_keys=True)) + + async def unlink(self, channel_key: str) -> None: + """Remove the mapping for ``channel_key``; no-op if absent. + + The file is only rewritten when an entry actually existed so we + don't churn disk on idempotent unlink calls. + """ + async with self._lock: + if self._data.pop(channel_key, None) is not None: + self._path.write_text(json.dumps(self._data, indent=2, sort_keys=True)) + + +@dataclass +class _PendingAuth: + """In-memory record of an authorize redirect waiting for its OAuth callback.""" + + channel: str + channel_id: str + expires_at: float + return_to: str | None = None + + +def _link_html(body: str, *, status: int = 200) -> HTMLResponse: + """Wrap ``body`` in a minimal HTML shell suitable for browser link UIs.""" + return HTMLResponse( + f"{body}", + status_code=status, + ) + + +def _load_certificate_credential(certificate_path: str | Path, certificate_password: bytes | None) -> dict[str, str]: + """Build the ``msal`` certificate credential dict from a PEM bundle. + + Expects ``certificate_path`` to point at a single PEM containing the + private key followed by the X.509 certificate (the layout produced by + ``cat key.pem cert.pem > combined.pem``). + """ + pem_bytes = Path(certificate_path).read_bytes() + private_key = serialization.load_pem_private_key(pem_bytes, password=certificate_password) + cert = x509.load_pem_x509_certificate(pem_bytes) + + private_key_pem = private_key.private_bytes( + encoding=serialization.Encoding.PEM, + format=serialization.PrivateFormat.PKCS8, + encryption_algorithm=serialization.NoEncryption(), + ).decode() + public_cert_pem = cert.public_bytes(serialization.Encoding.PEM).decode() + # SHA-1 thumbprint is required by the Entra ``client_assertion`` spec for cert auth — not a security choice. + thumbprint = cert.fingerprint(hashes.SHA1()).hex() # noqa: S303 + return { + "private_key": private_key_pem, + "thumbprint": thumbprint, + "public_certificate": public_cert_pem, + } + + +class EntraIdentityLinkChannel: + """Sidecar Channel exposing ``GET /auth/start`` and ``GET /auth/callback``. + + Demonstrates that ``Channel`` is a general extensibility point — not just + for chat surfaces. Owns the Entra OAuth Authorization Code flow used to + bind a per-channel id (e.g. Telegram chat id) to the user's Entra object + id. + + Two credential modes are supported (mutually exclusive): + + * ``client_secret`` — classic confidential-client secret. + * ``certificate_path`` — PEM bundle (private key + certificate) for + tenants that disallow secrets. See ``teams.py`` module docstring for + an ``openssl`` recipe; the same PEM works here. + + Flow (OAuth 2.0 Authorization Code, confidential client): + + 1. ``GET /auth/start?channel=&id=`` mints a one-shot + ``state`` token and 302s to the Entra ``authorize`` endpoint. + 2. User signs in; Entra calls ``GET /auth/callback?code=...&state=...``. + 3. We exchange the code for a token (via ``msal`` so secret + cert auth + look identical at the call site), call Microsoft Graph ``/me`` to + read ``id`` (oid), persist ``: → entra:``, and + respond with a friendly HTML page (or 302 to ``return_to``). + + Tokens never leave the host process; only the ``oid`` claim is stored. + """ + + name = "identity" + path = "/auth" + + _AUTHORITY_TEMPLATE = "https://login.microsoftonline.com/{tenant}" + _GRAPH_ME = "https://graph.microsoft.com/v1.0/me" + _PENDING_TTL_SECONDS = 600 # 10 minutes + + def __init__( + self, + *, + store: EntraIdentityStore, + tenant_id: str, + client_id: str, + public_base_url: str, + client_secret: str | None = None, + certificate_path: str | Path | None = None, + certificate_password: bytes | None = None, + scope: str = "openid profile User.Read", + link_token_secret: str | None = None, + link_token_ttl_seconds: int = 600, + ) -> None: + if bool(client_secret) == bool(certificate_path): + raise ValueError("IdentityLinkChannel: pass exactly one of client_secret or certificate_path.") + if certificate_path is not None: + credential: str | dict[str, str] = _load_certificate_credential(certificate_path, certificate_password) + self._auth_kind = "certificate" + else: + credential = client_secret # type: ignore[assignment] + self._auth_kind = "client_secret" + + self._store = store + self._tenant_id = tenant_id + self._client_id = client_id + self._public_base_url = public_base_url.rstrip("/") + self._scopes = [s for s in scope.split() if s and s.lower() not in {"openid", "profile", "offline_access"}] + # MSAL ConfidentialClientApplication is sync; we wrap blocking calls + # in ``asyncio.to_thread`` because token endpoint calls do real I/O. + self._msal_app = msal.ConfidentialClientApplication( + client_id=client_id, + authority=self._AUTHORITY_TEMPLATE.format(tenant=tenant_id), + client_credential=credential, + ) + self._pending: dict[str, _PendingAuth] = {} + self._http: httpx.AsyncClient | None = None + # ``link_token_secret`` is the HMAC key that gates ``/auth/start``. + # Without it any open-internet caller can mint a binding for an + # arbitrary ``(channel, channel_id)`` pair and IDOR the victim's + # isolation key (see PR review on 0026 for the threat model). + # Optional only so dev-mode samples without the integration in + # place don't have to scramble for a secret; unsigned mode logs + # a loud warning at startup and wire-time. + self._link_token_secret = link_token_secret.encode("utf-8") if link_token_secret else None + self._link_token_ttl = link_token_ttl_seconds + # Allowed redirect-back hosts: relative paths and same-origin only. + # ``return_to`` from the unauthenticated /start query string is + # otherwise an open redirect (auth-host phishing vector). + parsed = urlparse(self._public_base_url) + self._allowed_return_host = parsed.netloc.lower() if parsed.netloc else None + + @property + def redirect_uri(self) -> str: + """The fully-qualified OAuth redirect URI registered with Entra ID. + + Computed from ``public_base_url`` plus the channel's mount path so + operators can copy it straight into the app registration's reply URLs. + """ + return f"{self._public_base_url}{self.path}/callback" + + def contribute(self, context: "ChannelContext") -> "ChannelContribution": + """Mount the ``/start`` and ``/callback`` routes plus lifecycle hooks.""" + return ChannelContribution( + routes=[ + Route("/start", self._handle_start, methods=["GET"]), + Route("/callback", self._handle_callback, methods=["GET"]), + ], + on_startup=[self._on_startup], + on_shutdown=[self._on_shutdown], + ) + + async def _on_startup(self) -> None: + """Open the shared HTTP client used for Microsoft Graph calls.""" + self._http = httpx.AsyncClient(timeout=15.0) + if self._link_token_secret is None: + logger.warning( + "EntraIdentityLinkChannel running WITHOUT link_token_secret. " + "GET /auth/start accepts unauthenticated (channel, id) pairs, " + "which means any open-internet caller can bind their Entra " + "account to a victim's per-channel id (IDOR on the identity " + "store). Pass link_token_secret=, mint URLs via " + "mint_start_url(...), and gate /start in front of the " + "channel that issues those URLs." + ) + logger.info( + "IdentityLinkChannel ready (auth=%s, signed_start=%s); redirect_uri=%s", + self._auth_kind, + self._link_token_secret is not None, + self.redirect_uri, + ) + + async def _on_shutdown(self) -> None: + """Close the Graph HTTP client; safe to call when never started.""" + if self._http is not None: + await self._http.aclose() + + # -- link-token helpers ----------------------------------------------- # + + def _sign_link_token(self, channel: str, channel_id: str, expires_at: int) -> str: + """Sign ``(channel, channel_id, expires_at)`` with HMAC-SHA256.""" + if self._link_token_secret is None: # pragma: no cover - guarded by callers + raise RuntimeError("link_token_secret is required to mint link tokens") + msg = f"{channel}|{channel_id}|{expires_at}".encode() + return hmac.new(self._link_token_secret, msg, hashlib.sha256).hexdigest() + + def _verify_link_token(self, channel: str, channel_id: str, expires_at: int, signature: str) -> bool: + """Constant-time verify the link-token signature and TTL.""" + if self._link_token_secret is None: # pragma: no cover - guarded by callers + return False + if expires_at < int(time.time()): + return False + expected = self._sign_link_token(channel, channel_id, expires_at) + return hmac.compare_digest(expected, signature) + + def mint_start_url(self, channel: str, channel_id: str, return_to: str | None = None) -> str: + """Return a one-shot signed URL for ``GET /auth/start``. + + Required when ``link_token_secret`` is set. Channels that issue + these URLs (e.g. a Telegram ``/link`` command after verifying the + inbound webhook signature) call this helper so the resulting URL + proves the caller authorised the ``(channel, channel_id)`` binding. + + Without this layer ``GET /auth/start`` is an IDOR vector: any + anonymous caller can bind a victim's per-channel id to their own + Entra ``oid``. + """ + if self._link_token_secret is None: + raise RuntimeError("mint_start_url requires link_token_secret in the constructor") + if return_to is not None: + self._validate_return_to(return_to) # fail fast at mint time + expires_at = int(time.time()) + self._link_token_ttl + sig = self._sign_link_token(channel, str(channel_id), expires_at) + params = { + "channel": channel, + "id": str(channel_id), + "exp": str(expires_at), + "sig": sig, + } + if return_to: + params["return_to"] = return_to + return f"{self._public_base_url}{self.path}/start?{urlencode(params)}" + + def _validate_return_to(self, return_to: str) -> None: + """Reject open-redirect targets. + + Allows: relative paths starting with ``/``, or absolute URLs whose + host equals the configured ``public_base_url``'s host. Rejects + everything else with ``ValueError``. + """ + if return_to.startswith("/") and not return_to.startswith("//"): + return # relative path, safe. + parsed = urlparse(return_to) + if not parsed.netloc: + return + if self._allowed_return_host and parsed.netloc.lower() == self._allowed_return_host: + return + raise ValueError( + f"return_to must be a relative path or same-origin URL " + f"(public_base_url host={self._allowed_return_host!r}); got {return_to!r}" + ) + + def authorize_url_for(self, channel: str, channel_id: str, return_to: str | None = None) -> str: + """Mint a one-shot authorize URL the user can visit to bind their account.""" + state = secrets.token_urlsafe(24) + self._gc_pending() + self._pending[state] = _PendingAuth( + channel=channel, + channel_id=str(channel_id), + expires_at=time.monotonic() + self._PENDING_TTL_SECONDS, + return_to=return_to, + ) + return str( + self._msal_app.get_authorization_request_url( + scopes=self._scopes, + redirect_uri=self.redirect_uri, + state=state, + prompt="select_account", + ) + ) + + def _gc_pending(self) -> None: + """Drop expired pending-auth entries so the in-memory map cannot grow unbounded.""" + now = time.monotonic() + for key, entry in list(self._pending.items()): + if entry.expires_at < now: + self._pending.pop(key, None) + + async def _handle_start(self, request: Request) -> Response: + """``GET /start?channel=&id=&return_to=&exp=&sig=`` — redirect to Entra to sign in. + + **Security model.** When ``link_token_secret`` is set the + request must include ``exp`` + ``sig`` — an HMAC over + ``(channel, channel_id, expires_at)`` minted by + :meth:`mint_start_url`. Without that gate, any open-internet + caller can bind a victim's per-channel id (e.g. + ``telegram:``) to their own Entra ``oid``: the + callback would persist + ``"telegram:" -> "entra:"`` and any + future inbound message from the victim would resolve to the + attacker's isolation key. We make the unsigned mode opt-in + with a loud startup warning so the dev-mode default doesn't + ship to production. + + ``return_to`` is validated against the configured + ``public_base_url`` host (or restricted to relative paths) to + prevent open-redirect phishing on a successful sign-in. + """ + channel = request.query_params.get("channel") + channel_id = request.query_params.get("id") + return_to = request.query_params.get("return_to") + if not channel or not channel_id: + return _link_html("Missing 'channel' or 'id' query parameter.", status=400) + + if self._link_token_secret is not None: + sig = request.query_params.get("sig") + exp_raw = request.query_params.get("exp") + try: + exp = int(exp_raw) if exp_raw else 0 + except ValueError: + exp = 0 + if not sig or not exp or not self._verify_link_token(channel, channel_id, exp, sig): + logger.warning( + "EntraIdentityLinkChannel /start rejected: missing/invalid signed link-token (channel=%s, id=%s)", + channel, + channel_id, + ) + return _link_html("Invalid or expired sign-in link.", status=403) + else: + # See _on_startup warning. Logged on every wire access so + # operators can't miss the IDOR exposure in their access logs. + logger.warning( + "EntraIdentityLinkChannel /start accepted UNSIGNED request " + "for (channel=%s, id=%s) — set link_token_secret to require " + "HMAC-signed link tokens minted via mint_start_url().", + channel, + channel_id, + ) + if return_to is not None: + try: + self._validate_return_to(return_to) + except ValueError as exc: + logger.warning("EntraIdentityLinkChannel /start invalid return_to: %s", exc) + return _link_html("Invalid return_to URL.", status=400) + url = self.authorize_url_for(channel, channel_id, return_to=return_to) + return RedirectResponse(url, status_code=302) + + async def _handle_callback(self, request: Request) -> Response: + """``GET /callback`` — finish the OAuth flow and persist the link. + + Exchanges the authorization code for a token, reads the user's + ``id``/``userPrincipalName`` from Microsoft Graph, then stores the + ``channel:channel_id -> entra:`` mapping in the identity store. + Renders a small HTML page so a browser-based flow has something to + show; if ``return_to`` was supplied (and validated at /start time + against the same-origin allowlist) it appears as a deep link. + + All values that flow into HTML output (``error``, ``error_description``, + ``channel_key``, ``upn``) are passed through :func:`html.escape` to + avoid reflected XSS — both the OAuth-error path and the + sign-in-success body would otherwise execute attacker-controlled + markup on the auth host's origin. + """ + if self._http is None: # pragma: no cover - guarded by lifecycle + raise RuntimeError("entra identity channel not started") + if error := request.query_params.get("error"): + description = request.query_params.get("error_description", "") + return _link_html( + f"Sign-in failed: {html.escape(error)}
{html.escape(description)}", + status=400, + ) + + code = request.query_params.get("code") + state = request.query_params.get("state") + pending = self._pending.pop(state or "", None) + if not code or pending is None or pending.expires_at < time.monotonic(): + return _link_html("Invalid or expired sign-in state. Please retry.", status=400) + + # MSAL handles client_secret vs client_assertion (cert) under the hood. + result: dict[str, Any] = await asyncio.to_thread( + self._msal_app.acquire_token_by_authorization_code, + code, + scopes=self._scopes, + redirect_uri=self.redirect_uri, + ) + if "access_token" not in result: + logger.warning("Entra token exchange failed: %s", result) + err_text = result.get("error_description") or result.get("error") or "unknown error" + return _link_html( + f"Token exchange failed: {html.escape(str(err_text))}", + status=502, + ) + access_token = result["access_token"] + + me = await self._http.get(self._GRAPH_ME, headers={"Authorization": f"Bearer {access_token}"}) + if me.status_code != 200: + return _link_html("Could not read user profile from Microsoft Graph.", status=502) + profile = me.json() + oid = profile.get("id") + upn = profile.get("userPrincipalName") or profile.get("displayName") or oid + if not oid: + return _link_html("Profile response missing 'id'.", status=502) + + channel_key = f"{pending.channel}:{pending.channel_id}" + await self._store.link(channel_key, oid) + logger.info("Linked %s → entra:%s (%s)", channel_key, oid, upn) + + if pending.return_to: + # ``return_to`` was already validated at /start time against + # the allowlist (relative path or same-origin only). Re-check + # defensively to harden against any future code path that + # bypasses the /start gate. + try: + self._validate_return_to(pending.return_to) + return RedirectResponse(pending.return_to, status_code=302) + except ValueError: + logger.warning( + "EntraIdentityLinkChannel /callback dropping invalid return_to: %s", + pending.return_to, + ) + return _link_html( + f"

Linked

{html.escape(channel_key)} is now bound to " + f"{html.escape(str(upn))}.

" + "

You can close this window and return to your chat.

" + ) diff --git a/python/packages/hosting-entra/pyproject.toml b/python/packages/hosting-entra/pyproject.toml new file mode 100644 index 0000000000..45264741d5 --- /dev/null +++ b/python/packages/hosting-entra/pyproject.toml @@ -0,0 +1,108 @@ +[project] +name = "agent-framework-hosting-entra" +description = "Microsoft Entra (Azure AD) OAuth-based identity-linking channel for agent-framework-hosting." +authors = [{ name = "Microsoft", email = "af-support@microsoft.com"}] +readme = "README.md" +requires-python = ">=3.10" +version = "1.0.0a260424" +license-files = ["LICENSE"] +urls.homepage = "https://aka.ms/agent-framework" +urls.source = "https://github.com/microsoft/agent-framework/tree/main/python" +urls.release_notes = "https://github.com/microsoft/agent-framework/releases?q=tag%3Apython-1&expanded=true" +urls.issues = "https://github.com/microsoft/agent-framework/issues" +classifiers = [ + "License :: OSI Approved :: MIT License", + "Development Status :: 3 - Alpha", + "Intended Audience :: Developers", + "Programming Language :: Python :: 3", + "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: 3.12", + "Programming Language :: Python :: 3.13", + "Programming Language :: Python :: 3.14", + "Typing :: Typed", +] +dependencies = [ + "agent-framework-core>=1.2.0,<2", + "agent-framework-hosting==1.0.0a260424", + "httpx>=0.27,<1", + "msal>=1.28,<2", + "cryptography>=42", +] + +[tool.uv] +prerelease = "if-necessary-or-explicit" +environments = [ + "sys_platform == 'darwin'", + "sys_platform == 'linux'", + "sys_platform == 'win32'" +] + +[tool.uv-dynamic-versioning] +fallback-version = "0.0.0" + +[tool.pytest.ini_options] +testpaths = 'tests' +addopts = "-ra -q -r fEX" +asyncio_mode = "auto" +asyncio_default_fixture_loop_scope = "function" +filterwarnings = [] +timeout = 120 +markers = [ + "integration: marks tests as integration tests that require external services", +] + +[tool.ruff] +extend = "../../pyproject.toml" + +[tool.coverage.run] +omit = [ + "**/__init__.py" +] + +[tool.pyright] +extends = "../../pyproject.toml" +include = ["agent_framework_hosting_entra"] +exclude = ['tests'] +# Bot Framework activities arrive as loosely-typed JSON-ish maps. Strict +# ``Unknown`` reporting on every ``.get(...)`` adds noise without catching +# real bugs — narrowing happens via runtime isinstance checks instead. +reportUnknownArgumentType = "none" +reportUnknownMemberType = "none" +reportUnknownVariableType = "none" +reportUnknownLambdaType = "none" +reportOptionalMemberAccess = "none" + +[tool.mypy] +plugins = ['pydantic.mypy'] +strict = true +python_version = "3.10" +ignore_missing_imports = true +disallow_untyped_defs = true +no_implicit_optional = true +check_untyped_defs = true +warn_return_any = true +show_error_codes = true +warn_unused_ignores = false +disallow_incomplete_defs = true +disallow_untyped_decorators = true + +[tool.bandit] +targets = ["agent_framework_hosting_entra"] +exclude_dirs = ["tests"] + +[tool.poe] +executor.type = "uv" +include = "../../shared_tasks.toml" + +[tool.poe.tasks.mypy] +help = "Run MyPy for this package." +cmd = "mypy --config-file $POE_ROOT/pyproject.toml agent_framework_hosting_entra" + +[tool.poe.tasks.test] +help = "Run the default unit test suite for this package." +cmd = 'pytest -m "not integration" --cov=agent_framework_hosting_entra --cov-report=term-missing:skip-covered tests' + +[build-system] +requires = ["flit-core >= 3.11,<4.0"] +build-backend = "flit_core.buildapi" diff --git a/python/packages/hosting-entra/tests/__init__.py b/python/packages/hosting-entra/tests/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/python/packages/hosting-entra/tests/test_channel.py b/python/packages/hosting-entra/tests/test_channel.py new file mode 100644 index 0000000000..e28a217a07 --- /dev/null +++ b/python/packages/hosting-entra/tests/test_channel.py @@ -0,0 +1,467 @@ +# Copyright (c) Microsoft. All rights reserved. + +"""Unit tests for :mod:`agent_framework_hosting_entra`. + +The MSAL ``ConfidentialClientApplication`` and Microsoft Graph calls are +mocked out so no network access is required. Live OAuth, certificate auth, +and full webhook flow are out of scope here. +""" + +from __future__ import annotations + +import json +from pathlib import Path +from typing import Any +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest +from starlette.applications import Starlette +from starlette.testclient import TestClient + +from agent_framework_hosting_entra import ( + EntraIdentityLinkChannel, + EntraIdentityStore, + entra_isolation_key, +) + + +def test_entra_isolation_key_format() -> None: + assert entra_isolation_key("abc123") == "entra:abc123" + + +class TestEntraIdentityStore: + @pytest.mark.asyncio + async def test_link_writes_entra_namespaced_value(self, tmp_path: Path) -> None: + store = EntraIdentityStore(tmp_path / "links.json") + await store.link("telegram:42", "oid-xyz") + assert store.lookup("telegram:42") == "entra:oid-xyz" + # Persisted to disk. + saved = json.loads((tmp_path / "links.json").read_text()) + assert saved == {"telegram:42": "entra:oid-xyz"} + + @pytest.mark.asyncio + async def test_unlink_removes_entry(self, tmp_path: Path) -> None: + store = EntraIdentityStore(tmp_path / "links.json") + await store.link("telegram:42", "oid") + await store.unlink("telegram:42") + assert store.lookup("telegram:42") is None + assert json.loads((tmp_path / "links.json").read_text()) == {} + + @pytest.mark.asyncio + async def test_unlink_unknown_is_noop(self, tmp_path: Path) -> None: + store = EntraIdentityStore(tmp_path / "links.json") + await store.unlink("telegram:never") # must not raise + assert not (tmp_path / "links.json").exists() + + def test_loads_existing_file(self, tmp_path: Path) -> None: + path = tmp_path / "links.json" + path.write_text(json.dumps({"telegram:1": "entra:abc"})) + store = EntraIdentityStore(path) + assert store.lookup("telegram:1") == "entra:abc" + + def test_corrupt_file_starts_empty(self, tmp_path: Path) -> None: + path = tmp_path / "links.json" + path.write_text("not-json") + store = EntraIdentityStore(path) + assert store.lookup("anything") is None + + +class TestEntraIdentityLinkChannelConfig: + def test_rejects_neither_credential(self, tmp_path: Path) -> None: + with pytest.raises(ValueError, match="exactly one"): + EntraIdentityLinkChannel( + store=EntraIdentityStore(tmp_path / "x.json"), + tenant_id="t", + client_id="c", + public_base_url="https://example.com", + ) + + def test_rejects_both_credentials(self, tmp_path: Path) -> None: + with pytest.raises(ValueError, match="exactly one"): + EntraIdentityLinkChannel( + store=EntraIdentityStore(tmp_path / "x.json"), + tenant_id="t", + client_id="c", + public_base_url="https://example.com", + client_secret="s", + certificate_path="/tmp/does-not-exist.pem", + ) + + def test_redirect_uri_strips_trailing_slash(self, tmp_path: Path) -> None: + with patch( + "agent_framework_hosting_entra._channel.msal.ConfidentialClientApplication", + MagicMock(), + ): + ch = EntraIdentityLinkChannel( + store=EntraIdentityStore(tmp_path / "x.json"), + tenant_id="t", + client_id="c", + public_base_url="https://example.com/", + client_secret="s", + ) + assert ch.redirect_uri == "https://example.com/auth/callback" + + +class TestEntraIdentityLinkChannelRoutes: + def _make_channel(self, tmp_path: Path, msal_app: MagicMock) -> tuple[EntraIdentityLinkChannel, EntraIdentityStore]: + store = EntraIdentityStore(tmp_path / "links.json") + with patch( + "agent_framework_hosting_entra._channel.msal.ConfidentialClientApplication", + return_value=msal_app, + ): + ch = EntraIdentityLinkChannel( + store=store, + tenant_id="tenant-1", + client_id="client-1", + public_base_url="https://example.com", + client_secret="s", + ) + return ch, store + + def _mount_app(self, ch: EntraIdentityLinkChannel) -> Starlette: + # We don't depend on AgentFrameworkHost here — wire the routes + # directly so we can exercise the channel in isolation. + from starlette.routing import Mount + + contribution = ch.contribute(MagicMock()) + return Starlette(routes=[Mount(ch.path, routes=contribution.routes)]) + + def test_start_missing_params_returns_400(self, tmp_path: Path) -> None: + msal_app = MagicMock() + ch, _ = self._make_channel(tmp_path, msal_app) + with TestClient(self._mount_app(ch)) as client: + r = client.get("/auth/start", follow_redirects=False) + assert r.status_code == 400 + + def test_start_redirects_to_authorize_url(self, tmp_path: Path) -> None: + msal_app = MagicMock() + msal_app.get_authorization_request_url.return_value = ( + "https://login.microsoftonline.com/tenant-1/oauth2/v2.0/authorize?state=X" + ) + ch, _ = self._make_channel(tmp_path, msal_app) + with TestClient(self._mount_app(ch)) as client: + r = client.get( + "/auth/start", + params={"channel": "telegram", "id": "42"}, + follow_redirects=False, + ) + assert r.status_code == 302 + assert "login.microsoftonline.com" in r.headers["location"] + + def test_callback_invalid_state_returns_400(self, tmp_path: Path) -> None: + msal_app = MagicMock() + ch, _ = self._make_channel(tmp_path, msal_app) + ch._http = MagicMock(aclose=AsyncMock()) + with TestClient(self._mount_app(ch)) as client: + r = client.get("/auth/callback", params={"code": "c", "state": "unknown"}) + assert r.status_code == 400 + + def test_callback_links_oid_on_success(self, tmp_path: Path) -> None: + msal_app = MagicMock() + msal_app.get_authorization_request_url.return_value = ( + "https://login.microsoftonline.com/tenant-1/authorize?state=X" + ) + msal_app.acquire_token_by_authorization_code.return_value = {"access_token": "t"} + ch, store = self._make_channel(tmp_path, msal_app) + + # Fake the Graph /me call. + graph_response = MagicMock() + graph_response.status_code = 200 + graph_response.json = MagicMock(return_value={"id": "oid-xyz", "userPrincipalName": "user@x"}) + ch._http = MagicMock() + ch._http.get = AsyncMock(return_value=graph_response) + ch._http.aclose = AsyncMock() + + # Mint a real state via the public API so the pending dict is populated. + ch.authorize_url_for("telegram", "42") + state = next(iter(ch._pending.keys())) + + with TestClient(self._mount_app(ch)) as client: + r = client.get("/auth/callback", params={"code": "abc", "state": state}) + assert r.status_code == 200 + assert store.lookup("telegram:42") == "entra:oid-xyz" + + def test_callback_token_failure_returns_502(self, tmp_path: Path) -> None: + msal_app = MagicMock() + msal_app.get_authorization_request_url.return_value = "https://x" + msal_app.acquire_token_by_authorization_code.return_value = { + "error": "invalid_grant", + "error_description": "expired", + } + ch, store = self._make_channel(tmp_path, msal_app) + ch._http = MagicMock(aclose=AsyncMock()) + ch.authorize_url_for("telegram", "42") + state = next(iter(ch._pending.keys())) + with TestClient(self._mount_app(ch)) as client: + r = client.get("/auth/callback", params={"code": "c", "state": state}) + assert r.status_code == 502 + assert store.lookup("telegram:42") is None + + +# --------------------------------------------------------------------------- # +# Round-2 security hardening # +# --------------------------------------------------------------------------- # + + +class TestSignedLinkToken: + """`/auth/start` must reject unsigned/forged requests when secret is set.""" + + def _make_signed_channel( + self, tmp_path: Path, msal_app: MagicMock, *, secret: str = "test-secret" + ) -> EntraIdentityLinkChannel: + store = EntraIdentityStore(tmp_path / "links.json") + with patch( + "agent_framework_hosting_entra._channel.msal.ConfidentialClientApplication", + return_value=msal_app, + ): + return EntraIdentityLinkChannel( + store=store, + tenant_id="tenant-1", + client_id="client-1", + public_base_url="https://example.com", + client_secret="s", + link_token_secret=secret, + ) + + def _mount(self, ch: EntraIdentityLinkChannel) -> Starlette: + from starlette.routing import Mount + + contribution = ch.contribute(MagicMock()) + return Starlette(routes=[Mount(ch.path, routes=contribution.routes)]) + + def test_start_rejects_unsigned_request_when_secret_set(self, tmp_path: Path) -> None: + msal_app = MagicMock() + ch = self._make_signed_channel(tmp_path, msal_app) + with TestClient(self._mount(ch)) as client: + r = client.get( + "/auth/start", + params={"channel": "telegram", "id": "42"}, + follow_redirects=False, + ) + assert r.status_code == 403 + + def test_start_rejects_forged_signature(self, tmp_path: Path) -> None: + msal_app = MagicMock() + ch = self._make_signed_channel(tmp_path, msal_app) + with TestClient(self._mount(ch)) as client: + r = client.get( + "/auth/start", + params={ + "channel": "telegram", + "id": "42", + "exp": "9999999999", + "sig": "deadbeef", + }, + follow_redirects=False, + ) + assert r.status_code == 403 + + def test_start_accepts_valid_signed_url(self, tmp_path: Path) -> None: + msal_app = MagicMock() + msal_app.get_authorization_request_url.return_value = ( + "https://login.microsoftonline.com/tenant-1/authorize?state=X" + ) + ch = self._make_signed_channel(tmp_path, msal_app) + url = ch.mint_start_url("telegram", "42") + # Strip the host prefix to call via the in-process client. + path_and_query = url.split("https://example.com", 1)[1] + with TestClient(self._mount(ch)) as client: + r = client.get(path_and_query, follow_redirects=False) + assert r.status_code == 302 + + def test_start_rejects_expired_signed_url(self, tmp_path: Path) -> None: + import time as time_module + from urllib.parse import urlencode + + msal_app = MagicMock() + ch = self._make_signed_channel(tmp_path, msal_app) + # Hand-craft an expired-but-otherwise-valid token. + expired = int(time_module.time()) - 60 + sig = ch._sign_link_token("telegram", "42", expired) # type: ignore[attr-defined] # pyright: ignore[reportPrivateUsage] + params = {"channel": "telegram", "id": "42", "exp": str(expired), "sig": sig} + with TestClient(self._mount(ch)) as client: + r = client.get(f"/auth/start?{urlencode(params)}", follow_redirects=False) + assert r.status_code == 403 + + def test_mint_start_url_requires_secret(self, tmp_path: Path) -> None: + import pytest + + msal_app = MagicMock() + store = EntraIdentityStore(tmp_path / "links.json") + with patch( + "agent_framework_hosting_entra._channel.msal.ConfidentialClientApplication", + return_value=msal_app, + ): + ch = EntraIdentityLinkChannel( + store=store, + tenant_id="tenant-1", + client_id="client-1", + public_base_url="https://example.com", + client_secret="s", + ) + with pytest.raises(RuntimeError, match="link_token_secret"): + ch.mint_start_url("telegram", "42") + + def test_unsigned_mode_logs_warning_at_startup(self, tmp_path: Path, caplog: Any) -> None: + import asyncio as asyncio_mod + import logging + + msal_app = MagicMock() + store = EntraIdentityStore(tmp_path / "links.json") + with patch( + "agent_framework_hosting_entra._channel.msal.ConfidentialClientApplication", + return_value=msal_app, + ): + ch = EntraIdentityLinkChannel( + store=store, + tenant_id="tenant-1", + client_id="client-1", + public_base_url="https://example.com", + client_secret="s", + ) + with caplog.at_level(logging.WARNING, logger="agent_framework.hosting"): + asyncio_mod.run(ch._on_startup()) # pyright: ignore[reportPrivateUsage] + asyncio_mod.run(ch._on_shutdown()) # pyright: ignore[reportPrivateUsage] + assert any("WITHOUT link_token_secret" in r.message for r in caplog.records) + + +class TestXssEscaping: + """All inbound query/profile values must be HTML-escaped before output.""" + + def _setup(self, tmp_path: Path) -> tuple[EntraIdentityLinkChannel, EntraIdentityStore, MagicMock]: + store = EntraIdentityStore(tmp_path / "links.json") + msal_app = MagicMock() + msal_app.get_authorization_request_url.return_value = "https://x" + with patch( + "agent_framework_hosting_entra._channel.msal.ConfidentialClientApplication", + return_value=msal_app, + ): + ch = EntraIdentityLinkChannel( + store=store, + tenant_id="tenant-1", + client_id="client-1", + public_base_url="https://example.com", + client_secret="s", + ) + return ch, store, msal_app + + def _mount(self, ch: EntraIdentityLinkChannel) -> Starlette: + from starlette.routing import Mount + + contribution = ch.contribute(MagicMock()) + return Starlette(routes=[Mount(ch.path, routes=contribution.routes)]) + + def test_callback_error_param_is_escaped(self, tmp_path: Path) -> None: + ch, _, _ = self._setup(tmp_path) + ch._http = MagicMock(aclose=AsyncMock()) + with TestClient(self._mount(ch)) as client: + r = client.get( + "/auth/callback", + params={ + "error": "", + "error_description": "", + }, + ) + assert r.status_code == 400 + assert "@x"} + ) + ch._http = MagicMock(aclose=AsyncMock()) + ch._http.get = AsyncMock(return_value=graph_response) + # Mint a binding via authorize_url_for (channel-side trusted call). + ch.authorize_url_for("", "42") + state = next(iter(ch._pending.keys())) + with TestClient(self._mount(ch)) as client: + r = client.get("/auth/callback", params={"code": "abc", "state": state}) + assert r.status_code == 200 + assert "