Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
139 changes: 128 additions & 11 deletions agentex/src/config/dependencies.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,8 @@ def __init__(self):
self.redis_pool: redis.ConnectionPool | None = None
self.database_async_read_only_engine: AsyncEngine | None = None
self.postgres_metrics_collector: PostgresMetricsCollector | None = None
self._mongodb_refresh_task: asyncio.Task | None = None
self._mongodb_close_tasks: set[asyncio.Task] = set()
self._loaded = False

async def create_temporal_client(self):
Expand Down Expand Up @@ -122,17 +124,7 @@ async def load(self):

logger.info("Connecting to MongoDB")

self.mongodb_client = AsyncMongoClient(
mongodb_uri,
serverSelectionTimeoutMS=20000,
connectTimeoutMS=20000,
socketTimeoutMS=20000,
retryWrites=False, # Disable retryable writes for AWS DocumentDB compatibility
maxPoolSize=self.environment_variables.MONGODB_MAX_POOL_SIZE,
minPoolSize=self.environment_variables.MONGODB_MIN_POOL_SIZE,
maxIdleTimeMS=30000, # Close connections after 30 seconds of inactivity
waitQueueTimeoutMS=5000, # Wait up to 5 seconds for a connection from pool
)
self.mongodb_client = self._build_mongodb_client(mongodb_uri)
self.mongodb_database = self.mongodb_client[mongodb_database_name]

# Ping the database to verify connection
Expand Down Expand Up @@ -226,10 +218,132 @@ async def load(self):
service_name=service_name,
)

self._start_mongodb_oidc_refresh_loop()

self._loaded = True

def _build_mongodb_client(self, mongodb_uri: str) -> AsyncMongoClient:
"""Construct an AsyncMongoClient with the shared pool/timeout settings.

Used both at startup and by the OIDC refresh, so the two paths can never
drift apart.
"""
return AsyncMongoClient(
mongodb_uri,
serverSelectionTimeoutMS=20000,
connectTimeoutMS=20000,
socketTimeoutMS=20000,
retryWrites=False, # Disable retryable writes for AWS DocumentDB compatibility
maxPoolSize=self.environment_variables.MONGODB_MAX_POOL_SIZE,
minPoolSize=self.environment_variables.MONGODB_MIN_POOL_SIZE,
maxIdleTimeMS=30000, # Close connections after 30 seconds of inactivity
waitQueueTimeoutMS=5000, # Wait up to 5 seconds for a connection from pool
)

def _mongodb_uses_oidc(self) -> bool:
"""True only when the Mongo URI authenticates via MONGODB-OIDC.

Gates the refresh loop so standard-auth / AWS DocumentDB deployments are
never churned — only GCP OIDC tokens expire out from under a live client.
"""
uri = self.environment_variables.MONGODB_URI or ""
return "MONGODB-OIDC" in uri.upper()

async def refresh_mongodb_client(self) -> None:
"""Rebuild the Mongo client to renew the cached GCP OIDC token.

pymongo's built-in GCP OIDC provider caches the access token for the life
of the client and only refreshes it reactively (on a server reauth
challenge). GCP tokens expire after ~1h, so a long-lived client eventually
fails auth. A new client authenticates fresh, picking up a new token.

The new client is built and pinged (which forces fresh auth) before the
swap, and the old client is closed only after a drain delay, so no in-flight
operation is ever dropped and we never swap to a broken client.
"""
mongodb_uri = self.environment_variables.MONGODB_URI
if not mongodb_uri or not self._mongodb_uses_oidc():
return

new_client = self._build_mongodb_client(mongodb_uri)
# Force fresh OIDC auth and validate the new client before trusting it.
# If this raises, close the candidate (so a repeated auth/network outage
# can't accumulate orphaned clients across retries) and keep using the
# existing, working client.
try:
await new_client.admin.command("ping")
except BaseException:
await new_client.close()
raise

old_client = self.mongodb_client
self.mongodb_client = new_client
self.mongodb_database = new_client[
self.environment_variables.MONGODB_DATABASE_NAME
]
logger.info("Refreshed MongoDB client to renew OIDC credentials")

if old_client is not None and old_client is not new_client:
task = asyncio.create_task(
self._close_mongodb_client_after_delay(old_client)
)
# Keep a strong reference until done so the task is not GC'd mid-flight.
self._mongodb_close_tasks.add(task)
task.add_done_callback(self._mongodb_close_tasks.discard)
Comment on lines +286 to +292

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Old clients still used

