diff --git a/agentex/src/config/dependencies.py b/agentex/src/config/dependencies.py index f37cf043..b2086a5d 100644 --- a/agentex/src/config/dependencies.py +++ b/agentex/src/config/dependencies.py @@ -49,6 +49,8 @@ def __init__(self): self.redis_pool: redis.ConnectionPool | None = None self.database_async_read_only_engine: AsyncEngine | None = None self.postgres_metrics_collector: PostgresMetricsCollector | None = None + self._mongodb_refresh_task: asyncio.Task | None = None + self._mongodb_close_tasks: set[asyncio.Task] = set() self._loaded = False async def create_temporal_client(self): @@ -122,17 +124,7 @@ async def load(self): logger.info("Connecting to MongoDB") - self.mongodb_client = AsyncMongoClient( - mongodb_uri, - serverSelectionTimeoutMS=20000, - connectTimeoutMS=20000, - socketTimeoutMS=20000, - retryWrites=False, # Disable retryable writes for AWS DocumentDB compatibility - maxPoolSize=self.environment_variables.MONGODB_MAX_POOL_SIZE, - minPoolSize=self.environment_variables.MONGODB_MIN_POOL_SIZE, - maxIdleTimeMS=30000, # Close connections after 30 seconds of inactivity - waitQueueTimeoutMS=5000, # Wait up to 5 seconds for a connection from pool - ) + self.mongodb_client = self._build_mongodb_client(mongodb_uri) self.mongodb_database = self.mongodb_client[mongodb_database_name] # Ping the database to verify connection @@ -226,10 +218,132 @@ async def load(self): service_name=service_name, ) + self._start_mongodb_oidc_refresh_loop() + self._loaded = True + def _build_mongodb_client(self, mongodb_uri: str) -> AsyncMongoClient: + """Construct an AsyncMongoClient with the shared pool/timeout settings. + + Used both at startup and by the OIDC refresh, so the two paths can never + drift apart. + """ + return AsyncMongoClient( + mongodb_uri, + serverSelectionTimeoutMS=20000, + connectTimeoutMS=20000, + socketTimeoutMS=20000, + retryWrites=False, # Disable retryable writes for AWS DocumentDB compatibility + maxPoolSize=self.environment_variables.MONGODB_MAX_POOL_SIZE, + minPoolSize=self.environment_variables.MONGODB_MIN_POOL_SIZE, + maxIdleTimeMS=30000, # Close connections after 30 seconds of inactivity + waitQueueTimeoutMS=5000, # Wait up to 5 seconds for a connection from pool + ) + + def _mongodb_uses_oidc(self) -> bool: + """True only when the Mongo URI authenticates via MONGODB-OIDC. + + Gates the refresh loop so standard-auth / AWS DocumentDB deployments are + never churned — only GCP OIDC tokens expire out from under a live client. + """ + uri = self.environment_variables.MONGODB_URI or "" + return "MONGODB-OIDC" in uri.upper() + + async def refresh_mongodb_client(self) -> None: + """Rebuild the Mongo client to renew the cached GCP OIDC token. + + pymongo's built-in GCP OIDC provider caches the access token for the life + of the client and only refreshes it reactively (on a server reauth + challenge). GCP tokens expire after ~1h, so a long-lived client eventually + fails auth. A new client authenticates fresh, picking up a new token. + + The new client is built and pinged (which forces fresh auth) before the + swap, and the old client is closed only after a drain delay, so no in-flight + operation is ever dropped and we never swap to a broken client. + """ + mongodb_uri = self.environment_variables.MONGODB_URI + if not mongodb_uri or not self._mongodb_uses_oidc(): + return + + new_client = self._build_mongodb_client(mongodb_uri) + # Force fresh OIDC auth and validate the new client before trusting it. + # If this raises, close the candidate (so a repeated auth/network outage + # can't accumulate orphaned clients across retries) and keep using the + # existing, working client. + try: + await new_client.admin.command("ping") + except BaseException: + await new_client.close() + raise + + old_client = self.mongodb_client + self.mongodb_client = new_client + self.mongodb_database = new_client[ + self.environment_variables.MONGODB_DATABASE_NAME + ] + logger.info("Refreshed MongoDB client to renew OIDC credentials") + + if old_client is not None and old_client is not new_client: + task = asyncio.create_task( + self._close_mongodb_client_after_delay(old_client) + ) + # Keep a strong reference until done so the task is not GC'd mid-flight. + self._mongodb_close_tasks.add(task) + task.add_done_callback(self._mongodb_close_tasks.discard) + + async def _close_mongodb_client_after_delay( + self, client: AsyncMongoClient, delay: float = 60.0 + ) -> None: + """Close a superseded Mongo client after letting in-flight ops drain.""" + try: + await asyncio.sleep(delay) + await client.close() + except asyncio.CancelledError: + await client.close() + raise + except Exception as e: + logger.warning(f"Error closing superseded MongoDB client: {e}") + + def _start_mongodb_oidc_refresh_loop(self) -> None: + interval = self.environment_variables.MONGODB_OIDC_REFRESH_INTERVAL_SECONDS + if ( + self.mongodb_client is None + or not self._mongodb_uses_oidc() + or interval <= 0 + or self._mongodb_refresh_task is not None + ): + return + self._mongodb_refresh_task = asyncio.create_task( + self._mongodb_oidc_refresh_loop(interval) + ) + logger.info(f"Started MongoDB OIDC refresh loop (interval={interval}s)") + + async def _mongodb_oidc_refresh_loop(self, interval: int) -> None: + while True: + try: + await asyncio.sleep(interval) + await self.refresh_mongodb_client() + except asyncio.CancelledError: + raise + except Exception as e: + logger.error( + f"MongoDB OIDC refresh failed; retrying next interval: {e}" + ) + + async def _stop_mongodb_oidc_refresh_loop(self) -> None: + if self._mongodb_refresh_task is not None: + self._mongodb_refresh_task.cancel() + try: + await self._mongodb_refresh_task + except asyncio.CancelledError: + pass + self._mongodb_refresh_task = None + async def force_reload(self): """Force reload all dependencies with fresh environment variables""" + # Stop the MongoDB OIDC refresh loop before tearing down the client + await self._stop_mongodb_oidc_refresh_loop() + # Stop metrics collection if self.postgres_metrics_collector: await self.postgres_metrics_collector.stop_collection() @@ -272,6 +386,9 @@ def shutdown(): async def async_shutdown(): global_dependencies = GlobalDependencies() + # Stop the MongoDB OIDC refresh loop + await global_dependencies._stop_mongodb_oidc_refresh_loop() + # Stop PostgreSQL metrics collection if global_dependencies.postgres_metrics_collector: await global_dependencies.postgres_metrics_collector.stop_collection() diff --git a/agentex/src/config/environment_variables.py b/agentex/src/config/environment_variables.py index 0872c0cf..a6f1ca15 100644 --- a/agentex/src/config/environment_variables.py +++ b/agentex/src/config/environment_variables.py @@ -37,6 +37,7 @@ class EnvVarKeys(str, Enum): MONGODB_DATABASE_NAME = "MONGODB_DATABASE_NAME" MONGODB_MAX_POOL_SIZE = "MONGODB_MAX_POOL_SIZE" MONGODB_MIN_POOL_SIZE = "MONGODB_MIN_POOL_SIZE" + MONGODB_OIDC_REFRESH_INTERVAL_SECONDS = "MONGODB_OIDC_REFRESH_INTERVAL_SECONDS" REDIS_MAX_CONNECTIONS = "REDIS_MAX_CONNECTIONS" REDIS_CONNECTION_TIMEOUT = "REDIS_CONNECTION_TIMEOUT" REDIS_SOCKET_TIMEOUT = "REDIS_SOCKET_TIMEOUT" @@ -96,6 +97,11 @@ class EnvironmentVariables(BaseModel): MONGODB_DATABASE_NAME: str | None = "agentex" MONGODB_MAX_POOL_SIZE: int = 50 MONGODB_MIN_POOL_SIZE: int = 5 + # Rebuild the Mongo client on this interval to renew GCP OIDC credentials. + # pymongo caches the OIDC token for the life of the client and never refreshes + # it proactively, so a long-lived client fails auth once the ~1h GCP token + # expires. Only applied to MONGODB-OIDC URIs; 0 disables. Default 45 min. + MONGODB_OIDC_REFRESH_INTERVAL_SECONDS: int = 2700 REDIS_MAX_CONNECTIONS: int = 50 # Increased for SSE streaming REDIS_CONNECTION_TIMEOUT: int = 60 # Connection timeout in seconds REDIS_SOCKET_TIMEOUT: int = 30 # Socket timeout in seconds @@ -163,6 +169,9 @@ def refresh(cls, force_refresh: bool = False) -> EnvironmentVariables | None: MONGODB_MIN_POOL_SIZE=int( os.environ.get(EnvVarKeys.MONGODB_MIN_POOL_SIZE, "5") ), + MONGODB_OIDC_REFRESH_INTERVAL_SECONDS=int( + os.environ.get(EnvVarKeys.MONGODB_OIDC_REFRESH_INTERVAL_SECONDS, "2700") + ), REDIS_MAX_CONNECTIONS=int( os.environ.get(EnvVarKeys.REDIS_MAX_CONNECTIONS, "100") ), diff --git a/agentex/src/temporal/activities/retention_cleanup_activities.py b/agentex/src/temporal/activities/retention_cleanup_activities.py index d42def8b..0a6ed731 100644 --- a/agentex/src/temporal/activities/retention_cleanup_activities.py +++ b/agentex/src/temporal/activities/retention_cleanup_activities.py @@ -15,11 +15,11 @@ Pydantic models). """ +from collections.abc import Callable from typing import TypedDict from src.config.environment_variables import EnvironmentVariables from src.domain.exceptions import ClientError -from src.domain.repositories.task_repository import TaskRepository from src.domain.use_cases.task_retention_use_case import TaskRetentionUseCase from src.utils.logging import make_logger from temporalio import activity @@ -46,11 +46,15 @@ class CleanTaskOutcome(TypedDict): class RetentionCleanupActivities: def __init__( self, - task_repository: TaskRepository, - use_case: TaskRetentionUseCase, + use_case_factory: Callable[[], TaskRetentionUseCase], ): - self.task_repository = task_repository - self.use_case = use_case + # Build the use case (and its Mongo-backed repositories) per activity run + # rather than capturing it at worker startup. The OIDC client refresh + # periodically swaps GlobalDependencies.mongodb_database to a fresh client + # and closes the old one; a use case captured once would hold collections + # bound to the stale (eventually closed) client. Resolving per run always + # picks up the current client. + self._use_case_factory = use_case_factory @activity.defn(name=LOAD_CLEANUP_CONFIG_ACTIVITY) async def load_cleanup_config(self) -> dict: @@ -101,7 +105,8 @@ async def find_cleanup_candidates( "find_cleanup_candidates_started", extra={"after_id": after_id, "limit": limit}, ) - result = await self.task_repository.list_cleanup_candidate_ids( + task_repository = self._use_case_factory().task_repository + result = await task_repository.list_cleanup_candidate_ids( idle_days=idle_days, agent_names=agent_names, after_id=after_id, @@ -120,7 +125,8 @@ async def find_multi_agent_cleanup_candidates( Cleanup deletes task-wide content, so these are skipped by the scheduled workflow before any per-task child workflow can run. """ - result = await self.task_repository.list_multi_agent_task_ids(task_ids=task_ids) + task_repository = self._use_case_factory().task_repository + result = await task_repository.list_multi_agent_task_ids(task_ids=task_ids) logger.info( "find_multi_agent_cleanup_candidates_completed", extra={"count": len(result)}, @@ -146,9 +152,10 @@ async def clean_task( (e.g. task is still active, not yet idle long enough, or already cleaned). Other exceptions propagate so Temporal can retry them. """ + use_case = self._use_case_factory() try: if dry_run: - result = await self.use_case.preview_clean_task( + result = await use_case.preview_clean_task( task_id=task_id, force=False, idle_days=idle_days ) logger.info( @@ -163,7 +170,7 @@ async def clean_task( "task_states_deleted": 0, "events_deleted": 0, } - result = await self.use_case.clean_task( + result = await use_case.clean_task( task_id=task_id, force=False, idle_days=idle_days ) return { diff --git a/agentex/src/temporal/run_worker.py b/agentex/src/temporal/run_worker.py index de44cba6..1f500055 100644 --- a/agentex/src/temporal/run_worker.py +++ b/agentex/src/temporal/run_worker.py @@ -153,12 +153,13 @@ def create_agentex_server_worker( http_client=http_client, ) - retention_use_case = build_task_retention_use_case(global_dependencies) - # Reuse the repository the factory already built (avoids a duplicate - # TaskRepository) via the use case's stable accessor. + # Build the retention use case per activity run (not once here): the OIDC + # client refresh swaps global_dependencies.mongodb_database to a fresh client + # and closes the old one, so a use case captured at startup would eventually + # hold Mongo collections bound to a closed client. global_dependencies is the + # singleton, so each call reads the current mongodb_database. retention_activities = RetentionCleanupActivities( - task_repository=retention_use_case.task_repository, - use_case=retention_use_case, + use_case_factory=lambda: build_task_retention_use_case(global_dependencies), ) return asyncio.create_task( diff --git a/agentex/tests/integration/config/__init__.py b/agentex/tests/integration/config/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/agentex/tests/integration/config/test_mongodb_oidc_refresh_integration.py b/agentex/tests/integration/config/test_mongodb_oidc_refresh_integration.py new file mode 100644 index 00000000..2fce71dc --- /dev/null +++ b/agentex/tests/integration/config/test_mongodb_oidc_refresh_integration.py @@ -0,0 +1,71 @@ +"""Integration tests for the MongoDB client-refresh swap against a real Mongo. + +The unit tests mock the client; these prove the build-validate-swap-drain works +end-to-end against a live MongoDB container: data written before the swap is still +readable after it, the post-swap client is fully functional, and the superseded +client is drained and closed. (The container doesn't speak GCP OIDC, so the OIDC +gate is forced on to exercise the swap path itself.) +""" + +from unittest.mock import AsyncMock + +import pytest +from src.config.dependencies import GlobalDependencies, Singleton + + +@pytest.fixture +def deps(mongodb_connection_string): + """Fresh GlobalDependencies wired to the test Mongo container.""" + Singleton._instances.pop(GlobalDependencies, None) + instance = GlobalDependencies() + instance.environment_variables = instance.environment_variables.model_copy( + update={ + "MONGODB_URI": mongodb_connection_string, + "MONGODB_DATABASE_NAME": "agentex_oidc_refresh_test", + } + ) + instance.mongodb_client = instance._build_mongodb_client(mongodb_connection_string) + instance.mongodb_database = instance.mongodb_client["agentex_oidc_refresh_test"] + yield instance + Singleton._instances.pop(GlobalDependencies, None) + + +@pytest.mark.asyncio +@pytest.mark.integration +async def test_refresh_preserves_data_and_drains_old_client(deps, monkeypatch): + # Treat the container URI as OIDC so the refresh path actually runs, and + # collapse the drain delay so the close completes within the test. + monkeypatch.setattr(deps, "_mongodb_uses_oidc", lambda: True) + original_close_after_delay = deps._close_mongodb_client_after_delay + + async def fast_close(client, delay=0.0): + await original_close_after_delay(client, delay=0.0) + + monkeypatch.setattr(deps, "_close_mongodb_client_after_delay", fast_close) + + collection = "docs" + await deps.mongodb_database[collection].insert_one({"_id": "before", "n": 1}) + + old_client = deps.mongodb_client + old_client.close = AsyncMock(wraps=old_client.close) + + await deps.refresh_mongodb_client() + + # A genuinely new client is now installed. + assert deps.mongodb_client is not old_client + + # The new client can write, and reads the doc written before the swap. + await deps.mongodb_database[collection].insert_one({"_id": "after", "n": 2}) + ids = { + doc["_id"] + async for doc in deps.mongodb_database[collection].find({}, {"_id": 1}) + } + assert ids == {"before", "after"} + + # The superseded client is drained and closed. + for task in list(deps._mongodb_close_tasks): + await task + old_client.close.assert_awaited_once() + + await deps.mongodb_client.drop_database("agentex_oidc_refresh_test") + await deps.mongodb_client.close() diff --git a/agentex/tests/unit/config/test_mongodb_oidc_refresh.py b/agentex/tests/unit/config/test_mongodb_oidc_refresh.py new file mode 100644 index 00000000..96ca3b65 --- /dev/null +++ b/agentex/tests/unit/config/test_mongodb_oidc_refresh.py @@ -0,0 +1,167 @@ +"""Unit tests for the MongoDB OIDC client-refresh path in GlobalDependencies. + +pymongo's built-in GCP OIDC provider caches the access token for the life of the +client and never refreshes it proactively, so a long-lived client fails auth once +the ~1h GCP token expires. `refresh_mongodb_client()` rebuilds the client to renew +the token without bouncing the process; these tests cover the gating, the +build-validate-swap-drain ordering, and the loop lifecycle — all without a real +MongoDB (the client is mocked). +""" + +from unittest.mock import AsyncMock, MagicMock + +import pytest +from src.config.dependencies import GlobalDependencies, Singleton +from src.config.environment_variables import EnvironmentVariables + +OIDC_URI = ( + "mongodb://host/?authMechanism=MONGODB-OIDC" + "&authMechanismProperties=ENVIRONMENT:gcp,TOKEN_RESOURCE:FIRESTORE" +) +PLAIN_URI = "mongodb://user:pass@host:27017/?authSource=admin" + + +@pytest.fixture +def deps(): + """A fresh GlobalDependencies, isolated from the process-wide singleton.""" + Singleton._instances.pop(GlobalDependencies, None) + instance = GlobalDependencies() + yield instance + Singleton._instances.pop(GlobalDependencies, None) + + +def _mock_client() -> MagicMock: + client = MagicMock() + client.admin.command = AsyncMock(return_value={"ok": 1}) + client.close = AsyncMock() + return client + + +def _set_uri(deps: GlobalDependencies, uri: str | None) -> None: + deps.environment_variables = deps.environment_variables.model_copy( + update={"MONGODB_URI": uri, "MONGODB_DATABASE_NAME": "agentex"} + ) + + +@pytest.mark.unit +def test_env_refresh_interval_parses_and_defaults(monkeypatch): + monkeypatch.setenv("MONGODB_OIDC_REFRESH_INTERVAL_SECONDS", "900") + assert ( + EnvironmentVariables.refresh( + force_refresh=True + ).MONGODB_OIDC_REFRESH_INTERVAL_SECONDS + == 900 + ) + + monkeypatch.delenv("MONGODB_OIDC_REFRESH_INTERVAL_SECONDS", raising=False) + assert ( + EnvironmentVariables.refresh( + force_refresh=True + ).MONGODB_OIDC_REFRESH_INTERVAL_SECONDS + == 2700 + ) + + +@pytest.mark.unit +def test_uses_oidc_gate(deps): + _set_uri(deps, OIDC_URI) + assert deps._mongodb_uses_oidc() is True + + _set_uri(deps, PLAIN_URI) + assert deps._mongodb_uses_oidc() is False + + _set_uri(deps, None) + assert deps._mongodb_uses_oidc() is False + + +@pytest.mark.asyncio +@pytest.mark.unit +async def test_refresh_swaps_client_and_drains_old(deps, monkeypatch): + _set_uri(deps, OIDC_URI) + old_client = _mock_client() + new_client = _mock_client() + deps.mongodb_client = old_client + monkeypatch.setattr(deps, "_build_mongodb_client", lambda uri: new_client) + + await deps.refresh_mongodb_client() + + # New client validated before the swap, then installed. + new_client.admin.command.assert_awaited_once_with("ping") + assert deps.mongodb_client is new_client + assert deps.mongodb_database is new_client["agentex"] + + # Old client is scheduled for a drained close, not closed immediately. + old_client.close.assert_not_awaited() + assert len(deps._mongodb_close_tasks) == 1 + for task in list(deps._mongodb_close_tasks): + task.cancel() + + +@pytest.mark.asyncio +@pytest.mark.unit +async def test_refresh_noop_for_non_oidc(deps, monkeypatch): + _set_uri(deps, PLAIN_URI) + old_client = _mock_client() + deps.mongodb_client = old_client + build = MagicMock() + monkeypatch.setattr(deps, "_build_mongodb_client", build) + + await deps.refresh_mongodb_client() + + build.assert_not_called() + assert deps.mongodb_client is old_client + + +@pytest.mark.asyncio +@pytest.mark.unit +async def test_refresh_keeps_old_client_when_new_fails_validation(deps, monkeypatch): + _set_uri(deps, OIDC_URI) + old_client = _mock_client() + deps.mongodb_client = old_client + + broken = _mock_client() + broken.admin.command = AsyncMock(side_effect=RuntimeError("auth failed")) + monkeypatch.setattr(deps, "_build_mongodb_client", lambda uri: broken) + + with pytest.raises(RuntimeError): + await deps.refresh_mongodb_client() + + # Never swapped to the broken client; never tore down the working one. + assert deps.mongodb_client is old_client + old_client.close.assert_not_awaited() + # The candidate is closed before re-raising, so repeated failures can't leak + # orphaned clients across retries. + broken.close.assert_awaited_once() + + +@pytest.mark.asyncio +@pytest.mark.unit +async def test_loop_start_gating(deps): + deps.mongodb_client = _mock_client() + + # Disabled by interval. + _set_uri(deps, OIDC_URI) + deps.environment_variables = deps.environment_variables.model_copy( + update={"MONGODB_OIDC_REFRESH_INTERVAL_SECONDS": 0} + ) + deps._start_mongodb_oidc_refresh_loop() + assert deps._mongodb_refresh_task is None + + # Disabled by non-OIDC URI. + _set_uri(deps, PLAIN_URI) + deps.environment_variables = deps.environment_variables.model_copy( + update={"MONGODB_OIDC_REFRESH_INTERVAL_SECONDS": 2700} + ) + deps._start_mongodb_oidc_refresh_loop() + assert deps._mongodb_refresh_task is None + + # Enabled: OIDC + positive interval. + _set_uri(deps, OIDC_URI) + deps.environment_variables = deps.environment_variables.model_copy( + update={"MONGODB_OIDC_REFRESH_INTERVAL_SECONDS": 2700} + ) + deps._start_mongodb_oidc_refresh_loop() + assert deps._mongodb_refresh_task is not None + + await deps._stop_mongodb_oidc_refresh_loop() + assert deps._mongodb_refresh_task is None diff --git a/agentex/tests/unit/temporal/test_retention_cleanup_activities.py b/agentex/tests/unit/temporal/test_retention_cleanup_activities.py index f97403bf..a4dcda63 100644 --- a/agentex/tests/unit/temporal/test_retention_cleanup_activities.py +++ b/agentex/tests/unit/temporal/test_retention_cleanup_activities.py @@ -9,12 +9,19 @@ ) +def _activities(use_case) -> RetentionCleanupActivities: + """Wrap a use case in the per-run factory the activities now expect.""" + return RetentionCleanupActivities(use_case_factory=lambda: use_case) + + @pytest.mark.unit @pytest.mark.asyncio async def test_find_cleanup_candidates_delegates_to_repo(): repo = AsyncMock() repo.list_cleanup_candidate_ids.return_value = ["t1", "t2"] - activities = RetentionCleanupActivities(task_repository=repo, use_case=AsyncMock()) + use_case = AsyncMock() + use_case.task_repository = repo + activities = _activities(use_case) result = await activities.find_cleanup_candidates( after_id=None, limit=200, idle_days=7, agent_names=["a"] @@ -31,7 +38,9 @@ async def test_find_cleanup_candidates_delegates_to_repo(): async def test_find_multi_agent_cleanup_candidates_delegates_to_repo(): repo = AsyncMock() repo.list_multi_agent_task_ids.return_value = ["t2"] - activities = RetentionCleanupActivities(task_repository=repo, use_case=AsyncMock()) + use_case = AsyncMock() + use_case.task_repository = repo + activities = _activities(use_case) result = await activities.find_multi_agent_cleanup_candidates(["t1", "t2"]) @@ -50,9 +59,7 @@ async def test_clean_task_cleaned_outcome(): task_states_deleted=1, events_deleted=2, ) - activities = RetentionCleanupActivities( - task_repository=AsyncMock(), use_case=use_case - ) + activities = _activities(use_case) outcome = await activities.clean_task(task_id="t1", idle_days=7, dry_run=False) @@ -73,9 +80,7 @@ async def test_clean_task_defaults_to_dry_run_and_validates_without_writes(): task_states_deleted=0, events_deleted=0, ) - activities = RetentionCleanupActivities( - task_repository=AsyncMock(), use_case=use_case - ) + activities = _activities(use_case) outcome = await activities.clean_task(task_id="t1", idle_days=7) @@ -95,9 +100,7 @@ async def test_clean_task_clienterror_maps_to_skipped(): use_case.clean_task.side_effect = ClientError( "Cannot clean task t1: status is RUNNING (active)" ) - activities = RetentionCleanupActivities( - task_repository=AsyncMock(), use_case=use_case - ) + activities = _activities(use_case) outcome = await activities.clean_task(task_id="t1", idle_days=7, dry_run=False) @@ -111,14 +114,39 @@ async def test_clean_task_clienterror_maps_to_skipped(): async def test_clean_task_unexpected_error_propagates(): use_case = AsyncMock() use_case.clean_task.side_effect = RuntimeError("mongo timeout") - activities = RetentionCleanupActivities( - task_repository=AsyncMock(), use_case=use_case - ) + activities = _activities(use_case) with pytest.raises(RuntimeError): await activities.clean_task(task_id="t1", idle_days=7, dry_run=False) +@pytest.mark.unit +@pytest.mark.asyncio +async def test_use_case_resolved_per_activity_run(): + """Each activity call must resolve the use case fresh from the factory, so a + swapped Mongo client (from the OIDC refresh) is always picked up rather than a + stale captured one.""" + calls = [] + + def factory(): + use_case = AsyncMock() + use_case.task_repository.list_cleanup_candidate_ids.return_value = [] + use_case.task_repository.list_multi_agent_task_ids.return_value = [] + calls.append(use_case) + return use_case + + activities = RetentionCleanupActivities(use_case_factory=factory) + + await activities.find_cleanup_candidates( + after_id=None, limit=10, idle_days=7, agent_names=[] + ) + await activities.find_multi_agent_cleanup_candidates([]) + await activities.clean_task(task_id="t1", idle_days=7) + + # One fresh resolution per activity invocation — never cached on the instance. + assert len(calls) == 3 + + @pytest.mark.unit @pytest.mark.asyncio async def test_load_cleanup_config_reads_env(monkeypatch): @@ -128,9 +156,7 @@ async def test_load_cleanup_config_reads_env(monkeypatch): monkeypatch.setenv("RETENTION_CLEANUP_PAGE_SIZE", "33") monkeypatch.setenv("RETENTION_CLEANUP_MAX_IN_FLIGHT", "4") monkeypatch.setenv("RETENTION_CLEANUP_DRY_RUN", "true") - activities = RetentionCleanupActivities( - task_repository=AsyncMock(), use_case=AsyncMock() - ) + activities = _activities(AsyncMock()) config = await activities.load_cleanup_config() assert config == { "enabled": True,