This refresh swaps only GlobalDependencies.mongodb_database, but long-lived consumers can keep using AsyncDatabase and collection objects from the old client. For example, the Temporal worker builds retention repositories once at startup, and MongoDBCRUDRepository.__init__ stores self.collection = db[collection_name]. After the first refresh, those activities still use collections bound to old_client; once this delayed close runs, retention cleanup Mongo operations can fail even though the global dependency points at a fresh client.

Artifacts

Repro: pytest harness for cached Mongo collection across OIDC refresh

  • Contains supporting evidence from the run (text/x-python; charset=utf-8).

Repro: verbose pytest output showing cached collection uses the closed superseded client

  • Keeps the command output available without making the summary code-heavy.

View artifacts

T-Rex Ran code and verified through T-Rex

Prompt To Fix With AI
This is a comment left during a code review.
Path: agentex/src/config/dependencies.py
Line: 280-286

Comment:
**Old clients still used**

This refresh swaps only `GlobalDependencies.mongodb_database`, but long-lived consumers can keep using `AsyncDatabase` and collection objects from the old client. For example, the Temporal worker builds retention repositories once at startup, and `MongoDBCRUDRepository.__init__` stores `self.collection = db[collection_name]`. After the first refresh, those activities still use collections bound to `old_client`; once this delayed close runs, retention cleanup Mongo operations can fail even though the global dependency points at a fresh client.

How can I resolve this? If you propose a fix, please make it concise.

Fix in Cursor Fix in Claude Code Fix in Codex


async def _close_mongodb_client_after_delay(
self, client: AsyncMongoClient, delay: float = 60.0
) -> None:
"""Close a superseded Mongo client after letting in-flight ops drain."""
try:
await asyncio.sleep(delay)
await client.close()
except asyncio.CancelledError:
await client.close()
raise
except Exception as e:
logger.warning(f"Error closing superseded MongoDB client: {e}")

def _start_mongodb_oidc_refresh_loop(self) -> None:
interval = self.environment_variables.MONGODB_OIDC_REFRESH_INTERVAL_SECONDS
if (
self.mongodb_client is None
or not self._mongodb_uses_oidc()
or interval <= 0
or self._mongodb_refresh_task is not None
):
return
self._mongodb_refresh_task = asyncio.create_task(
self._mongodb_oidc_refresh_loop(interval)
)
logger.info(f"Started MongoDB OIDC refresh loop (interval={interval}s)")

async def _mongodb_oidc_refresh_loop(self, interval: int) -> None:
while True:
try:
await asyncio.sleep(interval)
await self.refresh_mongodb_client()
except asyncio.CancelledError:
raise
except Exception as e:
logger.error(
f"MongoDB OIDC refresh failed; retrying next interval: {e}"
)

async def _stop_mongodb_oidc_refresh_loop(self) -> None:
if self._mongodb_refresh_task is not None:
self._mongodb_refresh_task.cancel()
try:
await self._mongodb_refresh_task
except asyncio.CancelledError:
pass
self._mongodb_refresh_task = None

async def force_reload(self):
"""Force reload all dependencies with fresh environment variables"""
# Stop the MongoDB OIDC refresh loop before tearing down the client
await self._stop_mongodb_oidc_refresh_loop()

# Stop metrics collection
if self.postgres_metrics_collector:
await self.postgres_metrics_collector.stop_collection()
Expand Down Expand Up @@ -272,6 +386,9 @@ def shutdown():
async def async_shutdown():
global_dependencies = GlobalDependencies()

# Stop the MongoDB OIDC refresh loop
await global_dependencies._stop_mongodb_oidc_refresh_loop()

# Stop PostgreSQL metrics collection
if global_dependencies.postgres_metrics_collector:
await global_dependencies.postgres_metrics_collector.stop_collection()
Expand Down
9 changes: 9 additions & 0 deletions agentex/src/config/environment_variables.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ class EnvVarKeys(str, Enum):
MONGODB_DATABASE_NAME = "MONGODB_DATABASE_NAME"
MONGODB_MAX_POOL_SIZE = "MONGODB_MAX_POOL_SIZE"
MONGODB_MIN_POOL_SIZE = "MONGODB_MIN_POOL_SIZE"
MONGODB_OIDC_REFRESH_INTERVAL_SECONDS = "MONGODB_OIDC_REFRESH_INTERVAL_SECONDS"
REDIS_MAX_CONNECTIONS = "REDIS_MAX_CONNECTIONS"
REDIS_CONNECTION_TIMEOUT = "REDIS_CONNECTION_TIMEOUT"
REDIS_SOCKET_TIMEOUT = "REDIS_SOCKET_TIMEOUT"
Expand Down Expand Up @@ -96,6 +97,11 @@ class EnvironmentVariables(BaseModel):
MONGODB_DATABASE_NAME: str | None = "agentex"
MONGODB_MAX_POOL_SIZE: int = 50
MONGODB_MIN_POOL_SIZE: int = 5
# Rebuild the Mongo client on this interval to renew GCP OIDC credentials.
# pymongo caches the OIDC token for the life of the client and never refreshes
# it proactively, so a long-lived client fails auth once the ~1h GCP token
# expires. Only applied to MONGODB-OIDC URIs; 0 disables. Default 45 min.
MONGODB_OIDC_REFRESH_INTERVAL_SECONDS: int = 2700
REDIS_MAX_CONNECTIONS: int = 50 # Increased for SSE streaming
REDIS_CONNECTION_TIMEOUT: int = 60 # Connection timeout in seconds
REDIS_SOCKET_TIMEOUT: int = 30 # Socket timeout in seconds
Expand Down Expand Up @@ -163,6 +169,9 @@ def refresh(cls, force_refresh: bool = False) -> EnvironmentVariables | None:
MONGODB_MIN_POOL_SIZE=int(
os.environ.get(EnvVarKeys.MONGODB_MIN_POOL_SIZE, "5")
),
MONGODB_OIDC_REFRESH_INTERVAL_SECONDS=int(
os.environ.get(EnvVarKeys.MONGODB_OIDC_REFRESH_INTERVAL_SECONDS, "2700")
),
REDIS_MAX_CONNECTIONS=int(
os.environ.get(EnvVarKeys.REDIS_MAX_CONNECTIONS, "100")
),
Expand Down
25 changes: 16 additions & 9 deletions agentex/src/temporal/activities/retention_cleanup_activities.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,11 @@
Pydantic models).
"""

from collections.abc import Callable
from typing import TypedDict

from src.config.environment_variables import EnvironmentVariables
from src.domain.exceptions import ClientError
from src.domain.repositories.task_repository import TaskRepository
from src.domain.use_cases.task_retention_use_case import TaskRetentionUseCase
from src.utils.logging import make_logger
from temporalio import activity
Expand All @@ -46,11 +46,15 @@ class CleanTaskOutcome(TypedDict):
class RetentionCleanupActivities:
def __init__(
self,
task_repository: TaskRepository,
use_case: TaskRetentionUseCase,
use_case_factory: Callable[[], TaskRetentionUseCase],
):
self.task_repository = task_repository
self.use_case = use_case
# Build the use case (and its Mongo-backed repositories) per activity run
# rather than capturing it at worker startup. The OIDC client refresh
# periodically swaps GlobalDependencies.mongodb_database to a fresh client
# and closes the old one; a use case captured once would hold collections
# bound to the stale (eventually closed) client. Resolving per run always
# picks up the current client.
self._use_case_factory = use_case_factory

@activity.defn(name=LOAD_CLEANUP_CONFIG_ACTIVITY)
async def load_cleanup_config(self) -> dict:
Expand Down Expand Up @@ -101,7 +105,8 @@ async def find_cleanup_candidates(
"find_cleanup_candidates_started",
extra={"after_id": after_id, "limit": limit},
)
result = await self.task_repository.list_cleanup_candidate_ids(
task_repository = self._use_case_factory().task_repository
result = await task_repository.list_cleanup_candidate_ids(
idle_days=idle_days,
agent_names=agent_names,
after_id=after_id,
Expand All @@ -120,7 +125,8 @@ async def find_multi_agent_cleanup_candidates(
Cleanup deletes task-wide content, so these are skipped by the scheduled
workflow before any per-task child workflow can run.
"""
result = await self.task_repository.list_multi_agent_task_ids(task_ids=task_ids)
task_repository = self._use_case_factory().task_repository
result = await task_repository.list_multi_agent_task_ids(task_ids=task_ids)
logger.info(
"find_multi_agent_cleanup_candidates_completed",
extra={"count": len(result)},
Expand All @@ -146,9 +152,10 @@ async def clean_task(
(e.g. task is still active, not yet idle long enough, or already
cleaned). Other exceptions propagate so Temporal can retry them.
"""
use_case = self._use_case_factory()
try:
if dry_run:
result = await self.use_case.preview_clean_task(
result = await use_case.preview_clean_task(
task_id=task_id, force=False, idle_days=idle_days
)
logger.info(
Expand All @@ -163,7 +170,7 @@ async def clean_task(
"task_states_deleted": 0,
"events_deleted": 0,
}
result = await self.use_case.clean_task(
result = await use_case.clean_task(
task_id=task_id, force=False, idle_days=idle_days
)
return {
Expand Down
11 changes: 6 additions & 5 deletions agentex/src/temporal/run_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,12 +153,13 @@ def create_agentex_server_worker(
http_client=http_client,
)

retention_use_case = build_task_retention_use_case(global_dependencies)
# Reuse the repository the factory already built (avoids a duplicate
# TaskRepository) via the use case's stable accessor.
# Build the retention use case per activity run (not once here): the OIDC
# client refresh swaps global_dependencies.mongodb_database to a fresh client
# and closes the old one, so a use case captured at startup would eventually
# hold Mongo collections bound to a closed client. global_dependencies is the
# singleton, so each call reads the current mongodb_database.
retention_activities = RetentionCleanupActivities(
task_repository=retention_use_case.task_repository,
use_case=retention_use_case,
use_case_factory=lambda: build_task_retention_use_case(global_dependencies),
)

return asyncio.create_task(
Expand Down
Empty file.
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
"""Integration tests for the MongoDB client-refresh swap against a real Mongo.

The unit tests mock the client; these prove the build-validate-swap-drain works
end-to-end against a live MongoDB container: data written before the swap is still
readable after it, the post-swap client is fully functional, and the superseded
client is drained and closed. (The container doesn't speak GCP OIDC, so the OIDC
gate is forced on to exercise the swap path itself.)
"""

from unittest.mock import AsyncMock

import pytest
from src.config.dependencies import GlobalDependencies, Singleton


@pytest.fixture
def deps(mongodb_connection_string):
"""Fresh GlobalDependencies wired to the test Mongo container."""
Singleton._instances.pop(GlobalDependencies, None)
instance = GlobalDependencies()
instance.environment_variables = instance.environment_variables.model_copy(
update={
"MONGODB_URI": mongodb_connection_string,
"MONGODB_DATABASE_NAME": "agentex_oidc_refresh_test",
}
)
instance.mongodb_client = instance._build_mongodb_client(mongodb_connection_string)
instance.mongodb_database = instance.mongodb_client["agentex_oidc_refresh_test"]
yield instance
Singleton._instances.pop(GlobalDependencies, None)


@pytest.mark.asyncio
@pytest.mark.integration
async def test_refresh_preserves_data_and_drains_old_client(deps, monkeypatch):
# Treat the container URI as OIDC so the refresh path actually runs, and
# collapse the drain delay so the close completes within the test.
monkeypatch.setattr(deps, "_mongodb_uses_oidc", lambda: True)
original_close_after_delay = deps._close_mongodb_client_after_delay

async def fast_close(client, delay=0.0):
await original_close_after_delay(client, delay=0.0)

monkeypatch.setattr(deps, "_close_mongodb_client_after_delay", fast_close)

collection = "docs"
await deps.mongodb_database[collection].insert_one({"_id": "before", "n": 1})

old_client = deps.mongodb_client
old_client.close = AsyncMock(wraps=old_client.close)

await deps.refresh_mongodb_client()

# A genuinely new client is now installed.
assert deps.mongodb_client is not old_client

# The new client can write, and reads the doc written before the swap.
await deps.mongodb_database[collection].insert_one({"_id": "after", "n": 2})
ids = {
doc["_id"]
async for doc in deps.mongodb_database[collection].find({}, {"_id": 1})
}
assert ids == {"before", "after"}

# The superseded client is drained and closed.
for task in list(deps._mongodb_close_tasks):
await task
old_client.close.assert_awaited_once()

await deps.mongodb_client.drop_database("agentex_oidc_refresh_test")
await deps.mongodb_client.close()
Loading